Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
41ae6708
Unverified
Commit
41ae6708
authored
Jun 05, 2023
by
Will Berman
Committed by
GitHub
Jun 05, 2023
Browse files
move activation dispatches into helper function (#3656)
* move activation dispatches into helper function * tests
parent
462956be
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
89 additions
and
101 deletions
+89
-101
src/diffusers/models/activations.py
src/diffusers/models/activations.py
+12
-0
src/diffusers/models/attention.py
src/diffusers/models/attention.py
+6
-9
src/diffusers/models/embeddings.py
src/diffusers/models/embeddings.py
+4
-15
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+2
-13
src/diffusers/models/unet_1d_blocks.py
src/diffusers/models/unet_1d_blocks.py
+11
-25
src/diffusers/models/unet_2d_condition.py
src/diffusers/models/unet_2d_condition.py
+3
-20
src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
...users/pipelines/versatile_diffusion/modeling_text_unet.py
+3
-19
tests/models/test_activations.py
tests/models/test_activations.py
+48
-0
No files found.
src/diffusers/models/activations.py
0 → 100644
View file @
41ae6708
from
torch
import
nn
def
get_activation
(
act_fn
):
if
act_fn
in
[
"swish"
,
"silu"
]:
return
nn
.
SiLU
()
elif
act_fn
==
"mish"
:
return
nn
.
Mish
()
elif
act_fn
==
"gelu"
:
return
nn
.
GELU
()
else
:
raise
ValueError
(
f
"Unsupported activation function:
{
act_fn
}
"
)
src/diffusers/models/attention.py
View file @
41ae6708
...
...
@@ -18,6 +18,7 @@ import torch.nn.functional as F
from
torch
import
nn
from
..utils
import
maybe_allow_in_graph
from
.activations
import
get_activation
from
.attention_processor
import
Attention
from
.embeddings
import
CombinedTimestepLabelEmbeddings
...
...
@@ -345,15 +346,11 @@ class AdaGroupNorm(nn.Module):
super
().
__init__
()
self
.
num_groups
=
num_groups
self
.
eps
=
eps
self
.
act
=
None
if
act_fn
==
"swish"
:
self
.
act
=
lambda
x
:
F
.
silu
(
x
)
elif
act_fn
==
"mish"
:
self
.
act
=
nn
.
Mish
()
elif
act_fn
==
"silu"
:
self
.
act
=
nn
.
SiLU
()
elif
act_fn
==
"gelu"
:
self
.
act
=
nn
.
GELU
()
if
act_fn
is
None
:
self
.
act
=
None
else
:
self
.
act
=
get_activation
(
act_fn
)
self
.
linear
=
nn
.
Linear
(
embedding_dim
,
out_dim
*
2
)
...
...
src/diffusers/models/embeddings.py
View file @
41ae6708
...
...
@@ -18,6 +18,8 @@ import numpy as np
import
torch
from
torch
import
nn
from
.activations
import
get_activation
def
get_timestep_embedding
(
timesteps
:
torch
.
Tensor
,
...
...
@@ -171,14 +173,7 @@ class TimestepEmbedding(nn.Module):
else
:
self
.
cond_proj
=
None
if
act_fn
==
"silu"
:
self
.
act
=
nn
.
SiLU
()
elif
act_fn
==
"mish"
:
self
.
act
=
nn
.
Mish
()
elif
act_fn
==
"gelu"
:
self
.
act
=
nn
.
GELU
()
else
:
raise
ValueError
(
f
"
{
act_fn
}
does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'"
)
self
.
act
=
get_activation
(
act_fn
)
if
out_dim
is
not
None
:
time_embed_dim_out
=
out_dim
...
...
@@ -188,14 +183,8 @@ class TimestepEmbedding(nn.Module):
if
post_act_fn
is
None
:
self
.
post_act
=
None
elif
post_act_fn
==
"silu"
:
self
.
post_act
=
nn
.
SiLU
()
elif
post_act_fn
==
"mish"
:
self
.
post_act
=
nn
.
Mish
()
elif
post_act_fn
==
"gelu"
:
self
.
post_act
=
nn
.
GELU
()
else
:
raise
ValueError
(
f
"
{
post_act_fn
}
does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'"
)
self
.
post_act
=
get_activation
(
post_act_fn
)
def
forward
(
self
,
sample
,
condition
=
None
):
if
condition
is
not
None
:
...
...
src/diffusers/models/resnet.py
View file @
41ae6708
...
...
@@ -20,6 +20,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
.activations
import
get_activation
from
.attention
import
AdaGroupNorm
from
.attention_processor
import
SpatialNorm
...
...
@@ -558,14 +559,7 @@ class ResnetBlock2D(nn.Module):
conv_2d_out_channels
=
conv_2d_out_channels
or
out_channels
self
.
conv2
=
torch
.
nn
.
Conv2d
(
out_channels
,
conv_2d_out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
non_linearity
==
"swish"
:
self
.
nonlinearity
=
lambda
x
:
F
.
silu
(
x
)
elif
non_linearity
==
"mish"
:
self
.
nonlinearity
=
nn
.
Mish
()
elif
non_linearity
==
"silu"
:
self
.
nonlinearity
=
nn
.
SiLU
()
elif
non_linearity
==
"gelu"
:
self
.
nonlinearity
=
nn
.
GELU
()
self
.
nonlinearity
=
get_activation
(
non_linearity
)
self
.
upsample
=
self
.
downsample
=
None
if
self
.
up
:
...
...
@@ -646,11 +640,6 @@ class ResnetBlock2D(nn.Module):
return
output_tensor
class
Mish
(
torch
.
nn
.
Module
):
def
forward
(
self
,
hidden_states
):
return
hidden_states
*
torch
.
tanh
(
torch
.
nn
.
functional
.
softplus
(
hidden_states
))
# unet_rl.py
def
rearrange_dims
(
tensor
):
if
len
(
tensor
.
shape
)
==
2
:
...
...
src/diffusers/models/unet_1d_blocks.py
View file @
41ae6708
...
...
@@ -17,6 +17,7 @@ import torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
.activations
import
get_activation
from
.resnet
import
Downsample1D
,
ResidualTemporalBlock1D
,
Upsample1D
,
rearrange_dims
...
...
@@ -55,14 +56,10 @@ class DownResnetBlock1D(nn.Module):
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
if
non_linearity
==
"swish"
:
self
.
nonlinearity
=
lambda
x
:
F
.
silu
(
x
)
elif
non_linearity
==
"mish"
:
self
.
nonlinearity
=
nn
.
Mish
()
elif
non_linearity
==
"silu"
:
self
.
nonlinearity
=
nn
.
SiLU
()
else
:
if
non_linearity
is
None
:
self
.
nonlinearity
=
None
else
:
self
.
nonlinearity
=
get_activation
(
non_linearity
)
self
.
downsample
=
None
if
add_downsample
:
...
...
@@ -119,14 +116,10 @@ class UpResnetBlock1D(nn.Module):
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
if
non_linearity
==
"swish"
:
self
.
nonlinearity
=
lambda
x
:
F
.
silu
(
x
)
elif
non_linearity
==
"mish"
:
self
.
nonlinearity
=
nn
.
Mish
()
elif
non_linearity
==
"silu"
:
self
.
nonlinearity
=
nn
.
SiLU
()
else
:
if
non_linearity
is
None
:
self
.
nonlinearity
=
None
else
:
self
.
nonlinearity
=
get_activation
(
non_linearity
)
self
.
upsample
=
None
if
add_upsample
:
...
...
@@ -194,14 +187,10 @@ class MidResTemporalBlock1D(nn.Module):
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
if
non_linearity
==
"swish"
:
self
.
nonlinearity
=
lambda
x
:
F
.
silu
(
x
)
elif
non_linearity
==
"mish"
:
self
.
nonlinearity
=
nn
.
Mish
()
elif
non_linearity
==
"silu"
:
self
.
nonlinearity
=
nn
.
SiLU
()
else
:
if
non_linearity
is
None
:
self
.
nonlinearity
=
None
else
:
self
.
nonlinearity
=
get_activation
(
non_linearity
)
self
.
upsample
=
None
if
add_upsample
:
...
...
@@ -232,10 +221,7 @@ class OutConv1DBlock(nn.Module):
super
().
__init__
()
self
.
final_conv1d_1
=
nn
.
Conv1d
(
embed_dim
,
embed_dim
,
5
,
padding
=
2
)
self
.
final_conv1d_gn
=
nn
.
GroupNorm
(
num_groups_out
,
embed_dim
)
if
act_fn
==
"silu"
:
self
.
final_conv1d_act
=
nn
.
SiLU
()
if
act_fn
==
"mish"
:
self
.
final_conv1d_act
=
nn
.
Mish
()
self
.
final_conv1d_act
=
get_activation
(
act_fn
)
self
.
final_conv1d_2
=
nn
.
Conv1d
(
embed_dim
,
out_channels
,
1
)
def
forward
(
self
,
hidden_states
,
temb
=
None
):
...
...
src/diffusers/models/unet_2d_condition.py
View file @
41ae6708
...
...
@@ -16,12 +16,12 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.utils.checkpoint
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..loaders
import
UNet2DConditionLoadersMixin
from
..utils
import
BaseOutput
,
logging
from
.activations
import
get_activation
from
.attention_processor
import
AttentionProcessor
,
AttnProcessor
from
.embeddings
import
(
GaussianFourierProjection
,
...
...
@@ -338,16 +338,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
if
time_embedding_act_fn
is
None
:
self
.
time_embed_act
=
None
elif
time_embedding_act_fn
==
"swish"
:
self
.
time_embed_act
=
lambda
x
:
F
.
silu
(
x
)
elif
time_embedding_act_fn
==
"mish"
:
self
.
time_embed_act
=
nn
.
Mish
()
elif
time_embedding_act_fn
==
"silu"
:
self
.
time_embed_act
=
nn
.
SiLU
()
elif
time_embedding_act_fn
==
"gelu"
:
self
.
time_embed_act
=
nn
.
GELU
()
else
:
raise
ValueError
(
f
"Unsupported activation function:
{
time_embedding_act_fn
}
"
)
self
.
time_embed_act
=
get_activation
(
time_embedding_act_fn
)
self
.
down_blocks
=
nn
.
ModuleList
([])
self
.
up_blocks
=
nn
.
ModuleList
([])
...
...
@@ -501,16 +493,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
num_channels
=
block_out_channels
[
0
],
num_groups
=
norm_num_groups
,
eps
=
norm_eps
)
if
act_fn
==
"swish"
:
self
.
conv_act
=
lambda
x
:
F
.
silu
(
x
)
elif
act_fn
==
"mish"
:
self
.
conv_act
=
nn
.
Mish
()
elif
act_fn
==
"silu"
:
self
.
conv_act
=
nn
.
SiLU
()
elif
act_fn
==
"gelu"
:
self
.
conv_act
=
nn
.
GELU
()
else
:
raise
ValueError
(
f
"Unsupported activation function:
{
act_fn
}
"
)
self
.
conv_act
=
get_activation
(
act_fn
)
else
:
self
.
conv_norm_out
=
None
...
...
src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
View file @
41ae6708
...
...
@@ -7,6 +7,7 @@ import torch.nn.functional as F
from
...configuration_utils
import
ConfigMixin
,
register_to_config
from
...models
import
ModelMixin
from
...models.activations
import
get_activation
from
...models.attention
import
Attention
from
...models.attention_processor
import
(
AttentionProcessor
,
...
...
@@ -441,16 +442,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
if
time_embedding_act_fn
is
None
:
self
.
time_embed_act
=
None
elif
time_embedding_act_fn
==
"swish"
:
self
.
time_embed_act
=
lambda
x
:
F
.
silu
(
x
)
elif
time_embedding_act_fn
==
"mish"
:
self
.
time_embed_act
=
nn
.
Mish
()
elif
time_embedding_act_fn
==
"silu"
:
self
.
time_embed_act
=
nn
.
SiLU
()
elif
time_embedding_act_fn
==
"gelu"
:
self
.
time_embed_act
=
nn
.
GELU
()
else
:
raise
ValueError
(
f
"Unsupported activation function:
{
time_embedding_act_fn
}
"
)
self
.
time_embed_act
=
get_activation
(
time_embedding_act_fn
)
self
.
down_blocks
=
nn
.
ModuleList
([])
self
.
up_blocks
=
nn
.
ModuleList
([])
...
...
@@ -604,16 +597,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
num_channels
=
block_out_channels
[
0
],
num_groups
=
norm_num_groups
,
eps
=
norm_eps
)
if
act_fn
==
"swish"
:
self
.
conv_act
=
lambda
x
:
F
.
silu
(
x
)
elif
act_fn
==
"mish"
:
self
.
conv_act
=
nn
.
Mish
()
elif
act_fn
==
"silu"
:
self
.
conv_act
=
nn
.
SiLU
()
elif
act_fn
==
"gelu"
:
self
.
conv_act
=
nn
.
GELU
()
else
:
raise
ValueError
(
f
"Unsupported activation function:
{
act_fn
}
"
)
self
.
conv_act
=
get_activation
(
act_fn
)
else
:
self
.
conv_norm_out
=
None
...
...
tests/models/test_activations.py
0 → 100644
View file @
41ae6708
import
unittest
import
torch
from
torch
import
nn
from
diffusers.models.activations
import
get_activation
class
ActivationsTests
(
unittest
.
TestCase
):
def
test_swish
(
self
):
act
=
get_activation
(
"swish"
)
self
.
assertIsInstance
(
act
,
nn
.
SiLU
)
self
.
assertEqual
(
act
(
torch
.
tensor
(
-
100
,
dtype
=
torch
.
float32
)).
item
(),
0
)
self
.
assertNotEqual
(
act
(
torch
.
tensor
(
-
1
,
dtype
=
torch
.
float32
)).
item
(),
0
)
self
.
assertEqual
(
act
(
torch
.
tensor
(
0
,
dtype
=
torch
.
float32
)).
item
(),
0
)
self
.
assertEqual
(
act
(
torch
.
tensor
(
20
,
dtype
=
torch
.
float32
)).
item
(),
20
)
def
test_silu
(
self
):
act
=
get_activation
(
"silu"
)
self
.
assertIsInstance
(
act
,
nn
.
SiLU
)
self
.
assertEqual
(
act
(
torch
.
tensor
(
-
100
,
dtype
=
torch
.
float32
)).
item
(),
0
)
self
.
assertNotEqual
(
act
(
torch
.
tensor
(
-
1
,
dtype
=
torch
.
float32
)).
item
(),
0
)
self
.
assertEqual
(
act
(
torch
.
tensor
(
0
,
dtype
=
torch
.
float32
)).
item
(),
0
)
self
.
assertEqual
(
act
(
torch
.
tensor
(
20
,
dtype
=
torch
.
float32
)).
item
(),
20
)
def
test_mish
(
self
):
act
=
get_activation
(
"mish"
)
self
.
assertIsInstance
(
act
,
nn
.
Mish
)
self
.
assertEqual
(
act
(
torch
.
tensor
(
-
200
,
dtype
=
torch
.
float32
)).
item
(),
0
)
self
.
assertNotEqual
(
act
(
torch
.
tensor
(
-
1
,
dtype
=
torch
.
float32
)).
item
(),
0
)
self
.
assertEqual
(
act
(
torch
.
tensor
(
0
,
dtype
=
torch
.
float32
)).
item
(),
0
)
self
.
assertEqual
(
act
(
torch
.
tensor
(
20
,
dtype
=
torch
.
float32
)).
item
(),
20
)
def
test_gelu
(
self
):
act
=
get_activation
(
"gelu"
)
self
.
assertIsInstance
(
act
,
nn
.
GELU
)
self
.
assertEqual
(
act
(
torch
.
tensor
(
-
100
,
dtype
=
torch
.
float32
)).
item
(),
0
)
self
.
assertNotEqual
(
act
(
torch
.
tensor
(
-
1
,
dtype
=
torch
.
float32
)).
item
(),
0
)
self
.
assertEqual
(
act
(
torch
.
tensor
(
0
,
dtype
=
torch
.
float32
)).
item
(),
0
)
self
.
assertEqual
(
act
(
torch
.
tensor
(
20
,
dtype
=
torch
.
float32
)).
item
(),
20
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment