Unverified Commit 2d52e81c authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

unet time embedding activation function (#3048)

* unet time embedding activation function

* typo act_fn -> time_embedding_act_fn

* flatten conditional
parent 52c4d32d
......@@ -16,6 +16,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from ..configuration_utils import ConfigMixin, register_to_config
......@@ -101,6 +102,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
class conditioning with `class_embed_type` equal to `None`.
time_embedding_type (`str`, *optional*, default to `positional`):
The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
time_embedding_act_fn (`str`, *optional*, default to `None`):
Optional activation function to use on the time embeddings only one time before they as passed to the rest
of the unet. Choose from `silu`, `mish`, `gelu`, and `swish`.
timestep_post_act (`str, *optional*, default to `None`):
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
time_cond_proj_dim (`int`, *optional*, default to `None`):
......@@ -157,6 +161,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
resnet_skip_time_act: bool = False,
resnet_out_scale_factor: int = 1.0,
time_embedding_type: str = "positional",
time_embedding_act_fn: Optional[str] = None,
timestep_post_act: Optional[str] = None,
time_cond_proj_dim: Optional[int] = None,
conv_in_kernel: int = 3,
......@@ -267,6 +272,19 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
else:
self.class_embedding = None
if time_embedding_act_fn is None:
self.time_embed_act = None
elif time_embedding_act_fn == "swish":
self.time_embed_act = lambda x: F.silu(x)
elif time_embedding_act_fn == "mish":
self.time_embed_act = nn.Mish()
elif time_embedding_act_fn == "silu":
self.time_embed_act = nn.SiLU()
elif time_embedding_act_fn == "gelu":
self.time_embed_act = nn.GELU()
else:
raise ValueError(f"Unsupported activation function: {time_embedding_act_fn}")
self.down_blocks = nn.ModuleList([])
self.up_blocks = nn.ModuleList([])
......@@ -657,6 +675,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
else:
emb = emb + class_emb
if self.time_embed_act is not None:
emb = self.time_embed_act(emb)
if self.encoder_hid_proj is not None:
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
......
......@@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...models import ModelMixin
......@@ -182,6 +183,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
class conditioning with `class_embed_type` equal to `None`.
time_embedding_type (`str`, *optional*, default to `positional`):
The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
time_embedding_act_fn (`str`, *optional*, default to `None`):
Optional activation function to use on the time embeddings only one time before they as passed to the rest
of the unet. Choose from `silu`, `mish`, `gelu`, and `swish`.
timestep_post_act (`str, *optional*, default to `None`):
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
time_cond_proj_dim (`int`, *optional*, default to `None`):
......@@ -243,6 +247,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
resnet_skip_time_act: bool = False,
resnet_out_scale_factor: int = 1.0,
time_embedding_type: str = "positional",
time_embedding_act_fn: Optional[str] = None,
timestep_post_act: Optional[str] = None,
time_cond_proj_dim: Optional[int] = None,
conv_in_kernel: int = 3,
......@@ -359,6 +364,19 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
else:
self.class_embedding = None
if time_embedding_act_fn is None:
self.time_embed_act = None
elif time_embedding_act_fn == "swish":
self.time_embed_act = lambda x: F.silu(x)
elif time_embedding_act_fn == "mish":
self.time_embed_act = nn.Mish()
elif time_embedding_act_fn == "silu":
self.time_embed_act = nn.SiLU()
elif time_embedding_act_fn == "gelu":
self.time_embed_act = nn.GELU()
else:
raise ValueError(f"Unsupported activation function: {time_embedding_act_fn}")
self.down_blocks = nn.ModuleList([])
self.up_blocks = nn.ModuleList([])
......@@ -752,6 +770,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
else:
emb = emb + class_emb
if self.time_embed_act is not None:
emb = self.time_embed_act(emb)
if self.encoder_hid_proj is not None:
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment