Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
23904d54
Commit
23904d54
authored
Jul 01, 2022
by
patil-suraj
Browse files
Merge branch 'main' of
https://github.com/huggingface/diffusers
into conversion-scripts
parents
32b93da8
c691bb2f
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
476 additions
and
1334 deletions
+476
-1334
setup.py
setup.py
+1
-1
src/diffusers/dependency_versions_table.py
src/diffusers/dependency_versions_table.py
+1
-0
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+153
-276
src/diffusers/models/unet.py
src/diffusers/models/unet.py
+0
-42
src/diffusers/models/unet_glide.py
src/diffusers/models/unet_glide.py
+53
-60
src/diffusers/models/unet_ldm.py
src/diffusers/models/unet_ldm.py
+141
-507
src/diffusers/models/unet_rl.py
src/diffusers/models/unet_rl.py
+3
-21
src/diffusers/models/unet_sde_score_estimation.py
src/diffusers/models/unet_sde_score_estimation.py
+104
-405
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+20
-22
No files found.
setup.py
View file @
23904d54
...
...
@@ -88,7 +88,7 @@ _deps = [
"requests"
,
"torch>=1.4"
,
"tensorboard"
,
"modelcards=0.1.4"
"modelcards=
=
0.1.4"
]
# this is a lookup table with items like:
...
...
src/diffusers/dependency_versions_table.py
View file @
23904d54
...
...
@@ -14,4 +14,5 @@ deps = {
"requests"
:
"requests"
,
"torch"
:
"torch>=1.4"
,
"tensorboard"
:
"tensorboard"
,
"modelcards"
:
"modelcards==0.1.4"
,
}
src/diffusers/models/resnet.py
View file @
23904d54
import
string
from
abc
import
abstractmethod
from
functools
import
partial
import
numpy
as
np
import
torch
...
...
@@ -79,18 +79,25 @@ class Upsample(nn.Module):
upsampling occurs in the inner-two dimensions.
"""
def
__init__
(
self
,
channels
,
use_conv
=
False
,
use_conv_transpose
=
False
,
dims
=
2
,
out_channels
=
None
):
def
__init__
(
self
,
channels
,
use_conv
=
False
,
use_conv_transpose
=
False
,
dims
=
2
,
out_channels
=
None
,
name
=
"conv"
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
dims
=
dims
self
.
use_conv_transpose
=
use_conv_transpose
self
.
name
=
name
conv
=
None
if
use_conv_transpose
:
self
.
conv
=
conv_transpose_nd
(
dims
,
channels
,
self
.
out_channels
,
4
,
2
,
1
)
conv
=
conv_transpose_nd
(
dims
,
channels
,
self
.
out_channels
,
4
,
2
,
1
)
elif
use_conv
:
self
.
conv
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
conv
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
if
name
==
"conv"
:
self
.
conv
=
conv
else
:
self
.
Conv2d_0
=
conv
def
forward
(
self
,
x
):
assert
x
.
shape
[
1
]
==
self
.
channels
...
...
@@ -103,7 +110,10 @@ class Upsample(nn.Module):
x
=
F
.
interpolate
(
x
,
scale_factor
=
2.0
,
mode
=
"nearest"
)
if
self
.
use_conv
:
x
=
self
.
conv
(
x
)
if
self
.
name
==
"conv"
:
x
=
self
.
conv
(
x
)
else
:
x
=
self
.
Conv2d_0
(
x
)
return
x
...
...
@@ -135,6 +145,8 @@ class Downsample(nn.Module):
if
name
==
"conv"
:
self
.
conv
=
conv
elif
name
==
"Conv2d_0"
:
self
.
Conv2d_0
=
conv
else
:
self
.
op
=
conv
...
...
@@ -146,6 +158,8 @@ class Downsample(nn.Module):
if
self
.
name
==
"conv"
:
return
self
.
conv
(
x
)
elif
self
.
name
==
"Conv2d_0"
:
return
self
.
Conv2d_0
(
x
)
else
:
return
self
.
op
(
x
)
...
...
@@ -162,110 +176,7 @@ class Downsample(nn.Module):
# RESNETS
# unet_glide.py & unet_ldm.py
class
ResBlock
(
TimestepBlock
):
"""
A residual block that can optionally change the number of channels.
:param channels: the number of input channels. :param emb_channels: the number of timestep embedding channels.
:param dropout: the rate of dropout. :param out_channels: if specified, the number of out channels. :param
use_conv: if True and out_channels is specified, use a spatial
convolution instead of a smaller 1x1 convolution to change the channels in the skip connection.
:param dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing
on this module. :param up: if True, use this block for upsampling. :param down: if True, use this block for
downsampling.
"""
def
__init__
(
self
,
channels
,
emb_channels
,
dropout
,
out_channels
=
None
,
use_conv
=
False
,
use_scale_shift_norm
=
False
,
dims
=
2
,
use_checkpoint
=
False
,
up
=
False
,
down
=
False
,
):
super
().
__init__
()
self
.
channels
=
channels
self
.
emb_channels
=
emb_channels
self
.
dropout
=
dropout
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
use_checkpoint
=
use_checkpoint
self
.
use_scale_shift_norm
=
use_scale_shift_norm
self
.
in_layers
=
nn
.
Sequential
(
normalization
(
channels
,
swish
=
1.0
),
nn
.
Identity
(),
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
),
)
self
.
updown
=
up
or
down
if
up
:
self
.
h_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
self
.
x_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
elif
down
:
self
.
h_upd
=
Downsample
(
channels
,
use_conv
=
False
,
dims
=
dims
,
padding
=
1
,
name
=
"op"
)
self
.
x_upd
=
Downsample
(
channels
,
use_conv
=
False
,
dims
=
dims
,
padding
=
1
,
name
=
"op"
)
else
:
self
.
h_upd
=
self
.
x_upd
=
nn
.
Identity
()
self
.
emb_layers
=
nn
.
Sequential
(
nn
.
SiLU
(),
linear
(
emb_channels
,
2
*
self
.
out_channels
if
use_scale_shift_norm
else
self
.
out_channels
,
),
)
self
.
out_layers
=
nn
.
Sequential
(
normalization
(
self
.
out_channels
,
swish
=
0.0
if
use_scale_shift_norm
else
1.0
),
nn
.
SiLU
()
if
use_scale_shift_norm
else
nn
.
Identity
(),
nn
.
Dropout
(
p
=
dropout
),
zero_module
(
conv_nd
(
dims
,
self
.
out_channels
,
self
.
out_channels
,
3
,
padding
=
1
)),
)
if
self
.
out_channels
==
channels
:
self
.
skip_connection
=
nn
.
Identity
()
elif
use_conv
:
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
else
:
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
1
)
def
forward
(
self
,
x
,
emb
):
"""
Apply the block to a Tensor, conditioned on a timestep embedding.
:param x: an [N x C x ...] Tensor of features. :param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs.
"""
if
self
.
updown
:
in_rest
,
in_conv
=
self
.
in_layers
[:
-
1
],
self
.
in_layers
[
-
1
]
h
=
in_rest
(
x
)
h
=
self
.
h_upd
(
h
)
x
=
self
.
x_upd
(
x
)
h
=
in_conv
(
h
)
else
:
h
=
self
.
in_layers
(
x
)
emb_out
=
self
.
emb_layers
(
emb
).
type
(
h
.
dtype
)
while
len
(
emb_out
.
shape
)
<
len
(
h
.
shape
):
emb_out
=
emb_out
[...,
None
]
if
self
.
use_scale_shift_norm
:
out_norm
,
out_rest
=
self
.
out_layers
[
0
],
self
.
out_layers
[
1
:]
scale
,
shift
=
torch
.
chunk
(
emb_out
,
2
,
dim
=
1
)
h
=
out_norm
(
h
)
*
(
1
+
scale
)
+
shift
h
=
out_rest
(
h
)
else
:
h
=
h
+
emb_out
h
=
self
.
out_layers
(
h
)
return
self
.
skip_connection
(
x
)
+
h
# unet.py and unet_grad_tts.py
# unet.py, unet_grad_tts.py, unet_ldm.py, unet_glide.py
class
ResnetBlock
(
nn
.
Module
):
def
__init__
(
self
,
...
...
@@ -279,7 +190,12 @@ class ResnetBlock(nn.Module):
pre_norm
=
True
,
eps
=
1e-6
,
non_linearity
=
"swish"
,
time_embedding_norm
=
"default"
,
up
=
False
,
down
=
False
,
overwrite_for_grad_tts
=
False
,
overwrite_for_ldm
=
False
,
overwrite_for_glide
=
False
,
):
super
().
__init__
()
self
.
pre_norm
=
pre_norm
...
...
@@ -287,6 +203,9 @@ class ResnetBlock(nn.Module):
out_channels
=
in_channels
if
out_channels
is
None
else
out_channels
self
.
out_channels
=
out_channels
self
.
use_conv_shortcut
=
conv_shortcut
self
.
time_embedding_norm
=
time_embedding_norm
self
.
up
=
up
self
.
down
=
down
if
self
.
pre_norm
:
self
.
norm1
=
Normalize
(
in_channels
,
num_groups
=
groups
,
eps
=
eps
)
...
...
@@ -294,23 +213,38 @@ class ResnetBlock(nn.Module):
self
.
norm1
=
Normalize
(
out_channels
,
num_groups
=
groups
,
eps
=
eps
)
self
.
conv1
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
out_channels
)
if
time_embedding_norm
==
"default"
:
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
out_channels
)
elif
time_embedding_norm
==
"scale_shift"
:
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
2
*
out_channels
)
self
.
norm2
=
Normalize
(
out_channels
,
num_groups
=
groups
,
eps
=
eps
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
non_linearity
==
"swish"
:
self
.
nonlinearity
=
nonlinearity
elif
non_linearity
==
"mish"
:
self
.
nonlinearity
=
Mish
()
elif
non_linearity
==
"silu"
:
self
.
nonlinearity
=
nn
.
SiLU
()
if
up
:
self
.
h_upd
=
Upsample
(
in_channels
,
use_conv
=
False
,
dims
=
2
)
self
.
x_upd
=
Upsample
(
in_channels
,
use_conv
=
False
,
dims
=
2
)
elif
down
:
self
.
h_upd
=
Downsample
(
in_channels
,
use_conv
=
False
,
dims
=
2
,
padding
=
1
,
name
=
"op"
)
self
.
x_upd
=
Downsample
(
in_channels
,
use_conv
=
False
,
dims
=
2
,
padding
=
1
,
name
=
"op"
)
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
self
.
conv_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
else
:
self
.
nin_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
nin_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
# TODO(SURAJ, PATRICK): ALL OF THE FOLLOWING OF THE INIT METHOD CAN BE DELETED ONCE WEIGHTS ARE CONVERTED
self
.
is_overwritten
=
False
self
.
overwrite_for_glide
=
overwrite_for_glide
self
.
overwrite_for_grad_tts
=
overwrite_for_grad_tts
self
.
overwrite_for_ldm
=
overwrite_for_ldm
or
overwrite_for_glide
if
self
.
overwrite_for_grad_tts
:
dim
=
in_channels
dim_out
=
out_channels
...
...
@@ -324,6 +258,37 @@ class ResnetBlock(nn.Module):
self
.
res_conv
=
torch
.
nn
.
Conv2d
(
dim
,
dim_out
,
1
)
else
:
self
.
res_conv
=
torch
.
nn
.
Identity
()
elif
self
.
overwrite_for_ldm
:
dims
=
2
# eps = 1e-5
# non_linearity = "silu"
# overwrite_for_ldm
channels
=
in_channels
emb_channels
=
temb_channels
use_scale_shift_norm
=
False
self
.
in_layers
=
nn
.
Sequential
(
normalization
(
channels
,
swish
=
1.0
),
nn
.
Identity
(),
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
),
)
self
.
emb_layers
=
nn
.
Sequential
(
nn
.
SiLU
(),
linear
(
emb_channels
,
2
*
self
.
out_channels
if
self
.
time_embedding_norm
==
"scale_shift"
else
self
.
out_channels
,
),
)
self
.
out_layers
=
nn
.
Sequential
(
normalization
(
self
.
out_channels
,
swish
=
0.0
if
use_scale_shift_norm
else
1.0
),
nn
.
SiLU
()
if
use_scale_shift_norm
else
nn
.
Identity
(),
nn
.
Dropout
(
p
=
dropout
),
zero_module
(
conv_nd
(
dims
,
self
.
out_channels
,
self
.
out_channels
,
3
,
padding
=
1
)),
)
if
self
.
out_channels
==
in_channels
:
self
.
skip_connection
=
nn
.
Identity
()
else
:
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
1
)
def
set_weights_grad_tts
(
self
):
self
.
conv1
.
weight
.
data
=
self
.
block1
.
block
[
0
].
weight
.
data
...
...
@@ -343,30 +308,67 @@ class ResnetBlock(nn.Module):
self
.
nin_shortcut
.
weight
.
data
=
self
.
res_conv
.
weight
.
data
self
.
nin_shortcut
.
bias
.
data
=
self
.
res_conv
.
bias
.
data
def
forward
(
self
,
x
,
temb
,
mask
=
None
):
def
set_weights_ldm
(
self
):
self
.
norm1
.
weight
.
data
=
self
.
in_layers
[
0
].
weight
.
data
self
.
norm1
.
bias
.
data
=
self
.
in_layers
[
0
].
bias
.
data
self
.
conv1
.
weight
.
data
=
self
.
in_layers
[
-
1
].
weight
.
data
self
.
conv1
.
bias
.
data
=
self
.
in_layers
[
-
1
].
bias
.
data
self
.
temb_proj
.
weight
.
data
=
self
.
emb_layers
[
-
1
].
weight
.
data
self
.
temb_proj
.
bias
.
data
=
self
.
emb_layers
[
-
1
].
bias
.
data
self
.
norm2
.
weight
.
data
=
self
.
out_layers
[
0
].
weight
.
data
self
.
norm2
.
bias
.
data
=
self
.
out_layers
[
0
].
bias
.
data
self
.
conv2
.
weight
.
data
=
self
.
out_layers
[
-
1
].
weight
.
data
self
.
conv2
.
bias
.
data
=
self
.
out_layers
[
-
1
].
bias
.
data
if
self
.
in_channels
!=
self
.
out_channels
:
self
.
nin_shortcut
.
weight
.
data
=
self
.
skip_connection
.
weight
.
data
self
.
nin_shortcut
.
bias
.
data
=
self
.
skip_connection
.
bias
.
data
def
forward
(
self
,
x
,
temb
,
mask
=
1.0
):
# TODO(Patrick) eventually this class should be split into multiple classes
# too many if else statements
if
self
.
overwrite_for_grad_tts
and
not
self
.
is_overwritten
:
self
.
set_weights_grad_tts
()
self
.
is_overwritten
=
True
elif
self
.
overwrite_for_ldm
and
not
self
.
is_overwritten
:
self
.
set_weights_ldm
()
self
.
is_overwritten
=
True
h
=
x
h
=
h
*
mask
if
mask
is
not
None
else
h
h
=
h
*
mask
if
self
.
pre_norm
:
h
=
self
.
norm1
(
h
)
h
=
self
.
nonlinearity
(
h
)
if
self
.
up
or
self
.
down
:
x
=
self
.
x_upd
(
x
)
h
=
self
.
h_upd
(
h
)
h
=
self
.
conv1
(
h
)
if
not
self
.
pre_norm
:
h
=
self
.
norm1
(
h
)
h
=
self
.
nonlinearity
(
h
)
h
=
h
*
mask
if
mask
is
not
None
else
h
h
=
h
*
mask
h
=
h
+
self
.
temb_proj
(
self
.
nonlinearity
(
temb
))[:,
:,
None
,
None
]
temb
=
self
.
temb_proj
(
self
.
nonlinearity
(
temb
))[:,
:,
None
,
None
]
if
self
.
time_embedding_norm
==
"scale_shift"
:
scale
,
shift
=
torch
.
chunk
(
temb
,
2
,
dim
=
1
)
h
=
h
*
mask
if
mask
is
not
None
else
h
if
self
.
pre_norm
:
h
=
self
.
norm2
(
h
)
h
=
h
+
h
*
scale
+
shift
h
=
self
.
nonlinearity
(
h
)
elif
self
.
time_embedding_norm
==
"default"
:
h
=
h
+
temb
h
=
h
*
mask
if
self
.
pre_norm
:
h
=
self
.
norm2
(
h
)
h
=
self
.
nonlinearity
(
h
)
h
=
self
.
dropout
(
h
)
h
=
self
.
conv2
(
h
)
...
...
@@ -374,14 +376,11 @@ class ResnetBlock(nn.Module):
if
not
self
.
pre_norm
:
h
=
self
.
norm2
(
h
)
h
=
self
.
nonlinearity
(
h
)
h
=
h
*
mask
if
mask
is
not
None
else
h
h
=
h
*
mask
x
=
x
*
mask
if
mask
is
not
None
else
x
x
=
x
*
mask
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
x
=
self
.
conv_shortcut
(
x
)
else
:
x
=
self
.
nin_shortcut
(
x
)
x
=
self
.
nin_shortcut
(
x
)
return
x
+
h
...
...
@@ -394,10 +393,6 @@ class Block(torch.nn.Module):
torch
.
nn
.
Conv2d
(
dim
,
dim_out
,
3
,
padding
=
1
),
torch
.
nn
.
GroupNorm
(
groups
,
dim_out
),
Mish
()
)
def
forward
(
self
,
x
,
mask
):
output
=
self
.
block
(
x
*
mask
)
return
output
*
mask
# unet_score_estimation.py
class
ResnetBlockBigGANpp
(
nn
.
Module
):
...
...
@@ -424,17 +419,29 @@ class ResnetBlockBigGANpp(nn.Module):
self
.
fir
=
fir
self
.
fir_kernel
=
fir_kernel
self
.
Conv_0
=
conv3x3
(
in_ch
,
out_ch
)
if
self
.
up
:
if
self
.
fir
:
self
.
upsample
=
partial
(
upsample_2d
,
k
=
self
.
fir_kernel
,
factor
=
2
)
else
:
self
.
upsample
=
partial
(
F
.
interpolate
,
scale_factor
=
2.0
,
mode
=
"nearest"
)
elif
self
.
down
:
if
self
.
fir
:
self
.
downsample
=
partial
(
downsample_2d
,
k
=
self
.
fir_kernel
,
factor
=
2
)
else
:
self
.
downsample
=
partial
(
F
.
avg_pool2d
,
kernel_size
=
2
,
stride
=
2
)
self
.
Conv_0
=
conv2d
(
in_ch
,
out_ch
,
kernel_size
=
3
,
padding
=
1
)
if
temb_dim
is
not
None
:
self
.
Dense_0
=
nn
.
Linear
(
temb_dim
,
out_ch
)
self
.
Dense_0
.
weight
.
data
=
default_init
()(
self
.
Dense_0
.
weight
.
shape
)
self
.
Dense_0
.
weight
.
data
=
variance_scaling
()(
self
.
Dense_0
.
weight
.
shape
)
nn
.
init
.
zeros_
(
self
.
Dense_0
.
bias
)
self
.
GroupNorm_1
=
nn
.
GroupNorm
(
num_groups
=
min
(
out_ch
//
4
,
32
),
num_channels
=
out_ch
,
eps
=
1e-6
)
self
.
Dropout_0
=
nn
.
Dropout
(
dropout
)
self
.
Conv_1
=
conv
3x3
(
out_ch
,
out_ch
,
init_scale
=
init_scale
)
self
.
Conv_1
=
conv
2d
(
out_ch
,
out_ch
,
init_scale
=
init_scale
,
kernel_size
=
3
,
padding
=
1
)
if
in_ch
!=
out_ch
or
up
or
down
:
self
.
Conv_2
=
conv1x1
(
in_ch
,
out_ch
)
# 1x1 convolution with DDPM initialization.
self
.
Conv_2
=
conv2d
(
in_ch
,
out_ch
,
kernel_size
=
1
,
padding
=
0
)
self
.
skip_rescale
=
skip_rescale
self
.
act
=
act
...
...
@@ -445,19 +452,11 @@ class ResnetBlockBigGANpp(nn.Module):
h
=
self
.
act
(
self
.
GroupNorm_0
(
x
))
if
self
.
up
:
if
self
.
fir
:
h
=
upsample_2d
(
h
,
self
.
fir_kernel
,
factor
=
2
)
x
=
upsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
else
:
h
=
naive_upsample_2d
(
h
,
factor
=
2
)
x
=
naive_upsample_2d
(
x
,
factor
=
2
)
h
=
self
.
upsample
(
h
)
x
=
self
.
upsample
(
x
)
elif
self
.
down
:
if
self
.
fir
:
h
=
downsample_2d
(
h
,
self
.
fir_kernel
,
factor
=
2
)
x
=
downsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
else
:
h
=
naive_downsample_2d
(
h
,
factor
=
2
)
x
=
naive_downsample_2d
(
x
,
factor
=
2
)
h
=
self
.
downsample
(
h
)
x
=
self
.
downsample
(
x
)
h
=
self
.
Conv_0
(
h
)
# Add bias to each feature map conditioned on the time embedding
...
...
@@ -476,62 +475,6 @@ class ResnetBlockBigGANpp(nn.Module):
return
(
x
+
h
)
/
np
.
sqrt
(
2.0
)
# unet_score_estimation.py
class
ResnetBlockDDPMpp
(
nn
.
Module
):
"""ResBlock adapted from DDPM."""
def
__init__
(
self
,
act
,
in_ch
,
out_ch
=
None
,
temb_dim
=
None
,
conv_shortcut
=
False
,
dropout
=
0.1
,
skip_rescale
=
False
,
init_scale
=
0.0
,
):
super
().
__init__
()
out_ch
=
out_ch
if
out_ch
else
in_ch
self
.
GroupNorm_0
=
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
)
self
.
Conv_0
=
conv3x3
(
in_ch
,
out_ch
)
if
temb_dim
is
not
None
:
self
.
Dense_0
=
nn
.
Linear
(
temb_dim
,
out_ch
)
self
.
Dense_0
.
weight
.
data
=
default_init
()(
self
.
Dense_0
.
weight
.
data
.
shape
)
nn
.
init
.
zeros_
(
self
.
Dense_0
.
bias
)
self
.
GroupNorm_1
=
nn
.
GroupNorm
(
num_groups
=
min
(
out_ch
//
4
,
32
),
num_channels
=
out_ch
,
eps
=
1e-6
)
self
.
Dropout_0
=
nn
.
Dropout
(
dropout
)
self
.
Conv_1
=
conv3x3
(
out_ch
,
out_ch
,
init_scale
=
init_scale
)
if
in_ch
!=
out_ch
:
if
conv_shortcut
:
self
.
Conv_2
=
conv3x3
(
in_ch
,
out_ch
)
else
:
self
.
NIN_0
=
NIN
(
in_ch
,
out_ch
)
self
.
skip_rescale
=
skip_rescale
self
.
act
=
act
self
.
out_ch
=
out_ch
self
.
conv_shortcut
=
conv_shortcut
def
forward
(
self
,
x
,
temb
=
None
):
h
=
self
.
act
(
self
.
GroupNorm_0
(
x
))
h
=
self
.
Conv_0
(
h
)
if
temb
is
not
None
:
h
+=
self
.
Dense_0
(
self
.
act
(
temb
))[:,
:,
None
,
None
]
h
=
self
.
act
(
self
.
GroupNorm_1
(
h
))
h
=
self
.
Dropout_0
(
h
)
h
=
self
.
Conv_1
(
h
)
if
x
.
shape
[
1
]
!=
self
.
out_ch
:
if
self
.
conv_shortcut
:
x
=
self
.
Conv_2
(
x
)
else
:
x
=
self
.
NIN_0
(
x
)
if
not
self
.
skip_rescale
:
return
x
+
h
else
:
return
(
x
+
h
)
/
np
.
sqrt
(
2.0
)
# unet_rl.py
class
ResidualTemporalBlock
(
nn
.
Module
):
def
__init__
(
self
,
inp_channels
,
out_channels
,
embed_dim
,
horizon
,
kernel_size
=
5
):
...
...
@@ -649,32 +592,17 @@ class RearrangeDim(nn.Module):
raise
ValueError
(
f
"`len(tensor)`:
{
len
(
tensor
)
}
has to be 2, 3 or 4."
)
def
conv
1x1
(
in_planes
,
out_planes
,
stride
=
1
,
bias
=
True
,
init_scale
=
1.0
,
padding
=
0
):
"""
1x1
convolution with DDPM initialization."""
conv
=
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
1
,
stride
=
stride
,
padding
=
padding
,
bias
=
bias
)
conv
.
weight
.
data
=
default_init
(
init_scale
)(
conv
.
weight
.
data
.
shape
)
def
conv
2d
(
in_planes
,
out_planes
,
kernel_size
=
3
,
stride
=
1
,
bias
=
True
,
init_scale
=
1.0
,
padding
=
1
):
"""
nXn
convolution with DDPM initialization."""
conv
=
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
bias
=
bias
)
conv
.
weight
.
data
=
variance_scaling
(
init_scale
)(
conv
.
weight
.
data
.
shape
)
nn
.
init
.
zeros_
(
conv
.
bias
)
return
conv
def
conv3x3
(
in_planes
,
out_planes
,
stride
=
1
,
bias
=
True
,
dilation
=
1
,
init_scale
=
1.0
,
padding
=
1
):
"""3x3 convolution with DDPM initialization."""
conv
=
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
,
bias
=
bias
)
conv
.
weight
.
data
=
default_init
(
init_scale
)(
conv
.
weight
.
data
.
shape
)
nn
.
init
.
zeros_
(
conv
.
bias
)
return
conv
def
default_init
(
scale
=
1.0
):
"""The same initialization used in DDPM."""
scale
=
1e-10
if
scale
==
0
else
scale
return
variance_scaling
(
scale
,
"fan_avg"
,
"uniform"
)
def
variance_scaling
(
scale
,
mode
,
distribution
,
in_axis
=
1
,
out_axis
=
0
,
dtype
=
torch
.
float32
,
device
=
"cpu"
):
def
variance_scaling
(
scale
=
1.0
,
in_axis
=
1
,
out_axis
=
0
,
dtype
=
torch
.
float32
,
device
=
"cpu"
):
"""Ported from JAX."""
scale
=
1e-10
if
scale
==
0
else
scale
def
_compute_fans
(
shape
,
in_axis
=
1
,
out_axis
=
0
):
receptive_field_size
=
np
.
prod
(
shape
)
/
shape
[
in_axis
]
/
shape
[
out_axis
]
...
...
@@ -684,21 +612,9 @@ def variance_scaling(scale, mode, distribution, in_axis=1, out_axis=0, dtype=tor
def
init
(
shape
,
dtype
=
dtype
,
device
=
device
):
fan_in
,
fan_out
=
_compute_fans
(
shape
,
in_axis
,
out_axis
)
if
mode
==
"fan_in"
:
denominator
=
fan_in
elif
mode
==
"fan_out"
:
denominator
=
fan_out
elif
mode
==
"fan_avg"
:
denominator
=
(
fan_in
+
fan_out
)
/
2
else
:
raise
ValueError
(
"invalid mode for variance scaling initializer: {}"
.
format
(
mode
))
denominator
=
(
fan_in
+
fan_out
)
/
2
variance
=
scale
/
denominator
if
distribution
==
"normal"
:
return
torch
.
randn
(
*
shape
,
dtype
=
dtype
,
device
=
device
)
*
np
.
sqrt
(
variance
)
elif
distribution
==
"uniform"
:
return
(
torch
.
rand
(
*
shape
,
dtype
=
dtype
,
device
=
device
)
*
2.0
-
1.0
)
*
np
.
sqrt
(
3
*
variance
)
else
:
raise
ValueError
(
"invalid distribution for variance scaling initializer"
)
return
(
torch
.
rand
(
*
shape
,
dtype
=
dtype
,
device
=
device
)
*
2.0
-
1.0
)
*
np
.
sqrt
(
3
*
variance
)
return
init
...
...
@@ -796,31 +712,6 @@ def downsample_2d(x, k=None, factor=2, gain=1):
return
upfirdn2d
(
x
,
torch
.
tensor
(
k
,
device
=
x
.
device
),
down
=
factor
,
pad
=
((
p
+
1
)
//
2
,
p
//
2
))
def
naive_upsample_2d
(
x
,
factor
=
2
):
_N
,
C
,
H
,
W
=
x
.
shape
x
=
torch
.
reshape
(
x
,
(
-
1
,
C
,
H
,
1
,
W
,
1
))
x
=
x
.
repeat
(
1
,
1
,
1
,
factor
,
1
,
factor
)
return
torch
.
reshape
(
x
,
(
-
1
,
C
,
H
*
factor
,
W
*
factor
))
def
naive_downsample_2d
(
x
,
factor
=
2
):
_N
,
C
,
H
,
W
=
x
.
shape
x
=
torch
.
reshape
(
x
,
(
-
1
,
C
,
H
//
factor
,
factor
,
W
//
factor
,
factor
))
return
torch
.
mean
(
x
,
dim
=
(
3
,
5
))
class
NIN
(
nn
.
Module
):
def
__init__
(
self
,
in_dim
,
num_units
,
init_scale
=
0.1
):
super
().
__init__
()
self
.
W
=
nn
.
Parameter
(
default_init
(
scale
=
init_scale
)((
in_dim
,
num_units
)),
requires_grad
=
True
)
self
.
b
=
nn
.
Parameter
(
torch
.
zeros
(
num_units
),
requires_grad
=
True
)
def
forward
(
self
,
x
):
x
=
x
.
permute
(
0
,
2
,
3
,
1
)
y
=
contract_inner
(
x
,
self
.
W
)
+
self
.
b
return
y
.
permute
(
0
,
3
,
1
,
2
)
def
_setup_kernel
(
k
):
k
=
np
.
asarray
(
k
,
dtype
=
np
.
float32
)
if
k
.
ndim
==
1
:
...
...
@@ -829,17 +720,3 @@ def _setup_kernel(k):
assert
k
.
ndim
==
2
assert
k
.
shape
[
0
]
==
k
.
shape
[
1
]
return
k
def
contract_inner
(
x
,
y
):
"""tensordot(x, y, 1)."""
x_chars
=
list
(
string
.
ascii_lowercase
[:
len
(
x
.
shape
)])
y_chars
=
list
(
string
.
ascii_lowercase
[
len
(
x
.
shape
)
:
len
(
y
.
shape
)
+
len
(
x
.
shape
)])
y_chars
[
0
]
=
x_chars
[
-
1
]
# first axis of y and last of x get summed
out_chars
=
x_chars
[:
-
1
]
+
y_chars
[
1
:]
return
_einsum
(
x_chars
,
y_chars
,
out_chars
,
x
,
y
)
def
_einsum
(
a
,
b
,
c
,
x
,
y
):
einsum_str
=
"{},{}->{}"
.
format
(
""
.
join
(
a
),
""
.
join
(
b
),
""
.
join
(
c
))
return
torch
.
einsum
(
einsum_str
,
x
,
y
)
src/diffusers/models/unet.py
View file @
23904d54
...
...
@@ -34,48 +34,6 @@ def Normalize(in_channels):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
# class ResnetBlock(nn.Module):
# def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512):
# super().__init__()
# self.in_channels = in_channels
# out_channels = in_channels if out_channels is None else out_channels
# self.out_channels = out_channels
# self.use_conv_shortcut = conv_shortcut
#
# self.norm1 = Normalize(in_channels)
# self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
# self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
# self.norm2 = Normalize(out_channels)
# self.dropout = torch.nn.Dropout(dropout)
# self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
# if self.in_channels != self.out_channels:
# if self.use_conv_shortcut:
# self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
# else:
# self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
#
# def forward(self, x, temb):
# h = x
# h = self.norm1(h)
# h = nonlinearity(h)
# h = self.conv1(h)
#
# h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
#
# h = self.norm2(h)
# h = nonlinearity(h)
# h = self.dropout(h)
# h = self.conv2(h)
#
# if self.in_channels != self.out_channels:
# if self.use_conv_shortcut:
# x = self.conv_shortcut(x)
# else:
# x = self.nin_shortcut(x)
#
# return x + h
class
UNetModel
(
ModelMixin
,
ConfigMixin
):
def
__init__
(
self
,
...
...
src/diffusers/models/unet_glide.py
View file @
23904d54
...
...
@@ -6,7 +6,7 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
,
ResBlock
,
TimestepBlock
,
Upsample
from
.resnet
import
Downsample
,
Res
net
Block
,
TimestepBlock
,
Upsample
def
convert_module_to_f16
(
l
):
...
...
@@ -29,19 +29,6 @@ def convert_module_to_f32(l):
l
.
bias
.
data
=
l
.
bias
.
data
.
float
()
def
avg_pool_nd
(
dims
,
*
args
,
**
kwargs
):
"""
Create a 1D, 2D, or 3D average pooling module.
"""
if
dims
==
1
:
return
nn
.
AvgPool1d
(
*
args
,
**
kwargs
)
elif
dims
==
2
:
return
nn
.
AvgPool2d
(
*
args
,
**
kwargs
)
elif
dims
==
3
:
return
nn
.
AvgPool3d
(
*
args
,
**
kwargs
)
raise
ValueError
(
f
"unsupported dimensions:
{
dims
}
"
)
def
conv_nd
(
dims
,
*
args
,
**
kwargs
):
"""
Create a 1D, 2D, or 3D convolution module.
...
...
@@ -101,7 +88,7 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
def
forward
(
self
,
x
,
emb
,
encoder_out
=
None
):
for
layer
in
self
:
if
isinstance
(
layer
,
TimestepBlock
):
if
isinstance
(
layer
,
TimestepBlock
)
or
isinstance
(
layer
,
ResnetBlock
)
:
x
=
layer
(
x
,
emb
)
elif
isinstance
(
layer
,
AttentionBlock
):
x
=
layer
(
x
,
encoder_out
)
...
...
@@ -190,14 +177,15 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
for
level
,
mult
in
enumerate
(
channel_mult
):
for
_
in
range
(
num_res_blocks
):
layers
=
[
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
out_channels
=
int
(
mult
*
model_channels
),
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
ResnetBlock
(
in_channels
=
ch
,
out_channels
=
mult
*
model_channels
,
dropout
=
dropout
,
temb_channels
=
time_embed_dim
,
eps
=
1e-5
,
non_linearity
=
"silu"
,
time_embedding_norm
=
"scale_shift"
if
use_scale_shift_norm
else
"default"
,
overwrite_for_glide
=
True
,
)
]
ch
=
int
(
mult
*
model_channels
)
...
...
@@ -218,14 +206,15 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
out_ch
=
ch
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
ResnetBlock
(
in_channels
=
ch
,
out_channels
=
out_ch
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
dropout
=
dropout
,
temb_channels
=
time_embed_dim
,
eps
=
1e-5
,
non_linearity
=
"silu"
,
time_embedding_norm
=
"scale_shift"
if
use_scale_shift_norm
else
"default"
,
overwrite_for_glide
=
True
,
down
=
True
,
)
if
resblock_updown
...
...
@@ -240,13 +229,14 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
self
.
_feature_size
+=
ch
self
.
middle_block
=
TimestepEmbedSequential
(
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
ResnetBlock
(
in_channels
=
ch
,
dropout
=
dropout
,
temb_channels
=
time_embed_dim
,
eps
=
1e-5
,
non_linearity
=
"silu"
,
time_embedding_norm
=
"scale_shift"
if
use_scale_shift_norm
else
"default"
,
overwrite_for_glide
=
True
,
),
AttentionBlock
(
ch
,
...
...
@@ -255,13 +245,14 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
num_head_channels
=
num_head_channels
,
encoder_channels
=
transformer_dim
,
),
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
ResnetBlock
(
in_channels
=
ch
,
dropout
=
dropout
,
temb_channels
=
time_embed_dim
,
eps
=
1e-5
,
non_linearity
=
"silu"
,
time_embedding_norm
=
"scale_shift"
if
use_scale_shift_norm
else
"default"
,
overwrite_for_glide
=
True
,
),
)
self
.
_feature_size
+=
ch
...
...
@@ -271,15 +262,16 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
for
i
in
range
(
num_res_blocks
+
1
):
ich
=
input_block_chans
.
pop
()
layers
=
[
ResBlock
(
ch
+
ich
,
time_embed_dim
,
dropout
,
out_channels
=
int
(
model_channels
*
mult
),
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
)
ResnetBlock
(
in_channels
=
ch
+
ich
,
out_channels
=
model_channels
*
mult
,
dropout
=
dropout
,
temb_channels
=
time_embed_dim
,
eps
=
1e-5
,
non_linearity
=
"silu"
,
time_embedding_norm
=
"scale_shift"
if
use_scale_shift_norm
else
"default"
,
overwrite_for_glide
=
True
,
),
]
ch
=
int
(
model_channels
*
mult
)
if
ds
in
attention_resolutions
:
...
...
@@ -295,14 +287,15 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
if
level
and
i
==
num_res_blocks
:
out_ch
=
ch
layers
.
append
(
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
ResnetBlock
(
in_channels
=
ch
,
out_channels
=
out_ch
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
dropout
=
dropout
,
temb_channels
=
time_embed_dim
,
eps
=
1e-5
,
non_linearity
=
"silu"
,
time_embedding_norm
=
"scale_shift"
if
use_scale_shift_norm
else
"default"
,
overwrite_for_glide
=
True
,
up
=
True
,
)
if
resblock_updown
...
...
src/diffusers/models/unet_ldm.py
View file @
23904d54
...
...
@@ -10,7 +10,10 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
,
ResBlock
,
TimestepBlock
,
Upsample
from
.resnet
import
Downsample
,
ResnetBlock
,
TimestepBlock
,
Upsample
# from .resnet import ResBlock
def
exists
(
val
):
...
...
@@ -75,182 +78,6 @@ def Normalize(in_channels):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
# class LinearAttention(nn.Module):
# def __init__(self, dim, heads=4, dim_head=32):
# super().__init__()
# self.heads = heads
# hidden_dim = dim_head * heads
# self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
# self.to_out = nn.Conv2d(hidden_dim, dim, 1)
#
# def forward(self, x):
# b, c, h, w = x.shape
# qkv = self.to_qkv(x)
# q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
# import ipdb; ipdb.set_trace()
# k = k.softmax(dim=-1)
# context = torch.einsum("bhdn,bhen->bhde", k, v)
# out = torch.einsum("bhde,bhdn->bhen", context, q)
# out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
# return self.to_out(out)
#
# class SpatialSelfAttention(nn.Module):
# def __init__(self, in_channels):
# super().__init__()
# self.in_channels = in_channels
#
# self.norm = Normalize(in_channels)
# self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
# self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
# self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
# self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
#
# def forward(self, x):
# h_ = x
# h_ = self.norm(h_)
# q = self.q(h_)
# k = self.k(h_)
# v = self.v(h_)
#
# compute attention
# b, c, h, w = q.shape
# q = rearrange(q, "b c h w -> b (h w) c")
# k = rearrange(k, "b c h w -> b c (h w)")
# w_ = torch.einsum("bij,bjk->bik", q, k)
#
# w_ = w_ * (int(c) ** (-0.5))
# w_ = torch.nn.functional.softmax(w_, dim=2)
#
# attend to values
# v = rearrange(v, "b c h w -> b c (h w)")
# w_ = rearrange(w_, "b i j -> b j i")
# h_ = torch.einsum("bij,bjk->bik", v, w_)
# h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
# h_ = self.proj_out(h_)
#
# return x + h_
#
class
CrossAttention
(
nn
.
Module
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.0
):
super
().
__init__
()
inner_dim
=
dim_head
*
heads
context_dim
=
default
(
context_dim
,
query_dim
)
self
.
scale
=
dim_head
**-
0.5
self
.
heads
=
heads
self
.
to_q
=
nn
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
)
self
.
to_k
=
nn
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
self
.
to_v
=
nn
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
self
.
to_out
=
nn
.
Sequential
(
nn
.
Linear
(
inner_dim
,
query_dim
),
nn
.
Dropout
(
dropout
))
def
reshape_heads_to_batch_dim
(
self
,
tensor
):
batch_size
,
seq_len
,
dim
=
tensor
.
shape
head_size
=
self
.
heads
tensor
=
tensor
.
reshape
(
batch_size
,
seq_len
,
head_size
,
dim
//
head_size
)
tensor
=
tensor
.
permute
(
0
,
2
,
1
,
3
).
reshape
(
batch_size
*
head_size
,
seq_len
,
dim
//
head_size
)
return
tensor
def
reshape_batch_dim_to_heads
(
self
,
tensor
):
batch_size
,
seq_len
,
dim
=
tensor
.
shape
head_size
=
self
.
heads
tensor
=
tensor
.
reshape
(
batch_size
//
head_size
,
head_size
,
seq_len
,
dim
)
tensor
=
tensor
.
permute
(
0
,
2
,
1
,
3
).
reshape
(
batch_size
//
head_size
,
seq_len
,
dim
*
head_size
)
return
tensor
def
forward
(
self
,
x
,
context
=
None
,
mask
=
None
):
batch_size
,
sequence_length
,
dim
=
x
.
shape
h
=
self
.
heads
q
=
self
.
to_q
(
x
)
context
=
default
(
context
,
x
)
k
=
self
.
to_k
(
context
)
v
=
self
.
to_v
(
context
)
q
=
self
.
reshape_heads_to_batch_dim
(
q
)
k
=
self
.
reshape_heads_to_batch_dim
(
k
)
v
=
self
.
reshape_heads_to_batch_dim
(
v
)
sim
=
torch
.
einsum
(
"b i d, b j d -> b i j"
,
q
,
k
)
*
self
.
scale
if
exists
(
mask
):
mask
=
mask
.
reshape
(
batch_size
,
-
1
)
max_neg_value
=
-
torch
.
finfo
(
sim
.
dtype
).
max
mask
=
mask
[:,
None
,
:].
repeat
(
h
,
1
,
1
)
sim
.
masked_fill_
(
~
mask
,
max_neg_value
)
# attention, what we cannot get enough of
attn
=
sim
.
softmax
(
dim
=-
1
)
out
=
torch
.
einsum
(
"b i j, b j d -> b i d"
,
attn
,
v
)
out
=
self
.
reshape_batch_dim_to_heads
(
out
)
return
self
.
to_out
(
out
)
class
BasicTransformerBlock
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
n_heads
,
d_head
,
dropout
=
0.0
,
context_dim
=
None
,
gated_ff
=
True
,
checkpoint
=
True
):
super
().
__init__
()
self
.
attn1
=
CrossAttention
(
query_dim
=
dim
,
heads
=
n_heads
,
dim_head
=
d_head
,
dropout
=
dropout
)
# is a self-attention
self
.
ff
=
FeedForward
(
dim
,
dropout
=
dropout
,
glu
=
gated_ff
)
self
.
attn2
=
CrossAttention
(
query_dim
=
dim
,
context_dim
=
context_dim
,
heads
=
n_heads
,
dim_head
=
d_head
,
dropout
=
dropout
)
# is self-attn if context is none
self
.
norm1
=
nn
.
LayerNorm
(
dim
)
self
.
norm2
=
nn
.
LayerNorm
(
dim
)
self
.
norm3
=
nn
.
LayerNorm
(
dim
)
self
.
checkpoint
=
checkpoint
def
forward
(
self
,
x
,
context
=
None
):
x
=
self
.
attn1
(
self
.
norm1
(
x
))
+
x
x
=
self
.
attn2
(
self
.
norm2
(
x
),
context
=
context
)
+
x
x
=
self
.
ff
(
self
.
norm3
(
x
))
+
x
return
x
class
SpatialTransformer
(
nn
.
Module
):
"""
Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
standard transformer action. Finally, reshape to image
"""
def
__init__
(
self
,
in_channels
,
n_heads
,
d_head
,
depth
=
1
,
dropout
=
0.0
,
context_dim
=
None
):
super
().
__init__
()
self
.
in_channels
=
in_channels
inner_dim
=
n_heads
*
d_head
self
.
norm
=
Normalize
(
in_channels
)
self
.
proj_in
=
nn
.
Conv2d
(
in_channels
,
inner_dim
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
transformer_blocks
=
nn
.
ModuleList
(
[
BasicTransformerBlock
(
inner_dim
,
n_heads
,
d_head
,
dropout
=
dropout
,
context_dim
=
context_dim
)
for
d
in
range
(
depth
)
]
)
self
.
proj_out
=
zero_module
(
nn
.
Conv2d
(
inner_dim
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
))
def
forward
(
self
,
x
,
context
=
None
):
# note: if no context is given, cross-attention defaults to self-attention
b
,
c
,
h
,
w
=
x
.
shape
x_in
=
x
x
=
self
.
norm
(
x
)
x
=
self
.
proj_in
(
x
)
x
=
x
.
permute
(
0
,
2
,
3
,
1
).
reshape
(
b
,
h
*
w
,
c
)
for
block
in
self
.
transformer_blocks
:
x
=
block
(
x
,
context
=
context
)
x
=
x
.
reshape
(
b
,
h
,
w
,
c
).
permute
(
0
,
3
,
1
,
2
)
x
=
self
.
proj_out
(
x
)
return
x
+
x_in
def
convert_module_to_f16
(
l
):
"""
Convert primitive modules to float16.
...
...
@@ -271,19 +98,6 @@ def convert_module_to_f32(l):
l
.
bias
.
data
=
l
.
bias
.
data
.
float
()
def
avg_pool_nd
(
dims
,
*
args
,
**
kwargs
):
"""
Create a 1D, 2D, or 3D average pooling module.
"""
if
dims
==
1
:
return
nn
.
AvgPool1d
(
*
args
,
**
kwargs
)
elif
dims
==
2
:
return
nn
.
AvgPool2d
(
*
args
,
**
kwargs
)
elif
dims
==
3
:
return
nn
.
AvgPool3d
(
*
args
,
**
kwargs
)
raise
ValueError
(
f
"unsupported dimensions:
{
dims
}
"
)
def
conv_nd
(
dims
,
*
args
,
**
kwargs
):
"""
Create a 1D, 2D, or 3D convolution module.
...
...
@@ -327,36 +141,6 @@ def normalization(channels, swish=0.0):
return
GroupNorm32
(
num_channels
=
channels
,
num_groups
=
32
,
swish
=
swish
)
class
AttentionPool2d
(
nn
.
Module
):
"""
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
"""
def
__init__
(
self
,
spacial_dim
:
int
,
embed_dim
:
int
,
num_heads_channels
:
int
,
output_dim
:
int
=
None
,
):
super
().
__init__
()
self
.
positional_embedding
=
nn
.
Parameter
(
torch
.
randn
(
embed_dim
,
spacial_dim
**
2
+
1
)
/
embed_dim
**
0.5
)
self
.
qkv_proj
=
conv_nd
(
1
,
embed_dim
,
3
*
embed_dim
,
1
)
self
.
c_proj
=
conv_nd
(
1
,
embed_dim
,
output_dim
or
embed_dim
,
1
)
self
.
num_heads
=
embed_dim
//
num_heads_channels
self
.
attention
=
QKVAttention
(
self
.
num_heads
)
def
forward
(
self
,
x
):
b
,
c
,
*
_spatial
=
x
.
shape
x
=
x
.
reshape
(
b
,
c
,
-
1
)
# NC(HW)
x
=
torch
.
cat
([
x
.
mean
(
dim
=-
1
,
keepdim
=
True
),
x
],
dim
=-
1
)
# NC(HW+1)
x
=
x
+
self
.
positional_embedding
[
None
,
:,
:].
to
(
x
.
dtype
)
# NC(HW+1)
x
=
self
.
qkv_proj
(
x
)
x
=
self
.
attention
(
x
)
x
=
self
.
c_proj
(
x
)
return
x
[:,
:,
0
]
class
TimestepEmbedSequential
(
nn
.
Sequential
,
TimestepBlock
):
"""
A sequential module that passes timestep embeddings to the children that support it as an extra input.
...
...
@@ -364,7 +148,7 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
def
forward
(
self
,
x
,
emb
,
context
=
None
):
for
layer
in
self
:
if
isinstance
(
layer
,
TimestepBlock
):
if
isinstance
(
layer
,
TimestepBlock
)
or
isinstance
(
layer
,
ResnetBlock
)
:
x
=
layer
(
x
,
emb
)
elif
isinstance
(
layer
,
SpatialTransformer
):
x
=
layer
(
x
,
context
)
...
...
@@ -373,39 +157,6 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
return
x
class
QKVAttention
(
nn
.
Module
):
"""
A module which performs QKV attention and splits in a different order.
"""
def
__init__
(
self
,
n_heads
):
super
().
__init__
()
self
.
n_heads
=
n_heads
def
forward
(
self
,
qkv
):
"""
Apply QKV attention. :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x (H * C) x
T] tensor after attention.
"""
bs
,
width
,
length
=
qkv
.
shape
assert
width
%
(
3
*
self
.
n_heads
)
==
0
ch
=
width
//
(
3
*
self
.
n_heads
)
q
,
k
,
v
=
qkv
.
chunk
(
3
,
dim
=
1
)
scale
=
1
/
math
.
sqrt
(
math
.
sqrt
(
ch
))
weight
=
torch
.
einsum
(
"bct,bcs->bts"
,
(
q
*
scale
).
view
(
bs
*
self
.
n_heads
,
ch
,
length
),
(
k
*
scale
).
view
(
bs
*
self
.
n_heads
,
ch
,
length
),
)
# More stable with f16 than dividing afterwards
weight
=
torch
.
softmax
(
weight
.
float
(),
dim
=-
1
).
type
(
weight
.
dtype
)
a
=
torch
.
einsum
(
"bts,bcs->bct"
,
weight
,
v
.
reshape
(
bs
*
self
.
n_heads
,
ch
,
length
))
return
a
.
reshape
(
bs
,
-
1
,
length
)
@
staticmethod
def
count_flops
(
model
,
_x
,
y
):
return
count_flops_attn
(
model
,
_x
,
y
)
def
count_flops_attn
(
model
,
_x
,
y
):
"""
A counter for the `thop` package to count the operations in an attention operation. Meant to be used like:
...
...
@@ -559,14 +310,14 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
for
level
,
mult
in
enumerate
(
channel_mult
):
for
_
in
range
(
num_res_blocks
):
layers
=
[
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
ResnetBlock
(
in_channels
=
ch
,
out_channels
=
mult
*
model_channels
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
dropout
=
dropout
,
temb_channels
=
time_embed_dim
,
eps
=
1e-5
,
non_linearity
=
"silu"
,
overwrite_for_ldm
=
True
,
)
]
ch
=
mult
*
model_channels
...
...
@@ -599,20 +350,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
out_ch
=
ch
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
out_channels
=
out_ch
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
down
=
True
,
)
if
resblock_updown
else
Downsample
(
ch
,
use_conv
=
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
,
padding
=
1
,
name
=
"op"
)
Downsample
(
ch
,
use_conv
=
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
,
padding
=
1
,
name
=
"op"
)
)
)
ch
=
out_ch
...
...
@@ -629,13 +367,14 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
# num_heads = 1
dim_head
=
ch
//
num_heads
if
use_spatial_transformer
else
num_head_channels
self
.
middle_block
=
TimestepEmbedSequential
(
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
ResnetBlock
(
in_channels
=
ch
,
out_channels
=
None
,
dropout
=
dropout
,
temb_channels
=
time_embed_dim
,
eps
=
1e-5
,
non_linearity
=
"silu"
,
overwrite_for_ldm
=
True
,
),
AttentionBlock
(
ch
,
...
...
@@ -646,13 +385,14 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
)
if
not
use_spatial_transformer
else
SpatialTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
,
context_dim
=
context_dim
),
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
ResnetBlock
(
in_channels
=
ch
,
out_channels
=
None
,
dropout
=
dropout
,
temb_channels
=
time_embed_dim
,
eps
=
1e-5
,
non_linearity
=
"silu"
,
overwrite_for_ldm
=
True
,
),
)
self
.
_feature_size
+=
ch
...
...
@@ -662,15 +402,15 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
for
i
in
range
(
num_res_blocks
+
1
):
ich
=
input_block_chans
.
pop
()
layers
=
[
ResBlock
(
ch
+
ich
,
time_embed_dim
,
dropout
,
ResnetBlock
(
in_channels
=
ch
+
ich
,
out_channels
=
model_channels
*
mult
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
)
dropout
=
dropout
,
temb_channels
=
time_embed_dim
,
eps
=
1e-5
,
non_linearity
=
"silu"
,
overwrite_for_ldm
=
True
,
),
]
ch
=
model_channels
*
mult
if
ds
in
attention_resolutions
:
...
...
@@ -697,20 +437,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
)
if
level
and
i
==
num_res_blocks
:
out_ch
=
ch
layers
.
append
(
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
out_channels
=
out_ch
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
up
=
True
,
)
if
resblock_updown
else
Upsample
(
ch
,
use_conv
=
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
)
)
layers
.
append
(
Upsample
(
ch
,
use_conv
=
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
))
ds
//=
2
self
.
output_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
self
.
_feature_size
+=
ch
...
...
@@ -777,212 +504,119 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
return
self
.
out
(
h
)
class
EncoderUNetModel
(
nn
.
Module
):
class
SpatialTransformer
(
nn
.
Module
):
"""
The half UNet model with attention and timestep embedding. For usage, see UNet.
Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
standard transformer action. Finally, reshape to image
"""
def
__init__
(
self
,
image_size
,
in_channels
,
model_channels
,
out_channels
,
num_res_blocks
,
attention_resolutions
,
dropout
=
0
,
channel_mult
=
(
1
,
2
,
4
,
8
),
conv_resample
=
True
,
dims
=
2
,
use_checkpoint
=
False
,
use_fp16
=
False
,
num_heads
=
1
,
num_head_channels
=-
1
,
num_heads_upsample
=-
1
,
use_scale_shift_norm
=
False
,
resblock_updown
=
False
,
use_new_attention_order
=
False
,
pool
=
"adaptive"
,
*
args
,
**
kwargs
,
):
def
__init__
(
self
,
in_channels
,
n_heads
,
d_head
,
depth
=
1
,
dropout
=
0.0
,
context_dim
=
None
):
super
().
__init__
()
if
num_heads_upsample
==
-
1
:
num_heads_upsample
=
num_heads
self
.
in_channels
=
in_channels
self
.
model_channels
=
model_channels
self
.
out_channels
=
out_channels
self
.
num_res_blocks
=
num_res_blocks
self
.
attention_resolutions
=
attention_resolutions
self
.
dropout
=
dropout
self
.
channel_mult
=
channel_mult
self
.
conv_resample
=
conv_resample
self
.
use_checkpoint
=
use_checkpoint
self
.
dtype
=
torch
.
float16
if
use_fp16
else
torch
.
float32
self
.
num_heads
=
num_heads
self
.
num_head_channels
=
num_head_channels
self
.
num_heads_upsample
=
num_heads_upsample
inner_dim
=
n_heads
*
d_head
self
.
norm
=
Normalize
(
in_channels
)
time_embed_dim
=
model_channels
*
4
self
.
time_embed
=
nn
.
Sequential
(
linear
(
model_channels
,
time_embed_dim
),
nn
.
SiLU
(),
linear
(
time_embed_dim
,
time_embed_dim
),
)
self
.
proj_in
=
nn
.
Conv2d
(
in_channels
,
inner_dim
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
input_blocks
=
nn
.
ModuleList
(
[
TimestepEmbedSequential
(
conv_nd
(
dims
,
in_channels
,
model_channels
,
3
,
padding
=
1
))]
self
.
transformer_blocks
=
nn
.
ModuleList
(
[
BasicTransformerBlock
(
inner_dim
,
n_heads
,
d_head
,
dropout
=
dropout
,
context_dim
=
context_dim
)
for
d
in
range
(
depth
)
]
)
self
.
_feature_size
=
model_channels
input_block_chans
=
[
model_channels
]
ch
=
model_channels
ds
=
1
for
level
,
mult
in
enumerate
(
channel_mult
):
for
_
in
range
(
num_res_blocks
):
layers
=
[
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
out_channels
=
mult
*
model_channels
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
)
]
ch
=
mult
*
model_channels
if
ds
in
attention_resolutions
:
layers
.
append
(
AttentionBlock
(
ch
,
use_checkpoint
=
use_checkpoint
,
num_heads
=
num_heads
,
num_head_channels
=
num_head_channels
,
use_new_attention_order
=
use_new_attention_order
,
)
)
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
self
.
_feature_size
+=
ch
input_block_chans
.
append
(
ch
)
if
level
!=
len
(
channel_mult
)
-
1
:
out_ch
=
ch
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
out_channels
=
out_ch
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
down
=
True
,
)
if
resblock_updown
else
Downsample
(
ch
,
use_conv
=
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
,
padding
=
1
,
name
=
"op"
)
)
)
ch
=
out_ch
input_block_chans
.
append
(
ch
)
ds
*=
2
self
.
_feature_size
+=
ch
self
.
middle_block
=
TimestepEmbedSequential
(
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
),
AttentionBlock
(
ch
,
use_checkpoint
=
use_checkpoint
,
num_heads
=
num_heads
,
num_head_channels
=
num_head_channels
,
use_new_attention_order
=
use_new_attention_order
,
),
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
),
)
self
.
_feature_size
+=
ch
self
.
pool
=
pool
if
pool
==
"adaptive"
:
self
.
out
=
nn
.
Sequential
(
normalization
(
ch
),
nn
.
SiLU
(),
nn
.
AdaptiveAvgPool2d
((
1
,
1
)),
zero_module
(
conv_nd
(
dims
,
ch
,
out_channels
,
1
)),
nn
.
Flatten
(),
)
elif
pool
==
"attention"
:
assert
num_head_channels
!=
-
1
self
.
out
=
nn
.
Sequential
(
normalization
(
ch
),
nn
.
SiLU
(),
AttentionPool2d
((
image_size
//
ds
),
ch
,
num_head_channels
,
out_channels
),
)
elif
pool
==
"spatial"
:
self
.
out
=
nn
.
Sequential
(
nn
.
Linear
(
self
.
_feature_size
,
2048
),
nn
.
ReLU
(),
nn
.
Linear
(
2048
,
self
.
out_channels
),
)
elif
pool
==
"spatial_v2"
:
self
.
out
=
nn
.
Sequential
(
nn
.
Linear
(
self
.
_feature_size
,
2048
),
normalization
(
2048
),
nn
.
SiLU
(),
nn
.
Linear
(
2048
,
self
.
out_channels
),
)
else
:
raise
NotImplementedError
(
f
"Unexpected
{
pool
}
pooling"
)
self
.
proj_out
=
zero_module
(
nn
.
Conv2d
(
inner_dim
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
))
def
convert_to_fp16
(
self
):
"""
Convert the torso of the model to float16.
"""
self
.
input_blocks
.
apply
(
convert_module_to_f16
)
self
.
middle_block
.
apply
(
convert_module_to_f16
)
def
forward
(
self
,
x
,
context
=
None
):
# note: if no context is given, cross-attention defaults to self-attention
b
,
c
,
h
,
w
=
x
.
shape
x_in
=
x
x
=
self
.
norm
(
x
)
x
=
self
.
proj_in
(
x
)
x
=
x
.
permute
(
0
,
2
,
3
,
1
).
reshape
(
b
,
h
*
w
,
c
)
for
block
in
self
.
transformer_blocks
:
x
=
block
(
x
,
context
=
context
)
x
=
x
.
reshape
(
b
,
h
,
w
,
c
).
permute
(
0
,
3
,
1
,
2
)
x
=
self
.
proj_out
(
x
)
return
x
+
x_in
def
convert_to_fp32
(
self
):
"""
Convert the torso of the model to float32.
"""
self
.
input_blocks
.
apply
(
convert_module_to_f32
)
self
.
middle_block
.
apply
(
convert_module_to_f32
)
def
forward
(
self
,
x
,
timesteps
):
"""
Apply the model to an input batch. :param x: an [N x C x ...] Tensor of inputs. :param timesteps: a 1-D batch
of timesteps. :return: an [N x K] Tensor of outputs.
"""
emb
=
self
.
time_embed
(
get_timestep_embedding
(
timesteps
,
self
.
model_channels
,
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
)
)
class
BasicTransformerBlock
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
n_heads
,
d_head
,
dropout
=
0.0
,
context_dim
=
None
,
gated_ff
=
True
,
checkpoint
=
True
):
super
().
__init__
()
self
.
attn1
=
CrossAttention
(
query_dim
=
dim
,
heads
=
n_heads
,
dim_head
=
d_head
,
dropout
=
dropout
)
# is a self-attention
self
.
ff
=
FeedForward
(
dim
,
dropout
=
dropout
,
glu
=
gated_ff
)
self
.
attn2
=
CrossAttention
(
query_dim
=
dim
,
context_dim
=
context_dim
,
heads
=
n_heads
,
dim_head
=
d_head
,
dropout
=
dropout
)
# is self-attn if context is none
self
.
norm1
=
nn
.
LayerNorm
(
dim
)
self
.
norm2
=
nn
.
LayerNorm
(
dim
)
self
.
norm3
=
nn
.
LayerNorm
(
dim
)
self
.
checkpoint
=
checkpoint
results
=
[]
h
=
x
.
type
(
self
.
dtype
)
for
module
in
self
.
input_blocks
:
h
=
module
(
h
,
emb
)
if
self
.
pool
.
startswith
(
"spatial"
):
results
.
append
(
h
.
type
(
x
.
dtype
).
mean
(
dim
=
(
2
,
3
)))
h
=
self
.
middle_block
(
h
,
emb
)
if
self
.
pool
.
startswith
(
"spatial"
):
results
.
append
(
h
.
type
(
x
.
dtype
).
mean
(
dim
=
(
2
,
3
)))
h
=
torch
.
cat
(
results
,
axis
=-
1
)
return
self
.
out
(
h
)
else
:
h
=
h
.
type
(
x
.
dtype
)
return
self
.
out
(
h
)
def
forward
(
self
,
x
,
context
=
None
):
x
=
self
.
attn1
(
self
.
norm1
(
x
))
+
x
x
=
self
.
attn2
(
self
.
norm2
(
x
),
context
=
context
)
+
x
x
=
self
.
ff
(
self
.
norm3
(
x
))
+
x
return
x
class
CrossAttention
(
nn
.
Module
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.0
):
super
().
__init__
()
inner_dim
=
dim_head
*
heads
context_dim
=
default
(
context_dim
,
query_dim
)
self
.
scale
=
dim_head
**-
0.5
self
.
heads
=
heads
self
.
to_q
=
nn
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
)
self
.
to_k
=
nn
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
self
.
to_v
=
nn
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
self
.
to_out
=
nn
.
Sequential
(
nn
.
Linear
(
inner_dim
,
query_dim
),
nn
.
Dropout
(
dropout
))
def
reshape_heads_to_batch_dim
(
self
,
tensor
):
batch_size
,
seq_len
,
dim
=
tensor
.
shape
head_size
=
self
.
heads
tensor
=
tensor
.
reshape
(
batch_size
,
seq_len
,
head_size
,
dim
//
head_size
)
tensor
=
tensor
.
permute
(
0
,
2
,
1
,
3
).
reshape
(
batch_size
*
head_size
,
seq_len
,
dim
//
head_size
)
return
tensor
def
reshape_batch_dim_to_heads
(
self
,
tensor
):
batch_size
,
seq_len
,
dim
=
tensor
.
shape
head_size
=
self
.
heads
tensor
=
tensor
.
reshape
(
batch_size
//
head_size
,
head_size
,
seq_len
,
dim
)
tensor
=
tensor
.
permute
(
0
,
2
,
1
,
3
).
reshape
(
batch_size
//
head_size
,
seq_len
,
dim
*
head_size
)
return
tensor
def
forward
(
self
,
x
,
context
=
None
,
mask
=
None
):
batch_size
,
sequence_length
,
dim
=
x
.
shape
h
=
self
.
heads
q
=
self
.
to_q
(
x
)
context
=
default
(
context
,
x
)
k
=
self
.
to_k
(
context
)
v
=
self
.
to_v
(
context
)
q
=
self
.
reshape_heads_to_batch_dim
(
q
)
k
=
self
.
reshape_heads_to_batch_dim
(
k
)
v
=
self
.
reshape_heads_to_batch_dim
(
v
)
sim
=
torch
.
einsum
(
"b i d, b j d -> b i j"
,
q
,
k
)
*
self
.
scale
if
exists
(
mask
):
mask
=
mask
.
reshape
(
batch_size
,
-
1
)
max_neg_value
=
-
torch
.
finfo
(
sim
.
dtype
).
max
mask
=
mask
[:,
None
,
:].
repeat
(
h
,
1
,
1
)
sim
.
masked_fill_
(
~
mask
,
max_neg_value
)
# attention, what we cannot get enough of
attn
=
sim
.
softmax
(
dim
=-
1
)
out
=
torch
.
einsum
(
"b i j, b j d -> b i d"
,
attn
,
v
)
out
=
self
.
reshape_batch_dim_to_heads
(
out
)
return
self
.
to_out
(
out
)
src/diffusers/models/unet_rl.py
View file @
23904d54
...
...
@@ -6,7 +6,7 @@ import torch.nn as nn
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
ResidualTemporalBlock
from
.resnet
import
Downsample
,
ResidualTemporalBlock
,
Upsample
class
SinusoidalPosEmb
(
nn
.
Module
):
...
...
@@ -18,24 +18,6 @@ class SinusoidalPosEmb(nn.Module):
return
get_timestep_embedding
(
x
,
self
.
dim
)
class
Downsample1d
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
().
__init__
()
self
.
conv
=
nn
.
Conv1d
(
dim
,
dim
,
3
,
2
,
1
)
def
forward
(
self
,
x
):
return
self
.
conv
(
x
)
class
Upsample1d
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
().
__init__
()
self
.
conv
=
nn
.
ConvTranspose1d
(
dim
,
dim
,
4
,
2
,
1
)
def
forward
(
self
,
x
):
return
self
.
conv
(
x
)
class
RearrangeDim
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
...
...
@@ -114,7 +96,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
[
ResidualTemporalBlock
(
dim_in
,
dim_out
,
embed_dim
=
time_dim
,
horizon
=
training_horizon
),
ResidualTemporalBlock
(
dim_out
,
dim_out
,
embed_dim
=
time_dim
,
horizon
=
training_horizon
),
Downsample
1d
(
dim_out
)
if
not
is_last
else
nn
.
Identity
(),
Downsample
(
dim_out
,
use_conv
=
True
,
dims
=
1
)
if
not
is_last
else
nn
.
Identity
(),
]
)
)
...
...
@@ -134,7 +116,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
[
ResidualTemporalBlock
(
dim_out
*
2
,
dim_in
,
embed_dim
=
time_dim
,
horizon
=
training_horizon
),
ResidualTemporalBlock
(
dim_in
,
dim_in
,
embed_dim
=
time_dim
,
horizon
=
training_horizon
),
Upsample
1d
(
dim_in
)
if
not
is_last
else
nn
.
Identity
(),
Upsample
(
dim_in
,
use_conv_transpose
=
True
,
dims
=
1
)
if
not
is_last
else
nn
.
Identity
(),
]
)
)
...
...
src/diffusers/models/unet_sde_score_estimation.py
View file @
23904d54
...
...
@@ -17,7 +17,6 @@
import
functools
import
math
import
string
import
numpy
as
np
import
torch
...
...
@@ -28,116 +27,21 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.embeddings
import
GaussianFourierProjection
,
get_timestep_embedding
from
.resnet
import
ResnetBlockBigGANpp
,
ResnetBlockDDPMpp
from
.resnet
import
Downsample
,
ResnetBlockBigGANpp
,
Upsample
,
downsample_2d
,
upfirdn2d
,
upsample_2d
def
upfirdn2d
(
input
,
kernel
,
up
=
1
,
down
=
1
,
pad
=
(
0
,
0
)):
return
upfirdn2d_native
(
input
,
kernel
,
up
,
up
,
down
,
down
,
pad
[
0
],
pad
[
1
],
pad
[
0
],
pad
[
1
])
def
upfirdn2d_native
(
input
,
kernel
,
up_x
,
up_y
,
down_x
,
down_y
,
pad_x0
,
pad_x1
,
pad_y0
,
pad_y1
):
_
,
channel
,
in_h
,
in_w
=
input
.
shape
input
=
input
.
reshape
(
-
1
,
in_h
,
in_w
,
1
)
_
,
in_h
,
in_w
,
minor
=
input
.
shape
kernel_h
,
kernel_w
=
kernel
.
shape
out
=
input
.
view
(
-
1
,
in_h
,
1
,
in_w
,
1
,
minor
)
out
=
F
.
pad
(
out
,
[
0
,
0
,
0
,
up_x
-
1
,
0
,
0
,
0
,
up_y
-
1
])
out
=
out
.
view
(
-
1
,
in_h
*
up_y
,
in_w
*
up_x
,
minor
)
out
=
F
.
pad
(
out
,
[
0
,
0
,
max
(
pad_x0
,
0
),
max
(
pad_x1
,
0
),
max
(
pad_y0
,
0
),
max
(
pad_y1
,
0
)])
out
=
out
[
:,
max
(
-
pad_y0
,
0
)
:
out
.
shape
[
1
]
-
max
(
-
pad_y1
,
0
),
max
(
-
pad_x0
,
0
)
:
out
.
shape
[
2
]
-
max
(
-
pad_x1
,
0
),
:,
]
out
=
out
.
permute
(
0
,
3
,
1
,
2
)
out
=
out
.
reshape
([
-
1
,
1
,
in_h
*
up_y
+
pad_y0
+
pad_y1
,
in_w
*
up_x
+
pad_x0
+
pad_x1
])
w
=
torch
.
flip
(
kernel
,
[
0
,
1
]).
view
(
1
,
1
,
kernel_h
,
kernel_w
)
out
=
F
.
conv2d
(
out
,
w
)
out
=
out
.
reshape
(
-
1
,
minor
,
in_h
*
up_y
+
pad_y0
+
pad_y1
-
kernel_h
+
1
,
in_w
*
up_x
+
pad_x0
+
pad_x1
-
kernel_w
+
1
,
)
out
=
out
.
permute
(
0
,
2
,
3
,
1
)
out
=
out
[:,
::
down_y
,
::
down_x
,
:]
out_h
=
(
in_h
*
up_y
+
pad_y0
+
pad_y1
-
kernel_h
)
//
down_y
+
1
out_w
=
(
in_w
*
up_x
+
pad_x0
+
pad_x1
-
kernel_w
)
//
down_x
+
1
return
out
.
view
(
-
1
,
channel
,
out_h
,
out_w
)
# Function ported from StyleGAN2
def
get_weight
(
module
,
shape
,
weight_var
=
"weight"
,
kernel_init
=
None
):
"""Get/create weight tensor for a convolution or fully-connected layer."""
return
module
.
param
(
weight_var
,
kernel_init
,
shape
)
class
Conv2d
(
nn
.
Module
):
"""Conv2d layer with optimal upsampling and downsampling (StyleGAN2)."""
def
__init__
(
self
,
in_ch
,
out_ch
,
kernel
,
up
=
False
,
down
=
False
,
resample_kernel
=
(
1
,
3
,
3
,
1
),
use_bias
=
True
,
kernel_init
=
None
,
):
super
().
__init__
()
assert
not
(
up
and
down
)
assert
kernel
>=
1
and
kernel
%
2
==
1
self
.
weight
=
nn
.
Parameter
(
torch
.
zeros
(
out_ch
,
in_ch
,
kernel
,
kernel
))
if
kernel_init
is
not
None
:
self
.
weight
.
data
=
kernel_init
(
self
.
weight
.
data
.
shape
)
if
use_bias
:
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
out_ch
))
self
.
up
=
up
self
.
down
=
down
self
.
resample_kernel
=
resample_kernel
self
.
kernel
=
kernel
self
.
use_bias
=
use_bias
def
forward
(
self
,
x
):
if
self
.
up
:
x
=
upsample_conv_2d
(
x
,
self
.
weight
,
k
=
self
.
resample_kernel
)
elif
self
.
down
:
x
=
conv_downsample_2d
(
x
,
self
.
weight
,
k
=
self
.
resample_kernel
)
else
:
x
=
F
.
conv2d
(
x
,
self
.
weight
,
stride
=
1
,
padding
=
self
.
kernel
//
2
)
if
self
.
use_bias
:
x
=
x
+
self
.
bias
.
reshape
(
1
,
-
1
,
1
,
1
)
return
x
def
naive_upsample_2d
(
x
,
factor
=
2
):
_N
,
C
,
H
,
W
=
x
.
shape
x
=
torch
.
reshape
(
x
,
(
-
1
,
C
,
H
,
1
,
W
,
1
))
x
=
x
.
repeat
(
1
,
1
,
1
,
factor
,
1
,
factor
)
return
torch
.
reshape
(
x
,
(
-
1
,
C
,
H
*
factor
,
W
*
factor
))
def
naive_downsample_2d
(
x
,
factor
=
2
):
_N
,
C
,
H
,
W
=
x
.
shape
x
=
torch
.
reshape
(
x
,
(
-
1
,
C
,
H
//
factor
,
factor
,
W
//
factor
,
factor
))
return
torch
.
mean
(
x
,
dim
=
(
3
,
5
))
def
_setup_kernel
(
k
):
k
=
np
.
asarray
(
k
,
dtype
=
np
.
float32
)
if
k
.
ndim
==
1
:
k
=
np
.
outer
(
k
,
k
)
k
/=
np
.
sum
(
k
)
assert
k
.
ndim
==
2
assert
k
.
shape
[
0
]
==
k
.
shape
[
1
]
return
k
def
upsample_conv_2d
(
x
,
w
,
k
=
None
,
factor
=
2
,
gain
=
1
):
"""Fused `upsample_2d()` followed by `
tf.nn.c
onv2d()`.
def
_
upsample_conv_2d
(
x
,
w
,
k
=
None
,
factor
=
2
,
gain
=
1
):
"""Fused `upsample_2d()` followed by `
C
onv2d()`.
Args:
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
...
...
@@ -176,13 +80,13 @@ def upsample_conv_2d(x, w, k=None, factor=2, gain=1):
# Determine data dimensions.
stride
=
[
1
,
1
,
factor
,
factor
]
output_shape
=
((
_
shape
(
x
,
2
)
-
1
)
*
factor
+
convH
,
(
_
shape
(
x
,
3
)
-
1
)
*
factor
+
convW
)
output_shape
=
((
x
.
shape
[
2
]
-
1
)
*
factor
+
convH
,
(
x
.
shape
[
3
]
-
1
)
*
factor
+
convW
)
output_padding
=
(
output_shape
[
0
]
-
(
_
shape
(
x
,
2
)
-
1
)
*
stride
[
0
]
-
convH
,
output_shape
[
1
]
-
(
_
shape
(
x
,
3
)
-
1
)
*
stride
[
1
]
-
convW
,
output_shape
[
0
]
-
(
x
.
shape
[
2
]
-
1
)
*
stride
[
0
]
-
convH
,
output_shape
[
1
]
-
(
x
.
shape
[
3
]
-
1
)
*
stride
[
1
]
-
convW
,
)
assert
output_padding
[
0
]
>=
0
and
output_padding
[
1
]
>=
0
num_groups
=
_
shape
(
x
,
1
)
//
inC
num_groups
=
x
.
shape
[
1
]
//
inC
# Transpose weights.
w
=
torch
.
reshape
(
w
,
(
num_groups
,
-
1
,
inC
,
convH
,
convW
))
...
...
@@ -190,21 +94,12 @@ def upsample_conv_2d(x, w, k=None, factor=2, gain=1):
w
=
torch
.
reshape
(
w
,
(
num_groups
*
inC
,
-
1
,
convH
,
convW
))
x
=
F
.
conv_transpose2d
(
x
,
w
,
stride
=
stride
,
output_padding
=
output_padding
,
padding
=
0
)
# Original TF code.
# x = tf.nn.conv2d_transpose(
# x,
# w,
# output_shape=output_shape,
# strides=stride,
# padding='VALID',
# data_format=data_format)
# JAX equivalent
return
upfirdn2d
(
x
,
torch
.
tensor
(
k
,
device
=
x
.
device
),
pad
=
((
p
+
1
)
//
2
+
factor
-
1
,
p
//
2
+
1
))
def
conv_downsample_2d
(
x
,
w
,
k
=
None
,
factor
=
2
,
gain
=
1
):
"""Fused `
tf.nn.c
onv2d()` followed by `downsample_2d()`.
def
_
conv_downsample_2d
(
x
,
w
,
k
=
None
,
factor
=
2
,
gain
=
1
):
"""Fused `
C
onv2d()` followed by `downsample_2d()`.
Args:
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
...
...
@@ -235,138 +130,9 @@ def conv_downsample_2d(x, w, k=None, factor=2, gain=1):
return
F
.
conv2d
(
x
,
w
,
stride
=
s
,
padding
=
0
)
def
_setup_kernel
(
k
):
k
=
np
.
asarray
(
k
,
dtype
=
np
.
float32
)
if
k
.
ndim
==
1
:
k
=
np
.
outer
(
k
,
k
)
k
/=
np
.
sum
(
k
)
assert
k
.
ndim
==
2
assert
k
.
shape
[
0
]
==
k
.
shape
[
1
]
return
k
def
_shape
(
x
,
dim
):
return
x
.
shape
[
dim
]
def
upsample_2d
(
x
,
k
=
None
,
factor
=
2
,
gain
=
1
):
r
"""Upsample a batch of 2D images with the given filter.
Args:
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
`gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a:
multiple of the upsampling factor.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H * factor, W * factor]`
"""
assert
isinstance
(
factor
,
int
)
and
factor
>=
1
if
k
is
None
:
k
=
[
1
]
*
factor
k
=
_setup_kernel
(
k
)
*
(
gain
*
(
factor
**
2
))
p
=
k
.
shape
[
0
]
-
factor
return
upfirdn2d
(
x
,
torch
.
tensor
(
k
,
device
=
x
.
device
),
up
=
factor
,
pad
=
((
p
+
1
)
//
2
+
factor
-
1
,
p
//
2
))
def
downsample_2d
(
x
,
k
=
None
,
factor
=
2
,
gain
=
1
):
r
"""Downsample a batch of 2D images with the given filter.
Args:
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
shape is a multiple of the downsampling factor.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to average pooling.
factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H // factor, W // factor]`
"""
assert
isinstance
(
factor
,
int
)
and
factor
>=
1
if
k
is
None
:
k
=
[
1
]
*
factor
k
=
_setup_kernel
(
k
)
*
gain
p
=
k
.
shape
[
0
]
-
factor
return
upfirdn2d
(
x
,
torch
.
tensor
(
k
,
device
=
x
.
device
),
down
=
factor
,
pad
=
((
p
+
1
)
//
2
,
p
//
2
))
def
conv1x1
(
in_planes
,
out_planes
,
stride
=
1
,
bias
=
True
,
init_scale
=
1.0
,
padding
=
0
):
"""1x1 convolution with DDPM initialization."""
conv
=
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
1
,
stride
=
stride
,
padding
=
padding
,
bias
=
bias
)
conv
.
weight
.
data
=
default_init
(
init_scale
)(
conv
.
weight
.
data
.
shape
)
nn
.
init
.
zeros_
(
conv
.
bias
)
return
conv
def
conv3x3
(
in_planes
,
out_planes
,
stride
=
1
,
bias
=
True
,
dilation
=
1
,
init_scale
=
1.0
,
padding
=
1
):
"""3x3 convolution with DDPM initialization."""
conv
=
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
,
bias
=
bias
)
conv
.
weight
.
data
=
default_init
(
init_scale
)(
conv
.
weight
.
data
.
shape
)
nn
.
init
.
zeros_
(
conv
.
bias
)
return
conv
def
_einsum
(
a
,
b
,
c
,
x
,
y
):
einsum_str
=
"{},{}->{}"
.
format
(
""
.
join
(
a
),
""
.
join
(
b
),
""
.
join
(
c
))
return
torch
.
einsum
(
einsum_str
,
x
,
y
)
def
contract_inner
(
x
,
y
):
"""tensordot(x, y, 1)."""
x_chars
=
list
(
string
.
ascii_lowercase
[:
len
(
x
.
shape
)])
y_chars
=
list
(
string
.
ascii_lowercase
[
len
(
x
.
shape
)
:
len
(
y
.
shape
)
+
len
(
x
.
shape
)])
y_chars
[
0
]
=
x_chars
[
-
1
]
# first axis of y and last of x get summed
out_chars
=
x_chars
[:
-
1
]
+
y_chars
[
1
:]
return
_einsum
(
x_chars
,
y_chars
,
out_chars
,
x
,
y
)
class
NIN
(
nn
.
Module
):
def
__init__
(
self
,
in_dim
,
num_units
,
init_scale
=
0.1
):
super
().
__init__
()
self
.
W
=
nn
.
Parameter
(
default_init
(
scale
=
init_scale
)((
in_dim
,
num_units
)),
requires_grad
=
True
)
self
.
b
=
nn
.
Parameter
(
torch
.
zeros
(
num_units
),
requires_grad
=
True
)
def
forward
(
self
,
x
):
x
=
x
.
permute
(
0
,
2
,
3
,
1
)
y
=
contract_inner
(
x
,
self
.
W
)
+
self
.
b
return
y
.
permute
(
0
,
3
,
1
,
2
)
def
get_act
(
nonlinearity
):
"""Get activation functions from the config file."""
if
nonlinearity
.
lower
()
==
"elu"
:
return
nn
.
ELU
()
elif
nonlinearity
.
lower
()
==
"relu"
:
return
nn
.
ReLU
()
elif
nonlinearity
.
lower
()
==
"lrelu"
:
return
nn
.
LeakyReLU
(
negative_slope
=
0.2
)
elif
nonlinearity
.
lower
()
==
"swish"
:
return
nn
.
SiLU
()
else
:
raise
NotImplementedError
(
"activation function does not exist!"
)
def
default_init
(
scale
=
1.0
):
"""The same initialization used in DDPM."""
scale
=
1e-10
if
scale
==
0
else
scale
return
variance_scaling
(
scale
,
"fan_avg"
,
"uniform"
)
def
variance_scaling
(
scale
,
mode
,
distribution
,
in_axis
=
1
,
out_axis
=
0
,
dtype
=
torch
.
float32
,
device
=
"cpu"
):
def
_variance_scaling
(
scale
=
1.0
,
in_axis
=
1
,
out_axis
=
0
,
dtype
=
torch
.
float32
,
device
=
"cpu"
):
"""Ported from JAX."""
scale
=
1e-10
if
scale
==
0
else
scale
def
_compute_fans
(
shape
,
in_axis
=
1
,
out_axis
=
0
):
receptive_field_size
=
np
.
prod
(
shape
)
/
shape
[
in_axis
]
/
shape
[
out_axis
]
...
...
@@ -376,31 +142,35 @@ def variance_scaling(scale, mode, distribution, in_axis=1, out_axis=0, dtype=tor
def
init
(
shape
,
dtype
=
dtype
,
device
=
device
):
fan_in
,
fan_out
=
_compute_fans
(
shape
,
in_axis
,
out_axis
)
if
mode
==
"fan_in"
:
denominator
=
fan_in
elif
mode
==
"fan_out"
:
denominator
=
fan_out
elif
mode
==
"fan_avg"
:
denominator
=
(
fan_in
+
fan_out
)
/
2
else
:
raise
ValueError
(
"invalid mode for variance scaling initializer: {}"
.
format
(
mode
))
denominator
=
(
fan_in
+
fan_out
)
/
2
variance
=
scale
/
denominator
if
distribution
==
"normal"
:
return
torch
.
randn
(
*
shape
,
dtype
=
dtype
,
device
=
device
)
*
np
.
sqrt
(
variance
)
elif
distribution
==
"uniform"
:
return
(
torch
.
rand
(
*
shape
,
dtype
=
dtype
,
device
=
device
)
*
2.0
-
1.0
)
*
np
.
sqrt
(
3
*
variance
)
else
:
raise
ValueError
(
"invalid distribution for variance scaling initializer"
)
return
(
torch
.
rand
(
*
shape
,
dtype
=
dtype
,
device
=
device
)
*
2.0
-
1.0
)
*
np
.
sqrt
(
3
*
variance
)
return
init
def
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
3
,
stride
=
1
,
bias
=
True
,
init_scale
=
1.0
,
padding
=
1
):
"""nXn convolution with DDPM initialization."""
conv
=
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
bias
=
bias
)
conv
.
weight
.
data
=
_variance_scaling
(
init_scale
)(
conv
.
weight
.
data
.
shape
)
nn
.
init
.
zeros_
(
conv
.
bias
)
return
conv
def
Linear
(
dim_in
,
dim_out
):
linear
=
nn
.
Linear
(
dim_in
,
dim_out
)
linear
.
weight
.
data
=
_variance_scaling
()(
linear
.
weight
.
shape
)
nn
.
init
.
zeros_
(
linear
.
bias
)
return
linear
class
Combine
(
nn
.
Module
):
"""Combine information from skip connections."""
def
__init__
(
self
,
dim1
,
dim2
,
method
=
"cat"
):
super
().
__init__
()
self
.
Conv_0
=
conv1x1
(
dim1
,
dim2
)
# 1x1 convolution with DDPM initialization.
self
.
Conv_0
=
Conv2d
(
dim1
,
dim2
,
kernel_size
=
1
,
padding
=
0
)
self
.
method
=
method
def
forward
(
self
,
x
,
y
):
...
...
@@ -413,80 +183,42 @@ class Combine(nn.Module):
raise
ValueError
(
f
"Method
{
self
.
method
}
not recognized."
)
class
Upsample
(
nn
.
Module
):
def
__init__
(
self
,
in_ch
=
None
,
out_ch
=
None
,
with
_conv
=
False
,
fir
=
False
,
fir_kernel
=
(
1
,
3
,
3
,
1
)):
class
Fir
Upsample
(
nn
.
Module
):
def
__init__
(
self
,
channels
=
None
,
out_ch
annels
=
None
,
use
_conv
=
False
,
fir_kernel
=
(
1
,
3
,
3
,
1
)):
super
().
__init__
()
out_ch
=
out_ch
if
out_ch
else
in_ch
if
not
fir
:
if
with_conv
:
self
.
Conv_0
=
conv3x3
(
in_ch
,
out_ch
)
else
:
if
with_conv
:
self
.
Conv2d_0
=
Conv2d
(
in_ch
,
out_ch
,
kernel
=
3
,
up
=
True
,
resample_kernel
=
fir_kernel
,
use_bias
=
True
,
kernel_init
=
default_init
(),
)
self
.
fir
=
fir
self
.
with_conv
=
with_conv
out_channels
=
out_channels
if
out_channels
else
channels
if
use_conv
:
self
.
Conv2d_0
=
Conv2d
(
channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
use_conv
=
use_conv
self
.
fir_kernel
=
fir_kernel
self
.
out_ch
=
out_ch
self
.
out_ch
annels
=
out_ch
annels
def
forward
(
self
,
x
):
B
,
C
,
H
,
W
=
x
.
shape
if
not
self
.
fir
:
h
=
F
.
interpolate
(
x
,
(
H
*
2
,
W
*
2
),
"nearest"
)
if
self
.
with_conv
:
h
=
self
.
Conv_0
(
h
)
if
self
.
use_conv
:
h
=
_upsample_conv_2d
(
x
,
self
.
Conv2d_0
.
weight
,
k
=
self
.
fir_kernel
)
h
=
h
+
self
.
Conv2d_0
.
bias
.
reshape
(
1
,
-
1
,
1
,
1
)
else
:
if
not
self
.
with_conv
:
h
=
upsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
else
:
h
=
self
.
Conv2d_0
(
x
)
h
=
upsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
return
h
class
Downsample
(
nn
.
Module
):
def
__init__
(
self
,
in_ch
=
None
,
out_ch
=
None
,
with
_conv
=
False
,
fir
=
False
,
fir_kernel
=
(
1
,
3
,
3
,
1
)):
class
Fir
Downsample
(
nn
.
Module
):
def
__init__
(
self
,
channels
=
None
,
out_ch
annels
=
None
,
use
_conv
=
False
,
fir_kernel
=
(
1
,
3
,
3
,
1
)):
super
().
__init__
()
out_ch
=
out_ch
if
out_ch
else
in_ch
if
not
fir
:
if
with_conv
:
self
.
Conv_0
=
conv3x3
(
in_ch
,
out_ch
,
stride
=
2
,
padding
=
0
)
else
:
if
with_conv
:
self
.
Conv2d_0
=
Conv2d
(
in_ch
,
out_ch
,
kernel
=
3
,
down
=
True
,
resample_kernel
=
fir_kernel
,
use_bias
=
True
,
kernel_init
=
default_init
(),
)
self
.
fir
=
fir
out_channels
=
out_channels
if
out_channels
else
channels
if
use_conv
:
self
.
Conv2d_0
=
self
.
Conv2d_0
=
Conv2d
(
channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
fir_kernel
=
fir_kernel
self
.
with
_conv
=
with
_conv
self
.
out_ch
=
out_ch
self
.
use
_conv
=
use
_conv
self
.
out_ch
annels
=
out_ch
annels
def
forward
(
self
,
x
):
B
,
C
,
H
,
W
=
x
.
shape
if
not
self
.
fir
:
if
self
.
with_conv
:
x
=
F
.
pad
(
x
,
(
0
,
1
,
0
,
1
))
x
=
self
.
Conv_0
(
x
)
else
:
x
=
F
.
avg_pool2d
(
x
,
2
,
stride
=
2
)
if
self
.
use_conv
:
x
=
_conv_downsample_2d
(
x
,
self
.
Conv2d_0
.
weight
,
k
=
self
.
fir_kernel
)
x
=
x
+
self
.
Conv2d_0
.
bias
.
reshape
(
1
,
-
1
,
1
,
1
)
else
:
if
not
self
.
with_conv
:
x
=
downsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
else
:
x
=
self
.
Conv2d_0
(
x
)
x
=
downsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
return
x
...
...
@@ -496,10 +228,9 @@ class NCSNpp(ModelMixin, ConfigMixin):
def
__init__
(
self
,
centered
=
False
,
image_size
=
1024
,
num_channels
=
3
,
attention_type
=
"ddpm"
,
centered
=
False
,
attn_resolutions
=
(
16
,),
ch_mult
=
(
1
,
2
,
4
,
8
,
16
,
32
,
32
,
32
),
conditional
=
True
,
...
...
@@ -511,24 +242,20 @@ class NCSNpp(ModelMixin, ConfigMixin):
fourier_scale
=
16
,
init_scale
=
0.0
,
nf
=
16
,
nonlinearity
=
"swish"
,
normalization
=
"GroupNorm"
,
num_res_blocks
=
1
,
progressive
=
"output_skip"
,
progressive_combine
=
"sum"
,
progressive_input
=
"input_skip"
,
resamp_with_conv
=
True
,
resblock_type
=
"biggan"
,
scale_by_sigma
=
True
,
skip_rescale
=
True
,
continuous
=
True
,
):
super
().
__init__
()
self
.
register_to_config
(
centered
=
centered
,
image_size
=
image_size
,
num_channels
=
num_channels
,
attention_type
=
attention_type
,
centered
=
centered
,
attn_resolutions
=
attn_resolutions
,
ch_mult
=
ch_mult
,
conditional
=
conditional
,
...
...
@@ -540,19 +267,16 @@ class NCSNpp(ModelMixin, ConfigMixin):
fourier_scale
=
fourier_scale
,
init_scale
=
init_scale
,
nf
=
nf
,
nonlinearity
=
nonlinearity
,
normalization
=
normalization
,
num_res_blocks
=
num_res_blocks
,
progressive
=
progressive
,
progressive_combine
=
progressive_combine
,
progressive_input
=
progressive_input
,
resamp_with_conv
=
resamp_with_conv
,
resblock_type
=
resblock_type
,
scale_by_sigma
=
scale_by_sigma
,
skip_rescale
=
skip_rescale
,
continuous
=
continuous
,
)
self
.
act
=
act
=
get_act
(
nonlinearity
)
self
.
act
=
act
=
nn
.
SiLU
(
)
self
.
nf
=
nf
self
.
num_res_blocks
=
num_res_blocks
...
...
@@ -562,7 +286,6 @@ class NCSNpp(ModelMixin, ConfigMixin):
self
.
conditional
=
conditional
self
.
skip_rescale
=
skip_rescale
self
.
resblock_type
=
resblock_type
self
.
progressive
=
progressive
self
.
progressive_input
=
progressive_input
self
.
embedding_type
=
embedding_type
...
...
@@ -585,53 +308,41 @@ class NCSNpp(ModelMixin, ConfigMixin):
else
:
raise
ValueError
(
f
"embedding type
{
embedding_type
}
unknown."
)
if
conditional
:
modules
.
append
(
nn
.
Linear
(
embed_dim
,
nf
*
4
))
modules
[
-
1
].
weight
.
data
=
default_init
()(
modules
[
-
1
].
weight
.
shape
)
nn
.
init
.
zeros_
(
modules
[
-
1
].
bias
)
modules
.
append
(
nn
.
Linear
(
nf
*
4
,
nf
*
4
))
modules
[
-
1
].
weight
.
data
=
default_init
()(
modules
[
-
1
].
weight
.
shape
)
nn
.
init
.
zeros_
(
modules
[
-
1
].
bias
)
modules
.
append
(
Linear
(
embed_dim
,
nf
*
4
))
modules
.
append
(
Linear
(
nf
*
4
,
nf
*
4
))
AttnBlock
=
functools
.
partial
(
AttentionBlock
,
overwrite_linear
=
True
,
rescale_output_factor
=
math
.
sqrt
(
2.0
))
Up_sample
=
functools
.
partial
(
Upsample
,
with_conv
=
resamp_with_conv
,
fir
=
fir
,
fir_kernel
=
fir_kernel
)
if
self
.
fir
:
Up_sample
=
functools
.
partial
(
FirUpsample
,
fir_kernel
=
fir_kernel
,
use_conv
=
resamp_with_conv
)
else
:
Up_sample
=
functools
.
partial
(
Upsample
,
name
=
"Conv2d_0"
)
if
progressive
==
"output_skip"
:
self
.
pyramid_upsample
=
Up_sample
(
fir
=
fir
,
fir_kernel
=
fir_ker
ne
l
,
with
_conv
=
False
)
self
.
pyramid_upsample
=
Up_sample
(
channels
=
No
ne
,
use
_conv
=
False
)
elif
progressive
==
"residual"
:
pyramid_upsample
=
functools
.
partial
(
Up_sample
,
fir
=
fir
,
fir_kernel
=
fir_kernel
,
with
_conv
=
True
)
pyramid_upsample
=
functools
.
partial
(
Up_sample
,
use
_conv
=
True
)
Down_sample
=
functools
.
partial
(
Downsample
,
with_conv
=
resamp_with_conv
,
fir
=
fir
,
fir_kernel
=
fir_kernel
)
if
self
.
fir
:
Down_sample
=
functools
.
partial
(
FirDownsample
,
fir_kernel
=
fir_kernel
,
use_conv
=
resamp_with_conv
)
else
:
Down_sample
=
functools
.
partial
(
Downsample
,
padding
=
0
,
name
=
"Conv2d_0"
)
if
progressive_input
==
"input_skip"
:
self
.
pyramid_downsample
=
Down_sample
(
fir
=
fir
,
fir_kernel
=
fir_ker
ne
l
,
with
_conv
=
False
)
self
.
pyramid_downsample
=
Down_sample
(
channels
=
No
ne
,
use
_conv
=
False
)
elif
progressive_input
==
"residual"
:
pyramid_downsample
=
functools
.
partial
(
Down_sample
,
fir
=
fir
,
fir_kernel
=
fir_kernel
,
with_conv
=
True
)
if
resblock_type
==
"ddpm"
:
ResnetBlock
=
functools
.
partial
(
ResnetBlockDDPMpp
,
act
=
act
,
dropout
=
dropout
,
init_scale
=
init_scale
,
skip_rescale
=
skip_rescale
,
temb_dim
=
nf
*
4
,
)
elif
resblock_type
==
"biggan"
:
ResnetBlock
=
functools
.
partial
(
ResnetBlockBigGANpp
,
act
=
act
,
dropout
=
dropout
,
fir
=
fir
,
fir_kernel
=
fir_kernel
,
init_scale
=
init_scale
,
skip_rescale
=
skip_rescale
,
temb_dim
=
nf
*
4
,
)
pyramid_downsample
=
functools
.
partial
(
Down_sample
,
use_conv
=
True
)
else
:
raise
ValueError
(
f
"resblock type
{
resblock_type
}
unrecognized."
)
ResnetBlock
=
functools
.
partial
(
ResnetBlockBigGANpp
,
act
=
act
,
dropout
=
dropout
,
fir
=
fir
,
fir_kernel
=
fir_kernel
,
init_scale
=
init_scale
,
skip_rescale
=
skip_rescale
,
temb_dim
=
nf
*
4
,
)
# Downsampling block
...
...
@@ -639,7 +350,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
if
progressive_input
!=
"none"
:
input_pyramid_ch
=
channels
modules
.
append
(
c
onv
3x3
(
channels
,
nf
))
modules
.
append
(
C
onv
2d
(
channels
,
nf
,
kernel_size
=
3
,
padding
=
1
))
hs_c
=
[
nf
]
in_ch
=
nf
...
...
@@ -655,10 +366,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
hs_c
.
append
(
in_ch
)
if
i_level
!=
self
.
num_resolutions
-
1
:
if
resblock_type
==
"ddpm"
:
modules
.
append
(
Downsample
(
in_ch
=
in_ch
))
else
:
modules
.
append
(
ResnetBlock
(
down
=
True
,
in_ch
=
in_ch
))
modules
.
append
(
ResnetBlock
(
down
=
True
,
in_ch
=
in_ch
))
if
progressive_input
==
"input_skip"
:
modules
.
append
(
combiner
(
dim1
=
input_pyramid_ch
,
dim2
=
in_ch
))
...
...
@@ -666,7 +374,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
in_ch
*=
2
elif
progressive_input
==
"residual"
:
modules
.
append
(
pyramid_downsample
(
in_ch
=
input_pyramid_ch
,
out_ch
=
in_ch
))
modules
.
append
(
pyramid_downsample
(
channels
=
input_pyramid_ch
,
out_ch
annels
=
in_ch
))
input_pyramid_ch
=
in_ch
hs_c
.
append
(
in_ch
)
...
...
@@ -691,36 +399,35 @@ class NCSNpp(ModelMixin, ConfigMixin):
if
i_level
==
self
.
num_resolutions
-
1
:
if
progressive
==
"output_skip"
:
modules
.
append
(
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
))
modules
.
append
(
c
onv
3x3
(
in_ch
,
channels
,
init_scale
=
init_scale
))
modules
.
append
(
C
onv
2d
(
in_ch
,
channels
,
init_scale
=
init_scale
,
kernel_size
=
3
,
padding
=
1
))
pyramid_ch
=
channels
elif
progressive
==
"residual"
:
modules
.
append
(
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
))
modules
.
append
(
c
onv
3x3
(
in_ch
,
in_ch
,
bias
=
True
))
modules
.
append
(
C
onv
2d
(
in_ch
,
in_ch
,
bias
=
True
,
kernel_size
=
3
,
padding
=
1
))
pyramid_ch
=
in_ch
else
:
raise
ValueError
(
f
"
{
progressive
}
is not a valid name."
)
else
:
if
progressive
==
"output_skip"
:
modules
.
append
(
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
))
modules
.
append
(
conv3x3
(
in_ch
,
channels
,
bias
=
True
,
init_scale
=
init_scale
))
modules
.
append
(
Conv2d
(
in_ch
,
channels
,
bias
=
True
,
init_scale
=
init_scale
,
kernel_size
=
3
,
padding
=
1
)
)
pyramid_ch
=
channels
elif
progressive
==
"residual"
:
modules
.
append
(
pyramid_upsample
(
in_ch
=
pyramid_ch
,
out_ch
=
in_ch
))
modules
.
append
(
pyramid_upsample
(
channels
=
pyramid_ch
,
out_ch
annels
=
in_ch
))
pyramid_ch
=
in_ch
else
:
raise
ValueError
(
f
"
{
progressive
}
is not a valid name"
)
if
i_level
!=
0
:
if
resblock_type
==
"ddpm"
:
modules
.
append
(
Upsample
(
in_ch
=
in_ch
))
else
:
modules
.
append
(
ResnetBlock
(
in_ch
=
in_ch
,
up
=
True
))
modules
.
append
(
ResnetBlock
(
in_ch
=
in_ch
,
up
=
True
))
assert
not
hs_c
if
progressive
!=
"output_skip"
:
modules
.
append
(
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
))
modules
.
append
(
c
onv
3x3
(
in_ch
,
channels
,
init_scale
=
init_scale
))
modules
.
append
(
C
onv
2d
(
in_ch
,
channels
,
init_scale
=
init_scale
))
self
.
all_modules
=
nn
.
ModuleList
(
modules
)
...
...
@@ -751,8 +458,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
else
:
temb
=
None
# If input data is in [0, 1]
if
not
self
.
config
.
centered
:
# If input data is in [0, 1]
x
=
2
*
x
-
1.0
# Downsampling block
...
...
@@ -774,12 +481,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
hs
.
append
(
h
)
if
i_level
!=
self
.
num_resolutions
-
1
:
if
self
.
resblock_type
==
"ddpm"
:
h
=
modules
[
m_idx
](
hs
[
-
1
])
m_idx
+=
1
else
:
h
=
modules
[
m_idx
](
hs
[
-
1
],
temb
)
m_idx
+=
1
h
=
modules
[
m_idx
](
hs
[
-
1
],
temb
)
m_idx
+=
1
if
self
.
progressive_input
==
"input_skip"
:
input_pyramid
=
self
.
pyramid_downsample
(
input_pyramid
)
...
...
@@ -851,12 +554,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
raise
ValueError
(
f
"
{
self
.
progressive
}
is not a valid name"
)
if
i_level
!=
0
:
if
self
.
resblock_type
==
"ddpm"
:
h
=
modules
[
m_idx
](
h
)
m_idx
+=
1
else
:
h
=
modules
[
m_idx
](
h
,
temb
)
m_idx
+=
1
h
=
modules
[
m_idx
](
h
,
temb
)
m_idx
+=
1
assert
not
hs
...
...
tests/test_modeling_utils.py
View file @
23904d54
...
...
@@ -259,7 +259,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
# fmt: off
expected_output_slice
=
torch
.
tensor
([
0.2891
,
-
0.1899
,
0.2595
,
-
0.6214
,
0.0968
,
-
0.2622
,
0.4688
,
0.1311
,
0.0053
])
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
a
tol
=
1e-
3
))
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
r
tol
=
1e-
2
))
class
GlideSuperResUNetTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
...
...
@@ -607,7 +607,7 @@ class UNetGradTTSModelTests(ModelTesterMixin, unittest.TestCase):
expected_output_slice
=
torch
.
tensor
([
-
0.0690
,
-
0.0531
,
0.0633
,
-
0.0660
,
-
0.0541
,
0.0650
,
-
0.0656
,
-
0.0555
,
0.0617
])
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
a
tol
=
1e-3
))
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
r
tol
=
1e-3
))
class
TemporalUNetModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
...
...
@@ -678,7 +678,7 @@ class TemporalUNetModelTests(ModelTesterMixin, unittest.TestCase):
expected_output_slice
=
torch
.
tensor
([
-
0.2714
,
0.1042
,
-
0.0794
,
-
0.2820
,
0.0803
,
-
0.0811
,
-
0.2345
,
0.0580
,
-
0.0584
])
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
a
tol
=
1e-3
))
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
r
tol
=
1e-3
))
class
NCSNppModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
...
...
@@ -742,18 +742,18 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
num_channels
=
3
sizes
=
(
32
,
32
)
noise
=
floats_tensor
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
time_step
=
torch
.
tensor
(
batch_size
*
[
1
0
]).
to
(
torch_device
)
noise
=
torch
.
ones
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
time_step
=
torch
.
tensor
(
batch_size
*
[
1
e-4
]).
to
(
torch_device
)
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
)
output_slice
=
output
[
0
,
-
3
:,
-
3
:,
-
1
].
flatten
().
cpu
()
# fmt: off
expected_output_slice
=
torch
.
tensor
([
3
.1
909e-07
,
-
8.5393e-08
,
4.8460e-07
,
-
4.5550e-07
,
-
1.3205e-06
,
-
6.3475e-07
,
9.7837e-07
,
2.9974e-07
,
1.2345e-0
6
])
expected_output_slice
=
torch
.
tensor
([
0
.1
315
,
0.0741
,
0.0393
,
0.0455
,
0.0556
,
0.0180
,
-
0.0832
,
-
0.0644
,
-
0.085
6
])
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
a
tol
=
1e-
3
))
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
r
tol
=
1e-
2
))
def
test_output_pretrained_ve_large
(
self
):
model
=
NCSNpp
.
from_pretrained
(
"fusing/ncsnpp-ffhq-ve-dummy"
)
...
...
@@ -768,21 +768,21 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
num_channels
=
3
sizes
=
(
32
,
32
)
noise
=
floats_tensor
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
time_step
=
torch
.
tensor
(
batch_size
*
[
1
0
]).
to
(
torch_device
)
noise
=
torch
.
ones
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
time_step
=
torch
.
tensor
(
batch_size
*
[
1
e-4
]).
to
(
torch_device
)
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
)
output_slice
=
output
[
0
,
-
3
:,
-
3
:,
-
1
].
flatten
().
cpu
()
# fmt: off
expected_output_slice
=
torch
.
tensor
([
-
8.3299e-07
,
-
9.0431e-07
,
4.0585e-08
,
9.7563e-07
,
1.0280e-06
,
1.0133e-06
,
1.4979e-06
,
-
2.9716e-07
,
-
6.1817e-07
])
expected_output_slice
=
torch
.
tensor
([
-
0.0325
,
-
0.0900
,
-
0.0869
,
-
0.0332
,
-
0.0725
,
-
0.0270
,
-
0.0101
,
0.0227
,
0.0256
])
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
a
tol
=
1e-
3
))
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
r
tol
=
1e-
2
))
def
test_output_pretrained_vp
(
self
):
model
=
NCSNpp
.
from_pretrained
(
"fusing/
ddpm-
cifar10-
vp-dummy
"
)
model
=
NCSNpp
.
from_pretrained
(
"fusing/cifar10-
ddpmpp-vp
"
)
model
.
eval
()
model
.
to
(
torch_device
)
...
...
@@ -794,18 +794,18 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
num_channels
=
3
sizes
=
(
32
,
32
)
noise
=
floats_tensor
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
time_step
=
torch
.
tensor
(
batch_size
*
[
1
0
]).
to
(
torch_device
)
noise
=
torch
.
randn
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
time_step
=
torch
.
tensor
(
batch_size
*
[
9.
0
]).
to
(
torch_device
)
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
)
output_slice
=
output
[
0
,
-
3
:,
-
3
:,
-
1
].
flatten
().
cpu
()
# fmt: off
expected_output_slice
=
torch
.
tensor
([
-
3.9086e-07
,
-
1.1001e-0
5
,
1
.88
81e-06
,
1.1106e-05
,
1.
6629e-06
,
2.9820e-06
,
8
.4
978e-06
,
8.0253e-07
,
1.
5435e-0
6
])
expected_output_slice
=
torch
.
tensor
([
0.3303
,
-
0.227
5
,
-
2
.88
72
,
-
0.1309
,
-
1.
2861
,
3
.4
567
,
-
1.0083
,
2.5325
,
-
1.
386
6
])
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
a
tol
=
1e-
3
))
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
r
tol
=
1e-
2
))
class
VQModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
...
...
@@ -878,10 +878,9 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
expected_output_slice
=
torch
.
tensor
([
-
1.1321
,
0.1056
,
0.3505
,
-
0.6461
,
-
0.2014
,
0.0419
,
-
0.5763
,
-
0.8462
,
-
0.4218
])
expected_output_slice
=
torch
.
tensor
([
-
1.1321
,
0.1056
,
0.3505
,
-
0.6461
,
-
0.2014
,
0.0419
,
-
0.5763
,
-
0.8462
,
-
0.4218
])
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
a
tol
=
1e-
3
))
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
r
tol
=
1e-
2
))
class
AutoEncoderKLTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
...
...
@@ -950,10 +949,9 @@ class AutoEncoderKLTests(ModelTesterMixin, unittest.TestCase):
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
expected_output_slice
=
torch
.
tensor
([
-
0.0814
,
-
0.0229
,
-
0.1320
,
-
0.4123
,
-
0.0366
,
-
0.3473
,
0.0438
,
-
0.1662
,
0.1750
])
expected_output_slice
=
torch
.
tensor
([
-
0.0814
,
-
0.0229
,
-
0.1320
,
-
0.4123
,
-
0.0366
,
-
0.3473
,
0.0438
,
-
0.1662
,
0.1750
])
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
a
tol
=
1e-
3
))
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
r
tol
=
1e-
2
))
class
PipelineTesterMixin
(
unittest
.
TestCase
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment