Unverified Commit f0c74e9a authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

Add unet act fn to other model components (#3136)

Adding act fn config to the unet timestep class embedding and conv
activation.

The custom activation defaults to silu which is the default
activation function for both the conv act and the timestep class
embeddings so default behavior is not changed.

The only unet which use the custom activation is the stable diffusion
latent upscaler https://huggingface.co/stabilityai/sd-x2-latent-upscaler/blob/main/unet/config.json
(I ran a script against the hub to confirm).
The latent upscaler does not use the conv activation nor the timestep
class embeddings so we don't change its behavior.
parent 4bc157ff
...@@ -248,7 +248,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -248,7 +248,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
if class_embed_type is None and num_class_embeds is not None: if class_embed_type is None and num_class_embeds is not None:
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
elif class_embed_type == "timestep": elif class_embed_type == "timestep":
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
elif class_embed_type == "identity": elif class_embed_type == "identity":
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
elif class_embed_type == "projection": elif class_embed_type == "projection":
...@@ -437,7 +437,18 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -437,7 +437,18 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
self.conv_norm_out = nn.GroupNorm( self.conv_norm_out = nn.GroupNorm(
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
) )
self.conv_act = nn.SiLU()
if act_fn == "swish":
self.conv_act = lambda x: F.silu(x)
elif act_fn == "mish":
self.conv_act = nn.Mish()
elif act_fn == "silu":
self.conv_act = nn.SiLU()
elif act_fn == "gelu":
self.conv_act = nn.GELU()
else:
raise ValueError(f"Unsupported activation function: {act_fn}")
else: else:
self.conv_norm_out = None self.conv_norm_out = None
self.conv_act = None self.conv_act = None
......
...@@ -345,7 +345,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -345,7 +345,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
if class_embed_type is None and num_class_embeds is not None: if class_embed_type is None and num_class_embeds is not None:
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
elif class_embed_type == "timestep": elif class_embed_type == "timestep":
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
elif class_embed_type == "identity": elif class_embed_type == "identity":
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
elif class_embed_type == "projection": elif class_embed_type == "projection":
...@@ -534,7 +534,18 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -534,7 +534,18 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
self.conv_norm_out = nn.GroupNorm( self.conv_norm_out = nn.GroupNorm(
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
) )
self.conv_act = nn.SiLU()
if act_fn == "swish":
self.conv_act = lambda x: F.silu(x)
elif act_fn == "mish":
self.conv_act = nn.Mish()
elif act_fn == "silu":
self.conv_act = nn.SiLU()
elif act_fn == "gelu":
self.conv_act = nn.GELU()
else:
raise ValueError(f"Unsupported activation function: {act_fn}")
else: else:
self.conv_norm_out = None self.conv_norm_out = None
self.conv_act = None self.conv_act = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment