Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
a677565f
Commit
a677565f
authored
Jun 17, 2022
by
Patrick von Platen
Browse files
Merge branch 'main' of
https://github.com/huggingface/diffusers
parents
ff885b0e
d182a6ad
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
412 additions
and
27 deletions
+412
-27
src/diffusers/models/unet.py
src/diffusers/models/unet.py
+4
-4
src/diffusers/models/unet_rl.py
src/diffusers/models/unet_rl.py
+242
-0
src/diffusers/pipelines/grad_tts_utils.py
src/diffusers/pipelines/grad_tts_utils.py
+5
-1
src/diffusers/schedulers/scheduling_ddpm.py
src/diffusers/schedulers/scheduling_ddpm.py
+8
-2
src/diffusers/schedulers/scheduling_utils.py
src/diffusers/schedulers/scheduling_utils.py
+10
-0
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+143
-20
No files found.
src/diffusers/models/unet.py
View file @
a677565f
...
...
@@ -287,14 +287,14 @@ class UNetModel(ModelMixin, ConfigMixin):
self
.
norm_out
=
Normalize
(
block_in
)
self
.
conv_out
=
torch
.
nn
.
Conv2d
(
block_in
,
out_ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
x
,
t
):
def
forward
(
self
,
x
,
t
imesteps
):
assert
x
.
shape
[
2
]
==
x
.
shape
[
3
]
==
self
.
resolution
if
not
torch
.
is_tensor
(
t
):
t
=
torch
.
tensor
([
t
],
dtype
=
torch
.
long
,
device
=
x
.
device
)
if
not
torch
.
is_tensor
(
t
imesteps
):
t
imesteps
=
torch
.
tensor
([
t
imesteps
],
dtype
=
torch
.
long
,
device
=
x
.
device
)
# timestep embedding
temb
=
get_timestep_embedding
(
t
,
self
.
ch
)
temb
=
get_timestep_embedding
(
t
imesteps
,
self
.
ch
)
temb
=
self
.
temb
.
dense
[
0
](
temb
)
temb
=
nonlinearity
(
temb
)
temb
=
self
.
temb
.
dense
[
1
](
temb
)
...
...
src/diffusers/models/unet_rl.py
0 → 100644
View file @
a677565f
# model adapted from diffuser https://github.com/jannerm/diffuser/blob/main/diffuser/models/temporal.py
import
torch
import
torch.nn
as
nn
import
einops
from
einops.layers.torch
import
Rearrange
import
math
class
SinusoidalPosEmb
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
().
__init__
()
self
.
dim
=
dim
def
forward
(
self
,
x
):
device
=
x
.
device
half_dim
=
self
.
dim
//
2
emb
=
math
.
log
(
10000
)
/
(
half_dim
-
1
)
emb
=
torch
.
exp
(
torch
.
arange
(
half_dim
,
device
=
device
)
*
-
emb
)
emb
=
x
[:,
None
]
*
emb
[
None
,
:]
emb
=
torch
.
cat
((
emb
.
sin
(),
emb
.
cos
()),
dim
=-
1
)
return
emb
class
Downsample1d
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
().
__init__
()
self
.
conv
=
nn
.
Conv1d
(
dim
,
dim
,
3
,
2
,
1
)
def
forward
(
self
,
x
):
return
self
.
conv
(
x
)
class
Upsample1d
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
().
__init__
()
self
.
conv
=
nn
.
ConvTranspose1d
(
dim
,
dim
,
4
,
2
,
1
)
def
forward
(
self
,
x
):
return
self
.
conv
(
x
)
class
Conv1dBlock
(
nn
.
Module
):
'''
Conv1d --> GroupNorm --> Mish
'''
def
__init__
(
self
,
inp_channels
,
out_channels
,
kernel_size
,
n_groups
=
8
):
super
().
__init__
()
self
.
block
=
nn
.
Sequential
(
nn
.
Conv1d
(
inp_channels
,
out_channels
,
kernel_size
,
padding
=
kernel_size
//
2
),
Rearrange
(
'batch channels horizon -> batch channels 1 horizon'
),
nn
.
GroupNorm
(
n_groups
,
out_channels
),
Rearrange
(
'batch channels 1 horizon -> batch channels horizon'
),
nn
.
Mish
(),
)
def
forward
(
self
,
x
):
return
self
.
block
(
x
)
class
ResidualTemporalBlock
(
nn
.
Module
):
def
__init__
(
self
,
inp_channels
,
out_channels
,
embed_dim
,
horizon
,
kernel_size
=
5
):
super
().
__init__
()
self
.
blocks
=
nn
.
ModuleList
([
Conv1dBlock
(
inp_channels
,
out_channels
,
kernel_size
),
Conv1dBlock
(
out_channels
,
out_channels
,
kernel_size
),
])
self
.
time_mlp
=
nn
.
Sequential
(
nn
.
Mish
(),
nn
.
Linear
(
embed_dim
,
out_channels
),
Rearrange
(
'batch t -> batch t 1'
),
)
self
.
residual_conv
=
nn
.
Conv1d
(
inp_channels
,
out_channels
,
1
)
\
if
inp_channels
!=
out_channels
else
nn
.
Identity
()
def
forward
(
self
,
x
,
t
):
'''
x : [ batch_size x inp_channels x horizon ]
t : [ batch_size x embed_dim ]
returns:
out : [ batch_size x out_channels x horizon ]
'''
out
=
self
.
blocks
[
0
](
x
)
+
self
.
time_mlp
(
t
)
out
=
self
.
blocks
[
1
](
out
)
return
out
+
self
.
residual_conv
(
x
)
class
TemporalUnet
(
nn
.
Module
):
def
__init__
(
self
,
horizon
,
transition_dim
,
cond_dim
,
dim
=
32
,
dim_mults
=
(
1
,
2
,
4
,
8
),
):
super
().
__init__
()
dims
=
[
transition_dim
,
*
map
(
lambda
m
:
dim
*
m
,
dim_mults
)]
in_out
=
list
(
zip
(
dims
[:
-
1
],
dims
[
1
:]))
print
(
f
'[ models/temporal ] Channel dimensions:
{
in_out
}
'
)
time_dim
=
dim
self
.
time_mlp
=
nn
.
Sequential
(
SinusoidalPosEmb
(
dim
),
nn
.
Linear
(
dim
,
dim
*
4
),
nn
.
Mish
(),
nn
.
Linear
(
dim
*
4
,
dim
),
)
self
.
downs
=
nn
.
ModuleList
([])
self
.
ups
=
nn
.
ModuleList
([])
num_resolutions
=
len
(
in_out
)
print
(
in_out
)
for
ind
,
(
dim_in
,
dim_out
)
in
enumerate
(
in_out
):
is_last
=
ind
>=
(
num_resolutions
-
1
)
self
.
downs
.
append
(
nn
.
ModuleList
([
ResidualTemporalBlock
(
dim_in
,
dim_out
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
ResidualTemporalBlock
(
dim_out
,
dim_out
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
Downsample1d
(
dim_out
)
if
not
is_last
else
nn
.
Identity
()
]))
if
not
is_last
:
horizon
=
horizon
//
2
mid_dim
=
dims
[
-
1
]
self
.
mid_block1
=
ResidualTemporalBlock
(
mid_dim
,
mid_dim
,
embed_dim
=
time_dim
,
horizon
=
horizon
)
self
.
mid_block2
=
ResidualTemporalBlock
(
mid_dim
,
mid_dim
,
embed_dim
=
time_dim
,
horizon
=
horizon
)
for
ind
,
(
dim_in
,
dim_out
)
in
enumerate
(
reversed
(
in_out
[
1
:])):
is_last
=
ind
>=
(
num_resolutions
-
1
)
self
.
ups
.
append
(
nn
.
ModuleList
([
ResidualTemporalBlock
(
dim_out
*
2
,
dim_in
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
ResidualTemporalBlock
(
dim_in
,
dim_in
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
Upsample1d
(
dim_in
)
if
not
is_last
else
nn
.
Identity
()
]))
if
not
is_last
:
horizon
=
horizon
*
2
self
.
final_conv
=
nn
.
Sequential
(
Conv1dBlock
(
dim
,
dim
,
kernel_size
=
5
),
nn
.
Conv1d
(
dim
,
transition_dim
,
1
),
)
def
forward
(
self
,
x
,
cond
,
time
):
'''
x : [ batch x horizon x transition ]
'''
x
=
einops
.
rearrange
(
x
,
'b h t -> b t h'
)
t
=
self
.
time_mlp
(
time
)
h
=
[]
for
resnet
,
resnet2
,
downsample
in
self
.
downs
:
x
=
resnet
(
x
,
t
)
x
=
resnet2
(
x
,
t
)
h
.
append
(
x
)
x
=
downsample
(
x
)
x
=
self
.
mid_block1
(
x
,
t
)
x
=
self
.
mid_block2
(
x
,
t
)
for
resnet
,
resnet2
,
upsample
in
self
.
ups
:
x
=
torch
.
cat
((
x
,
h
.
pop
()),
dim
=
1
)
x
=
resnet
(
x
,
t
)
x
=
resnet2
(
x
,
t
)
x
=
upsample
(
x
)
x
=
self
.
final_conv
(
x
)
x
=
einops
.
rearrange
(
x
,
'b t h -> b h t'
)
return
x
class
TemporalValue
(
nn
.
Module
):
def
__init__
(
self
,
horizon
,
transition_dim
,
cond_dim
,
dim
=
32
,
time_dim
=
None
,
out_dim
=
1
,
dim_mults
=
(
1
,
2
,
4
,
8
),
):
super
().
__init__
()
dims
=
[
transition_dim
,
*
map
(
lambda
m
:
dim
*
m
,
dim_mults
)]
in_out
=
list
(
zip
(
dims
[:
-
1
],
dims
[
1
:]))
time_dim
=
time_dim
or
dim
self
.
time_mlp
=
nn
.
Sequential
(
SinusoidalPosEmb
(
dim
),
nn
.
Linear
(
dim
,
dim
*
4
),
nn
.
Mish
(),
nn
.
Linear
(
dim
*
4
,
dim
),
)
self
.
blocks
=
nn
.
ModuleList
([])
print
(
in_out
)
for
dim_in
,
dim_out
in
in_out
:
self
.
blocks
.
append
(
nn
.
ModuleList
([
ResidualTemporalBlock
(
dim_in
,
dim_out
,
kernel_size
=
5
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
ResidualTemporalBlock
(
dim_out
,
dim_out
,
kernel_size
=
5
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
Downsample1d
(
dim_out
)
]))
horizon
=
horizon
//
2
fc_dim
=
dims
[
-
1
]
*
max
(
horizon
,
1
)
self
.
final_block
=
nn
.
Sequential
(
nn
.
Linear
(
fc_dim
+
time_dim
,
fc_dim
//
2
),
nn
.
Mish
(),
nn
.
Linear
(
fc_dim
//
2
,
out_dim
),
)
def
forward
(
self
,
x
,
cond
,
time
,
*
args
):
'''
x : [ batch x horizon x transition ]
'''
x
=
einops
.
rearrange
(
x
,
'b h t -> b t h'
)
t
=
self
.
time_mlp
(
time
)
for
resnet
,
resnet2
,
downsample
in
self
.
blocks
:
x
=
resnet
(
x
,
t
)
x
=
resnet2
(
x
,
t
)
x
=
downsample
(
x
)
x
=
x
.
view
(
len
(
x
),
-
1
)
out
=
self
.
final_block
(
torch
.
cat
([
x
,
t
],
dim
=-
1
))
return
out
\ No newline at end of file
src/diffusers/pipelines/grad_tts_utils.py
View file @
a677565f
...
...
@@ -233,8 +233,12 @@ def english_cleaners(text):
text
=
collapse_whitespace
(
text
)
return
text
try
:
_inflect
=
inflect
.
engine
()
except
:
print
(
"inflect is not installed"
)
_inflect
=
None
_inflect
=
inflect
.
engine
()
_comma_number_re
=
re
.
compile
(
r
"([0-9][0-9\,]+[0-9])"
)
_decimal_number_re
=
re
.
compile
(
r
"([0-9]+\.[0-9]+)"
)
_pounds_re
=
re
.
compile
(
r
"£([0-9\,]*[0-9]+)"
)
...
...
src/diffusers/schedulers/scheduling_ddpm.py
View file @
a677565f
...
...
@@ -105,12 +105,15 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
# hacks - were probs added for training stability
if
self
.
config
.
variance_type
==
"fixed_small"
:
variance
=
self
.
clip
(
variance
,
min_value
=
1e-20
)
# for rl-diffuser https://arxiv.org/abs/2205.09991
elif
self
.
config
.
variance_type
==
"fixed_small_log"
:
variance
=
self
.
log
(
self
.
clip
(
variance
,
min_value
=
1e-20
))
elif
self
.
config
.
variance_type
==
"fixed_large"
:
variance
=
self
.
get_beta
(
t
)
return
variance
def
step
(
self
,
residual
,
sample
,
t
):
def
step
(
self
,
residual
,
sample
,
t
,
predict_epsilon
=
True
):
# 1. compute alphas, betas
alpha_prod_t
=
self
.
get_alpha_prod
(
t
)
alpha_prod_t_prev
=
self
.
get_alpha_prod
(
t
-
1
)
...
...
@@ -119,7 +122,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
# 2. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
pred_original_sample
=
(
sample
-
beta_prod_t
**
(
0.5
)
*
residual
)
/
alpha_prod_t
**
(
0.5
)
if
predict_epsilon
:
pred_original_sample
=
(
sample
-
beta_prod_t
**
(
0.5
)
*
residual
)
/
alpha_prod_t
**
(
0.5
)
else
:
pred_original_sample
=
residual
# 3. Clip "predicted x_0"
if
self
.
config
.
clip_sample
:
...
...
src/diffusers/schedulers/scheduling_utils.py
View file @
a677565f
...
...
@@ -64,3 +64,13 @@ class SchedulerMixin:
return
torch
.
clamp
(
tensor
,
min_value
,
max_value
)
raise
ValueError
(
f
"`self.tensor_format`:
{
self
.
tensor_format
}
is not valid."
)
def
log
(
self
,
tensor
):
tensor_format
=
getattr
(
self
,
"tensor_format"
,
"pt"
)
if
tensor_format
==
"np"
:
return
np
.
log
(
tensor
)
elif
tensor_format
==
"pt"
:
return
torch
.
log
(
tensor
)
raise
ValueError
(
f
"`self.tensor_format`:
{
self
.
tensor_format
}
is not valid."
)
tests/test_modeling_utils.py
View file @
a677565f
...
...
@@ -14,8 +14,10 @@
# limitations under the License.
import
inspect
import
tempfile
import
unittest
import
numpy
as
np
import
torch
...
...
@@ -82,41 +84,162 @@ class ConfigTester(unittest.TestCase):
assert
config
==
new_config
class
ModelTesterMixin
(
unittest
.
TestCase
):
@
property
def
dummy_input
(
self
):
batch_size
=
4
num_channels
=
3
sizes
=
(
32
,
32
)
noise
=
floats_tensor
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]).
to
(
torch_device
)
return
(
noise
,
time_step
)
class
ModelTesterMixin
:
def
test_from_pretrained_save_pretrained
(
self
):
model
=
UNetModel
(
ch
=
32
,
ch_mult
=
(
1
,
2
),
num_res_blocks
=
2
,
attn_resolutions
=
(
16
,),
resolution
=
32
)
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
model
.
to
(
torch_device
)
model
.
eval
()
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
model
.
save_pretrained
(
tmpdirname
)
new_model
=
UNetModel
.
from_pretrained
(
tmpdirname
)
new_model
.
to
(
torch_device
)
dummy_input
=
self
.
dummy_input
with
torch
.
no_grad
():
image
=
model
(
**
inputs_dict
)
new_image
=
new_model
(
**
inputs_dict
)
max_diff
=
(
image
-
new_image
).
abs
().
sum
().
item
()
self
.
assertLessEqual
(
max_diff
,
1e-5
,
"Models give different forward passes"
)
def
test_determinism
(
self
):
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
model
.
to
(
torch_device
)
model
.
eval
()
with
torch
.
no_grad
():
first
=
model
(
**
inputs_dict
)
second
=
model
(
**
inputs_dict
)
out_1
=
first
.
cpu
().
numpy
()
out_2
=
second
.
cpu
().
numpy
()
out_1
=
out_1
[
~
np
.
isnan
(
out_1
)]
out_2
=
out_2
[
~
np
.
isnan
(
out_2
)]
max_diff
=
np
.
amax
(
np
.
abs
(
out_1
-
out_2
))
self
.
assertLessEqual
(
max_diff
,
1e-5
)
def
test_output
(
self
):
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
model
.
to
(
torch_device
)
model
.
eval
()
with
torch
.
no_grad
():
output
=
model
(
**
inputs_dict
)
self
.
assertIsNotNone
(
output
)
expected_shape
=
inputs_dict
[
"x"
].
shape
self
.
assertEqual
(
output
.
shape
,
expected_shape
,
"Input and output shapes do not match"
)
def
test_forward_signature
(
self
):
init_dict
,
_
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
signature
=
inspect
.
signature
(
model
.
forward
)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names
=
[
*
signature
.
parameters
.
keys
()]
expected_arg_names
=
[
"x"
,
"timesteps"
]
self
.
assertListEqual
(
arg_names
[:
2
],
expected_arg_names
)
def
test_model_from_config
(
self
):
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
model
.
to
(
torch_device
)
model
.
eval
()
# test if the model can be loaded from the config
# and has all the expected shape
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
model
.
save_config
(
tmpdirname
)
new_model
=
self
.
model_class
.
from_config
(
tmpdirname
)
new_model
.
to
(
torch_device
)
new_model
.
eval
()
# check if all paramters shape are the same
for
param_name
in
model
.
state_dict
().
keys
():
param_1
=
model
.
state_dict
()[
param_name
]
param_2
=
new_model
.
state_dict
()[
param_name
]
self
.
assertEqual
(
param_1
.
shape
,
param_2
.
shape
)
with
torch
.
no_grad
():
output_1
=
model
(
**
inputs_dict
)
output_2
=
new_model
(
**
inputs_dict
)
self
.
assertEqual
(
output_1
.
shape
,
output_2
.
shape
)
def
test_training
(
self
):
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
model
.
to
(
torch_device
)
model
.
train
()
output
=
model
(
**
inputs_dict
)
noise
=
torch
.
randn
(
inputs_dict
[
"x"
].
shape
).
to
(
torch_device
)
loss
=
torch
.
nn
.
functional
.
mse_loss
(
output
,
noise
)
loss
.
backward
()
image
=
model
(
*
dummy_input
)
new_image
=
new_model
(
*
dummy_input
)
class
UnetModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
UNetModel
assert
(
image
-
new_image
).
abs
().
sum
()
<
1e-5
,
"Models don't give the same forward pass"
@
property
def
dummy_input
(
self
):
batch_size
=
4
num_channels
=
3
sizes
=
(
32
,
32
)
noise
=
floats_tensor
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]).
to
(
torch_device
)
return
{
"x"
:
noise
,
"timesteps"
:
time_step
}
def
prepare_init_args_and_inputs_for_common
(
self
):
init_dict
=
{
"ch"
:
32
,
"ch_mult"
:
(
1
,
2
),
"num_res_blocks"
:
2
,
"attn_resolutions"
:
(
16
,),
"resolution"
:
32
,
}
inputs_dict
=
self
.
dummy_input
return
init_dict
,
inputs_dict
def
test_from_pretrained_hub
(
self
):
model
=
UNetModel
.
from_pretrained
(
"fusing/ddpm_dummy"
)
model
.
to
(
torch_device
)
model
,
loading_info
=
UNetModel
.
from_pretrained
(
"fusing/ddpm_dummy"
,
output_loading_info
=
True
)
self
.
assertIsNotNone
(
model
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
image
=
model
(
*
self
.
dummy_input
)
model
.
to
(
torch_device
)
image
=
model
(
**
self
.
dummy_input
)
assert
image
is
not
None
,
"Make sure output is not None"
def
test_output_pretrained
(
self
):
model
=
UNetModel
.
from_pretrained
(
"fusing/ddpm_dummy"
)
model
.
eval
()
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
noise
=
torch
.
randn
(
1
,
model
.
config
.
in_channels
,
model
.
config
.
resolution
,
model
.
config
.
resolution
)
print
(
noise
.
shape
)
time_step
=
torch
.
tensor
([
10
])
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
)
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
expected_output_slice
=
torch
.
tensor
([
0.2891
,
-
0.1899
,
0.2595
,
-
0.6214
,
0.0968
,
-
0.2622
,
0.4688
,
0.1311
,
0.0053
])
# fmt: on
print
(
output_slice
)
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
class
PipelineTesterMixin
(
unittest
.
TestCase
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment