Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
23904d54
Commit
23904d54
authored
Jul 01, 2022
by
patil-suraj
Browse files
Merge branch 'main' of
https://github.com/huggingface/diffusers
into conversion-scripts
parents
32b93da8
c691bb2f
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
476 additions
and
1334 deletions
+476
-1334
setup.py
setup.py
+1
-1
src/diffusers/dependency_versions_table.py
src/diffusers/dependency_versions_table.py
+1
-0
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+153
-276
src/diffusers/models/unet.py
src/diffusers/models/unet.py
+0
-42
src/diffusers/models/unet_glide.py
src/diffusers/models/unet_glide.py
+53
-60
src/diffusers/models/unet_ldm.py
src/diffusers/models/unet_ldm.py
+141
-507
src/diffusers/models/unet_rl.py
src/diffusers/models/unet_rl.py
+3
-21
src/diffusers/models/unet_sde_score_estimation.py
src/diffusers/models/unet_sde_score_estimation.py
+104
-405
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+20
-22
No files found.
setup.py
View file @
23904d54
...
@@ -88,7 +88,7 @@ _deps = [
...
@@ -88,7 +88,7 @@ _deps = [
"requests"
,
"requests"
,
"torch>=1.4"
,
"torch>=1.4"
,
"tensorboard"
,
"tensorboard"
,
"modelcards=0.1.4"
"modelcards=
=
0.1.4"
]
]
# this is a lookup table with items like:
# this is a lookup table with items like:
...
...
src/diffusers/dependency_versions_table.py
View file @
23904d54
...
@@ -14,4 +14,5 @@ deps = {
...
@@ -14,4 +14,5 @@ deps = {
"requests"
:
"requests"
,
"requests"
:
"requests"
,
"torch"
:
"torch>=1.4"
,
"torch"
:
"torch>=1.4"
,
"tensorboard"
:
"tensorboard"
,
"tensorboard"
:
"tensorboard"
,
"modelcards"
:
"modelcards==0.1.4"
,
}
}
src/diffusers/models/resnet.py
View file @
23904d54
import
string
from
abc
import
abstractmethod
from
abc
import
abstractmethod
from
functools
import
partial
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -79,18 +79,25 @@ class Upsample(nn.Module):
...
@@ -79,18 +79,25 @@ class Upsample(nn.Module):
upsampling occurs in the inner-two dimensions.
upsampling occurs in the inner-two dimensions.
"""
"""
def
__init__
(
self
,
channels
,
use_conv
=
False
,
use_conv_transpose
=
False
,
dims
=
2
,
out_channels
=
None
):
def
__init__
(
self
,
channels
,
use_conv
=
False
,
use_conv_transpose
=
False
,
dims
=
2
,
out_channels
=
None
,
name
=
"conv"
):
super
().
__init__
()
super
().
__init__
()
self
.
channels
=
channels
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
use_conv
=
use_conv
self
.
dims
=
dims
self
.
dims
=
dims
self
.
use_conv_transpose
=
use_conv_transpose
self
.
use_conv_transpose
=
use_conv_transpose
self
.
name
=
name
conv
=
None
if
use_conv_transpose
:
if
use_conv_transpose
:
self
.
conv
=
conv_transpose_nd
(
dims
,
channels
,
self
.
out_channels
,
4
,
2
,
1
)
conv
=
conv_transpose_nd
(
dims
,
channels
,
self
.
out_channels
,
4
,
2
,
1
)
elif
use_conv
:
elif
use_conv
:
self
.
conv
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
conv
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
if
name
==
"conv"
:
self
.
conv
=
conv
else
:
self
.
Conv2d_0
=
conv
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
assert
x
.
shape
[
1
]
==
self
.
channels
assert
x
.
shape
[
1
]
==
self
.
channels
...
@@ -103,7 +110,10 @@ class Upsample(nn.Module):
...
@@ -103,7 +110,10 @@ class Upsample(nn.Module):
x
=
F
.
interpolate
(
x
,
scale_factor
=
2.0
,
mode
=
"nearest"
)
x
=
F
.
interpolate
(
x
,
scale_factor
=
2.0
,
mode
=
"nearest"
)
if
self
.
use_conv
:
if
self
.
use_conv
:
if
self
.
name
==
"conv"
:
x
=
self
.
conv
(
x
)
x
=
self
.
conv
(
x
)
else
:
x
=
self
.
Conv2d_0
(
x
)
return
x
return
x
...
@@ -135,6 +145,8 @@ class Downsample(nn.Module):
...
@@ -135,6 +145,8 @@ class Downsample(nn.Module):
if
name
==
"conv"
:
if
name
==
"conv"
:
self
.
conv
=
conv
self
.
conv
=
conv
elif
name
==
"Conv2d_0"
:
self
.
Conv2d_0
=
conv
else
:
else
:
self
.
op
=
conv
self
.
op
=
conv
...
@@ -146,6 +158,8 @@ class Downsample(nn.Module):
...
@@ -146,6 +158,8 @@ class Downsample(nn.Module):
if
self
.
name
==
"conv"
:
if
self
.
name
==
"conv"
:
return
self
.
conv
(
x
)
return
self
.
conv
(
x
)
elif
self
.
name
==
"Conv2d_0"
:
return
self
.
Conv2d_0
(
x
)
else
:
else
:
return
self
.
op
(
x
)
return
self
.
op
(
x
)
...
@@ -162,110 +176,7 @@ class Downsample(nn.Module):
...
@@ -162,110 +176,7 @@ class Downsample(nn.Module):
# RESNETS
# RESNETS
# unet_glide.py & unet_ldm.py
# unet.py, unet_grad_tts.py, unet_ldm.py, unet_glide.py
class
ResBlock
(
TimestepBlock
):
"""
A residual block that can optionally change the number of channels.
:param channels: the number of input channels. :param emb_channels: the number of timestep embedding channels.
:param dropout: the rate of dropout. :param out_channels: if specified, the number of out channels. :param
use_conv: if True and out_channels is specified, use a spatial
convolution instead of a smaller 1x1 convolution to change the channels in the skip connection.
:param dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing
on this module. :param up: if True, use this block for upsampling. :param down: if True, use this block for
downsampling.
"""
def
__init__
(
self
,
channels
,
emb_channels
,
dropout
,
out_channels
=
None
,
use_conv
=
False
,
use_scale_shift_norm
=
False
,
dims
=
2
,
use_checkpoint
=
False
,
up
=
False
,
down
=
False
,
):
super
().
__init__
()
self
.
channels
=
channels
self
.
emb_channels
=
emb_channels
self
.
dropout
=
dropout
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
use_checkpoint
=
use_checkpoint
self
.
use_scale_shift_norm
=
use_scale_shift_norm
self
.
in_layers
=
nn
.
Sequential
(
normalization
(
channels
,
swish
=
1.0
),
nn
.
Identity
(),
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
),
)
self
.
updown
=
up
or
down
if
up
:
self
.
h_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
self
.
x_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
elif
down
:
self
.
h_upd
=
Downsample
(
channels
,
use_conv
=
False
,
dims
=
dims
,
padding
=
1
,
name
=
"op"
)
self
.
x_upd
=
Downsample
(
channels
,
use_conv
=
False
,
dims
=
dims
,
padding
=
1
,
name
=
"op"
)
else
:
self
.
h_upd
=
self
.
x_upd
=
nn
.
Identity
()
self
.
emb_layers
=
nn
.
Sequential
(
nn
.
SiLU
(),
linear
(
emb_channels
,
2
*
self
.
out_channels
if
use_scale_shift_norm
else
self
.
out_channels
,
),
)
self
.
out_layers
=
nn
.
Sequential
(
normalization
(
self
.
out_channels
,
swish
=
0.0
if
use_scale_shift_norm
else
1.0
),
nn
.
SiLU
()
if
use_scale_shift_norm
else
nn
.
Identity
(),
nn
.
Dropout
(
p
=
dropout
),
zero_module
(
conv_nd
(
dims
,
self
.
out_channels
,
self
.
out_channels
,
3
,
padding
=
1
)),
)
if
self
.
out_channels
==
channels
:
self
.
skip_connection
=
nn
.
Identity
()
elif
use_conv
:
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
else
:
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
1
)
def
forward
(
self
,
x
,
emb
):
"""
Apply the block to a Tensor, conditioned on a timestep embedding.
:param x: an [N x C x ...] Tensor of features. :param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs.
"""
if
self
.
updown
:
in_rest
,
in_conv
=
self
.
in_layers
[:
-
1
],
self
.
in_layers
[
-
1
]
h
=
in_rest
(
x
)
h
=
self
.
h_upd
(
h
)
x
=
self
.
x_upd
(
x
)
h
=
in_conv
(
h
)
else
:
h
=
self
.
in_layers
(
x
)
emb_out
=
self
.
emb_layers
(
emb
).
type
(
h
.
dtype
)
while
len
(
emb_out
.
shape
)
<
len
(
h
.
shape
):
emb_out
=
emb_out
[...,
None
]
if
self
.
use_scale_shift_norm
:
out_norm
,
out_rest
=
self
.
out_layers
[
0
],
self
.
out_layers
[
1
:]
scale
,
shift
=
torch
.
chunk
(
emb_out
,
2
,
dim
=
1
)
h
=
out_norm
(
h
)
*
(
1
+
scale
)
+
shift
h
=
out_rest
(
h
)
else
:
h
=
h
+
emb_out
h
=
self
.
out_layers
(
h
)
return
self
.
skip_connection
(
x
)
+
h
# unet.py and unet_grad_tts.py
class
ResnetBlock
(
nn
.
Module
):
class
ResnetBlock
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -279,7 +190,12 @@ class ResnetBlock(nn.Module):
...
@@ -279,7 +190,12 @@ class ResnetBlock(nn.Module):
pre_norm
=
True
,
pre_norm
=
True
,
eps
=
1e-6
,
eps
=
1e-6
,
non_linearity
=
"swish"
,
non_linearity
=
"swish"
,
time_embedding_norm
=
"default"
,
up
=
False
,
down
=
False
,
overwrite_for_grad_tts
=
False
,
overwrite_for_grad_tts
=
False
,
overwrite_for_ldm
=
False
,
overwrite_for_glide
=
False
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
pre_norm
=
pre_norm
self
.
pre_norm
=
pre_norm
...
@@ -287,6 +203,9 @@ class ResnetBlock(nn.Module):
...
@@ -287,6 +203,9 @@ class ResnetBlock(nn.Module):
out_channels
=
in_channels
if
out_channels
is
None
else
out_channels
out_channels
=
in_channels
if
out_channels
is
None
else
out_channels
self
.
out_channels
=
out_channels
self
.
out_channels
=
out_channels
self
.
use_conv_shortcut
=
conv_shortcut
self
.
use_conv_shortcut
=
conv_shortcut
self
.
time_embedding_norm
=
time_embedding_norm
self
.
up
=
up
self
.
down
=
down
if
self
.
pre_norm
:
if
self
.
pre_norm
:
self
.
norm1
=
Normalize
(
in_channels
,
num_groups
=
groups
,
eps
=
eps
)
self
.
norm1
=
Normalize
(
in_channels
,
num_groups
=
groups
,
eps
=
eps
)
...
@@ -294,23 +213,38 @@ class ResnetBlock(nn.Module):
...
@@ -294,23 +213,38 @@ class ResnetBlock(nn.Module):
self
.
norm1
=
Normalize
(
out_channels
,
num_groups
=
groups
,
eps
=
eps
)
self
.
norm1
=
Normalize
(
out_channels
,
num_groups
=
groups
,
eps
=
eps
)
self
.
conv1
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
conv1
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
time_embedding_norm
==
"default"
:
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
out_channels
)
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
out_channels
)
elif
time_embedding_norm
==
"scale_shift"
:
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
2
*
out_channels
)
self
.
norm2
=
Normalize
(
out_channels
,
num_groups
=
groups
,
eps
=
eps
)
self
.
norm2
=
Normalize
(
out_channels
,
num_groups
=
groups
,
eps
=
eps
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
non_linearity
==
"swish"
:
if
non_linearity
==
"swish"
:
self
.
nonlinearity
=
nonlinearity
self
.
nonlinearity
=
nonlinearity
elif
non_linearity
==
"mish"
:
elif
non_linearity
==
"mish"
:
self
.
nonlinearity
=
Mish
()
self
.
nonlinearity
=
Mish
()
elif
non_linearity
==
"silu"
:
self
.
nonlinearity
=
nn
.
SiLU
()
if
up
:
self
.
h_upd
=
Upsample
(
in_channels
,
use_conv
=
False
,
dims
=
2
)
self
.
x_upd
=
Upsample
(
in_channels
,
use_conv
=
False
,
dims
=
2
)
elif
down
:
self
.
h_upd
=
Downsample
(
in_channels
,
use_conv
=
False
,
dims
=
2
,
padding
=
1
,
name
=
"op"
)
self
.
x_upd
=
Downsample
(
in_channels
,
use_conv
=
False
,
dims
=
2
,
padding
=
1
,
name
=
"op"
)
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
self
.
conv_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
else
:
self
.
nin_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
nin_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
# TODO(SURAJ, PATRICK): ALL OF THE FOLLOWING OF THE INIT METHOD CAN BE DELETED ONCE WEIGHTS ARE CONVERTED
self
.
is_overwritten
=
False
self
.
is_overwritten
=
False
self
.
overwrite_for_glide
=
overwrite_for_glide
self
.
overwrite_for_grad_tts
=
overwrite_for_grad_tts
self
.
overwrite_for_grad_tts
=
overwrite_for_grad_tts
self
.
overwrite_for_ldm
=
overwrite_for_ldm
or
overwrite_for_glide
if
self
.
overwrite_for_grad_tts
:
if
self
.
overwrite_for_grad_tts
:
dim
=
in_channels
dim
=
in_channels
dim_out
=
out_channels
dim_out
=
out_channels
...
@@ -324,6 +258,37 @@ class ResnetBlock(nn.Module):
...
@@ -324,6 +258,37 @@ class ResnetBlock(nn.Module):
self
.
res_conv
=
torch
.
nn
.
Conv2d
(
dim
,
dim_out
,
1
)
self
.
res_conv
=
torch
.
nn
.
Conv2d
(
dim
,
dim_out
,
1
)
else
:
else
:
self
.
res_conv
=
torch
.
nn
.
Identity
()
self
.
res_conv
=
torch
.
nn
.
Identity
()
elif
self
.
overwrite_for_ldm
:
dims
=
2
# eps = 1e-5
# non_linearity = "silu"
# overwrite_for_ldm
channels
=
in_channels
emb_channels
=
temb_channels
use_scale_shift_norm
=
False
self
.
in_layers
=
nn
.
Sequential
(
normalization
(
channels
,
swish
=
1.0
),
nn
.
Identity
(),
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
),
)
self
.
emb_layers
=
nn
.
Sequential
(
nn
.
SiLU
(),
linear
(
emb_channels
,
2
*
self
.
out_channels
if
self
.
time_embedding_norm
==
"scale_shift"
else
self
.
out_channels
,
),
)
self
.
out_layers
=
nn
.
Sequential
(
normalization
(
self
.
out_channels
,
swish
=
0.0
if
use_scale_shift_norm
else
1.0
),
nn
.
SiLU
()
if
use_scale_shift_norm
else
nn
.
Identity
(),
nn
.
Dropout
(
p
=
dropout
),
zero_module
(
conv_nd
(
dims
,
self
.
out_channels
,
self
.
out_channels
,
3
,
padding
=
1
)),
)
if
self
.
out_channels
==
in_channels
:
self
.
skip_connection
=
nn
.
Identity
()
else
:
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
1
)
def
set_weights_grad_tts
(
self
):
def
set_weights_grad_tts
(
self
):
self
.
conv1
.
weight
.
data
=
self
.
block1
.
block
[
0
].
weight
.
data
self
.
conv1
.
weight
.
data
=
self
.
block1
.
block
[
0
].
weight
.
data
...
@@ -343,27 +308,64 @@ class ResnetBlock(nn.Module):
...
@@ -343,27 +308,64 @@ class ResnetBlock(nn.Module):
self
.
nin_shortcut
.
weight
.
data
=
self
.
res_conv
.
weight
.
data
self
.
nin_shortcut
.
weight
.
data
=
self
.
res_conv
.
weight
.
data
self
.
nin_shortcut
.
bias
.
data
=
self
.
res_conv
.
bias
.
data
self
.
nin_shortcut
.
bias
.
data
=
self
.
res_conv
.
bias
.
data
def
forward
(
self
,
x
,
temb
,
mask
=
None
):
def
set_weights_ldm
(
self
):
self
.
norm1
.
weight
.
data
=
self
.
in_layers
[
0
].
weight
.
data
self
.
norm1
.
bias
.
data
=
self
.
in_layers
[
0
].
bias
.
data
self
.
conv1
.
weight
.
data
=
self
.
in_layers
[
-
1
].
weight
.
data
self
.
conv1
.
bias
.
data
=
self
.
in_layers
[
-
1
].
bias
.
data
self
.
temb_proj
.
weight
.
data
=
self
.
emb_layers
[
-
1
].
weight
.
data
self
.
temb_proj
.
bias
.
data
=
self
.
emb_layers
[
-
1
].
bias
.
data
self
.
norm2
.
weight
.
data
=
self
.
out_layers
[
0
].
weight
.
data
self
.
norm2
.
bias
.
data
=
self
.
out_layers
[
0
].
bias
.
data
self
.
conv2
.
weight
.
data
=
self
.
out_layers
[
-
1
].
weight
.
data
self
.
conv2
.
bias
.
data
=
self
.
out_layers
[
-
1
].
bias
.
data
if
self
.
in_channels
!=
self
.
out_channels
:
self
.
nin_shortcut
.
weight
.
data
=
self
.
skip_connection
.
weight
.
data
self
.
nin_shortcut
.
bias
.
data
=
self
.
skip_connection
.
bias
.
data
def
forward
(
self
,
x
,
temb
,
mask
=
1.0
):
# TODO(Patrick) eventually this class should be split into multiple classes
# too many if else statements
if
self
.
overwrite_for_grad_tts
and
not
self
.
is_overwritten
:
if
self
.
overwrite_for_grad_tts
and
not
self
.
is_overwritten
:
self
.
set_weights_grad_tts
()
self
.
set_weights_grad_tts
()
self
.
is_overwritten
=
True
self
.
is_overwritten
=
True
elif
self
.
overwrite_for_ldm
and
not
self
.
is_overwritten
:
self
.
set_weights_ldm
()
self
.
is_overwritten
=
True
h
=
x
h
=
x
h
=
h
*
mask
if
mask
is
not
None
else
h
h
=
h
*
mask
if
self
.
pre_norm
:
if
self
.
pre_norm
:
h
=
self
.
norm1
(
h
)
h
=
self
.
norm1
(
h
)
h
=
self
.
nonlinearity
(
h
)
h
=
self
.
nonlinearity
(
h
)
if
self
.
up
or
self
.
down
:
x
=
self
.
x_upd
(
x
)
h
=
self
.
h_upd
(
h
)
h
=
self
.
conv1
(
h
)
h
=
self
.
conv1
(
h
)
if
not
self
.
pre_norm
:
if
not
self
.
pre_norm
:
h
=
self
.
norm1
(
h
)
h
=
self
.
norm1
(
h
)
h
=
self
.
nonlinearity
(
h
)
h
=
self
.
nonlinearity
(
h
)
h
=
h
*
mask
if
mask
is
not
None
else
h
h
=
h
*
mask
temb
=
self
.
temb_proj
(
self
.
nonlinearity
(
temb
))[:,
:,
None
,
None
]
h
=
h
+
self
.
temb_proj
(
self
.
nonlinearity
(
temb
))[:,
:,
None
,
None
]
if
self
.
time_embedding_norm
==
"scale_shift"
:
scale
,
shift
=
torch
.
chunk
(
temb
,
2
,
dim
=
1
)
h
=
h
*
mask
if
mask
is
not
None
else
h
h
=
self
.
norm2
(
h
)
h
=
h
+
h
*
scale
+
shift
h
=
self
.
nonlinearity
(
h
)
elif
self
.
time_embedding_norm
==
"default"
:
h
=
h
+
temb
h
=
h
*
mask
if
self
.
pre_norm
:
if
self
.
pre_norm
:
h
=
self
.
norm2
(
h
)
h
=
self
.
norm2
(
h
)
h
=
self
.
nonlinearity
(
h
)
h
=
self
.
nonlinearity
(
h
)
...
@@ -374,13 +376,10 @@ class ResnetBlock(nn.Module):
...
@@ -374,13 +376,10 @@ class ResnetBlock(nn.Module):
if
not
self
.
pre_norm
:
if
not
self
.
pre_norm
:
h
=
self
.
norm2
(
h
)
h
=
self
.
norm2
(
h
)
h
=
self
.
nonlinearity
(
h
)
h
=
self
.
nonlinearity
(
h
)
h
=
h
*
mask
if
mask
is
not
None
else
h
h
=
h
*
mask
x
=
x
*
mask
if
mask
is
not
None
else
x
x
=
x
*
mask
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
x
=
self
.
conv_shortcut
(
x
)
else
:
x
=
self
.
nin_shortcut
(
x
)
x
=
self
.
nin_shortcut
(
x
)
return
x
+
h
return
x
+
h
...
@@ -394,10 +393,6 @@ class Block(torch.nn.Module):
...
@@ -394,10 +393,6 @@ class Block(torch.nn.Module):
torch
.
nn
.
Conv2d
(
dim
,
dim_out
,
3
,
padding
=
1
),
torch
.
nn
.
GroupNorm
(
groups
,
dim_out
),
Mish
()
torch
.
nn
.
Conv2d
(
dim
,
dim_out
,
3
,
padding
=
1
),
torch
.
nn
.
GroupNorm
(
groups
,
dim_out
),
Mish
()
)
)
def
forward
(
self
,
x
,
mask
):
output
=
self
.
block
(
x
*
mask
)
return
output
*
mask
# unet_score_estimation.py
# unet_score_estimation.py
class
ResnetBlockBigGANpp
(
nn
.
Module
):
class
ResnetBlockBigGANpp
(
nn
.
Module
):
...
@@ -424,17 +419,29 @@ class ResnetBlockBigGANpp(nn.Module):
...
@@ -424,17 +419,29 @@ class ResnetBlockBigGANpp(nn.Module):
self
.
fir
=
fir
self
.
fir
=
fir
self
.
fir_kernel
=
fir_kernel
self
.
fir_kernel
=
fir_kernel
self
.
Conv_0
=
conv3x3
(
in_ch
,
out_ch
)
if
self
.
up
:
if
self
.
fir
:
self
.
upsample
=
partial
(
upsample_2d
,
k
=
self
.
fir_kernel
,
factor
=
2
)
else
:
self
.
upsample
=
partial
(
F
.
interpolate
,
scale_factor
=
2.0
,
mode
=
"nearest"
)
elif
self
.
down
:
if
self
.
fir
:
self
.
downsample
=
partial
(
downsample_2d
,
k
=
self
.
fir_kernel
,
factor
=
2
)
else
:
self
.
downsample
=
partial
(
F
.
avg_pool2d
,
kernel_size
=
2
,
stride
=
2
)
self
.
Conv_0
=
conv2d
(
in_ch
,
out_ch
,
kernel_size
=
3
,
padding
=
1
)
if
temb_dim
is
not
None
:
if
temb_dim
is
not
None
:
self
.
Dense_0
=
nn
.
Linear
(
temb_dim
,
out_ch
)
self
.
Dense_0
=
nn
.
Linear
(
temb_dim
,
out_ch
)
self
.
Dense_0
.
weight
.
data
=
default_init
()(
self
.
Dense_0
.
weight
.
shape
)
self
.
Dense_0
.
weight
.
data
=
variance_scaling
()(
self
.
Dense_0
.
weight
.
shape
)
nn
.
init
.
zeros_
(
self
.
Dense_0
.
bias
)
nn
.
init
.
zeros_
(
self
.
Dense_0
.
bias
)
self
.
GroupNorm_1
=
nn
.
GroupNorm
(
num_groups
=
min
(
out_ch
//
4
,
32
),
num_channels
=
out_ch
,
eps
=
1e-6
)
self
.
GroupNorm_1
=
nn
.
GroupNorm
(
num_groups
=
min
(
out_ch
//
4
,
32
),
num_channels
=
out_ch
,
eps
=
1e-6
)
self
.
Dropout_0
=
nn
.
Dropout
(
dropout
)
self
.
Dropout_0
=
nn
.
Dropout
(
dropout
)
self
.
Conv_1
=
conv
3x3
(
out_ch
,
out_ch
,
init_scale
=
init_scale
)
self
.
Conv_1
=
conv
2d
(
out_ch
,
out_ch
,
init_scale
=
init_scale
,
kernel_size
=
3
,
padding
=
1
)
if
in_ch
!=
out_ch
or
up
or
down
:
if
in_ch
!=
out_ch
or
up
or
down
:
self
.
Conv_2
=
conv1x1
(
in_ch
,
out_ch
)
# 1x1 convolution with DDPM initialization.
self
.
Conv_2
=
conv2d
(
in_ch
,
out_ch
,
kernel_size
=
1
,
padding
=
0
)
self
.
skip_rescale
=
skip_rescale
self
.
skip_rescale
=
skip_rescale
self
.
act
=
act
self
.
act
=
act
...
@@ -445,19 +452,11 @@ class ResnetBlockBigGANpp(nn.Module):
...
@@ -445,19 +452,11 @@ class ResnetBlockBigGANpp(nn.Module):
h
=
self
.
act
(
self
.
GroupNorm_0
(
x
))
h
=
self
.
act
(
self
.
GroupNorm_0
(
x
))
if
self
.
up
:
if
self
.
up
:
if
self
.
fir
:
h
=
self
.
upsample
(
h
)
h
=
upsample_2d
(
h
,
self
.
fir_kernel
,
factor
=
2
)
x
=
self
.
upsample
(
x
)
x
=
upsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
else
:
h
=
naive_upsample_2d
(
h
,
factor
=
2
)
x
=
naive_upsample_2d
(
x
,
factor
=
2
)
elif
self
.
down
:
elif
self
.
down
:
if
self
.
fir
:
h
=
self
.
downsample
(
h
)
h
=
downsample_2d
(
h
,
self
.
fir_kernel
,
factor
=
2
)
x
=
self
.
downsample
(
x
)
x
=
downsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
else
:
h
=
naive_downsample_2d
(
h
,
factor
=
2
)
x
=
naive_downsample_2d
(
x
,
factor
=
2
)
h
=
self
.
Conv_0
(
h
)
h
=
self
.
Conv_0
(
h
)
# Add bias to each feature map conditioned on the time embedding
# Add bias to each feature map conditioned on the time embedding
...
@@ -476,62 +475,6 @@ class ResnetBlockBigGANpp(nn.Module):
...
@@ -476,62 +475,6 @@ class ResnetBlockBigGANpp(nn.Module):
return
(
x
+
h
)
/
np
.
sqrt
(
2.0
)
return
(
x
+
h
)
/
np
.
sqrt
(
2.0
)
# unet_score_estimation.py
class
ResnetBlockDDPMpp
(
nn
.
Module
):
"""ResBlock adapted from DDPM."""
def
__init__
(
self
,
act
,
in_ch
,
out_ch
=
None
,
temb_dim
=
None
,
conv_shortcut
=
False
,
dropout
=
0.1
,
skip_rescale
=
False
,
init_scale
=
0.0
,
):
super
().
__init__
()
out_ch
=
out_ch
if
out_ch
else
in_ch
self
.
GroupNorm_0
=
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
)
self
.
Conv_0
=
conv3x3
(
in_ch
,
out_ch
)
if
temb_dim
is
not
None
:
self
.
Dense_0
=
nn
.
Linear
(
temb_dim
,
out_ch
)
self
.
Dense_0
.
weight
.
data
=
default_init
()(
self
.
Dense_0
.
weight
.
data
.
shape
)
nn
.
init
.
zeros_
(
self
.
Dense_0
.
bias
)
self
.
GroupNorm_1
=
nn
.
GroupNorm
(
num_groups
=
min
(
out_ch
//
4
,
32
),
num_channels
=
out_ch
,
eps
=
1e-6
)
self
.
Dropout_0
=
nn
.
Dropout
(
dropout
)
self
.
Conv_1
=
conv3x3
(
out_ch
,
out_ch
,
init_scale
=
init_scale
)
if
in_ch
!=
out_ch
:
if
conv_shortcut
:
self
.
Conv_2
=
conv3x3
(
in_ch
,
out_ch
)
else
:
self
.
NIN_0
=
NIN
(
in_ch
,
out_ch
)
self
.
skip_rescale
=
skip_rescale
self
.
act
=
act
self
.
out_ch
=
out_ch
self
.
conv_shortcut
=
conv_shortcut
def
forward
(
self
,
x
,
temb
=
None
):
h
=
self
.
act
(
self
.
GroupNorm_0
(
x
))
h
=
self
.
Conv_0
(
h
)
if
temb
is
not
None
:
h
+=
self
.
Dense_0
(
self
.
act
(
temb
))[:,
:,
None
,
None
]
h
=
self
.
act
(
self
.
GroupNorm_1
(
h
))
h
=
self
.
Dropout_0
(
h
)
h
=
self
.
Conv_1
(
h
)
if
x
.
shape
[
1
]
!=
self
.
out_ch
:
if
self
.
conv_shortcut
:
x
=
self
.
Conv_2
(
x
)
else
:
x
=
self
.
NIN_0
(
x
)
if
not
self
.
skip_rescale
:
return
x
+
h
else
:
return
(
x
+
h
)
/
np
.
sqrt
(
2.0
)
# unet_rl.py
# unet_rl.py
class
ResidualTemporalBlock
(
nn
.
Module
):
class
ResidualTemporalBlock
(
nn
.
Module
):
def
__init__
(
self
,
inp_channels
,
out_channels
,
embed_dim
,
horizon
,
kernel_size
=
5
):
def
__init__
(
self
,
inp_channels
,
out_channels
,
embed_dim
,
horizon
,
kernel_size
=
5
):
...
@@ -649,32 +592,17 @@ class RearrangeDim(nn.Module):
...
@@ -649,32 +592,17 @@ class RearrangeDim(nn.Module):
raise
ValueError
(
f
"`len(tensor)`:
{
len
(
tensor
)
}
has to be 2, 3 or 4."
)
raise
ValueError
(
f
"`len(tensor)`:
{
len
(
tensor
)
}
has to be 2, 3 or 4."
)
def
conv1x1
(
in_planes
,
out_planes
,
stride
=
1
,
bias
=
True
,
init_scale
=
1.0
,
padding
=
0
):
def
conv2d
(
in_planes
,
out_planes
,
kernel_size
=
3
,
stride
=
1
,
bias
=
True
,
init_scale
=
1.0
,
padding
=
1
):
"""1x1 convolution with DDPM initialization."""
"""nXn convolution with DDPM initialization."""
conv
=
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
1
,
stride
=
stride
,
padding
=
padding
,
bias
=
bias
)
conv
=
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
bias
=
bias
)
conv
.
weight
.
data
=
default_init
(
init_scale
)(
conv
.
weight
.
data
.
shape
)
conv
.
weight
.
data
=
variance_scaling
(
init_scale
)(
conv
.
weight
.
data
.
shape
)
nn
.
init
.
zeros_
(
conv
.
bias
)
return
conv
def
conv3x3
(
in_planes
,
out_planes
,
stride
=
1
,
bias
=
True
,
dilation
=
1
,
init_scale
=
1.0
,
padding
=
1
):
"""3x3 convolution with DDPM initialization."""
conv
=
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
,
bias
=
bias
)
conv
.
weight
.
data
=
default_init
(
init_scale
)(
conv
.
weight
.
data
.
shape
)
nn
.
init
.
zeros_
(
conv
.
bias
)
nn
.
init
.
zeros_
(
conv
.
bias
)
return
conv
return
conv
def
default_init
(
scale
=
1.0
):
def
variance_scaling
(
scale
=
1.0
,
in_axis
=
1
,
out_axis
=
0
,
dtype
=
torch
.
float32
,
device
=
"cpu"
):
"""The same initialization used in DDPM."""
scale
=
1e-10
if
scale
==
0
else
scale
return
variance_scaling
(
scale
,
"fan_avg"
,
"uniform"
)
def
variance_scaling
(
scale
,
mode
,
distribution
,
in_axis
=
1
,
out_axis
=
0
,
dtype
=
torch
.
float32
,
device
=
"cpu"
):
"""Ported from JAX."""
"""Ported from JAX."""
scale
=
1e-10
if
scale
==
0
else
scale
def
_compute_fans
(
shape
,
in_axis
=
1
,
out_axis
=
0
):
def
_compute_fans
(
shape
,
in_axis
=
1
,
out_axis
=
0
):
receptive_field_size
=
np
.
prod
(
shape
)
/
shape
[
in_axis
]
/
shape
[
out_axis
]
receptive_field_size
=
np
.
prod
(
shape
)
/
shape
[
in_axis
]
/
shape
[
out_axis
]
...
@@ -684,21 +612,9 @@ def variance_scaling(scale, mode, distribution, in_axis=1, out_axis=0, dtype=tor
...
@@ -684,21 +612,9 @@ def variance_scaling(scale, mode, distribution, in_axis=1, out_axis=0, dtype=tor
def
init
(
shape
,
dtype
=
dtype
,
device
=
device
):
def
init
(
shape
,
dtype
=
dtype
,
device
=
device
):
fan_in
,
fan_out
=
_compute_fans
(
shape
,
in_axis
,
out_axis
)
fan_in
,
fan_out
=
_compute_fans
(
shape
,
in_axis
,
out_axis
)
if
mode
==
"fan_in"
:
denominator
=
fan_in
elif
mode
==
"fan_out"
:
denominator
=
fan_out
elif
mode
==
"fan_avg"
:
denominator
=
(
fan_in
+
fan_out
)
/
2
denominator
=
(
fan_in
+
fan_out
)
/
2
else
:
raise
ValueError
(
"invalid mode for variance scaling initializer: {}"
.
format
(
mode
))
variance
=
scale
/
denominator
variance
=
scale
/
denominator
if
distribution
==
"normal"
:
return
torch
.
randn
(
*
shape
,
dtype
=
dtype
,
device
=
device
)
*
np
.
sqrt
(
variance
)
elif
distribution
==
"uniform"
:
return
(
torch
.
rand
(
*
shape
,
dtype
=
dtype
,
device
=
device
)
*
2.0
-
1.0
)
*
np
.
sqrt
(
3
*
variance
)
return
(
torch
.
rand
(
*
shape
,
dtype
=
dtype
,
device
=
device
)
*
2.0
-
1.0
)
*
np
.
sqrt
(
3
*
variance
)
else
:
raise
ValueError
(
"invalid distribution for variance scaling initializer"
)
return
init
return
init
...
@@ -796,31 +712,6 @@ def downsample_2d(x, k=None, factor=2, gain=1):
...
@@ -796,31 +712,6 @@ def downsample_2d(x, k=None, factor=2, gain=1):
return
upfirdn2d
(
x
,
torch
.
tensor
(
k
,
device
=
x
.
device
),
down
=
factor
,
pad
=
((
p
+
1
)
//
2
,
p
//
2
))
return
upfirdn2d
(
x
,
torch
.
tensor
(
k
,
device
=
x
.
device
),
down
=
factor
,
pad
=
((
p
+
1
)
//
2
,
p
//
2
))
def
naive_upsample_2d
(
x
,
factor
=
2
):
_N
,
C
,
H
,
W
=
x
.
shape
x
=
torch
.
reshape
(
x
,
(
-
1
,
C
,
H
,
1
,
W
,
1
))
x
=
x
.
repeat
(
1
,
1
,
1
,
factor
,
1
,
factor
)
return
torch
.
reshape
(
x
,
(
-
1
,
C
,
H
*
factor
,
W
*
factor
))
def
naive_downsample_2d
(
x
,
factor
=
2
):
_N
,
C
,
H
,
W
=
x
.
shape
x
=
torch
.
reshape
(
x
,
(
-
1
,
C
,
H
//
factor
,
factor
,
W
//
factor
,
factor
))
return
torch
.
mean
(
x
,
dim
=
(
3
,
5
))
class
NIN
(
nn
.
Module
):
def
__init__
(
self
,
in_dim
,
num_units
,
init_scale
=
0.1
):
super
().
__init__
()
self
.
W
=
nn
.
Parameter
(
default_init
(
scale
=
init_scale
)((
in_dim
,
num_units
)),
requires_grad
=
True
)
self
.
b
=
nn
.
Parameter
(
torch
.
zeros
(
num_units
),
requires_grad
=
True
)
def
forward
(
self
,
x
):
x
=
x
.
permute
(
0
,
2
,
3
,
1
)
y
=
contract_inner
(
x
,
self
.
W
)
+
self
.
b
return
y
.
permute
(
0
,
3
,
1
,
2
)
def
_setup_kernel
(
k
):
def
_setup_kernel
(
k
):
k
=
np
.
asarray
(
k
,
dtype
=
np
.
float32
)
k
=
np
.
asarray
(
k
,
dtype
=
np
.
float32
)
if
k
.
ndim
==
1
:
if
k
.
ndim
==
1
:
...
@@ -829,17 +720,3 @@ def _setup_kernel(k):
...
@@ -829,17 +720,3 @@ def _setup_kernel(k):
assert
k
.
ndim
==
2
assert
k
.
ndim
==
2
assert
k
.
shape
[
0
]
==
k
.
shape
[
1
]
assert
k
.
shape
[
0
]
==
k
.
shape
[
1
]
return
k
return
k
def
contract_inner
(
x
,
y
):
"""tensordot(x, y, 1)."""
x_chars
=
list
(
string
.
ascii_lowercase
[:
len
(
x
.
shape
)])
y_chars
=
list
(
string
.
ascii_lowercase
[
len
(
x
.
shape
)
:
len
(
y
.
shape
)
+
len
(
x
.
shape
)])
y_chars
[
0
]
=
x_chars
[
-
1
]
# first axis of y and last of x get summed
out_chars
=
x_chars
[:
-
1
]
+
y_chars
[
1
:]
return
_einsum
(
x_chars
,
y_chars
,
out_chars
,
x
,
y
)
def
_einsum
(
a
,
b
,
c
,
x
,
y
):
einsum_str
=
"{},{}->{}"
.
format
(
""
.
join
(
a
),
""
.
join
(
b
),
""
.
join
(
c
))
return
torch
.
einsum
(
einsum_str
,
x
,
y
)
src/diffusers/models/unet.py
View file @
23904d54
...
@@ -34,48 +34,6 @@ def Normalize(in_channels):
...
@@ -34,48 +34,6 @@ def Normalize(in_channels):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
# class ResnetBlock(nn.Module):
# def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512):
# super().__init__()
# self.in_channels = in_channels
# out_channels = in_channels if out_channels is None else out_channels
# self.out_channels = out_channels
# self.use_conv_shortcut = conv_shortcut
#
# self.norm1 = Normalize(in_channels)
# self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
# self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
# self.norm2 = Normalize(out_channels)
# self.dropout = torch.nn.Dropout(dropout)
# self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
# if self.in_channels != self.out_channels:
# if self.use_conv_shortcut:
# self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
# else:
# self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
#
# def forward(self, x, temb):
# h = x
# h = self.norm1(h)
# h = nonlinearity(h)
# h = self.conv1(h)
#
# h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
#
# h = self.norm2(h)
# h = nonlinearity(h)
# h = self.dropout(h)
# h = self.conv2(h)
#
# if self.in_channels != self.out_channels:
# if self.use_conv_shortcut:
# x = self.conv_shortcut(x)
# else:
# x = self.nin_shortcut(x)
#
# return x + h
class
UNetModel
(
ModelMixin
,
ConfigMixin
):
class
UNetModel
(
ModelMixin
,
ConfigMixin
):
def
__init__
(
def
__init__
(
self
,
self
,
...
...
src/diffusers/models/unet_glide.py
View file @
23904d54
...
@@ -6,7 +6,7 @@ from ..configuration_utils import ConfigMixin
...
@@ -6,7 +6,7 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.attention
import
AttentionBlock
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
,
ResBlock
,
TimestepBlock
,
Upsample
from
.resnet
import
Downsample
,
Res
net
Block
,
TimestepBlock
,
Upsample
def
convert_module_to_f16
(
l
):
def
convert_module_to_f16
(
l
):
...
@@ -29,19 +29,6 @@ def convert_module_to_f32(l):
...
@@ -29,19 +29,6 @@ def convert_module_to_f32(l):
l
.
bias
.
data
=
l
.
bias
.
data
.
float
()
l
.
bias
.
data
=
l
.
bias
.
data
.
float
()
def
avg_pool_nd
(
dims
,
*
args
,
**
kwargs
):
"""
Create a 1D, 2D, or 3D average pooling module.
"""
if
dims
==
1
:
return
nn
.
AvgPool1d
(
*
args
,
**
kwargs
)
elif
dims
==
2
:
return
nn
.
AvgPool2d
(
*
args
,
**
kwargs
)
elif
dims
==
3
:
return
nn
.
AvgPool3d
(
*
args
,
**
kwargs
)
raise
ValueError
(
f
"unsupported dimensions:
{
dims
}
"
)
def
conv_nd
(
dims
,
*
args
,
**
kwargs
):
def
conv_nd
(
dims
,
*
args
,
**
kwargs
):
"""
"""
Create a 1D, 2D, or 3D convolution module.
Create a 1D, 2D, or 3D convolution module.
...
@@ -101,7 +88,7 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
...
@@ -101,7 +88,7 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
def
forward
(
self
,
x
,
emb
,
encoder_out
=
None
):
def
forward
(
self
,
x
,
emb
,
encoder_out
=
None
):
for
layer
in
self
:
for
layer
in
self
:
if
isinstance
(
layer
,
TimestepBlock
):
if
isinstance
(
layer
,
TimestepBlock
)
or
isinstance
(
layer
,
ResnetBlock
)
:
x
=
layer
(
x
,
emb
)
x
=
layer
(
x
,
emb
)
elif
isinstance
(
layer
,
AttentionBlock
):
elif
isinstance
(
layer
,
AttentionBlock
):
x
=
layer
(
x
,
encoder_out
)
x
=
layer
(
x
,
encoder_out
)
...
@@ -190,14 +177,15 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -190,14 +177,15 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
for
level
,
mult
in
enumerate
(
channel_mult
):
for
level
,
mult
in
enumerate
(
channel_mult
):
for
_
in
range
(
num_res_blocks
):
for
_
in
range
(
num_res_blocks
):
layers
=
[
layers
=
[
ResBlock
(
ResnetBlock
(
ch
,
in_channels
=
ch
,
time_embed_dim
,
out_channels
=
mult
*
model_channels
,
dropout
,
dropout
=
dropout
,
out_channels
=
int
(
mult
*
model_channels
),
temb_channels
=
time_embed_dim
,
dims
=
dims
,
eps
=
1e-5
,
use_checkpoint
=
use_checkpoint
,
non_linearity
=
"silu"
,
use_scale_shift_norm
=
use_scale_shift_norm
,
time_embedding_norm
=
"scale_shift"
if
use_scale_shift_norm
else
"default"
,
overwrite_for_glide
=
True
,
)
)
]
]
ch
=
int
(
mult
*
model_channels
)
ch
=
int
(
mult
*
model_channels
)
...
@@ -218,14 +206,15 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -218,14 +206,15 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
out_ch
=
ch
out_ch
=
ch
self
.
input_blocks
.
append
(
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
TimestepEmbedSequential
(
ResBlock
(
ResnetBlock
(
ch
,
in_channels
=
ch
,
time_embed_dim
,
dropout
,
out_channels
=
out_ch
,
out_channels
=
out_ch
,
dims
=
dims
,
dropout
=
dropout
,
use_checkpoint
=
use_checkpoint
,
temb_channels
=
time_embed_dim
,
use_scale_shift_norm
=
use_scale_shift_norm
,
eps
=
1e-5
,
non_linearity
=
"silu"
,
time_embedding_norm
=
"scale_shift"
if
use_scale_shift_norm
else
"default"
,
overwrite_for_glide
=
True
,
down
=
True
,
down
=
True
,
)
)
if
resblock_updown
if
resblock_updown
...
@@ -240,13 +229,14 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -240,13 +229,14 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
self
.
_feature_size
+=
ch
self
.
_feature_size
+=
ch
self
.
middle_block
=
TimestepEmbedSequential
(
self
.
middle_block
=
TimestepEmbedSequential
(
ResBlock
(
ResnetBlock
(
ch
,
in_channels
=
ch
,
time_embed_dim
,
dropout
=
dropout
,
dropout
,
temb_channels
=
time_embed_dim
,
dims
=
dims
,
eps
=
1e-5
,
use_checkpoint
=
use_checkpoint
,
non_linearity
=
"silu"
,
use_scale_shift_norm
=
use_scale_shift_norm
,
time_embedding_norm
=
"scale_shift"
if
use_scale_shift_norm
else
"default"
,
overwrite_for_glide
=
True
,
),
),
AttentionBlock
(
AttentionBlock
(
ch
,
ch
,
...
@@ -255,13 +245,14 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -255,13 +245,14 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
num_head_channels
=
num_head_channels
,
num_head_channels
=
num_head_channels
,
encoder_channels
=
transformer_dim
,
encoder_channels
=
transformer_dim
,
),
),
ResBlock
(
ResnetBlock
(
ch
,
in_channels
=
ch
,
time_embed_dim
,
dropout
=
dropout
,
dropout
,
temb_channels
=
time_embed_dim
,
dims
=
dims
,
eps
=
1e-5
,
use_checkpoint
=
use_checkpoint
,
non_linearity
=
"silu"
,
use_scale_shift_norm
=
use_scale_shift_norm
,
time_embedding_norm
=
"scale_shift"
if
use_scale_shift_norm
else
"default"
,
overwrite_for_glide
=
True
,
),
),
)
)
self
.
_feature_size
+=
ch
self
.
_feature_size
+=
ch
...
@@ -271,15 +262,16 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -271,15 +262,16 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
for
i
in
range
(
num_res_blocks
+
1
):
for
i
in
range
(
num_res_blocks
+
1
):
ich
=
input_block_chans
.
pop
()
ich
=
input_block_chans
.
pop
()
layers
=
[
layers
=
[
ResBlock
(
ResnetBlock
(
ch
+
ich
,
in_channels
=
ch
+
ich
,
time_embed_dim
,
out_channels
=
model_channels
*
mult
,
dropout
,
dropout
=
dropout
,
out_channels
=
int
(
model_channels
*
mult
),
temb_channels
=
time_embed_dim
,
dims
=
dims
,
eps
=
1e-5
,
use_checkpoint
=
use_checkpoint
,
non_linearity
=
"silu"
,
use_scale_shift_norm
=
use_scale_shift_norm
,
time_embedding_norm
=
"scale_shift"
if
use_scale_shift_norm
else
"default"
,
)
overwrite_for_glide
=
True
,
),
]
]
ch
=
int
(
model_channels
*
mult
)
ch
=
int
(
model_channels
*
mult
)
if
ds
in
attention_resolutions
:
if
ds
in
attention_resolutions
:
...
@@ -295,14 +287,15 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -295,14 +287,15 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
if
level
and
i
==
num_res_blocks
:
if
level
and
i
==
num_res_blocks
:
out_ch
=
ch
out_ch
=
ch
layers
.
append
(
layers
.
append
(
ResBlock
(
ResnetBlock
(
ch
,
in_channels
=
ch
,
time_embed_dim
,
dropout
,
out_channels
=
out_ch
,
out_channels
=
out_ch
,
dims
=
dims
,
dropout
=
dropout
,
use_checkpoint
=
use_checkpoint
,
temb_channels
=
time_embed_dim
,
use_scale_shift_norm
=
use_scale_shift_norm
,
eps
=
1e-5
,
non_linearity
=
"silu"
,
time_embedding_norm
=
"scale_shift"
if
use_scale_shift_norm
else
"default"
,
overwrite_for_glide
=
True
,
up
=
True
,
up
=
True
,
)
)
if
resblock_updown
if
resblock_updown
...
...
src/diffusers/models/unet_ldm.py
View file @
23904d54
...
@@ -10,7 +10,10 @@ from ..configuration_utils import ConfigMixin
...
@@ -10,7 +10,10 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.attention
import
AttentionBlock
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
,
ResBlock
,
TimestepBlock
,
Upsample
from
.resnet
import
Downsample
,
ResnetBlock
,
TimestepBlock
,
Upsample
# from .resnet import ResBlock
def
exists
(
val
):
def
exists
(
val
):
...
@@ -75,182 +78,6 @@ def Normalize(in_channels):
...
@@ -75,182 +78,6 @@ def Normalize(in_channels):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
# class LinearAttention(nn.Module):
# def __init__(self, dim, heads=4, dim_head=32):
# super().__init__()
# self.heads = heads
# hidden_dim = dim_head * heads
# self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
# self.to_out = nn.Conv2d(hidden_dim, dim, 1)
#
# def forward(self, x):
# b, c, h, w = x.shape
# qkv = self.to_qkv(x)
# q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
# import ipdb; ipdb.set_trace()
# k = k.softmax(dim=-1)
# context = torch.einsum("bhdn,bhen->bhde", k, v)
# out = torch.einsum("bhde,bhdn->bhen", context, q)
# out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
# return self.to_out(out)
#
# class SpatialSelfAttention(nn.Module):
# def __init__(self, in_channels):
# super().__init__()
# self.in_channels = in_channels
#
# self.norm = Normalize(in_channels)
# self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
# self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
# self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
# self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
#
# def forward(self, x):
# h_ = x
# h_ = self.norm(h_)
# q = self.q(h_)
# k = self.k(h_)
# v = self.v(h_)
#
# compute attention
# b, c, h, w = q.shape
# q = rearrange(q, "b c h w -> b (h w) c")
# k = rearrange(k, "b c h w -> b c (h w)")
# w_ = torch.einsum("bij,bjk->bik", q, k)
#
# w_ = w_ * (int(c) ** (-0.5))
# w_ = torch.nn.functional.softmax(w_, dim=2)
#
# attend to values
# v = rearrange(v, "b c h w -> b c (h w)")
# w_ = rearrange(w_, "b i j -> b j i")
# h_ = torch.einsum("bij,bjk->bik", v, w_)
# h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
# h_ = self.proj_out(h_)
#
# return x + h_
#
class
CrossAttention
(
nn
.
Module
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.0
):
super
().
__init__
()
inner_dim
=
dim_head
*
heads
context_dim
=
default
(
context_dim
,
query_dim
)
self
.
scale
=
dim_head
**-
0.5
self
.
heads
=
heads
self
.
to_q
=
nn
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
)
self
.
to_k
=
nn
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
self
.
to_v
=
nn
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
self
.
to_out
=
nn
.
Sequential
(
nn
.
Linear
(
inner_dim
,
query_dim
),
nn
.
Dropout
(
dropout
))
def
reshape_heads_to_batch_dim
(
self
,
tensor
):
batch_size
,
seq_len
,
dim
=
tensor
.
shape
head_size
=
self
.
heads
tensor
=
tensor
.
reshape
(
batch_size
,
seq_len
,
head_size
,
dim
//
head_size
)
tensor
=
tensor
.
permute
(
0
,
2
,
1
,
3
).
reshape
(
batch_size
*
head_size
,
seq_len
,
dim
//
head_size
)
return
tensor
def
reshape_batch_dim_to_heads
(
self
,
tensor
):
batch_size
,
seq_len
,
dim
=
tensor
.
shape
head_size
=
self
.
heads
tensor
=
tensor
.
reshape
(
batch_size
//
head_size
,
head_size
,
seq_len
,
dim
)
tensor
=
tensor
.
permute
(
0
,
2
,
1
,
3
).
reshape
(
batch_size
//
head_size
,
seq_len
,
dim
*
head_size
)
return
tensor
def
forward
(
self
,
x
,
context
=
None
,
mask
=
None
):
batch_size
,
sequence_length
,
dim
=
x
.
shape
h
=
self
.
heads
q
=
self
.
to_q
(
x
)
context
=
default
(
context
,
x
)
k
=
self
.
to_k
(
context
)
v
=
self
.
to_v
(
context
)
q
=
self
.
reshape_heads_to_batch_dim
(
q
)
k
=
self
.
reshape_heads_to_batch_dim
(
k
)
v
=
self
.
reshape_heads_to_batch_dim
(
v
)
sim
=
torch
.
einsum
(
"b i d, b j d -> b i j"
,
q
,
k
)
*
self
.
scale
if
exists
(
mask
):
mask
=
mask
.
reshape
(
batch_size
,
-
1
)
max_neg_value
=
-
torch
.
finfo
(
sim
.
dtype
).
max
mask
=
mask
[:,
None
,
:].
repeat
(
h
,
1
,
1
)
sim
.
masked_fill_
(
~
mask
,
max_neg_value
)
# attention, what we cannot get enough of
attn
=
sim
.
softmax
(
dim
=-
1
)
out
=
torch
.
einsum
(
"b i j, b j d -> b i d"
,
attn
,
v
)
out
=
self
.
reshape_batch_dim_to_heads
(
out
)
return
self
.
to_out
(
out
)
class
BasicTransformerBlock
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
n_heads
,
d_head
,
dropout
=
0.0
,
context_dim
=
None
,
gated_ff
=
True
,
checkpoint
=
True
):
super
().
__init__
()
self
.
attn1
=
CrossAttention
(
query_dim
=
dim
,
heads
=
n_heads
,
dim_head
=
d_head
,
dropout
=
dropout
)
# is a self-attention
self
.
ff
=
FeedForward
(
dim
,
dropout
=
dropout
,
glu
=
gated_ff
)
self
.
attn2
=
CrossAttention
(
query_dim
=
dim
,
context_dim
=
context_dim
,
heads
=
n_heads
,
dim_head
=
d_head
,
dropout
=
dropout
)
# is self-attn if context is none
self
.
norm1
=
nn
.
LayerNorm
(
dim
)
self
.
norm2
=
nn
.
LayerNorm
(
dim
)
self
.
norm3
=
nn
.
LayerNorm
(
dim
)
self
.
checkpoint
=
checkpoint
def
forward
(
self
,
x
,
context
=
None
):
x
=
self
.
attn1
(
self
.
norm1
(
x
))
+
x
x
=
self
.
attn2
(
self
.
norm2
(
x
),
context
=
context
)
+
x
x
=
self
.
ff
(
self
.
norm3
(
x
))
+
x
return
x
class
SpatialTransformer
(
nn
.
Module
):
"""
Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
standard transformer action. Finally, reshape to image
"""
def
__init__
(
self
,
in_channels
,
n_heads
,
d_head
,
depth
=
1
,
dropout
=
0.0
,
context_dim
=
None
):
super
().
__init__
()
self
.
in_channels
=
in_channels
inner_dim
=
n_heads
*
d_head
self
.
norm
=
Normalize
(
in_channels
)
self
.
proj_in
=
nn
.
Conv2d
(
in_channels
,
inner_dim
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
transformer_blocks
=
nn
.
ModuleList
(
[
BasicTransformerBlock
(
inner_dim
,
n_heads
,
d_head
,
dropout
=
dropout
,
context_dim
=
context_dim
)
for
d
in
range
(
depth
)
]
)
self
.
proj_out
=
zero_module
(
nn
.
Conv2d
(
inner_dim
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
))
def
forward
(
self
,
x
,
context
=
None
):
# note: if no context is given, cross-attention defaults to self-attention
b
,
c
,
h
,
w
=
x
.
shape
x_in
=
x
x
=
self
.
norm
(
x
)
x
=
self
.
proj_in
(
x
)
x
=
x
.
permute
(
0
,
2
,
3
,
1
).
reshape
(
b
,
h
*
w
,
c
)
for
block
in
self
.
transformer_blocks
:
x
=
block
(
x
,
context
=
context
)
x
=
x
.
reshape
(
b
,
h
,
w
,
c
).
permute
(
0
,
3
,
1
,
2
)
x
=
self
.
proj_out
(
x
)
return
x
+
x_in
def
convert_module_to_f16
(
l
):
def
convert_module_to_f16
(
l
):
"""
"""
Convert primitive modules to float16.
Convert primitive modules to float16.
...
@@ -271,19 +98,6 @@ def convert_module_to_f32(l):
...
@@ -271,19 +98,6 @@ def convert_module_to_f32(l):
l
.
bias
.
data
=
l
.
bias
.
data
.
float
()
l
.
bias
.
data
=
l
.
bias
.
data
.
float
()
def
avg_pool_nd
(
dims
,
*
args
,
**
kwargs
):
"""
Create a 1D, 2D, or 3D average pooling module.
"""
if
dims
==
1
:
return
nn
.
AvgPool1d
(
*
args
,
**
kwargs
)
elif
dims
==
2
:
return
nn
.
AvgPool2d
(
*
args
,
**
kwargs
)
elif
dims
==
3
:
return
nn
.
AvgPool3d
(
*
args
,
**
kwargs
)
raise
ValueError
(
f
"unsupported dimensions:
{
dims
}
"
)
def
conv_nd
(
dims
,
*
args
,
**
kwargs
):
def
conv_nd
(
dims
,
*
args
,
**
kwargs
):
"""
"""
Create a 1D, 2D, or 3D convolution module.
Create a 1D, 2D, or 3D convolution module.
...
@@ -327,36 +141,6 @@ def normalization(channels, swish=0.0):
...
@@ -327,36 +141,6 @@ def normalization(channels, swish=0.0):
return
GroupNorm32
(
num_channels
=
channels
,
num_groups
=
32
,
swish
=
swish
)
return
GroupNorm32
(
num_channels
=
channels
,
num_groups
=
32
,
swish
=
swish
)
class
AttentionPool2d
(
nn
.
Module
):
"""
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
"""
def
__init__
(
self
,
spacial_dim
:
int
,
embed_dim
:
int
,
num_heads_channels
:
int
,
output_dim
:
int
=
None
,
):
super
().
__init__
()
self
.
positional_embedding
=
nn
.
Parameter
(
torch
.
randn
(
embed_dim
,
spacial_dim
**
2
+
1
)
/
embed_dim
**
0.5
)
self
.
qkv_proj
=
conv_nd
(
1
,
embed_dim
,
3
*
embed_dim
,
1
)
self
.
c_proj
=
conv_nd
(
1
,
embed_dim
,
output_dim
or
embed_dim
,
1
)
self
.
num_heads
=
embed_dim
//
num_heads_channels
self
.
attention
=
QKVAttention
(
self
.
num_heads
)
def
forward
(
self
,
x
):
b
,
c
,
*
_spatial
=
x
.
shape
x
=
x
.
reshape
(
b
,
c
,
-
1
)
# NC(HW)
x
=
torch
.
cat
([
x
.
mean
(
dim
=-
1
,
keepdim
=
True
),
x
],
dim
=-
1
)
# NC(HW+1)
x
=
x
+
self
.
positional_embedding
[
None
,
:,
:].
to
(
x
.
dtype
)
# NC(HW+1)
x
=
self
.
qkv_proj
(
x
)
x
=
self
.
attention
(
x
)
x
=
self
.
c_proj
(
x
)
return
x
[:,
:,
0
]
class
TimestepEmbedSequential
(
nn
.
Sequential
,
TimestepBlock
):
class
TimestepEmbedSequential
(
nn
.
Sequential
,
TimestepBlock
):
"""
"""
A sequential module that passes timestep embeddings to the children that support it as an extra input.
A sequential module that passes timestep embeddings to the children that support it as an extra input.
...
@@ -364,7 +148,7 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
...
@@ -364,7 +148,7 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
def
forward
(
self
,
x
,
emb
,
context
=
None
):
def
forward
(
self
,
x
,
emb
,
context
=
None
):
for
layer
in
self
:
for
layer
in
self
:
if
isinstance
(
layer
,
TimestepBlock
):
if
isinstance
(
layer
,
TimestepBlock
)
or
isinstance
(
layer
,
ResnetBlock
)
:
x
=
layer
(
x
,
emb
)
x
=
layer
(
x
,
emb
)
elif
isinstance
(
layer
,
SpatialTransformer
):
elif
isinstance
(
layer
,
SpatialTransformer
):
x
=
layer
(
x
,
context
)
x
=
layer
(
x
,
context
)
...
@@ -373,39 +157,6 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
...
@@ -373,39 +157,6 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
return
x
return
x
class
QKVAttention
(
nn
.
Module
):
"""
A module which performs QKV attention and splits in a different order.
"""
def
__init__
(
self
,
n_heads
):
super
().
__init__
()
self
.
n_heads
=
n_heads
def
forward
(
self
,
qkv
):
"""
Apply QKV attention. :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x (H * C) x
T] tensor after attention.
"""
bs
,
width
,
length
=
qkv
.
shape
assert
width
%
(
3
*
self
.
n_heads
)
==
0
ch
=
width
//
(
3
*
self
.
n_heads
)
q
,
k
,
v
=
qkv
.
chunk
(
3
,
dim
=
1
)
scale
=
1
/
math
.
sqrt
(
math
.
sqrt
(
ch
))
weight
=
torch
.
einsum
(
"bct,bcs->bts"
,
(
q
*
scale
).
view
(
bs
*
self
.
n_heads
,
ch
,
length
),
(
k
*
scale
).
view
(
bs
*
self
.
n_heads
,
ch
,
length
),
)
# More stable with f16 than dividing afterwards
weight
=
torch
.
softmax
(
weight
.
float
(),
dim
=-
1
).
type
(
weight
.
dtype
)
a
=
torch
.
einsum
(
"bts,bcs->bct"
,
weight
,
v
.
reshape
(
bs
*
self
.
n_heads
,
ch
,
length
))
return
a
.
reshape
(
bs
,
-
1
,
length
)
@
staticmethod
def
count_flops
(
model
,
_x
,
y
):
return
count_flops_attn
(
model
,
_x
,
y
)
def
count_flops_attn
(
model
,
_x
,
y
):
def
count_flops_attn
(
model
,
_x
,
y
):
"""
"""
A counter for the `thop` package to count the operations in an attention operation. Meant to be used like:
A counter for the `thop` package to count the operations in an attention operation. Meant to be used like:
...
@@ -559,14 +310,14 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -559,14 +310,14 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
for
level
,
mult
in
enumerate
(
channel_mult
):
for
level
,
mult
in
enumerate
(
channel_mult
):
for
_
in
range
(
num_res_blocks
):
for
_
in
range
(
num_res_blocks
):
layers
=
[
layers
=
[
ResBlock
(
ResnetBlock
(
ch
,
in_channels
=
ch
,
time_embed_dim
,
dropout
,
out_channels
=
mult
*
model_channels
,
out_channels
=
mult
*
model_channels
,
dims
=
dims
,
dropout
=
dropout
,
use_checkpoint
=
use_checkpoint
,
temb_channels
=
time_embed_dim
,
use_scale_shift_norm
=
use_scale_shift_norm
,
eps
=
1e-5
,
non_linearity
=
"silu"
,
overwrite_for_ldm
=
True
,
)
)
]
]
ch
=
mult
*
model_channels
ch
=
mult
*
model_channels
...
@@ -599,20 +350,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -599,20 +350,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
out_ch
=
ch
out_ch
=
ch
self
.
input_blocks
.
append
(
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
TimestepEmbedSequential
(
ResBlock
(
Downsample
(
ch
,
use_conv
=
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
,
padding
=
1
,
name
=
"op"
)
ch
,
time_embed_dim
,
dropout
,
out_channels
=
out_ch
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
down
=
True
,
)
if
resblock_updown
else
Downsample
(
ch
,
use_conv
=
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
,
padding
=
1
,
name
=
"op"
)
)
)
)
)
ch
=
out_ch
ch
=
out_ch
...
@@ -629,13 +367,14 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -629,13 +367,14 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
# num_heads = 1
# num_heads = 1
dim_head
=
ch
//
num_heads
if
use_spatial_transformer
else
num_head_channels
dim_head
=
ch
//
num_heads
if
use_spatial_transformer
else
num_head_channels
self
.
middle_block
=
TimestepEmbedSequential
(
self
.
middle_block
=
TimestepEmbedSequential
(
ResBlock
(
ResnetBlock
(
ch
,
in_channels
=
ch
,
time_embed_dim
,
out_channels
=
None
,
dropout
,
dropout
=
dropout
,
dims
=
dims
,
temb_channels
=
time_embed_dim
,
use_checkpoint
=
use_checkpoint
,
eps
=
1e-5
,
use_scale_shift_norm
=
use_scale_shift_norm
,
non_linearity
=
"silu"
,
overwrite_for_ldm
=
True
,
),
),
AttentionBlock
(
AttentionBlock
(
ch
,
ch
,
...
@@ -646,13 +385,14 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -646,13 +385,14 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
)
)
if
not
use_spatial_transformer
if
not
use_spatial_transformer
else
SpatialTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
,
context_dim
=
context_dim
),
else
SpatialTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
,
context_dim
=
context_dim
),
ResBlock
(
ResnetBlock
(
ch
,
in_channels
=
ch
,
time_embed_dim
,
out_channels
=
None
,
dropout
,
dropout
=
dropout
,
dims
=
dims
,
temb_channels
=
time_embed_dim
,
use_checkpoint
=
use_checkpoint
,
eps
=
1e-5
,
use_scale_shift_norm
=
use_scale_shift_norm
,
non_linearity
=
"silu"
,
overwrite_for_ldm
=
True
,
),
),
)
)
self
.
_feature_size
+=
ch
self
.
_feature_size
+=
ch
...
@@ -662,15 +402,15 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -662,15 +402,15 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
for
i
in
range
(
num_res_blocks
+
1
):
for
i
in
range
(
num_res_blocks
+
1
):
ich
=
input_block_chans
.
pop
()
ich
=
input_block_chans
.
pop
()
layers
=
[
layers
=
[
ResBlock
(
ResnetBlock
(
ch
+
ich
,
in_channels
=
ch
+
ich
,
time_embed_dim
,
dropout
,
out_channels
=
model_channels
*
mult
,
out_channels
=
model_channels
*
mult
,
dims
=
dims
,
dropout
=
dropout
,
use_checkpoint
=
use_checkpoint
,
temb_channels
=
time_embed_dim
,
use_scale_shift_norm
=
use_scale_shift_norm
,
eps
=
1e-5
,
)
non_linearity
=
"silu"
,
overwrite_for_ldm
=
True
,
),
]
]
ch
=
model_channels
*
mult
ch
=
model_channels
*
mult
if
ds
in
attention_resolutions
:
if
ds
in
attention_resolutions
:
...
@@ -697,20 +437,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -697,20 +437,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
)
)
if
level
and
i
==
num_res_blocks
:
if
level
and
i
==
num_res_blocks
:
out_ch
=
ch
out_ch
=
ch
layers
.
append
(
layers
.
append
(
Upsample
(
ch
,
use_conv
=
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
))
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
out_channels
=
out_ch
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
up
=
True
,
)
if
resblock_updown
else
Upsample
(
ch
,
use_conv
=
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
)
)
ds
//=
2
ds
//=
2
self
.
output_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
self
.
output_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
self
.
_feature_size
+=
ch
self
.
_feature_size
+=
ch
...
@@ -777,212 +504,119 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -777,212 +504,119 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
return
self
.
out
(
h
)
return
self
.
out
(
h
)
class
EncoderUNetModel
(
nn
.
Module
):
class
SpatialTransformer
(
nn
.
Module
):
"""
"""
The half UNet model with attention and timestep embedding. For usage, see UNet.
Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
standard transformer action. Finally, reshape to image
"""
"""
def
__init__
(
def
__init__
(
self
,
in_channels
,
n_heads
,
d_head
,
depth
=
1
,
dropout
=
0.0
,
context_dim
=
None
):
self
,
image_size
,
in_channels
,
model_channels
,
out_channels
,
num_res_blocks
,
attention_resolutions
,
dropout
=
0
,
channel_mult
=
(
1
,
2
,
4
,
8
),
conv_resample
=
True
,
dims
=
2
,
use_checkpoint
=
False
,
use_fp16
=
False
,
num_heads
=
1
,
num_head_channels
=-
1
,
num_heads_upsample
=-
1
,
use_scale_shift_norm
=
False
,
resblock_updown
=
False
,
use_new_attention_order
=
False
,
pool
=
"adaptive"
,
*
args
,
**
kwargs
,
):
super
().
__init__
()
super
().
__init__
()
if
num_heads_upsample
==
-
1
:
num_heads_upsample
=
num_heads
self
.
in_channels
=
in_channels
self
.
in_channels
=
in_channels
self
.
model_channels
=
model_channels
inner_dim
=
n_heads
*
d_head
self
.
out_channels
=
out_channels
self
.
norm
=
Normalize
(
in_channels
)
self
.
num_res_blocks
=
num_res_blocks
self
.
attention_resolutions
=
attention_resolutions
self
.
dropout
=
dropout
self
.
channel_mult
=
channel_mult
self
.
conv_resample
=
conv_resample
self
.
use_checkpoint
=
use_checkpoint
self
.
dtype
=
torch
.
float16
if
use_fp16
else
torch
.
float32
self
.
num_heads
=
num_heads
self
.
num_head_channels
=
num_head_channels
self
.
num_heads_upsample
=
num_heads_upsample
time_embed_dim
=
model_channels
*
4
self
.
proj_in
=
nn
.
Conv2d
(
in_channels
,
inner_dim
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
time_embed
=
nn
.
Sequential
(
linear
(
model_channels
,
time_embed_dim
),
nn
.
SiLU
(),
linear
(
time_embed_dim
,
time_embed_dim
),
)
self
.
input_blocks
=
nn
.
ModuleList
(
self
.
transformer_blocks
=
nn
.
ModuleList
(
[
TimestepEmbedSequential
(
conv_nd
(
dims
,
in_channels
,
model_channels
,
3
,
padding
=
1
))]
[
)
BasicTransformerBlock
(
inner_dim
,
n_heads
,
d_head
,
dropout
=
dropout
,
context_dim
=
context_dim
)
self
.
_feature_size
=
model_channels
for
d
in
range
(
depth
)
input_block_chans
=
[
model_channels
]
ch
=
model_channels
ds
=
1
for
level
,
mult
in
enumerate
(
channel_mult
):
for
_
in
range
(
num_res_blocks
):
layers
=
[
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
out_channels
=
mult
*
model_channels
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
)
]
]
ch
=
mult
*
model_channels
if
ds
in
attention_resolutions
:
layers
.
append
(
AttentionBlock
(
ch
,
use_checkpoint
=
use_checkpoint
,
num_heads
=
num_heads
,
num_head_channels
=
num_head_channels
,
use_new_attention_order
=
use_new_attention_order
,
)
)
)
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
self
.
_feature_size
+=
ch
input_block_chans
.
append
(
ch
)
if
level
!=
len
(
channel_mult
)
-
1
:
out_ch
=
ch
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
out_channels
=
out_ch
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
down
=
True
,
)
if
resblock_updown
else
Downsample
(
ch
,
use_conv
=
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
,
padding
=
1
,
name
=
"op"
)
)
)
ch
=
out_ch
input_block_chans
.
append
(
ch
)
ds
*=
2
self
.
_feature_size
+=
ch
self
.
middle_block
=
TimestepEmbedSequential
(
self
.
proj_out
=
zero_module
(
nn
.
Conv2d
(
inner_dim
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
))
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
),
AttentionBlock
(
ch
,
use_checkpoint
=
use_checkpoint
,
num_heads
=
num_heads
,
num_head_channels
=
num_head_channels
,
use_new_attention_order
=
use_new_attention_order
,
),
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
),
)
self
.
_feature_size
+=
ch
self
.
pool
=
pool
if
pool
==
"adaptive"
:
self
.
out
=
nn
.
Sequential
(
normalization
(
ch
),
nn
.
SiLU
(),
nn
.
AdaptiveAvgPool2d
((
1
,
1
)),
zero_module
(
conv_nd
(
dims
,
ch
,
out_channels
,
1
)),
nn
.
Flatten
(),
)
elif
pool
==
"attention"
:
assert
num_head_channels
!=
-
1
self
.
out
=
nn
.
Sequential
(
normalization
(
ch
),
nn
.
SiLU
(),
AttentionPool2d
((
image_size
//
ds
),
ch
,
num_head_channels
,
out_channels
),
)
elif
pool
==
"spatial"
:
self
.
out
=
nn
.
Sequential
(
nn
.
Linear
(
self
.
_feature_size
,
2048
),
nn
.
ReLU
(),
nn
.
Linear
(
2048
,
self
.
out_channels
),
)
elif
pool
==
"spatial_v2"
:
self
.
out
=
nn
.
Sequential
(
nn
.
Linear
(
self
.
_feature_size
,
2048
),
normalization
(
2048
),
nn
.
SiLU
(),
nn
.
Linear
(
2048
,
self
.
out_channels
),
)
else
:
raise
NotImplementedError
(
f
"Unexpected
{
pool
}
pooling"
)
def
convert_to_fp16
(
self
):
def
forward
(
self
,
x
,
context
=
None
):
"""
# note: if no context is given, cross-attention defaults to self-attention
Convert the torso of the model to float16.
b
,
c
,
h
,
w
=
x
.
shape
"""
x_in
=
x
self
.
input_blocks
.
apply
(
convert_module_to_f16
)
x
=
self
.
norm
(
x
)
self
.
middle_block
.
apply
(
convert_module_to_f16
)
x
=
self
.
proj_in
(
x
)
x
=
x
.
permute
(
0
,
2
,
3
,
1
).
reshape
(
b
,
h
*
w
,
c
)
for
block
in
self
.
transformer_blocks
:
x
=
block
(
x
,
context
=
context
)
x
=
x
.
reshape
(
b
,
h
,
w
,
c
).
permute
(
0
,
3
,
1
,
2
)
x
=
self
.
proj_out
(
x
)
return
x
+
x_in
def
convert_to_fp32
(
self
):
"""
Convert the torso of the model to float32.
"""
self
.
input_blocks
.
apply
(
convert_module_to_f32
)
self
.
middle_block
.
apply
(
convert_module_to_f32
)
def
forward
(
self
,
x
,
timesteps
):
class
BasicTransformerBlock
(
nn
.
Module
):
"""
def
__init__
(
self
,
dim
,
n_heads
,
d_head
,
dropout
=
0.0
,
context_dim
=
None
,
gated_ff
=
True
,
checkpoint
=
True
):
Apply the model to an input batch. :param x: an [N x C x ...] Tensor of inputs. :param timesteps: a 1-D batch
super
().
__init__
()
of timesteps. :return: an [N x K] Tensor of outputs.
self
.
attn1
=
CrossAttention
(
"""
query_dim
=
dim
,
heads
=
n_heads
,
dim_head
=
d_head
,
dropout
=
dropout
emb
=
self
.
time_embed
(
)
# is a self-attention
get_timestep_embedding
(
timesteps
,
self
.
model_channels
,
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
)
self
.
ff
=
FeedForward
(
dim
,
dropout
=
dropout
,
glu
=
gated_ff
)
)
self
.
attn2
=
CrossAttention
(
query_dim
=
dim
,
context_dim
=
context_dim
,
heads
=
n_heads
,
dim_head
=
d_head
,
dropout
=
dropout
)
# is self-attn if context is none
self
.
norm1
=
nn
.
LayerNorm
(
dim
)
self
.
norm2
=
nn
.
LayerNorm
(
dim
)
self
.
norm3
=
nn
.
LayerNorm
(
dim
)
self
.
checkpoint
=
checkpoint
results
=
[]
def
forward
(
self
,
x
,
context
=
None
):
h
=
x
.
type
(
self
.
dtype
)
x
=
self
.
attn1
(
self
.
norm1
(
x
))
+
x
for
module
in
self
.
input_blocks
:
x
=
self
.
attn2
(
self
.
norm2
(
x
),
context
=
context
)
+
x
h
=
module
(
h
,
emb
)
x
=
self
.
ff
(
self
.
norm3
(
x
))
+
x
if
self
.
pool
.
startswith
(
"spatial"
):
return
x
results
.
append
(
h
.
type
(
x
.
dtype
).
mean
(
dim
=
(
2
,
3
)))
h
=
self
.
middle_block
(
h
,
emb
)
if
self
.
pool
.
startswith
(
"spatial"
):
class
CrossAttention
(
nn
.
Module
):
results
.
append
(
h
.
type
(
x
.
dtype
).
mean
(
dim
=
(
2
,
3
)))
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.0
):
h
=
torch
.
cat
(
results
,
axis
=-
1
)
super
().
__init__
()
return
self
.
out
(
h
)
inner_dim
=
dim_head
*
heads
else
:
context_dim
=
default
(
context_dim
,
query_dim
)
h
=
h
.
type
(
x
.
dtype
)
return
self
.
out
(
h
)
self
.
scale
=
dim_head
**-
0.5
self
.
heads
=
heads
self
.
to_q
=
nn
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
)
self
.
to_k
=
nn
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
self
.
to_v
=
nn
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
self
.
to_out
=
nn
.
Sequential
(
nn
.
Linear
(
inner_dim
,
query_dim
),
nn
.
Dropout
(
dropout
))
def
reshape_heads_to_batch_dim
(
self
,
tensor
):
batch_size
,
seq_len
,
dim
=
tensor
.
shape
head_size
=
self
.
heads
tensor
=
tensor
.
reshape
(
batch_size
,
seq_len
,
head_size
,
dim
//
head_size
)
tensor
=
tensor
.
permute
(
0
,
2
,
1
,
3
).
reshape
(
batch_size
*
head_size
,
seq_len
,
dim
//
head_size
)
return
tensor
def
reshape_batch_dim_to_heads
(
self
,
tensor
):
batch_size
,
seq_len
,
dim
=
tensor
.
shape
head_size
=
self
.
heads
tensor
=
tensor
.
reshape
(
batch_size
//
head_size
,
head_size
,
seq_len
,
dim
)
tensor
=
tensor
.
permute
(
0
,
2
,
1
,
3
).
reshape
(
batch_size
//
head_size
,
seq_len
,
dim
*
head_size
)
return
tensor
def
forward
(
self
,
x
,
context
=
None
,
mask
=
None
):
batch_size
,
sequence_length
,
dim
=
x
.
shape
h
=
self
.
heads
q
=
self
.
to_q
(
x
)
context
=
default
(
context
,
x
)
k
=
self
.
to_k
(
context
)
v
=
self
.
to_v
(
context
)
q
=
self
.
reshape_heads_to_batch_dim
(
q
)
k
=
self
.
reshape_heads_to_batch_dim
(
k
)
v
=
self
.
reshape_heads_to_batch_dim
(
v
)
sim
=
torch
.
einsum
(
"b i d, b j d -> b i j"
,
q
,
k
)
*
self
.
scale
if
exists
(
mask
):
mask
=
mask
.
reshape
(
batch_size
,
-
1
)
max_neg_value
=
-
torch
.
finfo
(
sim
.
dtype
).
max
mask
=
mask
[:,
None
,
:].
repeat
(
h
,
1
,
1
)
sim
.
masked_fill_
(
~
mask
,
max_neg_value
)
# attention, what we cannot get enough of
attn
=
sim
.
softmax
(
dim
=-
1
)
out
=
torch
.
einsum
(
"b i j, b j d -> b i d"
,
attn
,
v
)
out
=
self
.
reshape_batch_dim_to_heads
(
out
)
return
self
.
to_out
(
out
)
src/diffusers/models/unet_rl.py
View file @
23904d54
...
@@ -6,7 +6,7 @@ import torch.nn as nn
...
@@ -6,7 +6,7 @@ import torch.nn as nn
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
ResidualTemporalBlock
from
.resnet
import
Downsample
,
ResidualTemporalBlock
,
Upsample
class
SinusoidalPosEmb
(
nn
.
Module
):
class
SinusoidalPosEmb
(
nn
.
Module
):
...
@@ -18,24 +18,6 @@ class SinusoidalPosEmb(nn.Module):
...
@@ -18,24 +18,6 @@ class SinusoidalPosEmb(nn.Module):
return
get_timestep_embedding
(
x
,
self
.
dim
)
return
get_timestep_embedding
(
x
,
self
.
dim
)
class
Downsample1d
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
().
__init__
()
self
.
conv
=
nn
.
Conv1d
(
dim
,
dim
,
3
,
2
,
1
)
def
forward
(
self
,
x
):
return
self
.
conv
(
x
)
class
Upsample1d
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
().
__init__
()
self
.
conv
=
nn
.
ConvTranspose1d
(
dim
,
dim
,
4
,
2
,
1
)
def
forward
(
self
,
x
):
return
self
.
conv
(
x
)
class
RearrangeDim
(
nn
.
Module
):
class
RearrangeDim
(
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
):
super
().
__init__
()
super
().
__init__
()
...
@@ -114,7 +96,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
...
@@ -114,7 +96,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
[
[
ResidualTemporalBlock
(
dim_in
,
dim_out
,
embed_dim
=
time_dim
,
horizon
=
training_horizon
),
ResidualTemporalBlock
(
dim_in
,
dim_out
,
embed_dim
=
time_dim
,
horizon
=
training_horizon
),
ResidualTemporalBlock
(
dim_out
,
dim_out
,
embed_dim
=
time_dim
,
horizon
=
training_horizon
),
ResidualTemporalBlock
(
dim_out
,
dim_out
,
embed_dim
=
time_dim
,
horizon
=
training_horizon
),
Downsample
1d
(
dim_out
)
if
not
is_last
else
nn
.
Identity
(),
Downsample
(
dim_out
,
use_conv
=
True
,
dims
=
1
)
if
not
is_last
else
nn
.
Identity
(),
]
]
)
)
)
)
...
@@ -134,7 +116,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
...
@@ -134,7 +116,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
[
[
ResidualTemporalBlock
(
dim_out
*
2
,
dim_in
,
embed_dim
=
time_dim
,
horizon
=
training_horizon
),
ResidualTemporalBlock
(
dim_out
*
2
,
dim_in
,
embed_dim
=
time_dim
,
horizon
=
training_horizon
),
ResidualTemporalBlock
(
dim_in
,
dim_in
,
embed_dim
=
time_dim
,
horizon
=
training_horizon
),
ResidualTemporalBlock
(
dim_in
,
dim_in
,
embed_dim
=
time_dim
,
horizon
=
training_horizon
),
Upsample
1d
(
dim_in
)
if
not
is_last
else
nn
.
Identity
(),
Upsample
(
dim_in
,
use_conv_transpose
=
True
,
dims
=
1
)
if
not
is_last
else
nn
.
Identity
(),
]
]
)
)
)
)
...
...
src/diffusers/models/unet_sde_score_estimation.py
View file @
23904d54
...
@@ -17,7 +17,6 @@
...
@@ -17,7 +17,6 @@
import
functools
import
functools
import
math
import
math
import
string
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -28,116 +27,21 @@ from ..configuration_utils import ConfigMixin
...
@@ -28,116 +27,21 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.attention
import
AttentionBlock
from
.embeddings
import
GaussianFourierProjection
,
get_timestep_embedding
from
.embeddings
import
GaussianFourierProjection
,
get_timestep_embedding
from
.resnet
import
ResnetBlockBigGANpp
,
ResnetBlockDDPMpp
from
.resnet
import
Downsample
,
ResnetBlockBigGANpp
,
Upsample
,
downsample_2d
,
upfirdn2d
,
upsample_2d
def
upfirdn2d
(
input
,
kernel
,
up
=
1
,
down
=
1
,
pad
=
(
0
,
0
)):
def
_setup_kernel
(
k
):
return
upfirdn2d_native
(
input
,
kernel
,
up
,
up
,
down
,
down
,
pad
[
0
],
pad
[
1
],
pad
[
0
],
pad
[
1
])
k
=
np
.
asarray
(
k
,
dtype
=
np
.
float32
)
if
k
.
ndim
==
1
:
k
=
np
.
outer
(
k
,
k
)
def
upfirdn2d_native
(
input
,
kernel
,
up_x
,
up_y
,
down_x
,
down_y
,
pad_x0
,
pad_x1
,
pad_y0
,
pad_y1
):
k
/=
np
.
sum
(
k
)
_
,
channel
,
in_h
,
in_w
=
input
.
shape
assert
k
.
ndim
==
2
input
=
input
.
reshape
(
-
1
,
in_h
,
in_w
,
1
)
assert
k
.
shape
[
0
]
==
k
.
shape
[
1
]
return
k
_
,
in_h
,
in_w
,
minor
=
input
.
shape
kernel_h
,
kernel_w
=
kernel
.
shape
out
=
input
.
view
(
-
1
,
in_h
,
1
,
in_w
,
1
,
minor
)
out
=
F
.
pad
(
out
,
[
0
,
0
,
0
,
up_x
-
1
,
0
,
0
,
0
,
up_y
-
1
])
out
=
out
.
view
(
-
1
,
in_h
*
up_y
,
in_w
*
up_x
,
minor
)
out
=
F
.
pad
(
out
,
[
0
,
0
,
max
(
pad_x0
,
0
),
max
(
pad_x1
,
0
),
max
(
pad_y0
,
0
),
max
(
pad_y1
,
0
)])
out
=
out
[
:,
max
(
-
pad_y0
,
0
)
:
out
.
shape
[
1
]
-
max
(
-
pad_y1
,
0
),
max
(
-
pad_x0
,
0
)
:
out
.
shape
[
2
]
-
max
(
-
pad_x1
,
0
),
:,
]
out
=
out
.
permute
(
0
,
3
,
1
,
2
)
out
=
out
.
reshape
([
-
1
,
1
,
in_h
*
up_y
+
pad_y0
+
pad_y1
,
in_w
*
up_x
+
pad_x0
+
pad_x1
])
w
=
torch
.
flip
(
kernel
,
[
0
,
1
]).
view
(
1
,
1
,
kernel_h
,
kernel_w
)
out
=
F
.
conv2d
(
out
,
w
)
out
=
out
.
reshape
(
-
1
,
minor
,
in_h
*
up_y
+
pad_y0
+
pad_y1
-
kernel_h
+
1
,
in_w
*
up_x
+
pad_x0
+
pad_x1
-
kernel_w
+
1
,
)
out
=
out
.
permute
(
0
,
2
,
3
,
1
)
out
=
out
[:,
::
down_y
,
::
down_x
,
:]
out_h
=
(
in_h
*
up_y
+
pad_y0
+
pad_y1
-
kernel_h
)
//
down_y
+
1
out_w
=
(
in_w
*
up_x
+
pad_x0
+
pad_x1
-
kernel_w
)
//
down_x
+
1
return
out
.
view
(
-
1
,
channel
,
out_h
,
out_w
)
# Function ported from StyleGAN2
def
get_weight
(
module
,
shape
,
weight_var
=
"weight"
,
kernel_init
=
None
):
"""Get/create weight tensor for a convolution or fully-connected layer."""
return
module
.
param
(
weight_var
,
kernel_init
,
shape
)
class
Conv2d
(
nn
.
Module
):
"""Conv2d layer with optimal upsampling and downsampling (StyleGAN2)."""
def
__init__
(
self
,
in_ch
,
out_ch
,
kernel
,
up
=
False
,
down
=
False
,
resample_kernel
=
(
1
,
3
,
3
,
1
),
use_bias
=
True
,
kernel_init
=
None
,
):
super
().
__init__
()
assert
not
(
up
and
down
)
assert
kernel
>=
1
and
kernel
%
2
==
1
self
.
weight
=
nn
.
Parameter
(
torch
.
zeros
(
out_ch
,
in_ch
,
kernel
,
kernel
))
if
kernel_init
is
not
None
:
self
.
weight
.
data
=
kernel_init
(
self
.
weight
.
data
.
shape
)
if
use_bias
:
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
out_ch
))
self
.
up
=
up
self
.
down
=
down
self
.
resample_kernel
=
resample_kernel
self
.
kernel
=
kernel
self
.
use_bias
=
use_bias
def
forward
(
self
,
x
):
if
self
.
up
:
x
=
upsample_conv_2d
(
x
,
self
.
weight
,
k
=
self
.
resample_kernel
)
elif
self
.
down
:
x
=
conv_downsample_2d
(
x
,
self
.
weight
,
k
=
self
.
resample_kernel
)
else
:
x
=
F
.
conv2d
(
x
,
self
.
weight
,
stride
=
1
,
padding
=
self
.
kernel
//
2
)
if
self
.
use_bias
:
x
=
x
+
self
.
bias
.
reshape
(
1
,
-
1
,
1
,
1
)
return
x
def
naive_upsample_2d
(
x
,
factor
=
2
):
_N
,
C
,
H
,
W
=
x
.
shape
x
=
torch
.
reshape
(
x
,
(
-
1
,
C
,
H
,
1
,
W
,
1
))
x
=
x
.
repeat
(
1
,
1
,
1
,
factor
,
1
,
factor
)
return
torch
.
reshape
(
x
,
(
-
1
,
C
,
H
*
factor
,
W
*
factor
))
def
naive_downsample_2d
(
x
,
factor
=
2
):
_N
,
C
,
H
,
W
=
x
.
shape
x
=
torch
.
reshape
(
x
,
(
-
1
,
C
,
H
//
factor
,
factor
,
W
//
factor
,
factor
))
return
torch
.
mean
(
x
,
dim
=
(
3
,
5
))
def
upsample_conv_2d
(
x
,
w
,
k
=
None
,
factor
=
2
,
gain
=
1
):
def
_
upsample_conv_2d
(
x
,
w
,
k
=
None
,
factor
=
2
,
gain
=
1
):
"""Fused `upsample_2d()` followed by `
tf.nn.c
onv2d()`.
"""Fused `upsample_2d()` followed by `
C
onv2d()`.
Args:
Args:
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
...
@@ -176,13 +80,13 @@ def upsample_conv_2d(x, w, k=None, factor=2, gain=1):
...
@@ -176,13 +80,13 @@ def upsample_conv_2d(x, w, k=None, factor=2, gain=1):
# Determine data dimensions.
# Determine data dimensions.
stride
=
[
1
,
1
,
factor
,
factor
]
stride
=
[
1
,
1
,
factor
,
factor
]
output_shape
=
((
_
shape
(
x
,
2
)
-
1
)
*
factor
+
convH
,
(
_
shape
(
x
,
3
)
-
1
)
*
factor
+
convW
)
output_shape
=
((
x
.
shape
[
2
]
-
1
)
*
factor
+
convH
,
(
x
.
shape
[
3
]
-
1
)
*
factor
+
convW
)
output_padding
=
(
output_padding
=
(
output_shape
[
0
]
-
(
_
shape
(
x
,
2
)
-
1
)
*
stride
[
0
]
-
convH
,
output_shape
[
0
]
-
(
x
.
shape
[
2
]
-
1
)
*
stride
[
0
]
-
convH
,
output_shape
[
1
]
-
(
_
shape
(
x
,
3
)
-
1
)
*
stride
[
1
]
-
convW
,
output_shape
[
1
]
-
(
x
.
shape
[
3
]
-
1
)
*
stride
[
1
]
-
convW
,
)
)
assert
output_padding
[
0
]
>=
0
and
output_padding
[
1
]
>=
0
assert
output_padding
[
0
]
>=
0
and
output_padding
[
1
]
>=
0
num_groups
=
_
shape
(
x
,
1
)
//
inC
num_groups
=
x
.
shape
[
1
]
//
inC
# Transpose weights.
# Transpose weights.
w
=
torch
.
reshape
(
w
,
(
num_groups
,
-
1
,
inC
,
convH
,
convW
))
w
=
torch
.
reshape
(
w
,
(
num_groups
,
-
1
,
inC
,
convH
,
convW
))
...
@@ -190,21 +94,12 @@ def upsample_conv_2d(x, w, k=None, factor=2, gain=1):
...
@@ -190,21 +94,12 @@ def upsample_conv_2d(x, w, k=None, factor=2, gain=1):
w
=
torch
.
reshape
(
w
,
(
num_groups
*
inC
,
-
1
,
convH
,
convW
))
w
=
torch
.
reshape
(
w
,
(
num_groups
*
inC
,
-
1
,
convH
,
convW
))
x
=
F
.
conv_transpose2d
(
x
,
w
,
stride
=
stride
,
output_padding
=
output_padding
,
padding
=
0
)
x
=
F
.
conv_transpose2d
(
x
,
w
,
stride
=
stride
,
output_padding
=
output_padding
,
padding
=
0
)
# Original TF code.
# x = tf.nn.conv2d_transpose(
# x,
# w,
# output_shape=output_shape,
# strides=stride,
# padding='VALID',
# data_format=data_format)
# JAX equivalent
return
upfirdn2d
(
x
,
torch
.
tensor
(
k
,
device
=
x
.
device
),
pad
=
((
p
+
1
)
//
2
+
factor
-
1
,
p
//
2
+
1
))
return
upfirdn2d
(
x
,
torch
.
tensor
(
k
,
device
=
x
.
device
),
pad
=
((
p
+
1
)
//
2
+
factor
-
1
,
p
//
2
+
1
))
def
conv_downsample_2d
(
x
,
w
,
k
=
None
,
factor
=
2
,
gain
=
1
):
def
_
conv_downsample_2d
(
x
,
w
,
k
=
None
,
factor
=
2
,
gain
=
1
):
"""Fused `
tf.nn.c
onv2d()` followed by `downsample_2d()`.
"""Fused `
C
onv2d()` followed by `downsample_2d()`.
Args:
Args:
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
...
@@ -235,138 +130,9 @@ def conv_downsample_2d(x, w, k=None, factor=2, gain=1):
...
@@ -235,138 +130,9 @@ def conv_downsample_2d(x, w, k=None, factor=2, gain=1):
return
F
.
conv2d
(
x
,
w
,
stride
=
s
,
padding
=
0
)
return
F
.
conv2d
(
x
,
w
,
stride
=
s
,
padding
=
0
)
def
_setup_kernel
(
k
):
def
_variance_scaling
(
scale
=
1.0
,
in_axis
=
1
,
out_axis
=
0
,
dtype
=
torch
.
float32
,
device
=
"cpu"
):
k
=
np
.
asarray
(
k
,
dtype
=
np
.
float32
)
if
k
.
ndim
==
1
:
k
=
np
.
outer
(
k
,
k
)
k
/=
np
.
sum
(
k
)
assert
k
.
ndim
==
2
assert
k
.
shape
[
0
]
==
k
.
shape
[
1
]
return
k
def
_shape
(
x
,
dim
):
return
x
.
shape
[
dim
]
def
upsample_2d
(
x
,
k
=
None
,
factor
=
2
,
gain
=
1
):
r
"""Upsample a batch of 2D images with the given filter.
Args:
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
`gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a:
multiple of the upsampling factor.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H * factor, W * factor]`
"""
assert
isinstance
(
factor
,
int
)
and
factor
>=
1
if
k
is
None
:
k
=
[
1
]
*
factor
k
=
_setup_kernel
(
k
)
*
(
gain
*
(
factor
**
2
))
p
=
k
.
shape
[
0
]
-
factor
return
upfirdn2d
(
x
,
torch
.
tensor
(
k
,
device
=
x
.
device
),
up
=
factor
,
pad
=
((
p
+
1
)
//
2
+
factor
-
1
,
p
//
2
))
def
downsample_2d
(
x
,
k
=
None
,
factor
=
2
,
gain
=
1
):
r
"""Downsample a batch of 2D images with the given filter.
Args:
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
shape is a multiple of the downsampling factor.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to average pooling.
factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H // factor, W // factor]`
"""
assert
isinstance
(
factor
,
int
)
and
factor
>=
1
if
k
is
None
:
k
=
[
1
]
*
factor
k
=
_setup_kernel
(
k
)
*
gain
p
=
k
.
shape
[
0
]
-
factor
return
upfirdn2d
(
x
,
torch
.
tensor
(
k
,
device
=
x
.
device
),
down
=
factor
,
pad
=
((
p
+
1
)
//
2
,
p
//
2
))
def
conv1x1
(
in_planes
,
out_planes
,
stride
=
1
,
bias
=
True
,
init_scale
=
1.0
,
padding
=
0
):
"""1x1 convolution with DDPM initialization."""
conv
=
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
1
,
stride
=
stride
,
padding
=
padding
,
bias
=
bias
)
conv
.
weight
.
data
=
default_init
(
init_scale
)(
conv
.
weight
.
data
.
shape
)
nn
.
init
.
zeros_
(
conv
.
bias
)
return
conv
def
conv3x3
(
in_planes
,
out_planes
,
stride
=
1
,
bias
=
True
,
dilation
=
1
,
init_scale
=
1.0
,
padding
=
1
):
"""3x3 convolution with DDPM initialization."""
conv
=
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
,
bias
=
bias
)
conv
.
weight
.
data
=
default_init
(
init_scale
)(
conv
.
weight
.
data
.
shape
)
nn
.
init
.
zeros_
(
conv
.
bias
)
return
conv
def
_einsum
(
a
,
b
,
c
,
x
,
y
):
einsum_str
=
"{},{}->{}"
.
format
(
""
.
join
(
a
),
""
.
join
(
b
),
""
.
join
(
c
))
return
torch
.
einsum
(
einsum_str
,
x
,
y
)
def
contract_inner
(
x
,
y
):
"""tensordot(x, y, 1)."""
x_chars
=
list
(
string
.
ascii_lowercase
[:
len
(
x
.
shape
)])
y_chars
=
list
(
string
.
ascii_lowercase
[
len
(
x
.
shape
)
:
len
(
y
.
shape
)
+
len
(
x
.
shape
)])
y_chars
[
0
]
=
x_chars
[
-
1
]
# first axis of y and last of x get summed
out_chars
=
x_chars
[:
-
1
]
+
y_chars
[
1
:]
return
_einsum
(
x_chars
,
y_chars
,
out_chars
,
x
,
y
)
class
NIN
(
nn
.
Module
):
def
__init__
(
self
,
in_dim
,
num_units
,
init_scale
=
0.1
):
super
().
__init__
()
self
.
W
=
nn
.
Parameter
(
default_init
(
scale
=
init_scale
)((
in_dim
,
num_units
)),
requires_grad
=
True
)
self
.
b
=
nn
.
Parameter
(
torch
.
zeros
(
num_units
),
requires_grad
=
True
)
def
forward
(
self
,
x
):
x
=
x
.
permute
(
0
,
2
,
3
,
1
)
y
=
contract_inner
(
x
,
self
.
W
)
+
self
.
b
return
y
.
permute
(
0
,
3
,
1
,
2
)
def
get_act
(
nonlinearity
):
"""Get activation functions from the config file."""
if
nonlinearity
.
lower
()
==
"elu"
:
return
nn
.
ELU
()
elif
nonlinearity
.
lower
()
==
"relu"
:
return
nn
.
ReLU
()
elif
nonlinearity
.
lower
()
==
"lrelu"
:
return
nn
.
LeakyReLU
(
negative_slope
=
0.2
)
elif
nonlinearity
.
lower
()
==
"swish"
:
return
nn
.
SiLU
()
else
:
raise
NotImplementedError
(
"activation function does not exist!"
)
def
default_init
(
scale
=
1.0
):
"""The same initialization used in DDPM."""
scale
=
1e-10
if
scale
==
0
else
scale
return
variance_scaling
(
scale
,
"fan_avg"
,
"uniform"
)
def
variance_scaling
(
scale
,
mode
,
distribution
,
in_axis
=
1
,
out_axis
=
0
,
dtype
=
torch
.
float32
,
device
=
"cpu"
):
"""Ported from JAX."""
"""Ported from JAX."""
scale
=
1e-10
if
scale
==
0
else
scale
def
_compute_fans
(
shape
,
in_axis
=
1
,
out_axis
=
0
):
def
_compute_fans
(
shape
,
in_axis
=
1
,
out_axis
=
0
):
receptive_field_size
=
np
.
prod
(
shape
)
/
shape
[
in_axis
]
/
shape
[
out_axis
]
receptive_field_size
=
np
.
prod
(
shape
)
/
shape
[
in_axis
]
/
shape
[
out_axis
]
...
@@ -376,31 +142,35 @@ def variance_scaling(scale, mode, distribution, in_axis=1, out_axis=0, dtype=tor
...
@@ -376,31 +142,35 @@ def variance_scaling(scale, mode, distribution, in_axis=1, out_axis=0, dtype=tor
def
init
(
shape
,
dtype
=
dtype
,
device
=
device
):
def
init
(
shape
,
dtype
=
dtype
,
device
=
device
):
fan_in
,
fan_out
=
_compute_fans
(
shape
,
in_axis
,
out_axis
)
fan_in
,
fan_out
=
_compute_fans
(
shape
,
in_axis
,
out_axis
)
if
mode
==
"fan_in"
:
denominator
=
fan_in
elif
mode
==
"fan_out"
:
denominator
=
fan_out
elif
mode
==
"fan_avg"
:
denominator
=
(
fan_in
+
fan_out
)
/
2
denominator
=
(
fan_in
+
fan_out
)
/
2
else
:
raise
ValueError
(
"invalid mode for variance scaling initializer: {}"
.
format
(
mode
))
variance
=
scale
/
denominator
variance
=
scale
/
denominator
if
distribution
==
"normal"
:
return
torch
.
randn
(
*
shape
,
dtype
=
dtype
,
device
=
device
)
*
np
.
sqrt
(
variance
)
elif
distribution
==
"uniform"
:
return
(
torch
.
rand
(
*
shape
,
dtype
=
dtype
,
device
=
device
)
*
2.0
-
1.0
)
*
np
.
sqrt
(
3
*
variance
)
return
(
torch
.
rand
(
*
shape
,
dtype
=
dtype
,
device
=
device
)
*
2.0
-
1.0
)
*
np
.
sqrt
(
3
*
variance
)
else
:
raise
ValueError
(
"invalid distribution for variance scaling initializer"
)
return
init
return
init
def
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
3
,
stride
=
1
,
bias
=
True
,
init_scale
=
1.0
,
padding
=
1
):
"""nXn convolution with DDPM initialization."""
conv
=
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
bias
=
bias
)
conv
.
weight
.
data
=
_variance_scaling
(
init_scale
)(
conv
.
weight
.
data
.
shape
)
nn
.
init
.
zeros_
(
conv
.
bias
)
return
conv
def
Linear
(
dim_in
,
dim_out
):
linear
=
nn
.
Linear
(
dim_in
,
dim_out
)
linear
.
weight
.
data
=
_variance_scaling
()(
linear
.
weight
.
shape
)
nn
.
init
.
zeros_
(
linear
.
bias
)
return
linear
class
Combine
(
nn
.
Module
):
class
Combine
(
nn
.
Module
):
"""Combine information from skip connections."""
"""Combine information from skip connections."""
def
__init__
(
self
,
dim1
,
dim2
,
method
=
"cat"
):
def
__init__
(
self
,
dim1
,
dim2
,
method
=
"cat"
):
super
().
__init__
()
super
().
__init__
()
self
.
Conv_0
=
conv1x1
(
dim1
,
dim2
)
# 1x1 convolution with DDPM initialization.
self
.
Conv_0
=
Conv2d
(
dim1
,
dim2
,
kernel_size
=
1
,
padding
=
0
)
self
.
method
=
method
self
.
method
=
method
def
forward
(
self
,
x
,
y
):
def
forward
(
self
,
x
,
y
):
...
@@ -413,80 +183,42 @@ class Combine(nn.Module):
...
@@ -413,80 +183,42 @@ class Combine(nn.Module):
raise
ValueError
(
f
"Method
{
self
.
method
}
not recognized."
)
raise
ValueError
(
f
"Method
{
self
.
method
}
not recognized."
)
class
Upsample
(
nn
.
Module
):
class
Fir
Upsample
(
nn
.
Module
):
def
__init__
(
self
,
in_ch
=
None
,
out_ch
=
None
,
with
_conv
=
False
,
fir
=
False
,
fir_kernel
=
(
1
,
3
,
3
,
1
)):
def
__init__
(
self
,
channels
=
None
,
out_ch
annels
=
None
,
use
_conv
=
False
,
fir_kernel
=
(
1
,
3
,
3
,
1
)):
super
().
__init__
()
super
().
__init__
()
out_ch
=
out_ch
if
out_ch
else
in_ch
out_channels
=
out_channels
if
out_channels
else
channels
if
not
fir
:
if
use_conv
:
if
with_conv
:
self
.
Conv2d_0
=
Conv2d
(
channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
Conv_0
=
conv3x3
(
in_ch
,
out_ch
)
self
.
use_conv
=
use_conv
else
:
if
with_conv
:
self
.
Conv2d_0
=
Conv2d
(
in_ch
,
out_ch
,
kernel
=
3
,
up
=
True
,
resample_kernel
=
fir_kernel
,
use_bias
=
True
,
kernel_init
=
default_init
(),
)
self
.
fir
=
fir
self
.
with_conv
=
with_conv
self
.
fir_kernel
=
fir_kernel
self
.
fir_kernel
=
fir_kernel
self
.
out_ch
=
out_ch
self
.
out_ch
annels
=
out_ch
annels
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
B
,
C
,
H
,
W
=
x
.
shape
if
self
.
use_conv
:
if
not
self
.
fir
:
h
=
_upsample_conv_2d
(
x
,
self
.
Conv2d_0
.
weight
,
k
=
self
.
fir_kernel
)
h
=
F
.
interpolate
(
x
,
(
H
*
2
,
W
*
2
),
"nearest"
)
h
=
h
+
self
.
Conv2d_0
.
bias
.
reshape
(
1
,
-
1
,
1
,
1
)
if
self
.
with_conv
:
h
=
self
.
Conv_0
(
h
)
else
:
else
:
if
not
self
.
with_conv
:
h
=
upsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
h
=
upsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
else
:
h
=
self
.
Conv2d_0
(
x
)
return
h
return
h
class
Downsample
(
nn
.
Module
):
class
Fir
Downsample
(
nn
.
Module
):
def
__init__
(
self
,
in_ch
=
None
,
out_ch
=
None
,
with
_conv
=
False
,
fir
=
False
,
fir_kernel
=
(
1
,
3
,
3
,
1
)):
def
__init__
(
self
,
channels
=
None
,
out_ch
annels
=
None
,
use
_conv
=
False
,
fir_kernel
=
(
1
,
3
,
3
,
1
)):
super
().
__init__
()
super
().
__init__
()
out_ch
=
out_ch
if
out_ch
else
in_ch
out_channels
=
out_channels
if
out_channels
else
channels
if
not
fir
:
if
use_conv
:
if
with_conv
:
self
.
Conv2d_0
=
self
.
Conv2d_0
=
Conv2d
(
channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
Conv_0
=
conv3x3
(
in_ch
,
out_ch
,
stride
=
2
,
padding
=
0
)
else
:
if
with_conv
:
self
.
Conv2d_0
=
Conv2d
(
in_ch
,
out_ch
,
kernel
=
3
,
down
=
True
,
resample_kernel
=
fir_kernel
,
use_bias
=
True
,
kernel_init
=
default_init
(),
)
self
.
fir
=
fir
self
.
fir_kernel
=
fir_kernel
self
.
fir_kernel
=
fir_kernel
self
.
with
_conv
=
with
_conv
self
.
use
_conv
=
use
_conv
self
.
out_ch
=
out_ch
self
.
out_ch
annels
=
out_ch
annels
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
B
,
C
,
H
,
W
=
x
.
shape
if
self
.
use_conv
:
if
not
self
.
fir
:
x
=
_conv_downsample_2d
(
x
,
self
.
Conv2d_0
.
weight
,
k
=
self
.
fir_kernel
)
if
self
.
with_conv
:
x
=
x
+
self
.
Conv2d_0
.
bias
.
reshape
(
1
,
-
1
,
1
,
1
)
x
=
F
.
pad
(
x
,
(
0
,
1
,
0
,
1
))
x
=
self
.
Conv_0
(
x
)
else
:
else
:
x
=
F
.
avg_pool2d
(
x
,
2
,
stride
=
2
)
else
:
if
not
self
.
with_conv
:
x
=
downsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
x
=
downsample_2d
(
x
,
self
.
fir_kernel
,
factor
=
2
)
else
:
x
=
self
.
Conv2d_0
(
x
)
return
x
return
x
...
@@ -496,10 +228,9 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -496,10 +228,9 @@ class NCSNpp(ModelMixin, ConfigMixin):
def
__init__
(
def
__init__
(
self
,
self
,
centered
=
False
,
image_size
=
1024
,
image_size
=
1024
,
num_channels
=
3
,
num_channels
=
3
,
attention_type
=
"ddpm"
,
centered
=
False
,
attn_resolutions
=
(
16
,),
attn_resolutions
=
(
16
,),
ch_mult
=
(
1
,
2
,
4
,
8
,
16
,
32
,
32
,
32
),
ch_mult
=
(
1
,
2
,
4
,
8
,
16
,
32
,
32
,
32
),
conditional
=
True
,
conditional
=
True
,
...
@@ -511,24 +242,20 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -511,24 +242,20 @@ class NCSNpp(ModelMixin, ConfigMixin):
fourier_scale
=
16
,
fourier_scale
=
16
,
init_scale
=
0.0
,
init_scale
=
0.0
,
nf
=
16
,
nf
=
16
,
nonlinearity
=
"swish"
,
normalization
=
"GroupNorm"
,
num_res_blocks
=
1
,
num_res_blocks
=
1
,
progressive
=
"output_skip"
,
progressive
=
"output_skip"
,
progressive_combine
=
"sum"
,
progressive_combine
=
"sum"
,
progressive_input
=
"input_skip"
,
progressive_input
=
"input_skip"
,
resamp_with_conv
=
True
,
resamp_with_conv
=
True
,
resblock_type
=
"biggan"
,
scale_by_sigma
=
True
,
scale_by_sigma
=
True
,
skip_rescale
=
True
,
skip_rescale
=
True
,
continuous
=
True
,
continuous
=
True
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
register_to_config
(
self
.
register_to_config
(
centered
=
centered
,
image_size
=
image_size
,
image_size
=
image_size
,
num_channels
=
num_channels
,
num_channels
=
num_channels
,
attention_type
=
attention_type
,
centered
=
centered
,
attn_resolutions
=
attn_resolutions
,
attn_resolutions
=
attn_resolutions
,
ch_mult
=
ch_mult
,
ch_mult
=
ch_mult
,
conditional
=
conditional
,
conditional
=
conditional
,
...
@@ -540,19 +267,16 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -540,19 +267,16 @@ class NCSNpp(ModelMixin, ConfigMixin):
fourier_scale
=
fourier_scale
,
fourier_scale
=
fourier_scale
,
init_scale
=
init_scale
,
init_scale
=
init_scale
,
nf
=
nf
,
nf
=
nf
,
nonlinearity
=
nonlinearity
,
normalization
=
normalization
,
num_res_blocks
=
num_res_blocks
,
num_res_blocks
=
num_res_blocks
,
progressive
=
progressive
,
progressive
=
progressive
,
progressive_combine
=
progressive_combine
,
progressive_combine
=
progressive_combine
,
progressive_input
=
progressive_input
,
progressive_input
=
progressive_input
,
resamp_with_conv
=
resamp_with_conv
,
resamp_with_conv
=
resamp_with_conv
,
resblock_type
=
resblock_type
,
scale_by_sigma
=
scale_by_sigma
,
scale_by_sigma
=
scale_by_sigma
,
skip_rescale
=
skip_rescale
,
skip_rescale
=
skip_rescale
,
continuous
=
continuous
,
continuous
=
continuous
,
)
)
self
.
act
=
act
=
get_act
(
nonlinearity
)
self
.
act
=
act
=
nn
.
SiLU
(
)
self
.
nf
=
nf
self
.
nf
=
nf
self
.
num_res_blocks
=
num_res_blocks
self
.
num_res_blocks
=
num_res_blocks
...
@@ -562,7 +286,6 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -562,7 +286,6 @@ class NCSNpp(ModelMixin, ConfigMixin):
self
.
conditional
=
conditional
self
.
conditional
=
conditional
self
.
skip_rescale
=
skip_rescale
self
.
skip_rescale
=
skip_rescale
self
.
resblock_type
=
resblock_type
self
.
progressive
=
progressive
self
.
progressive
=
progressive
self
.
progressive_input
=
progressive_input
self
.
progressive_input
=
progressive_input
self
.
embedding_type
=
embedding_type
self
.
embedding_type
=
embedding_type
...
@@ -585,40 +308,31 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -585,40 +308,31 @@ class NCSNpp(ModelMixin, ConfigMixin):
else
:
else
:
raise
ValueError
(
f
"embedding type
{
embedding_type
}
unknown."
)
raise
ValueError
(
f
"embedding type
{
embedding_type
}
unknown."
)
if
conditional
:
modules
.
append
(
Linear
(
embed_dim
,
nf
*
4
))
modules
.
append
(
nn
.
Linear
(
embed_dim
,
nf
*
4
))
modules
.
append
(
Linear
(
nf
*
4
,
nf
*
4
))
modules
[
-
1
].
weight
.
data
=
default_init
()(
modules
[
-
1
].
weight
.
shape
)
nn
.
init
.
zeros_
(
modules
[
-
1
].
bias
)
modules
.
append
(
nn
.
Linear
(
nf
*
4
,
nf
*
4
))
modules
[
-
1
].
weight
.
data
=
default_init
()(
modules
[
-
1
].
weight
.
shape
)
nn
.
init
.
zeros_
(
modules
[
-
1
].
bias
)
AttnBlock
=
functools
.
partial
(
AttentionBlock
,
overwrite_linear
=
True
,
rescale_output_factor
=
math
.
sqrt
(
2.0
))
AttnBlock
=
functools
.
partial
(
AttentionBlock
,
overwrite_linear
=
True
,
rescale_output_factor
=
math
.
sqrt
(
2.0
))
Up_sample
=
functools
.
partial
(
Upsample
,
with_conv
=
resamp_with_conv
,
fir
=
fir
,
fir_kernel
=
fir_kernel
)
if
self
.
fir
:
Up_sample
=
functools
.
partial
(
FirUpsample
,
fir_kernel
=
fir_kernel
,
use_conv
=
resamp_with_conv
)
else
:
Up_sample
=
functools
.
partial
(
Upsample
,
name
=
"Conv2d_0"
)
if
progressive
==
"output_skip"
:
if
progressive
==
"output_skip"
:
self
.
pyramid_upsample
=
Up_sample
(
fir
=
fir
,
fir_kernel
=
fir_ker
ne
l
,
with
_conv
=
False
)
self
.
pyramid_upsample
=
Up_sample
(
channels
=
No
ne
,
use
_conv
=
False
)
elif
progressive
==
"residual"
:
elif
progressive
==
"residual"
:
pyramid_upsample
=
functools
.
partial
(
Up_sample
,
fir
=
fir
,
fir_kernel
=
fir_kernel
,
with
_conv
=
True
)
pyramid_upsample
=
functools
.
partial
(
Up_sample
,
use
_conv
=
True
)
Down_sample
=
functools
.
partial
(
Downsample
,
with_conv
=
resamp_with_conv
,
fir
=
fir
,
fir_kernel
=
fir_kernel
)
if
self
.
fir
:
Down_sample
=
functools
.
partial
(
FirDownsample
,
fir_kernel
=
fir_kernel
,
use_conv
=
resamp_with_conv
)
else
:
Down_sample
=
functools
.
partial
(
Downsample
,
padding
=
0
,
name
=
"Conv2d_0"
)
if
progressive_input
==
"input_skip"
:
if
progressive_input
==
"input_skip"
:
self
.
pyramid_downsample
=
Down_sample
(
fir
=
fir
,
fir_kernel
=
fir_ker
ne
l
,
with
_conv
=
False
)
self
.
pyramid_downsample
=
Down_sample
(
channels
=
No
ne
,
use
_conv
=
False
)
elif
progressive_input
==
"residual"
:
elif
progressive_input
==
"residual"
:
pyramid_downsample
=
functools
.
partial
(
Down_sample
,
fir
=
fir
,
fir_kernel
=
fir_kernel
,
with_conv
=
True
)
pyramid_downsample
=
functools
.
partial
(
Down_sample
,
use_conv
=
True
)
if
resblock_type
==
"ddpm"
:
ResnetBlock
=
functools
.
partial
(
ResnetBlockDDPMpp
,
act
=
act
,
dropout
=
dropout
,
init_scale
=
init_scale
,
skip_rescale
=
skip_rescale
,
temb_dim
=
nf
*
4
,
)
elif
resblock_type
==
"biggan"
:
ResnetBlock
=
functools
.
partial
(
ResnetBlock
=
functools
.
partial
(
ResnetBlockBigGANpp
,
ResnetBlockBigGANpp
,
act
=
act
,
act
=
act
,
...
@@ -630,16 +344,13 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -630,16 +344,13 @@ class NCSNpp(ModelMixin, ConfigMixin):
temb_dim
=
nf
*
4
,
temb_dim
=
nf
*
4
,
)
)
else
:
raise
ValueError
(
f
"resblock type
{
resblock_type
}
unrecognized."
)
# Downsampling block
# Downsampling block
channels
=
num_channels
channels
=
num_channels
if
progressive_input
!=
"none"
:
if
progressive_input
!=
"none"
:
input_pyramid_ch
=
channels
input_pyramid_ch
=
channels
modules
.
append
(
c
onv
3x3
(
channels
,
nf
))
modules
.
append
(
C
onv
2d
(
channels
,
nf
,
kernel_size
=
3
,
padding
=
1
))
hs_c
=
[
nf
]
hs_c
=
[
nf
]
in_ch
=
nf
in_ch
=
nf
...
@@ -655,9 +366,6 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -655,9 +366,6 @@ class NCSNpp(ModelMixin, ConfigMixin):
hs_c
.
append
(
in_ch
)
hs_c
.
append
(
in_ch
)
if
i_level
!=
self
.
num_resolutions
-
1
:
if
i_level
!=
self
.
num_resolutions
-
1
:
if
resblock_type
==
"ddpm"
:
modules
.
append
(
Downsample
(
in_ch
=
in_ch
))
else
:
modules
.
append
(
ResnetBlock
(
down
=
True
,
in_ch
=
in_ch
))
modules
.
append
(
ResnetBlock
(
down
=
True
,
in_ch
=
in_ch
))
if
progressive_input
==
"input_skip"
:
if
progressive_input
==
"input_skip"
:
...
@@ -666,7 +374,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -666,7 +374,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
in_ch
*=
2
in_ch
*=
2
elif
progressive_input
==
"residual"
:
elif
progressive_input
==
"residual"
:
modules
.
append
(
pyramid_downsample
(
in_ch
=
input_pyramid_ch
,
out_ch
=
in_ch
))
modules
.
append
(
pyramid_downsample
(
channels
=
input_pyramid_ch
,
out_ch
annels
=
in_ch
))
input_pyramid_ch
=
in_ch
input_pyramid_ch
=
in_ch
hs_c
.
append
(
in_ch
)
hs_c
.
append
(
in_ch
)
...
@@ -691,36 +399,35 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -691,36 +399,35 @@ class NCSNpp(ModelMixin, ConfigMixin):
if
i_level
==
self
.
num_resolutions
-
1
:
if
i_level
==
self
.
num_resolutions
-
1
:
if
progressive
==
"output_skip"
:
if
progressive
==
"output_skip"
:
modules
.
append
(
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
))
modules
.
append
(
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
))
modules
.
append
(
c
onv
3x3
(
in_ch
,
channels
,
init_scale
=
init_scale
))
modules
.
append
(
C
onv
2d
(
in_ch
,
channels
,
init_scale
=
init_scale
,
kernel_size
=
3
,
padding
=
1
))
pyramid_ch
=
channels
pyramid_ch
=
channels
elif
progressive
==
"residual"
:
elif
progressive
==
"residual"
:
modules
.
append
(
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
))
modules
.
append
(
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
))
modules
.
append
(
c
onv
3x3
(
in_ch
,
in_ch
,
bias
=
True
))
modules
.
append
(
C
onv
2d
(
in_ch
,
in_ch
,
bias
=
True
,
kernel_size
=
3
,
padding
=
1
))
pyramid_ch
=
in_ch
pyramid_ch
=
in_ch
else
:
else
:
raise
ValueError
(
f
"
{
progressive
}
is not a valid name."
)
raise
ValueError
(
f
"
{
progressive
}
is not a valid name."
)
else
:
else
:
if
progressive
==
"output_skip"
:
if
progressive
==
"output_skip"
:
modules
.
append
(
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
))
modules
.
append
(
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
))
modules
.
append
(
conv3x3
(
in_ch
,
channels
,
bias
=
True
,
init_scale
=
init_scale
))
modules
.
append
(
Conv2d
(
in_ch
,
channels
,
bias
=
True
,
init_scale
=
init_scale
,
kernel_size
=
3
,
padding
=
1
)
)
pyramid_ch
=
channels
pyramid_ch
=
channels
elif
progressive
==
"residual"
:
elif
progressive
==
"residual"
:
modules
.
append
(
pyramid_upsample
(
in_ch
=
pyramid_ch
,
out_ch
=
in_ch
))
modules
.
append
(
pyramid_upsample
(
channels
=
pyramid_ch
,
out_ch
annels
=
in_ch
))
pyramid_ch
=
in_ch
pyramid_ch
=
in_ch
else
:
else
:
raise
ValueError
(
f
"
{
progressive
}
is not a valid name"
)
raise
ValueError
(
f
"
{
progressive
}
is not a valid name"
)
if
i_level
!=
0
:
if
i_level
!=
0
:
if
resblock_type
==
"ddpm"
:
modules
.
append
(
Upsample
(
in_ch
=
in_ch
))
else
:
modules
.
append
(
ResnetBlock
(
in_ch
=
in_ch
,
up
=
True
))
modules
.
append
(
ResnetBlock
(
in_ch
=
in_ch
,
up
=
True
))
assert
not
hs_c
assert
not
hs_c
if
progressive
!=
"output_skip"
:
if
progressive
!=
"output_skip"
:
modules
.
append
(
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
))
modules
.
append
(
nn
.
GroupNorm
(
num_groups
=
min
(
in_ch
//
4
,
32
),
num_channels
=
in_ch
,
eps
=
1e-6
))
modules
.
append
(
c
onv
3x3
(
in_ch
,
channels
,
init_scale
=
init_scale
))
modules
.
append
(
C
onv
2d
(
in_ch
,
channels
,
init_scale
=
init_scale
))
self
.
all_modules
=
nn
.
ModuleList
(
modules
)
self
.
all_modules
=
nn
.
ModuleList
(
modules
)
...
@@ -751,8 +458,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -751,8 +458,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
else
:
else
:
temb
=
None
temb
=
None
if
not
self
.
config
.
centered
:
# If input data is in [0, 1]
# If input data is in [0, 1]
if
not
self
.
config
.
centered
:
x
=
2
*
x
-
1.0
x
=
2
*
x
-
1.0
# Downsampling block
# Downsampling block
...
@@ -774,10 +481,6 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -774,10 +481,6 @@ class NCSNpp(ModelMixin, ConfigMixin):
hs
.
append
(
h
)
hs
.
append
(
h
)
if
i_level
!=
self
.
num_resolutions
-
1
:
if
i_level
!=
self
.
num_resolutions
-
1
:
if
self
.
resblock_type
==
"ddpm"
:
h
=
modules
[
m_idx
](
hs
[
-
1
])
m_idx
+=
1
else
:
h
=
modules
[
m_idx
](
hs
[
-
1
],
temb
)
h
=
modules
[
m_idx
](
hs
[
-
1
],
temb
)
m_idx
+=
1
m_idx
+=
1
...
@@ -851,10 +554,6 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -851,10 +554,6 @@ class NCSNpp(ModelMixin, ConfigMixin):
raise
ValueError
(
f
"
{
self
.
progressive
}
is not a valid name"
)
raise
ValueError
(
f
"
{
self
.
progressive
}
is not a valid name"
)
if
i_level
!=
0
:
if
i_level
!=
0
:
if
self
.
resblock_type
==
"ddpm"
:
h
=
modules
[
m_idx
](
h
)
m_idx
+=
1
else
:
h
=
modules
[
m_idx
](
h
,
temb
)
h
=
modules
[
m_idx
](
h
,
temb
)
m_idx
+=
1
m_idx
+=
1
...
...
tests/test_modeling_utils.py
View file @
23904d54
...
@@ -259,7 +259,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -259,7 +259,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
# fmt: off
# fmt: off
expected_output_slice
=
torch
.
tensor
([
0.2891
,
-
0.1899
,
0.2595
,
-
0.6214
,
0.0968
,
-
0.2622
,
0.4688
,
0.1311
,
0.0053
])
expected_output_slice
=
torch
.
tensor
([
0.2891
,
-
0.1899
,
0.2595
,
-
0.6214
,
0.0968
,
-
0.2622
,
0.4688
,
0.1311
,
0.0053
])
# fmt: on
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
a
tol
=
1e-
3
))
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
r
tol
=
1e-
2
))
class
GlideSuperResUNetTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
class
GlideSuperResUNetTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
...
@@ -607,7 +607,7 @@ class UNetGradTTSModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -607,7 +607,7 @@ class UNetGradTTSModelTests(ModelTesterMixin, unittest.TestCase):
expected_output_slice
=
torch
.
tensor
([
-
0.0690
,
-
0.0531
,
0.0633
,
-
0.0660
,
-
0.0541
,
0.0650
,
-
0.0656
,
-
0.0555
,
0.0617
])
expected_output_slice
=
torch
.
tensor
([
-
0.0690
,
-
0.0531
,
0.0633
,
-
0.0660
,
-
0.0541
,
0.0650
,
-
0.0656
,
-
0.0555
,
0.0617
])
# fmt: on
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
a
tol
=
1e-3
))
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
r
tol
=
1e-3
))
class
TemporalUNetModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
class
TemporalUNetModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
...
@@ -678,7 +678,7 @@ class TemporalUNetModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -678,7 +678,7 @@ class TemporalUNetModelTests(ModelTesterMixin, unittest.TestCase):
expected_output_slice
=
torch
.
tensor
([
-
0.2714
,
0.1042
,
-
0.0794
,
-
0.2820
,
0.0803
,
-
0.0811
,
-
0.2345
,
0.0580
,
-
0.0584
])
expected_output_slice
=
torch
.
tensor
([
-
0.2714
,
0.1042
,
-
0.0794
,
-
0.2820
,
0.0803
,
-
0.0811
,
-
0.2345
,
0.0580
,
-
0.0584
])
# fmt: on
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
a
tol
=
1e-3
))
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
r
tol
=
1e-3
))
class
NCSNppModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
class
NCSNppModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
...
@@ -742,18 +742,18 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -742,18 +742,18 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
num_channels
=
3
num_channels
=
3
sizes
=
(
32
,
32
)
sizes
=
(
32
,
32
)
noise
=
floats_tensor
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
noise
=
torch
.
ones
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
time_step
=
torch
.
tensor
(
batch_size
*
[
1
0
]).
to
(
torch_device
)
time_step
=
torch
.
tensor
(
batch_size
*
[
1
e-4
]).
to
(
torch_device
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
)
output
=
model
(
noise
,
time_step
)
output_slice
=
output
[
0
,
-
3
:,
-
3
:,
-
1
].
flatten
().
cpu
()
output_slice
=
output
[
0
,
-
3
:,
-
3
:,
-
1
].
flatten
().
cpu
()
# fmt: off
# fmt: off
expected_output_slice
=
torch
.
tensor
([
3
.1
909e-07
,
-
8.5393e-08
,
4.8460e-07
,
-
4.5550e-07
,
-
1.3205e-06
,
-
6.3475e-07
,
9.7837e-07
,
2.9974e-07
,
1.2345e-0
6
])
expected_output_slice
=
torch
.
tensor
([
0
.1
315
,
0.0741
,
0.0393
,
0.0455
,
0.0556
,
0.0180
,
-
0.0832
,
-
0.0644
,
-
0.085
6
])
# fmt: on
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
a
tol
=
1e-
3
))
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
r
tol
=
1e-
2
))
def
test_output_pretrained_ve_large
(
self
):
def
test_output_pretrained_ve_large
(
self
):
model
=
NCSNpp
.
from_pretrained
(
"fusing/ncsnpp-ffhq-ve-dummy"
)
model
=
NCSNpp
.
from_pretrained
(
"fusing/ncsnpp-ffhq-ve-dummy"
)
...
@@ -768,21 +768,21 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -768,21 +768,21 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
num_channels
=
3
num_channels
=
3
sizes
=
(
32
,
32
)
sizes
=
(
32
,
32
)
noise
=
floats_tensor
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
noise
=
torch
.
ones
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
time_step
=
torch
.
tensor
(
batch_size
*
[
1
0
]).
to
(
torch_device
)
time_step
=
torch
.
tensor
(
batch_size
*
[
1
e-4
]).
to
(
torch_device
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
)
output
=
model
(
noise
,
time_step
)
output_slice
=
output
[
0
,
-
3
:,
-
3
:,
-
1
].
flatten
().
cpu
()
output_slice
=
output
[
0
,
-
3
:,
-
3
:,
-
1
].
flatten
().
cpu
()
# fmt: off
# fmt: off
expected_output_slice
=
torch
.
tensor
([
-
8.3299e-07
,
-
9.0431e-07
,
4.0585e-08
,
9.7563e-07
,
1.0280e-06
,
1.0133e-06
,
1.4979e-06
,
-
2.9716e-07
,
-
6.1817e-07
])
expected_output_slice
=
torch
.
tensor
([
-
0.0325
,
-
0.0900
,
-
0.0869
,
-
0.0332
,
-
0.0725
,
-
0.0270
,
-
0.0101
,
0.0227
,
0.0256
])
# fmt: on
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
a
tol
=
1e-
3
))
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
r
tol
=
1e-
2
))
def
test_output_pretrained_vp
(
self
):
def
test_output_pretrained_vp
(
self
):
model
=
NCSNpp
.
from_pretrained
(
"fusing/
ddpm-
cifar10-
vp-dummy
"
)
model
=
NCSNpp
.
from_pretrained
(
"fusing/cifar10-
ddpmpp-vp
"
)
model
.
eval
()
model
.
eval
()
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
...
@@ -794,18 +794,18 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -794,18 +794,18 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
num_channels
=
3
num_channels
=
3
sizes
=
(
32
,
32
)
sizes
=
(
32
,
32
)
noise
=
floats_tensor
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
noise
=
torch
.
randn
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
time_step
=
torch
.
tensor
(
batch_size
*
[
1
0
]).
to
(
torch_device
)
time_step
=
torch
.
tensor
(
batch_size
*
[
9.
0
]).
to
(
torch_device
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
)
output
=
model
(
noise
,
time_step
)
output_slice
=
output
[
0
,
-
3
:,
-
3
:,
-
1
].
flatten
().
cpu
()
output_slice
=
output
[
0
,
-
3
:,
-
3
:,
-
1
].
flatten
().
cpu
()
# fmt: off
# fmt: off
expected_output_slice
=
torch
.
tensor
([
-
3.9086e-07
,
-
1.1001e-0
5
,
1
.88
81e-06
,
1.1106e-05
,
1.
6629e-06
,
2.9820e-06
,
8
.4
978e-06
,
8.0253e-07
,
1.
5435e-0
6
])
expected_output_slice
=
torch
.
tensor
([
0.3303
,
-
0.227
5
,
-
2
.88
72
,
-
0.1309
,
-
1.
2861
,
3
.4
567
,
-
1.0083
,
2.5325
,
-
1.
386
6
])
# fmt: on
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
a
tol
=
1e-
3
))
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
r
tol
=
1e-
2
))
class
VQModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
class
VQModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
...
@@ -878,10 +878,9 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -878,10 +878,9 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
# fmt: off
expected_output_slice
=
torch
.
tensor
([
-
1.1321
,
0.1056
,
0.3505
,
-
0.6461
,
-
0.2014
,
0.0419
,
-
0.5763
,
-
0.8462
,
expected_output_slice
=
torch
.
tensor
([
-
1.1321
,
0.1056
,
0.3505
,
-
0.6461
,
-
0.2014
,
0.0419
,
-
0.5763
,
-
0.8462
,
-
0.4218
])
-
0.4218
])
# fmt: on
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
a
tol
=
1e-
3
))
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
r
tol
=
1e-
2
))
class
AutoEncoderKLTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
class
AutoEncoderKLTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
...
@@ -950,10 +949,9 @@ class AutoEncoderKLTests(ModelTesterMixin, unittest.TestCase):
...
@@ -950,10 +949,9 @@ class AutoEncoderKLTests(ModelTesterMixin, unittest.TestCase):
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
# fmt: off
expected_output_slice
=
torch
.
tensor
([
-
0.0814
,
-
0.0229
,
-
0.1320
,
-
0.4123
,
-
0.0366
,
-
0.3473
,
0.0438
,
-
0.1662
,
expected_output_slice
=
torch
.
tensor
([
-
0.0814
,
-
0.0229
,
-
0.1320
,
-
0.4123
,
-
0.0366
,
-
0.3473
,
0.0438
,
-
0.1662
,
0.1750
])
0.1750
])
# fmt: on
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
a
tol
=
1e-
3
))
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
r
tol
=
1e-
2
))
class
PipelineTesterMixin
(
unittest
.
TestCase
):
class
PipelineTesterMixin
(
unittest
.
TestCase
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment