Unverified Commit 7d887118 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Core] support saving and loading of sharded checkpoints (#7830)



* feat: support saving a model in sharded checkpoints.

* feat: make loading of sharded checkpoints work.

* add tests

* cleanse the loading logic a bit more.

* more resilience while loading from the Hub.

* parallelize shard downloads by using snapshot_download()/

* default to a shard size.

* more fix

* Empty-Commit

* debug

* fix

* uality

* more debugging

* fix more

* initial comments from Benjamin

* move certain methods to loading_utils

* add test to check if the correct number of shards are present.

* add a test to check if loading of sharded checkpoints from the Hub is okay

* clarify the unit when passed as an int.

* use hf_hub for sharding.

* remove unnecessary code

* remove unnecessary function

* lucain's comments.

* fixes

* address high-level comments.

* fix test

* subfolder shenanigans./

* Update src/diffusers/utils/hub_utils.py
Co-authored-by: default avatarLucain <lucainp@gmail.com>

* Apply suggestions from code review
Co-authored-by: default avatarLucain <lucainp@gmail.com>

* remove _huggingface_hub_version as not needed.

* address more feedback.

* add a test for local_files_only=True/

* need hf hub to be at least 0.23.2

* style

* final comment.

* clean up subfolder.

* deal with suffixes in code.

* _add_variant default.

* use weights_name_pattern

* remove add_suffix_keyword

* clean up downloading of sharded ckpts.

* don't return something special when using index.json

* fix more

* don't use bare except

* remove comments and catch the errors better

* fix a couple of things when using is_file()

* empty

---------
Co-authored-by: default avatarLucain <lucainp@gmail.com>
parent b63c9568
...@@ -101,7 +101,7 @@ _deps = [ ...@@ -101,7 +101,7 @@ _deps = [
"filelock", "filelock",
"flax>=0.4.1", "flax>=0.4.1",
"hf-doc-builder>=0.3.0", "hf-doc-builder>=0.3.0",
"huggingface-hub>=0.20.2", "huggingface-hub>=0.23.2",
"requests-mock==1.10.0", "requests-mock==1.10.0",
"importlib_metadata", "importlib_metadata",
"invisible-watermark>=0.2.0", "invisible-watermark>=0.2.0",
......
...@@ -9,7 +9,7 @@ deps = { ...@@ -9,7 +9,7 @@ deps = {
"filelock": "filelock", "filelock": "filelock",
"flax": "flax>=0.4.1", "flax": "flax>=0.4.1",
"hf-doc-builder": "hf-doc-builder>=0.3.0", "hf-doc-builder": "hf-doc-builder>=0.3.0",
"huggingface-hub": "huggingface-hub>=0.20.2", "huggingface-hub": "huggingface-hub>=0.23.2",
"requests-mock": "requests-mock==1.10.0", "requests-mock": "requests-mock==1.10.0",
"importlib_metadata": "importlib_metadata", "importlib_metadata": "importlib_metadata",
"invisible-watermark": "invisible-watermark>=0.2.0", "invisible-watermark": "invisible-watermark>=0.2.0",
......
...@@ -18,13 +18,19 @@ import importlib ...@@ -18,13 +18,19 @@ import importlib
import inspect import inspect
import os import os
from collections import OrderedDict from collections import OrderedDict
from pathlib import Path
from typing import List, Optional, Union from typing import List, Optional, Union
import safetensors import safetensors
import torch import torch
from huggingface_hub.utils import EntryNotFoundError
from ..utils import ( from ..utils import (
SAFE_WEIGHTS_INDEX_NAME,
SAFETENSORS_FILE_EXTENSION, SAFETENSORS_FILE_EXTENSION,
WEIGHTS_INDEX_NAME,
_add_variant,
_get_model_file,
is_accelerate_available, is_accelerate_available,
is_torch_version, is_torch_version,
logging, logging,
...@@ -175,3 +181,52 @@ def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[ ...@@ -175,3 +181,52 @@ def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[
load(model_to_load) load(model_to_load)
return error_msgs return error_msgs
def _fetch_index_file(
is_local,
pretrained_model_name_or_path,
subfolder,
use_safetensors,
cache_dir,
variant,
force_download,
resume_download,
proxies,
local_files_only,
token,
revision,
user_agent,
commit_hash,
):
if is_local:
index_file = Path(
pretrained_model_name_or_path,
subfolder or "",
_add_variant(SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME, variant),
)
else:
index_file_in_repo = Path(
subfolder or "",
_add_variant(SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME, variant),
).as_posix()
try:
index_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=index_file_in_repo,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
commit_hash=commit_hash,
)
index_file = Path(index_file)
except (EntryNotFoundError, EnvironmentError):
index_file = None
return index_file
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import inspect import inspect
import itertools import itertools
import json
import os import os
import re import re
from collections import OrderedDict from collections import OrderedDict
...@@ -25,7 +26,7 @@ from typing import Any, Callable, List, Optional, Tuple, Union ...@@ -25,7 +26,7 @@ from typing import Any, Callable, List, Optional, Tuple, Union
import safetensors import safetensors
import torch import torch
from huggingface_hub import create_repo from huggingface_hub import create_repo, split_torch_state_dict_into_shards
from huggingface_hub.utils import validate_hf_hub_args from huggingface_hub.utils import validate_hf_hub_args
from torch import Tensor, nn from torch import Tensor, nn
...@@ -33,9 +34,12 @@ from .. import __version__ ...@@ -33,9 +34,12 @@ from .. import __version__
from ..utils import ( from ..utils import (
CONFIG_NAME, CONFIG_NAME,
FLAX_WEIGHTS_NAME, FLAX_WEIGHTS_NAME,
SAFE_WEIGHTS_INDEX_NAME,
SAFETENSORS_WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME, WEIGHTS_NAME,
_add_variant, _add_variant,
_get_checkpoint_shard_files,
_get_model_file, _get_model_file,
deprecate, deprecate,
is_accelerate_available, is_accelerate_available,
...@@ -49,6 +53,7 @@ from ..utils.hub_utils import ( ...@@ -49,6 +53,7 @@ from ..utils.hub_utils import (
) )
from .model_loading_utils import ( from .model_loading_utils import (
_determine_device_map, _determine_device_map,
_fetch_index_file,
_load_state_dict_into_model, _load_state_dict_into_model,
load_model_dict_into_meta, load_model_dict_into_meta,
load_state_dict, load_state_dict,
...@@ -57,6 +62,8 @@ from .model_loading_utils import ( ...@@ -57,6 +62,8 @@ from .model_loading_utils import (
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
_REGEX_SHARD = re.compile(r"(.*?)-\d{5}-of-\d{5}")
if is_torch_version(">=", "1.9.0"): if is_torch_version(">=", "1.9.0"):
_LOW_CPU_MEM_USAGE_DEFAULT = True _LOW_CPU_MEM_USAGE_DEFAULT = True
...@@ -263,6 +270,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -263,6 +270,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
save_function: Optional[Callable] = None, save_function: Optional[Callable] = None,
safe_serialization: bool = True, safe_serialization: bool = True,
variant: Optional[str] = None, variant: Optional[str] = None,
max_shard_size: Union[int, str] = "5GB",
push_to_hub: bool = False, push_to_hub: bool = False,
**kwargs, **kwargs,
): ):
...@@ -285,6 +293,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -285,6 +293,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
variant (`str`, *optional*): variant (`str`, *optional*):
If specified, weights are saved in the format `pytorch_model.<variant>.bin`. If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
max_shard_size (`int` or `str`, defaults to `"5GB"`):
The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5GB"`).
If expressed as an integer, the unit is bytes.
push_to_hub (`bool`, *optional*, defaults to `False`): push_to_hub (`bool`, *optional*, defaults to `False`):
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
...@@ -296,6 +308,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -296,6 +308,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file") logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return return
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
weights_name = _add_variant(weights_name, variant)
weight_name_split = weights_name.split(".")
if len(weight_name_split) in [2, 3]:
weights_name_pattern = weight_name_split[0] + "{suffix}." + ".".join(weight_name_split[1:])
else:
raise ValueError(f"Invalid {weights_name} provided.")
os.makedirs(save_directory, exist_ok=True) os.makedirs(save_directory, exist_ok=True)
if push_to_hub: if push_to_hub:
...@@ -317,18 +337,58 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -317,18 +337,58 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
# Save the model # Save the model
state_dict = model_to_save.state_dict() state_dict = model_to_save.state_dict()
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
weights_name = _add_variant(weights_name, variant)
# Save the model # Save the model
if safe_serialization: state_dict_split = split_torch_state_dict_into_shards(
safetensors.torch.save_file( state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern
state_dict, Path(save_directory, weights_name).as_posix(), metadata={"format": "pt"} )
# Clean the folder from a previous save
if is_main_process:
for filename in os.listdir(save_directory):
if filename in state_dict_split.filename_to_tensors.keys():
continue
full_filename = os.path.join(save_directory, filename)
if not os.path.isfile(full_filename):
continue
weights_without_ext = weights_name_pattern.replace(".bin", "").replace(".safetensors", "")
weights_without_ext = weights_without_ext.replace("{suffix}", "")
filename_without_ext = filename.replace(".bin", "").replace(".safetensors", "")
# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
if (
filename.startswith(weights_without_ext)
and _REGEX_SHARD.fullmatch(filename_without_ext) is not None
):
os.remove(full_filename)
for filename, tensors in state_dict_split.filename_to_tensors.items():
shard = {tensor: state_dict[tensor] for tensor in tensors}
filepath = os.path.join(save_directory, filename)
if safe_serialization:
# At some point we will need to deal better with save_function (used for TPU and other distributed
# joyfulness), but for now this enough.
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
else:
torch.save(shard, filepath)
if state_dict_split.is_sharded:
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
# Save the index as well
with open(save_index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
logger.info(
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
) )
else: else:
torch.save(state_dict, Path(save_directory, weights_name).as_posix()) path_to_weights = os.path.join(save_directory, weights_name)
logger.info(f"Model weights saved in {path_to_weights}")
logger.info(f"Model weights saved in {Path(save_directory, weights_name).as_posix()}")
if push_to_hub: if push_to_hub:
# Create a new empty model card and eventually tag it # Create a new empty model card and eventually tag it
...@@ -566,6 +626,32 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -566,6 +626,32 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
**kwargs, **kwargs,
) )
# Determine if we're loading from a directory of sharded checkpoints.
is_sharded = False
index_file = None
is_local = os.path.isdir(pretrained_model_name_or_path)
index_file = _fetch_index_file(
is_local=is_local,
pretrained_model_name_or_path=pretrained_model_name_or_path,
subfolder=subfolder or "",
use_safetensors=use_safetensors,
cache_dir=cache_dir,
variant=variant,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
user_agent=user_agent,
commit_hash=commit_hash,
)
if index_file is not None and index_file.is_file():
is_sharded = True
if is_sharded and from_flax:
raise ValueError("Loading of sharded checkpoints is not supported when `from_flax=True`.")
# load model # load model
model_file = None model_file = None
if from_flax: if from_flax:
...@@ -590,7 +676,21 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -590,7 +676,21 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
model = load_flax_checkpoint_in_pytorch_model(model, model_file) model = load_flax_checkpoint_in_pytorch_model(model, model_file)
else: else:
if use_safetensors: if is_sharded:
sharded_ckpt_cached_folder, sharded_metadata = _get_checkpoint_shard_files(
pretrained_model_name_or_path,
index_file,
cache_dir=cache_dir,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
token=token,
user_agent=user_agent,
revision=revision,
subfolder=subfolder or "",
)
elif use_safetensors and not is_sharded:
try: try:
model_file = _get_model_file( model_file = _get_model_file(
pretrained_model_name_or_path, pretrained_model_name_or_path,
...@@ -606,11 +706,16 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -606,11 +706,16 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
user_agent=user_agent, user_agent=user_agent,
commit_hash=commit_hash, commit_hash=commit_hash,
) )
except IOError as e: except IOError as e:
logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}")
if not allow_pickle: if not allow_pickle:
raise e raise
pass logger.warning(
if model_file is None: "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
)
if model_file is None and not is_sharded:
model_file = _get_model_file( model_file = _get_model_file(
pretrained_model_name_or_path, pretrained_model_name_or_path,
weights_name=_add_variant(WEIGHTS_NAME, variant), weights_name=_add_variant(WEIGHTS_NAME, variant),
...@@ -632,7 +737,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -632,7 +737,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
model = cls.from_config(config, **unused_kwargs) model = cls.from_config(config, **unused_kwargs)
# if device_map is None, load the state dict and move the params from meta device to the cpu # if device_map is None, load the state dict and move the params from meta device to the cpu
if device_map is None: if device_map is None and not is_sharded:
param_device = "cpu" param_device = "cpu"
state_dict = load_state_dict(model_file, variant=variant) state_dict = load_state_dict(model_file, variant=variant)
model._convert_deprecated_attention_blocks(state_dict) model._convert_deprecated_attention_blocks(state_dict)
...@@ -670,7 +775,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -670,7 +775,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
try: try:
accelerate.load_checkpoint_and_dispatch( accelerate.load_checkpoint_and_dispatch(
model, model,
model_file, model_file if not is_sharded else sharded_ckpt_cached_folder,
device_map, device_map,
max_memory=max_memory, max_memory=max_memory,
offload_folder=offload_folder, offload_folder=offload_folder,
......
...@@ -28,9 +28,11 @@ from .constants import ( ...@@ -28,9 +28,11 @@ from .constants import (
MIN_PEFT_VERSION, MIN_PEFT_VERSION,
ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_EXTERNAL_WEIGHTS_NAME,
ONNX_WEIGHTS_NAME, ONNX_WEIGHTS_NAME,
SAFE_WEIGHTS_INDEX_NAME,
SAFETENSORS_FILE_EXTENSION, SAFETENSORS_FILE_EXTENSION,
SAFETENSORS_WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME,
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME, WEIGHTS_NAME,
) )
from .deprecation_utils import deprecate from .deprecation_utils import deprecate
...@@ -40,6 +42,7 @@ from .export_utils import export_to_gif, export_to_obj, export_to_ply, export_to ...@@ -40,6 +42,7 @@ from .export_utils import export_to_gif, export_to_obj, export_to_ply, export_to
from .hub_utils import ( from .hub_utils import (
PushToHubMixin, PushToHubMixin,
_add_variant, _add_variant,
_get_checkpoint_shard_files,
_get_model_file, _get_model_file,
extract_commit_hash, extract_commit_hash,
http_user_agent, http_user_agent,
......
...@@ -28,9 +28,11 @@ _CHECK_PEFT = os.environ.get("_CHECK_PEFT", "1") in ENV_VARS_TRUE_VALUES ...@@ -28,9 +28,11 @@ _CHECK_PEFT = os.environ.get("_CHECK_PEFT", "1") in ENV_VARS_TRUE_VALUES
CONFIG_NAME = "config.json" CONFIG_NAME = "config.json"
WEIGHTS_NAME = "diffusion_pytorch_model.bin" WEIGHTS_NAME = "diffusion_pytorch_model.bin"
WEIGHTS_INDEX_NAME = "diffusion_pytorch_model.bin.index.json"
FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack" FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack"
ONNX_WEIGHTS_NAME = "model.onnx" ONNX_WEIGHTS_NAME = "model.onnx"
SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors" SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors"
SAFE_WEIGHTS_INDEX_NAME = "diffusion_pytorch_model.safetensors.index.json"
SAFETENSORS_FILE_EXTENSION = "safetensors" SAFETENSORS_FILE_EXTENSION = "safetensors"
ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb" ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb"
HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co") HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co")
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
import json
import os import os
import re import re
import sys import sys
...@@ -29,6 +30,8 @@ from huggingface_hub import ( ...@@ -29,6 +30,8 @@ from huggingface_hub import (
ModelCardData, ModelCardData,
create_repo, create_repo,
hf_hub_download, hf_hub_download,
model_info,
snapshot_download,
upload_folder, upload_folder,
) )
from huggingface_hub.constants import HF_HUB_CACHE, HF_HUB_DISABLE_TELEMETRY, HF_HUB_OFFLINE from huggingface_hub.constants import HF_HUB_CACHE, HF_HUB_DISABLE_TELEMETRY, HF_HUB_OFFLINE
...@@ -393,6 +396,109 @@ def _get_model_file( ...@@ -393,6 +396,109 @@ def _get_model_file(
) )
# Adapted from
# https://github.com/huggingface/transformers/blob/1360801a69c0b169e3efdbb0cd05d9a0e72bfb70/src/transformers/utils/hub.py#L976
# Differences are in parallelization of shard downloads and checking if shards are present.
def _check_if_shards_exist_locally(local_dir, subfolder, original_shard_filenames):
shards_path = os.path.join(local_dir, subfolder)
shard_filenames = [os.path.join(shards_path, f) for f in original_shard_filenames]
for shard_file in shard_filenames:
if not os.path.exists(shard_file):
raise ValueError(
f"{shards_path} does not appear to have a file named {shard_file} which is "
"required according to the checkpoint index."
)
def _get_checkpoint_shard_files(
pretrained_model_name_or_path,
index_filename,
cache_dir=None,
proxies=None,
resume_download=False,
local_files_only=False,
token=None,
user_agent=None,
revision=None,
subfolder="",
):
"""
For a given model:
- download and cache all the shards of a sharded checkpoint if `pretrained_model_name_or_path` is a model ID on the
Hub
- returns the list of paths to all the shards, as well as some metadata.
For the description of each arg, see [`PreTrainedModel.from_pretrained`]. `index_filename` is the full path to the
index (downloaded and cached if `pretrained_model_name_or_path` is a model ID on the Hub).
"""
if not os.path.isfile(index_filename):
raise ValueError(f"Can't find a checkpoint index ({index_filename}) in {pretrained_model_name_or_path}.")
with open(index_filename, "r") as f:
index = json.loads(f.read())
original_shard_filenames = sorted(set(index["weight_map"].values()))
sharded_metadata = index["metadata"]
sharded_metadata["all_checkpoint_keys"] = list(index["weight_map"].keys())
sharded_metadata["weight_map"] = index["weight_map"].copy()
shards_path = os.path.join(pretrained_model_name_or_path, subfolder)
# First, let's deal with local folder.
if os.path.isdir(pretrained_model_name_or_path):
_check_if_shards_exist_locally(
pretrained_model_name_or_path, subfolder=subfolder, original_shard_filenames=original_shard_filenames
)
return pretrained_model_name_or_path, sharded_metadata
# At this stage pretrained_model_name_or_path is a model identifier on the Hub
allow_patterns = original_shard_filenames
ignore_patterns = ["*.json", "*.md"]
if not local_files_only:
# `model_info` call must guarded with the above condition.
model_files_info = model_info(pretrained_model_name_or_path)
for shard_file in original_shard_filenames:
shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings)
if not shard_file_present:
raise EnvironmentError(
f"{shards_path} does not appear to have a file named {shard_file} which is "
"required according to the checkpoint index."
)
try:
# Load from URL
cached_folder = snapshot_download(
pretrained_model_name_or_path,
cache_dir=cache_dir,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
user_agent=user_agent,
)
# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
# we don't have to catch them here. We have also dealt with EntryNotFoundError.
except HTTPError as e:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {pretrained_model_name_or_path}. You should try"
" again after checking your internet connection."
) from e
# If `local_files_only=True`, `cached_folder` may not contain all the shard files.
if local_files_only:
_check_if_shards_exist_locally(
local_dir=cache_dir, subfolder=subfolder, original_shard_filenames=original_shard_filenames
)
return cached_folder, sharded_metadata
class PushToHubMixin: class PushToHubMixin:
""" """
A Mixin to push a model, scheduler, or pipeline to the Hugging Face Hub. A Mixin to push a model, scheduler, or pipeline to the Hugging Face Hub.
......
...@@ -131,7 +131,6 @@ try: ...@@ -131,7 +131,6 @@ try:
except importlib_metadata.PackageNotFoundError: except importlib_metadata.PackageNotFoundError:
_unidecode_available = False _unidecode_available = False
_onnxruntime_version = "N/A" _onnxruntime_version = "N/A"
_onnx_available = importlib.util.find_spec("onnxruntime") is not None _onnx_available = importlib.util.find_spec("onnxruntime") is not None
if _onnx_available: if _onnx_available:
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
import json
import os
import tempfile import tempfile
import traceback import traceback
import unittest import unittest
...@@ -37,7 +39,7 @@ from diffusers.models.attention_processor import ( ...@@ -37,7 +39,7 @@ from diffusers.models.attention_processor import (
XFormersAttnProcessor, XFormersAttnProcessor,
) )
from diffusers.training_utils import EMAModel from diffusers.training_utils import EMAModel
from diffusers.utils import is_torch_npu_available, is_xformers_available, logging from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME, is_torch_npu_available, is_xformers_available, logging
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
CaptureLogger, CaptureLogger,
get_python_version, get_python_version,
...@@ -129,7 +131,9 @@ class ModelUtilsTest(unittest.TestCase): ...@@ -129,7 +131,9 @@ class ModelUtilsTest(unittest.TestCase):
) )
download_requests = [r.method for r in m.request_history] download_requests = [r.method for r in m.request_history]
assert download_requests.count("HEAD") == 2, "2 HEAD requests one for config, one for model" assert (
download_requests.count("HEAD") == 3
), "3 HEAD requests one for config, one for model, and one for shard index file."
assert download_requests.count("GET") == 2, "2 GET requests one for config, one for model" assert download_requests.count("GET") == 2, "2 GET requests one for config, one for model"
with requests_mock.mock(real_http=True) as m: with requests_mock.mock(real_http=True) as m:
...@@ -142,8 +146,8 @@ class ModelUtilsTest(unittest.TestCase): ...@@ -142,8 +146,8 @@ class ModelUtilsTest(unittest.TestCase):
cache_requests = [r.method for r in m.request_history] cache_requests = [r.method for r in m.request_history]
assert ( assert (
"HEAD" == cache_requests[0] and len(cache_requests) == 1 "HEAD" == cache_requests[0] and len(cache_requests) == 2
), "We should call only `model_info` to check for _commit hash and `send_telemetry`" ), "We should call only `model_info` to check for commit hash and knowing if shard index is present."
def test_weight_overwrite(self): def test_weight_overwrite(self):
with tempfile.TemporaryDirectory() as tmpdirname, self.assertRaises(ValueError) as error_context: with tempfile.TemporaryDirectory() as tmpdirname, self.assertRaises(ValueError) as error_context:
...@@ -866,6 +870,41 @@ class ModelTesterMixin: ...@@ -866,6 +870,41 @@ class ModelTesterMixin:
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
@require_torch_gpu
def test_sharded_checkpoints(self):
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**config).eval()
if model._no_split_modules is None:
return
model = model.to(torch_device)
torch.manual_seed(0)
base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB")
self.assertTrue(os.path.exists(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)))
# Now check if the right number of shards exists. First, let's get the number of shards.
# Since this number can be dependent on the model being tested, it's important that we calculate it
# instead of hardcoding it.
with open(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)) as f:
weight_map_dict = json.load(f)["weight_map"]
first_key = list(weight_map_dict.keys())[0]
weight_loc = weight_map_dict[first_key] # e.g., diffusion_pytorch_model-00001-of-00002.safetensors
expected_num_shards = int(weight_loc.split("-")[-1].split(".")[0])
actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")])
self.assertTrue(actual_num_shards == expected_num_shards)
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto")
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
@is_staging_test @is_staging_test
class ModelPushToHubTester(unittest.TestCase): class ModelPushToHubTester(unittest.TestCase):
......
...@@ -21,6 +21,7 @@ import unittest ...@@ -21,6 +21,7 @@ import unittest
from collections import OrderedDict from collections import OrderedDict
import torch import torch
from huggingface_hub import snapshot_download
from parameterized import parameterized from parameterized import parameterized
from pytest import mark from pytest import mark
...@@ -1034,6 +1035,25 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test ...@@ -1034,6 +1035,25 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
assert sample2.allclose(sample5, atol=1e-4, rtol=1e-4) assert sample2.allclose(sample5, atol=1e-4, rtol=1e-4)
assert sample2.allclose(sample6, atol=1e-4, rtol=1e-4) assert sample2.allclose(sample6, atol=1e-4, rtol=1e-4)
@require_torch_gpu
def test_load_sharded_checkpoint_from_hub(self):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
loaded_model = self.model_class.from_pretrained("hf-internal-testing/unet2d-sharded-dummy", device_map="auto")
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_gpu
def test_load_sharded_checkpoint_from_hub_local(self):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy")
loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True, device_map="auto")
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_peft_backend @require_peft_backend
def test_lora(self): def test_lora(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
......
...@@ -29,6 +29,7 @@ import PIL.Image ...@@ -29,6 +29,7 @@ import PIL.Image
import requests_mock import requests_mock
import safetensors.torch import safetensors.torch
import torch import torch
import torch.nn as nn
from parameterized import parameterized from parameterized import parameterized
from PIL import Image from PIL import Image
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
...@@ -135,6 +136,7 @@ def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout): ...@@ -135,6 +136,7 @@ def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout):
class CustomEncoder(ModelMixin, ConfigMixin): class CustomEncoder(ModelMixin, ConfigMixin):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.linear = nn.Linear(3, 3)
class CustomPipeline(DiffusionPipeline): class CustomPipeline(DiffusionPipeline):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment