"examples/vscode:/vscode.git/clone" did not exist on "66bf7ea5be7099c8a47b9cba135f276d55247447"
Unverified Commit 5b20d3b3 authored by Chenguo Lin's avatar Chenguo Lin Committed by GitHub
Browse files

fix the parameter naming in `self.downsamplers` (#1108)


Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 2c108693
...@@ -462,7 +462,7 @@ class AttnDownBlock2D(nn.Module): ...@@ -462,7 +462,7 @@ class AttnDownBlock2D(nn.Module):
self.downsamplers = nn.ModuleList( self.downsamplers = nn.ModuleList(
[ [
Downsample2D( Downsample2D(
in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
) )
] ]
) )
...@@ -546,7 +546,7 @@ class CrossAttnDownBlock2D(nn.Module): ...@@ -546,7 +546,7 @@ class CrossAttnDownBlock2D(nn.Module):
self.downsamplers = nn.ModuleList( self.downsamplers = nn.ModuleList(
[ [
Downsample2D( Downsample2D(
in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
) )
] ]
) )
...@@ -651,7 +651,7 @@ class DownBlock2D(nn.Module): ...@@ -651,7 +651,7 @@ class DownBlock2D(nn.Module):
self.downsamplers = nn.ModuleList( self.downsamplers = nn.ModuleList(
[ [
Downsample2D( Downsample2D(
in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
) )
] ]
) )
...@@ -729,7 +729,7 @@ class DownEncoderBlock2D(nn.Module): ...@@ -729,7 +729,7 @@ class DownEncoderBlock2D(nn.Module):
self.downsamplers = nn.ModuleList( self.downsamplers = nn.ModuleList(
[ [
Downsample2D( Downsample2D(
in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
) )
] ]
) )
...@@ -801,7 +801,7 @@ class AttnDownEncoderBlock2D(nn.Module): ...@@ -801,7 +801,7 @@ class AttnDownEncoderBlock2D(nn.Module):
self.downsamplers = nn.ModuleList( self.downsamplers = nn.ModuleList(
[ [
Downsample2D( Downsample2D(
in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
) )
] ]
) )
...@@ -886,7 +886,7 @@ class AttnSkipDownBlock2D(nn.Module): ...@@ -886,7 +886,7 @@ class AttnSkipDownBlock2D(nn.Module):
down=True, down=True,
kernel="fir", kernel="fir",
) )
self.downsamplers = nn.ModuleList([FirDownsample2D(in_channels, out_channels=out_channels)]) self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)])
self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1)) self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
else: else:
self.resnet_down = None self.resnet_down = None
...@@ -966,7 +966,7 @@ class SkipDownBlock2D(nn.Module): ...@@ -966,7 +966,7 @@ class SkipDownBlock2D(nn.Module):
down=True, down=True,
kernel="fir", kernel="fir",
) )
self.downsamplers = nn.ModuleList([FirDownsample2D(in_channels, out_channels=out_channels)]) self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)])
self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1)) self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
else: else:
self.resnet_down = None self.resnet_down = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment