Commit 919e27d3 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

re-add super.__init__ for all PyTorch modules

parent ad9d2525
...@@ -65,6 +65,7 @@ class UNetConditionalModel(ModelMixin, ConfigMixin): ...@@ -65,6 +65,7 @@ class UNetConditionalModel(ModelMixin, ConfigMixin):
center_input_sample=False, center_input_sample=False,
resnet_num_groups=30, resnet_num_groups=30,
): ):
super().__init__()
self.image_size = image_size self.image_size = image_size
time_embed_dim = block_channels[0] * 4 time_embed_dim = block_channels[0] * 4
......
...@@ -61,6 +61,7 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): ...@@ -61,6 +61,7 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
center_input_sample=False, center_input_sample=False,
resnet_num_groups=32, resnet_num_groups=32,
): ):
super().__init__()
self.image_size = image_size self.image_size = image_size
time_embed_dim = block_channels[0] * 4 time_embed_dim = block_channels[0] * 4
......
...@@ -400,6 +400,7 @@ class VQModel(ModelMixin, ConfigMixin): ...@@ -400,6 +400,7 @@ class VQModel(ModelMixin, ConfigMixin):
resamp_with_conv=True, resamp_with_conv=True,
give_pre_end=False, give_pre_end=False,
): ):
super().__init__()
# pass init params to Encoder # pass init params to Encoder
self.encoder = Encoder( self.encoder = Encoder(
...@@ -477,6 +478,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin): ...@@ -477,6 +478,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
resamp_with_conv=True, resamp_with_conv=True,
give_pre_end=False, give_pre_end=False,
): ):
super().__init__()
# pass init params to Encoder # pass init params to Encoder
self.encoder = Encoder( self.encoder = Encoder(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment