Commit 707341ae authored by William Berman's avatar William Berman Committed by Will Berman
Browse files

resnet skip time activation and output scale factor

parent 26b4319a
...@@ -459,6 +459,7 @@ class ResnetBlock2D(nn.Module): ...@@ -459,6 +459,7 @@ class ResnetBlock2D(nn.Module):
pre_norm=True, pre_norm=True,
eps=1e-6, eps=1e-6,
non_linearity="swish", non_linearity="swish",
skip_time_act=False,
time_embedding_norm="default", # default, scale_shift, ada_group time_embedding_norm="default", # default, scale_shift, ada_group
kernel=None, kernel=None,
output_scale_factor=1.0, output_scale_factor=1.0,
...@@ -479,6 +480,7 @@ class ResnetBlock2D(nn.Module): ...@@ -479,6 +480,7 @@ class ResnetBlock2D(nn.Module):
self.down = down self.down = down
self.output_scale_factor = output_scale_factor self.output_scale_factor = output_scale_factor
self.time_embedding_norm = time_embedding_norm self.time_embedding_norm = time_embedding_norm
self.skip_time_act = skip_time_act
if groups_out is None: if groups_out is None:
groups_out = groups groups_out = groups
...@@ -570,7 +572,9 @@ class ResnetBlock2D(nn.Module): ...@@ -570,7 +572,9 @@ class ResnetBlock2D(nn.Module):
hidden_states = self.conv1(hidden_states) hidden_states = self.conv1(hidden_states)
if self.time_emb_proj is not None: if self.time_emb_proj is not None:
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None] if not self.skip_time_act:
temb = self.nonlinearity(temb)
temb = self.time_emb_proj(temb)[:, :, None, None]
if temb is not None and self.time_embedding_norm == "default": if temb is not None and self.time_embedding_norm == "default":
hidden_states = hidden_states + temb hidden_states = hidden_states + temb
......
...@@ -42,6 +42,8 @@ def get_down_block( ...@@ -42,6 +42,8 @@ def get_down_block(
only_cross_attention=False, only_cross_attention=False,
upcast_attention=False, upcast_attention=False,
resnet_time_scale_shift="default", resnet_time_scale_shift="default",
resnet_skip_time_act=False,
resnet_out_scale_factor=1.0,
): ):
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
if down_block_type == "DownBlock2D": if down_block_type == "DownBlock2D":
...@@ -68,6 +70,8 @@ def get_down_block( ...@@ -68,6 +70,8 @@ def get_down_block(
resnet_act_fn=resnet_act_fn, resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups, resnet_groups=resnet_groups,
resnet_time_scale_shift=resnet_time_scale_shift, resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
output_scale_factor=resnet_out_scale_factor,
) )
elif down_block_type == "AttnDownBlock2D": elif down_block_type == "AttnDownBlock2D":
return AttnDownBlock2D( return AttnDownBlock2D(
...@@ -119,6 +123,8 @@ def get_down_block( ...@@ -119,6 +123,8 @@ def get_down_block(
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels, attn_num_head_channels=attn_num_head_channels,
resnet_time_scale_shift=resnet_time_scale_shift, resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
output_scale_factor=resnet_out_scale_factor,
) )
elif down_block_type == "SkipDownBlock2D": elif down_block_type == "SkipDownBlock2D":
return SkipDownBlock2D( return SkipDownBlock2D(
...@@ -214,6 +220,8 @@ def get_up_block( ...@@ -214,6 +220,8 @@ def get_up_block(
only_cross_attention=False, only_cross_attention=False,
upcast_attention=False, upcast_attention=False,
resnet_time_scale_shift="default", resnet_time_scale_shift="default",
resnet_skip_time_act=False,
resnet_out_scale_factor=1.0,
): ):
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
if up_block_type == "UpBlock2D": if up_block_type == "UpBlock2D":
...@@ -241,6 +249,8 @@ def get_up_block( ...@@ -241,6 +249,8 @@ def get_up_block(
resnet_act_fn=resnet_act_fn, resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups, resnet_groups=resnet_groups,
resnet_time_scale_shift=resnet_time_scale_shift, resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
output_scale_factor=resnet_out_scale_factor,
) )
elif up_block_type == "CrossAttnUpBlock2D": elif up_block_type == "CrossAttnUpBlock2D":
if cross_attention_dim is None: if cross_attention_dim is None:
...@@ -279,6 +289,8 @@ def get_up_block( ...@@ -279,6 +289,8 @@ def get_up_block(
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels, attn_num_head_channels=attn_num_head_channels,
resnet_time_scale_shift=resnet_time_scale_shift, resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
output_scale_factor=resnet_out_scale_factor,
) )
elif up_block_type == "AttnUpBlock2D": elif up_block_type == "AttnUpBlock2D":
return AttnUpBlock2D( return AttnUpBlock2D(
...@@ -562,6 +574,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module): ...@@ -562,6 +574,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
attn_num_head_channels=1, attn_num_head_channels=1,
output_scale_factor=1.0, output_scale_factor=1.0,
cross_attention_dim=1280, cross_attention_dim=1280,
skip_time_act=False,
): ):
super().__init__() super().__init__()
...@@ -585,6 +598,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module): ...@@ -585,6 +598,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
non_linearity=resnet_act_fn, non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor, output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm, pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
) )
] ]
attentions = [] attentions = []
...@@ -615,6 +629,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module): ...@@ -615,6 +629,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
non_linearity=resnet_act_fn, non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor, output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm, pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
) )
) )
...@@ -1247,6 +1262,7 @@ class ResnetDownsampleBlock2D(nn.Module): ...@@ -1247,6 +1262,7 @@ class ResnetDownsampleBlock2D(nn.Module):
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
output_scale_factor=1.0, output_scale_factor=1.0,
add_downsample=True, add_downsample=True,
skip_time_act=False,
): ):
super().__init__() super().__init__()
resnets = [] resnets = []
...@@ -1265,6 +1281,7 @@ class ResnetDownsampleBlock2D(nn.Module): ...@@ -1265,6 +1281,7 @@ class ResnetDownsampleBlock2D(nn.Module):
non_linearity=resnet_act_fn, non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor, output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm, pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
) )
) )
...@@ -1284,6 +1301,7 @@ class ResnetDownsampleBlock2D(nn.Module): ...@@ -1284,6 +1301,7 @@ class ResnetDownsampleBlock2D(nn.Module):
non_linearity=resnet_act_fn, non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor, output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm, pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
down=True, down=True,
) )
] ]
...@@ -1337,6 +1355,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module): ...@@ -1337,6 +1355,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
cross_attention_dim=1280, cross_attention_dim=1280,
output_scale_factor=1.0, output_scale_factor=1.0,
add_downsample=True, add_downsample=True,
skip_time_act=False,
): ):
super().__init__() super().__init__()
...@@ -1362,6 +1381,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module): ...@@ -1362,6 +1381,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
non_linearity=resnet_act_fn, non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor, output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm, pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
) )
) )
attentions.append( attentions.append(
...@@ -1394,6 +1414,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module): ...@@ -1394,6 +1414,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
non_linearity=resnet_act_fn, non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor, output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm, pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
down=True, down=True,
) )
] ]
...@@ -2237,6 +2258,7 @@ class ResnetUpsampleBlock2D(nn.Module): ...@@ -2237,6 +2258,7 @@ class ResnetUpsampleBlock2D(nn.Module):
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
output_scale_factor=1.0, output_scale_factor=1.0,
add_upsample=True, add_upsample=True,
skip_time_act=False,
): ):
super().__init__() super().__init__()
resnets = [] resnets = []
...@@ -2257,6 +2279,7 @@ class ResnetUpsampleBlock2D(nn.Module): ...@@ -2257,6 +2279,7 @@ class ResnetUpsampleBlock2D(nn.Module):
non_linearity=resnet_act_fn, non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor, output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm, pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
) )
) )
...@@ -2276,6 +2299,7 @@ class ResnetUpsampleBlock2D(nn.Module): ...@@ -2276,6 +2299,7 @@ class ResnetUpsampleBlock2D(nn.Module):
non_linearity=resnet_act_fn, non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor, output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm, pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
up=True, up=True,
) )
] ]
...@@ -2329,6 +2353,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module): ...@@ -2329,6 +2353,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
cross_attention_dim=1280, cross_attention_dim=1280,
output_scale_factor=1.0, output_scale_factor=1.0,
add_upsample=True, add_upsample=True,
skip_time_act=False,
): ):
super().__init__() super().__init__()
resnets = [] resnets = []
...@@ -2355,6 +2380,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module): ...@@ -2355,6 +2380,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
non_linearity=resnet_act_fn, non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor, output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm, pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
) )
) )
attentions.append( attentions.append(
...@@ -2387,6 +2413,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module): ...@@ -2387,6 +2413,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
non_linearity=resnet_act_fn, non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor, output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm, pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
up=True, up=True,
) )
] ]
......
...@@ -146,6 +146,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -146,6 +146,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
num_class_embeds: Optional[int] = None, num_class_embeds: Optional[int] = None,
upcast_attention: bool = False, upcast_attention: bool = False,
resnet_time_scale_shift: str = "default", resnet_time_scale_shift: str = "default",
resnet_skip_time_act: bool = False,
resnet_out_scale_factor: int = 1.0,
time_embedding_type: str = "positional", time_embedding_type: str = "positional",
timestep_post_act: Optional[str] = None, timestep_post_act: Optional[str] = None,
time_cond_proj_dim: Optional[int] = None, time_cond_proj_dim: Optional[int] = None,
...@@ -291,6 +293,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -291,6 +293,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
only_cross_attention=only_cross_attention[i], only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention, upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift, resnet_time_scale_shift=resnet_time_scale_shift,
resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor,
) )
self.down_blocks.append(down_block) self.down_blocks.append(down_block)
...@@ -321,6 +325,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -321,6 +325,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
attn_num_head_channels=attention_head_dim[-1], attn_num_head_channels=attention_head_dim[-1],
resnet_groups=norm_num_groups, resnet_groups=norm_num_groups,
resnet_time_scale_shift=resnet_time_scale_shift, resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
) )
elif mid_block_type is None: elif mid_block_type is None:
self.mid_block = None self.mid_block = None
...@@ -369,6 +374,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -369,6 +374,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
only_cross_attention=only_cross_attention[i], only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention, upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift, resnet_time_scale_shift=resnet_time_scale_shift,
resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor,
) )
self.up_blocks.append(up_block) self.up_blocks.append(up_block)
prev_output_channel = output_channel prev_output_channel = output_channel
......
...@@ -232,6 +232,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -232,6 +232,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
num_class_embeds: Optional[int] = None, num_class_embeds: Optional[int] = None,
upcast_attention: bool = False, upcast_attention: bool = False,
resnet_time_scale_shift: str = "default", resnet_time_scale_shift: str = "default",
resnet_skip_time_act: bool = False,
resnet_out_scale_factor: int = 1.0,
time_embedding_type: str = "positional", time_embedding_type: str = "positional",
timestep_post_act: Optional[str] = None, timestep_post_act: Optional[str] = None,
time_cond_proj_dim: Optional[int] = None, time_cond_proj_dim: Optional[int] = None,
...@@ -382,6 +384,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -382,6 +384,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
only_cross_attention=only_cross_attention[i], only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention, upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift, resnet_time_scale_shift=resnet_time_scale_shift,
resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor,
) )
self.down_blocks.append(down_block) self.down_blocks.append(down_block)
...@@ -412,6 +416,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -412,6 +416,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
attn_num_head_channels=attention_head_dim[-1], attn_num_head_channels=attention_head_dim[-1],
resnet_groups=norm_num_groups, resnet_groups=norm_num_groups,
resnet_time_scale_shift=resnet_time_scale_shift, resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
) )
elif mid_block_type is None: elif mid_block_type is None:
self.mid_block = None self.mid_block = None
...@@ -460,6 +465,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -460,6 +465,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
only_cross_attention=only_cross_attention[i], only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention, upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift, resnet_time_scale_shift=resnet_time_scale_shift,
resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor,
) )
self.up_blocks.append(up_block) self.up_blocks.append(up_block)
prev_output_channel = output_channel prev_output_channel = output_channel
...@@ -1434,6 +1441,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module): ...@@ -1434,6 +1441,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
attn_num_head_channels=1, attn_num_head_channels=1,
output_scale_factor=1.0, output_scale_factor=1.0,
cross_attention_dim=1280, cross_attention_dim=1280,
skip_time_act=False,
): ):
super().__init__() super().__init__()
...@@ -1457,6 +1465,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module): ...@@ -1457,6 +1465,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
non_linearity=resnet_act_fn, non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor, output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm, pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
) )
] ]
attentions = [] attentions = []
...@@ -1487,6 +1496,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module): ...@@ -1487,6 +1496,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
non_linearity=resnet_act_fn, non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor, output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm, pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
) )
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment