Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
99904459
Unverified
Commit
99904459
authored
Nov 30, 2022
by
Anton Lozhkov
Committed by
GitHub
Nov 30, 2022
Browse files
Bump to 0.10.0.dev0 + deprecations (#1490)
parent
eeeb28a9
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
20 additions
and
155 deletions
+20
-155
examples/unconditional_image_generation/train_unconditional.py
...les/unconditional_image_generation/train_unconditional.py
+1
-6
setup.py
setup.py
+1
-1
src/diffusers/__init__.py
src/diffusers/__init__.py
+1
-1
src/diffusers/hub_utils.py
src/diffusers/hub_utils.py
+2
-118
src/diffusers/modeling_utils.py
src/diffusers/modeling_utils.py
+0
-14
src/diffusers/pipelines/ddpm/pipeline_ddpm.py
src/diffusers/pipelines/ddpm/pipeline_ddpm.py
+1
-1
src/diffusers/schedulers/scheduling_ddim.py
src/diffusers/schedulers/scheduling_ddim.py
+1
-1
src/diffusers/schedulers/scheduling_ddim_flax.py
src/diffusers/schedulers/scheduling_ddim_flax.py
+1
-1
src/diffusers/schedulers/scheduling_ddpm.py
src/diffusers/schedulers/scheduling_ddpm.py
+2
-2
src/diffusers/schedulers/scheduling_ddpm_flax.py
src/diffusers/schedulers/scheduling_ddpm_flax.py
+2
-2
src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
+1
-1
src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py
...ffusers/schedulers/scheduling_dpmsolver_multistep_flax.py
+1
-1
tests/pipelines/ddpm/test_ddpm.py
tests/pipelines/ddpm/test_ddpm.py
+1
-1
tests/test_config.py
tests/test_config.py
+1
-1
tests/test_scheduler.py
tests/test_scheduler.py
+2
-2
tests/test_scheduler_flax.py
tests/test_scheduler_flax.py
+2
-2
No files found.
examples/unconditional_image_generation/train_unconditional.py
View file @
99904459
...
@@ -14,7 +14,6 @@ from datasets import load_dataset
...
@@ -14,7 +14,6 @@ from datasets import load_dataset
from
diffusers
import
DDPMPipeline
,
DDPMScheduler
,
UNet2DModel
,
__version__
from
diffusers
import
DDPMPipeline
,
DDPMScheduler
,
UNet2DModel
,
__version__
from
diffusers.optimization
import
get_scheduler
from
diffusers.optimization
import
get_scheduler
from
diffusers.training_utils
import
EMAModel
from
diffusers.training_utils
import
EMAModel
from
diffusers.utils
import
deprecate
from
huggingface_hub
import
HfFolder
,
Repository
,
whoami
from
huggingface_hub
import
HfFolder
,
Repository
,
whoami
from
packaging
import
version
from
packaging
import
version
from
torchvision.transforms
import
(
from
torchvision.transforms
import
(
...
@@ -417,11 +416,7 @@ def main(args):
...
@@ -417,11 +416,7 @@ def main(args):
scheduler
=
noise_scheduler
,
scheduler
=
noise_scheduler
,
)
)
deprecate
(
"todo: remove this check"
,
"0.10.0"
,
"when the most used version is >= 0.8.0"
)
generator
=
torch
.
Generator
(
device
=
pipeline
.
device
).
manual_seed
(
0
)
if
diffusers_version
<
version
.
parse
(
"0.8.0"
):
generator
=
torch
.
manual_seed
(
0
)
else
:
generator
=
torch
.
Generator
(
device
=
pipeline
.
device
).
manual_seed
(
0
)
# run pipeline in inference (sample random noise and denoise)
# run pipeline in inference (sample random noise and denoise)
images
=
pipeline
(
images
=
pipeline
(
generator
=
generator
,
generator
=
generator
,
...
...
setup.py
View file @
99904459
...
@@ -214,7 +214,7 @@ install_requires = [
...
@@ -214,7 +214,7 @@ install_requires = [
setup
(
setup
(
name
=
"diffusers"
,
name
=
"diffusers"
,
version
=
"0.
9.
0"
,
# expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version
=
"0.
10.0.dev
0"
,
# expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
description
=
"Diffusers"
,
description
=
"Diffusers"
,
long_description
=
open
(
"README.md"
,
"r"
,
encoding
=
"utf-8"
).
read
(),
long_description
=
open
(
"README.md"
,
"r"
,
encoding
=
"utf-8"
).
read
(),
long_description_content_type
=
"text/markdown"
,
long_description_content_type
=
"text/markdown"
,
...
...
src/diffusers/__init__.py
View file @
99904459
...
@@ -9,7 +9,7 @@ from .utils import (
...
@@ -9,7 +9,7 @@ from .utils import (
)
)
__version__
=
"0.
9.
0"
__version__
=
"0.
10.0.dev
0"
from
.configuration_utils
import
ConfigMixin
from
.configuration_utils
import
ConfigMixin
from
.onnx_utils
import
OnnxRuntimeModel
from
.onnx_utils
import
OnnxRuntimeModel
...
...
src/diffusers/hub_utils.py
View file @
99904459
...
@@ -15,16 +15,15 @@
...
@@ -15,16 +15,15 @@
import
os
import
os
import
shutil
import
sys
import
sys
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Dict
,
Optional
,
Union
from
typing
import
Dict
,
Optional
,
Union
from
uuid
import
uuid4
from
uuid
import
uuid4
from
huggingface_hub
import
HfFolder
,
Repository
,
whoami
from
huggingface_hub
import
HfFolder
,
whoami
from
.
import
__version__
from
.
import
__version__
from
.utils
import
ENV_VARS_TRUE_VALUES
,
deprecate
,
logging
from
.utils
import
ENV_VARS_TRUE_VALUES
,
logging
from
.utils.import_utils
import
(
from
.utils.import_utils
import
(
_flax_version
,
_flax_version
,
_jax_version
,
_jax_version
,
...
@@ -83,121 +82,6 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
...
@@ -83,121 +82,6 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
return
f
"
{
organization
}
/
{
model_id
}
"
return
f
"
{
organization
}
/
{
model_id
}
"
def
init_git_repo
(
args
,
at_init
:
bool
=
False
):
"""
Args:
Initializes a git repo in `args.hub_model_id`.
at_init (`bool`, *optional*, defaults to `False`):
Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is `True`
and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped out.
"""
deprecation_message
=
(
"Please use `huggingface_hub.Repository`. "
"See `examples/unconditional_image_generation/train_unconditional.py` for an example."
)
deprecate
(
"init_git_repo()"
,
"0.10.0"
,
deprecation_message
)
if
hasattr
(
args
,
"local_rank"
)
and
args
.
local_rank
not
in
[
-
1
,
0
]:
return
hub_token
=
args
.
hub_token
if
hasattr
(
args
,
"hub_token"
)
else
None
use_auth_token
=
True
if
hub_token
is
None
else
hub_token
if
not
hasattr
(
args
,
"hub_model_id"
)
or
args
.
hub_model_id
is
None
:
repo_name
=
Path
(
args
.
output_dir
).
absolute
().
name
else
:
repo_name
=
args
.
hub_model_id
if
"/"
not
in
repo_name
:
repo_name
=
get_full_repo_name
(
repo_name
,
token
=
hub_token
)
try
:
repo
=
Repository
(
args
.
output_dir
,
clone_from
=
repo_name
,
use_auth_token
=
use_auth_token
,
private
=
args
.
hub_private_repo
,
)
except
EnvironmentError
:
if
args
.
overwrite_output_dir
and
at_init
:
# Try again after wiping output_dir
shutil
.
rmtree
(
args
.
output_dir
)
repo
=
Repository
(
args
.
output_dir
,
clone_from
=
repo_name
,
use_auth_token
=
use_auth_token
,
)
else
:
raise
repo
.
git_pull
()
# By default, ignore the checkpoint folders
if
not
os
.
path
.
exists
(
os
.
path
.
join
(
args
.
output_dir
,
".gitignore"
)):
with
open
(
os
.
path
.
join
(
args
.
output_dir
,
".gitignore"
),
"w"
,
encoding
=
"utf-8"
)
as
writer
:
writer
.
writelines
([
"checkpoint-*/"
])
return
repo
def
push_to_hub
(
args
,
pipeline
,
repo
:
Repository
,
commit_message
:
Optional
[
str
]
=
"End of training"
,
blocking
:
bool
=
True
,
**
kwargs
,
)
->
str
:
"""
Parameters:
Upload *self.model* and *self.tokenizer* to the 🤗 model hub on the repo *self.args.hub_model_id*.
commit_message (`str`, *optional*, defaults to `"End of training"`):
Message to commit while pushing.
blocking (`bool`, *optional*, defaults to `True`):
Whether the function should return only when the `git push` has finished.
kwargs:
Additional keyword arguments passed along to [`create_model_card`].
Returns:
The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of the
commit and an object to track the progress of the commit if `blocking=True`
"""
deprecation_message
=
(
"Please use `huggingface_hub.Repository` and `Repository.push_to_hub()`. "
"See `examples/unconditional_image_generation/train_unconditional.py` for an example."
)
deprecate
(
"push_to_hub()"
,
"0.10.0"
,
deprecation_message
)
if
not
hasattr
(
args
,
"hub_model_id"
)
or
args
.
hub_model_id
is
None
:
model_name
=
Path
(
args
.
output_dir
).
name
else
:
model_name
=
args
.
hub_model_id
.
split
(
"/"
)[
-
1
]
output_dir
=
args
.
output_dir
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
logger
.
info
(
f
"Saving pipeline checkpoint to
{
output_dir
}
"
)
pipeline
.
save_pretrained
(
output_dir
)
# Only push from one node.
if
hasattr
(
args
,
"local_rank"
)
and
args
.
local_rank
not
in
[
-
1
,
0
]:
return
# Cancel any async push in progress if blocking=True. The commits will all be pushed together.
if
(
blocking
and
len
(
repo
.
command_queue
)
>
0
and
repo
.
command_queue
[
-
1
]
is
not
None
and
not
repo
.
command_queue
[
-
1
].
is_done
):
repo
.
command_queue
[
-
1
].
_process
.
kill
()
git_head_commit_url
=
repo
.
push_to_hub
(
commit_message
=
commit_message
,
blocking
=
blocking
,
auto_lfs_prune
=
True
)
# push separately the model card to be independent from the rest of the model
create_model_card
(
args
,
model_name
=
model_name
)
try
:
repo
.
push_to_hub
(
commit_message
=
"update model card README.md"
,
blocking
=
blocking
,
auto_lfs_prune
=
True
)
except
EnvironmentError
as
exc
:
logger
.
error
(
f
"Error pushing update to the model card. Please read logs and retry.
\n
$
{
exc
}
"
)
return
git_head_commit_url
def
create_model_card
(
args
,
model_name
):
def
create_model_card
(
args
,
model_name
):
if
not
is_modelcards_available
:
if
not
is_modelcards_available
:
raise
ValueError
(
raise
ValueError
(
...
...
src/diffusers/modeling_utils.py
View file @
99904459
...
@@ -666,20 +666,6 @@ class ModelMixin(torch.nn.Module):
...
@@ -666,20 +666,6 @@ class ModelMixin(torch.nn.Module):
return
sum
(
p
.
numel
()
for
p
in
self
.
parameters
()
if
p
.
requires_grad
or
not
only_trainable
)
return
sum
(
p
.
numel
()
for
p
in
self
.
parameters
()
if
p
.
requires_grad
or
not
only_trainable
)
def
unwrap_model
(
model
:
torch
.
nn
.
Module
)
->
torch
.
nn
.
Module
:
"""
Recursively unwraps a model from potential containers (as used in distributed training).
Args:
model (`torch.nn.Module`): The model to unwrap.
"""
# since there could be multiple levels of wrapping, unwrap recursively
if
hasattr
(
model
,
"module"
):
return
unwrap_model
(
model
.
module
)
else
:
return
model
def
_get_model_file
(
def
_get_model_file
(
pretrained_model_name_or_path
,
pretrained_model_name_or_path
,
*
,
*
,
...
...
src/diffusers/pipelines/ddpm/pipeline_ddpm.py
View file @
99904459
...
@@ -73,7 +73,7 @@ class DDPMPipeline(DiffusionPipeline):
...
@@ -73,7 +73,7 @@ class DDPMPipeline(DiffusionPipeline):
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.1
0
.0"
,
message
,
take_from
=
kwargs
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.1
1
.0"
,
message
,
take_from
=
kwargs
)
if
predict_epsilon
is
not
None
:
if
predict_epsilon
is
not
None
:
new_config
=
dict
(
self
.
scheduler
.
config
)
new_config
=
dict
(
self
.
scheduler
.
config
)
...
...
src/diffusers/schedulers/scheduling_ddim.py
View file @
99904459
...
@@ -134,7 +134,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -134,7 +134,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DDIMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
" DDIMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.1
0
.0"
,
message
,
take_from
=
kwargs
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.1
1
.0"
,
message
,
take_from
=
kwargs
)
if
predict_epsilon
is
not
None
:
if
predict_epsilon
is
not
None
:
self
.
register_to_config
(
prediction_type
=
"epsilon"
if
predict_epsilon
else
"sample"
)
self
.
register_to_config
(
prediction_type
=
"epsilon"
if
predict_epsilon
else
"sample"
)
...
...
src/diffusers/schedulers/scheduling_ddim_flax.py
View file @
99904459
...
@@ -138,7 +138,7 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
...
@@ -138,7 +138,7 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" FlaxDDIMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
" FlaxDDIMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.1
0
.0"
,
message
,
take_from
=
kwargs
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.1
1
.0"
,
message
,
take_from
=
kwargs
)
if
predict_epsilon
is
not
None
:
if
predict_epsilon
is
not
None
:
self
.
register_to_config
(
prediction_type
=
"epsilon"
if
predict_epsilon
else
"sample"
)
self
.
register_to_config
(
prediction_type
=
"epsilon"
if
predict_epsilon
else
"sample"
)
...
...
src/diffusers/schedulers/scheduling_ddpm.py
View file @
99904459
...
@@ -125,7 +125,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -125,7 +125,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.1
0
.0"
,
message
,
take_from
=
kwargs
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.1
1
.0"
,
message
,
take_from
=
kwargs
)
if
predict_epsilon
is
not
None
:
if
predict_epsilon
is
not
None
:
self
.
register_to_config
(
prediction_type
=
"epsilon"
if
predict_epsilon
else
"sample"
)
self
.
register_to_config
(
prediction_type
=
"epsilon"
if
predict_epsilon
else
"sample"
)
...
@@ -255,7 +255,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -255,7 +255,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.1
0
.0"
,
message
,
take_from
=
kwargs
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.1
1
.0"
,
message
,
take_from
=
kwargs
)
if
predict_epsilon
is
not
None
:
if
predict_epsilon
is
not
None
:
new_config
=
dict
(
self
.
config
)
new_config
=
dict
(
self
.
config
)
new_config
[
"prediction_type"
]
=
"epsilon"
if
predict_epsilon
else
"sample"
new_config
[
"prediction_type"
]
=
"epsilon"
if
predict_epsilon
else
"sample"
...
...
src/diffusers/schedulers/scheduling_ddpm_flax.py
View file @
99904459
...
@@ -132,7 +132,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
...
@@ -132,7 +132,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" FlaxDDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
" FlaxDDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.1
0
.0"
,
message
,
take_from
=
kwargs
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.1
1
.0"
,
message
,
take_from
=
kwargs
)
if
predict_epsilon
is
not
None
:
if
predict_epsilon
is
not
None
:
self
.
register_to_config
(
prediction_type
=
"epsilon"
if
predict_epsilon
else
"sample"
)
self
.
register_to_config
(
prediction_type
=
"epsilon"
if
predict_epsilon
else
"sample"
)
...
@@ -239,7 +239,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
...
@@ -239,7 +239,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" FlaxDDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
" FlaxDDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.1
0
.0"
,
message
,
take_from
=
kwargs
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.1
1
.0"
,
message
,
take_from
=
kwargs
)
if
predict_epsilon
is
not
None
:
if
predict_epsilon
is
not
None
:
new_config
=
dict
(
self
.
config
)
new_config
=
dict
(
self
.
config
)
new_config
[
"prediction_type"
]
=
"epsilon"
if
predict_epsilon
else
"sample"
new_config
[
"prediction_type"
]
=
"epsilon"
if
predict_epsilon
else
"sample"
...
...
src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
View file @
99904459
...
@@ -142,7 +142,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
...
@@ -142,7 +142,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DPMSolverMultistepScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
" DPMSolverMultistepScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.1
0
.0"
,
message
,
take_from
=
kwargs
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.1
1
.0"
,
message
,
take_from
=
kwargs
)
if
predict_epsilon
is
not
None
:
if
predict_epsilon
is
not
None
:
self
.
register_to_config
(
prediction_type
=
"epsilon"
if
predict_epsilon
else
"sample"
)
self
.
register_to_config
(
prediction_type
=
"epsilon"
if
predict_epsilon
else
"sample"
)
...
...
src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py
View file @
99904459
...
@@ -177,7 +177,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
...
@@ -177,7 +177,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" FlaxDPMSolverMultistepScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
" FlaxDPMSolverMultistepScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.1
0
.0"
,
message
,
take_from
=
kwargs
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.1
1
.0"
,
message
,
take_from
=
kwargs
)
if
predict_epsilon
is
not
None
:
if
predict_epsilon
is
not
None
:
self
.
register_to_config
(
prediction_type
=
"epsilon"
if
predict_epsilon
else
"sample"
)
self
.
register_to_config
(
prediction_type
=
"epsilon"
if
predict_epsilon
else
"sample"
)
...
...
tests/pipelines/ddpm/test_ddpm.py
View file @
99904459
...
@@ -69,7 +69,7 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
...
@@ -69,7 +69,7 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
assert
np
.
abs
(
image_from_tuple_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
assert
np
.
abs
(
image_from_tuple_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
def
test_inference_deprecated_predict_epsilon
(
self
):
def
test_inference_deprecated_predict_epsilon
(
self
):
deprecate
(
"remove this test"
,
"0.1
0
.0"
,
"remove"
)
deprecate
(
"remove this test"
,
"0.1
1
.0"
,
"remove"
)
unet
=
self
.
dummy_uncond_unet
unet
=
self
.
dummy_uncond_unet
scheduler
=
DDPMScheduler
(
predict_epsilon
=
False
)
scheduler
=
DDPMScheduler
(
predict_epsilon
=
False
)
...
...
tests/test_config.py
View file @
99904459
...
@@ -203,7 +203,7 @@ class ConfigTester(unittest.TestCase):
...
@@ -203,7 +203,7 @@ class ConfigTester(unittest.TestCase):
ddpm_2
=
DDPMScheduler
.
from_pretrained
(
"google/ddpm-celebahq-256"
,
beta_start
=
88
)
ddpm_2
=
DDPMScheduler
.
from_pretrained
(
"google/ddpm-celebahq-256"
,
beta_start
=
88
)
with
CaptureLogger
(
logger
)
as
cap_logger
:
with
CaptureLogger
(
logger
)
as
cap_logger
:
deprecate
(
"remove this case"
,
"0.1
0
.0"
,
"remove"
)
deprecate
(
"remove this case"
,
"0.1
1
.0"
,
"remove"
)
ddpm_3
=
DDPMScheduler
.
from_pretrained
(
ddpm_3
=
DDPMScheduler
.
from_pretrained
(
"hf-internal-testing/tiny-stable-diffusion-torch"
,
"hf-internal-testing/tiny-stable-diffusion-torch"
,
subfolder
=
"scheduler"
,
subfolder
=
"scheduler"
,
...
...
tests/test_scheduler.py
View file @
99904459
...
@@ -639,12 +639,12 @@ class DDPMSchedulerTest(SchedulerCommonTest):
...
@@ -639,12 +639,12 @@ class DDPMSchedulerTest(SchedulerCommonTest):
self
.
check_over_configs
(
prediction_type
=
prediction_type
)
self
.
check_over_configs
(
prediction_type
=
prediction_type
)
def
test_deprecated_predict_epsilon
(
self
):
def
test_deprecated_predict_epsilon
(
self
):
deprecate
(
"remove this test"
,
"0.1
0
.0"
,
"remove"
)
deprecate
(
"remove this test"
,
"0.1
1
.0"
,
"remove"
)
for
predict_epsilon
in
[
True
,
False
]:
for
predict_epsilon
in
[
True
,
False
]:
self
.
check_over_configs
(
predict_epsilon
=
predict_epsilon
)
self
.
check_over_configs
(
predict_epsilon
=
predict_epsilon
)
def
test_deprecated_epsilon
(
self
):
def
test_deprecated_epsilon
(
self
):
deprecate
(
"remove this test"
,
"0.1
0
.0"
,
"remove"
)
deprecate
(
"remove this test"
,
"0.1
1
.0"
,
"remove"
)
scheduler_class
=
self
.
scheduler_classes
[
0
]
scheduler_class
=
self
.
scheduler_classes
[
0
]
scheduler_config
=
self
.
get_scheduler_config
()
scheduler_config
=
self
.
get_scheduler_config
()
...
...
tests/test_scheduler_flax.py
View file @
99904459
...
@@ -626,12 +626,12 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
...
@@ -626,12 +626,12 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
self
.
check_over_configs
(
prediction_type
=
prediction_type
)
self
.
check_over_configs
(
prediction_type
=
prediction_type
)
def
test_deprecated_predict_epsilon
(
self
):
def
test_deprecated_predict_epsilon
(
self
):
deprecate
(
"remove this test"
,
"0.1
0
.0"
,
"remove"
)
deprecate
(
"remove this test"
,
"0.1
1
.0"
,
"remove"
)
for
predict_epsilon
in
[
True
,
False
]:
for
predict_epsilon
in
[
True
,
False
]:
self
.
check_over_configs
(
predict_epsilon
=
predict_epsilon
)
self
.
check_over_configs
(
predict_epsilon
=
predict_epsilon
)
def
test_deprecated_predict_epsilon_to_prediction_type
(
self
):
def
test_deprecated_predict_epsilon_to_prediction_type
(
self
):
deprecate
(
"remove this test"
,
"0.1
0
.0"
,
"remove"
)
deprecate
(
"remove this test"
,
"0.1
1
.0"
,
"remove"
)
for
scheduler_class
in
self
.
scheduler_classes
:
for
scheduler_class
in
self
.
scheduler_classes
:
scheduler_config
=
self
.
get_scheduler_config
(
predict_epsilon
=
True
)
scheduler_config
=
self
.
get_scheduler_config
(
predict_epsilon
=
True
)
scheduler
=
scheduler_class
.
from_config
(
scheduler_config
)
scheduler
=
scheduler_class
.
from_config
(
scheduler_config
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment