Commit 554b374d authored by Patrick von Platen's avatar Patrick von Platen
Browse files

Merge branch 'main' of https://github.com/huggingface/diffusers into main

parents d5ab55e4 a0520193
...@@ -19,7 +19,7 @@ import numpy as np ...@@ -19,7 +19,7 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, logging from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, logging
from .scheduling_utils import SchedulerMixin from .scheduling_utils import SchedulerMixin
...@@ -52,8 +52,8 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -52,8 +52,8 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~ConfigMixin.from_config`] functions. [`~SchedulerMixin.from_pretrained`] functions.
Args: Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model. num_train_timesteps (`int`): number of diffusion steps used to train the model.
...@@ -67,14 +67,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -67,14 +67,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
""" """
_compatible_classes = [ _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
"DDIMScheduler",
"DDPMScheduler",
"LMSDiscreteScheduler",
"PNDMScheduler",
"EulerDiscreteScheduler",
"DPMSolverMultistepScheduler",
]
@register_to_config @register_to_config
def __init__( def __init__(
......
...@@ -19,7 +19,7 @@ import numpy as np ...@@ -19,7 +19,7 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, logging from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, logging
from .scheduling_utils import SchedulerMixin from .scheduling_utils import SchedulerMixin
...@@ -53,8 +53,8 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -53,8 +53,8 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~ConfigMixin.from_config`] functions. [`~SchedulerMixin.from_pretrained`] functions.
Args: Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model. num_train_timesteps (`int`): number of diffusion steps used to train the model.
...@@ -68,14 +68,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -68,14 +68,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
""" """
_compatible_classes = [ _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
"DDIMScheduler",
"DDPMScheduler",
"LMSDiscreteScheduler",
"PNDMScheduler",
"EulerAncestralDiscreteScheduler",
"DPMSolverMultistepScheduler",
]
@register_to_config @register_to_config
def __init__( def __init__(
......
...@@ -28,8 +28,8 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -28,8 +28,8 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~ConfigMixin.from_config`] functions. [`~SchedulerMixin.from_pretrained`] functions.
For more details, see the original paper: https://arxiv.org/abs/2202.09778 For more details, see the original paper: https://arxiv.org/abs/2202.09778
......
...@@ -56,8 +56,8 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -56,8 +56,8 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~ConfigMixin.from_config`] functions. [`~SchedulerMixin.from_pretrained`] functions.
For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of
Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the
......
...@@ -67,8 +67,8 @@ class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -67,8 +67,8 @@ class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~ConfigMixin.from_config`] functions. [`~SchedulerMixin.from_pretrained`] functions.
For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of
Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the
......
...@@ -21,7 +21,7 @@ import torch ...@@ -21,7 +21,7 @@ import torch
from scipy import integrate from scipy import integrate
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput
from .scheduling_utils import SchedulerMixin from .scheduling_utils import SchedulerMixin
...@@ -52,8 +52,8 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -52,8 +52,8 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~ConfigMixin.from_config`] functions. [`~SchedulerMixin.from_pretrained`] functions.
Args: Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model. num_train_timesteps (`int`): number of diffusion steps used to train the model.
...@@ -67,14 +67,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -67,14 +67,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
""" """
_compatible_classes = [ _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
"DDIMScheduler",
"DDPMScheduler",
"PNDMScheduler",
"EulerDiscreteScheduler",
"EulerAncestralDiscreteScheduler",
"DPMSolverMultistepScheduler",
]
@register_to_config @register_to_config
def __init__( def __init__(
......
...@@ -20,7 +20,12 @@ import jax.numpy as jnp ...@@ -20,7 +20,12 @@ import jax.numpy as jnp
from scipy import integrate from scipy import integrate
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left from .scheduling_utils_flax import (
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
FlaxSchedulerMixin,
FlaxSchedulerOutput,
broadcast_to_shape_from_left,
)
@flax.struct.dataclass @flax.struct.dataclass
...@@ -49,8 +54,8 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -49,8 +54,8 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~ConfigMixin.from_config`] functions. [`~SchedulerMixin.from_pretrained`] functions.
Args: Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model. num_train_timesteps (`int`): number of diffusion steps used to train the model.
...@@ -63,6 +68,8 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -63,6 +68,8 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
""" """
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
@property @property
def has_state(self): def has_state(self):
return True return True
......
...@@ -21,6 +21,7 @@ import numpy as np ...@@ -21,6 +21,7 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
from .scheduling_utils import SchedulerMixin, SchedulerOutput from .scheduling_utils import SchedulerMixin, SchedulerOutput
...@@ -60,8 +61,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -60,8 +61,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~ConfigMixin.from_config`] functions. [`~SchedulerMixin.from_pretrained`] functions.
For more details, see the original paper: https://arxiv.org/abs/2202.09778 For more details, see the original paper: https://arxiv.org/abs/2202.09778
...@@ -88,14 +89,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -88,14 +89,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
""" """
_compatible_classes = [ _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
"DDIMScheduler",
"DDPMScheduler",
"LMSDiscreteScheduler",
"EulerDiscreteScheduler",
"EulerAncestralDiscreteScheduler",
"DPMSolverMultistepScheduler",
]
@register_to_config @register_to_config
def __init__( def __init__(
......
...@@ -23,7 +23,12 @@ import jax ...@@ -23,7 +23,12 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left from .scheduling_utils_flax import (
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
FlaxSchedulerMixin,
FlaxSchedulerOutput,
broadcast_to_shape_from_left,
)
def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999) -> jnp.ndarray: def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999) -> jnp.ndarray:
...@@ -87,8 +92,8 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -87,8 +92,8 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~ConfigMixin.from_config`] functions. [`~SchedulerMixin.from_pretrained`] functions.
For more details, see the original paper: https://arxiv.org/abs/2202.09778 For more details, see the original paper: https://arxiv.org/abs/2202.09778
...@@ -114,6 +119,8 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -114,6 +119,8 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
stable diffusion. stable diffusion.
""" """
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
@property @property
def has_state(self): def has_state(self):
return True return True
......
...@@ -77,8 +77,8 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin): ...@@ -77,8 +77,8 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~ConfigMixin.from_config`] functions. [`~SchedulerMixin.from_pretrained`] functions.
For more details, see the original paper: https://arxiv.org/pdf/2201.09865.pdf For more details, see the original paper: https://arxiv.org/pdf/2201.09865.pdf
......
...@@ -50,8 +50,8 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -50,8 +50,8 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~ConfigMixin.from_config`] functions. [`~SchedulerMixin.from_pretrained`] functions.
Args: Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model. num_train_timesteps (`int`): number of diffusion steps used to train the model.
......
...@@ -64,8 +64,8 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -64,8 +64,8 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~ConfigMixin.from_config`] functions. [`~SchedulerMixin.from_pretrained`] functions.
Args: Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model. num_train_timesteps (`int`): number of diffusion steps used to train the model.
......
...@@ -29,8 +29,8 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin): ...@@ -29,8 +29,8 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~ConfigMixin.from_config`] functions. [`~SchedulerMixin.from_pretrained`] functions.
For more information, see the original paper: https://arxiv.org/abs/2011.13456 For more information, see the original paper: https://arxiv.org/abs/2011.13456
......
...@@ -11,7 +11,10 @@ ...@@ -11,7 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import importlib
import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Optional, Union
import torch import torch
...@@ -38,6 +41,114 @@ class SchedulerOutput(BaseOutput): ...@@ -38,6 +41,114 @@ class SchedulerOutput(BaseOutput):
class SchedulerMixin: class SchedulerMixin:
""" """
Mixin containing common functions for the schedulers. Mixin containing common functions for the schedulers.
Class attributes:
- **_compatibles** (`List[str]`) -- A list of classes that are compatible with the parent class, so that
`from_config` can be used from a class different than the one used to save the config (should be overridden
by parent class).
""" """
config_name = SCHEDULER_CONFIG_NAME config_name = SCHEDULER_CONFIG_NAME
_compatibles = []
has_compatibles = True
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Dict[str, Any] = None,
subfolder: Optional[str] = None,
return_unused_kwargs=False,
**kwargs,
):
r"""
Instantiate a Scheduler class from a pre-defined JSON configuration file inside a directory or Hub repo.
Parameters:
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
Can be either:
- A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an
organization name, like `google/ddpm-celebahq-256`.
- A path to a *directory* containing the schedluer configurations saved using
[`~SchedulerMixin.save_pretrained`], e.g., `./my_model_directory/`.
subfolder (`str`, *optional*):
In case the relevant files are located inside a subfolder of the model repo (either remote in
huggingface.co or downloaded locally), you can specify the folder name here.
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
Whether kwargs that are not consumed by the Python class should be returned or not.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
file exists.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
output_loading_info(`bool`, *optional*, defaults to `False`):
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
local_files_only(`bool`, *optional*, defaults to `False`):
Whether or not to only look at local files (i.e., do not try to download the model).
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `transformers-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
<Tip>
It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
models](https://huggingface.co/docs/hub/models-gated#gated-models).
</Tip>
<Tip>
Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
use this method in a firewalled environment.
</Tip>
"""
config, kwargs = cls.load_config(
pretrained_model_name_or_path=pretrained_model_name_or_path,
subfolder=subfolder,
return_unused_kwargs=True,
**kwargs,
)
return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs)
def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
"""
Save a scheduler configuration object to the directory `save_directory`, so that it can be re-loaded using the
[`~SchedulerMixin.from_pretrained`] class method.
Args:
save_directory (`str` or `os.PathLike`):
Directory where the configuration JSON file will be saved (will be created if it does not exist).
"""
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
@property
def compatibles(self):
"""
Returns all schedulers that are compatible with this scheduler
Returns:
`List[SchedulerMixin]`: List of compatible schedulers
"""
return self._get_compatibles()
@classmethod
def _get_compatibles(cls):
compatible_classes_str = list(set([cls.__name__] + cls._compatibles))
diffusers_library = importlib.import_module(__name__.split(".")[0])
compatible_classes = [
getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c)
]
return compatible_classes
...@@ -11,15 +11,18 @@ ...@@ -11,15 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import importlib
import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Tuple from typing import Any, Dict, Optional, Tuple, Union
import jax.numpy as jnp import jax.numpy as jnp
from ..utils import BaseOutput from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput
SCHEDULER_CONFIG_NAME = "scheduler_config.json" SCHEDULER_CONFIG_NAME = "scheduler_config.json"
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS = ["Flax" + c for c in _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS]
@dataclass @dataclass
...@@ -39,9 +42,123 @@ class FlaxSchedulerOutput(BaseOutput): ...@@ -39,9 +42,123 @@ class FlaxSchedulerOutput(BaseOutput):
class FlaxSchedulerMixin: class FlaxSchedulerMixin:
""" """
Mixin containing common functions for the schedulers. Mixin containing common functions for the schedulers.
Class attributes:
- **_compatibles** (`List[str]`) -- A list of classes that are compatible with the parent class, so that
`from_config` can be used from a class different than the one used to save the config (should be overridden
by parent class).
""" """
config_name = SCHEDULER_CONFIG_NAME config_name = SCHEDULER_CONFIG_NAME
_compatibles = []
has_compatibles = True
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Dict[str, Any] = None,
subfolder: Optional[str] = None,
return_unused_kwargs=False,
**kwargs,
):
r"""
Instantiate a Scheduler class from a pre-defined JSON-file.
Parameters:
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
Can be either:
- A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an
organization name, like `google/ddpm-celebahq-256`.
- A path to a *directory* containing model weights saved using [`~SchedulerMixin.save_pretrained`],
e.g., `./my_model_directory/`.
subfolder (`str`, *optional*):
In case the relevant files are located inside a subfolder of the model repo (either remote in
huggingface.co or downloaded locally), you can specify the folder name here.
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
Whether kwargs that are not consumed by the Python class should be returned or not.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
file exists.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
output_loading_info(`bool`, *optional*, defaults to `False`):
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
local_files_only(`bool`, *optional*, defaults to `False`):
Whether or not to only look at local files (i.e., do not try to download the model).
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `transformers-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
<Tip>
It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
models](https://huggingface.co/docs/hub/models-gated#gated-models).
</Tip>
<Tip>
Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
use this method in a firewalled environment.
</Tip>
"""
config, kwargs = cls.load_config(
pretrained_model_name_or_path=pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
scheduler, unused_kwargs = cls.from_config(config, return_unused_kwargs=True, **kwargs)
if hasattr(scheduler, "create_state") and getattr(scheduler, "has_state", False):
state = scheduler.create_state()
if return_unused_kwargs:
return scheduler, state, unused_kwargs
return scheduler, state
def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
"""
Save a scheduler configuration object to the directory `save_directory`, so that it can be re-loaded using the
[`~FlaxSchedulerMixin.from_pretrained`] class method.
Args:
save_directory (`str` or `os.PathLike`):
Directory where the configuration JSON file will be saved (will be created if it does not exist).
"""
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
@property
def compatibles(self):
"""
Returns all schedulers that are compatible with this scheduler
Returns:
`List[SchedulerMixin]`: List of compatible schedulers
"""
return self._get_compatibles()
@classmethod
def _get_compatibles(cls):
compatible_classes_str = list(set([cls.__name__] + cls._compatibles))
diffusers_library = importlib.import_module(__name__.split(".")[0])
compatible_classes = [
getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c)
]
return compatible_classes
def broadcast_to_shape_from_left(x: jnp.ndarray, shape: Tuple[int]) -> jnp.ndarray: def broadcast_to_shape_from_left(x: jnp.ndarray, shape: Tuple[int]) -> jnp.ndarray:
......
...@@ -112,8 +112,8 @@ class VQDiffusionScheduler(SchedulerMixin, ConfigMixin): ...@@ -112,8 +112,8 @@ class VQDiffusionScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~ConfigMixin.from_config`] functions. [`~SchedulerMixin.from_pretrained`] functions.
For more details, see the original paper: https://arxiv.org/abs/2111.14822 For more details, see the original paper: https://arxiv.org/abs/2111.14822
......
...@@ -72,3 +72,13 @@ HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co" ...@@ -72,3 +72,13 @@ HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
DIFFUSERS_CACHE = default_cache_path DIFFUSERS_CACHE = default_cache_path
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules")) HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS = [
"DDIMScheduler",
"DDPMScheduler",
"PNDMScheduler",
"LMSDiscreteScheduler",
"EulerDiscreteScheduler",
"EulerAncestralDiscreteScheduler",
"DPMSolverMultistepScheduler",
]
...@@ -18,13 +18,120 @@ import unittest ...@@ -18,13 +18,120 @@ import unittest
import torch import torch
from diffusers import UNet1DModel from diffusers import UNet1DModel
from diffusers.utils import slow, torch_device from diffusers.utils import floats_tensor, slow, torch_device
from ..test_modeling_common import ModelTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
class UnetModel1DTests(unittest.TestCase): class UNet1DModelTests(ModelTesterMixin, unittest.TestCase):
model_class = UNet1DModel
@property
def dummy_input(self):
batch_size = 4
num_features = 14
seq_len = 16
noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device)
time_step = torch.tensor([10] * batch_size).to(torch_device)
return {"sample": noise, "timestep": time_step}
@property
def input_shape(self):
return (4, 14, 16)
@property
def output_shape(self):
return (4, 14, 16)
def test_ema_training(self):
pass
def test_training(self):
pass
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_determinism(self):
super().test_determinism()
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_outputs_equivalence(self):
super().test_outputs_equivalence()
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_from_pretrained_save_pretrained(self):
super().test_from_pretrained_save_pretrained()
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_model_from_pretrained(self):
super().test_model_from_pretrained()
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_output(self):
super().test_output()
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"block_out_channels": (32, 64, 128, 256),
"in_channels": 14,
"out_channels": 14,
"time_embedding_type": "positional",
"use_timestep_embedding": True,
"flip_sin_to_cos": False,
"freq_shift": 1.0,
"out_block_type": "OutConv1DBlock",
"mid_block_type": "MidResTemporalBlock1D",
"down_block_types": ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"),
"up_block_types": ("UpResnetBlock1D", "UpResnetBlock1D", "UpResnetBlock1D"),
"act_fn": "mish",
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_from_pretrained_hub(self):
model, loading_info = UNet1DModel.from_pretrained(
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="unet"
)
self.assertIsNotNone(model)
self.assertEqual(len(loading_info["missing_keys"]), 0)
model.to(torch_device)
image = model(**self.dummy_input)
assert image is not None, "Make sure output is not None"
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_output_pretrained(self):
model = UNet1DModel.from_pretrained("bglick13/hopper-medium-v2-value-function-hor32", subfolder="unet")
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
num_features = model.in_channels
seq_len = 16
noise = torch.randn((1, seq_len, num_features)).permute(
0, 2, 1
) # match original, we can update values and remove
time_step = torch.full((num_features,), 0)
with torch.no_grad():
output = model(noise, time_step).sample.permute(0, 2, 1)
output_slice = output[0, -3:, -3:].flatten()
# fmt: off
expected_output_slice = torch.tensor([-2.137172, 1.1426016, 0.3688687, -0.766922, 0.7303146, 0.11038864, -0.4760633, 0.13270172, 0.02591348])
# fmt: on
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3))
def test_forward_with_norm_groups(self):
# Not implemented yet for this UNet
pass
@slow @slow
def test_unet_1d_maestro(self): def test_unet_1d_maestro(self):
model_id = "harmonai/maestro-150k" model_id = "harmonai/maestro-150k"
...@@ -43,3 +150,127 @@ class UnetModel1DTests(unittest.TestCase): ...@@ -43,3 +150,127 @@ class UnetModel1DTests(unittest.TestCase):
assert (output_sum - 224.0896).abs() < 4e-2 assert (output_sum - 224.0896).abs() < 4e-2
assert (output_max - 0.0607).abs() < 4e-4 assert (output_max - 0.0607).abs() < 4e-4
class UNetRLModelTests(ModelTesterMixin, unittest.TestCase):
model_class = UNet1DModel
@property
def dummy_input(self):
batch_size = 4
num_features = 14
seq_len = 16
noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device)
time_step = torch.tensor([10] * batch_size).to(torch_device)
return {"sample": noise, "timestep": time_step}
@property
def input_shape(self):
return (4, 14, 16)
@property
def output_shape(self):
return (4, 14, 1)
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_determinism(self):
super().test_determinism()
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_outputs_equivalence(self):
super().test_outputs_equivalence()
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_from_pretrained_save_pretrained(self):
super().test_from_pretrained_save_pretrained()
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_model_from_pretrained(self):
super().test_model_from_pretrained()
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_output(self):
# UNetRL is a value-function is different output shape
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.sample
self.assertIsNotNone(output)
expected_shape = torch.Size((inputs_dict["sample"].shape[0], 1))
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_ema_training(self):
pass
def test_training(self):
pass
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"in_channels": 14,
"out_channels": 14,
"down_block_types": ["DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"],
"up_block_types": [],
"out_block_type": "ValueFunction",
"mid_block_type": "ValueFunctionMidBlock1D",
"block_out_channels": [32, 64, 128, 256],
"layers_per_block": 1,
"downsample_each_block": True,
"use_timestep_embedding": True,
"freq_shift": 1.0,
"flip_sin_to_cos": False,
"time_embedding_type": "positional",
"act_fn": "mish",
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_from_pretrained_hub(self):
value_function, vf_loading_info = UNet1DModel.from_pretrained(
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function"
)
self.assertIsNotNone(value_function)
self.assertEqual(len(vf_loading_info["missing_keys"]), 0)
value_function.to(torch_device)
image = value_function(**self.dummy_input)
assert image is not None, "Make sure output is not None"
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_output_pretrained(self):
value_function, vf_loading_info = UNet1DModel.from_pretrained(
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function"
)
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
num_features = value_function.in_channels
seq_len = 14
noise = torch.randn((1, seq_len, num_features)).permute(
0, 2, 1
) # match original, we can update values and remove
time_step = torch.full((num_features,), 0)
with torch.no_grad():
output = value_function(noise, time_step).sample
# fmt: off
expected_output_slice = torch.tensor([165.25] * seq_len)
# fmt: on
self.assertTrue(torch.allclose(output, expected_output_slice, rtol=1e-3))
def test_forward_with_norm_groups(self):
# Not implemented yet for this UNet
pass
...@@ -44,6 +44,10 @@ class PipelineFastTests(unittest.TestCase): ...@@ -44,6 +44,10 @@ class PipelineFastTests(unittest.TestCase):
sample_rate=16_000, sample_rate=16_000,
in_channels=2, in_channels=2,
out_channels=2, out_channels=2,
flip_sin_to_cos=True,
use_timestep_embedding=False,
time_embedding_type="fourier",
mid_block_type="UNetMidBlock1D",
down_block_types=["DownBlock1DNoSkip"] + ["DownBlock1D"] + ["AttnDownBlock1D"], down_block_types=["DownBlock1DNoSkip"] + ["DownBlock1D"] + ["AttnDownBlock1D"],
up_block_types=["AttnUpBlock1D"] + ["UpBlock1D"] + ["UpBlock1DNoSkip"], up_block_types=["AttnUpBlock1D"] + ["UpBlock1D"] + ["UpBlock1DNoSkip"],
) )
......
...@@ -75,7 +75,7 @@ class DDIMPipelineIntegrationTests(unittest.TestCase): ...@@ -75,7 +75,7 @@ class DDIMPipelineIntegrationTests(unittest.TestCase):
model_id = "google/ddpm-ema-bedroom-256" model_id = "google/ddpm-ema-bedroom-256"
unet = UNet2DModel.from_pretrained(model_id) unet = UNet2DModel.from_pretrained(model_id)
scheduler = DDIMScheduler.from_config(model_id) scheduler = DDIMScheduler.from_pretrained(model_id)
ddpm = DDIMPipeline(unet=unet, scheduler=scheduler) ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
ddpm.to(torch_device) ddpm.to(torch_device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment