Unverified Commit d761b58b authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[From pretrained] Speed-up loading from cache (#2515)



* [From pretrained] Speed-up loading from cache

* up

* Fix more

* fix one more bug

* make style

* bigger refactor

* factor out function

* Improve more

* better

* deprecate return cache folder

* clean up

* improve tests

* up

* upload

* add nice tests

* simplify

* finish

* correct

* fix version

* rename

* Apply suggestions from code review
Co-authored-by: default avatarLucain <lucainp@gmail.com>

* rename

* correct doc string

* correct more

* Apply suggestions from code review
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* apply code suggestions

* finish

---------
Co-authored-by: default avatarLucain <lucainp@gmail.com>
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
parent 7fe638c5
......@@ -16,7 +16,7 @@
import argparse
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import load_pipeline_from_original_stable_diffusion_ckpt
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt
if __name__ == "__main__":
......@@ -125,7 +125,7 @@ if __name__ == "__main__":
)
args = parser.parse_args()
pipe = load_pipeline_from_original_stable_diffusion_ckpt(
pipe = download_from_original_stable_diffusion_ckpt(
checkpoint_path=args.checkpoint_path,
original_config_file=args.original_config_file,
image_size=args.image_size,
......
......@@ -86,7 +86,8 @@ _deps = [
"filelock",
"flax>=0.4.1",
"hf-doc-builder>=0.3.0",
"huggingface-hub>=0.10.0",
"huggingface-hub>=0.13.0",
"requests-mock==1.10.0",
"importlib_metadata",
"isort>=5.5.4",
"jax>=0.2.8,!=0.3.2",
......@@ -192,6 +193,7 @@ extras["test"] = deps_list(
"pytest",
"pytest-timeout",
"pytest-xdist",
"requests-mock",
"safetensors",
"sentencepiece",
"scipy",
......
......@@ -31,7 +31,15 @@ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, R
from requests import HTTPError
from . import __version__
from .utils import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, DummyObject, deprecate, logging
from .utils import (
DIFFUSERS_CACHE,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
DummyObject,
deprecate,
extract_commit_hash,
http_user_agent,
logging,
)
logger = logging.get_logger(__name__)
......@@ -231,7 +239,11 @@ class ConfigMixin:
@classmethod
def load_config(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs
cls,
pretrained_model_name_or_path: Union[str, os.PathLike],
return_unused_kwargs=False,
return_commit_hash=False,
**kwargs,
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
r"""
Instantiate a Python class from a config dictionary
......@@ -271,6 +283,10 @@ class ConfigMixin:
subfolder (`str`, *optional*, defaults to `""`):
In case the relevant files are located inside a subfolder of the model repo (either remote in
huggingface.co or downloaded locally), you can specify the folder name here.
return_unused_kwargs (`bool`, *optional*, defaults to `False):
Whether unused keyword arguments of the config shall be returned.
return_commit_hash (`bool`, *optional*, defaults to `False):
Whether the commit_hash of the loaded configuration shall be returned.
<Tip>
......@@ -295,8 +311,10 @@ class ConfigMixin:
revision = kwargs.pop("revision", None)
_ = kwargs.pop("mirror", None)
subfolder = kwargs.pop("subfolder", None)
user_agent = kwargs.pop("user_agent", {})
user_agent = {"file_type": "config"}
user_agent = {**user_agent, "file_type": "config"}
user_agent = http_user_agent(user_agent)
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
......@@ -336,7 +354,6 @@ class ConfigMixin:
subfolder=subfolder,
revision=revision,
)
except RepositoryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
......@@ -378,13 +395,23 @@ class ConfigMixin:
try:
# Load config dict
config_dict = cls._dict_from_json_file(config_file)
commit_hash = extract_commit_hash(config_file)
except (json.JSONDecodeError, UnicodeDecodeError):
raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
if not (return_unused_kwargs or return_commit_hash):
return config_dict
outputs = (config_dict,)
if return_unused_kwargs:
return config_dict, kwargs
outputs += (kwargs,)
if return_commit_hash:
outputs += (commit_hash,)
return config_dict
return outputs
@staticmethod
def _get_init_keys(cls):
......
......@@ -10,7 +10,8 @@ deps = {
"filelock": "filelock",
"flax": "flax>=0.4.1",
"hf-doc-builder": "hf-doc-builder>=0.3.0",
"huggingface-hub": "huggingface-hub>=0.10.0",
"huggingface-hub": "huggingface-hub>=0.13.0",
"requests-mock": "requests-mock==1.10.0",
"importlib_metadata": "importlib_metadata",
"isort": "isort>=5.5.4",
"jax": "jax>=0.2.8,!=0.3.2",
......
......@@ -458,18 +458,34 @@ class ModelMixin(torch.nn.Module):
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
)
# Load config if we don't provide a configuration
config_path = pretrained_model_name_or_path
user_agent = {
"diffusers": __version__,
"file_type": "model",
"framework": "pytorch",
}
# Load config if we don't provide a configuration
config_path = pretrained_model_name_or_path
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
# Load model
# load config
config, unused_kwargs, commit_hash = cls.load_config(
config_path,
cache_dir=cache_dir,
return_unused_kwargs=True,
return_commit_hash=True,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
device_map=device_map,
user_agent=user_agent,
**kwargs,
)
# load model
model_file = None
if from_flax:
model_file = _get_model_file(
......@@ -484,20 +500,7 @@ class ModelMixin(torch.nn.Module):
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
config, unused_kwargs = cls.load_config(
config_path,
cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
device_map=device_map,
**kwargs,
commit_hash=commit_hash,
)
model = cls.from_config(config, **unused_kwargs)
......@@ -520,6 +523,7 @@ class ModelMixin(torch.nn.Module):
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
commit_hash=commit_hash,
)
except: # noqa: E722
pass
......@@ -536,25 +540,12 @@ class ModelMixin(torch.nn.Module):
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
commit_hash=commit_hash,
)
if low_cpu_mem_usage:
# Instantiate model with empty weights
with accelerate.init_empty_weights():
config, unused_kwargs = cls.load_config(
config_path,
cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
device_map=device_map,
**kwargs,
)
model = cls.from_config(config, **unused_kwargs)
# if device_map is None, load the state dict and move the params from meta device to the cpu
......@@ -593,20 +584,6 @@ class ModelMixin(torch.nn.Module):
"error_msgs": [],
}
else:
config, unused_kwargs = cls.load_config(
config_path,
cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
device_map=device_map,
**kwargs,
)
model = cls.from_config(config, **unused_kwargs)
state_dict = load_state_dict(model_file, variant=variant)
......@@ -803,6 +780,7 @@ def _get_model_file(
use_auth_token,
user_agent,
revision,
commit_hash=None,
):
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isfile(pretrained_model_name_or_path):
......@@ -840,7 +818,7 @@ def _get_model_file(
use_auth_token=use_auth_token,
user_agent=user_agent,
subfolder=subfolder,
revision=revision,
revision=revision or commit_hash,
)
warnings.warn(
f"Loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` is deprecated. Loading instead from `revision='main'` with `variant={revision}`. Loading model variants via `revision='{revision}'` will be removed in diffusers v1. Please use `variant='{revision}'` instead.",
......@@ -865,7 +843,7 @@ def _get_model_file(
use_auth_token=use_auth_token,
user_agent=user_agent,
subfolder=subfolder,
revision=revision,
revision=revision or commit_hash,
)
return model_file
......
This diff is collapsed.
......@@ -954,7 +954,7 @@ def stable_unclip_image_noising_components(
return image_normalizer, image_noising_scheduler
def load_pipeline_from_original_stable_diffusion_ckpt(
def download_from_original_stable_diffusion_ckpt(
checkpoint_path: str,
original_config_file: str = None,
image_size: int = 512,
......
......@@ -136,10 +136,11 @@ class SchedulerMixin:
</Tip>
"""
config, kwargs = cls.load_config(
config, kwargs, commit_hash = cls.load_config(
pretrained_model_name_or_path=pretrained_model_name_or_path,
subfolder=subfolder,
return_unused_kwargs=True,
return_commit_hash=True,
**kwargs,
)
return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs)
......
......@@ -35,7 +35,11 @@ from .constants import (
from .deprecation_utils import deprecate
from .doc_utils import replace_example_docstring
from .dynamic_modules_utils import get_class_from_dynamic_module
from .hub_utils import HF_HUB_OFFLINE, http_user_agent
from .hub_utils import (
HF_HUB_OFFLINE,
extract_commit_hash,
http_user_agent,
)
from .import_utils import (
ENV_VARS_TRUE_AND_AUTO_VALUES,
ENV_VARS_TRUE_VALUES,
......
......@@ -15,6 +15,7 @@
import os
import re
import sys
import traceback
from pathlib import Path
......@@ -22,6 +23,7 @@ from typing import Dict, Optional, Union
from uuid import uuid4
from huggingface_hub import HfFolder, ModelCard, ModelCardData, whoami
from huggingface_hub.file_download import REGEX_COMMIT_HASH
from huggingface_hub.utils import is_jinja_available
from .. import __version__
......@@ -132,6 +134,20 @@ def create_model_card(args, model_name):
model_card.save(card_path)
def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str] = None):
"""
Extracts the commit hash from a resolved filename toward a cache file.
"""
if resolved_file is None or commit_hash is not None:
return commit_hash
resolved_file = str(Path(resolved_file).as_posix())
search = re.search(r"snapshots/([^/]+)/", resolved_file)
if search is None:
return None
commit_hash = search.groups()[0]
return commit_hash if REGEX_COMMIT_HASH.match(commit_hash) else None
# Old default cache path, potentially to be migrated.
# This logic was more or less taken from `transformers`, with the following differences:
# - Diffusers doesn't use custom environment variables to specify the cache path.
......@@ -150,7 +166,7 @@ def move_cache(old_cache_dir: Optional[str] = None, new_cache_dir: Optional[str]
old_cache_dir = Path(old_cache_dir).expanduser()
new_cache_dir = Path(new_cache_dir).expanduser()
for old_blob_path in old_cache_dir.glob("**/blobs/*"): # move file blob by blob
for old_blob_path in old_cache_dir.glob("**/blobs/*"):
if old_blob_path.is_file() and not old_blob_path.is_symlink():
new_blob_path = new_cache_dir / old_blob_path.relative_to(old_cache_dir)
new_blob_path.parent.mkdir(parents=True, exist_ok=True)
......
......@@ -20,6 +20,7 @@ import unittest.mock as mock
from typing import Dict, List, Tuple
import numpy as np
import requests_mock
import torch
from requests.exceptions import HTTPError
......@@ -29,6 +30,13 @@ from diffusers.utils import torch_device
class ModelUtilsTest(unittest.TestCase):
def tearDown(self):
super().tearDown()
import diffusers
diffusers.utils.import_utils._safetensors_available = True
def test_accelerate_loading_error_message(self):
with self.assertRaises(ValueError) as error_context:
UNet2DConditionModel.from_pretrained("hf-internal-testing/stable-diffusion-broken", subfolder="unet")
......@@ -60,6 +68,37 @@ class ModelUtilsTest(unittest.TestCase):
if p1.data.ne(p2.data).sum() > 0:
assert False, "Parameters not the same!"
def test_one_request_upon_cached(self):
# TODO: For some reason this test fails on MPS where no HEAD call is made.
if torch_device == "mps":
return
import diffusers
diffusers.utils.import_utils._safetensors_available = False
with tempfile.TemporaryDirectory() as tmpdirname:
with requests_mock.mock(real_http=True) as m:
UNet2DConditionModel.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet", cache_dir=tmpdirname
)
download_requests = [r.method for r in m.request_history]
assert download_requests.count("HEAD") == 2, "2 HEAD requests one for config, one for model"
assert download_requests.count("GET") == 2, "2 GET requests one for config, one for model"
with requests_mock.mock(real_http=True) as m:
UNet2DConditionModel.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet", cache_dir=tmpdirname
)
cache_requests = [r.method for r in m.request_history]
assert (
"HEAD" == cache_requests[0] and len(cache_requests) == 1
), "We should call only `model_info` to check for _commit hash and `send_telemetry`"
diffusers.utils.import_utils._safetensors_available = True
class ModelTesterMixin:
def test_from_save_pretrained(self):
......
......@@ -25,6 +25,7 @@ import unittest.mock as mock
import numpy as np
import PIL
import requests_mock
import safetensors.torch
import torch
from parameterized import parameterized
......@@ -61,14 +62,44 @@ torch.backends.cuda.matmul.allow_tf32 = False
class DownloadTests(unittest.TestCase):
def test_one_request_upon_cached(self):
# TODO: For some reason this test fails on MPS where no HEAD call is made.
if torch_device == "mps":
return
with tempfile.TemporaryDirectory() as tmpdirname:
with requests_mock.mock(real_http=True) as m:
DiffusionPipeline.download(
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname
)
download_requests = [r.method for r in m.request_history]
assert download_requests.count("HEAD") == 16, "15 calls to files + send_telemetry"
assert download_requests.count("GET") == 17, "15 calls to files + model_info + model_index.json"
assert (
len(download_requests) == 33
), "2 calls per file (15 files) + send_telemetry, model_info and model_index.json"
with requests_mock.mock(real_http=True) as m:
DiffusionPipeline.download(
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname
)
cache_requests = [r.method for r in m.request_history]
assert cache_requests.count("HEAD") == 1, "send_telemetry is only HEAD"
assert cache_requests.count("GET") == 1, "model info is only GET"
assert (
len(cache_requests) == 2
), "We should call only `model_info` to check for _commit hash and `send_telemetry`"
def test_download_only_pytorch(self):
with tempfile.TemporaryDirectory() as tmpdirname:
# pipeline has Flax weights
_ = DiffusionPipeline.from_pretrained(
tmpdirname = DiffusionPipeline.download(
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname
)
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots"))]
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))]
files = [item for sublist in all_root_files for item in sublist]
# None of the downloaded files should be a flax file even if we have some here:
......@@ -101,13 +132,13 @@ class DownloadTests(unittest.TestCase):
def test_download_safetensors(self):
with tempfile.TemporaryDirectory() as tmpdirname:
# pipeline has Flax weights
_ = DiffusionPipeline.from_pretrained(
tmpdirname = DiffusionPipeline.download(
"hf-internal-testing/tiny-stable-diffusion-pipe-safetensors",
safety_checker=None,
cache_dir=tmpdirname,
)
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots"))]
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))]
files = [item for sublist in all_root_files for item in sublist]
# None of the downloaded files should be a pytorch file even if we have some here:
......@@ -204,12 +235,10 @@ class DownloadTests(unittest.TestCase):
other_format = ".bin" if safe_avail else ".safetensors"
with tempfile.TemporaryDirectory() as tmpdirname:
StableDiffusionPipeline.from_pretrained(
tmpdirname = StableDiffusionPipeline.download(
"hf-internal-testing/stable-diffusion-all-variants", cache_dir=tmpdirname
)
all_root_files = [
t[-1] for t in os.walk(os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots"))
]
all_root_files = [t[-1] for t in os.walk(tmpdirname)]
files = [item for sublist in all_root_files for item in sublist]
# None of the downloaded files should be a variant file even if we have some here:
......@@ -232,12 +261,10 @@ class DownloadTests(unittest.TestCase):
variant = "fp16"
with tempfile.TemporaryDirectory() as tmpdirname:
StableDiffusionPipeline.from_pretrained(
tmpdirname = StableDiffusionPipeline.download(
"hf-internal-testing/stable-diffusion-all-variants", cache_dir=tmpdirname, variant=variant
)
all_root_files = [
t[-1] for t in os.walk(os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots"))
]
all_root_files = [t[-1] for t in os.walk(tmpdirname)]
files = [item for sublist in all_root_files for item in sublist]
# None of the downloaded files should be a non-variant file even if we have some here:
......@@ -262,14 +289,13 @@ class DownloadTests(unittest.TestCase):
variant = "no_ema"
with tempfile.TemporaryDirectory() as tmpdirname:
StableDiffusionPipeline.from_pretrained(
tmpdirname = StableDiffusionPipeline.download(
"hf-internal-testing/stable-diffusion-all-variants", cache_dir=tmpdirname, variant=variant
)
snapshots = os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots")
all_root_files = [t[-1] for t in os.walk(snapshots)]
all_root_files = [t[-1] for t in os.walk(tmpdirname)]
files = [item for sublist in all_root_files for item in sublist]
unet_files = os.listdir(os.path.join(snapshots, os.listdir(snapshots)[0], "unet"))
unet_files = os.listdir(os.path.join(tmpdirname, "unet"))
# Some of the downloaded files should be a non-variant file, check:
# https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main/unet
......@@ -292,7 +318,7 @@ class DownloadTests(unittest.TestCase):
for variant in [None, "no_ema"]:
with self.assertRaises(OSError) as error_context:
with tempfile.TemporaryDirectory() as tmpdirname:
StableDiffusionPipeline.from_pretrained(
tmpdirname = StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/stable-diffusion-broken-variants",
cache_dir=tmpdirname,
variant=variant,
......@@ -302,13 +328,11 @@ class DownloadTests(unittest.TestCase):
# text encoder has fp16 variants so we can load it
with tempfile.TemporaryDirectory() as tmpdirname:
pipe = StableDiffusionPipeline.from_pretrained(
tmpdirname = StableDiffusionPipeline.download(
"hf-internal-testing/stable-diffusion-broken-variants", cache_dir=tmpdirname, variant="fp16"
)
assert pipe is not None
snapshots = os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots")
all_root_files = [t[-1] for t in os.walk(snapshots)]
all_root_files = [t[-1] for t in os.walk(tmpdirname)]
files = [item for sublist in all_root_files for item in sublist]
# None of the downloaded files should be a non-variant file even if we have some here:
......@@ -395,7 +419,7 @@ class CustomPipelineTests(unittest.TestCase):
@slow
@require_torch_gpu
def test_load_pipeline_from_git(self):
def test_download_from_git(self):
clip_model_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
feature_extractor = CLIPFeatureExtractor.from_pretrained(clip_model_id)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment