"git@developer.sourcefind.cn:OpenDAS/ktransformers.git" did not exist on "19d4a50b1cb514f9090c5bcbf5d9893da8b48674"
Unverified Commit d144c46a authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[UNet2DConditionModel, UNet2DModel] pass norm_num_groups to all the blocks (#442)

* pass norm_num_groups to unet blocs and attention

* fix UNet2DConditionModel

* add norm_num_groups arg in vae

* add tests

* remove comment

* Apply suggestions from code review
parent b34be039
...@@ -113,6 +113,7 @@ class SpatialTransformer(nn.Module): ...@@ -113,6 +113,7 @@ class SpatialTransformer(nn.Module):
d_head: int, d_head: int,
depth: int = 1, depth: int = 1,
dropout: float = 0.0, dropout: float = 0.0,
num_groups: int = 32,
context_dim: Optional[int] = None, context_dim: Optional[int] = None,
): ):
super().__init__() super().__init__()
...@@ -120,7 +121,7 @@ class SpatialTransformer(nn.Module): ...@@ -120,7 +121,7 @@ class SpatialTransformer(nn.Module):
self.d_head = d_head self.d_head = d_head
self.in_channels = in_channels self.in_channels = in_channels
inner_dim = n_heads * d_head inner_dim = n_heads * d_head
self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) self.norm = torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
......
...@@ -114,6 +114,7 @@ class UNet2DModel(ModelMixin, ConfigMixin): ...@@ -114,6 +114,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
add_downsample=not is_final_block, add_downsample=not is_final_block,
resnet_eps=norm_eps, resnet_eps=norm_eps,
resnet_act_fn=act_fn, resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attn_num_head_channels=attention_head_dim, attn_num_head_channels=attention_head_dim,
downsample_padding=downsample_padding, downsample_padding=downsample_padding,
) )
...@@ -151,6 +152,7 @@ class UNet2DModel(ModelMixin, ConfigMixin): ...@@ -151,6 +152,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
add_upsample=not is_final_block, add_upsample=not is_final_block,
resnet_eps=norm_eps, resnet_eps=norm_eps,
resnet_act_fn=act_fn, resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attn_num_head_channels=attention_head_dim, attn_num_head_channels=attention_head_dim,
) )
self.up_blocks.append(up_block) self.up_blocks.append(up_block)
......
...@@ -114,6 +114,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): ...@@ -114,6 +114,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
add_downsample=not is_final_block, add_downsample=not is_final_block,
resnet_eps=norm_eps, resnet_eps=norm_eps,
resnet_act_fn=act_fn, resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim, attn_num_head_channels=attention_head_dim,
downsample_padding=downsample_padding, downsample_padding=downsample_padding,
...@@ -153,6 +154,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): ...@@ -153,6 +154,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
add_upsample=not is_final_block, add_upsample=not is_final_block,
resnet_eps=norm_eps, resnet_eps=norm_eps,
resnet_act_fn=act_fn, resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim, attn_num_head_channels=attention_head_dim,
) )
......
...@@ -31,6 +31,7 @@ def get_down_block( ...@@ -31,6 +31,7 @@ def get_down_block(
resnet_eps, resnet_eps,
resnet_act_fn, resnet_act_fn,
attn_num_head_channels, attn_num_head_channels,
resnet_groups=None,
cross_attention_dim=None, cross_attention_dim=None,
downsample_padding=None, downsample_padding=None,
): ):
...@@ -44,6 +45,7 @@ def get_down_block( ...@@ -44,6 +45,7 @@ def get_down_block(
add_downsample=add_downsample, add_downsample=add_downsample,
resnet_eps=resnet_eps, resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn, resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
downsample_padding=downsample_padding, downsample_padding=downsample_padding,
) )
elif down_block_type == "AttnDownBlock2D": elif down_block_type == "AttnDownBlock2D":
...@@ -55,6 +57,7 @@ def get_down_block( ...@@ -55,6 +57,7 @@ def get_down_block(
add_downsample=add_downsample, add_downsample=add_downsample,
resnet_eps=resnet_eps, resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn, resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
downsample_padding=downsample_padding, downsample_padding=downsample_padding,
attn_num_head_channels=attn_num_head_channels, attn_num_head_channels=attn_num_head_channels,
) )
...@@ -69,6 +72,7 @@ def get_down_block( ...@@ -69,6 +72,7 @@ def get_down_block(
add_downsample=add_downsample, add_downsample=add_downsample,
resnet_eps=resnet_eps, resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn, resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
downsample_padding=downsample_padding, downsample_padding=downsample_padding,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels, attn_num_head_channels=attn_num_head_channels,
...@@ -104,6 +108,7 @@ def get_down_block( ...@@ -104,6 +108,7 @@ def get_down_block(
add_downsample=add_downsample, add_downsample=add_downsample,
resnet_eps=resnet_eps, resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn, resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
downsample_padding=downsample_padding, downsample_padding=downsample_padding,
) )
...@@ -119,6 +124,7 @@ def get_up_block( ...@@ -119,6 +124,7 @@ def get_up_block(
resnet_eps, resnet_eps,
resnet_act_fn, resnet_act_fn,
attn_num_head_channels, attn_num_head_channels,
resnet_groups=None,
cross_attention_dim=None, cross_attention_dim=None,
): ):
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
...@@ -132,6 +138,7 @@ def get_up_block( ...@@ -132,6 +138,7 @@ def get_up_block(
add_upsample=add_upsample, add_upsample=add_upsample,
resnet_eps=resnet_eps, resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn, resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
) )
elif up_block_type == "CrossAttnUpBlock2D": elif up_block_type == "CrossAttnUpBlock2D":
if cross_attention_dim is None: if cross_attention_dim is None:
...@@ -145,6 +152,7 @@ def get_up_block( ...@@ -145,6 +152,7 @@ def get_up_block(
add_upsample=add_upsample, add_upsample=add_upsample,
resnet_eps=resnet_eps, resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn, resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels, attn_num_head_channels=attn_num_head_channels,
) )
...@@ -158,6 +166,7 @@ def get_up_block( ...@@ -158,6 +166,7 @@ def get_up_block(
add_upsample=add_upsample, add_upsample=add_upsample,
resnet_eps=resnet_eps, resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn, resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
attn_num_head_channels=attn_num_head_channels, attn_num_head_channels=attn_num_head_channels,
) )
elif up_block_type == "SkipUpBlock2D": elif up_block_type == "SkipUpBlock2D":
...@@ -191,6 +200,7 @@ def get_up_block( ...@@ -191,6 +200,7 @@ def get_up_block(
add_upsample=add_upsample, add_upsample=add_upsample,
resnet_eps=resnet_eps, resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn, resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
) )
raise ValueError(f"{up_block_type} does not exist.") raise ValueError(f"{up_block_type} does not exist.")
...@@ -323,6 +333,7 @@ class UNetMidBlock2DCrossAttn(nn.Module): ...@@ -323,6 +333,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
in_channels // attn_num_head_channels, in_channels // attn_num_head_channels,
depth=1, depth=1,
context_dim=cross_attention_dim, context_dim=cross_attention_dim,
num_groups=resnet_groups,
) )
) )
resnets.append( resnets.append(
...@@ -414,6 +425,7 @@ class AttnDownBlock2D(nn.Module): ...@@ -414,6 +425,7 @@ class AttnDownBlock2D(nn.Module):
num_head_channels=attn_num_head_channels, num_head_channels=attn_num_head_channels,
rescale_output_factor=output_scale_factor, rescale_output_factor=output_scale_factor,
eps=resnet_eps, eps=resnet_eps,
num_groups=resnet_groups,
) )
) )
...@@ -498,6 +510,7 @@ class CrossAttnDownBlock2D(nn.Module): ...@@ -498,6 +510,7 @@ class CrossAttnDownBlock2D(nn.Module):
out_channels // attn_num_head_channels, out_channels // attn_num_head_channels,
depth=1, depth=1,
context_dim=cross_attention_dim, context_dim=cross_attention_dim,
num_groups=resnet_groups,
) )
) )
self.attentions = nn.ModuleList(attentions) self.attentions = nn.ModuleList(attentions)
...@@ -966,6 +979,7 @@ class AttnUpBlock2D(nn.Module): ...@@ -966,6 +979,7 @@ class AttnUpBlock2D(nn.Module):
num_head_channels=attn_num_head_channels, num_head_channels=attn_num_head_channels,
rescale_output_factor=output_scale_factor, rescale_output_factor=output_scale_factor,
eps=resnet_eps, eps=resnet_eps,
num_groups=resnet_groups,
) )
) )
...@@ -1047,6 +1061,7 @@ class CrossAttnUpBlock2D(nn.Module): ...@@ -1047,6 +1061,7 @@ class CrossAttnUpBlock2D(nn.Module):
out_channels // attn_num_head_channels, out_channels // attn_num_head_channels,
depth=1, depth=1,
context_dim=cross_attention_dim, context_dim=cross_attention_dim,
num_groups=resnet_groups,
) )
) )
self.attentions = nn.ModuleList(attentions) self.attentions = nn.ModuleList(attentions)
......
...@@ -59,6 +59,7 @@ class Encoder(nn.Module): ...@@ -59,6 +59,7 @@ class Encoder(nn.Module):
down_block_types=("DownEncoderBlock2D",), down_block_types=("DownEncoderBlock2D",),
block_out_channels=(64,), block_out_channels=(64,),
layers_per_block=2, layers_per_block=2,
norm_num_groups=32,
act_fn="silu", act_fn="silu",
double_z=True, double_z=True,
): ):
...@@ -86,6 +87,7 @@ class Encoder(nn.Module): ...@@ -86,6 +87,7 @@ class Encoder(nn.Module):
resnet_eps=1e-6, resnet_eps=1e-6,
downsample_padding=0, downsample_padding=0,
resnet_act_fn=act_fn, resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attn_num_head_channels=None, attn_num_head_channels=None,
temb_channels=None, temb_channels=None,
) )
...@@ -99,13 +101,12 @@ class Encoder(nn.Module): ...@@ -99,13 +101,12 @@ class Encoder(nn.Module):
output_scale_factor=1, output_scale_factor=1,
resnet_time_scale_shift="default", resnet_time_scale_shift="default",
attn_num_head_channels=None, attn_num_head_channels=None,
resnet_groups=32, resnet_groups=norm_num_groups,
temb_channels=None, temb_channels=None,
) )
# out # out
num_groups_out = 32 self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=num_groups_out, eps=1e-6)
self.conv_act = nn.SiLU() self.conv_act = nn.SiLU()
conv_out_channels = 2 * out_channels if double_z else out_channels conv_out_channels = 2 * out_channels if double_z else out_channels
...@@ -138,6 +139,7 @@ class Decoder(nn.Module): ...@@ -138,6 +139,7 @@ class Decoder(nn.Module):
up_block_types=("UpDecoderBlock2D",), up_block_types=("UpDecoderBlock2D",),
block_out_channels=(64,), block_out_channels=(64,),
layers_per_block=2, layers_per_block=2,
norm_num_groups=32,
act_fn="silu", act_fn="silu",
): ):
super().__init__() super().__init__()
...@@ -156,7 +158,7 @@ class Decoder(nn.Module): ...@@ -156,7 +158,7 @@ class Decoder(nn.Module):
output_scale_factor=1, output_scale_factor=1,
resnet_time_scale_shift="default", resnet_time_scale_shift="default",
attn_num_head_channels=None, attn_num_head_channels=None,
resnet_groups=32, resnet_groups=norm_num_groups,
temb_channels=None, temb_channels=None,
) )
...@@ -178,6 +180,7 @@ class Decoder(nn.Module): ...@@ -178,6 +180,7 @@ class Decoder(nn.Module):
add_upsample=not is_final_block, add_upsample=not is_final_block,
resnet_eps=1e-6, resnet_eps=1e-6,
resnet_act_fn=act_fn, resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attn_num_head_channels=None, attn_num_head_channels=None,
temb_channels=None, temb_channels=None,
) )
...@@ -185,8 +188,7 @@ class Decoder(nn.Module): ...@@ -185,8 +188,7 @@ class Decoder(nn.Module):
prev_output_channel = output_channel prev_output_channel = output_channel
# out # out
num_groups_out = 32 self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=1e-6)
self.conv_act = nn.SiLU() self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
...@@ -405,6 +407,7 @@ class VQModel(ModelMixin, ConfigMixin): ...@@ -405,6 +407,7 @@ class VQModel(ModelMixin, ConfigMixin):
latent_channels: int = 3, latent_channels: int = 3,
sample_size: int = 32, sample_size: int = 32,
num_vq_embeddings: int = 256, num_vq_embeddings: int = 256,
norm_num_groups: int = 32,
): ):
super().__init__() super().__init__()
...@@ -416,6 +419,7 @@ class VQModel(ModelMixin, ConfigMixin): ...@@ -416,6 +419,7 @@ class VQModel(ModelMixin, ConfigMixin):
block_out_channels=block_out_channels, block_out_channels=block_out_channels,
layers_per_block=layers_per_block, layers_per_block=layers_per_block,
act_fn=act_fn, act_fn=act_fn,
norm_num_groups=norm_num_groups,
double_z=False, double_z=False,
) )
...@@ -433,6 +437,7 @@ class VQModel(ModelMixin, ConfigMixin): ...@@ -433,6 +437,7 @@ class VQModel(ModelMixin, ConfigMixin):
block_out_channels=block_out_channels, block_out_channels=block_out_channels,
layers_per_block=layers_per_block, layers_per_block=layers_per_block,
act_fn=act_fn, act_fn=act_fn,
norm_num_groups=norm_num_groups,
) )
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput: def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput:
...@@ -509,6 +514,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin): ...@@ -509,6 +514,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
layers_per_block: int = 1, layers_per_block: int = 1,
act_fn: str = "silu", act_fn: str = "silu",
latent_channels: int = 4, latent_channels: int = 4,
norm_num_groups: int = 32,
sample_size: int = 32, sample_size: int = 32,
): ):
super().__init__() super().__init__()
...@@ -521,6 +527,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin): ...@@ -521,6 +527,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
block_out_channels=block_out_channels, block_out_channels=block_out_channels,
layers_per_block=layers_per_block, layers_per_block=layers_per_block,
act_fn=act_fn, act_fn=act_fn,
norm_num_groups=norm_num_groups,
double_z=True, double_z=True,
) )
...@@ -531,6 +538,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin): ...@@ -531,6 +538,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
up_block_types=up_block_types, up_block_types=up_block_types,
block_out_channels=block_out_channels, block_out_channels=block_out_channels,
layers_per_block=layers_per_block, layers_per_block=layers_per_block,
norm_num_groups=norm_num_groups,
act_fn=act_fn, act_fn=act_fn,
) )
......
...@@ -99,6 +99,26 @@ class ModelTesterMixin: ...@@ -99,6 +99,26 @@ class ModelTesterMixin:
expected_shape = inputs_dict["sample"].shape expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_forward_with_norm_groups(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["norm_num_groups"] = 16
init_dict["block_out_channels"] = (16, 32)
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.sample
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_forward_signature(self): def test_forward_signature(self):
init_dict, _ = self.prepare_init_args_and_inputs_for_common() init_dict, _ = self.prepare_init_args_and_inputs_for_common()
......
...@@ -293,3 +293,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -293,3 +293,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
# fmt: on # fmt: on
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2)) self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
def test_forward_with_norm_groups(self):
# not required for this model
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment