Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
3286dac6
Commit
3286dac6
authored
Jun 27, 2022
by
anton-l
Browse files
Merge remote-tracking branch 'origin/main'
parents
1cf7933e
ee010726
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
149 additions
and
233 deletions
+149
-233
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+7
-87
src/diffusers/models/unet.py
src/diffusers/models/unet.py
+2
-15
src/diffusers/models/unet_glide.py
src/diffusers/models/unet_glide.py
+4
-33
src/diffusers/models/unet_grad_tts.py
src/diffusers/models/unet_grad_tts.py
+2
-10
src/diffusers/models/unet_ldm.py
src/diffusers/models/unet_ldm.py
+60
-86
tests/test_layers_utils.py
tests/test_layers_utils.py
+51
-0
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+23
-2
No files found.
src/diffusers/models/resnet.py
View file @
3286dac6
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -29,6 +28,7 @@ def conv_nd(dims, *args, **kwargs):
...
@@ -29,6 +28,7 @@ def conv_nd(dims, *args, **kwargs):
return
nn
.
Conv3d
(
*
args
,
**
kwargs
)
return
nn
.
Conv3d
(
*
args
,
**
kwargs
)
raise
ValueError
(
f
"unsupported dimensions:
{
dims
}
"
)
raise
ValueError
(
f
"unsupported dimensions:
{
dims
}
"
)
def
conv_transpose_nd
(
dims
,
*
args
,
**
kwargs
):
def
conv_transpose_nd
(
dims
,
*
args
,
**
kwargs
):
"""
"""
Create a 1D, 2D, or 3D convolution module.
Create a 1D, 2D, or 3D convolution module.
...
@@ -64,7 +64,7 @@ class Upsample(nn.Module):
...
@@ -64,7 +64,7 @@ class Upsample(nn.Module):
upsampling occurs in the inner-two dimensions.
upsampling occurs in the inner-two dimensions.
"""
"""
def
__init__
(
self
,
channels
,
use_conv
,
use_conv_transpose
=
False
,
dims
=
2
,
out_channels
=
None
):
def
__init__
(
self
,
channels
,
use_conv
=
False
,
use_conv_transpose
=
False
,
dims
=
2
,
out_channels
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
channels
=
channels
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
out_channels
=
out_channels
or
channels
...
@@ -73,7 +73,7 @@ class Upsample(nn.Module):
...
@@ -73,7 +73,7 @@ class Upsample(nn.Module):
self
.
use_conv_transpose
=
use_conv_transpose
self
.
use_conv_transpose
=
use_conv_transpose
if
use_conv_transpose
:
if
use_conv_transpose
:
self
.
conv
=
conv_transpose_nd
(
dims
,
channels
,
out_channels
,
4
,
2
,
1
)
self
.
conv
=
conv_transpose_nd
(
dims
,
channels
,
self
.
out_channels
,
4
,
2
,
1
)
elif
use_conv
:
elif
use_conv
:
self
.
conv
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
self
.
conv
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
...
@@ -125,87 +125,7 @@ class Downsample(nn.Module):
...
@@ -125,87 +125,7 @@ class Downsample(nn.Module):
return
self
.
down
(
x
)
return
self
.
down
(
x
)
class
UNetUpsample
(
nn
.
Module
):
# TODO (patil-suraj): needs test
def
__init__
(
self
,
in_channels
,
with_conv
):
super
().
__init__
()
self
.
with_conv
=
with_conv
if
self
.
with_conv
:
self
.
conv
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
x
):
x
=
torch
.
nn
.
functional
.
interpolate
(
x
,
scale_factor
=
2.0
,
mode
=
"nearest"
)
if
self
.
with_conv
:
x
=
self
.
conv
(
x
)
return
x
class
GlideUpsample
(
nn
.
Module
):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
dims
=
dims
if
use_conv
:
self
.
conv
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
def
forward
(
self
,
x
):
assert
x
.
shape
[
1
]
==
self
.
channels
if
self
.
dims
==
3
:
x
=
F
.
interpolate
(
x
,
(
x
.
shape
[
2
],
x
.
shape
[
3
]
*
2
,
x
.
shape
[
4
]
*
2
),
mode
=
"nearest"
)
else
:
x
=
F
.
interpolate
(
x
,
scale_factor
=
2
,
mode
=
"nearest"
)
if
self
.
use_conv
:
x
=
self
.
conv
(
x
)
return
x
class
LDMUpsample
(
nn
.
Module
):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
,
padding
=
1
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
dims
=
dims
if
use_conv
:
self
.
conv
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
padding
=
padding
)
def
forward
(
self
,
x
):
assert
x
.
shape
[
1
]
==
self
.
channels
if
self
.
dims
==
3
:
x
=
F
.
interpolate
(
x
,
(
x
.
shape
[
2
],
x
.
shape
[
3
]
*
2
,
x
.
shape
[
4
]
*
2
),
mode
=
"nearest"
)
else
:
x
=
F
.
interpolate
(
x
,
scale_factor
=
2
,
mode
=
"nearest"
)
if
self
.
use_conv
:
x
=
self
.
conv
(
x
)
return
x
class
GradTTSUpsample
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
(
Upsample
,
self
).
__init__
()
self
.
conv
=
torch
.
nn
.
ConvTranspose2d
(
dim
,
dim
,
4
,
2
,
1
)
def
forward
(
self
,
x
):
return
self
.
conv
(
x
)
class
Upsample1d
(
nn
.
Module
):
class
Upsample1d
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
def
__init__
(
self
,
dim
):
super
().
__init__
()
super
().
__init__
()
...
...
src/diffusers/models/unet.py
View file @
3286dac6
...
@@ -31,6 +31,7 @@ from tqdm import tqdm
...
@@ -31,6 +31,7 @@ from tqdm import tqdm
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Upsample
def
nonlinearity
(
x
):
def
nonlinearity
(
x
):
...
@@ -42,20 +43,6 @@ def Normalize(in_channels):
...
@@ -42,20 +43,6 @@ def Normalize(in_channels):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
class
Upsample
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
with_conv
):
super
().
__init__
()
self
.
with_conv
=
with_conv
if
self
.
with_conv
:
self
.
conv
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
x
):
x
=
torch
.
nn
.
functional
.
interpolate
(
x
,
scale_factor
=
2.0
,
mode
=
"nearest"
)
if
self
.
with_conv
:
x
=
self
.
conv
(
x
)
return
x
class
Downsample
(
nn
.
Module
):
class
Downsample
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
with_conv
):
def
__init__
(
self
,
in_channels
,
with_conv
):
super
().
__init__
()
super
().
__init__
()
...
@@ -259,7 +246,7 @@ class UNetModel(ModelMixin, ConfigMixin):
...
@@ -259,7 +246,7 @@ class UNetModel(ModelMixin, ConfigMixin):
up
.
block
=
block
up
.
block
=
block
up
.
attn
=
attn
up
.
attn
=
attn
if
i_level
!=
0
:
if
i_level
!=
0
:
up
.
upsample
=
Upsample
(
block_in
,
resamp_with_conv
)
up
.
upsample
=
Upsample
(
block_in
,
use_conv
=
resamp_with_conv
)
curr_res
=
curr_res
*
2
curr_res
=
curr_res
*
2
self
.
up
.
insert
(
0
,
up
)
# prepend to get consistent order
self
.
up
.
insert
(
0
,
up
)
# prepend to get consistent order
...
...
src/diffusers/models/unet_glide.py
View file @
3286dac6
...
@@ -8,6 +8,7 @@ import torch.nn.functional as F
...
@@ -8,6 +8,7 @@ import torch.nn.functional as F
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Upsample
def
convert_module_to_f16
(
l
):
def
convert_module_to_f16
(
l
):
...
@@ -125,36 +126,6 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
...
@@ -125,36 +126,6 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
return
x
return
x
class
Upsample
(
nn
.
Module
):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
dims
=
dims
if
use_conv
:
self
.
conv
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
def
forward
(
self
,
x
):
assert
x
.
shape
[
1
]
==
self
.
channels
if
self
.
dims
==
3
:
x
=
F
.
interpolate
(
x
,
(
x
.
shape
[
2
],
x
.
shape
[
3
]
*
2
,
x
.
shape
[
4
]
*
2
),
mode
=
"nearest"
)
else
:
x
=
F
.
interpolate
(
x
,
scale_factor
=
2
,
mode
=
"nearest"
)
if
self
.
use_conv
:
x
=
self
.
conv
(
x
)
return
x
class
Downsample
(
nn
.
Module
):
class
Downsample
(
nn
.
Module
):
"""
"""
A downsampling layer with an optional convolution.
A downsampling layer with an optional convolution.
...
@@ -231,8 +202,8 @@ class ResBlock(TimestepBlock):
...
@@ -231,8 +202,8 @@ class ResBlock(TimestepBlock):
self
.
updown
=
up
or
down
self
.
updown
=
up
or
down
if
up
:
if
up
:
self
.
h_upd
=
Upsample
(
channels
,
False
,
dims
)
self
.
h_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
self
.
x_upd
=
Upsample
(
channels
,
False
,
dims
)
self
.
x_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
elif
down
:
elif
down
:
self
.
h_upd
=
Downsample
(
channels
,
False
,
dims
)
self
.
h_upd
=
Downsample
(
channels
,
False
,
dims
)
self
.
x_upd
=
Downsample
(
channels
,
False
,
dims
)
self
.
x_upd
=
Downsample
(
channels
,
False
,
dims
)
...
@@ -567,7 +538,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -567,7 +538,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
up
=
True
,
up
=
True
,
)
)
if
resblock_updown
if
resblock_updown
else
Upsample
(
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
)
else
Upsample
(
ch
,
use_conv
=
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
)
)
)
ds
//=
2
ds
//=
2
self
.
output_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
self
.
output_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
...
...
src/diffusers/models/unet_grad_tts.py
View file @
3286dac6
...
@@ -10,6 +10,7 @@ except:
...
@@ -10,6 +10,7 @@ except:
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Upsample
class
Mish
(
torch
.
nn
.
Module
):
class
Mish
(
torch
.
nn
.
Module
):
...
@@ -17,15 +18,6 @@ class Mish(torch.nn.Module):
...
@@ -17,15 +18,6 @@ class Mish(torch.nn.Module):
return
x
*
torch
.
tanh
(
torch
.
nn
.
functional
.
softplus
(
x
))
return
x
*
torch
.
tanh
(
torch
.
nn
.
functional
.
softplus
(
x
))
class
Upsample
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
(
Upsample
,
self
).
__init__
()
self
.
conv
=
torch
.
nn
.
ConvTranspose2d
(
dim
,
dim
,
4
,
2
,
1
)
def
forward
(
self
,
x
):
return
self
.
conv
(
x
)
class
Downsample
(
torch
.
nn
.
Module
):
class
Downsample
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
):
def
__init__
(
self
,
dim
):
super
(
Downsample
,
self
).
__init__
()
super
(
Downsample
,
self
).
__init__
()
...
@@ -166,7 +158,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
...
@@ -166,7 +158,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
ResnetBlock
(
dim_out
*
2
,
dim_in
,
time_emb_dim
=
dim
),
ResnetBlock
(
dim_out
*
2
,
dim_in
,
time_emb_dim
=
dim
),
ResnetBlock
(
dim_in
,
dim_in
,
time_emb_dim
=
dim
),
ResnetBlock
(
dim_in
,
dim_in
,
time_emb_dim
=
dim
),
Residual
(
Rezero
(
LinearAttention
(
dim_in
))),
Residual
(
Rezero
(
LinearAttention
(
dim_in
))),
Upsample
(
dim_in
),
Upsample
(
dim_in
,
use_conv_transpose
=
True
),
]
]
)
)
)
)
...
...
src/diffusers/models/unet_ldm.py
View file @
3286dac6
...
@@ -17,6 +17,7 @@ except:
...
@@ -17,6 +17,7 @@ except:
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Upsample
def
exists
(
val
):
def
exists
(
val
):
...
@@ -81,60 +82,62 @@ def Normalize(in_channels):
...
@@ -81,60 +82,62 @@ def Normalize(in_channels):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
class
LinearAttention
(
nn
.
Module
):
# class LinearAttention(nn.Module):
def
__init__
(
self
,
dim
,
heads
=
4
,
dim_head
=
32
):
# def __init__(self, dim, heads=4, dim_head=32):
super
().
__init__
()
# super().__init__()
self
.
heads
=
heads
# self.heads = heads
hidden_dim
=
dim_head
*
heads
# hidden_dim = dim_head * heads
self
.
to_qkv
=
nn
.
Conv2d
(
dim
,
hidden_dim
*
3
,
1
,
bias
=
False
)
# self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self
.
to_out
=
nn
.
Conv2d
(
hidden_dim
,
dim
,
1
)
# self.to_out = nn.Conv2d(hidden_dim, dim, 1)
#
def
forward
(
self
,
x
):
# def forward(self, x):
b
,
c
,
h
,
w
=
x
.
shape
# b, c, h, w = x.shape
qkv
=
self
.
to_qkv
(
x
)
# qkv = self.to_qkv(x)
q
,
k
,
v
=
rearrange
(
qkv
,
"b (qkv heads c) h w -> qkv b heads c (h w)"
,
heads
=
self
.
heads
,
qkv
=
3
)
# q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
k
=
k
.
softmax
(
dim
=-
1
)
# import ipdb; ipdb.set_trace()
context
=
torch
.
einsum
(
"bhdn,bhen->bhde"
,
k
,
v
)
# k = k.softmax(dim=-1)
out
=
torch
.
einsum
(
"bhde,bhdn->bhen"
,
context
,
q
)
# context = torch.einsum("bhdn,bhen->bhde", k, v)
out
=
rearrange
(
out
,
"b heads c (h w) -> b (heads c) h w"
,
heads
=
self
.
heads
,
h
=
h
,
w
=
w
)
# out = torch.einsum("bhde,bhdn->bhen", context, q)
return
self
.
to_out
(
out
)
# out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
# return self.to_out(out)
#
class
SpatialSelfAttention
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
):
# class SpatialSelfAttention(nn.Module):
super
().
__init__
()
# def __init__(self, in_channels):
self
.
in_channels
=
in_channels
# super().__init__()
# self.in_channels = in_channels
self
.
norm
=
Normalize
(
in_channels
)
#
self
.
q
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
# self.norm = Normalize(in_channels)
self
.
k
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
# self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self
.
v
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
# self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self
.
proj_out
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
# self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
# self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
def
forward
(
self
,
x
):
#
h_
=
x
# def forward(self, x):
h_
=
self
.
norm
(
h_
)
# h_ = x
q
=
self
.
q
(
h_
)
# h_ = self.norm(h_)
k
=
self
.
k
(
h_
)
# q = self.q(h_)
v
=
self
.
v
(
h_
)
# k = self.k(h_)
# v = self.v(h_)
# compute attention
#
b
,
c
,
h
,
w
=
q
.
shape
# compute attention
q
=
rearrange
(
q
,
"b c h w -> b (h w) c"
)
# b, c, h, w = q.shape
k
=
rearrange
(
k
,
"b c h w -> b c (h w)"
)
# q = rearrange(q, "b c h w -> b (h w) c")
w_
=
torch
.
einsum
(
"bij,bjk->bik"
,
q
,
k
)
# k = rearrange(k, "b c h w -> b c (h w)")
# w_ = torch.einsum("bij,bjk->bik", q, k)
w_
=
w_
*
(
int
(
c
)
**
(
-
0.5
))
#
w_
=
torch
.
nn
.
functional
.
softmax
(
w_
,
dim
=
2
)
# w_ = w_ * (int(c) ** (-0.5))
# w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
#
v
=
rearrange
(
v
,
"b c h w -> b c (h w)"
)
# attend to values
w_
=
rearrange
(
w_
,
"b i j -> b j i"
)
# v = rearrange(v, "b c h w -> b c (h w)")
h_
=
torch
.
einsum
(
"bij,bjk->bik"
,
v
,
w_
)
# w_ = rearrange(w_, "b i j -> b j i")
h_
=
rearrange
(
h_
,
"b c (h w) -> b c h w"
,
h
=
h
)
# h_ = torch.einsum("bij,bjk->bik", v, w_)
h_
=
self
.
proj_out
(
h_
)
# h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
# h_ = self.proj_out(h_)
return
x
+
h_
#
# return x + h_
#
class
CrossAttention
(
nn
.
Module
):
class
CrossAttention
(
nn
.
Module
):
...
@@ -377,35 +380,6 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
...
@@ -377,35 +380,6 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
return
x
return
x
class
Upsample
(
nn
.
Module
):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
,
padding
=
1
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
dims
=
dims
if
use_conv
:
self
.
conv
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
padding
=
padding
)
def
forward
(
self
,
x
):
assert
x
.
shape
[
1
]
==
self
.
channels
if
self
.
dims
==
3
:
x
=
F
.
interpolate
(
x
,
(
x
.
shape
[
2
],
x
.
shape
[
3
]
*
2
,
x
.
shape
[
4
]
*
2
),
mode
=
"nearest"
)
else
:
x
=
F
.
interpolate
(
x
,
scale_factor
=
2
,
mode
=
"nearest"
)
if
self
.
use_conv
:
x
=
self
.
conv
(
x
)
return
x
class
Downsample
(
nn
.
Module
):
class
Downsample
(
nn
.
Module
):
"""
"""
A downsampling layer with an optional convolution.
A downsampling layer with an optional convolution.
...
@@ -480,8 +454,8 @@ class ResBlock(TimestepBlock):
...
@@ -480,8 +454,8 @@ class ResBlock(TimestepBlock):
self
.
updown
=
up
or
down
self
.
updown
=
up
or
down
if
up
:
if
up
:
self
.
h_upd
=
Upsample
(
channels
,
False
,
dims
)
self
.
h_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
self
.
x_upd
=
Upsample
(
channels
,
False
,
dims
)
self
.
x_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
elif
down
:
elif
down
:
self
.
h_upd
=
Downsample
(
channels
,
False
,
dims
)
self
.
h_upd
=
Downsample
(
channels
,
False
,
dims
)
self
.
x_upd
=
Downsample
(
channels
,
False
,
dims
)
self
.
x_upd
=
Downsample
(
channels
,
False
,
dims
)
...
@@ -948,7 +922,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -948,7 +922,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
up
=
True
,
up
=
True
,
)
)
if
resblock_updown
if
resblock_updown
else
Upsample
(
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
)
else
Upsample
(
ch
,
use_conv
=
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
)
)
)
ds
//=
2
ds
//=
2
self
.
output_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
self
.
output_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
...
...
tests/test_layers_utils.py
View file @
3286dac6
...
@@ -22,6 +22,7 @@ import numpy as np
...
@@ -22,6 +22,7 @@ import numpy as np
import
torch
import
torch
from
diffusers.models.embeddings
import
get_timestep_embedding
from
diffusers.models.embeddings
import
get_timestep_embedding
from
diffusers.models.resnet
import
Upsample
from
diffusers.testing_utils
import
floats_tensor
,
slow
,
torch_device
from
diffusers.testing_utils
import
floats_tensor
,
slow
,
torch_device
...
@@ -113,3 +114,53 @@ class EmbeddingsTests(unittest.TestCase):
...
@@ -113,3 +114,53 @@ class EmbeddingsTests(unittest.TestCase):
torch
.
tensor
([
-
0.9801
,
-
0.9464
,
-
0.9349
,
-
0.3952
,
0.8887
,
-
0.9709
,
0.5299
,
-
0.2853
,
-
0.9927
]),
torch
.
tensor
([
-
0.9801
,
-
0.9464
,
-
0.9349
,
-
0.3952
,
0.8887
,
-
0.9709
,
0.5299
,
-
0.2853
,
-
0.9927
]),
1e-3
,
1e-3
,
)
)
class
UpsampleBlockTests
(
unittest
.
TestCase
):
def
test_upsample_default
(
self
):
torch
.
manual_seed
(
0
)
sample
=
torch
.
randn
(
1
,
32
,
32
,
32
)
upsample
=
Upsample
(
channels
=
32
,
use_conv
=
False
)
with
torch
.
no_grad
():
upsampled
=
upsample
(
sample
)
assert
upsampled
.
shape
==
(
1
,
32
,
64
,
64
)
output_slice
=
upsampled
[
0
,
-
1
,
-
3
:,
-
3
:]
expected_slice
=
torch
.
tensor
([
-
0.2173
,
-
1.2079
,
-
1.2079
,
0.2952
,
1.1254
,
1.1254
,
0.2952
,
1.1254
,
1.1254
])
assert
torch
.
allclose
(
output_slice
.
flatten
(),
expected_slice
,
atol
=
1e-3
)
def
test_upsample_with_conv
(
self
):
torch
.
manual_seed
(
0
)
sample
=
torch
.
randn
(
1
,
32
,
32
,
32
)
upsample
=
Upsample
(
channels
=
32
,
use_conv
=
True
)
with
torch
.
no_grad
():
upsampled
=
upsample
(
sample
)
assert
upsampled
.
shape
==
(
1
,
32
,
64
,
64
)
output_slice
=
upsampled
[
0
,
-
1
,
-
3
:,
-
3
:]
expected_slice
=
torch
.
tensor
([
0.7145
,
1.3773
,
0.3492
,
0.8448
,
1.0839
,
-
0.3341
,
0.5956
,
0.1250
,
-
0.4841
])
assert
torch
.
allclose
(
output_slice
.
flatten
(),
expected_slice
,
atol
=
1e-3
)
def
test_upsample_with_conv_out_dim
(
self
):
torch
.
manual_seed
(
0
)
sample
=
torch
.
randn
(
1
,
32
,
32
,
32
)
upsample
=
Upsample
(
channels
=
32
,
use_conv
=
True
,
out_channels
=
64
)
with
torch
.
no_grad
():
upsampled
=
upsample
(
sample
)
assert
upsampled
.
shape
==
(
1
,
64
,
64
,
64
)
output_slice
=
upsampled
[
0
,
-
1
,
-
3
:,
-
3
:]
expected_slice
=
torch
.
tensor
([
0.2703
,
0.1656
,
-
0.2538
,
-
0.0553
,
-
0.2984
,
0.1044
,
0.1155
,
0.2579
,
0.7755
])
assert
torch
.
allclose
(
output_slice
.
flatten
(),
expected_slice
,
atol
=
1e-3
)
def
test_upsample_with_transpose
(
self
):
torch
.
manual_seed
(
0
)
sample
=
torch
.
randn
(
1
,
32
,
32
,
32
)
upsample
=
Upsample
(
channels
=
32
,
use_conv
=
False
,
use_conv_transpose
=
True
)
with
torch
.
no_grad
():
upsampled
=
upsample
(
sample
)
assert
upsampled
.
shape
==
(
1
,
32
,
64
,
64
)
output_slice
=
upsampled
[
0
,
-
1
,
-
3
:,
-
3
:]
expected_slice
=
torch
.
tensor
([
-
0.3028
,
-
0.1582
,
0.0071
,
0.0350
,
-
0.4799
,
-
0.1139
,
0.1056
,
-
0.1153
,
-
0.1046
])
assert
torch
.
allclose
(
output_slice
.
flatten
(),
expected_slice
,
atol
=
1e-3
)
tests/test_modeling_utils.py
View file @
3286dac6
...
@@ -21,7 +21,7 @@ import unittest
...
@@ -21,7 +21,7 @@ import unittest
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
diffusers
import
(
from
diffusers
import
(
# GradTTSPipeline,
BDDMPipeline
,
BDDMPipeline
,
DDIMPipeline
,
DDIMPipeline
,
DDIMScheduler
,
DDIMScheduler
,
...
@@ -30,7 +30,6 @@ from diffusers import (
...
@@ -30,7 +30,6 @@ from diffusers import (
GlidePipeline
,
GlidePipeline
,
GlideSuperResUNetModel
,
GlideSuperResUNetModel
,
GlideTextToImageUNetModel
,
GlideTextToImageUNetModel
,
GradTTSPipeline
,
GradTTSScheduler
,
GradTTSScheduler
,
LatentDiffusionPipeline
,
LatentDiffusionPipeline
,
NCSNpp
,
NCSNpp
,
...
@@ -511,6 +510,28 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -511,6 +510,28 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
def
test_output_pretrained_spatial_transformer
(
self
):
model
=
UNetLDMModel
.
from_pretrained
(
"fusing/unet-ldm-dummy-spatial"
)
model
.
eval
()
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
noise
=
torch
.
randn
(
1
,
model
.
config
.
in_channels
,
model
.
config
.
image_size
,
model
.
config
.
image_size
)
context
=
torch
.
ones
((
1
,
16
,
64
),
dtype
=
torch
.
float32
)
time_step
=
torch
.
tensor
([
10
]
*
noise
.
shape
[
0
])
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
,
context
=
context
)
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
expected_output_slice
=
torch
.
tensor
([
61.3445
,
56.9005
,
29.4339
,
59.5497
,
60.7375
,
34.1719
,
48.1951
,
42.6569
,
25.0890
])
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
class
UNetGradTTSModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
class
UNetGradTTSModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
UNetGradTTSModel
model_class
=
UNetGradTTSModel
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment