Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
99904459
Unverified
Commit
99904459
authored
Nov 30, 2022
by
Anton Lozhkov
Committed by
GitHub
Nov 30, 2022
Browse files
Bump to 0.10.0.dev0 + deprecations (#1490)
parent
eeeb28a9
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
20 additions
and
155 deletions
+20
-155
examples/unconditional_image_generation/train_unconditional.py
...les/unconditional_image_generation/train_unconditional.py
+1
-6
setup.py
setup.py
+1
-1
src/diffusers/__init__.py
src/diffusers/__init__.py
+1
-1
src/diffusers/hub_utils.py
src/diffusers/hub_utils.py
+2
-118
src/diffusers/modeling_utils.py
src/diffusers/modeling_utils.py
+0
-14
src/diffusers/pipelines/ddpm/pipeline_ddpm.py
src/diffusers/pipelines/ddpm/pipeline_ddpm.py
+1
-1
src/diffusers/schedulers/scheduling_ddim.py
src/diffusers/schedulers/scheduling_ddim.py
+1
-1
src/diffusers/schedulers/scheduling_ddim_flax.py
src/diffusers/schedulers/scheduling_ddim_flax.py
+1
-1
src/diffusers/schedulers/scheduling_ddpm.py
src/diffusers/schedulers/scheduling_ddpm.py
+2
-2
src/diffusers/schedulers/scheduling_ddpm_flax.py
src/diffusers/schedulers/scheduling_ddpm_flax.py
+2
-2
src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
+1
-1
src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py
...ffusers/schedulers/scheduling_dpmsolver_multistep_flax.py
+1
-1
tests/pipelines/ddpm/test_ddpm.py
tests/pipelines/ddpm/test_ddpm.py
+1
-1
tests/test_config.py
tests/test_config.py
+1
-1
tests/test_scheduler.py
tests/test_scheduler.py
+2
-2
tests/test_scheduler_flax.py
tests/test_scheduler_flax.py
+2
-2
No files found.
examples/unconditional_image_generation/train_unconditional.py
View file @
99904459
...
@@ -14,7 +14,6 @@ from datasets import load_dataset
...
@@ -14,7 +14,6 @@ from datasets import load_dataset
from
diffusers
import
DDPMPipeline
,
DDPMScheduler
,
UNet2DModel
,
__version__
from
diffusers
import
DDPMPipeline
,
DDPMScheduler
,
UNet2DModel
,
__version__
from
diffusers.optimization
import
get_scheduler
from
diffusers.optimization
import
get_scheduler
from
diffusers.training_utils
import
EMAModel
from
diffusers.training_utils
import
EMAModel
from
diffusers.utils
import
deprecate
from
huggingface_hub
import
HfFolder
,
Repository
,
whoami
from
huggingface_hub
import
HfFolder
,
Repository
,
whoami
from
packaging
import
version
from
packaging
import
version
from
torchvision.transforms
import
(
from
torchvision.transforms
import
(
...
@@ -417,11 +416,7 @@ def main(args):
...
@@ -417,11 +416,7 @@ def main(args):
scheduler
=
noise_scheduler
,
scheduler
=
noise_scheduler
,
)
)
deprecate
(
"todo: remove this check"
,
"0.10.0"
,
"when the most used version is >= 0.8.0"
)
generator
=
torch
.
Generator
(
device
=
pipeline
.
device
).
manual_seed
(
0
)
if
diffusers_version
<
version
.
parse
(
"0.8.0"
):
generator
=
torch
.
manual_seed
(
0
)
else
:
generator
=
torch
.
Generator
(
device
=
pipeline
.
device
).
manual_seed
(
0
)
# run pipeline in inference (sample random noise and denoise)
# run pipeline in inference (sample random noise and denoise)
images
=
pipeline
(
images
=
pipeline
(
generator
=
generator
,
generator
=
generator
,
...
...
setup.py
View file @
99904459
...
@@ -214,7 +214,7 @@ install_requires = [
...
@@ -214,7 +214,7 @@ install_requires = [
setup
(
setup
(
name
=
"diffusers"
,
name
=
"diffusers"
,
version
=
"0.
9.
0"
,
# expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version
=
"0.
10.0.dev
0"
,
# expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
description
=
"Diffusers"
,
description
=
"Diffusers"
,
long_description
=
open
(
"README.md"
,
"r"
,
encoding
=
"utf-8"
).
read
(),
long_description
=
open
(
"README.md"
,
"r"
,
encoding
=
"utf-8"
).
read
(),
long_description_content_type
=
"text/markdown"
,
long_description_content_type
=
"text/markdown"
,
...
...
src/diffusers/__init__.py
View file @
99904459
...
@@ -9,7 +9,7 @@ from .utils import (
...
@@ -9,7 +9,7 @@ from .utils import (
)
)
__version__
=
"0.
9.
0"
__version__
=
"0.
10.0.dev
0"
from
.configuration_utils
import
ConfigMixin
from
.configuration_utils
import
ConfigMixin
from
.onnx_utils
import
OnnxRuntimeModel
from
.onnx_utils
import
OnnxRuntimeModel
...
...
src/diffusers/hub_utils.py
View file @
99904459
...
@@ -15,16 +15,15 @@
...
@@ -15,16 +15,15 @@
import
os
import
os
import
shutil
import
sys
import
sys
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Dict
,
Optional
,
Union
from
typing
import
Dict
,
Optional
,
Union
from
uuid
import
uuid4
from
uuid
import
uuid4
from
huggingface_hub
import
HfFolder
,
Repository
,
whoami
from
huggingface_hub
import
HfFolder
,
whoami
from
.
import
__version__
from
.
import
__version__
from
.utils
import
ENV_VARS_TRUE_VALUES
,
deprecate
,
logging
from
.utils
import
ENV_VARS_TRUE_VALUES
,
logging
from
.utils.import_utils
import
(
from
.utils.import_utils
import
(
_flax_version
,
_flax_version
,
_jax_version
,
_jax_version
,
...
@@ -83,121 +82,6 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
...
@@ -83,121 +82,6 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
return
f
"
{
organization
}
/
{
model_id
}
"
return
f
"
{
organization
}
/
{
model_id
}
"
def
init_git_repo
(
args
,
at_init
:
bool
=
False
):
"""
Args:
Initializes a git repo in `args.hub_model_id`.
at_init (`bool`, *optional*, defaults to `False`):
Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is `True`
and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped out.
"""
deprecation_message
=
(
"Please use `huggingface_hub.Repository`. "
"See `examples/unconditional_image_generation/train_unconditional.py` for an example."
)
deprecate
(
"init_git_repo()"
,
"0.10.0"
,
deprecation_message
)
if
hasattr
(
args
,
"local_rank"
)
and
args
.
local_rank
not
in
[
-
1
,
0
]:
return
hub_token
=
args
.
hub_token
if
hasattr
(
args
,
"hub_token"
)
else
None
use_auth_token
=
True
if
hub_token
is
None
else
hub_token
if
not
hasattr
(
args
,
"hub_model_id"
)
or
args
.
hub_model_id
is
None
:
repo_name
=
Path
(
args
.
output_dir
).
absolute
().
name
else
:
repo_name
=
args
.
hub_model_id
if
"/"
not
in
repo_name
:
repo_name
=
get_full_repo_name
(
repo_name
,
token
=
hub_token
)
try
:
repo
=
Repository
(
args
.
output_dir
,
clone_from
=
repo_name
,
use_auth_token
=
use_auth_token
,
private
=
args
.
hub_private_repo
,
)
except
EnvironmentError
:
if
args
.
overwrite_output_dir
and
at_init
:
# Try again after wiping output_dir
shutil
.
rmtree
(
args
.
output_dir
)
repo
=
Repository
(
args
.
output_dir
,
clone_from
=
repo_name
,
use_auth_token
=
use_auth_token
,
)
else
:
raise
repo
.
git_pull
()
# By default, ignore the checkpoint folders
if
not
os
.
path
.
exists
(
os
.
path
.
join
(
args
.
output_dir
,
".gitignore"
)):
with
open
(
os
.
path
.
join
(
args
.
output_dir
,
".gitignore"
),
"w"
,
encoding
=
"utf-8"
)
as
writer
:
writer
.
writelines
([
"checkpoint-*/"
])
return
repo
def
push_to_hub
(
args
,
pipeline
,
repo
:
Repository
,
commit_message
:
Optional
[
str
]
=
"End of training"
,
blocking
:
bool
=
True
,
**
kwargs
,
)
->
str
:
"""
Parameters:
Upload *self.model* and *self.tokenizer* to the 🤗 model hub on the repo *self.args.hub_model_id*.
commit_message (`str`, *optional*, defaults to `"End of training"`):
Message to commit while pushing.
blocking (`bool`, *optional*, defaults to `True`):
Whether the function should return only when the `git push` has finished.
kwargs:
Additional keyword arguments passed along to [`create_model_card`].
Returns:
The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of the
commit and an object to track the progress of the commit if `blocking=True`
"""
deprecation_message
=
(
"Please use `huggingface_hub.Repository` and `Repository.push_to_hub()`. "
"See `examples/unconditional_image_generation/train_unconditional.py` for an example."
)
deprecate
(
"push_to_hub()"
,
"0.10.0"
,
deprecation_message
)
if
not
hasattr
(
args
,
"hub_model_id"
)
or
args
.
hub_model_id
is
None
:
model_name
=
Path
(
args
.
output_dir
).
name
else
:
model_name
=
args
.
hub_model_id
.
split
(
"/"
)[
-
1
]
output_dir
=
args
.
output_dir
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
logger
.
info
(
f
"Saving pipeline checkpoint to
{
output_dir
}
"
)
pipeline
.
save_pretrained
(
output_dir
)
# Only push from one node.
if
hasattr
(
args
,
"local_rank"
)
and
args
.
local_rank
not
in
[
-
1
,
0
]:
return
# Cancel any async push in progress if blocking=True. The commits will all be pushed together.
if
(
blocking
and
len
(
repo
.
command_queue
)
>
0
and
repo
.
command_queue
[
-
1
]
is
not
None
and
not
repo
.
command_queue
[
-
1
].
is_done
):
repo
.
command_queue
[
-
1
].
_process
.
kill
()
git_head_commit_url
=
repo
.
push_to_hub
(
commit_message
=
commit_message
,
blocking
=
blocking
,
auto_lfs_prune
=
True
)
# push separately the model card to be independent from the rest of the model
create_model_card
(
args
,
model_name
=
model_name
)
try
:
repo
.
push_to_hub
(
commit_message
=
"update model card README.md"
,
blocking
=
blocking
,
auto_lfs_prune
=
True
)
except
EnvironmentError
as
exc
:
logger
.
error
(
f
"Error pushing update to the model card. Please read logs and retry.
\n
$
{
exc
}
"
)
return
git_head_commit_url
def
create_model_card
(
args
,
model_name
):
def
create_model_card
(
args
,
model_name
):
if
not
is_modelcards_available
:
if
not
is_modelcards_available
:
raise
ValueError
(
raise
ValueError
(
...
...
src/diffusers/modeling_utils.py
View file @
99904459
...
@@ -666,20 +666,6 @@ class ModelMixin(torch.nn.Module):
...
@@ -666,20 +666,6 @@ class ModelMixin(torch.nn.Module):
return
sum
(
p
.
numel
()
for
p
in
self
.
parameters
()
if
p
.
requires_grad
or
not
only_trainable
)
return
sum
(
p
.
numel
()
for
p
in
self
.
parameters
()
if
p
.
requires_grad
or
not
only_trainable
)
def
unwrap_model
(
model
:
torch
.
nn
.
Module
)
->
torch
.
nn
.
Module
:
"""
Recursively unwraps a model from potential containers (as used in distributed training).
Args:
model (`torch.nn.Module`): The model to unwrap.
"""
# since there could be multiple levels of wrapping, unwrap recursively
if
hasattr
(
model
,
"module"
):
return
unwrap_model
(
model
.
module
)
else
:
return
model
def
_get_model_file
(
def
_get_model_file
(
pretrained_model_name_or_path
,
pretrained_model_name_or_path
,
*
,
*
,
...
...
src/diffusers/pipelines/ddpm/pipeline_ddpm.py
View file @
99904459
...
@@ -73,7 +73,7 @@ class DDPMPipeline(DiffusionPipeline):
...
@@ -73,7 +73,7 @@ class DDPMPipeline(DiffusionPipeline):
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.1
0
.0"
,
message
,
take_from
=
kwargs
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.1
1
.0"
,
message
,
take_from
=
kwargs
)
if
predict_epsilon
is
not
None
:
if
predict_epsilon
is
not
None
:
new_config
=
dict
(
self
.
scheduler
.
config
)
new_config
=
dict
(
self
.
scheduler
.
config
)
...
...
src/diffusers/schedulers/scheduling_ddim.py
View file @
99904459
...
@@ -134,7 +134,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -134,7 +134,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DDIMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
" DDIMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.1
0
.0"
,
message
,
take_from
=
kwargs
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.1
1
.0"
,
message
,
take_from
=
kwargs
)
if
predict_epsilon
is
not
None
:
if
predict_epsilon
is
not
None
:
self
.
register_to_config
(
prediction_type
=
"epsilon"
if
predict_epsilon
else
"sample"
)
self
.
register_to_config
(
prediction_type
=
"epsilon"
if
predict_epsilon
else
"sample"
)
...
...
src/diffusers/schedulers/scheduling_ddim_flax.py
View file @
99904459
...
@@ -138,7 +138,7 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
...
@@ -138,7 +138,7 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" FlaxDDIMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
" FlaxDDIMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.1
0
.0"
,
message
,
take_from
=
kwargs
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.1
1
.0"
,
message
,
take_from
=
kwargs
)
if
predict_epsilon
is
not
None
:
if
predict_epsilon
is
not
None
:
self
.
register_to_config
(
prediction_type
=
"epsilon"
if
predict_epsilon
else
"sample"
)
self
.
register_to_config
(
prediction_type
=
"epsilon"
if
predict_epsilon
else
"sample"
)
...
...
src/diffusers/schedulers/scheduling_ddpm.py
View file @
99904459
...
@@ -125,7 +125,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -125,7 +125,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.1
0
.0"
,
message
,
take_from
=
kwargs
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.1
1
.0"
,
message
,
take_from
=
kwargs
)
if
predict_epsilon
is
not
None
:
if
predict_epsilon
is
not
None
:
self
.
register_to_config
(
prediction_type
=
"epsilon"
if
predict_epsilon
else
"sample"
)
self
.
register_to_config
(
prediction_type
=
"epsilon"
if
predict_epsilon
else
"sample"
)
...
@@ -255,7 +255,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -255,7 +255,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.1
0
.0"
,
message
,
take_from
=
kwargs
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.1
1
.0"
,
message
,
take_from
=
kwargs
)
if
predict_epsilon
is
not
None
:
if
predict_epsilon
is
not
None
:
new_config
=
dict
(
self
.
config
)
new_config
=
dict
(
self
.
config
)
new_config
[
"prediction_type"
]
=
"epsilon"
if
predict_epsilon
else
"sample"
new_config
[
"prediction_type"
]
=
"epsilon"
if
predict_epsilon
else
"sample"
...
...
src/diffusers/schedulers/scheduling_ddpm_flax.py
View file @
99904459
...
@@ -132,7 +132,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
...
@@ -132,7 +132,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" FlaxDDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
" FlaxDDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.1
0
.0"
,
message
,
take_from
=
kwargs
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.1
1
.0"
,
message
,
take_from
=
kwargs
)
if
predict_epsilon
is
not
None
:
if
predict_epsilon
is
not
None
:
self
.
register_to_config
(
prediction_type
=
"epsilon"
if
predict_epsilon
else
"sample"
)
self
.
register_to_config
(
prediction_type
=
"epsilon"
if
predict_epsilon
else
"sample"
)
...
@@ -239,7 +239,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
...
@@ -239,7 +239,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" FlaxDDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
" FlaxDDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.1
0
.0"
,
message
,
take_from
=
kwargs
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.1
1
.0"
,
message
,
take_from
=
kwargs
)
if
predict_epsilon
is
not
None
:
if
predict_epsilon
is
not
None
:
new_config
=
dict
(
self
.
config
)
new_config
=
dict
(
self
.
config
)
new_config
[
"prediction_type"
]
=
"epsilon"
if
predict_epsilon
else
"sample"
new_config
[
"prediction_type"
]
=
"epsilon"
if
predict_epsilon
else
"sample"
...
...
src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
View file @
99904459
...
@@ -142,7 +142,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
...
@@ -142,7 +142,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DPMSolverMultistepScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
" DPMSolverMultistepScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.1
0
.0"
,
message
,
take_from
=
kwargs
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.1
1
.0"
,
message
,
take_from
=
kwargs
)
if
predict_epsilon
is
not
None
:
if
predict_epsilon
is
not
None
:
self
.
register_to_config
(
prediction_type
=
"epsilon"
if
predict_epsilon
else
"sample"
)
self
.
register_to_config
(
prediction_type
=
"epsilon"
if
predict_epsilon
else
"sample"
)
...
...
src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py
View file @
99904459
...
@@ -177,7 +177,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
...
@@ -177,7 +177,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" FlaxDPMSolverMultistepScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
" FlaxDPMSolverMultistepScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.1
0
.0"
,
message
,
take_from
=
kwargs
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.1
1
.0"
,
message
,
take_from
=
kwargs
)
if
predict_epsilon
is
not
None
:
if
predict_epsilon
is
not
None
:
self
.
register_to_config
(
prediction_type
=
"epsilon"
if
predict_epsilon
else
"sample"
)
self
.
register_to_config
(
prediction_type
=
"epsilon"
if
predict_epsilon
else
"sample"
)
...
...
tests/pipelines/ddpm/test_ddpm.py
View file @
99904459
...
@@ -69,7 +69,7 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
...
@@ -69,7 +69,7 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
assert
np
.
abs
(
image_from_tuple_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
assert
np
.
abs
(
image_from_tuple_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
def
test_inference_deprecated_predict_epsilon
(
self
):
def
test_inference_deprecated_predict_epsilon
(
self
):
deprecate
(
"remove this test"
,
"0.1
0
.0"
,
"remove"
)
deprecate
(
"remove this test"
,
"0.1
1
.0"
,
"remove"
)
unet
=
self
.
dummy_uncond_unet
unet
=
self
.
dummy_uncond_unet
scheduler
=
DDPMScheduler
(
predict_epsilon
=
False
)
scheduler
=
DDPMScheduler
(
predict_epsilon
=
False
)
...
...
tests/test_config.py
View file @
99904459
...
@@ -203,7 +203,7 @@ class ConfigTester(unittest.TestCase):
...
@@ -203,7 +203,7 @@ class ConfigTester(unittest.TestCase):
ddpm_2
=
DDPMScheduler
.
from_pretrained
(
"google/ddpm-celebahq-256"
,
beta_start
=
88
)
ddpm_2
=
DDPMScheduler
.
from_pretrained
(
"google/ddpm-celebahq-256"
,
beta_start
=
88
)
with
CaptureLogger
(
logger
)
as
cap_logger
:
with
CaptureLogger
(
logger
)
as
cap_logger
:
deprecate
(
"remove this case"
,
"0.1
0
.0"
,
"remove"
)
deprecate
(
"remove this case"
,
"0.1
1
.0"
,
"remove"
)
ddpm_3
=
DDPMScheduler
.
from_pretrained
(
ddpm_3
=
DDPMScheduler
.
from_pretrained
(
"hf-internal-testing/tiny-stable-diffusion-torch"
,
"hf-internal-testing/tiny-stable-diffusion-torch"
,
subfolder
=
"scheduler"
,
subfolder
=
"scheduler"
,
...
...
tests/test_scheduler.py
View file @
99904459
...
@@ -639,12 +639,12 @@ class DDPMSchedulerTest(SchedulerCommonTest):
...
@@ -639,12 +639,12 @@ class DDPMSchedulerTest(SchedulerCommonTest):
self
.
check_over_configs
(
prediction_type
=
prediction_type
)
self
.
check_over_configs
(
prediction_type
=
prediction_type
)
def
test_deprecated_predict_epsilon
(
self
):
def
test_deprecated_predict_epsilon
(
self
):
deprecate
(
"remove this test"
,
"0.1
0
.0"
,
"remove"
)
deprecate
(
"remove this test"
,
"0.1
1
.0"
,
"remove"
)
for
predict_epsilon
in
[
True
,
False
]:
for
predict_epsilon
in
[
True
,
False
]:
self
.
check_over_configs
(
predict_epsilon
=
predict_epsilon
)
self
.
check_over_configs
(
predict_epsilon
=
predict_epsilon
)
def
test_deprecated_epsilon
(
self
):
def
test_deprecated_epsilon
(
self
):
deprecate
(
"remove this test"
,
"0.1
0
.0"
,
"remove"
)
deprecate
(
"remove this test"
,
"0.1
1
.0"
,
"remove"
)
scheduler_class
=
self
.
scheduler_classes
[
0
]
scheduler_class
=
self
.
scheduler_classes
[
0
]
scheduler_config
=
self
.
get_scheduler_config
()
scheduler_config
=
self
.
get_scheduler_config
()
...
...
tests/test_scheduler_flax.py
View file @
99904459
...
@@ -626,12 +626,12 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
...
@@ -626,12 +626,12 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
self
.
check_over_configs
(
prediction_type
=
prediction_type
)
self
.
check_over_configs
(
prediction_type
=
prediction_type
)
def
test_deprecated_predict_epsilon
(
self
):
def
test_deprecated_predict_epsilon
(
self
):
deprecate
(
"remove this test"
,
"0.1
0
.0"
,
"remove"
)
deprecate
(
"remove this test"
,
"0.1
1
.0"
,
"remove"
)
for
predict_epsilon
in
[
True
,
False
]:
for
predict_epsilon
in
[
True
,
False
]:
self
.
check_over_configs
(
predict_epsilon
=
predict_epsilon
)
self
.
check_over_configs
(
predict_epsilon
=
predict_epsilon
)
def
test_deprecated_predict_epsilon_to_prediction_type
(
self
):
def
test_deprecated_predict_epsilon_to_prediction_type
(
self
):
deprecate
(
"remove this test"
,
"0.1
0
.0"
,
"remove"
)
deprecate
(
"remove this test"
,
"0.1
1
.0"
,
"remove"
)
for
scheduler_class
in
self
.
scheduler_classes
:
for
scheduler_class
in
self
.
scheduler_classes
:
scheduler_config
=
self
.
get_scheduler_config
(
predict_epsilon
=
True
)
scheduler_config
=
self
.
get_scheduler_config
(
predict_epsilon
=
True
)
scheduler
=
scheduler_class
.
from_config
(
scheduler_config
)
scheduler
=
scheduler_class
.
from_config
(
scheduler_config
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment