Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
b09b152f
Commit
b09b152f
authored
Jun 21, 2022
by
anton-l
Browse files
Merge branch 'main' of github.com:huggingface/diffusers
parents
a2117cb7
4497e78d
Changes
49
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
831 additions
and
189 deletions
+831
-189
src/diffusers/schedulers/scheduling_grad_tts.py
src/diffusers/schedulers/scheduling_grad_tts.py
+1
-3
src/diffusers/schedulers/scheduling_pndm.py
src/diffusers/schedulers/scheduling_pndm.py
+110
-54
src/diffusers/schedulers/scheduling_utils.py
src/diffusers/schedulers/scheduling_utils.py
+10
-24
src/diffusers/utils/__init__.py
src/diffusers/utils/__init__.py
+58
-9
src/diffusers/utils/dummy_transformers_objects.py
src/diffusers/utils/dummy_transformers_objects.py
+48
-0
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+378
-15
tests/test_scheduler.py
tests/test_scheduler.py
+206
-58
utils/check_dummies.py
utils/check_dummies.py
+12
-18
utils/check_table.py
utils/check_table.py
+8
-8
No files found.
src/diffusers/schedulers/scheduling_grad_tts.py
View file @
b09b152f
...
@@ -30,8 +30,6 @@ class GradTTSScheduler(SchedulerMixin, ConfigMixin):
...
@@ -30,8 +30,6 @@ class GradTTSScheduler(SchedulerMixin, ConfigMixin):
beta_start
=
beta_start
,
beta_start
=
beta_start
,
beta_end
=
beta_end
,
beta_end
=
beta_end
,
)
)
self
.
timesteps
=
int
(
timesteps
)
self
.
set_format
(
tensor_format
=
tensor_format
)
self
.
set_format
(
tensor_format
=
tensor_format
)
def
sample_noise
(
self
,
timestep
):
def
sample_noise
(
self
,
timestep
):
...
@@ -46,4 +44,4 @@ class GradTTSScheduler(SchedulerMixin, ConfigMixin):
...
@@ -46,4 +44,4 @@ class GradTTSScheduler(SchedulerMixin, ConfigMixin):
return
xt
return
xt
def
__len__
(
self
):
def
__len__
(
self
):
return
self
.
timesteps
return
len
(
self
.
config
.
timesteps
)
src/diffusers/schedulers/scheduling_pndm.py
View file @
b09b152f
# Copyright 2022 The HuggingFace Team. All rights reserved.
# Copyright 2022
Zhejiang University Team and
The HuggingFace Team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -11,12 +11,39 @@
...
@@ -11,12 +11,39 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import
math
import
math
import
numpy
as
np
import
numpy
as
np
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
.scheduling_utils
import
SchedulerMixin
,
betas_for_alpha_bar
,
linear_beta_schedule
from
.scheduling_utils
import
SchedulerMixin
def
betas_for_alpha_bar
(
num_diffusion_timesteps
,
max_beta
=
0.999
):
"""
Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of (1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce.
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
produces the cumulative product of (1-beta) up to that
part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
def
alpha_bar
(
time_step
):
return
math
.
cos
((
time_step
+
0.008
)
/
1.008
*
math
.
pi
/
2
)
**
2
betas
=
[]
for
i
in
range
(
num_diffusion_timesteps
):
t1
=
i
/
num_diffusion_timesteps
t2
=
(
i
+
1
)
/
num_diffusion_timesteps
betas
.
append
(
min
(
1
-
alpha_bar
(
t2
)
/
alpha_bar
(
t1
),
max_beta
))
return
np
.
array
(
betas
,
dtype
=
np
.
float32
)
class
PNDMScheduler
(
SchedulerMixin
,
ConfigMixin
):
class
PNDMScheduler
(
SchedulerMixin
,
ConfigMixin
):
...
@@ -35,16 +62,12 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -35,16 +62,12 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
beta_end
=
beta_end
,
beta_end
=
beta_end
,
beta_schedule
=
beta_schedule
,
beta_schedule
=
beta_schedule
,
)
)
self
.
timesteps
=
int
(
timesteps
)
if
beta_schedule
==
"linear"
:
if
beta_schedule
==
"linear"
:
self
.
betas
=
linear_beta_schedule
(
timesteps
,
beta_start
=
beta_
start
,
beta_end
=
beta_end
)
self
.
betas
=
np
.
linspace
(
beta_start
,
beta_
end
,
timesteps
,
dtype
=
np
.
float32
)
elif
beta_schedule
==
"squaredcos_cap_v2"
:
elif
beta_schedule
==
"squaredcos_cap_v2"
:
# GLIDE cosine schedule
# GLIDE cosine schedule
self
.
betas
=
betas_for_alpha_bar
(
self
.
betas
=
betas_for_alpha_bar
(
timesteps
)
timesteps
,
lambda
t
:
math
.
cos
((
t
+
0.008
)
/
1.008
*
math
.
pi
/
2
)
**
2
,
)
else
:
else
:
raise
NotImplementedError
(
f
"
{
beta_schedule
}
does is not implemented for
{
self
.
__class__
}
"
)
raise
NotImplementedError
(
f
"
{
beta_schedule
}
does is not implemented for
{
self
.
__class__
}
"
)
...
@@ -57,55 +80,58 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -57,55 +80,58 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
# For now we only support F-PNDM, i.e. the runge-kutta method
# For now we only support F-PNDM, i.e. the runge-kutta method
# For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
# For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
# mainly at
equations
(12)
and
(13) and the Algorithm 2.
# mainly at
formula (9),
(12)
,
(13) and the Algorithm 2.
self
.
pndm_order
=
4
self
.
pndm_order
=
4
# running values
# running values
self
.
cur_residual
=
0
self
.
cur_residual
=
0
self
.
cur_sample
=
None
self
.
cur_sample
=
None
self
.
ets
=
[]
self
.
ets
=
[]
self
.
warmup
_time_steps
=
{}
self
.
prk
_time_steps
=
{}
self
.
time_steps
=
{}
self
.
time_steps
=
{}
self
.
set_prk_mode
()
def
get_alpha
(
self
,
time_step
):
def
get_prk_time_steps
(
self
,
num_inference_steps
):
return
self
.
alphas
[
time_step
]
if
num_inference_steps
in
self
.
prk_time_steps
:
return
self
.
prk_time_steps
[
num_inference_steps
]
def
get_beta
(
self
,
time_step
):
return
self
.
betas
[
time_step
]
def
get_alpha_prod
(
self
,
time_step
):
inference_step_times
=
list
(
range
(
0
,
self
.
config
.
timesteps
,
self
.
config
.
timesteps
//
num_inference_steps
))
if
time_step
<
0
:
return
self
.
one
return
self
.
alphas_cumprod
[
time_step
]
def
get_warmup_time_steps
(
self
,
num_inference_steps
):
prk_time_steps
=
np
.
array
(
inference_step_times
[
-
self
.
pndm_order
:]).
repeat
(
2
)
+
np
.
tile
(
if
num_inference_steps
in
self
.
warmup_time_steps
:
np
.
array
([
0
,
self
.
config
.
timesteps
//
num_inference_steps
//
2
]),
self
.
pndm_order
return
self
.
warmup_time_steps
[
num_inference_steps
]
inference_step_times
=
list
(
range
(
0
,
self
.
timesteps
,
self
.
timesteps
//
num_inference_steps
))
warmup_time_steps
=
np
.
array
(
inference_step_times
[
-
self
.
pndm_order
:]).
repeat
(
2
)
+
np
.
tile
(
np
.
array
([
0
,
self
.
timesteps
//
num_inference_steps
//
2
]),
self
.
pndm_order
)
)
self
.
warmup
_time_steps
[
num_inference_steps
]
=
list
(
reversed
(
warmup
_time_steps
[:
-
1
].
repeat
(
2
)[
1
:
-
1
]))
self
.
prk
_time_steps
[
num_inference_steps
]
=
list
(
reversed
(
prk
_time_steps
[:
-
1
].
repeat
(
2
)[
1
:
-
1
]))
return
self
.
warmup
_time_steps
[
num_inference_steps
]
return
self
.
prk
_time_steps
[
num_inference_steps
]
def
get_time_steps
(
self
,
num_inference_steps
):
def
get_time_steps
(
self
,
num_inference_steps
):
if
num_inference_steps
in
self
.
time_steps
:
if
num_inference_steps
in
self
.
time_steps
:
return
self
.
time_steps
[
num_inference_steps
]
return
self
.
time_steps
[
num_inference_steps
]
inference_step_times
=
list
(
range
(
0
,
self
.
timesteps
,
self
.
timesteps
//
num_inference_steps
))
inference_step_times
=
list
(
range
(
0
,
self
.
config
.
timesteps
,
self
.
config
.
timesteps
//
num_inference_steps
))
self
.
time_steps
[
num_inference_steps
]
=
list
(
reversed
(
inference_step_times
[:
-
3
]))
self
.
time_steps
[
num_inference_steps
]
=
list
(
reversed
(
inference_step_times
[:
-
3
]))
return
self
.
time_steps
[
num_inference_steps
]
return
self
.
time_steps
[
num_inference_steps
]
def
set_prk_mode
(
self
):
self
.
mode
=
"prk"
def
set_plms_mode
(
self
):
self
.
mode
=
"plms"
def
step
(
self
,
*
args
,
**
kwargs
):
if
self
.
mode
==
"prk"
:
return
self
.
step_prk
(
*
args
,
**
kwargs
)
if
self
.
mode
==
"plms"
:
return
self
.
step_plms
(
*
args
,
**
kwargs
)
raise
ValueError
(
f
"mode
{
self
.
mode
}
does not exist."
)
def
step_prk
(
self
,
residual
,
sample
,
t
,
num_inference_steps
):
def
step_prk
(
self
,
residual
,
sample
,
t
,
num_inference_steps
):
# TODO(Patrick) - need to rethink whether the "warmup" way is the correct API design here
prk_time_steps
=
self
.
get_prk_time_steps
(
num_inference_steps
)
warmup_time_steps
=
self
.
get_warmup_time_steps
(
num_inference_steps
)
t_
prev
=
warmup
_time_steps
[
t
//
4
*
4
]
t_
orig
=
prk
_time_steps
[
t
//
4
*
4
]
t_
next
=
warmup
_time_steps
[
min
(
t
+
1
,
len
(
warmup
_time_steps
)
-
1
)]
t_
orig_prev
=
prk
_time_steps
[
min
(
t
+
1
,
len
(
prk
_time_steps
)
-
1
)]
if
t
%
4
==
0
:
if
t
%
4
==
0
:
self
.
cur_residual
+=
1
/
6
*
residual
self
.
cur_residual
+=
1
/
6
*
residual
...
@@ -119,33 +145,63 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -119,33 +145,63 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
residual
=
self
.
cur_residual
+
1
/
6
*
residual
residual
=
self
.
cur_residual
+
1
/
6
*
residual
self
.
cur_residual
=
0
self
.
cur_residual
=
0
return
self
.
transfer
(
self
.
cur_sample
,
t_prev
,
t_next
,
residual
)
# cur_sample should not be `None`
cur_sample
=
self
.
cur_sample
if
self
.
cur_sample
is
not
None
else
sample
return
self
.
get_prev_sample
(
cur_sample
,
t_orig
,
t_orig_prev
,
residual
)
def
step_plms
(
self
,
residual
,
sample
,
t
,
num_inference_steps
):
def
step_plms
(
self
,
residual
,
sample
,
t
,
num_inference_steps
):
if
len
(
self
.
ets
)
<
3
:
raise
ValueError
(
f
"
{
self
.
__class__
}
can only be run AFTER scheduler has been run "
"in 'prk' mode for at least 12 iterations "
"See: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py "
"for more information."
)
timesteps
=
self
.
get_time_steps
(
num_inference_steps
)
timesteps
=
self
.
get_time_steps
(
num_inference_steps
)
t_
prev
=
timesteps
[
t
]
t_
orig
=
timesteps
[
t
]
t_
next
=
timesteps
[
min
(
t
+
1
,
len
(
timesteps
)
-
1
)]
t_
orig_prev
=
timesteps
[
min
(
t
+
1
,
len
(
timesteps
)
-
1
)]
self
.
ets
.
append
(
residual
)
self
.
ets
.
append
(
residual
)
residual
=
(
1
/
24
)
*
(
55
*
self
.
ets
[
-
1
]
-
59
*
self
.
ets
[
-
2
]
+
37
*
self
.
ets
[
-
3
]
-
9
*
self
.
ets
[
-
4
])
residual
=
(
1
/
24
)
*
(
55
*
self
.
ets
[
-
1
]
-
59
*
self
.
ets
[
-
2
]
+
37
*
self
.
ets
[
-
3
]
-
9
*
self
.
ets
[
-
4
])
return
self
.
transfer
(
sample
,
t_prev
,
t_next
,
residual
)
return
self
.
get_prev_sample
(
sample
,
t_orig
,
t_orig_prev
,
residual
)
def
transfer
(
self
,
x
,
t
,
t_next
,
et
):
def
get_prev_sample
(
self
,
sample
,
t_orig
,
t_orig_prev
,
residual
):
# TODO(Patrick): clean up to be compatible with numpy and give better names
# See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
# this function computes x_(t−δ) using the formula of (9)
alphas_cump
=
self
.
alphas_cumprod
.
to
(
x
.
device
)
# Note that x_t needs to be added to both sides of the equation
at
=
alphas_cump
[
t
+
1
].
view
(
-
1
,
1
,
1
,
1
)
at_next
=
alphas_cump
[
t_next
+
1
].
view
(
-
1
,
1
,
1
,
1
)
# Notation (<variable name> -> <name in paper>
# alpha_prod_t -> α_t
x_delta
=
(
at_next
-
at
)
*
(
# alpha_prod_t_prev -> α_(t−δ)
(
1
/
(
at
.
sqrt
()
*
(
at
.
sqrt
()
+
at_next
.
sqrt
())))
*
x
# beta_prod_t -> (1 - α_t)
-
1
/
(
at
.
sqrt
()
*
(((
1
-
at_next
)
*
at
).
sqrt
()
+
((
1
-
at
)
*
at_next
).
sqrt
()))
*
et
# beta_prod_t_prev -> (1 - α_(t−δ))
)
# sample -> x_t
# residual -> e_θ(x_t, t)
x_next
=
x
+
x_delta
# prev_sample -> x_(t−δ)
return
x_next
alpha_prod_t
=
self
.
alphas_cumprod
[
t_orig
+
1
]
alpha_prod_t_prev
=
self
.
alphas_cumprod
[
t_orig_prev
+
1
]
beta_prod_t
=
1
-
alpha_prod_t
beta_prod_t_prev
=
1
-
alpha_prod_t_prev
# corresponds to (α_(t−δ) - α_t) divided by
# denominator of x_t in formula (9) and plus 1
# Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) =
# sqrt(α_(t−δ)) / sqrt(α_t))
sample_coeff
=
(
alpha_prod_t_prev
/
alpha_prod_t
)
**
(
0.5
)
# corresponds to denominator of e_θ(x_t, t) in formula (9)
residual_denom_coeff
=
alpha_prod_t
*
beta_prod_t_prev
**
(
0.5
)
+
(
alpha_prod_t
*
beta_prod_t
*
alpha_prod_t_prev
)
**
(
0.5
)
# full formula (9)
prev_sample
=
sample_coeff
*
sample
-
(
alpha_prod_t_prev
-
alpha_prod_t
)
*
residual
/
residual_denom_coeff
return
prev_sample
def
__len__
(
self
):
def
__len__
(
self
):
return
self
.
timesteps
return
self
.
config
.
timesteps
src/diffusers/schedulers/scheduling_utils.py
View file @
b09b152f
...
@@ -18,30 +18,6 @@ import torch
...
@@ -18,30 +18,6 @@ import torch
SCHEDULER_CONFIG_NAME
=
"scheduler_config.json"
SCHEDULER_CONFIG_NAME
=
"scheduler_config.json"
def
linear_beta_schedule
(
timesteps
,
beta_start
,
beta_end
):
return
np
.
linspace
(
beta_start
,
beta_end
,
timesteps
,
dtype
=
np
.
float32
)
def
betas_for_alpha_bar
(
num_diffusion_timesteps
,
alpha_bar
,
max_beta
=
0.999
):
"""
Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of (1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce.
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
produces the cumulative product of (1-beta) up to that
part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
betas
=
[]
for
i
in
range
(
num_diffusion_timesteps
):
t1
=
i
/
num_diffusion_timesteps
t2
=
(
i
+
1
)
/
num_diffusion_timesteps
betas
.
append
(
min
(
1
-
alpha_bar
(
t2
)
/
alpha_bar
(
t1
),
max_beta
))
return
np
.
array
(
betas
,
dtype
=
np
.
float32
)
class
SchedulerMixin
:
class
SchedulerMixin
:
config_name
=
SCHEDULER_CONFIG_NAME
config_name
=
SCHEDULER_CONFIG_NAME
...
@@ -64,3 +40,13 @@ class SchedulerMixin:
...
@@ -64,3 +40,13 @@ class SchedulerMixin:
return
torch
.
clamp
(
tensor
,
min_value
,
max_value
)
return
torch
.
clamp
(
tensor
,
min_value
,
max_value
)
raise
ValueError
(
f
"`self.tensor_format`:
{
self
.
tensor_format
}
is not valid."
)
raise
ValueError
(
f
"`self.tensor_format`:
{
self
.
tensor_format
}
is not valid."
)
def
log
(
self
,
tensor
):
tensor_format
=
getattr
(
self
,
"tensor_format"
,
"pt"
)
if
tensor_format
==
"np"
:
return
np
.
log
(
tensor
)
elif
tensor_format
==
"pt"
:
return
torch
.
log
(
tensor
)
raise
ValueError
(
f
"`self.tensor_format`:
{
self
.
tensor_format
}
is not valid."
)
src/diffusers/utils/__init__.py
View file @
b09b152f
#!/usr/bin/env python
# coding=utf-8
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
import
os
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
...
@@ -20,8 +11,18 @@ import os
...
@@ -20,8 +11,18 @@ import os
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
importlib
import
os
from
collections
import
OrderedDict
import
importlib_metadata
from
requests.exceptions
import
HTTPError
from
requests.exceptions
import
HTTPError
from
.logging
import
get_logger
logger
=
get_logger
(
__name__
)
hf_cache_home
=
os
.
path
.
expanduser
(
hf_cache_home
=
os
.
path
.
expanduser
(
os
.
getenv
(
"HF_HOME"
,
os
.
path
.
join
(
os
.
getenv
(
"XDG_CACHE_HOME"
,
"~/.cache"
),
"huggingface"
))
os
.
getenv
(
"HF_HOME"
,
os
.
path
.
join
(
os
.
getenv
(
"XDG_CACHE_HOME"
,
"~/.cache"
),
"huggingface"
))
...
@@ -36,6 +37,18 @@ DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
...
@@ -36,6 +37,18 @@ DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
HF_MODULES_CACHE
=
os
.
getenv
(
"HF_MODULES_CACHE"
,
os
.
path
.
join
(
hf_cache_home
,
"modules"
))
HF_MODULES_CACHE
=
os
.
getenv
(
"HF_MODULES_CACHE"
,
os
.
path
.
join
(
hf_cache_home
,
"modules"
))
_transformers_available
=
importlib
.
util
.
find_spec
(
"transformers"
)
is
not
None
try
:
_transformers_version
=
importlib_metadata
.
version
(
"transformers"
)
logger
.
debug
(
f
"Successfully imported transformers version
{
_transformers_version
}
"
)
except
importlib_metadata
.
PackageNotFoundError
:
_transformers_available
=
False
def
is_transformers_available
():
return
_transformers_available
class
RepositoryNotFoundError
(
HTTPError
):
class
RepositoryNotFoundError
(
HTTPError
):
"""
"""
Raised when trying to access a hf.co URL with an invalid repository name, or with a private repo name the user does
Raised when trying to access a hf.co URL with an invalid repository name, or with a private repo name the user does
...
@@ -49,3 +62,39 @@ class EntryNotFoundError(HTTPError):
...
@@ -49,3 +62,39 @@ class EntryNotFoundError(HTTPError):
class
RevisionNotFoundError
(
HTTPError
):
class
RevisionNotFoundError
(
HTTPError
):
"""Raised when trying to access a hf.co URL with a valid repository but an invalid revision."""
"""Raised when trying to access a hf.co URL with a valid repository but an invalid revision."""
TRANSFORMERS_IMPORT_ERROR
=
"""
{0} requires the transformers library but it was not found in your environment. You can install it with pip:
`pip install transformers`
"""
BACKENDS_MAPPING
=
OrderedDict
(
[
(
"transformers"
,
(
is_transformers_available
,
TRANSFORMERS_IMPORT_ERROR
)),
]
)
def
requires_backends
(
obj
,
backends
):
if
not
isinstance
(
backends
,
(
list
,
tuple
)):
backends
=
[
backends
]
name
=
obj
.
__name__
if
hasattr
(
obj
,
"__name__"
)
else
obj
.
__class__
.
__name__
checks
=
(
BACKENDS_MAPPING
[
backend
]
for
backend
in
backends
)
failed
=
[
msg
.
format
(
name
)
for
available
,
msg
in
checks
if
not
available
()]
if
failed
:
raise
ImportError
(
""
.
join
(
failed
))
class
DummyObject
(
type
):
"""
Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by
`requires_backend` each time a user tries to access any method of that class.
"""
def
__getattr__
(
cls
,
key
):
if
key
.
startswith
(
"_"
):
return
super
().
__getattr__
(
cls
,
key
)
requires_backends
(
cls
,
cls
.
_backends
)
src/diffusers/utils/dummy_transformers_objects.py
0 → 100644
View file @
b09b152f
# This file is autogenerated by the command `make fix-copies`, do not edit.
# flake8: noqa
from
..utils
import
DummyObject
,
requires_backends
class
GLIDESuperResUNetModel
(
metaclass
=
DummyObject
):
_backends
=
[
"transformers"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"transformers"
])
class
GLIDETextToImageUNetModel
(
metaclass
=
DummyObject
):
_backends
=
[
"transformers"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"transformers"
])
class
GLIDEUNetModel
(
metaclass
=
DummyObject
):
_backends
=
[
"transformers"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"transformers"
])
class
UNetGradTTSModel
(
metaclass
=
DummyObject
):
_backends
=
[
"transformers"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"transformers"
])
GLIDE
=
None
class
GradTTS
(
metaclass
=
DummyObject
):
_backends
=
[
"transformers"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"transformers"
])
class
LatentDiffusion
(
metaclass
=
DummyObject
):
_backends
=
[
"transformers"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"transformers"
])
tests/test_modeling_utils.py
View file @
b09b152f
...
@@ -14,11 +14,14 @@
...
@@ -14,11 +14,14 @@
# limitations under the License.
# limitations under the License.
import
inspect
import
tempfile
import
tempfile
import
unittest
import
unittest
import
numpy
as
np
import
torch
import
torch
import
pytest
from
diffusers
import
(
from
diffusers
import
(
BDDM
,
BDDM
,
DDIM
,
DDIM
,
...
@@ -27,9 +30,12 @@ from diffusers import (
...
@@ -27,9 +30,12 @@ from diffusers import (
PNDM
,
PNDM
,
DDIMScheduler
,
DDIMScheduler
,
DDPMScheduler
,
DDPMScheduler
,
GLIDESuperResUNetModel
,
LatentDiffusion
,
LatentDiffusion
,
PNDMScheduler
,
PNDMScheduler
,
UNetModel
,
UNetModel
,
UNetLDMModel
,
UNetGradTTSModel
,
)
)
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.pipeline_utils
import
DiffusionPipeline
from
diffusers.pipeline_utils
import
DiffusionPipeline
...
@@ -82,7 +88,108 @@ class ConfigTester(unittest.TestCase):
...
@@ -82,7 +88,108 @@ class ConfigTester(unittest.TestCase):
assert
config
==
new_config
assert
config
==
new_config
class
ModelTesterMixin
(
unittest
.
TestCase
):
class
ModelTesterMixin
:
def
test_from_pretrained_save_pretrained
(
self
):
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
model
.
to
(
torch_device
)
model
.
eval
()
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
model
.
save_pretrained
(
tmpdirname
)
new_model
=
self
.
model_class
.
from_pretrained
(
tmpdirname
)
new_model
.
to
(
torch_device
)
with
torch
.
no_grad
():
image
=
model
(
**
inputs_dict
)
new_image
=
new_model
(
**
inputs_dict
)
max_diff
=
(
image
-
new_image
).
abs
().
sum
().
item
()
self
.
assertLessEqual
(
max_diff
,
1e-5
,
"Models give different forward passes"
)
def
test_determinism
(
self
):
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
model
.
to
(
torch_device
)
model
.
eval
()
with
torch
.
no_grad
():
first
=
model
(
**
inputs_dict
)
second
=
model
(
**
inputs_dict
)
out_1
=
first
.
cpu
().
numpy
()
out_2
=
second
.
cpu
().
numpy
()
out_1
=
out_1
[
~
np
.
isnan
(
out_1
)]
out_2
=
out_2
[
~
np
.
isnan
(
out_2
)]
max_diff
=
np
.
amax
(
np
.
abs
(
out_1
-
out_2
))
self
.
assertLessEqual
(
max_diff
,
1e-5
)
def
test_output
(
self
):
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
model
.
to
(
torch_device
)
model
.
eval
()
with
torch
.
no_grad
():
output
=
model
(
**
inputs_dict
)
self
.
assertIsNotNone
(
output
)
expected_shape
=
inputs_dict
[
"x"
].
shape
self
.
assertEqual
(
output
.
shape
,
expected_shape
,
"Input and output shapes do not match"
)
def
test_forward_signature
(
self
):
init_dict
,
_
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
signature
=
inspect
.
signature
(
model
.
forward
)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names
=
[
*
signature
.
parameters
.
keys
()]
expected_arg_names
=
[
"x"
,
"timesteps"
]
self
.
assertListEqual
(
arg_names
[:
2
],
expected_arg_names
)
def
test_model_from_config
(
self
):
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
model
.
to
(
torch_device
)
model
.
eval
()
# test if the model can be loaded from the config
# and has all the expected shape
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
model
.
save_config
(
tmpdirname
)
new_model
=
self
.
model_class
.
from_config
(
tmpdirname
)
new_model
.
to
(
torch_device
)
new_model
.
eval
()
# check if all paramters shape are the same
for
param_name
in
model
.
state_dict
().
keys
():
param_1
=
model
.
state_dict
()[
param_name
]
param_2
=
new_model
.
state_dict
()[
param_name
]
self
.
assertEqual
(
param_1
.
shape
,
param_2
.
shape
)
with
torch
.
no_grad
():
output_1
=
model
(
**
inputs_dict
)
output_2
=
new_model
(
**
inputs_dict
)
self
.
assertEqual
(
output_1
.
shape
,
output_2
.
shape
)
def
test_training
(
self
):
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
model
.
to
(
torch_device
)
model
.
train
()
output
=
model
(
**
inputs_dict
)
noise
=
torch
.
randn
((
inputs_dict
[
"x"
].
shape
[
0
],)
+
self
.
get_output_shape
).
to
(
torch_device
)
loss
=
torch
.
nn
.
functional
.
mse_loss
(
output
,
noise
)
loss
.
backward
()
class
UnetModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
UNetModel
@
property
@
property
def
dummy_input
(
self
):
def
dummy_input
(
self
):
batch_size
=
4
batch_size
=
4
...
@@ -92,32 +199,289 @@ class ModelTesterMixin(unittest.TestCase):
...
@@ -92,32 +199,289 @@ class ModelTesterMixin(unittest.TestCase):
noise
=
floats_tensor
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
noise
=
floats_tensor
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]).
to
(
torch_device
)
return
(
noise
,
time_step
)
return
{
"x"
:
noise
,
"timesteps"
:
time_step
}
@
property
def
get_input_shape
(
self
):
return
(
3
,
32
,
32
)
@
property
def
get_output_shape
(
self
):
return
(
3
,
32
,
32
)
def
prepare_init_args_and_inputs_for_common
(
self
):
init_dict
=
{
"ch"
:
32
,
"ch_mult"
:
(
1
,
2
),
"num_res_blocks"
:
2
,
"attn_resolutions"
:
(
16
,),
"resolution"
:
32
,
}
inputs_dict
=
self
.
dummy_input
return
init_dict
,
inputs_dict
def
test_from_pretrained_hub
(
self
):
model
,
loading_info
=
UNetModel
.
from_pretrained
(
"fusing/ddpm_dummy"
,
output_loading_info
=
True
)
self
.
assertIsNotNone
(
model
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
def
test_from_pretrained_save_pretrained
(
self
):
model
=
UNetModel
(
ch
=
32
,
ch_mult
=
(
1
,
2
),
num_res_blocks
=
2
,
attn_resolutions
=
(
16
,),
resolution
=
32
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
image
=
model
(
**
self
.
dummy_input
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
assert
image
is
not
None
,
"Make sure output is not None"
model
.
save_pretrained
(
tmpdirname
)
new_model
=
UNetModel
.
from_pretrained
(
tmpdirname
)
def
test_output_pretrained
(
self
):
new_model
.
to
(
torch_device
)
model
=
UNetModel
.
from_pretrained
(
"fusing/ddpm_dummy"
)
model
.
eval
()
dummy_input
=
self
.
dummy_input
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
image
=
model
(
*
dummy_input
)
noise
=
torch
.
randn
(
1
,
model
.
config
.
in_channels
,
model
.
config
.
resolution
,
model
.
config
.
resolution
)
new_image
=
new_model
(
*
dummy_input
)
time_step
=
torch
.
tensor
([
10
]
)
assert
(
image
-
new_image
).
abs
().
sum
()
<
1e-5
,
"Models don't give the same forward pass"
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
)
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
expected_output_slice
=
torch
.
tensor
([
0.2891
,
-
0.1899
,
0.2595
,
-
0.6214
,
0.0968
,
-
0.2622
,
0.4688
,
0.1311
,
0.0053
])
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
class
GLIDESuperResUNetTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
GLIDESuperResUNetModel
@
property
def
dummy_input
(
self
):
batch_size
=
4
num_channels
=
6
sizes
=
(
32
,
32
)
low_res_size
=
(
4
,
4
)
torch_device
=
"cpu"
noise
=
torch
.
randn
((
batch_size
,
num_channels
//
2
)
+
sizes
).
to
(
torch_device
)
low_res
=
torch
.
randn
((
batch_size
,
3
)
+
low_res_size
).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]
*
noise
.
shape
[
0
],
device
=
torch_device
)
return
{
"x"
:
noise
,
"timesteps"
:
time_step
,
"low_res"
:
low_res
}
@
property
def
get_input_shape
(
self
):
return
(
3
,
32
,
32
)
@
property
def
get_output_shape
(
self
):
return
(
6
,
32
,
32
)
def
prepare_init_args_and_inputs_for_common
(
self
):
init_dict
=
{
"attention_resolutions"
:
(
2
,),
"channel_mult"
:
(
1
,
2
),
"in_channels"
:
6
,
"out_channels"
:
6
,
"model_channels"
:
32
,
"num_head_channels"
:
8
,
"num_heads_upsample"
:
1
,
"num_res_blocks"
:
2
,
"resblock_updown"
:
True
,
"resolution"
:
32
,
"use_scale_shift_norm"
:
True
,
}
inputs_dict
=
self
.
dummy_input
return
init_dict
,
inputs_dict
def
test_output
(
self
):
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
model
.
to
(
torch_device
)
model
.
eval
()
with
torch
.
no_grad
():
output
=
model
(
**
inputs_dict
)
output
,
_
=
torch
.
split
(
output
,
3
,
dim
=
1
)
self
.
assertIsNotNone
(
output
)
expected_shape
=
inputs_dict
[
"x"
].
shape
self
.
assertEqual
(
output
.
shape
,
expected_shape
,
"Input and output shapes do not match"
)
def
test_from_pretrained_hub
(
self
):
def
test_from_pretrained_hub
(
self
):
model
=
UNetModel
.
from_pretrained
(
"fusing/ddpm_dummy"
)
model
,
loading_info
=
GLIDESuperResUNetModel
.
from_pretrained
(
"fusing/glide-super-res-dummy"
,
output_loading_info
=
True
)
self
.
assertIsNotNone
(
model
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
model
.
to
(
torch_device
)
image
=
model
(
**
self
.
dummy_input
)
assert
image
is
not
None
,
"Make sure output is not None"
def
test_output_pretrained
(
self
):
model
=
GLIDESuperResUNetModel
.
from_pretrained
(
"fusing/glide-super-res-dummy"
)
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
noise
=
torch
.
randn
(
1
,
3
,
64
,
64
)
low_res
=
torch
.
randn
(
1
,
3
,
4
,
4
)
time_step
=
torch
.
tensor
([
42
]
*
noise
.
shape
[
0
])
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
,
low_res
)
output
,
_
=
torch
.
split
(
output
,
3
,
dim
=
1
)
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
expected_output_slice
=
torch
.
tensor
([
-
22.8782
,
-
23.2652
,
-
15.3966
,
-
22.8034
,
-
23.3159
,
-
15.5640
,
-
15.3970
,
-
15.4614
,
-
10.4370
])
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
class
UNetLDMModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
UNetLDMModel
@
property
def
dummy_input
(
self
):
batch_size
=
4
num_channels
=
4
sizes
=
(
32
,
32
)
noise
=
floats_tensor
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]).
to
(
torch_device
)
return
{
"x"
:
noise
,
"timesteps"
:
time_step
}
@
property
def
get_input_shape
(
self
):
return
(
4
,
32
,
32
)
@
property
def
get_output_shape
(
self
):
return
(
4
,
32
,
32
)
def
prepare_init_args_and_inputs_for_common
(
self
):
init_dict
=
{
"image_size"
:
32
,
"in_channels"
:
4
,
"out_channels"
:
4
,
"model_channels"
:
32
,
"num_res_blocks"
:
2
,
"attention_resolutions"
:
(
16
,),
"channel_mult"
:
(
1
,
2
),
"num_heads"
:
2
,
"conv_resample"
:
True
,
}
inputs_dict
=
self
.
dummy_input
return
init_dict
,
inputs_dict
def
test_from_pretrained_hub
(
self
):
model
,
loading_info
=
UNetLDMModel
.
from_pretrained
(
"fusing/unet-ldm-dummy"
,
output_loading_info
=
True
)
self
.
assertIsNotNone
(
model
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
image
=
model
(
**
self
.
dummy_input
)
assert
image
is
not
None
,
"Make sure output is not None"
def
test_output_pretrained
(
self
):
model
=
UNetLDMModel
.
from_pretrained
(
"fusing/unet-ldm-dummy"
)
model
.
eval
()
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
noise
=
torch
.
randn
(
1
,
model
.
config
.
in_channels
,
model
.
config
.
image_size
,
model
.
config
.
image_size
)
time_step
=
torch
.
tensor
([
10
]
*
noise
.
shape
[
0
])
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
)
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
expected_output_slice
=
torch
.
tensor
([
-
13.3258
,
-
20.1100
,
-
15.9873
,
-
17.6617
,
-
23.0596
,
-
17.9419
,
-
13.3675
,
-
16.1889
,
-
12.3800
])
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
class
UNetGradTTSModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
UNetGradTTSModel
@
property
def
dummy_input
(
self
):
batch_size
=
4
num_features
=
32
seq_len
=
16
noise
=
floats_tensor
((
batch_size
,
num_features
,
seq_len
)).
to
(
torch_device
)
condition
=
floats_tensor
((
batch_size
,
num_features
,
seq_len
)).
to
(
torch_device
)
mask
=
floats_tensor
((
batch_size
,
1
,
seq_len
)).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]
*
batch_size
).
to
(
torch_device
)
return
{
"x"
:
noise
,
"timesteps"
:
time_step
,
"mu"
:
condition
,
"mask"
:
mask
}
image
=
model
(
*
self
.
dummy_input
)
@
property
def
get_input_shape
(
self
):
return
(
4
,
32
,
16
)
@
property
def
get_output_shape
(
self
):
return
(
4
,
32
,
16
)
def
prepare_init_args_and_inputs_for_common
(
self
):
init_dict
=
{
"dim"
:
64
,
"groups"
:
4
,
"dim_mults"
:
(
1
,
2
),
"n_feats"
:
32
,
"pe_scale"
:
1000
,
"n_spks"
:
1
,
}
inputs_dict
=
self
.
dummy_input
return
init_dict
,
inputs_dict
def
test_from_pretrained_hub
(
self
):
model
,
loading_info
=
UNetGradTTSModel
.
from_pretrained
(
"fusing/unet-grad-tts-dummy"
,
output_loading_info
=
True
)
self
.
assertIsNotNone
(
model
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
model
.
to
(
torch_device
)
image
=
model
(
**
self
.
dummy_input
)
assert
image
is
not
None
,
"Make sure output is not None"
assert
image
is
not
None
,
"Make sure output is not None"
def
test_output_pretrained
(
self
):
model
=
UNetGradTTSModel
.
from_pretrained
(
"fusing/unet-grad-tts-dummy"
)
model
.
eval
()
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
num_features
=
model
.
config
.
n_feats
seq_len
=
16
noise
=
torch
.
randn
((
1
,
num_features
,
seq_len
))
condition
=
torch
.
randn
((
1
,
num_features
,
seq_len
))
mask
=
torch
.
randn
((
1
,
1
,
seq_len
))
time_step
=
torch
.
tensor
([
10
])
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
,
condition
,
mask
)
output_slice
=
output
[
0
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
expected_output_slice
=
torch
.
tensor
([
-
0.0690
,
-
0.0531
,
0.0633
,
-
0.0660
,
-
0.0541
,
0.0650
,
-
0.0656
,
-
0.0555
,
0.0617
])
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
class
PipelineTesterMixin
(
unittest
.
TestCase
):
class
PipelineTesterMixin
(
unittest
.
TestCase
):
def
test_from_pretrained_save_pretrained
(
self
):
def
test_from_pretrained_save_pretrained
(
self
):
...
@@ -223,7 +587,6 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -223,7 +587,6 @@ class PipelineTesterMixin(unittest.TestCase):
image
=
ldm
([
prompt
],
generator
=
generator
,
num_inference_steps
=
20
)
image
=
ldm
([
prompt
],
generator
=
generator
,
num_inference_steps
=
20
)
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
print
(
image_slice
.
shape
)
assert
image
.
shape
==
(
1
,
3
,
256
,
256
)
assert
image
.
shape
==
(
1
,
3
,
256
,
256
)
expected_slice
=
torch
.
tensor
([
0.7295
,
0.7358
,
0.7256
,
0.7435
,
0.7095
,
0.6884
,
0.7325
,
0.6921
,
0.6458
])
expected_slice
=
torch
.
tensor
([
0.7295
,
0.7358
,
0.7256
,
0.7435
,
0.7095
,
0.6884
,
0.7325
,
0.6921
,
0.6458
])
...
...
tests/test_scheduler.py
View file @
b09b152f
This diff is collapsed.
Click to expand it.
utils/check_dummies.py
View file @
b09b152f
...
@@ -20,10 +20,10 @@ import re
...
@@ -20,10 +20,10 @@ import re
# All paths are set with the intent you should run this script from the root of the repo with the command
# All paths are set with the intent you should run this script from the root of the repo with the command
# python utils/check_dummies.py
# python utils/check_dummies.py
PATH_TO_
TRANSFORM
ERS
=
"src/
transform
ers"
PATH_TO_
DIFFUS
ERS
=
"src/
diffus
ers"
# Matches is_xxx_available()
# Matches is_xxx_available()
_re_backend
=
re
.
compile
(
r
"is\_([a-z_]*)_available
(
)"
)
_re_backend
=
re
.
compile
(
r
"
if
is\_([a-z_]*)_available
\(\
)"
)
# Matches from xxx import bla
# Matches from xxx import bla
_re_single_line_import
=
re
.
compile
(
r
"\s+from\s+\S*\s+import\s+([^\(\s].*)\n"
)
_re_single_line_import
=
re
.
compile
(
r
"\s+from\s+\S*\s+import\s+([^\(\s].*)\n"
)
_re_test_backend
=
re
.
compile
(
r
"^\s+if\s+not\s+is\_[a-z]*\_available\(\)"
)
_re_test_backend
=
re
.
compile
(
r
"^\s+if\s+not\s+is\_[a-z]*\_available\(\)"
)
...
@@ -50,36 +50,30 @@ def {0}(*args, **kwargs):
...
@@ -50,36 +50,30 @@ def {0}(*args, **kwargs):
def
find_backend
(
line
):
def
find_backend
(
line
):
"""Find one (or multiple) backend in a code line of the init."""
"""Find one (or multiple) backend in a code line of the init."""
if
_re_test_backend
.
search
(
line
)
is
None
:
backends
=
_re_backend
.
findall
(
line
)
if
len
(
backends
)
==
0
:
return
None
return
None
backends
=
[
b
[
0
]
for
b
in
_re_backend
.
findall
(
line
)]
backends
.
sort
()
return
backends
[
0
]
return
"_and_"
.
join
(
backends
)
def
read_init
():
def
read_init
():
"""Read the init and extracts PyTorch, TensorFlow, SentencePiece and Tokenizers objects."""
"""Read the init and extracts PyTorch, TensorFlow, SentencePiece and Tokenizers objects."""
with
open
(
os
.
path
.
join
(
PATH_TO_
TRANSFORM
ERS
,
"__init__.py"
),
"r"
,
encoding
=
"utf-8"
,
newline
=
"
\n
"
)
as
f
:
with
open
(
os
.
path
.
join
(
PATH_TO_
DIFFUS
ERS
,
"__init__.py"
),
"r"
,
encoding
=
"utf-8"
,
newline
=
"
\n
"
)
as
f
:
lines
=
f
.
readlines
()
lines
=
f
.
readlines
()
# Get to the point we do the actual imports for type checking
# Get to the point we do the actual imports for type checking
line_index
=
0
line_index
=
0
while
not
lines
[
line_index
].
startswith
(
"if TYPE_CHECKING"
):
line_index
+=
1
backend_specific_objects
=
{}
backend_specific_objects
=
{}
# Go through the end of the file
# Go through the end of the file
while
line_index
<
len
(
lines
):
while
line_index
<
len
(
lines
):
# If the line is an if is_backend_available, we grab all objects associated.
# If the line is an if is_backend_available, we grab all objects associated.
backend
=
find_backend
(
lines
[
line_index
])
backend
=
find_backend
(
lines
[
line_index
])
if
backend
is
not
None
:
if
backend
is
not
None
:
while
not
lines
[
line_index
].
startswith
(
" else:"
):
line_index
+=
1
line_index
+=
1
objects
=
[]
objects
=
[]
line_index
+=
1
# Until we unindent, add backend objects to the list
# Until we unindent, add backend objects to the list
while
len
(
lines
[
line_index
])
<=
1
or
lines
[
line_index
].
startswith
(
"
"
*
8
):
while
not
lines
[
line_index
].
startswith
(
"
else:"
):
line
=
lines
[
line_index
]
line
=
lines
[
line_index
]
single_line_import_search
=
_re_single_line_import
.
search
(
line
)
single_line_import_search
=
_re_single_line_import
.
search
(
line
)
if
single_line_import_search
is
not
None
:
if
single_line_import_search
is
not
None
:
...
@@ -129,7 +123,7 @@ def check_dummies(overwrite=False):
...
@@ -129,7 +123,7 @@ def check_dummies(overwrite=False):
short_names
=
{
"torch"
:
"pt"
}
short_names
=
{
"torch"
:
"pt"
}
# Locate actual dummy modules and read their content.
# Locate actual dummy modules and read their content.
path
=
os
.
path
.
join
(
PATH_TO_
TRANSFORM
ERS
,
"utils"
)
path
=
os
.
path
.
join
(
PATH_TO_
DIFFUS
ERS
,
"utils"
)
dummy_file_paths
=
{
dummy_file_paths
=
{
backend
:
os
.
path
.
join
(
path
,
f
"dummy_
{
short_names
.
get
(
backend
,
backend
)
}
_objects.py"
)
backend
:
os
.
path
.
join
(
path
,
f
"dummy_
{
short_names
.
get
(
backend
,
backend
)
}
_objects.py"
)
for
backend
in
dummy_files
.
keys
()
for
backend
in
dummy_files
.
keys
()
...
@@ -147,7 +141,7 @@ def check_dummies(overwrite=False):
...
@@ -147,7 +141,7 @@ def check_dummies(overwrite=False):
if
dummy_files
[
backend
]
!=
actual_dummies
[
backend
]:
if
dummy_files
[
backend
]
!=
actual_dummies
[
backend
]:
if
overwrite
:
if
overwrite
:
print
(
print
(
f
"Updating
transform
ers.utils.dummy_
{
short_names
.
get
(
backend
,
backend
)
}
_objects.py as the main "
f
"Updating
diffus
ers.utils.dummy_
{
short_names
.
get
(
backend
,
backend
)
}
_objects.py as the main "
"__init__ has new objects."
"__init__ has new objects."
)
)
with
open
(
dummy_file_paths
[
backend
],
"w"
,
encoding
=
"utf-8"
,
newline
=
"
\n
"
)
as
f
:
with
open
(
dummy_file_paths
[
backend
],
"w"
,
encoding
=
"utf-8"
,
newline
=
"
\n
"
)
as
f
:
...
@@ -155,7 +149,7 @@ def check_dummies(overwrite=False):
...
@@ -155,7 +149,7 @@ def check_dummies(overwrite=False):
else
:
else
:
raise
ValueError
(
raise
ValueError
(
"The main __init__ has objects that are not present in "
"The main __init__ has objects that are not present in "
f
"
transform
ers.utils.dummy_
{
short_names
.
get
(
backend
,
backend
)
}
_objects.py. Run `make fix-copies` "
f
"
diffus
ers.utils.dummy_
{
short_names
.
get
(
backend
,
backend
)
}
_objects.py. Run `make fix-copies` "
"to fix this."
"to fix this."
)
)
...
...
utils/check_table.py
View file @
b09b152f
...
@@ -22,7 +22,7 @@ import re
...
@@ -22,7 +22,7 @@ import re
# All paths are set with the intent you should run this script from the root of the repo with the command
# All paths are set with the intent you should run this script from the root of the repo with the command
# python utils/check_table.py
# python utils/check_table.py
TRANSFORMERS_PATH
=
"src/
transform
ers"
TRANSFORMERS_PATH
=
"src/
diffus
ers"
PATH_TO_DOCS
=
"docs/source/en"
PATH_TO_DOCS
=
"docs/source/en"
REPO_PATH
=
"."
REPO_PATH
=
"."
...
@@ -62,13 +62,13 @@ _re_flax_models = re.compile(r"Flax(.*)(?:Model|Encoder|Decoder|ForConditionalGe
...
@@ -62,13 +62,13 @@ _re_flax_models = re.compile(r"Flax(.*)(?:Model|Encoder|Decoder|ForConditionalGe
_re_pt_models
=
re
.
compile
(
r
"(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)"
)
_re_pt_models
=
re
.
compile
(
r
"(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)"
)
# This is to make sure the
transform
ers module imported is the one in the repo.
# This is to make sure the
diffus
ers module imported is the one in the repo.
spec
=
importlib
.
util
.
spec_from_file_location
(
spec
=
importlib
.
util
.
spec_from_file_location
(
"
transform
ers"
,
"
diffus
ers"
,
os
.
path
.
join
(
TRANSFORMERS_PATH
,
"__init__.py"
),
os
.
path
.
join
(
TRANSFORMERS_PATH
,
"__init__.py"
),
submodule_search_locations
=
[
TRANSFORMERS_PATH
],
submodule_search_locations
=
[
TRANSFORMERS_PATH
],
)
)
transform
ers_module
=
spec
.
loader
.
load_module
()
diffus
ers_module
=
spec
.
loader
.
load_module
()
# Thanks to https://stackoverflow.com/questions/29916065/how-to-do-camelcase-split-in-python
# Thanks to https://stackoverflow.com/questions/29916065/how-to-do-camelcase-split-in-python
...
@@ -88,10 +88,10 @@ def _center_text(text, width):
...
@@ -88,10 +88,10 @@ def _center_text(text, width):
def
get_model_table_from_auto_modules
():
def
get_model_table_from_auto_modules
():
"""Generates an up-to-date model table from the content of the auto modules."""
"""Generates an up-to-date model table from the content of the auto modules."""
# Dictionary model names to config.
# Dictionary model names to config.
config_maping_names
=
transform
ers_module
.
models
.
auto
.
configuration_auto
.
CONFIG_MAPPING_NAMES
config_maping_names
=
diffus
ers_module
.
models
.
auto
.
configuration_auto
.
CONFIG_MAPPING_NAMES
model_name_to_config
=
{
model_name_to_config
=
{
name
:
config_maping_names
[
code
]
name
:
config_maping_names
[
code
]
for
code
,
name
in
transform
ers_module
.
MODEL_NAMES_MAPPING
.
items
()
for
code
,
name
in
diffus
ers_module
.
MODEL_NAMES_MAPPING
.
items
()
if
code
in
config_maping_names
if
code
in
config_maping_names
}
}
model_name_to_prefix
=
{
name
:
config
.
replace
(
"ConfigMixin"
,
""
)
for
name
,
config
in
model_name_to_config
.
items
()}
model_name_to_prefix
=
{
name
:
config
.
replace
(
"ConfigMixin"
,
""
)
for
name
,
config
in
model_name_to_config
.
items
()}
...
@@ -103,8 +103,8 @@ def get_model_table_from_auto_modules():
...
@@ -103,8 +103,8 @@ def get_model_table_from_auto_modules():
tf_models
=
collections
.
defaultdict
(
bool
)
tf_models
=
collections
.
defaultdict
(
bool
)
flax_models
=
collections
.
defaultdict
(
bool
)
flax_models
=
collections
.
defaultdict
(
bool
)
# Let's lookup through all
transform
ers object (once).
# Let's lookup through all
diffus
ers object (once).
for
attr_name
in
dir
(
transform
ers_module
):
for
attr_name
in
dir
(
diffus
ers_module
):
lookup_dict
=
None
lookup_dict
=
None
if
attr_name
.
endswith
(
"Tokenizer"
):
if
attr_name
.
endswith
(
"Tokenizer"
):
lookup_dict
=
slow_tokenizers
lookup_dict
=
slow_tokenizers
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment