Unverified Commit 41ae6708 authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

move activation dispatches into helper function (#3656)

* move activation dispatches into helper function

* tests
parent 462956be
from torch import nn
def get_activation(act_fn):
if act_fn in ["swish", "silu"]:
return nn.SiLU()
elif act_fn == "mish":
return nn.Mish()
elif act_fn == "gelu":
return nn.GELU()
else:
raise ValueError(f"Unsupported activation function: {act_fn}")
...@@ -18,6 +18,7 @@ import torch.nn.functional as F ...@@ -18,6 +18,7 @@ import torch.nn.functional as F
from torch import nn from torch import nn
from ..utils import maybe_allow_in_graph from ..utils import maybe_allow_in_graph
from .activations import get_activation
from .attention_processor import Attention from .attention_processor import Attention
from .embeddings import CombinedTimestepLabelEmbeddings from .embeddings import CombinedTimestepLabelEmbeddings
...@@ -345,15 +346,11 @@ class AdaGroupNorm(nn.Module): ...@@ -345,15 +346,11 @@ class AdaGroupNorm(nn.Module):
super().__init__() super().__init__()
self.num_groups = num_groups self.num_groups = num_groups
self.eps = eps self.eps = eps
if act_fn is None:
self.act = None self.act = None
if act_fn == "swish": else:
self.act = lambda x: F.silu(x) self.act = get_activation(act_fn)
elif act_fn == "mish":
self.act = nn.Mish()
elif act_fn == "silu":
self.act = nn.SiLU()
elif act_fn == "gelu":
self.act = nn.GELU()
self.linear = nn.Linear(embedding_dim, out_dim * 2) self.linear = nn.Linear(embedding_dim, out_dim * 2)
......
...@@ -18,6 +18,8 @@ import numpy as np ...@@ -18,6 +18,8 @@ import numpy as np
import torch import torch
from torch import nn from torch import nn
from .activations import get_activation
def get_timestep_embedding( def get_timestep_embedding(
timesteps: torch.Tensor, timesteps: torch.Tensor,
...@@ -171,14 +173,7 @@ class TimestepEmbedding(nn.Module): ...@@ -171,14 +173,7 @@ class TimestepEmbedding(nn.Module):
else: else:
self.cond_proj = None self.cond_proj = None
if act_fn == "silu": self.act = get_activation(act_fn)
self.act = nn.SiLU()
elif act_fn == "mish":
self.act = nn.Mish()
elif act_fn == "gelu":
self.act = nn.GELU()
else:
raise ValueError(f"{act_fn} does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'")
if out_dim is not None: if out_dim is not None:
time_embed_dim_out = out_dim time_embed_dim_out = out_dim
...@@ -188,14 +183,8 @@ class TimestepEmbedding(nn.Module): ...@@ -188,14 +183,8 @@ class TimestepEmbedding(nn.Module):
if post_act_fn is None: if post_act_fn is None:
self.post_act = None self.post_act = None
elif post_act_fn == "silu":
self.post_act = nn.SiLU()
elif post_act_fn == "mish":
self.post_act = nn.Mish()
elif post_act_fn == "gelu":
self.post_act = nn.GELU()
else: else:
raise ValueError(f"{post_act_fn} does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'") self.post_act = get_activation(post_act_fn)
def forward(self, sample, condition=None): def forward(self, sample, condition=None):
if condition is not None: if condition is not None:
......
...@@ -20,6 +20,7 @@ import torch ...@@ -20,6 +20,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .activations import get_activation
from .attention import AdaGroupNorm from .attention import AdaGroupNorm
from .attention_processor import SpatialNorm from .attention_processor import SpatialNorm
...@@ -558,14 +559,7 @@ class ResnetBlock2D(nn.Module): ...@@ -558,14 +559,7 @@ class ResnetBlock2D(nn.Module):
conv_2d_out_channels = conv_2d_out_channels or out_channels conv_2d_out_channels = conv_2d_out_channels or out_channels
self.conv2 = torch.nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1) self.conv2 = torch.nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
if non_linearity == "swish": self.nonlinearity = get_activation(non_linearity)
self.nonlinearity = lambda x: F.silu(x)
elif non_linearity == "mish":
self.nonlinearity = nn.Mish()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
elif non_linearity == "gelu":
self.nonlinearity = nn.GELU()
self.upsample = self.downsample = None self.upsample = self.downsample = None
if self.up: if self.up:
...@@ -646,11 +640,6 @@ class ResnetBlock2D(nn.Module): ...@@ -646,11 +640,6 @@ class ResnetBlock2D(nn.Module):
return output_tensor return output_tensor
class Mish(torch.nn.Module):
def forward(self, hidden_states):
return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
# unet_rl.py # unet_rl.py
def rearrange_dims(tensor): def rearrange_dims(tensor):
if len(tensor.shape) == 2: if len(tensor.shape) == 2:
......
...@@ -17,6 +17,7 @@ import torch ...@@ -17,6 +17,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from .activations import get_activation
from .resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D, rearrange_dims from .resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D, rearrange_dims
...@@ -55,14 +56,10 @@ class DownResnetBlock1D(nn.Module): ...@@ -55,14 +56,10 @@ class DownResnetBlock1D(nn.Module):
self.resnets = nn.ModuleList(resnets) self.resnets = nn.ModuleList(resnets)
if non_linearity == "swish": if non_linearity is None:
self.nonlinearity = lambda x: F.silu(x)
elif non_linearity == "mish":
self.nonlinearity = nn.Mish()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
else:
self.nonlinearity = None self.nonlinearity = None
else:
self.nonlinearity = get_activation(non_linearity)
self.downsample = None self.downsample = None
if add_downsample: if add_downsample:
...@@ -119,14 +116,10 @@ class UpResnetBlock1D(nn.Module): ...@@ -119,14 +116,10 @@ class UpResnetBlock1D(nn.Module):
self.resnets = nn.ModuleList(resnets) self.resnets = nn.ModuleList(resnets)
if non_linearity == "swish": if non_linearity is None:
self.nonlinearity = lambda x: F.silu(x)
elif non_linearity == "mish":
self.nonlinearity = nn.Mish()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
else:
self.nonlinearity = None self.nonlinearity = None
else:
self.nonlinearity = get_activation(non_linearity)
self.upsample = None self.upsample = None
if add_upsample: if add_upsample:
...@@ -194,14 +187,10 @@ class MidResTemporalBlock1D(nn.Module): ...@@ -194,14 +187,10 @@ class MidResTemporalBlock1D(nn.Module):
self.resnets = nn.ModuleList(resnets) self.resnets = nn.ModuleList(resnets)
if non_linearity == "swish": if non_linearity is None:
self.nonlinearity = lambda x: F.silu(x)
elif non_linearity == "mish":
self.nonlinearity = nn.Mish()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
else:
self.nonlinearity = None self.nonlinearity = None
else:
self.nonlinearity = get_activation(non_linearity)
self.upsample = None self.upsample = None
if add_upsample: if add_upsample:
...@@ -232,10 +221,7 @@ class OutConv1DBlock(nn.Module): ...@@ -232,10 +221,7 @@ class OutConv1DBlock(nn.Module):
super().__init__() super().__init__()
self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2) self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2)
self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim) self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim)
if act_fn == "silu": self.final_conv1d_act = get_activation(act_fn)
self.final_conv1d_act = nn.SiLU()
if act_fn == "mish":
self.final_conv1d_act = nn.Mish()
self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1) self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1)
def forward(self, hidden_states, temb=None): def forward(self, hidden_states, temb=None):
......
...@@ -16,12 +16,12 @@ from typing import Any, Dict, List, Optional, Tuple, Union ...@@ -16,12 +16,12 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import UNet2DConditionLoadersMixin from ..loaders import UNet2DConditionLoadersMixin
from ..utils import BaseOutput, logging from ..utils import BaseOutput, logging
from .activations import get_activation
from .attention_processor import AttentionProcessor, AttnProcessor from .attention_processor import AttentionProcessor, AttnProcessor
from .embeddings import ( from .embeddings import (
GaussianFourierProjection, GaussianFourierProjection,
...@@ -338,16 +338,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -338,16 +338,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
if time_embedding_act_fn is None: if time_embedding_act_fn is None:
self.time_embed_act = None self.time_embed_act = None
elif time_embedding_act_fn == "swish":
self.time_embed_act = lambda x: F.silu(x)
elif time_embedding_act_fn == "mish":
self.time_embed_act = nn.Mish()
elif time_embedding_act_fn == "silu":
self.time_embed_act = nn.SiLU()
elif time_embedding_act_fn == "gelu":
self.time_embed_act = nn.GELU()
else: else:
raise ValueError(f"Unsupported activation function: {time_embedding_act_fn}") self.time_embed_act = get_activation(time_embedding_act_fn)
self.down_blocks = nn.ModuleList([]) self.down_blocks = nn.ModuleList([])
self.up_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([])
...@@ -501,16 +493,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -501,16 +493,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
) )
if act_fn == "swish": self.conv_act = get_activation(act_fn)
self.conv_act = lambda x: F.silu(x)
elif act_fn == "mish":
self.conv_act = nn.Mish()
elif act_fn == "silu":
self.conv_act = nn.SiLU()
elif act_fn == "gelu":
self.conv_act = nn.GELU()
else:
raise ValueError(f"Unsupported activation function: {act_fn}")
else: else:
self.conv_norm_out = None self.conv_norm_out = None
......
...@@ -7,6 +7,7 @@ import torch.nn.functional as F ...@@ -7,6 +7,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...models import ModelMixin from ...models import ModelMixin
from ...models.activations import get_activation
from ...models.attention import Attention from ...models.attention import Attention
from ...models.attention_processor import ( from ...models.attention_processor import (
AttentionProcessor, AttentionProcessor,
...@@ -441,16 +442,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -441,16 +442,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
if time_embedding_act_fn is None: if time_embedding_act_fn is None:
self.time_embed_act = None self.time_embed_act = None
elif time_embedding_act_fn == "swish":
self.time_embed_act = lambda x: F.silu(x)
elif time_embedding_act_fn == "mish":
self.time_embed_act = nn.Mish()
elif time_embedding_act_fn == "silu":
self.time_embed_act = nn.SiLU()
elif time_embedding_act_fn == "gelu":
self.time_embed_act = nn.GELU()
else: else:
raise ValueError(f"Unsupported activation function: {time_embedding_act_fn}") self.time_embed_act = get_activation(time_embedding_act_fn)
self.down_blocks = nn.ModuleList([]) self.down_blocks = nn.ModuleList([])
self.up_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([])
...@@ -604,16 +597,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -604,16 +597,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
) )
if act_fn == "swish": self.conv_act = get_activation(act_fn)
self.conv_act = lambda x: F.silu(x)
elif act_fn == "mish":
self.conv_act = nn.Mish()
elif act_fn == "silu":
self.conv_act = nn.SiLU()
elif act_fn == "gelu":
self.conv_act = nn.GELU()
else:
raise ValueError(f"Unsupported activation function: {act_fn}")
else: else:
self.conv_norm_out = None self.conv_norm_out = None
......
import unittest
import torch
from torch import nn
from diffusers.models.activations import get_activation
class ActivationsTests(unittest.TestCase):
def test_swish(self):
act = get_activation("swish")
self.assertIsInstance(act, nn.SiLU)
self.assertEqual(act(torch.tensor(-100, dtype=torch.float32)).item(), 0)
self.assertNotEqual(act(torch.tensor(-1, dtype=torch.float32)).item(), 0)
self.assertEqual(act(torch.tensor(0, dtype=torch.float32)).item(), 0)
self.assertEqual(act(torch.tensor(20, dtype=torch.float32)).item(), 20)
def test_silu(self):
act = get_activation("silu")
self.assertIsInstance(act, nn.SiLU)
self.assertEqual(act(torch.tensor(-100, dtype=torch.float32)).item(), 0)
self.assertNotEqual(act(torch.tensor(-1, dtype=torch.float32)).item(), 0)
self.assertEqual(act(torch.tensor(0, dtype=torch.float32)).item(), 0)
self.assertEqual(act(torch.tensor(20, dtype=torch.float32)).item(), 20)
def test_mish(self):
act = get_activation("mish")
self.assertIsInstance(act, nn.Mish)
self.assertEqual(act(torch.tensor(-200, dtype=torch.float32)).item(), 0)
self.assertNotEqual(act(torch.tensor(-1, dtype=torch.float32)).item(), 0)
self.assertEqual(act(torch.tensor(0, dtype=torch.float32)).item(), 0)
self.assertEqual(act(torch.tensor(20, dtype=torch.float32)).item(), 20)
def test_gelu(self):
act = get_activation("gelu")
self.assertIsInstance(act, nn.GELU)
self.assertEqual(act(torch.tensor(-100, dtype=torch.float32)).item(), 0)
self.assertNotEqual(act(torch.tensor(-1, dtype=torch.float32)).item(), 0)
self.assertEqual(act(torch.tensor(0, dtype=torch.float32)).item(), 0)
self.assertEqual(act(torch.tensor(20, dtype=torch.float32)).item(), 20)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment