Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
a2b72faf
Commit
a2b72faf
authored
Jun 27, 2022
by
Patrick von Platen
Browse files
Merge branch 'main' of
https://github.com/huggingface/diffusers
into main
parents
c9504bba
26ea58d4
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
91 additions
and
98 deletions
+91
-98
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+15
-4
src/diffusers/models/unet.py
src/diffusers/models/unet.py
+2
-20
src/diffusers/models/unet_glide.py
src/diffusers/models/unet_glide.py
+6
-31
src/diffusers/models/unet_grad_tts.py
src/diffusers/models/unet_grad_tts.py
+3
-11
src/diffusers/models/unet_ldm.py
src/diffusers/models/unet_ldm.py
+9
-31
tests/test_layers_utils.py
tests/test_layers_utils.py
+56
-1
No files found.
src/diffusers/models/resnet.py
View file @
a2b72faf
...
...
@@ -101,7 +101,7 @@ class Downsample(nn.Module):
downsampling occurs in the inner-two dimensions.
"""
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
,
padding
=
1
):
def
__init__
(
self
,
channels
,
use_conv
=
False
,
dims
=
2
,
out_channels
=
None
,
padding
=
1
,
name
=
"conv"
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
...
...
@@ -109,18 +109,29 @@ class Downsample(nn.Module):
self
.
dims
=
dims
self
.
padding
=
padding
stride
=
2
if
dims
!=
3
else
(
1
,
2
,
2
)
self
.
name
=
name
if
use_conv
:
self
.
down
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
stride
=
stride
,
padding
=
padding
)
conv
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
stride
=
stride
,
padding
=
padding
)
else
:
assert
self
.
channels
==
self
.
out_channels
self
.
down
=
avg_pool_nd
(
dims
,
kernel_size
=
stride
,
stride
=
stride
)
conv
=
avg_pool_nd
(
dims
,
kernel_size
=
stride
,
stride
=
stride
)
if
name
==
"conv"
:
self
.
conv
=
conv
else
:
self
.
op
=
conv
def
forward
(
self
,
x
):
assert
x
.
shape
[
1
]
==
self
.
channels
if
self
.
use_conv
and
self
.
padding
==
0
and
self
.
dims
==
2
:
pad
=
(
0
,
1
,
0
,
1
)
x
=
F
.
pad
(
x
,
pad
,
mode
=
"constant"
,
value
=
0
)
return
self
.
down
(
x
)
if
self
.
name
==
"conv"
:
return
self
.
conv
(
x
)
else
:
return
self
.
op
(
x
)
class
UNetUpsample
(
nn
.
Module
):
...
...
src/diffusers/models/unet.py
View file @
a2b72faf
...
...
@@ -31,7 +31,7 @@ from tqdm import tqdm
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Upsample
from
.resnet
import
Downsample
,
Upsample
def
nonlinearity
(
x
):
...
...
@@ -43,24 +43,6 @@ def Normalize(in_channels):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
class
Downsample
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
with_conv
):
super
().
__init__
()
self
.
with_conv
=
with_conv
if
self
.
with_conv
:
# no asymmetric padding in torch conv, must do it ourselves
self
.
conv
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
3
,
stride
=
2
,
padding
=
0
)
def
forward
(
self
,
x
):
if
self
.
with_conv
:
pad
=
(
0
,
1
,
0
,
1
)
x
=
torch
.
nn
.
functional
.
pad
(
x
,
pad
,
mode
=
"constant"
,
value
=
0
)
x
=
self
.
conv
(
x
)
else
:
x
=
torch
.
nn
.
functional
.
avg_pool2d
(
x
,
kernel_size
=
2
,
stride
=
2
)
return
x
class
ResnetBlock
(
nn
.
Module
):
def
__init__
(
self
,
*
,
in_channels
,
out_channels
=
None
,
conv_shortcut
=
False
,
dropout
,
temb_channels
=
512
):
super
().
__init__
()
...
...
@@ -207,7 +189,7 @@ class UNetModel(ModelMixin, ConfigMixin):
down
.
block
=
block
down
.
attn
=
attn
if
i_level
!=
self
.
num_resolutions
-
1
:
down
.
downsample
=
Downsample
(
block_in
,
resamp_with_conv
)
down
.
downsample
=
Downsample
(
block_in
,
use_conv
=
resamp_with_conv
,
padding
=
0
)
curr_res
=
curr_res
//
2
self
.
down
.
append
(
down
)
...
...
src/diffusers/models/unet_glide.py
View file @
a2b72faf
...
...
@@ -8,7 +8,7 @@ import torch.nn.functional as F
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Upsample
from
.resnet
import
Downsample
,
Upsample
def
convert_module_to_f16
(
l
):
...
...
@@ -124,33 +124,6 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
return
x
class
Downsample
(
nn
.
Module
):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
dims
=
dims
stride
=
2
if
dims
!=
3
else
(
1
,
2
,
2
)
if
use_conv
:
self
.
op
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
stride
=
stride
,
padding
=
1
)
else
:
assert
self
.
channels
==
self
.
out_channels
self
.
op
=
avg_pool_nd
(
dims
,
kernel_size
=
stride
,
stride
=
stride
)
def
forward
(
self
,
x
):
assert
x
.
shape
[
1
]
==
self
.
channels
return
self
.
op
(
x
)
class
ResBlock
(
TimestepBlock
):
"""
A residual block that can optionally change the number of channels.
...
...
@@ -198,8 +171,8 @@ class ResBlock(TimestepBlock):
self
.
h_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
self
.
x_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
elif
down
:
self
.
h_upd
=
Downsample
(
channels
,
False
,
dims
)
self
.
x_upd
=
Downsample
(
channels
,
False
,
dims
)
self
.
h_upd
=
Downsample
(
channels
,
use_conv
=
False
,
dims
=
dims
,
padding
=
1
,
name
=
"op"
)
self
.
x_upd
=
Downsample
(
channels
,
use_conv
=
False
,
dims
=
dims
,
padding
=
1
,
name
=
"op"
)
else
:
self
.
h_upd
=
self
.
x_upd
=
nn
.
Identity
()
...
...
@@ -450,7 +423,9 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
down
=
True
,
)
if
resblock_updown
else
Downsample
(
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
)
else
Downsample
(
ch
,
use_conv
=
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
,
padding
=
1
,
name
=
"op"
)
)
)
ch
=
out_ch
...
...
src/diffusers/models/unet_grad_tts.py
View file @
a2b72faf
import
torch
from
numpy
import
pad
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Upsample
from
.resnet
import
Downsample
,
Upsample
class
Mish
(
torch
.
nn
.
Module
):
...
...
@@ -11,15 +12,6 @@ class Mish(torch.nn.Module):
return
x
*
torch
.
tanh
(
torch
.
nn
.
functional
.
softplus
(
x
))
class
Downsample
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
(
Downsample
,
self
).
__init__
()
self
.
conv
=
torch
.
nn
.
Conv2d
(
dim
,
dim
,
3
,
2
,
1
)
def
forward
(
self
,
x
):
return
self
.
conv
(
x
)
class
Rezero
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
fn
):
super
(
Rezero
,
self
).
__init__
()
...
...
@@ -141,7 +133,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
ResnetBlock
(
dim_in
,
dim_out
,
time_emb_dim
=
dim
),
ResnetBlock
(
dim_out
,
dim_out
,
time_emb_dim
=
dim
),
Residual
(
Rezero
(
LinearAttention
(
dim_out
))),
Downsample
(
dim_out
)
if
not
is_last
else
torch
.
nn
.
Identity
(),
Downsample
(
dim_out
,
use_conv
=
True
,
padding
=
1
)
if
not
is_last
else
torch
.
nn
.
Identity
(),
]
)
)
...
...
src/diffusers/models/unet_ldm.py
View file @
a2b72faf
...
...
@@ -10,7 +10,7 @@ import torch.nn.functional as F
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Upsample
from
.resnet
import
Downsample
,
Upsample
def
exists
(
val
):
...
...
@@ -392,32 +392,6 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
return
x
class
Downsample
(
nn
.
Module
):
"""
A downsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param
use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D.
If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
,
padding
=
1
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
dims
=
dims
stride
=
2
if
dims
!=
3
else
(
1
,
2
,
2
)
if
use_conv
:
self
.
op
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
stride
=
stride
,
padding
=
padding
)
else
:
assert
self
.
channels
==
self
.
out_channels
self
.
op
=
avg_pool_nd
(
dims
,
kernel_size
=
stride
,
stride
=
stride
)
def
forward
(
self
,
x
):
assert
x
.
shape
[
1
]
==
self
.
channels
return
self
.
op
(
x
)
class
ResBlock
(
TimestepBlock
):
"""
A residual block that can optionally change the number of channels. :param channels: the number of input channels.
...
...
@@ -464,8 +438,8 @@ class ResBlock(TimestepBlock):
self
.
h_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
self
.
x_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
elif
down
:
self
.
h_upd
=
Downsample
(
channels
,
False
,
dims
)
self
.
x_upd
=
Downsample
(
channels
,
False
,
dims
)
self
.
h_upd
=
Downsample
(
channels
,
use_conv
=
False
,
dims
=
dims
,
padding
=
1
,
name
=
"op"
)
self
.
x_upd
=
Downsample
(
channels
,
use_conv
=
False
,
dims
=
dims
,
padding
=
1
,
name
=
"op"
)
else
:
self
.
h_upd
=
self
.
x_upd
=
nn
.
Identity
()
...
...
@@ -820,7 +794,9 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
down
=
True
,
)
if
resblock_updown
else
Downsample
(
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
)
else
Downsample
(
ch
,
use_conv
=
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
,
padding
=
1
,
name
=
"op"
)
)
)
ch
=
out_ch
...
...
@@ -1089,7 +1065,9 @@ class EncoderUNetModel(nn.Module):
down
=
True
,
)
if
resblock_updown
else
Downsample
(
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
)
else
Downsample
(
ch
,
use_conv
=
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
,
padding
=
1
,
name
=
"op"
)
)
)
ch
=
out_ch
...
...
tests/test_layers_utils.py
View file @
a2b72faf
...
...
@@ -22,7 +22,7 @@ import numpy as np
import
torch
from
diffusers.models.embeddings
import
get_timestep_embedding
from
diffusers.models.resnet
import
Upsample
from
diffusers.models.resnet
import
Downsample
,
Upsample
from
diffusers.testing_utils
import
floats_tensor
,
slow
,
torch_device
...
...
@@ -164,3 +164,58 @@ class UpsampleBlockTests(unittest.TestCase):
output_slice
=
upsampled
[
0
,
-
1
,
-
3
:,
-
3
:]
expected_slice
=
torch
.
tensor
([
-
0.3028
,
-
0.1582
,
0.0071
,
0.0350
,
-
0.4799
,
-
0.1139
,
0.1056
,
-
0.1153
,
-
0.1046
])
assert
torch
.
allclose
(
output_slice
.
flatten
(),
expected_slice
,
atol
=
1e-3
)
class
DownsampleBlockTests
(
unittest
.
TestCase
):
def
test_downsample_default
(
self
):
torch
.
manual_seed
(
0
)
sample
=
torch
.
randn
(
1
,
32
,
64
,
64
)
downsample
=
Downsample
(
channels
=
32
,
use_conv
=
False
)
with
torch
.
no_grad
():
downsampled
=
downsample
(
sample
)
assert
downsampled
.
shape
==
(
1
,
32
,
32
,
32
)
output_slice
=
downsampled
[
0
,
-
1
,
-
3
:,
-
3
:]
expected_slice
=
torch
.
tensor
([
-
0.0513
,
-
0.3889
,
0.0640
,
0.0836
,
-
0.5460
,
-
0.0341
,
-
0.0169
,
-
0.6967
,
0.1179
])
max_diff
=
(
output_slice
.
flatten
()
-
expected_slice
).
abs
().
sum
().
item
()
assert
max_diff
<=
1e-3
# assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-1)
def
test_downsample_with_conv
(
self
):
torch
.
manual_seed
(
0
)
sample
=
torch
.
randn
(
1
,
32
,
64
,
64
)
downsample
=
Downsample
(
channels
=
32
,
use_conv
=
True
)
with
torch
.
no_grad
():
downsampled
=
downsample
(
sample
)
assert
downsampled
.
shape
==
(
1
,
32
,
32
,
32
)
output_slice
=
downsampled
[
0
,
-
1
,
-
3
:,
-
3
:]
expected_slice
=
torch
.
tensor
(
[
0.9267
,
0.5878
,
0.3337
,
1.2321
,
-
0.1191
,
-
0.3984
,
-
0.7532
,
-
0.0715
,
-
0.3913
],
)
assert
torch
.
allclose
(
output_slice
.
flatten
(),
expected_slice
,
atol
=
1e-3
)
def
test_downsample_with_conv_pad1
(
self
):
torch
.
manual_seed
(
0
)
sample
=
torch
.
randn
(
1
,
32
,
64
,
64
)
downsample
=
Downsample
(
channels
=
32
,
use_conv
=
True
,
padding
=
1
)
with
torch
.
no_grad
():
downsampled
=
downsample
(
sample
)
assert
downsampled
.
shape
==
(
1
,
32
,
32
,
32
)
output_slice
=
downsampled
[
0
,
-
1
,
-
3
:,
-
3
:]
expected_slice
=
torch
.
tensor
([
0.9267
,
0.5878
,
0.3337
,
1.2321
,
-
0.1191
,
-
0.3984
,
-
0.7532
,
-
0.0715
,
-
0.3913
])
assert
torch
.
allclose
(
output_slice
.
flatten
(),
expected_slice
,
atol
=
1e-3
)
def
test_downsample_with_conv_out_dim
(
self
):
torch
.
manual_seed
(
0
)
sample
=
torch
.
randn
(
1
,
32
,
64
,
64
)
downsample
=
Downsample
(
channels
=
32
,
use_conv
=
True
,
out_channels
=
16
)
with
torch
.
no_grad
():
downsampled
=
downsample
(
sample
)
assert
downsampled
.
shape
==
(
1
,
16
,
32
,
32
)
output_slice
=
downsampled
[
0
,
-
1
,
-
3
:,
-
3
:]
expected_slice
=
torch
.
tensor
([
-
0.6586
,
0.5985
,
0.0721
,
0.1256
,
-
0.1492
,
0.4436
,
-
0.2544
,
0.5021
,
1.1522
])
assert
torch
.
allclose
(
output_slice
.
flatten
(),
expected_slice
,
atol
=
1e-3
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment