"vscode:/vscode.git/clone" did not exist on "648090e26eda508ac37d8657a91fc3d759bcbd8f"
Commit 81e71447 authored by patil-suraj's avatar patil-suraj
Browse files

remove naive up/down sample

parent c9bd4d43
......@@ -957,19 +957,6 @@ def downsample_2d(x, k=None, factor=2, gain=1):
return upfirdn2d(x, torch.tensor(k, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
def naive_upsample_2d(x, factor=2):
_N, C, H, W = x.shape
x = torch.reshape(x, (-1, C, H, 1, W, 1))
x = x.repeat(1, 1, 1, factor, 1, factor)
return torch.reshape(x, (-1, C, H * factor, W * factor))
def naive_downsample_2d(x, factor=2):
_N, C, H, W = x.shape
x = torch.reshape(x, (-1, C, H // factor, factor, W // factor, factor))
return torch.mean(x, dim=(3, 5))
class NIN(nn.Module):
def __init__(self, in_dim, num_units, init_scale=0.1):
super().__init__()
......
......@@ -123,19 +123,6 @@ class Conv2d(nn.Module):
return x
def naive_upsample_2d(x, factor=2):
_N, C, H, W = x.shape
x = torch.reshape(x, (-1, C, H, 1, W, 1))
x = x.repeat(1, 1, 1, factor, 1, factor)
return torch.reshape(x, (-1, C, H * factor, W * factor))
def naive_downsample_2d(x, factor=2):
_N, C, H, W = x.shape
x = torch.reshape(x, (-1, C, H // factor, factor, W // factor, factor))
return torch.mean(x, dim=(3, 5))
def upsample_conv_2d(x, w, k=None, factor=2, gain=1):
"""Fused `upsample_2d()` followed by `tf.nn.conv2d()`.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment