"vscode:/vscode.git/clone" did not exist on "5bcc463d05955ba7a11238450d039978a2d67387"
Unverified Commit d761b58b authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[From pretrained] Speed-up loading from cache (#2515)



* [From pretrained] Speed-up loading from cache

* up

* Fix more

* fix one more bug

* make style

* bigger refactor

* factor out function

* Improve more

* better

* deprecate return cache folder

* clean up

* improve tests

* up

* upload

* add nice tests

* simplify

* finish

* correct

* fix version

* rename

* Apply suggestions from code review
Co-authored-by: default avatarLucain <lucainp@gmail.com>

* rename

* correct doc string

* correct more

* Apply suggestions from code review
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* apply code suggestions

* finish

---------
Co-authored-by: default avatarLucain <lucainp@gmail.com>
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
parent 7fe638c5
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import argparse import argparse
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import load_pipeline_from_original_stable_diffusion_ckpt from diffusers.pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt
if __name__ == "__main__": if __name__ == "__main__":
...@@ -125,7 +125,7 @@ if __name__ == "__main__": ...@@ -125,7 +125,7 @@ if __name__ == "__main__":
) )
args = parser.parse_args() args = parser.parse_args()
pipe = load_pipeline_from_original_stable_diffusion_ckpt( pipe = download_from_original_stable_diffusion_ckpt(
checkpoint_path=args.checkpoint_path, checkpoint_path=args.checkpoint_path,
original_config_file=args.original_config_file, original_config_file=args.original_config_file,
image_size=args.image_size, image_size=args.image_size,
......
...@@ -86,7 +86,8 @@ _deps = [ ...@@ -86,7 +86,8 @@ _deps = [
"filelock", "filelock",
"flax>=0.4.1", "flax>=0.4.1",
"hf-doc-builder>=0.3.0", "hf-doc-builder>=0.3.0",
"huggingface-hub>=0.10.0", "huggingface-hub>=0.13.0",
"requests-mock==1.10.0",
"importlib_metadata", "importlib_metadata",
"isort>=5.5.4", "isort>=5.5.4",
"jax>=0.2.8,!=0.3.2", "jax>=0.2.8,!=0.3.2",
...@@ -192,6 +193,7 @@ extras["test"] = deps_list( ...@@ -192,6 +193,7 @@ extras["test"] = deps_list(
"pytest", "pytest",
"pytest-timeout", "pytest-timeout",
"pytest-xdist", "pytest-xdist",
"requests-mock",
"safetensors", "safetensors",
"sentencepiece", "sentencepiece",
"scipy", "scipy",
......
...@@ -31,7 +31,15 @@ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, R ...@@ -31,7 +31,15 @@ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, R
from requests import HTTPError from requests import HTTPError
from . import __version__ from . import __version__
from .utils import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, DummyObject, deprecate, logging from .utils import (
DIFFUSERS_CACHE,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
DummyObject,
deprecate,
extract_commit_hash,
http_user_agent,
logging,
)
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -231,7 +239,11 @@ class ConfigMixin: ...@@ -231,7 +239,11 @@ class ConfigMixin:
@classmethod @classmethod
def load_config( def load_config(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs cls,
pretrained_model_name_or_path: Union[str, os.PathLike],
return_unused_kwargs=False,
return_commit_hash=False,
**kwargs,
) -> Tuple[Dict[str, Any], Dict[str, Any]]: ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
r""" r"""
Instantiate a Python class from a config dictionary Instantiate a Python class from a config dictionary
...@@ -271,6 +283,10 @@ class ConfigMixin: ...@@ -271,6 +283,10 @@ class ConfigMixin:
subfolder (`str`, *optional*, defaults to `""`): subfolder (`str`, *optional*, defaults to `""`):
In case the relevant files are located inside a subfolder of the model repo (either remote in In case the relevant files are located inside a subfolder of the model repo (either remote in
huggingface.co or downloaded locally), you can specify the folder name here. huggingface.co or downloaded locally), you can specify the folder name here.
return_unused_kwargs (`bool`, *optional*, defaults to `False):
Whether unused keyword arguments of the config shall be returned.
return_commit_hash (`bool`, *optional*, defaults to `False):
Whether the commit_hash of the loaded configuration shall be returned.
<Tip> <Tip>
...@@ -295,8 +311,10 @@ class ConfigMixin: ...@@ -295,8 +311,10 @@ class ConfigMixin:
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
_ = kwargs.pop("mirror", None) _ = kwargs.pop("mirror", None)
subfolder = kwargs.pop("subfolder", None) subfolder = kwargs.pop("subfolder", None)
user_agent = kwargs.pop("user_agent", {})
user_agent = {"file_type": "config"} user_agent = {**user_agent, "file_type": "config"}
user_agent = http_user_agent(user_agent)
pretrained_model_name_or_path = str(pretrained_model_name_or_path) pretrained_model_name_or_path = str(pretrained_model_name_or_path)
...@@ -336,7 +354,6 @@ class ConfigMixin: ...@@ -336,7 +354,6 @@ class ConfigMixin:
subfolder=subfolder, subfolder=subfolder,
revision=revision, revision=revision,
) )
except RepositoryNotFoundError: except RepositoryNotFoundError:
raise EnvironmentError( raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier" f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
...@@ -378,13 +395,23 @@ class ConfigMixin: ...@@ -378,13 +395,23 @@ class ConfigMixin:
try: try:
# Load config dict # Load config dict
config_dict = cls._dict_from_json_file(config_file) config_dict = cls._dict_from_json_file(config_file)
commit_hash = extract_commit_hash(config_file)
except (json.JSONDecodeError, UnicodeDecodeError): except (json.JSONDecodeError, UnicodeDecodeError):
raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.") raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
if not (return_unused_kwargs or return_commit_hash):
return config_dict
outputs = (config_dict,)
if return_unused_kwargs: if return_unused_kwargs:
return config_dict, kwargs outputs += (kwargs,)
return config_dict if return_commit_hash:
outputs += (commit_hash,)
return outputs
@staticmethod @staticmethod
def _get_init_keys(cls): def _get_init_keys(cls):
......
...@@ -10,7 +10,8 @@ deps = { ...@@ -10,7 +10,8 @@ deps = {
"filelock": "filelock", "filelock": "filelock",
"flax": "flax>=0.4.1", "flax": "flax>=0.4.1",
"hf-doc-builder": "hf-doc-builder>=0.3.0", "hf-doc-builder": "hf-doc-builder>=0.3.0",
"huggingface-hub": "huggingface-hub>=0.10.0", "huggingface-hub": "huggingface-hub>=0.13.0",
"requests-mock": "requests-mock==1.10.0",
"importlib_metadata": "importlib_metadata", "importlib_metadata": "importlib_metadata",
"isort": "isort>=5.5.4", "isort": "isort>=5.5.4",
"jax": "jax>=0.2.8,!=0.3.2", "jax": "jax>=0.2.8,!=0.3.2",
......
...@@ -458,24 +458,21 @@ class ModelMixin(torch.nn.Module): ...@@ -458,24 +458,21 @@ class ModelMixin(torch.nn.Module):
" dispatching. Please make sure to set `low_cpu_mem_usage=True`." " dispatching. Please make sure to set `low_cpu_mem_usage=True`."
) )
# Load config if we don't provide a configuration
config_path = pretrained_model_name_or_path
user_agent = { user_agent = {
"diffusers": __version__, "diffusers": __version__,
"file_type": "model", "file_type": "model",
"framework": "pytorch", "framework": "pytorch",
} }
# Load config if we don't provide a configuration # load config
config_path = pretrained_model_name_or_path config, unused_kwargs, commit_hash = cls.load_config(
config_path,
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
# Load model
model_file = None
if from_flax:
model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=FLAX_WEIGHTS_NAME,
cache_dir=cache_dir, cache_dir=cache_dir,
return_unused_kwargs=True,
return_commit_hash=True,
force_download=force_download, force_download=force_download,
resume_download=resume_download, resume_download=resume_download,
proxies=proxies, proxies=proxies,
...@@ -483,12 +480,18 @@ class ModelMixin(torch.nn.Module): ...@@ -483,12 +480,18 @@ class ModelMixin(torch.nn.Module):
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
revision=revision, revision=revision,
subfolder=subfolder, subfolder=subfolder,
device_map=device_map,
user_agent=user_agent, user_agent=user_agent,
**kwargs,
) )
config, unused_kwargs = cls.load_config(
config_path, # load model
model_file = None
if from_flax:
model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=FLAX_WEIGHTS_NAME,
cache_dir=cache_dir, cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download, force_download=force_download,
resume_download=resume_download, resume_download=resume_download,
proxies=proxies, proxies=proxies,
...@@ -496,8 +499,8 @@ class ModelMixin(torch.nn.Module): ...@@ -496,8 +499,8 @@ class ModelMixin(torch.nn.Module):
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
revision=revision, revision=revision,
subfolder=subfolder, subfolder=subfolder,
device_map=device_map, user_agent=user_agent,
**kwargs, commit_hash=commit_hash,
) )
model = cls.from_config(config, **unused_kwargs) model = cls.from_config(config, **unused_kwargs)
...@@ -520,6 +523,7 @@ class ModelMixin(torch.nn.Module): ...@@ -520,6 +523,7 @@ class ModelMixin(torch.nn.Module):
revision=revision, revision=revision,
subfolder=subfolder, subfolder=subfolder,
user_agent=user_agent, user_agent=user_agent,
commit_hash=commit_hash,
) )
except: # noqa: E722 except: # noqa: E722
pass pass
...@@ -536,25 +540,12 @@ class ModelMixin(torch.nn.Module): ...@@ -536,25 +540,12 @@ class ModelMixin(torch.nn.Module):
revision=revision, revision=revision,
subfolder=subfolder, subfolder=subfolder,
user_agent=user_agent, user_agent=user_agent,
commit_hash=commit_hash,
) )
if low_cpu_mem_usage: if low_cpu_mem_usage:
# Instantiate model with empty weights # Instantiate model with empty weights
with accelerate.init_empty_weights(): with accelerate.init_empty_weights():
config, unused_kwargs = cls.load_config(
config_path,
cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
device_map=device_map,
**kwargs,
)
model = cls.from_config(config, **unused_kwargs) model = cls.from_config(config, **unused_kwargs)
# if device_map is None, load the state dict and move the params from meta device to the cpu # if device_map is None, load the state dict and move the params from meta device to the cpu
...@@ -593,20 +584,6 @@ class ModelMixin(torch.nn.Module): ...@@ -593,20 +584,6 @@ class ModelMixin(torch.nn.Module):
"error_msgs": [], "error_msgs": [],
} }
else: else:
config, unused_kwargs = cls.load_config(
config_path,
cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
device_map=device_map,
**kwargs,
)
model = cls.from_config(config, **unused_kwargs) model = cls.from_config(config, **unused_kwargs)
state_dict = load_state_dict(model_file, variant=variant) state_dict = load_state_dict(model_file, variant=variant)
...@@ -803,6 +780,7 @@ def _get_model_file( ...@@ -803,6 +780,7 @@ def _get_model_file(
use_auth_token, use_auth_token,
user_agent, user_agent,
revision, revision,
commit_hash=None,
): ):
pretrained_model_name_or_path = str(pretrained_model_name_or_path) pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isfile(pretrained_model_name_or_path): if os.path.isfile(pretrained_model_name_or_path):
...@@ -840,7 +818,7 @@ def _get_model_file( ...@@ -840,7 +818,7 @@ def _get_model_file(
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
user_agent=user_agent, user_agent=user_agent,
subfolder=subfolder, subfolder=subfolder,
revision=revision, revision=revision or commit_hash,
) )
warnings.warn( warnings.warn(
f"Loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` is deprecated. Loading instead from `revision='main'` with `variant={revision}`. Loading model variants via `revision='{revision}'` will be removed in diffusers v1. Please use `variant='{revision}'` instead.", f"Loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` is deprecated. Loading instead from `revision='main'` with `variant={revision}`. Loading model variants via `revision='{revision}'` will be removed in diffusers v1. Please use `variant='{revision}'` instead.",
...@@ -865,7 +843,7 @@ def _get_model_file( ...@@ -865,7 +843,7 @@ def _get_model_file(
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
user_agent=user_agent, user_agent=user_agent,
subfolder=subfolder, subfolder=subfolder,
revision=revision, revision=revision or commit_hash,
) )
return model_file return model_file
......
This diff is collapsed.
...@@ -954,7 +954,7 @@ def stable_unclip_image_noising_components( ...@@ -954,7 +954,7 @@ def stable_unclip_image_noising_components(
return image_normalizer, image_noising_scheduler return image_normalizer, image_noising_scheduler
def load_pipeline_from_original_stable_diffusion_ckpt( def download_from_original_stable_diffusion_ckpt(
checkpoint_path: str, checkpoint_path: str,
original_config_file: str = None, original_config_file: str = None,
image_size: int = 512, image_size: int = 512,
......
...@@ -136,10 +136,11 @@ class SchedulerMixin: ...@@ -136,10 +136,11 @@ class SchedulerMixin:
</Tip> </Tip>
""" """
config, kwargs = cls.load_config( config, kwargs, commit_hash = cls.load_config(
pretrained_model_name_or_path=pretrained_model_name_or_path, pretrained_model_name_or_path=pretrained_model_name_or_path,
subfolder=subfolder, subfolder=subfolder,
return_unused_kwargs=True, return_unused_kwargs=True,
return_commit_hash=True,
**kwargs, **kwargs,
) )
return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs) return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs)
......
...@@ -35,7 +35,11 @@ from .constants import ( ...@@ -35,7 +35,11 @@ from .constants import (
from .deprecation_utils import deprecate from .deprecation_utils import deprecate
from .doc_utils import replace_example_docstring from .doc_utils import replace_example_docstring
from .dynamic_modules_utils import get_class_from_dynamic_module from .dynamic_modules_utils import get_class_from_dynamic_module
from .hub_utils import HF_HUB_OFFLINE, http_user_agent from .hub_utils import (
HF_HUB_OFFLINE,
extract_commit_hash,
http_user_agent,
)
from .import_utils import ( from .import_utils import (
ENV_VARS_TRUE_AND_AUTO_VALUES, ENV_VARS_TRUE_AND_AUTO_VALUES,
ENV_VARS_TRUE_VALUES, ENV_VARS_TRUE_VALUES,
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import os import os
import re
import sys import sys
import traceback import traceback
from pathlib import Path from pathlib import Path
...@@ -22,6 +23,7 @@ from typing import Dict, Optional, Union ...@@ -22,6 +23,7 @@ from typing import Dict, Optional, Union
from uuid import uuid4 from uuid import uuid4
from huggingface_hub import HfFolder, ModelCard, ModelCardData, whoami from huggingface_hub import HfFolder, ModelCard, ModelCardData, whoami
from huggingface_hub.file_download import REGEX_COMMIT_HASH
from huggingface_hub.utils import is_jinja_available from huggingface_hub.utils import is_jinja_available
from .. import __version__ from .. import __version__
...@@ -132,6 +134,20 @@ def create_model_card(args, model_name): ...@@ -132,6 +134,20 @@ def create_model_card(args, model_name):
model_card.save(card_path) model_card.save(card_path)
def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str] = None):
"""
Extracts the commit hash from a resolved filename toward a cache file.
"""
if resolved_file is None or commit_hash is not None:
return commit_hash
resolved_file = str(Path(resolved_file).as_posix())
search = re.search(r"snapshots/([^/]+)/", resolved_file)
if search is None:
return None
commit_hash = search.groups()[0]
return commit_hash if REGEX_COMMIT_HASH.match(commit_hash) else None
# Old default cache path, potentially to be migrated. # Old default cache path, potentially to be migrated.
# This logic was more or less taken from `transformers`, with the following differences: # This logic was more or less taken from `transformers`, with the following differences:
# - Diffusers doesn't use custom environment variables to specify the cache path. # - Diffusers doesn't use custom environment variables to specify the cache path.
...@@ -150,7 +166,7 @@ def move_cache(old_cache_dir: Optional[str] = None, new_cache_dir: Optional[str] ...@@ -150,7 +166,7 @@ def move_cache(old_cache_dir: Optional[str] = None, new_cache_dir: Optional[str]
old_cache_dir = Path(old_cache_dir).expanduser() old_cache_dir = Path(old_cache_dir).expanduser()
new_cache_dir = Path(new_cache_dir).expanduser() new_cache_dir = Path(new_cache_dir).expanduser()
for old_blob_path in old_cache_dir.glob("**/blobs/*"): # move file blob by blob for old_blob_path in old_cache_dir.glob("**/blobs/*"):
if old_blob_path.is_file() and not old_blob_path.is_symlink(): if old_blob_path.is_file() and not old_blob_path.is_symlink():
new_blob_path = new_cache_dir / old_blob_path.relative_to(old_cache_dir) new_blob_path = new_cache_dir / old_blob_path.relative_to(old_cache_dir)
new_blob_path.parent.mkdir(parents=True, exist_ok=True) new_blob_path.parent.mkdir(parents=True, exist_ok=True)
......
...@@ -20,6 +20,7 @@ import unittest.mock as mock ...@@ -20,6 +20,7 @@ import unittest.mock as mock
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import numpy as np import numpy as np
import requests_mock
import torch import torch
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
...@@ -29,6 +30,13 @@ from diffusers.utils import torch_device ...@@ -29,6 +30,13 @@ from diffusers.utils import torch_device
class ModelUtilsTest(unittest.TestCase): class ModelUtilsTest(unittest.TestCase):
def tearDown(self):
super().tearDown()
import diffusers
diffusers.utils.import_utils._safetensors_available = True
def test_accelerate_loading_error_message(self): def test_accelerate_loading_error_message(self):
with self.assertRaises(ValueError) as error_context: with self.assertRaises(ValueError) as error_context:
UNet2DConditionModel.from_pretrained("hf-internal-testing/stable-diffusion-broken", subfolder="unet") UNet2DConditionModel.from_pretrained("hf-internal-testing/stable-diffusion-broken", subfolder="unet")
...@@ -60,6 +68,37 @@ class ModelUtilsTest(unittest.TestCase): ...@@ -60,6 +68,37 @@ class ModelUtilsTest(unittest.TestCase):
if p1.data.ne(p2.data).sum() > 0: if p1.data.ne(p2.data).sum() > 0:
assert False, "Parameters not the same!" assert False, "Parameters not the same!"
def test_one_request_upon_cached(self):
# TODO: For some reason this test fails on MPS where no HEAD call is made.
if torch_device == "mps":
return
import diffusers
diffusers.utils.import_utils._safetensors_available = False
with tempfile.TemporaryDirectory() as tmpdirname:
with requests_mock.mock(real_http=True) as m:
UNet2DConditionModel.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet", cache_dir=tmpdirname
)
download_requests = [r.method for r in m.request_history]
assert download_requests.count("HEAD") == 2, "2 HEAD requests one for config, one for model"
assert download_requests.count("GET") == 2, "2 GET requests one for config, one for model"
with requests_mock.mock(real_http=True) as m:
UNet2DConditionModel.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet", cache_dir=tmpdirname
)
cache_requests = [r.method for r in m.request_history]
assert (
"HEAD" == cache_requests[0] and len(cache_requests) == 1
), "We should call only `model_info` to check for _commit hash and `send_telemetry`"
diffusers.utils.import_utils._safetensors_available = True
class ModelTesterMixin: class ModelTesterMixin:
def test_from_save_pretrained(self): def test_from_save_pretrained(self):
......
...@@ -25,6 +25,7 @@ import unittest.mock as mock ...@@ -25,6 +25,7 @@ import unittest.mock as mock
import numpy as np import numpy as np
import PIL import PIL
import requests_mock
import safetensors.torch import safetensors.torch
import torch import torch
from parameterized import parameterized from parameterized import parameterized
...@@ -61,14 +62,44 @@ torch.backends.cuda.matmul.allow_tf32 = False ...@@ -61,14 +62,44 @@ torch.backends.cuda.matmul.allow_tf32 = False
class DownloadTests(unittest.TestCase): class DownloadTests(unittest.TestCase):
def test_one_request_upon_cached(self):
# TODO: For some reason this test fails on MPS where no HEAD call is made.
if torch_device == "mps":
return
with tempfile.TemporaryDirectory() as tmpdirname:
with requests_mock.mock(real_http=True) as m:
DiffusionPipeline.download(
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname
)
download_requests = [r.method for r in m.request_history]
assert download_requests.count("HEAD") == 16, "15 calls to files + send_telemetry"
assert download_requests.count("GET") == 17, "15 calls to files + model_info + model_index.json"
assert (
len(download_requests) == 33
), "2 calls per file (15 files) + send_telemetry, model_info and model_index.json"
with requests_mock.mock(real_http=True) as m:
DiffusionPipeline.download(
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname
)
cache_requests = [r.method for r in m.request_history]
assert cache_requests.count("HEAD") == 1, "send_telemetry is only HEAD"
assert cache_requests.count("GET") == 1, "model info is only GET"
assert (
len(cache_requests) == 2
), "We should call only `model_info` to check for _commit hash and `send_telemetry`"
def test_download_only_pytorch(self): def test_download_only_pytorch(self):
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
# pipeline has Flax weights # pipeline has Flax weights
_ = DiffusionPipeline.from_pretrained( tmpdirname = DiffusionPipeline.download(
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname "hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname
) )
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots"))] all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))]
files = [item for sublist in all_root_files for item in sublist] files = [item for sublist in all_root_files for item in sublist]
# None of the downloaded files should be a flax file even if we have some here: # None of the downloaded files should be a flax file even if we have some here:
...@@ -101,13 +132,13 @@ class DownloadTests(unittest.TestCase): ...@@ -101,13 +132,13 @@ class DownloadTests(unittest.TestCase):
def test_download_safetensors(self): def test_download_safetensors(self):
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
# pipeline has Flax weights # pipeline has Flax weights
_ = DiffusionPipeline.from_pretrained( tmpdirname = DiffusionPipeline.download(
"hf-internal-testing/tiny-stable-diffusion-pipe-safetensors", "hf-internal-testing/tiny-stable-diffusion-pipe-safetensors",
safety_checker=None, safety_checker=None,
cache_dir=tmpdirname, cache_dir=tmpdirname,
) )
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots"))] all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))]
files = [item for sublist in all_root_files for item in sublist] files = [item for sublist in all_root_files for item in sublist]
# None of the downloaded files should be a pytorch file even if we have some here: # None of the downloaded files should be a pytorch file even if we have some here:
...@@ -204,12 +235,10 @@ class DownloadTests(unittest.TestCase): ...@@ -204,12 +235,10 @@ class DownloadTests(unittest.TestCase):
other_format = ".bin" if safe_avail else ".safetensors" other_format = ".bin" if safe_avail else ".safetensors"
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
StableDiffusionPipeline.from_pretrained( tmpdirname = StableDiffusionPipeline.download(
"hf-internal-testing/stable-diffusion-all-variants", cache_dir=tmpdirname "hf-internal-testing/stable-diffusion-all-variants", cache_dir=tmpdirname
) )
all_root_files = [ all_root_files = [t[-1] for t in os.walk(tmpdirname)]
t[-1] for t in os.walk(os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots"))
]
files = [item for sublist in all_root_files for item in sublist] files = [item for sublist in all_root_files for item in sublist]
# None of the downloaded files should be a variant file even if we have some here: # None of the downloaded files should be a variant file even if we have some here:
...@@ -232,12 +261,10 @@ class DownloadTests(unittest.TestCase): ...@@ -232,12 +261,10 @@ class DownloadTests(unittest.TestCase):
variant = "fp16" variant = "fp16"
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
StableDiffusionPipeline.from_pretrained( tmpdirname = StableDiffusionPipeline.download(
"hf-internal-testing/stable-diffusion-all-variants", cache_dir=tmpdirname, variant=variant "hf-internal-testing/stable-diffusion-all-variants", cache_dir=tmpdirname, variant=variant
) )
all_root_files = [ all_root_files = [t[-1] for t in os.walk(tmpdirname)]
t[-1] for t in os.walk(os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots"))
]
files = [item for sublist in all_root_files for item in sublist] files = [item for sublist in all_root_files for item in sublist]
# None of the downloaded files should be a non-variant file even if we have some here: # None of the downloaded files should be a non-variant file even if we have some here:
...@@ -262,14 +289,13 @@ class DownloadTests(unittest.TestCase): ...@@ -262,14 +289,13 @@ class DownloadTests(unittest.TestCase):
variant = "no_ema" variant = "no_ema"
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
StableDiffusionPipeline.from_pretrained( tmpdirname = StableDiffusionPipeline.download(
"hf-internal-testing/stable-diffusion-all-variants", cache_dir=tmpdirname, variant=variant "hf-internal-testing/stable-diffusion-all-variants", cache_dir=tmpdirname, variant=variant
) )
snapshots = os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots") all_root_files = [t[-1] for t in os.walk(tmpdirname)]
all_root_files = [t[-1] for t in os.walk(snapshots)]
files = [item for sublist in all_root_files for item in sublist] files = [item for sublist in all_root_files for item in sublist]
unet_files = os.listdir(os.path.join(snapshots, os.listdir(snapshots)[0], "unet")) unet_files = os.listdir(os.path.join(tmpdirname, "unet"))
# Some of the downloaded files should be a non-variant file, check: # Some of the downloaded files should be a non-variant file, check:
# https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main/unet # https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main/unet
...@@ -292,7 +318,7 @@ class DownloadTests(unittest.TestCase): ...@@ -292,7 +318,7 @@ class DownloadTests(unittest.TestCase):
for variant in [None, "no_ema"]: for variant in [None, "no_ema"]:
with self.assertRaises(OSError) as error_context: with self.assertRaises(OSError) as error_context:
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
StableDiffusionPipeline.from_pretrained( tmpdirname = StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/stable-diffusion-broken-variants", "hf-internal-testing/stable-diffusion-broken-variants",
cache_dir=tmpdirname, cache_dir=tmpdirname,
variant=variant, variant=variant,
...@@ -302,13 +328,11 @@ class DownloadTests(unittest.TestCase): ...@@ -302,13 +328,11 @@ class DownloadTests(unittest.TestCase):
# text encoder has fp16 variants so we can load it # text encoder has fp16 variants so we can load it
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
pipe = StableDiffusionPipeline.from_pretrained( tmpdirname = StableDiffusionPipeline.download(
"hf-internal-testing/stable-diffusion-broken-variants", cache_dir=tmpdirname, variant="fp16" "hf-internal-testing/stable-diffusion-broken-variants", cache_dir=tmpdirname, variant="fp16"
) )
assert pipe is not None
snapshots = os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots") all_root_files = [t[-1] for t in os.walk(tmpdirname)]
all_root_files = [t[-1] for t in os.walk(snapshots)]
files = [item for sublist in all_root_files for item in sublist] files = [item for sublist in all_root_files for item in sublist]
# None of the downloaded files should be a non-variant file even if we have some here: # None of the downloaded files should be a non-variant file even if we have some here:
...@@ -395,7 +419,7 @@ class CustomPipelineTests(unittest.TestCase): ...@@ -395,7 +419,7 @@ class CustomPipelineTests(unittest.TestCase):
@slow @slow
@require_torch_gpu @require_torch_gpu
def test_load_pipeline_from_git(self): def test_download_from_git(self):
clip_model_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K" clip_model_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
feature_extractor = CLIPFeatureExtractor.from_pretrained(clip_model_id) feature_extractor = CLIPFeatureExtractor.from_pretrained(clip_model_id)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment