Unverified Commit e5810e68 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Variant] Add "variant" as input kwarg so to have better UX when downloading...


[Variant] Add "variant" as input kwarg so to have better UX when downloading no_ema or fp16 weights (#2305)

* [Variant] Add variant loading mechanism

* clean

* improve further

* up

* add tests

* add some first tests

* up

* up

* use path splittetx

* add deprecate

* deprecation warnings

* improve docs

* up

* up

* up

* fix tests

* Apply suggestions from code review
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* Apply suggestions from code review
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* correct code format

* fix warning

* finish

* Apply suggestions from code review
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* Apply suggestions from code review
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* Update docs/source/en/using-diffusers/loading.mdx
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* Apply suggestions from code review
Co-authored-by: default avatarWill Berman <wlbberman@gmail.com>
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* correct loading docs

* finish

---------
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
Co-authored-by: default avatarWill Berman <wlbberman@gmail.com>
parent e3ddbe25
This diff is collapsed.
......@@ -16,18 +16,21 @@
import inspect
import os
import warnings
from functools import partial
from typing import Callable, List, Optional, Tuple, Union
import torch
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from packaging import version
from requests import HTTPError
from torch import Tensor, device
from .. import __version__
from ..utils import (
CONFIG_NAME,
DEPRECATED_REVISION_ARGS,
DIFFUSERS_CACHE,
FLAX_WEIGHTS_NAME,
HF_HUB_OFFLINE,
......@@ -89,12 +92,12 @@ def get_parameter_dtype(parameter: torch.nn.Module):
return first_tuple[1].dtype
def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
"""
Reads a checkpoint file, returning properly formatted errors if they arise.
"""
try:
if os.path.basename(checkpoint_file) == WEIGHTS_NAME:
if os.path.basename(checkpoint_file) == _add_variant(WEIGHTS_NAME, variant):
return torch.load(checkpoint_file, map_location="cpu")
else:
return safetensors.torch.load_file(checkpoint_file, device="cpu")
......@@ -141,6 +144,15 @@ def _load_state_dict_into_model(model_to_load, state_dict):
return error_msgs
def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
if variant is not None:
splits = weights_name.split(".")
splits = splits[:-1] + [variant] + splits[-1:]
weights_name = ".".join(splits)
return weights_name
class ModelMixin(torch.nn.Module):
r"""
Base class for all models.
......@@ -250,6 +262,7 @@ class ModelMixin(torch.nn.Module):
is_main_process: bool = True,
save_function: Callable = None,
safe_serialization: bool = False,
variant: Optional[str] = None,
):
"""
Save a model and its configuration file to a directory, so that it can be re-loaded using the
......@@ -268,6 +281,8 @@ class ModelMixin(torch.nn.Module):
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `False`):
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
variant (`str`, *optional*):
If specified, weights are saved in the format pytorch_model.<variant>.bin.
"""
if safe_serialization and not is_safetensors_available():
raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.")
......@@ -292,6 +307,7 @@ class ModelMixin(torch.nn.Module):
state_dict = model_to_save.state_dict()
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
weights_name = _add_variant(weights_name, variant)
# Save the model
save_function(state_dict, os.path.join(save_directory, weights_name))
......@@ -371,6 +387,9 @@ class ModelMixin(torch.nn.Module):
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
setting this argument to `True` will raise an error.
variant (`str`, *optional*):
If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
ignored when using `from_flax`.
<Tip>
......@@ -401,6 +420,7 @@ class ModelMixin(torch.nn.Module):
subfolder = kwargs.pop("subfolder", None)
device_map = kwargs.pop("device_map", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
variant = kwargs.pop("variant", None)
if low_cpu_mem_usage and not is_accelerate_available():
low_cpu_mem_usage = False
......@@ -488,7 +508,7 @@ class ModelMixin(torch.nn.Module):
try:
model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=SAFETENSORS_WEIGHTS_NAME,
weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
......@@ -504,7 +524,7 @@ class ModelMixin(torch.nn.Module):
if model_file is None:
model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=WEIGHTS_NAME,
weights_name=_add_variant(WEIGHTS_NAME, variant),
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
......@@ -538,7 +558,7 @@ class ModelMixin(torch.nn.Module):
# if device_map is None, load the state dict and move the params from meta device to the cpu
if device_map is None:
param_device = "cpu"
state_dict = load_state_dict(model_file)
state_dict = load_state_dict(model_file, variant=variant)
# move the params from meta device to cpu
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
if len(missing_keys) > 0:
......@@ -587,7 +607,7 @@ class ModelMixin(torch.nn.Module):
)
model = cls.from_config(config, **unused_kwargs)
state_dict = load_state_dict(model_file)
state_dict = load_state_dict(model_file, variant=variant)
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
model,
......@@ -800,8 +820,38 @@ def _get_model_file(
f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}."
)
else:
# 1. First check if deprecated way of loading from branches is used
if (
revision in DEPRECATED_REVISION_ARGS
and (weights_name == WEIGHTS_NAME or weights_name == SAFETENSORS_WEIGHTS_NAME)
and version.parse(version.parse(__version__).base_version) >= version.parse("0.15.0")
):
try:
model_file = hf_hub_download(
pretrained_model_name_or_path,
filename=_add_variant(weights_name, revision),
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
subfolder=subfolder,
revision=revision,
)
warnings.warn(
f"Loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` is deprecated. Loading instead from `revision='main'` with `variant={revision}`. Loading model variants via `revision='{revision}'` will be removed in diffusers v1. Please use `variant='{revision}'` instead.",
FutureWarning,
)
return model_file
except: # noqa: E722
warnings.warn(
f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have a {_add_variant(weights_name)} file in the 'main' branch of {pretrained_model_name_or_path}. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {_add_variant(weights_name)}' so that the correct variant file can be added.",
FutureWarning,
)
try:
# Load from URL or cache if already cached
# 2. Load model file as usual
model_file = hf_hub_download(
pretrained_model_name_or_path,
filename=weights_name,
......
......@@ -17,6 +17,8 @@
import importlib
import inspect
import os
import re
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union
......@@ -31,15 +33,16 @@ from tqdm.auto import tqdm
import diffusers
from .. import __version__
from ..configuration_utils import ConfigMixin
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from ..utils import (
CONFIG_NAME,
DEPRECATED_REVISION_ARGS,
DIFFUSERS_CACHE,
FLAX_WEIGHTS_NAME,
HF_HUB_OFFLINE,
ONNX_WEIGHTS_NAME,
SAFETENSORS_WEIGHTS_NAME,
WEIGHTS_NAME,
BaseOutput,
deprecate,
......@@ -56,6 +59,11 @@ from ..utils import (
if is_transformers_available():
import transformers
from transformers import PreTrainedModel
from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
from ..utils import FLAX_WEIGHTS_NAME, ONNX_WEIGHTS_NAME
INDEX_FILE = "diffusion_pytorch_model.bin"
......@@ -120,15 +128,16 @@ class AudioPipelineOutput(BaseOutput):
audios: np.ndarray
def is_safetensors_compatible(info) -> bool:
filenames = set(sibling.rfilename for sibling in info.siblings)
def is_safetensors_compatible(filenames, variant=None) -> bool:
pt_filenames = set(filename for filename in filenames if filename.endswith(".bin"))
is_safetensors_compatible = any(file.endswith(".safetensors") for file in filenames)
for pt_filename in pt_filenames:
_variant = f".{variant}" if (variant is not None and variant in pt_filename) else ""
prefix, raw = os.path.split(pt_filename)
if raw == "pytorch_model.bin":
if raw == f"pytorch_model{_variant}.bin":
# transformers specific
sf_filename = os.path.join(prefix, "model.safetensors")
sf_filename = os.path.join(prefix, f"model{_variant}.safetensors")
else:
sf_filename = pt_filename[: -len(".bin")] + ".safetensors"
if is_safetensors_compatible and sf_filename not in filenames:
......@@ -137,6 +146,41 @@ def is_safetensors_compatible(info) -> bool:
return is_safetensors_compatible
def variant_compatible_siblings(info, variant=None) -> Union[List[os.PathLike], str]:
filenames = set(sibling.rfilename for sibling in info.siblings)
weight_names = [WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME, FLAX_WEIGHTS_NAME, ONNX_WEIGHTS_NAME]
if is_transformers_available():
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
# model_pytorch, diffusion_model_pytorch, ...
weight_prefixes = [w.split(".")[0] for w in weight_names]
# .bin, .safetensors, ...
weight_suffixs = [w.split(".")[-1] for w in weight_names]
variant_file_regex = (
re.compile(f"({'|'.join(weight_prefixes)})(.{variant}.)({'|'.join(weight_suffixs)})")
if variant is not None
else None
)
non_variant_file_regex = re.compile(f"{'|'.join(weight_names)}")
if variant is not None:
variant_filenames = set(f for f in filenames if variant_file_regex.match(f.split("/")[-1]) is not None)
else:
variant_filenames = set()
non_variant_filenames = set(f for f in filenames if non_variant_file_regex.match(f.split("/")[-1]) is not None)
usable_filenames = set(variant_filenames)
for f in non_variant_filenames:
variant_filename = f"{f.split('.')[0]}.{variant}.{f.split('.')[1]}"
if variant_filename not in usable_filenames:
usable_filenames.add(f)
return usable_filenames, variant_filenames
class DiffusionPipeline(ConfigMixin):
r"""
Base class for all models.
......@@ -194,6 +238,7 @@ class DiffusionPipeline(ConfigMixin):
self,
save_directory: Union[str, os.PathLike],
safe_serialization: bool = False,
variant: Optional[str] = None,
):
"""
Save all variables of the pipeline that can be saved and loaded as well as the pipelines configuration file to
......@@ -205,6 +250,8 @@ class DiffusionPipeline(ConfigMixin):
Directory to which to save. Will be created if it doesn't exist.
safe_serialization (`bool`, *optional*, defaults to `False`):
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
variant (`str`, *optional*):
If specified, weights are saved in the format pytorch_model.<variant>.bin.
"""
self.save_config(save_directory)
......@@ -246,12 +293,15 @@ class DiffusionPipeline(ConfigMixin):
# Call the save method with the argument safe_serialization only if it's supported
save_method_signature = inspect.signature(save_method)
save_method_accept_safe = "safe_serialization" in save_method_signature.parameters
save_method_accept_variant = "variant" in save_method_signature.parameters
save_kwargs = {}
if save_method_accept_safe:
save_method(
os.path.join(save_directory, pipeline_component_name), safe_serialization=safe_serialization
)
else:
save_method(os.path.join(save_directory, pipeline_component_name))
save_kwargs["safe_serialization"] = safe_serialization
if save_method_accept_variant:
save_kwargs["variant"] = variant
save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)
def to(self, torch_device: Optional[Union[str, torch.device]] = None):
if torch_device is None:
......@@ -403,6 +453,9 @@ class DiffusionPipeline(ConfigMixin):
Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
specific pipeline class. The overwritten components are then directly passed to the pipelines
`__init__` method. See example below for more information.
variant (`str`, *optional*):
If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
ignored when using `from_flax`.
<Tip>
......@@ -454,6 +507,7 @@ class DiffusionPipeline(ConfigMixin):
device_map = kwargs.pop("device_map", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
return_cached_folder = kwargs.pop("return_cached_folder", False)
variant = kwargs.pop("variant", None)
# 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained
......@@ -468,28 +522,87 @@ class DiffusionPipeline(ConfigMixin):
use_auth_token=use_auth_token,
revision=revision,
)
# make sure we only download sub-folders and `diffusers` filenames
folder_names = [k for k in config_dict.keys() if not k.startswith("_")]
allow_patterns = [os.path.join(k, "*") for k in folder_names]
allow_patterns += [
WEIGHTS_NAME,
SCHEDULER_CONFIG_NAME,
CONFIG_NAME,
ONNX_WEIGHTS_NAME,
cls.config_name,
]
# make sure we don't download flax weights
ignore_patterns = ["*.msgpack"]
if from_flax:
ignore_patterns = ["*.bin", "*.safetensors"]
# retrieve all folder_names that contain relevant files
folder_names = [k for k, v in config_dict.items() if isinstance(v, list)]
if not local_files_only:
info = model_info(
pretrained_model_name_or_path,
use_auth_token=use_auth_token,
revision=revision,
)
model_filenames, variant_filenames = variant_compatible_siblings(info, variant=variant)
model_folder_names = set([os.path.split(f)[0] for f in model_filenames])
if revision in DEPRECATED_REVISION_ARGS and version.parse(
version.parse(__version__).base_version
) >= version.parse("0.10.0"):
info = model_info(
pretrained_model_name_or_path,
use_auth_token=use_auth_token,
revision=None,
)
comp_model_filenames, _ = variant_compatible_siblings(info, variant=revision)
comp_model_filenames = [
".".join(f.split(".")[:1] + f.split(".")[2:]) for f in comp_model_filenames
]
if set(comp_model_filenames) == set(model_filenames):
warnings.warn(
f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` even though you can load it via `variant=`{revision}`. Loading model variants via `revision='{variant}'` is deprecated and will be removed in diffusers v1. Please use `variant='{revision}'` instead.",
FutureWarning,
)
else:
warnings.warn(
f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have the required variant filenames in the 'main' branch. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {revision} files' so that the correct variant file can be added.",
FutureWarning,
)
# all filenames compatible with variant will be added
allow_patterns = list(model_filenames)
# allow all patterns from non-model folders
# this enables downloading schedulers, tokenizers, ...
allow_patterns += [os.path.join(k, "*") for k in folder_names if k not in model_folder_names]
# also allow downloading config.jsons with the model
allow_patterns += [os.path.join(k, "*.json") for k in model_folder_names]
allow_patterns += [
FLAX_WEIGHTS_NAME,
SCHEDULER_CONFIG_NAME,
CONFIG_NAME,
cls.config_name,
CUSTOM_PIPELINE_FILE_NAME,
]
if custom_pipeline is not None:
allow_patterns += [CUSTOM_PIPELINE_FILE_NAME]
if from_flax:
ignore_patterns = ["*.bin", "*.safetensors", ".onnx"]
elif is_safetensors_available() and is_safetensors_compatible(model_filenames, variant=variant):
ignore_patterns = ["*.bin", "*.msgpack"]
safetensors_variant_filenames = set([f for f in variant_filenames if f.endswith(".safetensors")])
safetensors_model_filenames = set([f for f in model_filenames if f.endswith(".safetensors")])
if (
len(safetensors_variant_filenames) > 0
and safetensors_model_filenames != safetensors_variant_filenames
):
logger.warn(
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
)
else:
ignore_patterns = ["*.safetensors", "*.msgpack"]
bin_variant_filenames = set([f for f in variant_filenames if f.endswith(".bin")])
bin_model_filenames = set([f for f in model_filenames if f.endswith(".bin")])
if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames:
logger.warn(
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
)
else:
# allow everything since it has to be downloaded anyways
ignore_patterns = allow_patterns = None
if cls != DiffusionPipeline:
requested_pipeline_class = cls.__name__
......@@ -501,21 +614,6 @@ class DiffusionPipeline(ConfigMixin):
user_agent = http_user_agent(user_agent)
if is_safetensors_available() and not local_files_only:
info = model_info(
pretrained_model_name_or_path,
use_auth_token=use_auth_token,
revision=revision,
)
if is_safetensors_compatible(info):
ignore_patterns.append("*.bin")
else:
# as a safety mechanism we also don't download safetensors if
# not all safetensors files are there
ignore_patterns.append("*.safetensors")
else:
ignore_patterns.append("*.safetensors")
# download all allow_patterns
cached_folder = snapshot_download(
pretrained_model_name_or_path,
......@@ -533,6 +631,16 @@ class DiffusionPipeline(ConfigMixin):
cached_folder = pretrained_model_name_or_path
config_dict = cls.load_config(cached_folder)
# retrieve which subfolders should load variants
model_variants = {}
if variant is not None:
for folder in os.listdir(cached_folder):
folder_path = os.path.join(cached_folder, folder)
is_folder = os.path.isdir(folder_path) and folder in config_dict
variant_exists = is_folder and any(path.split(".")[1] == variant for path in os.listdir(folder_path))
if variant_exists:
model_variants[folder] = variant
# 2. Load the pipeline class, if using custom module then load it from the hub
# if we load from explicit class, let's use it
if custom_pipeline is not None:
......@@ -717,10 +825,11 @@ class DiffusionPipeline(ConfigMixin):
loading_kwargs["sess_options"] = sess_options
is_diffusers_model = issubclass(class_obj, diffusers.ModelMixin)
transformers_version = version.parse(version.parse(transformers.__version__).base_version)
is_transformers_model = (
is_transformers_available()
and issubclass(class_obj, PreTrainedModel)
and version.parse(version.parse(transformers.__version__).base_version) >= version.parse("4.20.0")
and transformers_version >= version.parse("4.20.0")
)
# When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers.
......@@ -728,9 +837,23 @@ class DiffusionPipeline(ConfigMixin):
# This makes sure that the weights won't be initialized which significantly speeds up loading.
if is_diffusers_model or is_transformers_model:
loading_kwargs["device_map"] = device_map
loading_kwargs["variant"] = model_variants.pop(name, None)
if from_flax:
loading_kwargs["from_flax"] = True
# the following can be deleted once the minimum required `transformers` version
# is higher than 4.27
if (
is_transformers_model
and loading_kwargs["variant"] is not None
and transformers_version < version.parse("4.27.0")
):
raise ImportError(
f"When passing `variant='{variant}'`, please make sure to upgrade your `transformers` version to at least 4.27.0.dev0"
)
elif is_transformers_model and loading_kwargs["variant"] is None:
loading_kwargs.pop("variant")
# if `from_flax` and model is transformer model, can currently not load with `low_cpu_mem_usage`
if not (from_flax and is_transformers_model):
loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
......
......@@ -20,6 +20,7 @@ from packaging import version
from .. import __version__
from .constants import (
CONFIG_NAME,
DEPRECATED_REVISION_ARGS,
DIFFUSERS_CACHE,
DIFFUSERS_DYNAMIC_MODULE_NAME,
FLAX_WEIGHTS_NAME,
......
......@@ -30,3 +30,4 @@ HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
DIFFUSERS_CACHE = default_cache_path
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
......@@ -16,10 +16,12 @@
import inspect
import tempfile
import unittest
import unittest.mock as mock
from typing import Dict, List, Tuple
import numpy as np
import torch
from requests.exceptions import HTTPError
from diffusers.models import ModelMixin, UNet2DConditionModel
from diffusers.training_utils import EMAModel
......@@ -34,6 +36,30 @@ class ModelUtilsTest(unittest.TestCase):
# make sure that error message states what keys are missing
assert "conv_out.bias" in str(error_context.exception)
def test_cached_files_are_used_when_no_internet(self):
# A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock()
response_mock.status_code = 500
response_mock.headers = {}
response_mock.raise_for_status.side_effect = HTTPError
response_mock.json.return_value = {}
# Download this model to make sure it's in the cache.
orig_model = UNet2DConditionModel.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet"
)
# Under the mock environment we get a 500 error when trying to reach the model.
with mock.patch("requests.request", return_value=response_mock):
# Download this model to make sure it's in the cache.
model = UNet2DConditionModel.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet", local_files_only=True
)
for p1, p2 in zip(orig_model.parameters(), model.parameters()):
if p1.data.ne(p2.data).sum() > 0:
assert False, "Parameters not the same!"
class ModelTesterMixin:
def test_from_save_pretrained(self):
......@@ -66,6 +92,44 @@ class ModelTesterMixin:
max_diff = (image - new_image).abs().sum().item()
self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes")
def test_from_save_pretrained_variant(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname, variant="fp16")
new_model = self.model_class.from_pretrained(tmpdirname, variant="fp16")
# non-variant cannot be loaded
with self.assertRaises(OSError) as error_context:
self.model_class.from_pretrained(tmpdirname)
# make sure that error message states what keys are missing
assert "Error no file named diffusion_pytorch_model.bin found in directory" in str(error_context.exception)
new_model.to(torch_device)
with torch.no_grad():
# Warmup pass when using mps (see #372)
if torch_device == "mps" and isinstance(model, ModelMixin):
_ = model(**self.dummy_input)
_ = new_model(**self.dummy_input)
image = model(**inputs_dict)
if isinstance(image, dict):
image = image.sample
new_image = new_model(**inputs_dict)
if isinstance(new_image, dict):
new_image = new_image.sample
max_diff = (image - new_image).abs().sum().item()
self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes")
def test_from_save_pretrained_dtype(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
......
......@@ -21,6 +21,7 @@ import shutil
import sys
import tempfile
import unittest
import unittest.mock as mock
import numpy as np
import PIL
......@@ -28,6 +29,7 @@ import safetensors.torch
import torch
from parameterized import parameterized
from PIL import Image
from requests.exceptions import HTTPError
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import (
......@@ -166,6 +168,155 @@ class DownloadTests(unittest.TestCase):
assert np.max(np.abs(out - out_2)) < 1e-3
def test_cached_files_are_used_when_no_internet(self):
# A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock()
response_mock.status_code = 500
response_mock.headers = {}
response_mock.raise_for_status.side_effect = HTTPError
response_mock.json.return_value = {}
# Download this model to make sure it's in the cache.
orig_pipe = StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
)
orig_comps = {k: v for k, v in orig_pipe.components.items() if hasattr(v, "parameters")}
# Under the mock environment we get a 500 error when trying to reach the model.
with mock.patch("requests.request", return_value=response_mock):
# Download this model to make sure it's in the cache.
pipe = StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None, local_files_only=True
)
comps = {k: v for k, v in pipe.components.items() if hasattr(v, "parameters")}
for m1, m2 in zip(orig_comps.values(), comps.values()):
for p1, p2 in zip(m1.parameters(), m2.parameters()):
if p1.data.ne(p2.data).sum() > 0:
assert False, "Parameters not the same!"
def test_download_from_variant_folder(self):
for safe_avail in [False, True]:
import diffusers
diffusers.utils.import_utils._safetensors_available = safe_avail
other_format = ".bin" if safe_avail else ".safetensors"
with tempfile.TemporaryDirectory() as tmpdirname:
StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/stable-diffusion-all-variants", cache_dir=tmpdirname
)
all_root_files = [
t[-1] for t in os.walk(os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots"))
]
files = [item for sublist in all_root_files for item in sublist]
# None of the downloaded files should be a variant file even if we have some here:
# https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main/unet
assert len(files) == 15, f"We should only download 15 files, not {len(files)}"
assert not any(f.endswith(other_format) for f in files)
# no variants
assert not any(len(f.split(".")) == 3 for f in files)
diffusers.utils.import_utils._safetensors_available = True
def test_download_variant_all(self):
for safe_avail in [False, True]:
import diffusers
diffusers.utils.import_utils._safetensors_available = safe_avail
other_format = ".bin" if safe_avail else ".safetensors"
this_format = ".safetensors" if safe_avail else ".bin"
variant = "fp16"
with tempfile.TemporaryDirectory() as tmpdirname:
StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/stable-diffusion-all-variants", cache_dir=tmpdirname, variant=variant
)
all_root_files = [
t[-1] for t in os.walk(os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots"))
]
files = [item for sublist in all_root_files for item in sublist]
# None of the downloaded files should be a non-variant file even if we have some here:
# https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main/unet
assert len(files) == 15, f"We should only download 15 files, not {len(files)}"
# unet, vae, text_encoder, safety_checker
assert len([f for f in files if f.endswith(f"{variant}{this_format}")]) == 4
# all checkpoints should have variant ending
assert not any(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files)
assert not any(f.endswith(other_format) for f in files)
diffusers.utils.import_utils._safetensors_available = True
def test_download_variant_partly(self):
for safe_avail in [False, True]:
import diffusers
diffusers.utils.import_utils._safetensors_available = safe_avail
other_format = ".bin" if safe_avail else ".safetensors"
this_format = ".safetensors" if safe_avail else ".bin"
variant = "no_ema"
with tempfile.TemporaryDirectory() as tmpdirname:
StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/stable-diffusion-all-variants", cache_dir=tmpdirname, variant=variant
)
snapshots = os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots")
all_root_files = [t[-1] for t in os.walk(snapshots)]
files = [item for sublist in all_root_files for item in sublist]
unet_files = os.listdir(os.path.join(snapshots, os.listdir(snapshots)[0], "unet"))
# Some of the downloaded files should be a non-variant file, check:
# https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main/unet
assert len(files) == 15, f"We should only download 15 files, not {len(files)}"
# only unet has "no_ema" variant
assert f"diffusion_pytorch_model.{variant}{this_format}" in unet_files
assert len([f for f in files if f.endswith(f"{variant}{this_format}")]) == 1
# vae, safety_checker and text_encoder should have no variant
assert sum(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files) == 3
assert not any(f.endswith(other_format) for f in files)
diffusers.utils.import_utils._safetensors_available = True
def test_download_broken_variant(self):
for safe_avail in [False, True]:
import diffusers
diffusers.utils.import_utils._safetensors_available = safe_avail
# text encoder is missing no variant and "no_ema" variant weights, so the following can't work
for variant in [None, "no_ema"]:
with self.assertRaises(OSError) as error_context:
with tempfile.TemporaryDirectory() as tmpdirname:
StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/stable-diffusion-broken-variants",
cache_dir=tmpdirname,
variant=variant,
)
assert "Error no file name" in str(error_context.exception)
# text encoder has fp16 variants so we can load it
with tempfile.TemporaryDirectory() as tmpdirname:
pipe = StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/stable-diffusion-broken-variants", cache_dir=tmpdirname, variant="fp16"
)
assert pipe is not None
snapshots = os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots")
all_root_files = [t[-1] for t in os.walk(snapshots)]
files = [item for sublist in all_root_files for item in sublist]
# None of the downloaded files should be a non-variant file even if we have some here:
# https://huggingface.co/hf-internal-testing/stable-diffusion-broken-variants/tree/main/unet
assert len(files) == 15, f"We should only download 15 files, not {len(files)}"
# only unet has "no_ema" variant
diffusers.utils.import_utils._safetensors_available = True
class CustomPipelineTests(unittest.TestCase):
def test_load_custom_pipeline(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment