Unverified Commit 2d52e81c authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

unet time embedding activation function (#3048)

* unet time embedding activation function

* typo act_fn -> time_embedding_act_fn

* flatten conditional
parent 52c4d32d
...@@ -16,6 +16,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union ...@@ -16,6 +16,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
...@@ -101,6 +102,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -101,6 +102,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
class conditioning with `class_embed_type` equal to `None`. class conditioning with `class_embed_type` equal to `None`.
time_embedding_type (`str`, *optional*, default to `positional`): time_embedding_type (`str`, *optional*, default to `positional`):
The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
time_embedding_act_fn (`str`, *optional*, default to `None`):
Optional activation function to use on the time embeddings only one time before they as passed to the rest
of the unet. Choose from `silu`, `mish`, `gelu`, and `swish`.
timestep_post_act (`str, *optional*, default to `None`): timestep_post_act (`str, *optional*, default to `None`):
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
time_cond_proj_dim (`int`, *optional*, default to `None`): time_cond_proj_dim (`int`, *optional*, default to `None`):
...@@ -157,6 +161,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -157,6 +161,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
resnet_skip_time_act: bool = False, resnet_skip_time_act: bool = False,
resnet_out_scale_factor: int = 1.0, resnet_out_scale_factor: int = 1.0,
time_embedding_type: str = "positional", time_embedding_type: str = "positional",
time_embedding_act_fn: Optional[str] = None,
timestep_post_act: Optional[str] = None, timestep_post_act: Optional[str] = None,
time_cond_proj_dim: Optional[int] = None, time_cond_proj_dim: Optional[int] = None,
conv_in_kernel: int = 3, conv_in_kernel: int = 3,
...@@ -267,6 +272,19 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -267,6 +272,19 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
else: else:
self.class_embedding = None self.class_embedding = None
if time_embedding_act_fn is None:
self.time_embed_act = None
elif time_embedding_act_fn == "swish":
self.time_embed_act = lambda x: F.silu(x)
elif time_embedding_act_fn == "mish":
self.time_embed_act = nn.Mish()
elif time_embedding_act_fn == "silu":
self.time_embed_act = nn.SiLU()
elif time_embedding_act_fn == "gelu":
self.time_embed_act = nn.GELU()
else:
raise ValueError(f"Unsupported activation function: {time_embedding_act_fn}")
self.down_blocks = nn.ModuleList([]) self.down_blocks = nn.ModuleList([])
self.up_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([])
...@@ -657,6 +675,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -657,6 +675,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
else: else:
emb = emb + class_emb emb = emb + class_emb
if self.time_embed_act is not None:
emb = self.time_embed_act(emb)
if self.encoder_hid_proj is not None: if self.encoder_hid_proj is not None:
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
......
...@@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union ...@@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...models import ModelMixin from ...models import ModelMixin
...@@ -182,6 +183,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -182,6 +183,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
class conditioning with `class_embed_type` equal to `None`. class conditioning with `class_embed_type` equal to `None`.
time_embedding_type (`str`, *optional*, default to `positional`): time_embedding_type (`str`, *optional*, default to `positional`):
The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
time_embedding_act_fn (`str`, *optional*, default to `None`):
Optional activation function to use on the time embeddings only one time before they as passed to the rest
of the unet. Choose from `silu`, `mish`, `gelu`, and `swish`.
timestep_post_act (`str, *optional*, default to `None`): timestep_post_act (`str, *optional*, default to `None`):
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
time_cond_proj_dim (`int`, *optional*, default to `None`): time_cond_proj_dim (`int`, *optional*, default to `None`):
...@@ -243,6 +247,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -243,6 +247,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
resnet_skip_time_act: bool = False, resnet_skip_time_act: bool = False,
resnet_out_scale_factor: int = 1.0, resnet_out_scale_factor: int = 1.0,
time_embedding_type: str = "positional", time_embedding_type: str = "positional",
time_embedding_act_fn: Optional[str] = None,
timestep_post_act: Optional[str] = None, timestep_post_act: Optional[str] = None,
time_cond_proj_dim: Optional[int] = None, time_cond_proj_dim: Optional[int] = None,
conv_in_kernel: int = 3, conv_in_kernel: int = 3,
...@@ -359,6 +364,19 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -359,6 +364,19 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
else: else:
self.class_embedding = None self.class_embedding = None
if time_embedding_act_fn is None:
self.time_embed_act = None
elif time_embedding_act_fn == "swish":
self.time_embed_act = lambda x: F.silu(x)
elif time_embedding_act_fn == "mish":
self.time_embed_act = nn.Mish()
elif time_embedding_act_fn == "silu":
self.time_embed_act = nn.SiLU()
elif time_embedding_act_fn == "gelu":
self.time_embed_act = nn.GELU()
else:
raise ValueError(f"Unsupported activation function: {time_embedding_act_fn}")
self.down_blocks = nn.ModuleList([]) self.down_blocks = nn.ModuleList([])
self.up_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([])
...@@ -752,6 +770,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -752,6 +770,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
else: else:
emb = emb + class_emb emb = emb + class_emb
if self.time_embed_act is not None:
emb = self.time_embed_act(emb)
if self.encoder_hid_proj is not None: if self.encoder_hid_proj is not None:
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment