Commit b9de7172 authored by patil-suraj's avatar patil-suraj
Browse files

add Downsample

parent ee010726
...@@ -103,7 +103,7 @@ class Downsample(nn.Module): ...@@ -103,7 +103,7 @@ class Downsample(nn.Module):
downsampling occurs in the inner-two dimensions. downsampling occurs in the inner-two dimensions.
""" """
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): def __init__(self, channels, use_conv=False, dims=2, out_channels=None, padding=1, name="conv"):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.out_channels = out_channels or channels self.out_channels = out_channels or channels
...@@ -111,18 +111,29 @@ class Downsample(nn.Module): ...@@ -111,18 +111,29 @@ class Downsample(nn.Module):
self.dims = dims self.dims = dims
self.padding = padding self.padding = padding
stride = 2 if dims != 3 else (1, 2, 2) stride = 2 if dims != 3 else (1, 2, 2)
self.name = name
if use_conv: if use_conv:
self.down = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding) conv = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding)
else: else:
assert self.channels == self.out_channels assert self.channels == self.out_channels
self.down = avg_pool_nd(dims, kernel_size=stride, stride=stride) conv = avg_pool_nd(dims, kernel_size=stride, stride=stride)
if name == "conv":
self.conv = conv
else:
self.op = conv
def forward(self, x): def forward(self, x):
assert x.shape[1] == self.channels assert x.shape[1] == self.channels
if self.use_conv and self.padding == 0 and self.dims == 2: if self.use_conv and self.padding == 0 and self.dims == 2:
pad = (0, 1, 0, 1) pad = (0, 1, 0, 1)
x = F.pad(x, pad, mode="constant", value=0) x = F.pad(x, pad, mode="constant", value=0)
return self.down(x)
if self.name == "conv":
return self.conv(x)
else:
return self.op(x)
# TODO (patil-suraj): needs test # TODO (patil-suraj): needs test
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment