Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
41ae6708
Unverified
Commit
41ae6708
authored
Jun 05, 2023
by
Will Berman
Committed by
GitHub
Jun 05, 2023
Browse files
move activation dispatches into helper function (#3656)
* move activation dispatches into helper function * tests
parent
462956be
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
89 additions
and
101 deletions
+89
-101
src/diffusers/models/activations.py
src/diffusers/models/activations.py
+12
-0
src/diffusers/models/attention.py
src/diffusers/models/attention.py
+6
-9
src/diffusers/models/embeddings.py
src/diffusers/models/embeddings.py
+4
-15
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+2
-13
src/diffusers/models/unet_1d_blocks.py
src/diffusers/models/unet_1d_blocks.py
+11
-25
src/diffusers/models/unet_2d_condition.py
src/diffusers/models/unet_2d_condition.py
+3
-20
src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
...users/pipelines/versatile_diffusion/modeling_text_unet.py
+3
-19
tests/models/test_activations.py
tests/models/test_activations.py
+48
-0
No files found.
src/diffusers/models/activations.py
0 → 100644
View file @
41ae6708
from
torch
import
nn
def
get_activation
(
act_fn
):
if
act_fn
in
[
"swish"
,
"silu"
]:
return
nn
.
SiLU
()
elif
act_fn
==
"mish"
:
return
nn
.
Mish
()
elif
act_fn
==
"gelu"
:
return
nn
.
GELU
()
else
:
raise
ValueError
(
f
"Unsupported activation function:
{
act_fn
}
"
)
src/diffusers/models/attention.py
View file @
41ae6708
...
...
@@ -18,6 +18,7 @@ import torch.nn.functional as F
from
torch
import
nn
from
..utils
import
maybe_allow_in_graph
from
.activations
import
get_activation
from
.attention_processor
import
Attention
from
.embeddings
import
CombinedTimestepLabelEmbeddings
...
...
@@ -345,15 +346,11 @@ class AdaGroupNorm(nn.Module):
super
().
__init__
()
self
.
num_groups
=
num_groups
self
.
eps
=
eps
if
act_fn
is
None
:
self
.
act
=
None
if
act_fn
==
"swish"
:
self
.
act
=
lambda
x
:
F
.
silu
(
x
)
elif
act_fn
==
"mish"
:
self
.
act
=
nn
.
Mish
()
elif
act_fn
==
"silu"
:
self
.
act
=
nn
.
SiLU
()
elif
act_fn
==
"gelu"
:
self
.
act
=
nn
.
GELU
()
else
:
self
.
act
=
get_activation
(
act_fn
)
self
.
linear
=
nn
.
Linear
(
embedding_dim
,
out_dim
*
2
)
...
...
src/diffusers/models/embeddings.py
View file @
41ae6708
...
...
@@ -18,6 +18,8 @@ import numpy as np
import
torch
from
torch
import
nn
from
.activations
import
get_activation
def
get_timestep_embedding
(
timesteps
:
torch
.
Tensor
,
...
...
@@ -171,14 +173,7 @@ class TimestepEmbedding(nn.Module):
else
:
self
.
cond_proj
=
None
if
act_fn
==
"silu"
:
self
.
act
=
nn
.
SiLU
()
elif
act_fn
==
"mish"
:
self
.
act
=
nn
.
Mish
()
elif
act_fn
==
"gelu"
:
self
.
act
=
nn
.
GELU
()
else
:
raise
ValueError
(
f
"
{
act_fn
}
does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'"
)
self
.
act
=
get_activation
(
act_fn
)
if
out_dim
is
not
None
:
time_embed_dim_out
=
out_dim
...
...
@@ -188,14 +183,8 @@ class TimestepEmbedding(nn.Module):
if
post_act_fn
is
None
:
self
.
post_act
=
None
elif
post_act_fn
==
"silu"
:
self
.
post_act
=
nn
.
SiLU
()
elif
post_act_fn
==
"mish"
:
self
.
post_act
=
nn
.
Mish
()
elif
post_act_fn
==
"gelu"
:
self
.
post_act
=
nn
.
GELU
()
else
:
raise
ValueError
(
f
"
{
post_act_fn
}
does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'"
)
self
.
post_act
=
get_activation
(
post_act_fn
)
def
forward
(
self
,
sample
,
condition
=
None
):
if
condition
is
not
None
:
...
...
src/diffusers/models/resnet.py
View file @
41ae6708
...
...
@@ -20,6 +20,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
.activations
import
get_activation
from
.attention
import
AdaGroupNorm
from
.attention_processor
import
SpatialNorm
...
...
@@ -558,14 +559,7 @@ class ResnetBlock2D(nn.Module):
conv_2d_out_channels
=
conv_2d_out_channels
or
out_channels
self
.
conv2
=
torch
.
nn
.
Conv2d
(
out_channels
,
conv_2d_out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
non_linearity
==
"swish"
:
self
.
nonlinearity
=
lambda
x
:
F
.
silu
(
x
)
elif
non_linearity
==
"mish"
:
self
.
nonlinearity
=
nn
.
Mish
()
elif
non_linearity
==
"silu"
:
self
.
nonlinearity
=
nn
.
SiLU
()
elif
non_linearity
==
"gelu"
:
self
.
nonlinearity
=
nn
.
GELU
()
self
.
nonlinearity
=
get_activation
(
non_linearity
)
self
.
upsample
=
self
.
downsample
=
None
if
self
.
up
:
...
...
@@ -646,11 +640,6 @@ class ResnetBlock2D(nn.Module):
return
output_tensor
class
Mish
(
torch
.
nn
.
Module
):
def
forward
(
self
,
hidden_states
):
return
hidden_states
*
torch
.
tanh
(
torch
.
nn
.
functional
.
softplus
(
hidden_states
))
# unet_rl.py
def
rearrange_dims
(
tensor
):
if
len
(
tensor
.
shape
)
==
2
:
...
...
src/diffusers/models/unet_1d_blocks.py
View file @
41ae6708
...
...
@@ -17,6 +17,7 @@ import torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
.activations
import
get_activation
from
.resnet
import
Downsample1D
,
ResidualTemporalBlock1D
,
Upsample1D
,
rearrange_dims
...
...
@@ -55,14 +56,10 @@ class DownResnetBlock1D(nn.Module):
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
if
non_linearity
==
"swish"
:
self
.
nonlinearity
=
lambda
x
:
F
.
silu
(
x
)
elif
non_linearity
==
"mish"
:
self
.
nonlinearity
=
nn
.
Mish
()
elif
non_linearity
==
"silu"
:
self
.
nonlinearity
=
nn
.
SiLU
()
else
:
if
non_linearity
is
None
:
self
.
nonlinearity
=
None
else
:
self
.
nonlinearity
=
get_activation
(
non_linearity
)
self
.
downsample
=
None
if
add_downsample
:
...
...
@@ -119,14 +116,10 @@ class UpResnetBlock1D(nn.Module):
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
if
non_linearity
==
"swish"
:
self
.
nonlinearity
=
lambda
x
:
F
.
silu
(
x
)
elif
non_linearity
==
"mish"
:
self
.
nonlinearity
=
nn
.
Mish
()
elif
non_linearity
==
"silu"
:
self
.
nonlinearity
=
nn
.
SiLU
()
else
:
if
non_linearity
is
None
:
self
.
nonlinearity
=
None
else
:
self
.
nonlinearity
=
get_activation
(
non_linearity
)
self
.
upsample
=
None
if
add_upsample
:
...
...
@@ -194,14 +187,10 @@ class MidResTemporalBlock1D(nn.Module):
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
if
non_linearity
==
"swish"
:
self
.
nonlinearity
=
lambda
x
:
F
.
silu
(
x
)
elif
non_linearity
==
"mish"
:
self
.
nonlinearity
=
nn
.
Mish
()
elif
non_linearity
==
"silu"
:
self
.
nonlinearity
=
nn
.
SiLU
()
else
:
if
non_linearity
is
None
:
self
.
nonlinearity
=
None
else
:
self
.
nonlinearity
=
get_activation
(
non_linearity
)
self
.
upsample
=
None
if
add_upsample
:
...
...
@@ -232,10 +221,7 @@ class OutConv1DBlock(nn.Module):
super
().
__init__
()
self
.
final_conv1d_1
=
nn
.
Conv1d
(
embed_dim
,
embed_dim
,
5
,
padding
=
2
)
self
.
final_conv1d_gn
=
nn
.
GroupNorm
(
num_groups_out
,
embed_dim
)
if
act_fn
==
"silu"
:
self
.
final_conv1d_act
=
nn
.
SiLU
()
if
act_fn
==
"mish"
:
self
.
final_conv1d_act
=
nn
.
Mish
()
self
.
final_conv1d_act
=
get_activation
(
act_fn
)
self
.
final_conv1d_2
=
nn
.
Conv1d
(
embed_dim
,
out_channels
,
1
)
def
forward
(
self
,
hidden_states
,
temb
=
None
):
...
...
src/diffusers/models/unet_2d_condition.py
View file @
41ae6708
...
...
@@ -16,12 +16,12 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.utils.checkpoint
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..loaders
import
UNet2DConditionLoadersMixin
from
..utils
import
BaseOutput
,
logging
from
.activations
import
get_activation
from
.attention_processor
import
AttentionProcessor
,
AttnProcessor
from
.embeddings
import
(
GaussianFourierProjection
,
...
...
@@ -338,16 +338,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
if
time_embedding_act_fn
is
None
:
self
.
time_embed_act
=
None
elif
time_embedding_act_fn
==
"swish"
:
self
.
time_embed_act
=
lambda
x
:
F
.
silu
(
x
)
elif
time_embedding_act_fn
==
"mish"
:
self
.
time_embed_act
=
nn
.
Mish
()
elif
time_embedding_act_fn
==
"silu"
:
self
.
time_embed_act
=
nn
.
SiLU
()
elif
time_embedding_act_fn
==
"gelu"
:
self
.
time_embed_act
=
nn
.
GELU
()
else
:
raise
ValueError
(
f
"Unsupported activation function:
{
time_embedding_act_fn
}
"
)
self
.
time_embed_act
=
get_activation
(
time_embedding_act_fn
)
self
.
down_blocks
=
nn
.
ModuleList
([])
self
.
up_blocks
=
nn
.
ModuleList
([])
...
...
@@ -501,16 +493,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
num_channels
=
block_out_channels
[
0
],
num_groups
=
norm_num_groups
,
eps
=
norm_eps
)
if
act_fn
==
"swish"
:
self
.
conv_act
=
lambda
x
:
F
.
silu
(
x
)
elif
act_fn
==
"mish"
:
self
.
conv_act
=
nn
.
Mish
()
elif
act_fn
==
"silu"
:
self
.
conv_act
=
nn
.
SiLU
()
elif
act_fn
==
"gelu"
:
self
.
conv_act
=
nn
.
GELU
()
else
:
raise
ValueError
(
f
"Unsupported activation function:
{
act_fn
}
"
)
self
.
conv_act
=
get_activation
(
act_fn
)
else
:
self
.
conv_norm_out
=
None
...
...
src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
View file @
41ae6708
...
...
@@ -7,6 +7,7 @@ import torch.nn.functional as F
from
...configuration_utils
import
ConfigMixin
,
register_to_config
from
...models
import
ModelMixin
from
...models.activations
import
get_activation
from
...models.attention
import
Attention
from
...models.attention_processor
import
(
AttentionProcessor
,
...
...
@@ -441,16 +442,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
if
time_embedding_act_fn
is
None
:
self
.
time_embed_act
=
None
elif
time_embedding_act_fn
==
"swish"
:
self
.
time_embed_act
=
lambda
x
:
F
.
silu
(
x
)
elif
time_embedding_act_fn
==
"mish"
:
self
.
time_embed_act
=
nn
.
Mish
()
elif
time_embedding_act_fn
==
"silu"
:
self
.
time_embed_act
=
nn
.
SiLU
()
elif
time_embedding_act_fn
==
"gelu"
:
self
.
time_embed_act
=
nn
.
GELU
()
else
:
raise
ValueError
(
f
"Unsupported activation function:
{
time_embedding_act_fn
}
"
)
self
.
time_embed_act
=
get_activation
(
time_embedding_act_fn
)
self
.
down_blocks
=
nn
.
ModuleList
([])
self
.
up_blocks
=
nn
.
ModuleList
([])
...
...
@@ -604,16 +597,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
num_channels
=
block_out_channels
[
0
],
num_groups
=
norm_num_groups
,
eps
=
norm_eps
)
if
act_fn
==
"swish"
:
self
.
conv_act
=
lambda
x
:
F
.
silu
(
x
)
elif
act_fn
==
"mish"
:
self
.
conv_act
=
nn
.
Mish
()
elif
act_fn
==
"silu"
:
self
.
conv_act
=
nn
.
SiLU
()
elif
act_fn
==
"gelu"
:
self
.
conv_act
=
nn
.
GELU
()
else
:
raise
ValueError
(
f
"Unsupported activation function:
{
act_fn
}
"
)
self
.
conv_act
=
get_activation
(
act_fn
)
else
:
self
.
conv_norm_out
=
None
...
...
tests/models/test_activations.py
0 → 100644
View file @
41ae6708
import
unittest
import
torch
from
torch
import
nn
from
diffusers.models.activations
import
get_activation
class
ActivationsTests
(
unittest
.
TestCase
):
def
test_swish
(
self
):
act
=
get_activation
(
"swish"
)
self
.
assertIsInstance
(
act
,
nn
.
SiLU
)
self
.
assertEqual
(
act
(
torch
.
tensor
(
-
100
,
dtype
=
torch
.
float32
)).
item
(),
0
)
self
.
assertNotEqual
(
act
(
torch
.
tensor
(
-
1
,
dtype
=
torch
.
float32
)).
item
(),
0
)
self
.
assertEqual
(
act
(
torch
.
tensor
(
0
,
dtype
=
torch
.
float32
)).
item
(),
0
)
self
.
assertEqual
(
act
(
torch
.
tensor
(
20
,
dtype
=
torch
.
float32
)).
item
(),
20
)
def
test_silu
(
self
):
act
=
get_activation
(
"silu"
)
self
.
assertIsInstance
(
act
,
nn
.
SiLU
)
self
.
assertEqual
(
act
(
torch
.
tensor
(
-
100
,
dtype
=
torch
.
float32
)).
item
(),
0
)
self
.
assertNotEqual
(
act
(
torch
.
tensor
(
-
1
,
dtype
=
torch
.
float32
)).
item
(),
0
)
self
.
assertEqual
(
act
(
torch
.
tensor
(
0
,
dtype
=
torch
.
float32
)).
item
(),
0
)
self
.
assertEqual
(
act
(
torch
.
tensor
(
20
,
dtype
=
torch
.
float32
)).
item
(),
20
)
def
test_mish
(
self
):
act
=
get_activation
(
"mish"
)
self
.
assertIsInstance
(
act
,
nn
.
Mish
)
self
.
assertEqual
(
act
(
torch
.
tensor
(
-
200
,
dtype
=
torch
.
float32
)).
item
(),
0
)
self
.
assertNotEqual
(
act
(
torch
.
tensor
(
-
1
,
dtype
=
torch
.
float32
)).
item
(),
0
)
self
.
assertEqual
(
act
(
torch
.
tensor
(
0
,
dtype
=
torch
.
float32
)).
item
(),
0
)
self
.
assertEqual
(
act
(
torch
.
tensor
(
20
,
dtype
=
torch
.
float32
)).
item
(),
20
)
def
test_gelu
(
self
):
act
=
get_activation
(
"gelu"
)
self
.
assertIsInstance
(
act
,
nn
.
GELU
)
self
.
assertEqual
(
act
(
torch
.
tensor
(
-
100
,
dtype
=
torch
.
float32
)).
item
(),
0
)
self
.
assertNotEqual
(
act
(
torch
.
tensor
(
-
1
,
dtype
=
torch
.
float32
)).
item
(),
0
)
self
.
assertEqual
(
act
(
torch
.
tensor
(
0
,
dtype
=
torch
.
float32
)).
item
(),
0
)
self
.
assertEqual
(
act
(
torch
.
tensor
(
20
,
dtype
=
torch
.
float32
)).
item
(),
20
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment