Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
a2b72faf
Commit
a2b72faf
authored
Jun 27, 2022
by
Patrick von Platen
Browse files
Merge branch 'main' of
https://github.com/huggingface/diffusers
into main
parents
c9504bba
26ea58d4
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
91 additions
and
98 deletions
+91
-98
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+15
-4
src/diffusers/models/unet.py
src/diffusers/models/unet.py
+2
-20
src/diffusers/models/unet_glide.py
src/diffusers/models/unet_glide.py
+6
-31
src/diffusers/models/unet_grad_tts.py
src/diffusers/models/unet_grad_tts.py
+3
-11
src/diffusers/models/unet_ldm.py
src/diffusers/models/unet_ldm.py
+9
-31
tests/test_layers_utils.py
tests/test_layers_utils.py
+56
-1
No files found.
src/diffusers/models/resnet.py
View file @
a2b72faf
...
@@ -101,7 +101,7 @@ class Downsample(nn.Module):
...
@@ -101,7 +101,7 @@ class Downsample(nn.Module):
downsampling occurs in the inner-two dimensions.
downsampling occurs in the inner-two dimensions.
"""
"""
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
,
padding
=
1
):
def
__init__
(
self
,
channels
,
use_conv
=
False
,
dims
=
2
,
out_channels
=
None
,
padding
=
1
,
name
=
"conv"
):
super
().
__init__
()
super
().
__init__
()
self
.
channels
=
channels
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
out_channels
=
out_channels
or
channels
...
@@ -109,18 +109,29 @@ class Downsample(nn.Module):
...
@@ -109,18 +109,29 @@ class Downsample(nn.Module):
self
.
dims
=
dims
self
.
dims
=
dims
self
.
padding
=
padding
self
.
padding
=
padding
stride
=
2
if
dims
!=
3
else
(
1
,
2
,
2
)
stride
=
2
if
dims
!=
3
else
(
1
,
2
,
2
)
self
.
name
=
name
if
use_conv
:
if
use_conv
:
self
.
down
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
stride
=
stride
,
padding
=
padding
)
conv
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
stride
=
stride
,
padding
=
padding
)
else
:
else
:
assert
self
.
channels
==
self
.
out_channels
assert
self
.
channels
==
self
.
out_channels
self
.
down
=
avg_pool_nd
(
dims
,
kernel_size
=
stride
,
stride
=
stride
)
conv
=
avg_pool_nd
(
dims
,
kernel_size
=
stride
,
stride
=
stride
)
if
name
==
"conv"
:
self
.
conv
=
conv
else
:
self
.
op
=
conv
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
assert
x
.
shape
[
1
]
==
self
.
channels
assert
x
.
shape
[
1
]
==
self
.
channels
if
self
.
use_conv
and
self
.
padding
==
0
and
self
.
dims
==
2
:
if
self
.
use_conv
and
self
.
padding
==
0
and
self
.
dims
==
2
:
pad
=
(
0
,
1
,
0
,
1
)
pad
=
(
0
,
1
,
0
,
1
)
x
=
F
.
pad
(
x
,
pad
,
mode
=
"constant"
,
value
=
0
)
x
=
F
.
pad
(
x
,
pad
,
mode
=
"constant"
,
value
=
0
)
return
self
.
down
(
x
)
if
self
.
name
==
"conv"
:
return
self
.
conv
(
x
)
else
:
return
self
.
op
(
x
)
class
UNetUpsample
(
nn
.
Module
):
class
UNetUpsample
(
nn
.
Module
):
...
...
src/diffusers/models/unet.py
View file @
a2b72faf
...
@@ -31,7 +31,7 @@ from tqdm import tqdm
...
@@ -31,7 +31,7 @@ from tqdm import tqdm
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Upsample
from
.resnet
import
Downsample
,
Upsample
def
nonlinearity
(
x
):
def
nonlinearity
(
x
):
...
@@ -43,24 +43,6 @@ def Normalize(in_channels):
...
@@ -43,24 +43,6 @@ def Normalize(in_channels):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
class
Downsample
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
with_conv
):
super
().
__init__
()
self
.
with_conv
=
with_conv
if
self
.
with_conv
:
# no asymmetric padding in torch conv, must do it ourselves
self
.
conv
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
3
,
stride
=
2
,
padding
=
0
)
def
forward
(
self
,
x
):
if
self
.
with_conv
:
pad
=
(
0
,
1
,
0
,
1
)
x
=
torch
.
nn
.
functional
.
pad
(
x
,
pad
,
mode
=
"constant"
,
value
=
0
)
x
=
self
.
conv
(
x
)
else
:
x
=
torch
.
nn
.
functional
.
avg_pool2d
(
x
,
kernel_size
=
2
,
stride
=
2
)
return
x
class
ResnetBlock
(
nn
.
Module
):
class
ResnetBlock
(
nn
.
Module
):
def
__init__
(
self
,
*
,
in_channels
,
out_channels
=
None
,
conv_shortcut
=
False
,
dropout
,
temb_channels
=
512
):
def
__init__
(
self
,
*
,
in_channels
,
out_channels
=
None
,
conv_shortcut
=
False
,
dropout
,
temb_channels
=
512
):
super
().
__init__
()
super
().
__init__
()
...
@@ -207,7 +189,7 @@ class UNetModel(ModelMixin, ConfigMixin):
...
@@ -207,7 +189,7 @@ class UNetModel(ModelMixin, ConfigMixin):
down
.
block
=
block
down
.
block
=
block
down
.
attn
=
attn
down
.
attn
=
attn
if
i_level
!=
self
.
num_resolutions
-
1
:
if
i_level
!=
self
.
num_resolutions
-
1
:
down
.
downsample
=
Downsample
(
block_in
,
resamp_with_conv
)
down
.
downsample
=
Downsample
(
block_in
,
use_conv
=
resamp_with_conv
,
padding
=
0
)
curr_res
=
curr_res
//
2
curr_res
=
curr_res
//
2
self
.
down
.
append
(
down
)
self
.
down
.
append
(
down
)
...
...
src/diffusers/models/unet_glide.py
View file @
a2b72faf
...
@@ -8,7 +8,7 @@ import torch.nn.functional as F
...
@@ -8,7 +8,7 @@ import torch.nn.functional as F
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Upsample
from
.resnet
import
Downsample
,
Upsample
def
convert_module_to_f16
(
l
):
def
convert_module_to_f16
(
l
):
...
@@ -124,33 +124,6 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
...
@@ -124,33 +124,6 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
return
x
return
x
class
Downsample
(
nn
.
Module
):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
dims
=
dims
stride
=
2
if
dims
!=
3
else
(
1
,
2
,
2
)
if
use_conv
:
self
.
op
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
stride
=
stride
,
padding
=
1
)
else
:
assert
self
.
channels
==
self
.
out_channels
self
.
op
=
avg_pool_nd
(
dims
,
kernel_size
=
stride
,
stride
=
stride
)
def
forward
(
self
,
x
):
assert
x
.
shape
[
1
]
==
self
.
channels
return
self
.
op
(
x
)
class
ResBlock
(
TimestepBlock
):
class
ResBlock
(
TimestepBlock
):
"""
"""
A residual block that can optionally change the number of channels.
A residual block that can optionally change the number of channels.
...
@@ -198,8 +171,8 @@ class ResBlock(TimestepBlock):
...
@@ -198,8 +171,8 @@ class ResBlock(TimestepBlock):
self
.
h_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
self
.
h_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
self
.
x_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
self
.
x_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
elif
down
:
elif
down
:
self
.
h_upd
=
Downsample
(
channels
,
False
,
dims
)
self
.
h_upd
=
Downsample
(
channels
,
use_conv
=
False
,
dims
=
dims
,
padding
=
1
,
name
=
"op"
)
self
.
x_upd
=
Downsample
(
channels
,
False
,
dims
)
self
.
x_upd
=
Downsample
(
channels
,
use_conv
=
False
,
dims
=
dims
,
padding
=
1
,
name
=
"op"
)
else
:
else
:
self
.
h_upd
=
self
.
x_upd
=
nn
.
Identity
()
self
.
h_upd
=
self
.
x_upd
=
nn
.
Identity
()
...
@@ -450,7 +423,9 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -450,7 +423,9 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
down
=
True
,
down
=
True
,
)
)
if
resblock_updown
if
resblock_updown
else
Downsample
(
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
)
else
Downsample
(
ch
,
use_conv
=
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
,
padding
=
1
,
name
=
"op"
)
)
)
)
)
ch
=
out_ch
ch
=
out_ch
...
...
src/diffusers/models/unet_grad_tts.py
View file @
a2b72faf
import
torch
import
torch
from
numpy
import
pad
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Upsample
from
.resnet
import
Downsample
,
Upsample
class
Mish
(
torch
.
nn
.
Module
):
class
Mish
(
torch
.
nn
.
Module
):
...
@@ -11,15 +12,6 @@ class Mish(torch.nn.Module):
...
@@ -11,15 +12,6 @@ class Mish(torch.nn.Module):
return
x
*
torch
.
tanh
(
torch
.
nn
.
functional
.
softplus
(
x
))
return
x
*
torch
.
tanh
(
torch
.
nn
.
functional
.
softplus
(
x
))
class
Downsample
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
(
Downsample
,
self
).
__init__
()
self
.
conv
=
torch
.
nn
.
Conv2d
(
dim
,
dim
,
3
,
2
,
1
)
def
forward
(
self
,
x
):
return
self
.
conv
(
x
)
class
Rezero
(
torch
.
nn
.
Module
):
class
Rezero
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
fn
):
def
__init__
(
self
,
fn
):
super
(
Rezero
,
self
).
__init__
()
super
(
Rezero
,
self
).
__init__
()
...
@@ -141,7 +133,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
...
@@ -141,7 +133,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
ResnetBlock
(
dim_in
,
dim_out
,
time_emb_dim
=
dim
),
ResnetBlock
(
dim_in
,
dim_out
,
time_emb_dim
=
dim
),
ResnetBlock
(
dim_out
,
dim_out
,
time_emb_dim
=
dim
),
ResnetBlock
(
dim_out
,
dim_out
,
time_emb_dim
=
dim
),
Residual
(
Rezero
(
LinearAttention
(
dim_out
))),
Residual
(
Rezero
(
LinearAttention
(
dim_out
))),
Downsample
(
dim_out
)
if
not
is_last
else
torch
.
nn
.
Identity
(),
Downsample
(
dim_out
,
use_conv
=
True
,
padding
=
1
)
if
not
is_last
else
torch
.
nn
.
Identity
(),
]
]
)
)
)
)
...
...
src/diffusers/models/unet_ldm.py
View file @
a2b72faf
...
@@ -10,7 +10,7 @@ import torch.nn.functional as F
...
@@ -10,7 +10,7 @@ import torch.nn.functional as F
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Upsample
from
.resnet
import
Downsample
,
Upsample
def
exists
(
val
):
def
exists
(
val
):
...
@@ -392,32 +392,6 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
...
@@ -392,32 +392,6 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
return
x
return
x
class
Downsample
(
nn
.
Module
):
"""
A downsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param
use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D.
If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
,
padding
=
1
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
dims
=
dims
stride
=
2
if
dims
!=
3
else
(
1
,
2
,
2
)
if
use_conv
:
self
.
op
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
stride
=
stride
,
padding
=
padding
)
else
:
assert
self
.
channels
==
self
.
out_channels
self
.
op
=
avg_pool_nd
(
dims
,
kernel_size
=
stride
,
stride
=
stride
)
def
forward
(
self
,
x
):
assert
x
.
shape
[
1
]
==
self
.
channels
return
self
.
op
(
x
)
class
ResBlock
(
TimestepBlock
):
class
ResBlock
(
TimestepBlock
):
"""
"""
A residual block that can optionally change the number of channels. :param channels: the number of input channels.
A residual block that can optionally change the number of channels. :param channels: the number of input channels.
...
@@ -464,8 +438,8 @@ class ResBlock(TimestepBlock):
...
@@ -464,8 +438,8 @@ class ResBlock(TimestepBlock):
self
.
h_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
self
.
h_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
self
.
x_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
self
.
x_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
elif
down
:
elif
down
:
self
.
h_upd
=
Downsample
(
channels
,
False
,
dims
)
self
.
h_upd
=
Downsample
(
channels
,
use_conv
=
False
,
dims
=
dims
,
padding
=
1
,
name
=
"op"
)
self
.
x_upd
=
Downsample
(
channels
,
False
,
dims
)
self
.
x_upd
=
Downsample
(
channels
,
use_conv
=
False
,
dims
=
dims
,
padding
=
1
,
name
=
"op"
)
else
:
else
:
self
.
h_upd
=
self
.
x_upd
=
nn
.
Identity
()
self
.
h_upd
=
self
.
x_upd
=
nn
.
Identity
()
...
@@ -820,7 +794,9 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -820,7 +794,9 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
down
=
True
,
down
=
True
,
)
)
if
resblock_updown
if
resblock_updown
else
Downsample
(
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
)
else
Downsample
(
ch
,
use_conv
=
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
,
padding
=
1
,
name
=
"op"
)
)
)
)
)
ch
=
out_ch
ch
=
out_ch
...
@@ -1089,7 +1065,9 @@ class EncoderUNetModel(nn.Module):
...
@@ -1089,7 +1065,9 @@ class EncoderUNetModel(nn.Module):
down
=
True
,
down
=
True
,
)
)
if
resblock_updown
if
resblock_updown
else
Downsample
(
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
)
else
Downsample
(
ch
,
use_conv
=
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
,
padding
=
1
,
name
=
"op"
)
)
)
)
)
ch
=
out_ch
ch
=
out_ch
...
...
tests/test_layers_utils.py
View file @
a2b72faf
...
@@ -22,7 +22,7 @@ import numpy as np
...
@@ -22,7 +22,7 @@ import numpy as np
import
torch
import
torch
from
diffusers.models.embeddings
import
get_timestep_embedding
from
diffusers.models.embeddings
import
get_timestep_embedding
from
diffusers.models.resnet
import
Upsample
from
diffusers.models.resnet
import
Downsample
,
Upsample
from
diffusers.testing_utils
import
floats_tensor
,
slow
,
torch_device
from
diffusers.testing_utils
import
floats_tensor
,
slow
,
torch_device
...
@@ -164,3 +164,58 @@ class UpsampleBlockTests(unittest.TestCase):
...
@@ -164,3 +164,58 @@ class UpsampleBlockTests(unittest.TestCase):
output_slice
=
upsampled
[
0
,
-
1
,
-
3
:,
-
3
:]
output_slice
=
upsampled
[
0
,
-
1
,
-
3
:,
-
3
:]
expected_slice
=
torch
.
tensor
([
-
0.3028
,
-
0.1582
,
0.0071
,
0.0350
,
-
0.4799
,
-
0.1139
,
0.1056
,
-
0.1153
,
-
0.1046
])
expected_slice
=
torch
.
tensor
([
-
0.3028
,
-
0.1582
,
0.0071
,
0.0350
,
-
0.4799
,
-
0.1139
,
0.1056
,
-
0.1153
,
-
0.1046
])
assert
torch
.
allclose
(
output_slice
.
flatten
(),
expected_slice
,
atol
=
1e-3
)
assert
torch
.
allclose
(
output_slice
.
flatten
(),
expected_slice
,
atol
=
1e-3
)
class
DownsampleBlockTests
(
unittest
.
TestCase
):
def
test_downsample_default
(
self
):
torch
.
manual_seed
(
0
)
sample
=
torch
.
randn
(
1
,
32
,
64
,
64
)
downsample
=
Downsample
(
channels
=
32
,
use_conv
=
False
)
with
torch
.
no_grad
():
downsampled
=
downsample
(
sample
)
assert
downsampled
.
shape
==
(
1
,
32
,
32
,
32
)
output_slice
=
downsampled
[
0
,
-
1
,
-
3
:,
-
3
:]
expected_slice
=
torch
.
tensor
([
-
0.0513
,
-
0.3889
,
0.0640
,
0.0836
,
-
0.5460
,
-
0.0341
,
-
0.0169
,
-
0.6967
,
0.1179
])
max_diff
=
(
output_slice
.
flatten
()
-
expected_slice
).
abs
().
sum
().
item
()
assert
max_diff
<=
1e-3
# assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-1)
def
test_downsample_with_conv
(
self
):
torch
.
manual_seed
(
0
)
sample
=
torch
.
randn
(
1
,
32
,
64
,
64
)
downsample
=
Downsample
(
channels
=
32
,
use_conv
=
True
)
with
torch
.
no_grad
():
downsampled
=
downsample
(
sample
)
assert
downsampled
.
shape
==
(
1
,
32
,
32
,
32
)
output_slice
=
downsampled
[
0
,
-
1
,
-
3
:,
-
3
:]
expected_slice
=
torch
.
tensor
(
[
0.9267
,
0.5878
,
0.3337
,
1.2321
,
-
0.1191
,
-
0.3984
,
-
0.7532
,
-
0.0715
,
-
0.3913
],
)
assert
torch
.
allclose
(
output_slice
.
flatten
(),
expected_slice
,
atol
=
1e-3
)
def
test_downsample_with_conv_pad1
(
self
):
torch
.
manual_seed
(
0
)
sample
=
torch
.
randn
(
1
,
32
,
64
,
64
)
downsample
=
Downsample
(
channels
=
32
,
use_conv
=
True
,
padding
=
1
)
with
torch
.
no_grad
():
downsampled
=
downsample
(
sample
)
assert
downsampled
.
shape
==
(
1
,
32
,
32
,
32
)
output_slice
=
downsampled
[
0
,
-
1
,
-
3
:,
-
3
:]
expected_slice
=
torch
.
tensor
([
0.9267
,
0.5878
,
0.3337
,
1.2321
,
-
0.1191
,
-
0.3984
,
-
0.7532
,
-
0.0715
,
-
0.3913
])
assert
torch
.
allclose
(
output_slice
.
flatten
(),
expected_slice
,
atol
=
1e-3
)
def
test_downsample_with_conv_out_dim
(
self
):
torch
.
manual_seed
(
0
)
sample
=
torch
.
randn
(
1
,
32
,
64
,
64
)
downsample
=
Downsample
(
channels
=
32
,
use_conv
=
True
,
out_channels
=
16
)
with
torch
.
no_grad
():
downsampled
=
downsample
(
sample
)
assert
downsampled
.
shape
==
(
1
,
16
,
32
,
32
)
output_slice
=
downsampled
[
0
,
-
1
,
-
3
:,
-
3
:]
expected_slice
=
torch
.
tensor
([
-
0.6586
,
0.5985
,
0.0721
,
0.1256
,
-
0.1492
,
0.4436
,
-
0.2544
,
0.5021
,
1.1522
])
assert
torch
.
allclose
(
output_slice
.
flatten
(),
expected_slice
,
atol
=
1e-3
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment