"docs/git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "48d5b095a254bf46c7312ef2ebbb13991bde57d1"
Commit 707341ae authored by William Berman's avatar William Berman Committed by Will Berman
Browse files

resnet skip time activation and output scale factor

parent 26b4319a
...@@ -459,6 +459,7 @@ class ResnetBlock2D(nn.Module): ...@@ -459,6 +459,7 @@ class ResnetBlock2D(nn.Module):
pre_norm=True, pre_norm=True,
eps=1e-6, eps=1e-6,
non_linearity="swish", non_linearity="swish",
skip_time_act=False,
time_embedding_norm="default", # default, scale_shift, ada_group time_embedding_norm="default", # default, scale_shift, ada_group
kernel=None, kernel=None,
output_scale_factor=1.0, output_scale_factor=1.0,
...@@ -479,6 +480,7 @@ class ResnetBlock2D(nn.Module): ...@@ -479,6 +480,7 @@ class ResnetBlock2D(nn.Module):
self.down = down self.down = down
self.output_scale_factor = output_scale_factor self.output_scale_factor = output_scale_factor
self.time_embedding_norm = time_embedding_norm self.time_embedding_norm = time_embedding_norm
self.skip_time_act = skip_time_act
if groups_out is None: if groups_out is None:
groups_out = groups groups_out = groups
...@@ -570,7 +572,9 @@ class ResnetBlock2D(nn.Module): ...@@ -570,7 +572,9 @@ class ResnetBlock2D(nn.Module):
hidden_states = self.conv1(hidden_states) hidden_states = self.conv1(hidden_states)
if self.time_emb_proj is not None: if self.time_emb_proj is not None:
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None] if not self.skip_time_act:
temb = self.nonlinearity(temb)
temb = self.time_emb_proj(temb)[:, :, None, None]
if temb is not None and self.time_embedding_norm == "default": if temb is not None and self.time_embedding_norm == "default":
hidden_states = hidden_states + temb hidden_states = hidden_states + temb
......
...@@ -42,6 +42,8 @@ def get_down_block( ...@@ -42,6 +42,8 @@ def get_down_block(
only_cross_attention=False, only_cross_attention=False,
upcast_attention=False, upcast_attention=False,
resnet_time_scale_shift="default", resnet_time_scale_shift="default",
resnet_skip_time_act=False,
resnet_out_scale_factor=1.0,
): ):
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
if down_block_type == "DownBlock2D": if down_block_type == "DownBlock2D":
...@@ -68,6 +70,8 @@ def get_down_block( ...@@ -68,6 +70,8 @@ def get_down_block(
resnet_act_fn=resnet_act_fn, resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups, resnet_groups=resnet_groups,
resnet_time_scale_shift=resnet_time_scale_shift, resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
output_scale_factor=resnet_out_scale_factor,
) )
elif down_block_type == "AttnDownBlock2D": elif down_block_type == "AttnDownBlock2D":
return AttnDownBlock2D( return AttnDownBlock2D(
...@@ -119,6 +123,8 @@ def get_down_block( ...@@ -119,6 +123,8 @@ def get_down_block(
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels, attn_num_head_channels=attn_num_head_channels,
resnet_time_scale_shift=resnet_time_scale_shift, resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
output_scale_factor=resnet_out_scale_factor,
) )
elif down_block_type == "SkipDownBlock2D": elif down_block_type == "SkipDownBlock2D":
return SkipDownBlock2D( return SkipDownBlock2D(
...@@ -214,6 +220,8 @@ def get_up_block( ...@@ -214,6 +220,8 @@ def get_up_block(
only_cross_attention=False, only_cross_attention=False,
upcast_attention=False, upcast_attention=False,
resnet_time_scale_shift="default", resnet_time_scale_shift="default",
resnet_skip_time_act=False,
resnet_out_scale_factor=1.0,
): ):
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
if up_block_type == "UpBlock2D": if up_block_type == "UpBlock2D":
...@@ -241,6 +249,8 @@ def get_up_block( ...@@ -241,6 +249,8 @@ def get_up_block(
resnet_act_fn=resnet_act_fn, resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups, resnet_groups=resnet_groups,
resnet_time_scale_shift=resnet_time_scale_shift, resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
output_scale_factor=resnet_out_scale_factor,
) )
elif up_block_type == "CrossAttnUpBlock2D": elif up_block_type == "CrossAttnUpBlock2D":
if cross_attention_dim is None: if cross_attention_dim is None:
...@@ -279,6 +289,8 @@ def get_up_block( ...@@ -279,6 +289,8 @@ def get_up_block(
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels, attn_num_head_channels=attn_num_head_channels,
resnet_time_scale_shift=resnet_time_scale_shift, resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
output_scale_factor=resnet_out_scale_factor,
) )
elif up_block_type == "AttnUpBlock2D": elif up_block_type == "AttnUpBlock2D":
return AttnUpBlock2D( return AttnUpBlock2D(
...@@ -562,6 +574,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module): ...@@ -562,6 +574,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
attn_num_head_channels=1, attn_num_head_channels=1,
output_scale_factor=1.0, output_scale_factor=1.0,
cross_attention_dim=1280, cross_attention_dim=1280,
skip_time_act=False,
): ):
super().__init__() super().__init__()
...@@ -585,6 +598,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module): ...@@ -585,6 +598,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
non_linearity=resnet_act_fn, non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor, output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm, pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
) )
] ]
attentions = [] attentions = []
...@@ -615,6 +629,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module): ...@@ -615,6 +629,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
non_linearity=resnet_act_fn, non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor, output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm, pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
) )
) )
...@@ -1247,6 +1262,7 @@ class ResnetDownsampleBlock2D(nn.Module): ...@@ -1247,6 +1262,7 @@ class ResnetDownsampleBlock2D(nn.Module):
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
output_scale_factor=1.0, output_scale_factor=1.0,
add_downsample=True, add_downsample=True,
skip_time_act=False,
): ):
super().__init__() super().__init__()
resnets = [] resnets = []
...@@ -1265,6 +1281,7 @@ class ResnetDownsampleBlock2D(nn.Module): ...@@ -1265,6 +1281,7 @@ class ResnetDownsampleBlock2D(nn.Module):
non_linearity=resnet_act_fn, non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor, output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm, pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
) )
) )
...@@ -1284,6 +1301,7 @@ class ResnetDownsampleBlock2D(nn.Module): ...@@ -1284,6 +1301,7 @@ class ResnetDownsampleBlock2D(nn.Module):
non_linearity=resnet_act_fn, non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor, output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm, pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
down=True, down=True,
) )
] ]
...@@ -1337,6 +1355,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module): ...@@ -1337,6 +1355,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
cross_attention_dim=1280, cross_attention_dim=1280,
output_scale_factor=1.0, output_scale_factor=1.0,
add_downsample=True, add_downsample=True,
skip_time_act=False,
): ):
super().__init__() super().__init__()
...@@ -1362,6 +1381,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module): ...@@ -1362,6 +1381,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
non_linearity=resnet_act_fn, non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor, output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm, pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
) )
) )
attentions.append( attentions.append(
...@@ -1394,6 +1414,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module): ...@@ -1394,6 +1414,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
non_linearity=resnet_act_fn, non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor, output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm, pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
down=True, down=True,
) )
] ]
...@@ -2237,6 +2258,7 @@ class ResnetUpsampleBlock2D(nn.Module): ...@@ -2237,6 +2258,7 @@ class ResnetUpsampleBlock2D(nn.Module):
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
output_scale_factor=1.0, output_scale_factor=1.0,
add_upsample=True, add_upsample=True,
skip_time_act=False,
): ):
super().__init__() super().__init__()
resnets = [] resnets = []
...@@ -2257,6 +2279,7 @@ class ResnetUpsampleBlock2D(nn.Module): ...@@ -2257,6 +2279,7 @@ class ResnetUpsampleBlock2D(nn.Module):
non_linearity=resnet_act_fn, non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor, output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm, pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
) )
) )
...@@ -2276,6 +2299,7 @@ class ResnetUpsampleBlock2D(nn.Module): ...@@ -2276,6 +2299,7 @@ class ResnetUpsampleBlock2D(nn.Module):
non_linearity=resnet_act_fn, non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor, output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm, pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
up=True, up=True,
) )
] ]
...@@ -2329,6 +2353,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module): ...@@ -2329,6 +2353,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
cross_attention_dim=1280, cross_attention_dim=1280,
output_scale_factor=1.0, output_scale_factor=1.0,
add_upsample=True, add_upsample=True,
skip_time_act=False,
): ):
super().__init__() super().__init__()
resnets = [] resnets = []
...@@ -2355,6 +2380,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module): ...@@ -2355,6 +2380,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
non_linearity=resnet_act_fn, non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor, output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm, pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
) )
) )
attentions.append( attentions.append(
...@@ -2387,6 +2413,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module): ...@@ -2387,6 +2413,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
non_linearity=resnet_act_fn, non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor, output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm, pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
up=True, up=True,
) )
] ]
......
...@@ -146,6 +146,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -146,6 +146,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
num_class_embeds: Optional[int] = None, num_class_embeds: Optional[int] = None,
upcast_attention: bool = False, upcast_attention: bool = False,
resnet_time_scale_shift: str = "default", resnet_time_scale_shift: str = "default",
resnet_skip_time_act: bool = False,
resnet_out_scale_factor: int = 1.0,
time_embedding_type: str = "positional", time_embedding_type: str = "positional",
timestep_post_act: Optional[str] = None, timestep_post_act: Optional[str] = None,
time_cond_proj_dim: Optional[int] = None, time_cond_proj_dim: Optional[int] = None,
...@@ -291,6 +293,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -291,6 +293,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
only_cross_attention=only_cross_attention[i], only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention, upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift, resnet_time_scale_shift=resnet_time_scale_shift,
resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor,
) )
self.down_blocks.append(down_block) self.down_blocks.append(down_block)
...@@ -321,6 +325,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -321,6 +325,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
attn_num_head_channels=attention_head_dim[-1], attn_num_head_channels=attention_head_dim[-1],
resnet_groups=norm_num_groups, resnet_groups=norm_num_groups,
resnet_time_scale_shift=resnet_time_scale_shift, resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
) )
elif mid_block_type is None: elif mid_block_type is None:
self.mid_block = None self.mid_block = None
...@@ -369,6 +374,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -369,6 +374,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
only_cross_attention=only_cross_attention[i], only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention, upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift, resnet_time_scale_shift=resnet_time_scale_shift,
resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor,
) )
self.up_blocks.append(up_block) self.up_blocks.append(up_block)
prev_output_channel = output_channel prev_output_channel = output_channel
......
...@@ -232,6 +232,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -232,6 +232,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
num_class_embeds: Optional[int] = None, num_class_embeds: Optional[int] = None,
upcast_attention: bool = False, upcast_attention: bool = False,
resnet_time_scale_shift: str = "default", resnet_time_scale_shift: str = "default",
resnet_skip_time_act: bool = False,
resnet_out_scale_factor: int = 1.0,
time_embedding_type: str = "positional", time_embedding_type: str = "positional",
timestep_post_act: Optional[str] = None, timestep_post_act: Optional[str] = None,
time_cond_proj_dim: Optional[int] = None, time_cond_proj_dim: Optional[int] = None,
...@@ -382,6 +384,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -382,6 +384,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
only_cross_attention=only_cross_attention[i], only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention, upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift, resnet_time_scale_shift=resnet_time_scale_shift,
resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor,
) )
self.down_blocks.append(down_block) self.down_blocks.append(down_block)
...@@ -412,6 +416,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -412,6 +416,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
attn_num_head_channels=attention_head_dim[-1], attn_num_head_channels=attention_head_dim[-1],
resnet_groups=norm_num_groups, resnet_groups=norm_num_groups,
resnet_time_scale_shift=resnet_time_scale_shift, resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
) )
elif mid_block_type is None: elif mid_block_type is None:
self.mid_block = None self.mid_block = None
...@@ -460,6 +465,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -460,6 +465,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
only_cross_attention=only_cross_attention[i], only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention, upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift, resnet_time_scale_shift=resnet_time_scale_shift,
resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor,
) )
self.up_blocks.append(up_block) self.up_blocks.append(up_block)
prev_output_channel = output_channel prev_output_channel = output_channel
...@@ -1434,6 +1441,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module): ...@@ -1434,6 +1441,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
attn_num_head_channels=1, attn_num_head_channels=1,
output_scale_factor=1.0, output_scale_factor=1.0,
cross_attention_dim=1280, cross_attention_dim=1280,
skip_time_act=False,
): ):
super().__init__() super().__init__()
...@@ -1457,6 +1465,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module): ...@@ -1457,6 +1465,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
non_linearity=resnet_act_fn, non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor, output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm, pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
) )
] ]
attentions = [] attentions = []
...@@ -1487,6 +1496,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module): ...@@ -1487,6 +1496,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
non_linearity=resnet_act_fn, non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor, output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm, pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
) )
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment