Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
c524244f
Unverified
Commit
c524244f
authored
Jul 03, 2022
by
Patrick von Platen
Committed by
GitHub
Jul 03, 2022
Browse files
[Resnet] Remove unnecessary functions / classes (#67)
Remove unnecessary functions / classes
parent
d224c637
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
10 additions
and
35 deletions
+10
-35
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+4
-29
src/diffusers/models/unet_glide.py
src/diffusers/models/unet_glide.py
+3
-3
src/diffusers/models/unet_ldm.py
src/diffusers/models/unet_ldm.py
+3
-3
No files found.
src/diffusers/models/resnet.py
View file @
c524244f
from
abc
import
abstractmethod
from
functools
import
partial
import
numpy
as
np
...
...
@@ -46,30 +45,6 @@ def conv_transpose_nd(dims, *args, **kwargs):
raise
ValueError
(
f
"unsupported dimensions:
{
dims
}
"
)
def
Normalize
(
in_channels
,
num_groups
=
32
,
eps
=
1e-6
):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
num_groups
,
num_channels
=
in_channels
,
eps
=
eps
,
affine
=
True
)
def
nonlinearity
(
x
,
swish
=
1.0
):
# swish
if
swish
==
1.0
:
return
F
.
silu
(
x
)
else
:
return
x
*
F
.
sigmoid
(
x
*
float
(
swish
))
class
TimestepBlock
(
nn
.
Module
):
"""
Any module where forward() takes timestep embeddings as a second argument.
"""
@
abstractmethod
def
forward
(
self
,
x
,
emb
):
"""
Apply the module to `x` given `emb` timestep embeddings.
"""
class
Upsample
(
nn
.
Module
):
"""
An upsampling layer with an optional convolution.
...
...
@@ -216,9 +191,9 @@ class ResnetBlock2D(nn.Module):
groups_out
=
groups
if
self
.
pre_norm
:
self
.
norm1
=
N
or
malize
(
in_
ch
a
nn
els
,
num_groups
=
groups
,
eps
=
eps
)
self
.
norm1
=
t
orch
.
nn
.
GroupNorm
(
num_groups
=
groups
,
num_channels
=
in_channels
,
eps
=
eps
,
affine
=
True
)
else
:
self
.
norm1
=
N
or
malize
(
out_
ch
a
nn
els
,
num_groups
=
groups
,
eps
=
eps
)
self
.
norm1
=
t
orch
.
nn
.
GroupNorm
(
num_groups
=
groups
,
num_channels
=
out_channels
,
eps
=
eps
,
affine
=
True
)
self
.
conv1
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
...
...
@@ -227,12 +202,12 @@ class ResnetBlock2D(nn.Module):
elif
time_embedding_norm
==
"scale_shift"
and
temb_channels
>
0
:
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
2
*
out_channels
)
self
.
norm2
=
N
or
malize
(
out_
ch
a
nn
els
,
num_groups
=
groups_out
,
eps
=
eps
)
self
.
norm2
=
t
orch
.
nn
.
GroupNorm
(
num_groups
=
groups_out
,
num_channels
=
out_channels
,
eps
=
eps
,
affine
=
True
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
non_linearity
==
"swish"
:
self
.
nonlinearity
=
nonlinearity
self
.
nonlinearity
=
lambda
x
:
F
.
silu
(
x
)
elif
non_linearity
==
"mish"
:
self
.
nonlinearity
=
Mish
()
elif
non_linearity
==
"silu"
:
...
...
src/diffusers/models/unet_glide.py
View file @
c524244f
...
...
@@ -6,7 +6,7 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
,
ResnetBlock2D
,
TimestepBlock
,
Upsample
from
.resnet
import
Downsample
,
ResnetBlock2D
,
Upsample
def
convert_module_to_f16
(
l
):
...
...
@@ -81,14 +81,14 @@ def zero_module(module):
return
module
class
TimestepEmbedSequential
(
nn
.
Sequential
,
TimestepBlock
):
class
TimestepEmbedSequential
(
nn
.
Sequential
):
"""
A sequential module that passes timestep embeddings to the children that support it as an extra input.
"""
def
forward
(
self
,
x
,
emb
,
encoder_out
=
None
):
for
layer
in
self
:
if
isinstance
(
layer
,
Timestep
Block
)
or
isinstance
(
layer
,
ResnetBlock2D
):
if
isinstance
(
layer
,
Resnet
Block
2D
)
or
isinstance
(
layer
,
TimestepEmbedSequential
):
x
=
layer
(
x
,
emb
)
elif
isinstance
(
layer
,
AttentionBlock
):
x
=
layer
(
x
,
encoder_out
)
...
...
src/diffusers/models/unet_ldm.py
View file @
c524244f
...
...
@@ -10,7 +10,7 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
,
ResnetBlock2D
,
TimestepBlock
,
Upsample
from
.resnet
import
Downsample
,
ResnetBlock2D
,
Upsample
# from .resnet import ResBlock
...
...
@@ -141,14 +141,14 @@ def normalization(channels, swish=0.0):
return
GroupNorm32
(
num_channels
=
channels
,
num_groups
=
32
,
swish
=
swish
)
class
TimestepEmbedSequential
(
nn
.
Sequential
,
TimestepBlock
):
class
TimestepEmbedSequential
(
nn
.
Sequential
):
"""
A sequential module that passes timestep embeddings to the children that support it as an extra input.
"""
def
forward
(
self
,
x
,
emb
,
context
=
None
):
for
layer
in
self
:
if
isinstance
(
layer
,
Timestep
Block
)
or
isinstance
(
layer
,
ResnetBlock2D
):
if
isinstance
(
layer
,
Resnet
Block
2D
)
or
isinstance
(
layer
,
TimestepEmbedSequential
):
x
=
layer
(
x
,
emb
)
elif
isinstance
(
layer
,
SpatialTransformer
):
x
=
layer
(
x
,
context
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment