Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
08c85229
Commit
08c85229
authored
Jun 20, 2022
by
Patrick von Platen
Browse files
add license disclaimers to schedulers
parent
2b8bc91c
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
123 additions
and
84 deletions
+123
-84
src/diffusers/models/unet_rl.py
src/diffusers/models/unet_rl.py
+69
-48
src/diffusers/pipelines/grad_tts_utils.py
src/diffusers/pipelines/grad_tts_utils.py
+1
-0
src/diffusers/schedulers/scheduling_ddim.py
src/diffusers/schedulers/scheduling_ddim.py
+6
-1
src/diffusers/schedulers/scheduling_ddpm.py
src/diffusers/schedulers/scheduling_ddpm.py
+5
-1
src/diffusers/schedulers/scheduling_pndm.py
src/diffusers/schedulers/scheduling_pndm.py
+7
-2
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+35
-32
No files found.
src/diffusers/models/unet_rl.py
View file @
08c85229
# model adapted from diffuser https://github.com/jannerm/diffuser/blob/main/diffuser/models/temporal.py
# model adapted from diffuser https://github.com/jannerm/diffuser/blob/main/diffuser/models/temporal.py
import
math
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
einops
import
einops
from
einops.layers.torch
import
Rearrange
from
einops.layers.torch
import
Rearrange
import
math
class
SinusoidalPosEmb
(
nn
.
Module
):
class
SinusoidalPosEmb
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
def
__init__
(
self
,
dim
):
...
@@ -20,6 +23,7 @@ class SinusoidalPosEmb(nn.Module):
...
@@ -20,6 +23,7 @@ class SinusoidalPosEmb(nn.Module):
emb
=
torch
.
cat
((
emb
.
sin
(),
emb
.
cos
()),
dim
=-
1
)
emb
=
torch
.
cat
((
emb
.
sin
(),
emb
.
cos
()),
dim
=-
1
)
return
emb
return
emb
class
Downsample1d
(
nn
.
Module
):
class
Downsample1d
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
def
__init__
(
self
,
dim
):
super
().
__init__
()
super
().
__init__
()
...
@@ -28,6 +32,7 @@ class Downsample1d(nn.Module):
...
@@ -28,6 +32,7 @@ class Downsample1d(nn.Module):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
return
self
.
conv
(
x
)
return
self
.
conv
(
x
)
class
Upsample1d
(
nn
.
Module
):
class
Upsample1d
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
def
__init__
(
self
,
dim
):
super
().
__init__
()
super
().
__init__
()
...
@@ -36,57 +41,61 @@ class Upsample1d(nn.Module):
...
@@ -36,57 +41,61 @@ class Upsample1d(nn.Module):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
return
self
.
conv
(
x
)
return
self
.
conv
(
x
)
class
Conv1dBlock
(
nn
.
Module
):
class
Conv1dBlock
(
nn
.
Module
):
'''
"""
Conv1d --> GroupNorm --> Mish
Conv1d --> GroupNorm --> Mish
'''
"""
def
__init__
(
self
,
inp_channels
,
out_channels
,
kernel_size
,
n_groups
=
8
):
def
__init__
(
self
,
inp_channels
,
out_channels
,
kernel_size
,
n_groups
=
8
):
super
().
__init__
()
super
().
__init__
()
self
.
block
=
nn
.
Sequential
(
self
.
block
=
nn
.
Sequential
(
nn
.
Conv1d
(
inp_channels
,
out_channels
,
kernel_size
,
padding
=
kernel_size
//
2
),
nn
.
Conv1d
(
inp_channels
,
out_channels
,
kernel_size
,
padding
=
kernel_size
//
2
),
Rearrange
(
'
batch channels horizon -> batch channels 1 horizon
'
),
Rearrange
(
"
batch channels horizon -> batch channels 1 horizon
"
),
nn
.
GroupNorm
(
n_groups
,
out_channels
),
nn
.
GroupNorm
(
n_groups
,
out_channels
),
Rearrange
(
'
batch channels 1 horizon -> batch channels horizon
'
),
Rearrange
(
"
batch channels 1 horizon -> batch channels horizon
"
),
nn
.
Mish
(),
nn
.
Mish
(),
)
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
return
self
.
block
(
x
)
return
self
.
block
(
x
)
class
ResidualTemporalBlock
(
nn
.
Module
):
class
ResidualTemporalBlock
(
nn
.
Module
):
def
__init__
(
self
,
inp_channels
,
out_channels
,
embed_dim
,
horizon
,
kernel_size
=
5
):
def
__init__
(
self
,
inp_channels
,
out_channels
,
embed_dim
,
horizon
,
kernel_size
=
5
):
super
().
__init__
()
super
().
__init__
()
self
.
blocks
=
nn
.
ModuleList
([
self
.
blocks
=
nn
.
ModuleList
(
Conv1dBlock
(
inp_channels
,
out_channels
,
kernel_size
),
[
Conv1dBlock
(
out_channels
,
out_channels
,
kernel_size
),
Conv1dBlock
(
inp_channels
,
out_channels
,
kernel_size
),
])
Conv1dBlock
(
out_channels
,
out_channels
,
kernel_size
),
]
)
self
.
time_mlp
=
nn
.
Sequential
(
self
.
time_mlp
=
nn
.
Sequential
(
nn
.
Mish
(),
nn
.
Mish
(),
nn
.
Linear
(
embed_dim
,
out_channels
),
nn
.
Linear
(
embed_dim
,
out_channels
),
Rearrange
(
'
batch t -> batch t 1
'
),
Rearrange
(
"
batch t -> batch t 1
"
),
)
)
self
.
residual_conv
=
nn
.
Conv1d
(
inp_channels
,
out_channels
,
1
)
\
self
.
residual_conv
=
(
if
inp_channels
!=
out_channels
else
nn
.
Identity
()
nn
.
Conv1d
(
inp_channels
,
out_channels
,
1
)
if
inp_channels
!=
out_channels
else
nn
.
Identity
()
)
def
forward
(
self
,
x
,
t
):
def
forward
(
self
,
x
,
t
):
'''
"""
x : [ batch_size x inp_channels x horizon ]
x : [ batch_size x inp_channels x horizon ]
t : [ batch_size x embed_dim ]
t : [ batch_size x embed_dim ]
returns:
returns:
out : [ batch_size x out_channels x horizon ]
out : [ batch_size x out_channels x horizon ]
'''
"""
out
=
self
.
blocks
[
0
](
x
)
+
self
.
time_mlp
(
t
)
out
=
self
.
blocks
[
0
](
x
)
+
self
.
time_mlp
(
t
)
out
=
self
.
blocks
[
1
](
out
)
out
=
self
.
blocks
[
1
](
out
)
return
out
+
self
.
residual_conv
(
x
)
return
out
+
self
.
residual_conv
(
x
)
class
TemporalUnet
(
nn
.
Module
):
class
TemporalUnet
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
horizon
,
horizon
,
...
@@ -99,7 +108,7 @@ class TemporalUnet(nn.Module):
...
@@ -99,7 +108,7 @@ class TemporalUnet(nn.Module):
dims
=
[
transition_dim
,
*
map
(
lambda
m
:
dim
*
m
,
dim_mults
)]
dims
=
[
transition_dim
,
*
map
(
lambda
m
:
dim
*
m
,
dim_mults
)]
in_out
=
list
(
zip
(
dims
[:
-
1
],
dims
[
1
:]))
in_out
=
list
(
zip
(
dims
[:
-
1
],
dims
[
1
:]))
print
(
f
'
[ models/temporal ] Channel dimensions:
{
in_out
}
'
)
print
(
f
"
[ models/temporal ] Channel dimensions:
{
in_out
}
"
)
time_dim
=
dim
time_dim
=
dim
self
.
time_mlp
=
nn
.
Sequential
(
self
.
time_mlp
=
nn
.
Sequential
(
...
@@ -117,11 +126,15 @@ class TemporalUnet(nn.Module):
...
@@ -117,11 +126,15 @@ class TemporalUnet(nn.Module):
for
ind
,
(
dim_in
,
dim_out
)
in
enumerate
(
in_out
):
for
ind
,
(
dim_in
,
dim_out
)
in
enumerate
(
in_out
):
is_last
=
ind
>=
(
num_resolutions
-
1
)
is_last
=
ind
>=
(
num_resolutions
-
1
)
self
.
downs
.
append
(
nn
.
ModuleList
([
self
.
downs
.
append
(
ResidualTemporalBlock
(
dim_in
,
dim_out
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
nn
.
ModuleList
(
ResidualTemporalBlock
(
dim_out
,
dim_out
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
[
Downsample1d
(
dim_out
)
if
not
is_last
else
nn
.
Identity
()
ResidualTemporalBlock
(
dim_in
,
dim_out
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
]))
ResidualTemporalBlock
(
dim_out
,
dim_out
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
Downsample1d
(
dim_out
)
if
not
is_last
else
nn
.
Identity
(),
]
)
)
if
not
is_last
:
if
not
is_last
:
horizon
=
horizon
//
2
horizon
=
horizon
//
2
...
@@ -133,11 +146,15 @@ class TemporalUnet(nn.Module):
...
@@ -133,11 +146,15 @@ class TemporalUnet(nn.Module):
for
ind
,
(
dim_in
,
dim_out
)
in
enumerate
(
reversed
(
in_out
[
1
:])):
for
ind
,
(
dim_in
,
dim_out
)
in
enumerate
(
reversed
(
in_out
[
1
:])):
is_last
=
ind
>=
(
num_resolutions
-
1
)
is_last
=
ind
>=
(
num_resolutions
-
1
)
self
.
ups
.
append
(
nn
.
ModuleList
([
self
.
ups
.
append
(
ResidualTemporalBlock
(
dim_out
*
2
,
dim_in
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
nn
.
ModuleList
(
ResidualTemporalBlock
(
dim_in
,
dim_in
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
[
Upsample1d
(
dim_in
)
if
not
is_last
else
nn
.
Identity
()
ResidualTemporalBlock
(
dim_out
*
2
,
dim_in
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
]))
ResidualTemporalBlock
(
dim_in
,
dim_in
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
Upsample1d
(
dim_in
)
if
not
is_last
else
nn
.
Identity
(),
]
)
)
if
not
is_last
:
if
not
is_last
:
horizon
=
horizon
*
2
horizon
=
horizon
*
2
...
@@ -148,11 +165,11 @@ class TemporalUnet(nn.Module):
...
@@ -148,11 +165,11 @@ class TemporalUnet(nn.Module):
)
)
def
forward
(
self
,
x
,
cond
,
time
):
def
forward
(
self
,
x
,
cond
,
time
):
'''
"""
x : [ batch x horizon x transition ]
x : [ batch x horizon x transition ]
'''
"""
x
=
einops
.
rearrange
(
x
,
'
b h t -> b t h
'
)
x
=
einops
.
rearrange
(
x
,
"
b h t -> b t h
"
)
t
=
self
.
time_mlp
(
time
)
t
=
self
.
time_mlp
(
time
)
h
=
[]
h
=
[]
...
@@ -174,11 +191,11 @@ class TemporalUnet(nn.Module):
...
@@ -174,11 +191,11 @@ class TemporalUnet(nn.Module):
x
=
self
.
final_conv
(
x
)
x
=
self
.
final_conv
(
x
)
x
=
einops
.
rearrange
(
x
,
'
b t h -> b h t
'
)
x
=
einops
.
rearrange
(
x
,
"
b t h -> b h t
"
)
return
x
return
x
class
TemporalValue
(
nn
.
Module
):
class
TemporalValue
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
horizon
,
horizon
,
...
@@ -207,11 +224,15 @@ class TemporalValue(nn.Module):
...
@@ -207,11 +224,15 @@ class TemporalValue(nn.Module):
print
(
in_out
)
print
(
in_out
)
for
dim_in
,
dim_out
in
in_out
:
for
dim_in
,
dim_out
in
in_out
:
self
.
blocks
.
append
(
nn
.
ModuleList
([
self
.
blocks
.
append
(
ResidualTemporalBlock
(
dim_in
,
dim_out
,
kernel_size
=
5
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
nn
.
ModuleList
(
ResidualTemporalBlock
(
dim_out
,
dim_out
,
kernel_size
=
5
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
[
Downsample1d
(
dim_out
)
ResidualTemporalBlock
(
dim_in
,
dim_out
,
kernel_size
=
5
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
]))
ResidualTemporalBlock
(
dim_out
,
dim_out
,
kernel_size
=
5
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
Downsample1d
(
dim_out
),
]
)
)
horizon
=
horizon
//
2
horizon
=
horizon
//
2
...
@@ -224,11 +245,11 @@ class TemporalValue(nn.Module):
...
@@ -224,11 +245,11 @@ class TemporalValue(nn.Module):
)
)
def
forward
(
self
,
x
,
cond
,
time
,
*
args
):
def
forward
(
self
,
x
,
cond
,
time
,
*
args
):
'''
"""
x : [ batch x horizon x transition ]
x : [ batch x horizon x transition ]
'''
"""
x
=
einops
.
rearrange
(
x
,
'
b h t -> b t h
'
)
x
=
einops
.
rearrange
(
x
,
"
b h t -> b t h
"
)
t
=
self
.
time_mlp
(
time
)
t
=
self
.
time_mlp
(
time
)
...
@@ -239,4 +260,4 @@ class TemporalValue(nn.Module):
...
@@ -239,4 +260,4 @@ class TemporalValue(nn.Module):
x
=
x
.
view
(
len
(
x
),
-
1
)
x
=
x
.
view
(
len
(
x
),
-
1
)
out
=
self
.
final_block
(
torch
.
cat
([
x
,
t
],
dim
=-
1
))
out
=
self
.
final_block
(
torch
.
cat
([
x
,
t
],
dim
=-
1
))
return
out
return
out
\ No newline at end of file
src/diffusers/pipelines/grad_tts_utils.py
View file @
08c85229
...
@@ -233,6 +233,7 @@ def english_cleaners(text):
...
@@ -233,6 +233,7 @@ def english_cleaners(text):
text
=
collapse_whitespace
(
text
)
text
=
collapse_whitespace
(
text
)
return
text
return
text
try
:
try
:
_inflect
=
inflect
.
engine
()
_inflect
=
inflect
.
engine
()
except
:
except
:
...
...
src/diffusers/schedulers/scheduling_ddim.py
View file @
08c85229
# Copyright 2022 The HuggingFace Team. All rights reserved.
# Copyright 2022
Stanford University Team and
The HuggingFace Team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -11,6 +11,10 @@
...
@@ -11,6 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
# and https://github.com/hojonathanho/diffusion
import
math
import
math
import
numpy
as
np
import
numpy
as
np
...
@@ -31,6 +35,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
...
@@ -31,6 +35,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
:param max_beta: the maximum beta to use; use values lower than 1 to
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
prevent singularities.
"""
"""
def
alpha_bar
(
time_step
):
def
alpha_bar
(
time_step
):
return
math
.
cos
((
time_step
+
0.008
)
/
1.008
*
math
.
pi
/
2
)
**
2
return
math
.
cos
((
time_step
+
0.008
)
/
1.008
*
math
.
pi
/
2
)
**
2
...
...
src/diffusers/schedulers/scheduling_ddpm.py
View file @
08c85229
# Copyright 2022 The HuggingFace Team. All rights reserved.
# Copyright 2022
UC Berkely Team and
The HuggingFace Team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -11,6 +11,9 @@
...
@@ -11,6 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import
math
import
math
import
numpy
as
np
import
numpy
as
np
...
@@ -31,6 +34,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
...
@@ -31,6 +34,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
:param max_beta: the maximum beta to use; use values lower than 1 to
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
prevent singularities.
"""
"""
def
alpha_bar
(
time_step
):
def
alpha_bar
(
time_step
):
return
math
.
cos
((
time_step
+
0.008
)
/
1.008
*
math
.
pi
/
2
)
**
2
return
math
.
cos
((
time_step
+
0.008
)
/
1.008
*
math
.
pi
/
2
)
**
2
...
...
src/diffusers/schedulers/scheduling_pndm.py
View file @
08c85229
# Copyright 2022 The HuggingFace Team. All rights reserved.
# Copyright 2022
Zhejiang University Team and
The HuggingFace Team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -11,9 +11,13 @@
...
@@ -11,9 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
numpy
as
np
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import
math
import
math
import
numpy
as
np
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
.scheduling_utils
import
SchedulerMixin
from
.scheduling_utils
import
SchedulerMixin
...
@@ -30,6 +34,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
...
@@ -30,6 +34,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
:param max_beta: the maximum beta to use; use values lower than 1 to
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
prevent singularities.
"""
"""
def
alpha_bar
(
time_step
):
def
alpha_bar
(
time_step
):
return
math
.
cos
((
time_step
+
0.008
)
/
1.008
*
math
.
pi
/
2
)
**
2
return
math
.
cos
((
time_step
+
0.008
)
/
1.008
*
math
.
pi
/
2
)
**
2
...
...
tests/test_modeling_utils.py
View file @
08c85229
...
@@ -17,11 +17,11 @@
...
@@ -17,11 +17,11 @@
import
inspect
import
inspect
import
tempfile
import
tempfile
import
unittest
import
unittest
import
numpy
as
np
import
pytest
import
numpy
as
np
import
torch
import
torch
import
pytest
from
diffusers
import
(
from
diffusers
import
(
BDDM
,
BDDM
,
DDIM
,
DDIM
,
...
@@ -30,10 +30,10 @@ from diffusers import (
...
@@ -30,10 +30,10 @@ from diffusers import (
PNDM
,
PNDM
,
DDIMScheduler
,
DDIMScheduler
,
DDPMScheduler
,
DDPMScheduler
,
GLIDESuperResUNetModel
,
LatentDiffusion
,
LatentDiffusion
,
PNDMScheduler
,
PNDMScheduler
,
UNetModel
,
UNetModel
,
GLIDESuperResUNetModel
)
)
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.pipeline_utils
import
DiffusionPipeline
from
diffusers.pipeline_utils
import
DiffusionPipeline
...
@@ -105,7 +105,7 @@ class ModelTesterMixin:
...
@@ -105,7 +105,7 @@ class ModelTesterMixin:
max_diff
=
(
image
-
new_image
).
abs
().
sum
().
item
()
max_diff
=
(
image
-
new_image
).
abs
().
sum
().
item
()
self
.
assertLessEqual
(
max_diff
,
1e-5
,
"Models give different forward passes"
)
self
.
assertLessEqual
(
max_diff
,
1e-5
,
"Models give different forward passes"
)
def
test_determinism
(
self
):
def
test_determinism
(
self
):
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
model
=
self
.
model_class
(
**
init_dict
)
...
@@ -121,7 +121,7 @@ class ModelTesterMixin:
...
@@ -121,7 +121,7 @@ class ModelTesterMixin:
out_2
=
out_2
[
~
np
.
isnan
(
out_2
)]
out_2
=
out_2
[
~
np
.
isnan
(
out_2
)]
max_diff
=
np
.
amax
(
np
.
abs
(
out_1
-
out_2
))
max_diff
=
np
.
amax
(
np
.
abs
(
out_1
-
out_2
))
self
.
assertLessEqual
(
max_diff
,
1e-5
)
self
.
assertLessEqual
(
max_diff
,
1e-5
)
def
test_output
(
self
):
def
test_output
(
self
):
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
model
=
self
.
model_class
(
**
init_dict
)
...
@@ -130,11 +130,11 @@ class ModelTesterMixin:
...
@@ -130,11 +130,11 @@ class ModelTesterMixin:
with
torch
.
no_grad
():
with
torch
.
no_grad
():
output
=
model
(
**
inputs_dict
)
output
=
model
(
**
inputs_dict
)
self
.
assertIsNotNone
(
output
)
self
.
assertIsNotNone
(
output
)
expected_shape
=
inputs_dict
[
"x"
].
shape
expected_shape
=
inputs_dict
[
"x"
].
shape
self
.
assertEqual
(
output
.
shape
,
expected_shape
,
"Input and output shapes do not match"
)
self
.
assertEqual
(
output
.
shape
,
expected_shape
,
"Input and output shapes do not match"
)
def
test_forward_signature
(
self
):
def
test_forward_signature
(
self
):
init_dict
,
_
=
self
.
prepare_init_args_and_inputs_for_common
()
init_dict
,
_
=
self
.
prepare_init_args_and_inputs_for_common
()
...
@@ -145,14 +145,14 @@ class ModelTesterMixin:
...
@@ -145,14 +145,14 @@ class ModelTesterMixin:
expected_arg_names
=
[
"x"
,
"timesteps"
]
expected_arg_names
=
[
"x"
,
"timesteps"
]
self
.
assertListEqual
(
arg_names
[:
2
],
expected_arg_names
)
self
.
assertListEqual
(
arg_names
[:
2
],
expected_arg_names
)
def
test_model_from_config
(
self
):
def
test_model_from_config
(
self
):
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
model
=
self
.
model_class
(
**
init_dict
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
# test if the model can be loaded from the config
# test if the model can be loaded from the config
# and has all the expected shape
# and has all the expected shape
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
...
@@ -160,17 +160,17 @@ class ModelTesterMixin:
...
@@ -160,17 +160,17 @@ class ModelTesterMixin:
new_model
=
self
.
model_class
.
from_config
(
tmpdirname
)
new_model
=
self
.
model_class
.
from_config
(
tmpdirname
)
new_model
.
to
(
torch_device
)
new_model
.
to
(
torch_device
)
new_model
.
eval
()
new_model
.
eval
()
# check if all paramters shape are the same
# check if all paramters shape are the same
for
param_name
in
model
.
state_dict
().
keys
():
for
param_name
in
model
.
state_dict
().
keys
():
param_1
=
model
.
state_dict
()[
param_name
]
param_1
=
model
.
state_dict
()[
param_name
]
param_2
=
new_model
.
state_dict
()[
param_name
]
param_2
=
new_model
.
state_dict
()[
param_name
]
self
.
assertEqual
(
param_1
.
shape
,
param_2
.
shape
)
self
.
assertEqual
(
param_1
.
shape
,
param_2
.
shape
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
output_1
=
model
(
**
inputs_dict
)
output_1
=
model
(
**
inputs_dict
)
output_2
=
new_model
(
**
inputs_dict
)
output_2
=
new_model
(
**
inputs_dict
)
self
.
assertEqual
(
output_1
.
shape
,
output_2
.
shape
)
self
.
assertEqual
(
output_1
.
shape
,
output_2
.
shape
)
def
test_training
(
self
):
def
test_training
(
self
):
...
@@ -180,7 +180,7 @@ class ModelTesterMixin:
...
@@ -180,7 +180,7 @@ class ModelTesterMixin:
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
train
()
model
.
train
()
output
=
model
(
**
inputs_dict
)
output
=
model
(
**
inputs_dict
)
noise
=
torch
.
randn
((
inputs_dict
[
"x"
].
shape
[
0
],
)
+
self
.
get_output_shape
).
to
(
torch_device
)
noise
=
torch
.
randn
((
inputs_dict
[
"x"
].
shape
[
0
],)
+
self
.
get_output_shape
).
to
(
torch_device
)
loss
=
torch
.
nn
.
functional
.
mse_loss
(
output
,
noise
)
loss
=
torch
.
nn
.
functional
.
mse_loss
(
output
,
noise
)
loss
.
backward
()
loss
.
backward
()
...
@@ -198,11 +198,11 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -198,11 +198,11 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
time_step
=
torch
.
tensor
([
10
]).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]).
to
(
torch_device
)
return
{
"x"
:
noise
,
"timesteps"
:
time_step
}
return
{
"x"
:
noise
,
"timesteps"
:
time_step
}
@
property
@
property
def
get_input_shape
(
self
):
def
get_input_shape
(
self
):
return
(
3
,
32
,
32
)
return
(
3
,
32
,
32
)
@
property
@
property
def
get_output_shape
(
self
):
def
get_output_shape
(
self
):
return
(
3
,
32
,
32
)
return
(
3
,
32
,
32
)
...
@@ -217,7 +217,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -217,7 +217,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
}
}
inputs_dict
=
self
.
dummy_input
inputs_dict
=
self
.
dummy_input
return
init_dict
,
inputs_dict
return
init_dict
,
inputs_dict
def
test_from_pretrained_hub
(
self
):
def
test_from_pretrained_hub
(
self
):
model
,
loading_info
=
UNetModel
.
from_pretrained
(
"fusing/ddpm_dummy"
,
output_loading_info
=
True
)
model
,
loading_info
=
UNetModel
.
from_pretrained
(
"fusing/ddpm_dummy"
,
output_loading_info
=
True
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsNotNone
(
model
)
...
@@ -227,7 +227,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -227,7 +227,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
image
=
model
(
**
self
.
dummy_input
)
image
=
model
(
**
self
.
dummy_input
)
assert
image
is
not
None
,
"Make sure output is not None"
assert
image
is
not
None
,
"Make sure output is not None"
def
test_output_pretrained
(
self
):
def
test_output_pretrained
(
self
):
model
=
UNetModel
.
from_pretrained
(
"fusing/ddpm_dummy"
)
model
=
UNetModel
.
from_pretrained
(
"fusing/ddpm_dummy"
)
model
.
eval
()
model
.
eval
()
...
@@ -235,13 +235,13 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -235,13 +235,13 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
torch
.
cuda
.
manual_seed_all
(
0
)
noise
=
torch
.
randn
(
1
,
model
.
config
.
in_channels
,
model
.
config
.
resolution
,
model
.
config
.
resolution
)
noise
=
torch
.
randn
(
1
,
model
.
config
.
in_channels
,
model
.
config
.
resolution
,
model
.
config
.
resolution
)
time_step
=
torch
.
tensor
([
10
])
time_step
=
torch
.
tensor
([
10
])
with
torch
.
no_grad
():
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
)
output
=
model
(
noise
,
time_step
)
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
# fmt: off
expected_output_slice
=
torch
.
tensor
([
0.2891
,
-
0.1899
,
0.2595
,
-
0.6214
,
0.0968
,
-
0.2622
,
0.4688
,
0.1311
,
0.0053
])
expected_output_slice
=
torch
.
tensor
([
0.2891
,
-
0.1899
,
0.2595
,
-
0.6214
,
0.0968
,
-
0.2622
,
0.4688
,
0.1311
,
0.0053
])
...
@@ -249,6 +249,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -249,6 +249,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
print
(
output_slice
)
print
(
output_slice
)
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
class
GLIDESuperResUNetTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
class
GLIDESuperResUNetTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
GLIDESuperResUNetModel
model_class
=
GLIDESuperResUNetModel
...
@@ -266,19 +267,19 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
...
@@ -266,19 +267,19 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
time_step
=
torch
.
tensor
([
10
]
*
noise
.
shape
[
0
],
device
=
torch_device
)
time_step
=
torch
.
tensor
([
10
]
*
noise
.
shape
[
0
],
device
=
torch_device
)
return
{
"x"
:
noise
,
"timesteps"
:
time_step
,
"low_res"
:
low_res
}
return
{
"x"
:
noise
,
"timesteps"
:
time_step
,
"low_res"
:
low_res
}
@
property
@
property
def
get_input_shape
(
self
):
def
get_input_shape
(
self
):
return
(
3
,
32
,
32
)
return
(
3
,
32
,
32
)
@
property
@
property
def
get_output_shape
(
self
):
def
get_output_shape
(
self
):
return
(
6
,
32
,
32
)
return
(
6
,
32
,
32
)
def
prepare_init_args_and_inputs_for_common
(
self
):
def
prepare_init_args_and_inputs_for_common
(
self
):
init_dict
=
{
init_dict
=
{
"attention_resolutions"
:
(
2
,),
"attention_resolutions"
:
(
2
,),
"channel_mult"
:
(
1
,
2
),
"channel_mult"
:
(
1
,
2
),
"in_channels"
:
6
,
"in_channels"
:
6
,
"out_channels"
:
6
,
"out_channels"
:
6
,
"model_channels"
:
32
,
"model_channels"
:
32
,
...
@@ -287,7 +288,7 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
...
@@ -287,7 +288,7 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
"num_res_blocks"
:
2
,
"num_res_blocks"
:
2
,
"resblock_updown"
:
True
,
"resblock_updown"
:
True
,
"resolution"
:
32
,
"resolution"
:
32
,
"use_scale_shift_norm"
:
True
"use_scale_shift_norm"
:
True
,
}
}
inputs_dict
=
self
.
dummy_input
inputs_dict
=
self
.
dummy_input
return
init_dict
,
inputs_dict
return
init_dict
,
inputs_dict
...
@@ -302,13 +303,15 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
...
@@ -302,13 +303,15 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
output
=
model
(
**
inputs_dict
)
output
=
model
(
**
inputs_dict
)
output
,
_
=
torch
.
split
(
output
,
3
,
dim
=
1
)
output
,
_
=
torch
.
split
(
output
,
3
,
dim
=
1
)
self
.
assertIsNotNone
(
output
)
self
.
assertIsNotNone
(
output
)
expected_shape
=
inputs_dict
[
"x"
].
shape
expected_shape
=
inputs_dict
[
"x"
].
shape
self
.
assertEqual
(
output
.
shape
,
expected_shape
,
"Input and output shapes do not match"
)
self
.
assertEqual
(
output
.
shape
,
expected_shape
,
"Input and output shapes do not match"
)
def
test_from_pretrained_hub
(
self
):
def
test_from_pretrained_hub
(
self
):
model
,
loading_info
=
GLIDESuperResUNetModel
.
from_pretrained
(
"fusing/glide-super-res-dummy"
,
output_loading_info
=
True
)
model
,
loading_info
=
GLIDESuperResUNetModel
.
from_pretrained
(
"fusing/glide-super-res-dummy"
,
output_loading_info
=
True
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsNotNone
(
model
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
...
@@ -316,7 +319,7 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
...
@@ -316,7 +319,7 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
image
=
model
(
**
self
.
dummy_input
)
image
=
model
(
**
self
.
dummy_input
)
assert
image
is
not
None
,
"Make sure output is not None"
assert
image
is
not
None
,
"Make sure output is not None"
# TODO (patil-suraj): Check why GLIDESuperResUNetModel always outputs zero
# TODO (patil-suraj): Check why GLIDESuperResUNetModel always outputs zero
@
unittest
.
skip
(
"GLIDESuperResUNetModel always outputs zero"
)
@
unittest
.
skip
(
"GLIDESuperResUNetModel always outputs zero"
)
def
test_output_pretrained
(
self
):
def
test_output_pretrained
(
self
):
...
@@ -326,14 +329,14 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
...
@@ -326,14 +329,14 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
torch
.
cuda
.
manual_seed_all
(
0
)
noise
=
torch
.
randn
(
1
,
3
,
32
,
32
)
noise
=
torch
.
randn
(
1
,
3
,
32
,
32
)
low_res
=
torch
.
randn
(
1
,
3
,
4
,
4
)
low_res
=
torch
.
randn
(
1
,
3
,
4
,
4
)
time_step
=
torch
.
tensor
([
42
]
*
noise
.
shape
[
0
])
time_step
=
torch
.
tensor
([
42
]
*
noise
.
shape
[
0
])
with
torch
.
no_grad
():
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
,
low_res
)
output
=
model
(
noise
,
time_step
,
low_res
)
output
,
_
=
torch
.
split
(
output
,
3
,
dim
=
1
)
output
,
_
=
torch
.
split
(
output
,
3
,
dim
=
1
)
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
# fmt: off
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment