Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
c524244f
Unverified
Commit
c524244f
authored
Jul 03, 2022
by
Patrick von Platen
Committed by
GitHub
Jul 03, 2022
Browse files
[Resnet] Remove unnecessary functions / classes (#67)
Remove unnecessary functions / classes
parent
d224c637
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
10 additions
and
35 deletions
+10
-35
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+4
-29
src/diffusers/models/unet_glide.py
src/diffusers/models/unet_glide.py
+3
-3
src/diffusers/models/unet_ldm.py
src/diffusers/models/unet_ldm.py
+3
-3
No files found.
src/diffusers/models/resnet.py
View file @
c524244f
from
abc
import
abstractmethod
from
functools
import
partial
import
numpy
as
np
...
...
@@ -46,30 +45,6 @@ def conv_transpose_nd(dims, *args, **kwargs):
raise
ValueError
(
f
"unsupported dimensions:
{
dims
}
"
)
def
Normalize
(
in_channels
,
num_groups
=
32
,
eps
=
1e-6
):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
num_groups
,
num_channels
=
in_channels
,
eps
=
eps
,
affine
=
True
)
def
nonlinearity
(
x
,
swish
=
1.0
):
# swish
if
swish
==
1.0
:
return
F
.
silu
(
x
)
else
:
return
x
*
F
.
sigmoid
(
x
*
float
(
swish
))
class
TimestepBlock
(
nn
.
Module
):
"""
Any module where forward() takes timestep embeddings as a second argument.
"""
@
abstractmethod
def
forward
(
self
,
x
,
emb
):
"""
Apply the module to `x` given `emb` timestep embeddings.
"""
class
Upsample
(
nn
.
Module
):
"""
An upsampling layer with an optional convolution.
...
...
@@ -216,9 +191,9 @@ class ResnetBlock2D(nn.Module):
groups_out
=
groups
if
self
.
pre_norm
:
self
.
norm1
=
N
or
malize
(
in_
ch
a
nn
els
,
num_groups
=
groups
,
eps
=
eps
)
self
.
norm1
=
t
orch
.
nn
.
GroupNorm
(
num_groups
=
groups
,
num_channels
=
in_channels
,
eps
=
eps
,
affine
=
True
)
else
:
self
.
norm1
=
N
or
malize
(
out_
ch
a
nn
els
,
num_groups
=
groups
,
eps
=
eps
)
self
.
norm1
=
t
orch
.
nn
.
GroupNorm
(
num_groups
=
groups
,
num_channels
=
out_channels
,
eps
=
eps
,
affine
=
True
)
self
.
conv1
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
...
...
@@ -227,12 +202,12 @@ class ResnetBlock2D(nn.Module):
elif
time_embedding_norm
==
"scale_shift"
and
temb_channels
>
0
:
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
2
*
out_channels
)
self
.
norm2
=
N
or
malize
(
out_
ch
a
nn
els
,
num_groups
=
groups_out
,
eps
=
eps
)
self
.
norm2
=
t
orch
.
nn
.
GroupNorm
(
num_groups
=
groups_out
,
num_channels
=
out_channels
,
eps
=
eps
,
affine
=
True
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
non_linearity
==
"swish"
:
self
.
nonlinearity
=
nonlinearity
self
.
nonlinearity
=
lambda
x
:
F
.
silu
(
x
)
elif
non_linearity
==
"mish"
:
self
.
nonlinearity
=
Mish
()
elif
non_linearity
==
"silu"
:
...
...
src/diffusers/models/unet_glide.py
View file @
c524244f
...
...
@@ -6,7 +6,7 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
,
ResnetBlock2D
,
TimestepBlock
,
Upsample
from
.resnet
import
Downsample
,
ResnetBlock2D
,
Upsample
def
convert_module_to_f16
(
l
):
...
...
@@ -81,14 +81,14 @@ def zero_module(module):
return
module
class
TimestepEmbedSequential
(
nn
.
Sequential
,
TimestepBlock
):
class
TimestepEmbedSequential
(
nn
.
Sequential
):
"""
A sequential module that passes timestep embeddings to the children that support it as an extra input.
"""
def
forward
(
self
,
x
,
emb
,
encoder_out
=
None
):
for
layer
in
self
:
if
isinstance
(
layer
,
Timestep
Block
)
or
isinstance
(
layer
,
ResnetBlock2D
):
if
isinstance
(
layer
,
Resnet
Block
2D
)
or
isinstance
(
layer
,
TimestepEmbedSequential
):
x
=
layer
(
x
,
emb
)
elif
isinstance
(
layer
,
AttentionBlock
):
x
=
layer
(
x
,
encoder_out
)
...
...
src/diffusers/models/unet_ldm.py
View file @
c524244f
...
...
@@ -10,7 +10,7 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
,
ResnetBlock2D
,
TimestepBlock
,
Upsample
from
.resnet
import
Downsample
,
ResnetBlock2D
,
Upsample
# from .resnet import ResBlock
...
...
@@ -141,14 +141,14 @@ def normalization(channels, swish=0.0):
return
GroupNorm32
(
num_channels
=
channels
,
num_groups
=
32
,
swish
=
swish
)
class
TimestepEmbedSequential
(
nn
.
Sequential
,
TimestepBlock
):
class
TimestepEmbedSequential
(
nn
.
Sequential
):
"""
A sequential module that passes timestep embeddings to the children that support it as an extra input.
"""
def
forward
(
self
,
x
,
emb
,
context
=
None
):
for
layer
in
self
:
if
isinstance
(
layer
,
Timestep
Block
)
or
isinstance
(
layer
,
ResnetBlock2D
):
if
isinstance
(
layer
,
Resnet
Block
2D
)
or
isinstance
(
layer
,
TimestepEmbedSequential
):
x
=
layer
(
x
,
emb
)
elif
isinstance
(
layer
,
SpatialTransformer
):
x
=
layer
(
x
,
context
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment