Unverified Commit 1d0f3c21 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Move accelerate to a soft-dependency (#1134)



* finish

* finish

* Update src/diffusers/modeling_utils.py

* Update src/diffusers/pipeline_utils.py
Co-authored-by: default avatarAnton Lozhkov <anton@huggingface.co>

* more fixes

* fix
Co-authored-by: default avatarAnton Lozhkov <anton@huggingface.co>
parent c62b3a2e
from .utils import ( from .utils import (
is_accelerate_available,
is_flax_available, is_flax_available,
is_inflect_available, is_inflect_available,
is_onnx_available, is_onnx_available,
...@@ -17,13 +16,6 @@ from .onnx_utils import OnnxRuntimeModel ...@@ -17,13 +16,6 @@ from .onnx_utils import OnnxRuntimeModel
from .utils import logging from .utils import logging
# This will create an extra dummy file "dummy_torch_and_accelerate_objects.py"
# TODO: (patil-suraj, anton-l) maybe import everything under is_torch_and_accelerate_available
if is_torch_available() and not is_accelerate_available():
error_msg = "Please install the `accelerate` library to use Diffusers with PyTorch. You can do so by running `pip install diffusers[torch]`. Or if torch is already installed, you can run `pip install accelerate`." # noqa: E501
raise ImportError(error_msg)
if is_torch_available(): if is_torch_available():
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
from .models import AutoencoderKL, Transformer2DModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel from .models import AutoencoderKL, Transformer2DModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel
......
...@@ -21,15 +21,20 @@ from typing import Callable, List, Optional, Tuple, Union ...@@ -21,15 +21,20 @@ from typing import Callable, List, Optional, Tuple, Union
import torch import torch
from torch import Tensor, device from torch import Tensor, device
import accelerate
from accelerate.utils import set_module_tensor_to_device
from accelerate.utils.versions import is_torch_version
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from requests import HTTPError from requests import HTTPError
from . import __version__ from . import __version__
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, WEIGHTS_NAME, logging from .utils import (
CONFIG_NAME,
DIFFUSERS_CACHE,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
WEIGHTS_NAME,
is_accelerate_available,
is_torch_version,
logging,
)
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -41,6 +46,12 @@ else: ...@@ -41,6 +46,12 @@ else:
_LOW_CPU_MEM_USAGE_DEFAULT = False _LOW_CPU_MEM_USAGE_DEFAULT = False
if is_accelerate_available():
import accelerate
from accelerate.utils import set_module_tensor_to_device
from accelerate.utils.versions import is_torch_version
def get_parameter_device(parameter: torch.nn.Module): def get_parameter_device(parameter: torch.nn.Module):
try: try:
return next(parameter.parameters()).device return next(parameter.parameters()).device
...@@ -319,6 +330,21 @@ class ModelMixin(torch.nn.Module): ...@@ -319,6 +330,21 @@ class ModelMixin(torch.nn.Module):
device_map = kwargs.pop("device_map", None) device_map = kwargs.pop("device_map", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
if low_cpu_mem_usage and not is_accelerate_available():
low_cpu_mem_usage = False
logger.warn(
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
" install accelerate\n```\n."
)
if device_map is not None and not is_accelerate_available():
raise NotImplementedError(
"Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
" `device_map=None`. You can install accelerate with `pip install accelerate`."
)
# Check if we can handle device_map and dispatching the weights # Check if we can handle device_map and dispatching the weights
if device_map is not None and not is_torch_version(">=", "1.9.0"): if device_map is not None and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError( raise NotImplementedError(
......
...@@ -25,7 +25,6 @@ import torch ...@@ -25,7 +25,6 @@ import torch
import diffusers import diffusers
import PIL import PIL
from accelerate.utils.versions import is_torch_version
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from packaging import version from packaging import version
from PIL import Image from PIL import Image
...@@ -43,6 +42,8 @@ from .utils import ( ...@@ -43,6 +42,8 @@ from .utils import (
WEIGHTS_NAME, WEIGHTS_NAME,
BaseOutput, BaseOutput,
deprecate, deprecate,
is_accelerate_available,
is_torch_version,
is_transformers_available, is_transformers_available,
logging, logging,
) )
...@@ -397,6 +398,15 @@ class DiffusionPipeline(ConfigMixin): ...@@ -397,6 +398,15 @@ class DiffusionPipeline(ConfigMixin):
device_map = kwargs.pop("device_map", None) device_map = kwargs.pop("device_map", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
if low_cpu_mem_usage and not is_accelerate_available():
low_cpu_mem_usage = False
logger.warn(
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
" install accelerate\n```\n."
)
if device_map is not None and not is_torch_version(">=", "1.9.0"): if device_map is not None and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError( raise NotImplementedError(
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set" "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
......
...@@ -31,6 +31,7 @@ from .import_utils import ( ...@@ -31,6 +31,7 @@ from .import_utils import (
is_scipy_available, is_scipy_available,
is_tf_available, is_tf_available,
is_torch_available, is_torch_available,
is_torch_version,
is_transformers_available, is_transformers_available,
is_unidecode_available, is_unidecode_available,
requires_backends, requires_backends,
......
...@@ -272,21 +272,6 @@ class ScoreSdeVePipeline(metaclass=DummyObject): ...@@ -272,21 +272,6 @@ class ScoreSdeVePipeline(metaclass=DummyObject):
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class VQDiffusionPipeline(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class DDIMScheduler(metaclass=DummyObject): class DDIMScheduler(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
# This file is autogenerated by the command `make fix-copies`, do not edit.
# flake8: noqa
from ..utils import DummyObject, requires_backends
class ModelMixin(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class AutoencoderKL(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class Transformer2DModel(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class UNet1DModel(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class UNet2DConditionModel(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class UNet2DModel(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class VQModel(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
def get_constant_schedule(*args, **kwargs):
requires_backends(get_constant_schedule, ["torch", "accelerate"])
def get_constant_schedule_with_warmup(*args, **kwargs):
requires_backends(get_constant_schedule_with_warmup, ["torch", "accelerate"])
def get_cosine_schedule_with_warmup(*args, **kwargs):
requires_backends(get_cosine_schedule_with_warmup, ["torch", "accelerate"])
def get_cosine_with_hard_restarts_schedule_with_warmup(*args, **kwargs):
requires_backends(get_cosine_with_hard_restarts_schedule_with_warmup, ["torch", "accelerate"])
def get_linear_schedule_with_warmup(*args, **kwargs):
requires_backends(get_linear_schedule_with_warmup, ["torch", "accelerate"])
def get_polynomial_decay_schedule_with_warmup(*args, **kwargs):
requires_backends(get_polynomial_decay_schedule_with_warmup, ["torch", "accelerate"])
def get_scheduler(*args, **kwargs):
requires_backends(get_scheduler, ["torch", "accelerate"])
class DiffusionPipeline(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class DanceDiffusionPipeline(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class DDIMPipeline(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class DDPMPipeline(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class KarrasVePipeline(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class LDMPipeline(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class PNDMPipeline(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class RePaintPipeline(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class ScoreSdeVePipeline(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class DDIMScheduler(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class DDPMScheduler(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class EulerAncestralDiscreteScheduler(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class EulerDiscreteScheduler(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class IPNDMScheduler(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class KarrasVeScheduler(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class PNDMScheduler(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class RePaintScheduler(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class SchedulerMixin(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class ScoreSdeVeScheduler(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class VQDiffusionScheduler(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class EMAModel(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
...@@ -15,11 +15,14 @@ ...@@ -15,11 +15,14 @@
Import utilities: Utilities related to imports and our lazy inits. Import utilities: Utilities related to imports and our lazy inits.
""" """
import importlib.util import importlib.util
import operator as op
import os import os
import sys import sys
from collections import OrderedDict from collections import OrderedDict
from typing import Union
from packaging import version from packaging import version
from packaging.version import Version, parse
from . import logging from . import logging
...@@ -40,6 +43,8 @@ USE_TF = os.environ.get("USE_TF", "AUTO").upper() ...@@ -40,6 +43,8 @@ USE_TF = os.environ.get("USE_TF", "AUTO").upper()
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper() USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt}
_torch_version = "N/A" _torch_version = "N/A"
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES: if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
_torch_available = importlib.util.find_spec("torch") is not None _torch_available = importlib.util.find_spec("torch") is not None
...@@ -309,3 +314,36 @@ class DummyObject(type): ...@@ -309,3 +314,36 @@ class DummyObject(type):
if key.startswith("_"): if key.startswith("_"):
return super().__getattr__(cls, key) return super().__getattr__(cls, key)
requires_backends(cls, cls._backends) requires_backends(cls, cls._backends)
# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L319
def compare_versions(library_or_version: Union[str, Version], operation: str, requirement_version: str):
"""
Args:
Compares a library version to some requirement using a given operation.
library_or_version (`str` or `packaging.version.Version`):
A library name or a version to check.
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`.
requirement_version (`str`):
The version to compare the library version against
"""
if operation not in STR_OPERATION_TO_FUNC.keys():
raise ValueError(f"`operation` must be one of {list(STR_OPERATION_TO_FUNC.keys())}, received {operation}")
operation = STR_OPERATION_TO_FUNC[operation]
if isinstance(library_or_version, str):
library_or_version = parse(importlib_metadata.version(library_or_version))
return operation(library_or_version, parse(requirement_version))
# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L338
def is_torch_version(operation: str, version: str):
"""
Args:
Compares the current PyTorch version to a given reference with an operation.
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
A string version of PyTorch
"""
return compare_versions(parse(_torch_version), operation, version)
...@@ -52,13 +52,13 @@ class CheckDummiesTester(unittest.TestCase): ...@@ -52,13 +52,13 @@ class CheckDummiesTester(unittest.TestCase):
def test_read_init(self): def test_read_init(self):
objects = read_init() objects = read_init()
# We don't assert on the exact list of keys to allow for smooth grow of backend-specific objects # We don't assert on the exact list of keys to allow for smooth grow of backend-specific objects
self.assertIn("torch_and_accelerate", objects) self.assertIn("torch", objects)
self.assertIn("torch_and_transformers", objects) self.assertIn("torch_and_transformers", objects)
self.assertIn("flax_and_transformers", objects) self.assertIn("flax_and_transformers", objects)
self.assertIn("torch_and_transformers_and_onnx", objects) self.assertIn("torch_and_transformers_and_onnx", objects)
# Likewise, we can't assert on the exact content of a key # Likewise, we can't assert on the exact content of a key
self.assertIn("UNet2DModel", objects["torch_and_accelerate"]) self.assertIn("UNet2DModel", objects["torch"])
self.assertIn("FlaxUNet2DConditionModel", objects["flax"]) self.assertIn("FlaxUNet2DConditionModel", objects["flax"])
self.assertIn("StableDiffusionPipeline", objects["torch_and_transformers"]) self.assertIn("StableDiffusionPipeline", objects["torch_and_transformers"])
self.assertIn("FlaxStableDiffusionPipeline", objects["flax_and_transformers"]) self.assertIn("FlaxStableDiffusionPipeline", objects["flax_and_transformers"])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment