Unverified Commit c0493723 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Remove the usage of numpy in up/down sample_2d (#503)



* Fix PT up/down sample_2d

* empty commit

* style

* style
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent c727a6a5
from functools import partial
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
......@@ -134,10 +133,10 @@ class FirUpsample2D(nn.Module):
kernel = [1] * factor
# setup kernel
kernel = np.asarray(kernel, dtype=np.float32)
kernel = torch.tensor(kernel, dtype=torch.float32)
if kernel.ndim == 1:
kernel = np.outer(kernel, kernel)
kernel /= np.sum(kernel)
kernel = torch.outer(kernel, kernel)
kernel /= torch.sum(kernel)
kernel = kernel * (gain * (factor**2))
......@@ -219,10 +218,10 @@ class FirDownsample2D(nn.Module):
kernel = [1] * factor
# setup kernel
kernel = np.asarray(kernel, dtype=np.float32)
kernel = torch.tensor(kernel, dtype=torch.float32)
if kernel.ndim == 1:
kernel = np.outer(kernel, kernel)
kernel /= np.sum(kernel)
kernel = torch.outer(kernel, kernel)
kernel /= torch.sum(kernel)
kernel = kernel * gain
......@@ -391,16 +390,14 @@ def upsample_2d(x, kernel=None, factor=2, gain=1):
if kernel is None:
kernel = [1] * factor
kernel = np.asarray(kernel, dtype=np.float32)
kernel = torch.tensor(kernel, dtype=torch.float32)
if kernel.ndim == 1:
kernel = np.outer(kernel, kernel)
kernel /= np.sum(kernel)
kernel = torch.outer(kernel, kernel)
kernel /= torch.sum(kernel)
kernel = kernel * (gain * (factor**2))
p = kernel.shape[0] - factor
return upfirdn2d_native(
x, torch.tensor(kernel, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)
)
return upfirdn2d_native(x, kernel.to(device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2))
def downsample_2d(x, kernel=None, factor=2, gain=1):
......@@ -425,14 +422,14 @@ def downsample_2d(x, kernel=None, factor=2, gain=1):
if kernel is None:
kernel = [1] * factor
kernel = np.asarray(kernel, dtype=np.float32)
kernel = torch.tensor(kernel, dtype=torch.float32)
if kernel.ndim == 1:
kernel = np.outer(kernel, kernel)
kernel /= np.sum(kernel)
kernel = torch.outer(kernel, kernel)
kernel /= torch.sum(kernel)
kernel = kernel * gain
p = kernel.shape[0] - factor
return upfirdn2d_native(x, torch.tensor(kernel, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
return upfirdn2d_native(x, kernel.to(device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment