Unverified Commit a2090375 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[VAE] fix the downsample block in Encoder. (#156)

* pass downsample_padding in encoder

* update tests
parent c4a3b09a
...@@ -40,6 +40,7 @@ class Encoder(nn.Module): ...@@ -40,6 +40,7 @@ class Encoder(nn.Module):
out_channels=output_channel, out_channels=output_channel,
add_downsample=not is_final_block, add_downsample=not is_final_block,
resnet_eps=1e-6, resnet_eps=1e-6,
downsample_padding=0,
resnet_act_fn=act_fn, resnet_act_fn=act_fn,
attn_num_head_channels=None, attn_num_head_channels=None,
temb_channels=None, temb_channels=None,
......
...@@ -555,11 +555,11 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -555,11 +555,11 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
def prepare_init_args_and_inputs_for_common(self): def prepare_init_args_and_inputs_for_common(self):
init_dict = { init_dict = {
"block_out_channels": [64], "block_out_channels": [32, 64],
"in_channels": 3, "in_channels": 3,
"out_channels": 3, "out_channels": 3,
"down_block_types": ["DownEncoderBlock2D"], "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
"up_block_types": ["UpDecoderBlock2D"], "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"],
"latent_channels": 3, "latent_channels": 3,
} }
inputs_dict = self.dummy_input inputs_dict = self.dummy_input
...@@ -595,7 +595,7 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -595,7 +595,7 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
output_slice = output[0, -1, -3:, -3:].flatten() output_slice = output[0, -1, -3:, -3:].flatten()
# fmt: off # fmt: off
expected_output_slice = torch.tensor([-1.1321, 0.1056, 0.3505, -0.6461, -0.2014, 0.0419, -0.5763, -0.8462, -0.4218]) expected_output_slice = torch.tensor([-0.0153, -0.4044, -0.1880, -0.5161, -0.2418, -0.4072, -0.1612, -0.0633, -0.0143])
# fmt: on # fmt: on
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2)) self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
...@@ -623,22 +623,11 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase): ...@@ -623,22 +623,11 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
def prepare_init_args_and_inputs_for_common(self): def prepare_init_args_and_inputs_for_common(self):
init_dict = { init_dict = {
"ch": 64, "block_out_channels": [32, 64],
"ch_mult": (1,),
"embed_dim": 4,
"in_channels": 3,
"attn_resolutions": [],
"num_res_blocks": 1,
"out_ch": 3,
"resolution": 32,
"z_channels": 4,
}
init_dict = {
"block_out_channels": [64],
"in_channels": 3, "in_channels": 3,
"out_channels": 3, "out_channels": 3,
"down_block_types": ["DownEncoderBlock2D"], "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
"up_block_types": ["UpDecoderBlock2D"], "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"],
"latent_channels": 4, "latent_channels": 4,
} }
inputs_dict = self.dummy_input inputs_dict = self.dummy_input
...@@ -674,7 +663,7 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase): ...@@ -674,7 +663,7 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
output_slice = output[0, -1, -3:, -3:].flatten() output_slice = output[0, -1, -3:, -3:].flatten()
# fmt: off # fmt: off
expected_output_slice = torch.tensor([-0.3900, -0.2800, 0.1281, -0.4449, -0.4890, -0.0207, 0.0784, -0.1258, -0.0409]) expected_output_slice = torch.tensor([-4.0078e-01, -3.8304e-04, -1.2681e-01, -1.1462e-01, 2.0095e-01, 1.0893e-01, -8.8248e-02, -3.0361e-01, -9.8646e-03])
# fmt: on # fmt: on
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2)) self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment