"...git@developer.sourcefind.cn:OpenDAS/diffusers.git" did not exist on "a81334e3f01550f79417af81caa92a9da58f135b"
Commit 707341ae authored by William Berman's avatar William Berman Committed by Will Berman
Browse files

resnet skip time activation and output scale factor

parent 26b4319a
......@@ -459,6 +459,7 @@ class ResnetBlock2D(nn.Module):
pre_norm=True,
eps=1e-6,
non_linearity="swish",
skip_time_act=False,
time_embedding_norm="default", # default, scale_shift, ada_group
kernel=None,
output_scale_factor=1.0,
......@@ -479,6 +480,7 @@ class ResnetBlock2D(nn.Module):
self.down = down
self.output_scale_factor = output_scale_factor
self.time_embedding_norm = time_embedding_norm
self.skip_time_act = skip_time_act
if groups_out is None:
groups_out = groups
......@@ -570,7 +572,9 @@ class ResnetBlock2D(nn.Module):
hidden_states = self.conv1(hidden_states)
if self.time_emb_proj is not None:
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
if not self.skip_time_act:
temb = self.nonlinearity(temb)
temb = self.time_emb_proj(temb)[:, :, None, None]
if temb is not None and self.time_embedding_norm == "default":
hidden_states = hidden_states + temb
......
......@@ -42,6 +42,8 @@ def get_down_block(
only_cross_attention=False,
upcast_attention=False,
resnet_time_scale_shift="default",
resnet_skip_time_act=False,
resnet_out_scale_factor=1.0,
):
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
if down_block_type == "DownBlock2D":
......@@ -68,6 +70,8 @@ def get_down_block(
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
output_scale_factor=resnet_out_scale_factor,
)
elif down_block_type == "AttnDownBlock2D":
return AttnDownBlock2D(
......@@ -119,6 +123,8 @@ def get_down_block(
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels,
resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
output_scale_factor=resnet_out_scale_factor,
)
elif down_block_type == "SkipDownBlock2D":
return SkipDownBlock2D(
......@@ -214,6 +220,8 @@ def get_up_block(
only_cross_attention=False,
upcast_attention=False,
resnet_time_scale_shift="default",
resnet_skip_time_act=False,
resnet_out_scale_factor=1.0,
):
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
if up_block_type == "UpBlock2D":
......@@ -241,6 +249,8 @@ def get_up_block(
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
output_scale_factor=resnet_out_scale_factor,
)
elif up_block_type == "CrossAttnUpBlock2D":
if cross_attention_dim is None:
......@@ -279,6 +289,8 @@ def get_up_block(
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels,
resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
output_scale_factor=resnet_out_scale_factor,
)
elif up_block_type == "AttnUpBlock2D":
return AttnUpBlock2D(
......@@ -562,6 +574,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
attn_num_head_channels=1,
output_scale_factor=1.0,
cross_attention_dim=1280,
skip_time_act=False,
):
super().__init__()
......@@ -585,6 +598,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
)
]
attentions = []
......@@ -615,6 +629,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
)
)
......@@ -1247,6 +1262,7 @@ class ResnetDownsampleBlock2D(nn.Module):
resnet_pre_norm: bool = True,
output_scale_factor=1.0,
add_downsample=True,
skip_time_act=False,
):
super().__init__()
resnets = []
......@@ -1265,6 +1281,7 @@ class ResnetDownsampleBlock2D(nn.Module):
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
)
)
......@@ -1284,6 +1301,7 @@ class ResnetDownsampleBlock2D(nn.Module):
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
down=True,
)
]
......@@ -1337,6 +1355,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
cross_attention_dim=1280,
output_scale_factor=1.0,
add_downsample=True,
skip_time_act=False,
):
super().__init__()
......@@ -1362,6 +1381,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
)
)
attentions.append(
......@@ -1394,6 +1414,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
down=True,
)
]
......@@ -2237,6 +2258,7 @@ class ResnetUpsampleBlock2D(nn.Module):
resnet_pre_norm: bool = True,
output_scale_factor=1.0,
add_upsample=True,
skip_time_act=False,
):
super().__init__()
resnets = []
......@@ -2257,6 +2279,7 @@ class ResnetUpsampleBlock2D(nn.Module):
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
)
)
......@@ -2276,6 +2299,7 @@ class ResnetUpsampleBlock2D(nn.Module):
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
up=True,
)
]
......@@ -2329,6 +2353,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
cross_attention_dim=1280,
output_scale_factor=1.0,
add_upsample=True,
skip_time_act=False,
):
super().__init__()
resnets = []
......@@ -2355,6 +2380,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
)
)
attentions.append(
......@@ -2387,6 +2413,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
up=True,
)
]
......
......@@ -146,6 +146,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
num_class_embeds: Optional[int] = None,
upcast_attention: bool = False,
resnet_time_scale_shift: str = "default",
resnet_skip_time_act: bool = False,
resnet_out_scale_factor: int = 1.0,
time_embedding_type: str = "positional",
timestep_post_act: Optional[str] = None,
time_cond_proj_dim: Optional[int] = None,
......@@ -291,6 +293,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor,
)
self.down_blocks.append(down_block)
......@@ -321,6 +325,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
attn_num_head_channels=attention_head_dim[-1],
resnet_groups=norm_num_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
)
elif mid_block_type is None:
self.mid_block = None
......@@ -369,6 +374,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
......
......@@ -232,6 +232,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
num_class_embeds: Optional[int] = None,
upcast_attention: bool = False,
resnet_time_scale_shift: str = "default",
resnet_skip_time_act: bool = False,
resnet_out_scale_factor: int = 1.0,
time_embedding_type: str = "positional",
timestep_post_act: Optional[str] = None,
time_cond_proj_dim: Optional[int] = None,
......@@ -382,6 +384,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor,
)
self.down_blocks.append(down_block)
......@@ -412,6 +416,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
attn_num_head_channels=attention_head_dim[-1],
resnet_groups=norm_num_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
)
elif mid_block_type is None:
self.mid_block = None
......@@ -460,6 +465,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
......@@ -1434,6 +1441,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
attn_num_head_channels=1,
output_scale_factor=1.0,
cross_attention_dim=1280,
skip_time_act=False,
):
super().__init__()
......@@ -1457,6 +1465,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
)
]
attentions = []
......@@ -1487,6 +1496,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
)
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment