Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
08c85229
Commit
08c85229
authored
Jun 20, 2022
by
Patrick von Platen
Browse files
add license disclaimers to schedulers
parent
2b8bc91c
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
123 additions
and
84 deletions
+123
-84
src/diffusers/models/unet_rl.py
src/diffusers/models/unet_rl.py
+69
-48
src/diffusers/pipelines/grad_tts_utils.py
src/diffusers/pipelines/grad_tts_utils.py
+1
-0
src/diffusers/schedulers/scheduling_ddim.py
src/diffusers/schedulers/scheduling_ddim.py
+6
-1
src/diffusers/schedulers/scheduling_ddpm.py
src/diffusers/schedulers/scheduling_ddpm.py
+5
-1
src/diffusers/schedulers/scheduling_pndm.py
src/diffusers/schedulers/scheduling_pndm.py
+7
-2
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+35
-32
No files found.
src/diffusers/models/unet_rl.py
View file @
08c85229
# model adapted from diffuser https://github.com/jannerm/diffuser/blob/main/diffuser/models/temporal.py
import
math
import
torch
import
torch.nn
as
nn
import
einops
from
einops.layers.torch
import
Rearrange
import
math
class
SinusoidalPosEmb
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
...
...
@@ -20,6 +23,7 @@ class SinusoidalPosEmb(nn.Module):
emb
=
torch
.
cat
((
emb
.
sin
(),
emb
.
cos
()),
dim
=-
1
)
return
emb
class
Downsample1d
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
().
__init__
()
...
...
@@ -28,6 +32,7 @@ class Downsample1d(nn.Module):
def
forward
(
self
,
x
):
return
self
.
conv
(
x
)
class
Upsample1d
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
().
__init__
()
...
...
@@ -36,57 +41,61 @@ class Upsample1d(nn.Module):
def
forward
(
self
,
x
):
return
self
.
conv
(
x
)
class
Conv1dBlock
(
nn
.
Module
):
'''
"""
Conv1d --> GroupNorm --> Mish
'''
"""
def
__init__
(
self
,
inp_channels
,
out_channels
,
kernel_size
,
n_groups
=
8
):
super
().
__init__
()
self
.
block
=
nn
.
Sequential
(
nn
.
Conv1d
(
inp_channels
,
out_channels
,
kernel_size
,
padding
=
kernel_size
//
2
),
Rearrange
(
'
batch channels horizon -> batch channels 1 horizon
'
),
Rearrange
(
"
batch channels horizon -> batch channels 1 horizon
"
),
nn
.
GroupNorm
(
n_groups
,
out_channels
),
Rearrange
(
'
batch channels 1 horizon -> batch channels horizon
'
),
Rearrange
(
"
batch channels 1 horizon -> batch channels horizon
"
),
nn
.
Mish
(),
)
def
forward
(
self
,
x
):
return
self
.
block
(
x
)
class
ResidualTemporalBlock
(
nn
.
Module
):
class
ResidualTemporalBlock
(
nn
.
Module
):
def
__init__
(
self
,
inp_channels
,
out_channels
,
embed_dim
,
horizon
,
kernel_size
=
5
):
super
().
__init__
()
self
.
blocks
=
nn
.
ModuleList
([
self
.
blocks
=
nn
.
ModuleList
(
[
Conv1dBlock
(
inp_channels
,
out_channels
,
kernel_size
),
Conv1dBlock
(
out_channels
,
out_channels
,
kernel_size
),
])
]
)
self
.
time_mlp
=
nn
.
Sequential
(
nn
.
Mish
(),
nn
.
Linear
(
embed_dim
,
out_channels
),
Rearrange
(
'
batch t -> batch t 1
'
),
Rearrange
(
"
batch t -> batch t 1
"
),
)
self
.
residual_conv
=
nn
.
Conv1d
(
inp_channels
,
out_channels
,
1
)
\
if
inp_channels
!=
out_channels
else
nn
.
Identity
()
self
.
residual_conv
=
(
nn
.
Conv1d
(
inp_channels
,
out_channels
,
1
)
if
inp_channels
!=
out_channels
else
nn
.
Identity
()
)
def
forward
(
self
,
x
,
t
):
'''
"""
x : [ batch_size x inp_channels x horizon ]
t : [ batch_size x embed_dim ]
returns:
out : [ batch_size x out_channels x horizon ]
'''
"""
out
=
self
.
blocks
[
0
](
x
)
+
self
.
time_mlp
(
t
)
out
=
self
.
blocks
[
1
](
out
)
return
out
+
self
.
residual_conv
(
x
)
class
TemporalUnet
(
nn
.
Module
):
class
TemporalUnet
(
nn
.
Module
):
def
__init__
(
self
,
horizon
,
...
...
@@ -99,7 +108,7 @@ class TemporalUnet(nn.Module):
dims
=
[
transition_dim
,
*
map
(
lambda
m
:
dim
*
m
,
dim_mults
)]
in_out
=
list
(
zip
(
dims
[:
-
1
],
dims
[
1
:]))
print
(
f
'
[ models/temporal ] Channel dimensions:
{
in_out
}
'
)
print
(
f
"
[ models/temporal ] Channel dimensions:
{
in_out
}
"
)
time_dim
=
dim
self
.
time_mlp
=
nn
.
Sequential
(
...
...
@@ -117,11 +126,15 @@ class TemporalUnet(nn.Module):
for
ind
,
(
dim_in
,
dim_out
)
in
enumerate
(
in_out
):
is_last
=
ind
>=
(
num_resolutions
-
1
)
self
.
downs
.
append
(
nn
.
ModuleList
([
self
.
downs
.
append
(
nn
.
ModuleList
(
[
ResidualTemporalBlock
(
dim_in
,
dim_out
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
ResidualTemporalBlock
(
dim_out
,
dim_out
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
Downsample1d
(
dim_out
)
if
not
is_last
else
nn
.
Identity
()
]))
Downsample1d
(
dim_out
)
if
not
is_last
else
nn
.
Identity
(),
]
)
)
if
not
is_last
:
horizon
=
horizon
//
2
...
...
@@ -133,11 +146,15 @@ class TemporalUnet(nn.Module):
for
ind
,
(
dim_in
,
dim_out
)
in
enumerate
(
reversed
(
in_out
[
1
:])):
is_last
=
ind
>=
(
num_resolutions
-
1
)
self
.
ups
.
append
(
nn
.
ModuleList
([
self
.
ups
.
append
(
nn
.
ModuleList
(
[
ResidualTemporalBlock
(
dim_out
*
2
,
dim_in
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
ResidualTemporalBlock
(
dim_in
,
dim_in
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
Upsample1d
(
dim_in
)
if
not
is_last
else
nn
.
Identity
()
]))
Upsample1d
(
dim_in
)
if
not
is_last
else
nn
.
Identity
(),
]
)
)
if
not
is_last
:
horizon
=
horizon
*
2
...
...
@@ -148,11 +165,11 @@ class TemporalUnet(nn.Module):
)
def
forward
(
self
,
x
,
cond
,
time
):
'''
"""
x : [ batch x horizon x transition ]
'''
"""
x
=
einops
.
rearrange
(
x
,
'
b h t -> b t h
'
)
x
=
einops
.
rearrange
(
x
,
"
b h t -> b t h
"
)
t
=
self
.
time_mlp
(
time
)
h
=
[]
...
...
@@ -174,11 +191,11 @@ class TemporalUnet(nn.Module):
x
=
self
.
final_conv
(
x
)
x
=
einops
.
rearrange
(
x
,
'
b t h -> b h t
'
)
x
=
einops
.
rearrange
(
x
,
"
b t h -> b h t
"
)
return
x
class
TemporalValue
(
nn
.
Module
):
class
TemporalValue
(
nn
.
Module
):
def
__init__
(
self
,
horizon
,
...
...
@@ -207,11 +224,15 @@ class TemporalValue(nn.Module):
print
(
in_out
)
for
dim_in
,
dim_out
in
in_out
:
self
.
blocks
.
append
(
nn
.
ModuleList
([
self
.
blocks
.
append
(
nn
.
ModuleList
(
[
ResidualTemporalBlock
(
dim_in
,
dim_out
,
kernel_size
=
5
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
ResidualTemporalBlock
(
dim_out
,
dim_out
,
kernel_size
=
5
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
Downsample1d
(
dim_out
)
]))
Downsample1d
(
dim_out
),
]
)
)
horizon
=
horizon
//
2
...
...
@@ -224,11 +245,11 @@ class TemporalValue(nn.Module):
)
def
forward
(
self
,
x
,
cond
,
time
,
*
args
):
'''
"""
x : [ batch x horizon x transition ]
'''
"""
x
=
einops
.
rearrange
(
x
,
'
b h t -> b t h
'
)
x
=
einops
.
rearrange
(
x
,
"
b h t -> b t h
"
)
t
=
self
.
time_mlp
(
time
)
...
...
src/diffusers/pipelines/grad_tts_utils.py
View file @
08c85229
...
...
@@ -233,6 +233,7 @@ def english_cleaners(text):
text
=
collapse_whitespace
(
text
)
return
text
try
:
_inflect
=
inflect
.
engine
()
except
:
...
...
src/diffusers/schedulers/scheduling_ddim.py
View file @
08c85229
# Copyright 2022 The HuggingFace Team. All rights reserved.
# Copyright 2022
Stanford University Team and
The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -11,6 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
# and https://github.com/hojonathanho/diffusion
import
math
import
numpy
as
np
...
...
@@ -31,6 +35,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
def
alpha_bar
(
time_step
):
return
math
.
cos
((
time_step
+
0.008
)
/
1.008
*
math
.
pi
/
2
)
**
2
...
...
src/diffusers/schedulers/scheduling_ddpm.py
View file @
08c85229
# Copyright 2022 The HuggingFace Team. All rights reserved.
# Copyright 2022
UC Berkely Team and
The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -11,6 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import
math
import
numpy
as
np
...
...
@@ -31,6 +34,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
def
alpha_bar
(
time_step
):
return
math
.
cos
((
time_step
+
0.008
)
/
1.008
*
math
.
pi
/
2
)
**
2
...
...
src/diffusers/schedulers/scheduling_pndm.py
View file @
08c85229
# Copyright 2022 The HuggingFace Team. All rights reserved.
# Copyright 2022
Zhejiang University Team and
The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -11,9 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import
math
import
numpy
as
np
from
..configuration_utils
import
ConfigMixin
from
.scheduling_utils
import
SchedulerMixin
...
...
@@ -30,6 +34,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
def
alpha_bar
(
time_step
):
return
math
.
cos
((
time_step
+
0.008
)
/
1.008
*
math
.
pi
/
2
)
**
2
...
...
tests/test_modeling_utils.py
View file @
08c85229
...
...
@@ -17,11 +17,11 @@
import
inspect
import
tempfile
import
unittest
import
numpy
as
np
import
pytest
import
numpy
as
np
import
torch
import
pytest
from
diffusers
import
(
BDDM
,
DDIM
,
...
...
@@ -30,10 +30,10 @@ from diffusers import (
PNDM
,
DDIMScheduler
,
DDPMScheduler
,
GLIDESuperResUNetModel
,
LatentDiffusion
,
PNDMScheduler
,
UNetModel
,
GLIDESuperResUNetModel
)
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.pipeline_utils
import
DiffusionPipeline
...
...
@@ -180,7 +180,7 @@ class ModelTesterMixin:
model
.
to
(
torch_device
)
model
.
train
()
output
=
model
(
**
inputs_dict
)
noise
=
torch
.
randn
((
inputs_dict
[
"x"
].
shape
[
0
],
)
+
self
.
get_output_shape
).
to
(
torch_device
)
noise
=
torch
.
randn
((
inputs_dict
[
"x"
].
shape
[
0
],)
+
self
.
get_output_shape
).
to
(
torch_device
)
loss
=
torch
.
nn
.
functional
.
mse_loss
(
output
,
noise
)
loss
.
backward
()
...
...
@@ -249,6 +249,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
print
(
output_slice
)
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
class
GLIDESuperResUNetTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
GLIDESuperResUNetModel
...
...
@@ -278,7 +279,7 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
def
prepare_init_args_and_inputs_for_common
(
self
):
init_dict
=
{
"attention_resolutions"
:
(
2
,),
"channel_mult"
:
(
1
,
2
),
"channel_mult"
:
(
1
,
2
),
"in_channels"
:
6
,
"out_channels"
:
6
,
"model_channels"
:
32
,
...
...
@@ -287,7 +288,7 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
"num_res_blocks"
:
2
,
"resblock_updown"
:
True
,
"resolution"
:
32
,
"use_scale_shift_norm"
:
True
"use_scale_shift_norm"
:
True
,
}
inputs_dict
=
self
.
dummy_input
return
init_dict
,
inputs_dict
...
...
@@ -308,7 +309,9 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
self
.
assertEqual
(
output
.
shape
,
expected_shape
,
"Input and output shapes do not match"
)
def
test_from_pretrained_hub
(
self
):
model
,
loading_info
=
GLIDESuperResUNetModel
.
from_pretrained
(
"fusing/glide-super-res-dummy"
,
output_loading_info
=
True
)
model
,
loading_info
=
GLIDESuperResUNetModel
.
from_pretrained
(
"fusing/glide-super-res-dummy"
,
output_loading_info
=
True
)
self
.
assertIsNotNone
(
model
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment