Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
a677565f
Commit
a677565f
authored
Jun 17, 2022
by
Patrick von Platen
Browse files
Merge branch 'main' of
https://github.com/huggingface/diffusers
parents
ff885b0e
d182a6ad
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
412 additions
and
27 deletions
+412
-27
src/diffusers/models/unet.py
src/diffusers/models/unet.py
+4
-4
src/diffusers/models/unet_rl.py
src/diffusers/models/unet_rl.py
+242
-0
src/diffusers/pipelines/grad_tts_utils.py
src/diffusers/pipelines/grad_tts_utils.py
+5
-1
src/diffusers/schedulers/scheduling_ddpm.py
src/diffusers/schedulers/scheduling_ddpm.py
+8
-2
src/diffusers/schedulers/scheduling_utils.py
src/diffusers/schedulers/scheduling_utils.py
+10
-0
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+143
-20
No files found.
src/diffusers/models/unet.py
View file @
a677565f
...
...
@@ -287,14 +287,14 @@ class UNetModel(ModelMixin, ConfigMixin):
self
.
norm_out
=
Normalize
(
block_in
)
self
.
conv_out
=
torch
.
nn
.
Conv2d
(
block_in
,
out_ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
x
,
t
):
def
forward
(
self
,
x
,
t
imesteps
):
assert
x
.
shape
[
2
]
==
x
.
shape
[
3
]
==
self
.
resolution
if
not
torch
.
is_tensor
(
t
):
t
=
torch
.
tensor
([
t
],
dtype
=
torch
.
long
,
device
=
x
.
device
)
if
not
torch
.
is_tensor
(
t
imesteps
):
t
imesteps
=
torch
.
tensor
([
t
imesteps
],
dtype
=
torch
.
long
,
device
=
x
.
device
)
# timestep embedding
temb
=
get_timestep_embedding
(
t
,
self
.
ch
)
temb
=
get_timestep_embedding
(
t
imesteps
,
self
.
ch
)
temb
=
self
.
temb
.
dense
[
0
](
temb
)
temb
=
nonlinearity
(
temb
)
temb
=
self
.
temb
.
dense
[
1
](
temb
)
...
...
src/diffusers/models/unet_rl.py
0 → 100644
View file @
a677565f
# model adapted from diffuser https://github.com/jannerm/diffuser/blob/main/diffuser/models/temporal.py
import
torch
import
torch.nn
as
nn
import
einops
from
einops.layers.torch
import
Rearrange
import
math
class
SinusoidalPosEmb
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
().
__init__
()
self
.
dim
=
dim
def
forward
(
self
,
x
):
device
=
x
.
device
half_dim
=
self
.
dim
//
2
emb
=
math
.
log
(
10000
)
/
(
half_dim
-
1
)
emb
=
torch
.
exp
(
torch
.
arange
(
half_dim
,
device
=
device
)
*
-
emb
)
emb
=
x
[:,
None
]
*
emb
[
None
,
:]
emb
=
torch
.
cat
((
emb
.
sin
(),
emb
.
cos
()),
dim
=-
1
)
return
emb
class
Downsample1d
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
().
__init__
()
self
.
conv
=
nn
.
Conv1d
(
dim
,
dim
,
3
,
2
,
1
)
def
forward
(
self
,
x
):
return
self
.
conv
(
x
)
class
Upsample1d
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
().
__init__
()
self
.
conv
=
nn
.
ConvTranspose1d
(
dim
,
dim
,
4
,
2
,
1
)
def
forward
(
self
,
x
):
return
self
.
conv
(
x
)
class
Conv1dBlock
(
nn
.
Module
):
'''
Conv1d --> GroupNorm --> Mish
'''
def
__init__
(
self
,
inp_channels
,
out_channels
,
kernel_size
,
n_groups
=
8
):
super
().
__init__
()
self
.
block
=
nn
.
Sequential
(
nn
.
Conv1d
(
inp_channels
,
out_channels
,
kernel_size
,
padding
=
kernel_size
//
2
),
Rearrange
(
'batch channels horizon -> batch channels 1 horizon'
),
nn
.
GroupNorm
(
n_groups
,
out_channels
),
Rearrange
(
'batch channels 1 horizon -> batch channels horizon'
),
nn
.
Mish
(),
)
def
forward
(
self
,
x
):
return
self
.
block
(
x
)
class
ResidualTemporalBlock
(
nn
.
Module
):
def
__init__
(
self
,
inp_channels
,
out_channels
,
embed_dim
,
horizon
,
kernel_size
=
5
):
super
().
__init__
()
self
.
blocks
=
nn
.
ModuleList
([
Conv1dBlock
(
inp_channels
,
out_channels
,
kernel_size
),
Conv1dBlock
(
out_channels
,
out_channels
,
kernel_size
),
])
self
.
time_mlp
=
nn
.
Sequential
(
nn
.
Mish
(),
nn
.
Linear
(
embed_dim
,
out_channels
),
Rearrange
(
'batch t -> batch t 1'
),
)
self
.
residual_conv
=
nn
.
Conv1d
(
inp_channels
,
out_channels
,
1
)
\
if
inp_channels
!=
out_channels
else
nn
.
Identity
()
def
forward
(
self
,
x
,
t
):
'''
x : [ batch_size x inp_channels x horizon ]
t : [ batch_size x embed_dim ]
returns:
out : [ batch_size x out_channels x horizon ]
'''
out
=
self
.
blocks
[
0
](
x
)
+
self
.
time_mlp
(
t
)
out
=
self
.
blocks
[
1
](
out
)
return
out
+
self
.
residual_conv
(
x
)
class
TemporalUnet
(
nn
.
Module
):
def
__init__
(
self
,
horizon
,
transition_dim
,
cond_dim
,
dim
=
32
,
dim_mults
=
(
1
,
2
,
4
,
8
),
):
super
().
__init__
()
dims
=
[
transition_dim
,
*
map
(
lambda
m
:
dim
*
m
,
dim_mults
)]
in_out
=
list
(
zip
(
dims
[:
-
1
],
dims
[
1
:]))
print
(
f
'[ models/temporal ] Channel dimensions:
{
in_out
}
'
)
time_dim
=
dim
self
.
time_mlp
=
nn
.
Sequential
(
SinusoidalPosEmb
(
dim
),
nn
.
Linear
(
dim
,
dim
*
4
),
nn
.
Mish
(),
nn
.
Linear
(
dim
*
4
,
dim
),
)
self
.
downs
=
nn
.
ModuleList
([])
self
.
ups
=
nn
.
ModuleList
([])
num_resolutions
=
len
(
in_out
)
print
(
in_out
)
for
ind
,
(
dim_in
,
dim_out
)
in
enumerate
(
in_out
):
is_last
=
ind
>=
(
num_resolutions
-
1
)
self
.
downs
.
append
(
nn
.
ModuleList
([
ResidualTemporalBlock
(
dim_in
,
dim_out
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
ResidualTemporalBlock
(
dim_out
,
dim_out
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
Downsample1d
(
dim_out
)
if
not
is_last
else
nn
.
Identity
()
]))
if
not
is_last
:
horizon
=
horizon
//
2
mid_dim
=
dims
[
-
1
]
self
.
mid_block1
=
ResidualTemporalBlock
(
mid_dim
,
mid_dim
,
embed_dim
=
time_dim
,
horizon
=
horizon
)
self
.
mid_block2
=
ResidualTemporalBlock
(
mid_dim
,
mid_dim
,
embed_dim
=
time_dim
,
horizon
=
horizon
)
for
ind
,
(
dim_in
,
dim_out
)
in
enumerate
(
reversed
(
in_out
[
1
:])):
is_last
=
ind
>=
(
num_resolutions
-
1
)
self
.
ups
.
append
(
nn
.
ModuleList
([
ResidualTemporalBlock
(
dim_out
*
2
,
dim_in
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
ResidualTemporalBlock
(
dim_in
,
dim_in
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
Upsample1d
(
dim_in
)
if
not
is_last
else
nn
.
Identity
()
]))
if
not
is_last
:
horizon
=
horizon
*
2
self
.
final_conv
=
nn
.
Sequential
(
Conv1dBlock
(
dim
,
dim
,
kernel_size
=
5
),
nn
.
Conv1d
(
dim
,
transition_dim
,
1
),
)
def
forward
(
self
,
x
,
cond
,
time
):
'''
x : [ batch x horizon x transition ]
'''
x
=
einops
.
rearrange
(
x
,
'b h t -> b t h'
)
t
=
self
.
time_mlp
(
time
)
h
=
[]
for
resnet
,
resnet2
,
downsample
in
self
.
downs
:
x
=
resnet
(
x
,
t
)
x
=
resnet2
(
x
,
t
)
h
.
append
(
x
)
x
=
downsample
(
x
)
x
=
self
.
mid_block1
(
x
,
t
)
x
=
self
.
mid_block2
(
x
,
t
)
for
resnet
,
resnet2
,
upsample
in
self
.
ups
:
x
=
torch
.
cat
((
x
,
h
.
pop
()),
dim
=
1
)
x
=
resnet
(
x
,
t
)
x
=
resnet2
(
x
,
t
)
x
=
upsample
(
x
)
x
=
self
.
final_conv
(
x
)
x
=
einops
.
rearrange
(
x
,
'b t h -> b h t'
)
return
x
class
TemporalValue
(
nn
.
Module
):
def
__init__
(
self
,
horizon
,
transition_dim
,
cond_dim
,
dim
=
32
,
time_dim
=
None
,
out_dim
=
1
,
dim_mults
=
(
1
,
2
,
4
,
8
),
):
super
().
__init__
()
dims
=
[
transition_dim
,
*
map
(
lambda
m
:
dim
*
m
,
dim_mults
)]
in_out
=
list
(
zip
(
dims
[:
-
1
],
dims
[
1
:]))
time_dim
=
time_dim
or
dim
self
.
time_mlp
=
nn
.
Sequential
(
SinusoidalPosEmb
(
dim
),
nn
.
Linear
(
dim
,
dim
*
4
),
nn
.
Mish
(),
nn
.
Linear
(
dim
*
4
,
dim
),
)
self
.
blocks
=
nn
.
ModuleList
([])
print
(
in_out
)
for
dim_in
,
dim_out
in
in_out
:
self
.
blocks
.
append
(
nn
.
ModuleList
([
ResidualTemporalBlock
(
dim_in
,
dim_out
,
kernel_size
=
5
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
ResidualTemporalBlock
(
dim_out
,
dim_out
,
kernel_size
=
5
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
Downsample1d
(
dim_out
)
]))
horizon
=
horizon
//
2
fc_dim
=
dims
[
-
1
]
*
max
(
horizon
,
1
)
self
.
final_block
=
nn
.
Sequential
(
nn
.
Linear
(
fc_dim
+
time_dim
,
fc_dim
//
2
),
nn
.
Mish
(),
nn
.
Linear
(
fc_dim
//
2
,
out_dim
),
)
def
forward
(
self
,
x
,
cond
,
time
,
*
args
):
'''
x : [ batch x horizon x transition ]
'''
x
=
einops
.
rearrange
(
x
,
'b h t -> b t h'
)
t
=
self
.
time_mlp
(
time
)
for
resnet
,
resnet2
,
downsample
in
self
.
blocks
:
x
=
resnet
(
x
,
t
)
x
=
resnet2
(
x
,
t
)
x
=
downsample
(
x
)
x
=
x
.
view
(
len
(
x
),
-
1
)
out
=
self
.
final_block
(
torch
.
cat
([
x
,
t
],
dim
=-
1
))
return
out
\ No newline at end of file
src/diffusers/pipelines/grad_tts_utils.py
View file @
a677565f
...
...
@@ -233,8 +233,12 @@ def english_cleaners(text):
text
=
collapse_whitespace
(
text
)
return
text
try
:
_inflect
=
inflect
.
engine
()
except
:
print
(
"inflect is not installed"
)
_inflect
=
None
_inflect
=
inflect
.
engine
()
_comma_number_re
=
re
.
compile
(
r
"([0-9][0-9\,]+[0-9])"
)
_decimal_number_re
=
re
.
compile
(
r
"([0-9]+\.[0-9]+)"
)
_pounds_re
=
re
.
compile
(
r
"£([0-9\,]*[0-9]+)"
)
...
...
src/diffusers/schedulers/scheduling_ddpm.py
View file @
a677565f
...
...
@@ -105,12 +105,15 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
# hacks - were probs added for training stability
if
self
.
config
.
variance_type
==
"fixed_small"
:
variance
=
self
.
clip
(
variance
,
min_value
=
1e-20
)
# for rl-diffuser https://arxiv.org/abs/2205.09991
elif
self
.
config
.
variance_type
==
"fixed_small_log"
:
variance
=
self
.
log
(
self
.
clip
(
variance
,
min_value
=
1e-20
))
elif
self
.
config
.
variance_type
==
"fixed_large"
:
variance
=
self
.
get_beta
(
t
)
return
variance
def
step
(
self
,
residual
,
sample
,
t
):
def
step
(
self
,
residual
,
sample
,
t
,
predict_epsilon
=
True
):
# 1. compute alphas, betas
alpha_prod_t
=
self
.
get_alpha_prod
(
t
)
alpha_prod_t_prev
=
self
.
get_alpha_prod
(
t
-
1
)
...
...
@@ -119,7 +122,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
# 2. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
if
predict_epsilon
:
pred_original_sample
=
(
sample
-
beta_prod_t
**
(
0.5
)
*
residual
)
/
alpha_prod_t
**
(
0.5
)
else
:
pred_original_sample
=
residual
# 3. Clip "predicted x_0"
if
self
.
config
.
clip_sample
:
...
...
src/diffusers/schedulers/scheduling_utils.py
View file @
a677565f
...
...
@@ -64,3 +64,13 @@ class SchedulerMixin:
return
torch
.
clamp
(
tensor
,
min_value
,
max_value
)
raise
ValueError
(
f
"`self.tensor_format`:
{
self
.
tensor_format
}
is not valid."
)
def
log
(
self
,
tensor
):
tensor_format
=
getattr
(
self
,
"tensor_format"
,
"pt"
)
if
tensor_format
==
"np"
:
return
np
.
log
(
tensor
)
elif
tensor_format
==
"pt"
:
return
torch
.
log
(
tensor
)
raise
ValueError
(
f
"`self.tensor_format`:
{
self
.
tensor_format
}
is not valid."
)
tests/test_modeling_utils.py
View file @
a677565f
...
...
@@ -14,8 +14,10 @@
# limitations under the License.
import
inspect
import
tempfile
import
unittest
import
numpy
as
np
import
torch
...
...
@@ -82,7 +84,108 @@ class ConfigTester(unittest.TestCase):
assert
config
==
new_config
class
ModelTesterMixin
(
unittest
.
TestCase
):
class
ModelTesterMixin
:
def
test_from_pretrained_save_pretrained
(
self
):
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
model
.
to
(
torch_device
)
model
.
eval
()
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
model
.
save_pretrained
(
tmpdirname
)
new_model
=
UNetModel
.
from_pretrained
(
tmpdirname
)
new_model
.
to
(
torch_device
)
with
torch
.
no_grad
():
image
=
model
(
**
inputs_dict
)
new_image
=
new_model
(
**
inputs_dict
)
max_diff
=
(
image
-
new_image
).
abs
().
sum
().
item
()
self
.
assertLessEqual
(
max_diff
,
1e-5
,
"Models give different forward passes"
)
def
test_determinism
(
self
):
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
model
.
to
(
torch_device
)
model
.
eval
()
with
torch
.
no_grad
():
first
=
model
(
**
inputs_dict
)
second
=
model
(
**
inputs_dict
)
out_1
=
first
.
cpu
().
numpy
()
out_2
=
second
.
cpu
().
numpy
()
out_1
=
out_1
[
~
np
.
isnan
(
out_1
)]
out_2
=
out_2
[
~
np
.
isnan
(
out_2
)]
max_diff
=
np
.
amax
(
np
.
abs
(
out_1
-
out_2
))
self
.
assertLessEqual
(
max_diff
,
1e-5
)
def
test_output
(
self
):
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
model
.
to
(
torch_device
)
model
.
eval
()
with
torch
.
no_grad
():
output
=
model
(
**
inputs_dict
)
self
.
assertIsNotNone
(
output
)
expected_shape
=
inputs_dict
[
"x"
].
shape
self
.
assertEqual
(
output
.
shape
,
expected_shape
,
"Input and output shapes do not match"
)
def
test_forward_signature
(
self
):
init_dict
,
_
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
signature
=
inspect
.
signature
(
model
.
forward
)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names
=
[
*
signature
.
parameters
.
keys
()]
expected_arg_names
=
[
"x"
,
"timesteps"
]
self
.
assertListEqual
(
arg_names
[:
2
],
expected_arg_names
)
def
test_model_from_config
(
self
):
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
model
.
to
(
torch_device
)
model
.
eval
()
# test if the model can be loaded from the config
# and has all the expected shape
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
model
.
save_config
(
tmpdirname
)
new_model
=
self
.
model_class
.
from_config
(
tmpdirname
)
new_model
.
to
(
torch_device
)
new_model
.
eval
()
# check if all paramters shape are the same
for
param_name
in
model
.
state_dict
().
keys
():
param_1
=
model
.
state_dict
()[
param_name
]
param_2
=
new_model
.
state_dict
()[
param_name
]
self
.
assertEqual
(
param_1
.
shape
,
param_2
.
shape
)
with
torch
.
no_grad
():
output_1
=
model
(
**
inputs_dict
)
output_2
=
new_model
(
**
inputs_dict
)
self
.
assertEqual
(
output_1
.
shape
,
output_2
.
shape
)
def
test_training
(
self
):
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
model
.
to
(
torch_device
)
model
.
train
()
output
=
model
(
**
inputs_dict
)
noise
=
torch
.
randn
(
inputs_dict
[
"x"
].
shape
).
to
(
torch_device
)
loss
=
torch
.
nn
.
functional
.
mse_loss
(
output
,
noise
)
loss
.
backward
()
class
UnetModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
UNetModel
@
property
def
dummy_input
(
self
):
batch_size
=
4
...
...
@@ -92,31 +195,51 @@ class ModelTesterMixin(unittest.TestCase):
noise
=
floats_tensor
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]).
to
(
torch_device
)
return
(
noise
,
time_step
)
return
{
"x"
:
noise
,
"timesteps"
:
time_step
}
def
prepare_init_args_and_inputs_for_common
(
self
):
init_dict
=
{
"ch"
:
32
,
"ch_mult"
:
(
1
,
2
),
"num_res_blocks"
:
2
,
"attn_resolutions"
:
(
16
,),
"resolution"
:
32
,
}
inputs_dict
=
self
.
dummy_input
return
init_dict
,
inputs_dict
def
test_from_pretrained_hub
(
self
):
model
,
loading_info
=
UNetModel
.
from_pretrained
(
"fusing/ddpm_dummy"
,
output_loading_info
=
True
)
self
.
assertIsNotNone
(
model
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
def
test_from_pretrained_save_pretrained
(
self
):
model
=
UNetModel
(
ch
=
32
,
ch_mult
=
(
1
,
2
),
num_res_blocks
=
2
,
attn_resolutions
=
(
16
,),
resolution
=
32
)
model
.
to
(
torch_device
)
image
=
model
(
**
self
.
dummy_input
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
model
.
save_pretrained
(
tmpdirname
)
new_model
=
UNetModel
.
from_pretrained
(
tmpdirname
)
new_model
.
to
(
torch_device
)
assert
image
is
not
None
,
"Make sure output is not None"
dummy_input
=
self
.
dummy_input
def
test_output_pretrained
(
self
):
model
=
UNetModel
.
from_pretrained
(
"fusing/ddpm_dummy"
)
model
.
eval
()
image
=
model
(
*
dummy_input
)
new_image
=
new_model
(
*
dummy_input
)
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
assert
(
image
-
new_image
).
abs
().
sum
()
<
1e-5
,
"Models don't give the same forward pass"
noise
=
torch
.
randn
(
1
,
model
.
config
.
in_channels
,
model
.
config
.
resolution
,
model
.
config
.
resolution
)
print
(
noise
.
shape
)
time_step
=
torch
.
tensor
([
10
])
def
test_from_pretrained_hub
(
self
):
model
=
UNetModel
.
from_pretrained
(
"fusing/ddpm_dummy"
)
model
.
to
(
torch_device
)
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
)
image
=
model
(
*
self
.
dummy_input
)
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
expected_output_slice
=
torch
.
tensor
([
0.2891
,
-
0.1899
,
0.2595
,
-
0.6214
,
0.0968
,
-
0.2622
,
0.4688
,
0.1311
,
0.0053
])
# fmt: on
print
(
output_slice
)
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
assert
image
is
not
None
,
"Make sure output is not None"
class
PipelineTesterMixin
(
unittest
.
TestCase
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment