Unverified Commit fbff43ac authored by Marc Sun's avatar Marc Sun Committed by GitHub
Browse files

[FEAT] DDUF format (#10037)



* load and save dduf archive

* style

* switch to zip uncompressed

* updates

* Update src/diffusers/pipelines/pipeline_utils.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Update src/diffusers/pipelines/pipeline_utils.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* first draft

* remove print

* switch to dduf_file for consistency

* switch to huggingface hub api

* fix log

* add a basic test

* Update src/diffusers/configuration_utils.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Update src/diffusers/pipelines/pipeline_utils.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Update src/diffusers/pipelines/pipeline_utils.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* fix

* fix variant

* change saving logic

* DDUF - Load transformers components manually (#10171)

* update hfh version

* Load transformers components manually

* load encoder from_pretrained with state_dict

* working version with transformers and tokenizer !

* add generation_config case

* fix tests

* remove saving for now

* typing

* need next version from transformers

* Update src/diffusers/configuration_utils.py
Co-authored-by: default avatarLucain <lucain@huggingface.co>

* check path corectly

* Apply suggestions from code review
Co-authored-by: default avatarLucain <lucain@huggingface.co>

* udapte

* typing

* remove check for subfolder

* quality

* revert setup changes

* oups

* more readable condition

* add loading from the hub test

* add basic docs.

* Apply suggestions from code review
Co-authored-by: default avatarLucain <lucain@huggingface.co>

* add example

* add

* make functions private

* Apply suggestions from code review
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* minor.

* fixes

* fix

* change the precdence of parameterized.

* error out when custom pipeline is passed with dduf_file.

* updates

* fix

* updates

* fixes

* updates

* fix xfail condition.

* fix xfail

* fixes

* sharded checkpoint compat

* add test for sharded checkpoint

* add suggestions

* Update src/diffusers/models/model_loading_utils.py
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

* from suggestions

* add class attributes to flag dduf tests

* last one

* fix logic

* remove comment

* revert changes

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarLucain <lucain@huggingface.co>
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent 3279751b
...@@ -240,6 +240,46 @@ Benefits of using a single-file layout include: ...@@ -240,6 +240,46 @@ Benefits of using a single-file layout include:
1. Easy compatibility with diffusion interfaces such as [ComfyUI](https://github.com/comfyanonymous/ComfyUI) or [Automatic1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) which commonly use a single-file layout. 1. Easy compatibility with diffusion interfaces such as [ComfyUI](https://github.com/comfyanonymous/ComfyUI) or [Automatic1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) which commonly use a single-file layout.
2. Easier to manage (download and share) a single file. 2. Easier to manage (download and share) a single file.
### DDUF
> [!WARNING]
> DDUF is an experimental file format and APIs related to it can change in the future.
DDUF (**D**DUF **D**iffusion **U**nified **F**ormat) is a file format designed to make storing, distributing, and using diffusion models much easier. Built on the ZIP file format, DDUF offers a standardized, efficient, and flexible way to package all parts of a diffusion model into a single, easy-to-manage file. It provides a balance between Diffusers multi-folder format and the widely popular single-file format.
Learn more details about DDUF on the Hugging Face Hub [documentation](https://huggingface.co/docs/hub/dduf).
Pass a checkpoint to the `dduf_file` parameter to load it in [`DiffusionPipeline`].
```py
from diffusers import DiffusionPipeline
import torch
pipe = DiffusionPipeline.from_pretrained(
"DDUF/FLUX.1-dev-DDUF", dduf_file="FLUX.1-dev.dduf", torch_dtype=torch.bfloat16
).to("cuda")
image = pipe(
"photo a cat holding a sign that says Diffusers", num_inference_steps=50, guidance_scale=3.5
).images[0]
image.save("cat.png")
```
To save a pipeline as a `.dduf` checkpoint, use the [`~huggingface_hub.export_folder_as_dduf`] utility, which takes care of all the necessary file-level validations.
```py
from huggingface_hub import export_folder_as_dduf
from diffusers import DiffusionPipeline
import torch
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
save_folder = "flux-dev"
pipe.save_pretrained("flux-dev")
export_folder_as_dduf("flux-dev.dduf", folder_path=save_folder)
> [!TIP]
> Packaging and loading quantized checkpoints in the DDUF format is supported as long as they respect the multi-folder structure.
## Convert layout and files ## Convert layout and files
Diffusers provides many scripts and methods to convert storage layouts and file formats to enable broader support across the diffusion ecosystem. Diffusers provides many scripts and methods to convert storage layouts and file formats to enable broader support across the diffusion ecosystem.
......
...@@ -101,7 +101,7 @@ _deps = [ ...@@ -101,7 +101,7 @@ _deps = [
"filelock", "filelock",
"flax>=0.4.1", "flax>=0.4.1",
"hf-doc-builder>=0.3.0", "hf-doc-builder>=0.3.0",
"huggingface-hub>=0.23.2", "huggingface-hub>=0.27.0",
"requests-mock==1.10.0", "requests-mock==1.10.0",
"importlib_metadata", "importlib_metadata",
"invisible-watermark>=0.2.0", "invisible-watermark>=0.2.0",
......
...@@ -24,10 +24,10 @@ import os ...@@ -24,10 +24,10 @@ import os
import re import re
from collections import OrderedDict from collections import OrderedDict
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Tuple, Union from typing import Any, Dict, Optional, Tuple, Union
import numpy as np import numpy as np
from huggingface_hub import create_repo, hf_hub_download from huggingface_hub import DDUFEntry, create_repo, hf_hub_download
from huggingface_hub.utils import ( from huggingface_hub.utils import (
EntryNotFoundError, EntryNotFoundError,
RepositoryNotFoundError, RepositoryNotFoundError,
...@@ -347,6 +347,7 @@ class ConfigMixin: ...@@ -347,6 +347,7 @@ class ConfigMixin:
_ = kwargs.pop("mirror", None) _ = kwargs.pop("mirror", None)
subfolder = kwargs.pop("subfolder", None) subfolder = kwargs.pop("subfolder", None)
user_agent = kwargs.pop("user_agent", {}) user_agent = kwargs.pop("user_agent", {})
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
user_agent = {**user_agent, "file_type": "config"} user_agent = {**user_agent, "file_type": "config"}
user_agent = http_user_agent(user_agent) user_agent = http_user_agent(user_agent)
...@@ -358,8 +359,15 @@ class ConfigMixin: ...@@ -358,8 +359,15 @@ class ConfigMixin:
"`self.config_name` is not defined. Note that one should not load a config from " "`self.config_name` is not defined. Note that one should not load a config from "
"`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`" "`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
) )
# Custom path for now
if os.path.isfile(pretrained_model_name_or_path): if dduf_entries:
if subfolder is not None:
raise ValueError(
"DDUF file only allow for 1 level of directory (e.g transformer/model1/model.safetentors is not allowed). "
"Please check the DDUF structure"
)
config_file = cls._get_config_file_from_dduf(pretrained_model_name_or_path, dduf_entries)
elif os.path.isfile(pretrained_model_name_or_path):
config_file = pretrained_model_name_or_path config_file = pretrained_model_name_or_path
elif os.path.isdir(pretrained_model_name_or_path): elif os.path.isdir(pretrained_model_name_or_path):
if subfolder is not None and os.path.isfile( if subfolder is not None and os.path.isfile(
...@@ -426,10 +434,8 @@ class ConfigMixin: ...@@ -426,10 +434,8 @@ class ConfigMixin:
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
f"containing a {cls.config_name} file" f"containing a {cls.config_name} file"
) )
try: try:
# Load config dict config_dict = cls._dict_from_json_file(config_file, dduf_entries=dduf_entries)
config_dict = cls._dict_from_json_file(config_file)
commit_hash = extract_commit_hash(config_file) commit_hash = extract_commit_hash(config_file)
except (json.JSONDecodeError, UnicodeDecodeError): except (json.JSONDecodeError, UnicodeDecodeError):
...@@ -552,9 +558,14 @@ class ConfigMixin: ...@@ -552,9 +558,14 @@ class ConfigMixin:
return init_dict, unused_kwargs, hidden_config_dict return init_dict, unused_kwargs, hidden_config_dict
@classmethod @classmethod
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]): def _dict_from_json_file(
with open(json_file, "r", encoding="utf-8") as reader: cls, json_file: Union[str, os.PathLike], dduf_entries: Optional[Dict[str, DDUFEntry]] = None
text = reader.read() ):
if dduf_entries:
text = dduf_entries[json_file].read_text()
else:
with open(json_file, "r", encoding="utf-8") as reader:
text = reader.read()
return json.loads(text) return json.loads(text)
def __repr__(self): def __repr__(self):
...@@ -616,6 +627,20 @@ class ConfigMixin: ...@@ -616,6 +627,20 @@ class ConfigMixin:
with open(json_file_path, "w", encoding="utf-8") as writer: with open(json_file_path, "w", encoding="utf-8") as writer:
writer.write(self.to_json_string()) writer.write(self.to_json_string())
@classmethod
def _get_config_file_from_dduf(cls, pretrained_model_name_or_path: str, dduf_entries: Dict[str, DDUFEntry]):
# paths inside a DDUF file must always be "/"
config_file = (
cls.config_name
if pretrained_model_name_or_path == ""
else "/".join([pretrained_model_name_or_path, cls.config_name])
)
if config_file not in dduf_entries:
raise ValueError(
f"We did not manage to find the file {config_file} in the dduf file. We only have the following files {dduf_entries.keys()}"
)
return config_file
def register_to_config(init): def register_to_config(init):
r""" r"""
......
...@@ -9,7 +9,7 @@ deps = { ...@@ -9,7 +9,7 @@ deps = {
"filelock": "filelock", "filelock": "filelock",
"flax": "flax>=0.4.1", "flax": "flax>=0.4.1",
"hf-doc-builder": "hf-doc-builder>=0.3.0", "hf-doc-builder": "hf-doc-builder>=0.3.0",
"huggingface-hub": "huggingface-hub>=0.23.2", "huggingface-hub": "huggingface-hub>=0.27.0",
"requests-mock": "requests-mock==1.10.0", "requests-mock": "requests-mock==1.10.0",
"importlib_metadata": "importlib_metadata", "importlib_metadata": "importlib_metadata",
"invisible-watermark": "invisible-watermark>=0.2.0", "invisible-watermark": "invisible-watermark>=0.2.0",
......
...@@ -20,10 +20,11 @@ import os ...@@ -20,10 +20,11 @@ import os
from array import array from array import array
from collections import OrderedDict from collections import OrderedDict
from pathlib import Path from pathlib import Path
from typing import List, Optional, Union from typing import Dict, List, Optional, Union
import safetensors import safetensors
import torch import torch
from huggingface_hub import DDUFEntry
from huggingface_hub.utils import EntryNotFoundError from huggingface_hub.utils import EntryNotFoundError
from ..utils import ( from ..utils import (
...@@ -132,7 +133,10 @@ def _fetch_remapped_cls_from_config(config, old_class): ...@@ -132,7 +133,10 @@ def _fetch_remapped_cls_from_config(config, old_class):
def load_state_dict( def load_state_dict(
checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None, disable_mmap: bool = False checkpoint_file: Union[str, os.PathLike],
variant: Optional[str] = None,
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
disable_mmap: bool = False,
): ):
""" """
Reads a checkpoint file, returning properly formatted errors if they arise. Reads a checkpoint file, returning properly formatted errors if they arise.
...@@ -144,6 +148,10 @@ def load_state_dict( ...@@ -144,6 +148,10 @@ def load_state_dict(
try: try:
file_extension = os.path.basename(checkpoint_file).split(".")[-1] file_extension = os.path.basename(checkpoint_file).split(".")[-1]
if file_extension == SAFETENSORS_FILE_EXTENSION: if file_extension == SAFETENSORS_FILE_EXTENSION:
if dduf_entries:
# tensors are loaded on cpu
with dduf_entries[checkpoint_file].as_mmap() as mm:
return safetensors.torch.load(mm)
if disable_mmap: if disable_mmap:
return safetensors.torch.load(open(checkpoint_file, "rb").read()) return safetensors.torch.load(open(checkpoint_file, "rb").read())
else: else:
...@@ -284,6 +292,7 @@ def _fetch_index_file( ...@@ -284,6 +292,7 @@ def _fetch_index_file(
revision, revision,
user_agent, user_agent,
commit_hash, commit_hash,
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
): ):
if is_local: if is_local:
index_file = Path( index_file = Path(
...@@ -309,8 +318,10 @@ def _fetch_index_file( ...@@ -309,8 +318,10 @@ def _fetch_index_file(
subfolder=None, subfolder=None,
user_agent=user_agent, user_agent=user_agent,
commit_hash=commit_hash, commit_hash=commit_hash,
dduf_entries=dduf_entries,
) )
index_file = Path(index_file) if not dduf_entries:
index_file = Path(index_file)
except (EntryNotFoundError, EnvironmentError): except (EntryNotFoundError, EnvironmentError):
index_file = None index_file = None
...@@ -319,7 +330,9 @@ def _fetch_index_file( ...@@ -319,7 +330,9 @@ def _fetch_index_file(
# Adapted from # Adapted from
# https://github.com/bghira/SimpleTuner/blob/cea2457ab063f6dedb9e697830ae68a96be90641/helpers/training/save_hooks.py#L64 # https://github.com/bghira/SimpleTuner/blob/cea2457ab063f6dedb9e697830ae68a96be90641/helpers/training/save_hooks.py#L64
def _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata): def _merge_sharded_checkpoints(
sharded_ckpt_cached_folder, sharded_metadata, dduf_entries: Optional[Dict[str, DDUFEntry]] = None
):
weight_map = sharded_metadata.get("weight_map", None) weight_map = sharded_metadata.get("weight_map", None)
if weight_map is None: if weight_map is None:
raise KeyError("'weight_map' key not found in the shard index file.") raise KeyError("'weight_map' key not found in the shard index file.")
...@@ -332,14 +345,23 @@ def _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata): ...@@ -332,14 +345,23 @@ def _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata):
# Load tensors from each unique file # Load tensors from each unique file
for file_name in files_to_load: for file_name in files_to_load:
part_file_path = os.path.join(sharded_ckpt_cached_folder, file_name) part_file_path = os.path.join(sharded_ckpt_cached_folder, file_name)
if not os.path.exists(part_file_path): if dduf_entries:
raise FileNotFoundError(f"Part file {file_name} not found.") if part_file_path not in dduf_entries:
raise FileNotFoundError(f"Part file {file_name} not found.")
else:
if not os.path.exists(part_file_path):
raise FileNotFoundError(f"Part file {file_name} not found.")
if is_safetensors: if is_safetensors:
with safetensors.safe_open(part_file_path, framework="pt", device="cpu") as f: if dduf_entries:
for tensor_key in f.keys(): with dduf_entries[part_file_path].as_mmap() as mm:
if tensor_key in weight_map: tensors = safetensors.torch.load(mm)
merged_state_dict[tensor_key] = f.get_tensor(tensor_key) merged_state_dict.update(tensors)
else:
with safetensors.safe_open(part_file_path, framework="pt", device="cpu") as f:
for tensor_key in f.keys():
if tensor_key in weight_map:
merged_state_dict[tensor_key] = f.get_tensor(tensor_key)
else: else:
merged_state_dict.update(torch.load(part_file_path, weights_only=True, map_location="cpu")) merged_state_dict.update(torch.load(part_file_path, weights_only=True, map_location="cpu"))
...@@ -360,6 +382,7 @@ def _fetch_index_file_legacy( ...@@ -360,6 +382,7 @@ def _fetch_index_file_legacy(
revision, revision,
user_agent, user_agent,
commit_hash, commit_hash,
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
): ):
if is_local: if is_local:
index_file = Path( index_file = Path(
...@@ -400,6 +423,7 @@ def _fetch_index_file_legacy( ...@@ -400,6 +423,7 @@ def _fetch_index_file_legacy(
subfolder=None, subfolder=None,
user_agent=user_agent, user_agent=user_agent,
commit_hash=commit_hash, commit_hash=commit_hash,
dduf_entries=dduf_entries,
) )
index_file = Path(index_file) index_file = Path(index_file)
deprecation_message = f"This serialization format is now deprecated to standardize the serialization format between `transformers` and `diffusers`. We recommend you to remove the existing files associated with the current variant ({variant}) and re-obtain them by running a `save_pretrained()`." deprecation_message = f"This serialization format is now deprecated to standardize the serialization format between `transformers` and `diffusers`. We recommend you to remove the existing files associated with the current variant ({variant}) and re-obtain them by running a `save_pretrained()`."
......
...@@ -23,11 +23,11 @@ import re ...@@ -23,11 +23,11 @@ import re
from collections import OrderedDict from collections import OrderedDict
from functools import partial, wraps from functools import partial, wraps
from pathlib import Path from pathlib import Path
from typing import Any, Callable, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import safetensors import safetensors
import torch import torch
from huggingface_hub import create_repo, split_torch_state_dict_into_shards from huggingface_hub import DDUFEntry, create_repo, split_torch_state_dict_into_shards
from huggingface_hub.utils import validate_hf_hub_args from huggingface_hub.utils import validate_hf_hub_args
from torch import Tensor, nn from torch import Tensor, nn
...@@ -607,6 +607,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -607,6 +607,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
variant = kwargs.pop("variant", None) variant = kwargs.pop("variant", None)
use_safetensors = kwargs.pop("use_safetensors", None) use_safetensors = kwargs.pop("use_safetensors", None)
quantization_config = kwargs.pop("quantization_config", None) quantization_config = kwargs.pop("quantization_config", None)
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
disable_mmap = kwargs.pop("disable_mmap", False) disable_mmap = kwargs.pop("disable_mmap", False)
allow_pickle = False allow_pickle = False
...@@ -700,6 +701,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -700,6 +701,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
revision=revision, revision=revision,
subfolder=subfolder, subfolder=subfolder,
user_agent=user_agent, user_agent=user_agent,
dduf_entries=dduf_entries,
**kwargs, **kwargs,
) )
# no in-place modification of the original config. # no in-place modification of the original config.
...@@ -776,13 +778,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -776,13 +778,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
"revision": revision, "revision": revision,
"user_agent": user_agent, "user_agent": user_agent,
"commit_hash": commit_hash, "commit_hash": commit_hash,
"dduf_entries": dduf_entries,
} }
index_file = _fetch_index_file(**index_file_kwargs) index_file = _fetch_index_file(**index_file_kwargs)
# In case the index file was not found we still have to consider the legacy format. # In case the index file was not found we still have to consider the legacy format.
# this becomes applicable when the variant is not None. # this becomes applicable when the variant is not None.
if variant is not None and (index_file is None or not os.path.exists(index_file)): if variant is not None and (index_file is None or not os.path.exists(index_file)):
index_file = _fetch_index_file_legacy(**index_file_kwargs) index_file = _fetch_index_file_legacy(**index_file_kwargs)
if index_file is not None and index_file.is_file(): if index_file is not None and (dduf_entries or index_file.is_file()):
is_sharded = True is_sharded = True
if is_sharded and from_flax: if is_sharded and from_flax:
...@@ -811,6 +814,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -811,6 +814,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
model = load_flax_checkpoint_in_pytorch_model(model, model_file) model = load_flax_checkpoint_in_pytorch_model(model, model_file)
else: else:
# in the case it is sharded, we have already the index
if is_sharded: if is_sharded:
sharded_ckpt_cached_folder, sharded_metadata = _get_checkpoint_shard_files( sharded_ckpt_cached_folder, sharded_metadata = _get_checkpoint_shard_files(
pretrained_model_name_or_path, pretrained_model_name_or_path,
...@@ -822,10 +826,13 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -822,10 +826,13 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
user_agent=user_agent, user_agent=user_agent,
revision=revision, revision=revision,
subfolder=subfolder or "", subfolder=subfolder or "",
dduf_entries=dduf_entries,
) )
# TODO: https://github.com/huggingface/diffusers/issues/10013 # TODO: https://github.com/huggingface/diffusers/issues/10013
if hf_quantizer is not None: if hf_quantizer is not None or dduf_entries:
model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata) model_file = _merge_sharded_checkpoints(
sharded_ckpt_cached_folder, sharded_metadata, dduf_entries=dduf_entries
)
logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.") logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.")
is_sharded = False is_sharded = False
...@@ -843,6 +850,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -843,6 +850,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
subfolder=subfolder, subfolder=subfolder,
user_agent=user_agent, user_agent=user_agent,
commit_hash=commit_hash, commit_hash=commit_hash,
dduf_entries=dduf_entries,
) )
except IOError as e: except IOError as e:
...@@ -866,6 +874,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -866,6 +874,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
subfolder=subfolder, subfolder=subfolder,
user_agent=user_agent, user_agent=user_agent,
commit_hash=commit_hash, commit_hash=commit_hash,
dduf_entries=dduf_entries,
) )
if low_cpu_mem_usage: if low_cpu_mem_usage:
...@@ -887,7 +896,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -887,7 +896,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
# TODO (sayakpaul, SunMarc): remove this after model loading refactor # TODO (sayakpaul, SunMarc): remove this after model loading refactor
else: else:
param_device = torch.device(torch.cuda.current_device()) param_device = torch.device(torch.cuda.current_device())
state_dict = load_state_dict(model_file, variant=variant, disable_mmap=disable_mmap) state_dict = load_state_dict(
model_file, variant=variant, dduf_entries=dduf_entries, disable_mmap=disable_mmap
)
model._convert_deprecated_attention_blocks(state_dict) model._convert_deprecated_attention_blocks(state_dict)
# move the params from meta device to cpu # move the params from meta device to cpu
...@@ -983,7 +994,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -983,7 +994,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
else: else:
model = cls.from_config(config, **unused_kwargs) model = cls.from_config(config, **unused_kwargs)
state_dict = load_state_dict(model_file, variant=variant, disable_mmap=disable_mmap) state_dict = load_state_dict(
model_file, variant=variant, dduf_entries=dduf_entries, disable_mmap=disable_mmap
)
model._convert_deprecated_attention_blocks(state_dict) model._convert_deprecated_attention_blocks(state_dict)
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
......
...@@ -12,19 +12,19 @@ ...@@ -12,19 +12,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import importlib import importlib
import os import os
import re import re
import warnings import warnings
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
import requests
import torch import torch
from huggingface_hub import ModelCard, model_info from huggingface_hub import DDUFEntry, ModelCard, model_info, snapshot_download
from huggingface_hub.utils import validate_hf_hub_args from huggingface_hub.utils import OfflineModeIsEnabled, validate_hf_hub_args
from packaging import version from packaging import version
from requests.exceptions import HTTPError
from .. import __version__ from .. import __version__
from ..utils import ( from ..utils import (
...@@ -38,14 +38,16 @@ from ..utils import ( ...@@ -38,14 +38,16 @@ from ..utils import (
is_accelerate_available, is_accelerate_available,
is_peft_available, is_peft_available,
is_transformers_available, is_transformers_available,
is_transformers_version,
logging, logging,
) )
from ..utils.torch_utils import is_compiled_module from ..utils.torch_utils import is_compiled_module
from .transformers_loading_utils import _load_tokenizer_from_dduf, _load_transformers_model_from_dduf
if is_transformers_available(): if is_transformers_available():
import transformers import transformers
from transformers import PreTrainedModel from transformers import PreTrainedModel, PreTrainedTokenizerBase
from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
...@@ -627,6 +629,7 @@ def load_sub_model( ...@@ -627,6 +629,7 @@ def load_sub_model(
low_cpu_mem_usage: bool, low_cpu_mem_usage: bool,
cached_folder: Union[str, os.PathLike], cached_folder: Union[str, os.PathLike],
use_safetensors: bool, use_safetensors: bool,
dduf_entries: Optional[Dict[str, DDUFEntry]],
): ):
"""Helper method to load the module `name` from `library_name` and `class_name`""" """Helper method to load the module `name` from `library_name` and `class_name`"""
...@@ -663,7 +666,7 @@ def load_sub_model( ...@@ -663,7 +666,7 @@ def load_sub_model(
f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}." f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}."
) )
load_method = getattr(class_obj, load_method_name) load_method = _get_load_method(class_obj, load_method_name, is_dduf=dduf_entries is not None)
# add kwargs to loading method # add kwargs to loading method
diffusers_module = importlib.import_module(__name__.split(".")[0]) diffusers_module = importlib.import_module(__name__.split(".")[0])
...@@ -721,7 +724,10 @@ def load_sub_model( ...@@ -721,7 +724,10 @@ def load_sub_model(
loading_kwargs["low_cpu_mem_usage"] = False loading_kwargs["low_cpu_mem_usage"] = False
# check if the module is in a subdirectory # check if the module is in a subdirectory
if os.path.isdir(os.path.join(cached_folder, name)): if dduf_entries:
loading_kwargs["dduf_entries"] = dduf_entries
loaded_sub_model = load_method(name, **loading_kwargs)
elif os.path.isdir(os.path.join(cached_folder, name)):
loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs) loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs)
else: else:
# else load from the root directory # else load from the root directory
...@@ -746,6 +752,22 @@ def load_sub_model( ...@@ -746,6 +752,22 @@ def load_sub_model(
return loaded_sub_model return loaded_sub_model
def _get_load_method(class_obj: object, load_method_name: str, is_dduf: bool) -> Callable:
"""
Return the method to load the sub model.
In practice, this method will return the `"from_pretrained"` (or `load_method_name`) method of the class object
except if loading from a DDUF checkpoint. In that case, transformers models and tokenizers have a specific loading
method that we need to use.
"""
if is_dduf:
if issubclass(class_obj, PreTrainedTokenizerBase):
return lambda *args, **kwargs: _load_tokenizer_from_dduf(class_obj, *args, **kwargs)
if issubclass(class_obj, PreTrainedModel):
return lambda *args, **kwargs: _load_transformers_model_from_dduf(class_obj, *args, **kwargs)
return getattr(class_obj, load_method_name)
def _fetch_class_library_tuple(module): def _fetch_class_library_tuple(module):
# import it here to avoid circular import # import it here to avoid circular import
diffusers_module = importlib.import_module(__name__.split(".")[0]) diffusers_module = importlib.import_module(__name__.split(".")[0])
...@@ -968,3 +990,70 @@ def _get_ignore_patterns( ...@@ -968,3 +990,70 @@ def _get_ignore_patterns(
) )
return ignore_patterns return ignore_patterns
def _download_dduf_file(
pretrained_model_name: str,
dduf_file: str,
pipeline_class_name: str,
cache_dir: str,
proxies: str,
local_files_only: bool,
token: str,
revision: str,
):
model_info_call_error = None
if not local_files_only:
try:
info = model_info(pretrained_model_name, token=token, revision=revision)
except (HTTPError, OfflineModeIsEnabled, requests.ConnectionError) as e:
logger.warning(f"Couldn't connect to the Hub: {e}.\nWill try to load from local cache.")
local_files_only = True
model_info_call_error = e # save error to reraise it if model is not cached locally
if (
not local_files_only
and dduf_file is not None
and dduf_file not in (sibling.rfilename for sibling in info.siblings)
):
raise ValueError(f"Requested {dduf_file} file is not available in {pretrained_model_name}.")
try:
user_agent = {"pipeline_class": pipeline_class_name, "dduf": True}
cached_folder = snapshot_download(
pretrained_model_name,
cache_dir=cache_dir,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
allow_patterns=[dduf_file],
user_agent=user_agent,
)
return cached_folder
except FileNotFoundError:
# Means we tried to load pipeline with `local_files_only=True` but the files have not been found in local cache.
# This can happen in two cases:
# 1. If the user passed `local_files_only=True` => we raise the error directly
# 2. If we forced `local_files_only=True` when `model_info` failed => we raise the initial error
if model_info_call_error is None:
# 1. user passed `local_files_only=True`
raise
else:
# 2. we forced `local_files_only=True` when `model_info` failed
raise EnvironmentError(
f"Cannot load model {pretrained_model_name}: model is not cached locally and an error occurred"
" while trying to fetch metadata from the Hub. Please check out the root cause in the stacktrace"
" above."
) from model_info_call_error
def _maybe_raise_error_for_incorrect_transformers(config_dict):
has_transformers_component = False
for k in config_dict:
if isinstance(config_dict[k], list):
has_transformers_component = config_dict[k][0] == "transformers"
if has_transformers_component:
break
if has_transformers_component and not is_transformers_version(">", "4.47.1"):
raise ValueError("Please upgrade your `transformers` installation to the latest version to use DDUF.")
...@@ -29,10 +29,12 @@ import PIL.Image ...@@ -29,10 +29,12 @@ import PIL.Image
import requests import requests
import torch import torch
from huggingface_hub import ( from huggingface_hub import (
DDUFEntry,
ModelCard, ModelCard,
create_repo, create_repo,
hf_hub_download, hf_hub_download,
model_info, model_info,
read_dduf_file,
snapshot_download, snapshot_download,
) )
from huggingface_hub.utils import OfflineModeIsEnabled, validate_hf_hub_args from huggingface_hub.utils import OfflineModeIsEnabled, validate_hf_hub_args
...@@ -72,6 +74,7 @@ from .pipeline_loading_utils import ( ...@@ -72,6 +74,7 @@ from .pipeline_loading_utils import (
CONNECTED_PIPES_KEYS, CONNECTED_PIPES_KEYS,
CUSTOM_PIPELINE_FILE_NAME, CUSTOM_PIPELINE_FILE_NAME,
LOADABLE_CLASSES, LOADABLE_CLASSES,
_download_dduf_file,
_fetch_class_library_tuple, _fetch_class_library_tuple,
_get_custom_components_and_folders, _get_custom_components_and_folders,
_get_custom_pipeline_class, _get_custom_pipeline_class,
...@@ -79,6 +82,7 @@ from .pipeline_loading_utils import ( ...@@ -79,6 +82,7 @@ from .pipeline_loading_utils import (
_get_ignore_patterns, _get_ignore_patterns,
_get_pipeline_class, _get_pipeline_class,
_identify_model_variants, _identify_model_variants,
_maybe_raise_error_for_incorrect_transformers,
_maybe_raise_warning_for_inpainting, _maybe_raise_warning_for_inpainting,
_resolve_custom_pipeline_and_cls, _resolve_custom_pipeline_and_cls,
_unwrap_model, _unwrap_model,
...@@ -218,6 +222,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -218,6 +222,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
namespace). namespace).
kwargs (`Dict[str, Any]`, *optional*): kwargs (`Dict[str, Any]`, *optional*):
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
""" """
...@@ -531,6 +536,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -531,6 +536,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
- A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights - A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights
saved using saved using
[`~DiffusionPipeline.save_pretrained`]. [`~DiffusionPipeline.save_pretrained`].
- A path to a *directory* (for example `./my_pipeline_directory/`) containing a dduf file
torch_dtype (`str` or `torch.dtype`, *optional*): torch_dtype (`str` or `torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the
dtype is automatically derived from the model's weights. dtype is automatically derived from the model's weights.
...@@ -625,6 +631,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -625,6 +631,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
variant (`str`, *optional*): variant (`str`, *optional*):
Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
loading `from_flax`. loading `from_flax`.
dduf_file(`str`, *optional*):
Load weights from the specified dduf file.
<Tip> <Tip>
...@@ -674,6 +682,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -674,6 +682,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
offload_state_dict = kwargs.pop("offload_state_dict", False) offload_state_dict = kwargs.pop("offload_state_dict", False)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
variant = kwargs.pop("variant", None) variant = kwargs.pop("variant", None)
dduf_file = kwargs.pop("dduf_file", None)
use_safetensors = kwargs.pop("use_safetensors", None) use_safetensors = kwargs.pop("use_safetensors", None)
use_onnx = kwargs.pop("use_onnx", None) use_onnx = kwargs.pop("use_onnx", None)
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False) load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
...@@ -722,6 +731,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -722,6 +731,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
" dispatching. Please make sure to set `low_cpu_mem_usage=True`." " dispatching. Please make sure to set `low_cpu_mem_usage=True`."
) )
if dduf_file:
if custom_pipeline:
raise NotImplementedError("Custom pipelines are not supported with DDUF at the moment.")
if load_connected_pipeline:
raise NotImplementedError("Connected pipelines are not supported with DDUF at the moment.")
# 1. Download the checkpoints and configs # 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained # use snapshot download here to get it working from from_pretrained
if not os.path.isdir(pretrained_model_name_or_path): if not os.path.isdir(pretrained_model_name_or_path):
...@@ -744,6 +759,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -744,6 +759,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
custom_pipeline=custom_pipeline, custom_pipeline=custom_pipeline,
custom_revision=custom_revision, custom_revision=custom_revision,
variant=variant, variant=variant,
dduf_file=dduf_file,
load_connected_pipeline=load_connected_pipeline, load_connected_pipeline=load_connected_pipeline,
**kwargs, **kwargs,
) )
...@@ -765,7 +781,17 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -765,7 +781,17 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
) )
logger.warning(warn_msg) logger.warning(warn_msg)
config_dict = cls.load_config(cached_folder) dduf_entries = None
if dduf_file:
dduf_file_path = os.path.join(cached_folder, dduf_file)
dduf_entries = read_dduf_file(dduf_file_path)
# The reader contains already all the files needed, no need to check it again
cached_folder = ""
config_dict = cls.load_config(cached_folder, dduf_entries=dduf_entries)
if dduf_file:
_maybe_raise_error_for_incorrect_transformers(config_dict)
# pop out "_ignore_files" as it is only needed for download # pop out "_ignore_files" as it is only needed for download
config_dict.pop("_ignore_files", None) config_dict.pop("_ignore_files", None)
...@@ -943,6 +969,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -943,6 +969,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
cached_folder=cached_folder, cached_folder=cached_folder,
use_safetensors=use_safetensors, use_safetensors=use_safetensors,
dduf_entries=dduf_entries,
) )
logger.info( logger.info(
f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}." f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."
...@@ -1256,6 +1283,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -1256,6 +1283,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
variant (`str`, *optional*): variant (`str`, *optional*):
Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
loading `from_flax`. loading `from_flax`.
dduf_file(`str`, *optional*):
Load weights from the specified DDUF file.
use_safetensors (`bool`, *optional*, defaults to `None`): use_safetensors (`bool`, *optional*, defaults to `None`):
If set to `None`, the safetensors weights are downloaded if they're available **and** if the If set to `None`, the safetensors weights are downloaded if they're available **and** if the
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
...@@ -1296,6 +1325,23 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -1296,6 +1325,23 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
use_onnx = kwargs.pop("use_onnx", None) use_onnx = kwargs.pop("use_onnx", None)
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False) load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
trust_remote_code = kwargs.pop("trust_remote_code", False) trust_remote_code = kwargs.pop("trust_remote_code", False)
dduf_file: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_file", None)
if dduf_file:
if custom_pipeline:
raise NotImplementedError("Custom pipelines are not supported with DDUF at the moment.")
if load_connected_pipeline:
raise NotImplementedError("Connected pipelines are not supported with DDUF at the moment.")
return _download_dduf_file(
pretrained_model_name=pretrained_model_name,
dduf_file=dduf_file,
pipeline_class_name=cls.__name__,
cache_dir=cache_dir,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
)
allow_pickle = False allow_pickle = False
if use_safetensors is None: if use_safetensors is None:
...@@ -1375,7 +1421,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -1375,7 +1421,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
allow_patterns += [f"{custom_pipeline}.py"] if f"{custom_pipeline}.py" in filenames else [] allow_patterns += [f"{custom_pipeline}.py"] if f"{custom_pipeline}.py" in filenames else []
# also allow downloading config.json files with the model # also allow downloading config.json files with the model
allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names] allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names]
allow_patterns += [ allow_patterns += [
SCHEDULER_CONFIG_NAME, SCHEDULER_CONFIG_NAME,
CONFIG_NAME, CONFIG_NAME,
...@@ -1471,7 +1516,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -1471,7 +1516,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
user_agent=user_agent, user_agent=user_agent,
) )
# retrieve pipeline class from local file
cls_name = cls.load_config(os.path.join(cached_folder, "model_index.json")).get("_class_name", None) cls_name = cls.load_config(os.path.join(cached_folder, "model_index.json")).get("_class_name", None)
cls_name = cls_name[4:] if isinstance(cls_name, str) and cls_name.startswith("Flax") else cls_name cls_name = cls_name[4:] if isinstance(cls_name, str) and cls_name.startswith("Flax") else cls_name
......
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import os
import tempfile
from typing import TYPE_CHECKING, Dict
from huggingface_hub import DDUFEntry
from tqdm import tqdm
from ..utils import is_safetensors_available, is_transformers_available, is_transformers_version
if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer
if is_transformers_available():
from transformers import PreTrainedModel, PreTrainedTokenizer
if is_safetensors_available():
import safetensors.torch
def _load_tokenizer_from_dduf(
cls: "PreTrainedTokenizer", name: str, dduf_entries: Dict[str, DDUFEntry], **kwargs
) -> "PreTrainedTokenizer":
"""
Load a tokenizer from a DDUF archive.
In practice, `transformers` do not provide a way to load a tokenizer from a DDUF archive. This function is a
workaround by extracting the tokenizer files from the DDUF archive and loading the tokenizer from the extracted
files. There is an extra cost of extracting the files, but of limited impact as the tokenizer files are usually
small-ish.
"""
with tempfile.TemporaryDirectory() as tmp_dir:
for entry_name, entry in dduf_entries.items():
if entry_name.startswith(name + "/"):
tmp_entry_path = os.path.join(tmp_dir, *entry_name.split("/"))
# need to create intermediary directory if they don't exist
os.makedirs(os.path.dirname(tmp_entry_path), exist_ok=True)
with open(tmp_entry_path, "wb") as f:
with entry.as_mmap() as mm:
f.write(mm)
return cls.from_pretrained(os.path.dirname(tmp_entry_path), **kwargs)
def _load_transformers_model_from_dduf(
cls: "PreTrainedModel", name: str, dduf_entries: Dict[str, DDUFEntry], **kwargs
) -> "PreTrainedModel":
"""
Load a transformers model from a DDUF archive.
In practice, `transformers` do not provide a way to load a model from a DDUF archive. This function is a workaround
by instantiating a model from the config file and loading the weights from the DDUF archive directly.
"""
config_file = dduf_entries.get(f"{name}/config.json")
if config_file is None:
raise EnvironmentError(
f"Could not find a config.json file for component {name} in DDUF file (contains {dduf_entries.keys()})."
)
generation_config = dduf_entries.get(f"{name}/generation_config.json", None)
weight_files = [
entry
for entry_name, entry in dduf_entries.items()
if entry_name.startswith(f"{name}/") and entry_name.endswith(".safetensors")
]
if not weight_files:
raise EnvironmentError(
f"Could not find any weight file for component {name} in DDUF file (contains {dduf_entries.keys()})."
)
if not is_safetensors_available():
raise EnvironmentError(
"Safetensors is not available, cannot load model from DDUF. Please `pip install safetensors`."
)
if is_transformers_version("<", "4.47.0"):
raise ImportError(
"You need to install `transformers>4.47.0` in order to load a transformers model from a DDUF file. "
"You can install it with: `pip install --upgrade transformers`"
)
with tempfile.TemporaryDirectory() as tmp_dir:
from transformers import AutoConfig, GenerationConfig
tmp_config_file = os.path.join(tmp_dir, "config.json")
with open(tmp_config_file, "w") as f:
f.write(config_file.read_text())
config = AutoConfig.from_pretrained(tmp_config_file)
if generation_config is not None:
tmp_generation_config_file = os.path.join(tmp_dir, "generation_config.json")
with open(tmp_generation_config_file, "w") as f:
f.write(generation_config.read_text())
generation_config = GenerationConfig.from_pretrained(tmp_generation_config_file)
state_dict = {}
with contextlib.ExitStack() as stack:
for entry in tqdm(weight_files, desc="Loading state_dict"): # Loop over safetensors files
# Memory-map the safetensors file
mmap = stack.enter_context(entry.as_mmap())
# Load tensors from the memory-mapped file
tensors = safetensors.torch.load(mmap)
# Update the state dictionary with tensors
state_dict.update(tensors)
return cls.from_pretrained(
pretrained_model_name_or_path=None,
config=config,
generation_config=generation_config,
state_dict=state_dict,
**kwargs,
)
...@@ -70,6 +70,7 @@ from .import_utils import ( ...@@ -70,6 +70,7 @@ from .import_utils import (
is_gguf_available, is_gguf_available,
is_gguf_version, is_gguf_version,
is_google_colab, is_google_colab,
is_hf_hub_version,
is_inflect_available, is_inflect_available,
is_invisible_watermark_available, is_invisible_watermark_available,
is_k_diffusion_available, is_k_diffusion_available,
......
...@@ -26,6 +26,7 @@ from typing import Dict, List, Optional, Union ...@@ -26,6 +26,7 @@ from typing import Dict, List, Optional, Union
from uuid import uuid4 from uuid import uuid4
from huggingface_hub import ( from huggingface_hub import (
DDUFEntry,
ModelCard, ModelCard,
ModelCardData, ModelCardData,
create_repo, create_repo,
...@@ -291,9 +292,26 @@ def _get_model_file( ...@@ -291,9 +292,26 @@ def _get_model_file(
user_agent: Optional[Union[Dict, str]] = None, user_agent: Optional[Union[Dict, str]] = None,
revision: Optional[str] = None, revision: Optional[str] = None,
commit_hash: Optional[str] = None, commit_hash: Optional[str] = None,
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
): ):
pretrained_model_name_or_path = str(pretrained_model_name_or_path) pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isfile(pretrained_model_name_or_path):
if dduf_entries:
if subfolder is not None:
raise ValueError(
"DDUF file only allow for 1 level of directory (e.g transformer/model1/model.safetentors is not allowed). "
"Please check the DDUF structure"
)
model_file = (
weights_name
if pretrained_model_name_or_path == ""
else "/".join([pretrained_model_name_or_path, weights_name])
)
if model_file in dduf_entries:
return model_file
else:
raise EnvironmentError(f"Error no file named {weights_name} found in archive {dduf_entries.keys()}.")
elif os.path.isfile(pretrained_model_name_or_path):
return pretrained_model_name_or_path return pretrained_model_name_or_path
elif os.path.isdir(pretrained_model_name_or_path): elif os.path.isdir(pretrained_model_name_or_path):
if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)): if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)):
...@@ -419,6 +437,7 @@ def _get_checkpoint_shard_files( ...@@ -419,6 +437,7 @@ def _get_checkpoint_shard_files(
user_agent=None, user_agent=None,
revision=None, revision=None,
subfolder="", subfolder="",
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
): ):
""" """
For a given model: For a given model:
...@@ -430,11 +449,18 @@ def _get_checkpoint_shard_files( ...@@ -430,11 +449,18 @@ def _get_checkpoint_shard_files(
For the description of each arg, see [`PreTrainedModel.from_pretrained`]. `index_filename` is the full path to the For the description of each arg, see [`PreTrainedModel.from_pretrained`]. `index_filename` is the full path to the
index (downloaded and cached if `pretrained_model_name_or_path` is a model ID on the Hub). index (downloaded and cached if `pretrained_model_name_or_path` is a model ID on the Hub).
""" """
if not os.path.isfile(index_filename): if dduf_entries:
raise ValueError(f"Can't find a checkpoint index ({index_filename}) in {pretrained_model_name_or_path}.") if index_filename not in dduf_entries:
raise ValueError(f"Can't find a checkpoint index ({index_filename}) in {pretrained_model_name_or_path}.")
else:
if not os.path.isfile(index_filename):
raise ValueError(f"Can't find a checkpoint index ({index_filename}) in {pretrained_model_name_or_path}.")
with open(index_filename, "r") as f: if dduf_entries:
index = json.loads(f.read()) index = json.loads(dduf_entries[index_filename].read_text())
else:
with open(index_filename, "r") as f:
index = json.loads(f.read())
original_shard_filenames = sorted(set(index["weight_map"].values())) original_shard_filenames = sorted(set(index["weight_map"].values()))
sharded_metadata = index["metadata"] sharded_metadata = index["metadata"]
...@@ -448,6 +474,8 @@ def _get_checkpoint_shard_files( ...@@ -448,6 +474,8 @@ def _get_checkpoint_shard_files(
pretrained_model_name_or_path, subfolder=subfolder, original_shard_filenames=original_shard_filenames pretrained_model_name_or_path, subfolder=subfolder, original_shard_filenames=original_shard_filenames
) )
return shards_path, sharded_metadata return shards_path, sharded_metadata
elif dduf_entries:
return shards_path, sharded_metadata
# At this stage pretrained_model_name_or_path is a model identifier on the Hub # At this stage pretrained_model_name_or_path is a model identifier on the Hub
allow_patterns = original_shard_filenames allow_patterns = original_shard_filenames
......
...@@ -115,6 +115,13 @@ try: ...@@ -115,6 +115,13 @@ try:
except importlib_metadata.PackageNotFoundError: except importlib_metadata.PackageNotFoundError:
_transformers_available = False _transformers_available = False
_hf_hub_available = importlib.util.find_spec("huggingface_hub") is not None
try:
_hf_hub_version = importlib_metadata.version("huggingface_hub")
logger.debug(f"Successfully imported huggingface_hub version {_hf_hub_version}")
except importlib_metadata.PackageNotFoundError:
_hf_hub_available = False
_inflect_available = importlib.util.find_spec("inflect") is not None _inflect_available = importlib.util.find_spec("inflect") is not None
try: try:
...@@ -767,6 +774,21 @@ def is_transformers_version(operation: str, version: str): ...@@ -767,6 +774,21 @@ def is_transformers_version(operation: str, version: str):
return compare_versions(parse(_transformers_version), operation, version) return compare_versions(parse(_transformers_version), operation, version)
def is_hf_hub_version(operation: str, version: str):
"""
Compares the current Hugging Face Hub version to a given reference with an operation.
Args:
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
A version string
"""
if not _hf_hub_available:
return False
return compare_versions(parse(_hf_hub_version), operation, version)
def is_accelerate_version(operation: str, version: str): def is_accelerate_version(operation: str, version: str):
""" """
Compares the current Accelerate version to a given reference with an operation. Compares the current Accelerate version to a given reference with an operation.
......
...@@ -478,6 +478,18 @@ def require_bitsandbytes_version_greater(bnb_version): ...@@ -478,6 +478,18 @@ def require_bitsandbytes_version_greater(bnb_version):
return decorator return decorator
def require_hf_hub_version_greater(hf_hub_version):
def decorator(test_case):
correct_hf_hub_version = version.parse(
version.parse(importlib.metadata.version("huggingface_hub")).base_version
) > version.parse(hf_hub_version)
return unittest.skipUnless(
correct_hf_hub_version, f"Test requires huggingface_hub with the version greater than {hf_hub_version}."
)(test_case)
return decorator
def require_gguf_version_greater_or_equal(gguf_version): def require_gguf_version_greater_or_equal(gguf_version):
def decorator(test_case): def decorator(test_case):
correct_gguf_version = is_gguf_available() and version.parse( correct_gguf_version = is_gguf_available() and version.parse(
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
import gc import gc
import inspect import inspect
import os
import tempfile
import unittest import unittest
import numpy as np import numpy as np
...@@ -24,7 +26,9 @@ from diffusers import AllegroPipeline, AllegroTransformer3DModel, AutoencoderKLA ...@@ -24,7 +26,9 @@ from diffusers import AllegroPipeline, AllegroTransformer3DModel, AutoencoderKLA
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
enable_full_determinism, enable_full_determinism,
numpy_cosine_similarity_distance, numpy_cosine_similarity_distance,
require_hf_hub_version_greater,
require_torch_gpu, require_torch_gpu,
require_transformers_version_greater,
slow, slow,
torch_device, torch_device,
) )
...@@ -297,6 +301,35 @@ class AllegroPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -297,6 +301,35 @@ class AllegroPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
"VAE tiling should not affect the inference results", "VAE tiling should not affect the inference results",
) )
@require_hf_hub_version_greater("0.26.5")
@require_transformers_version_greater("4.47.1")
def test_save_load_dduf(self):
# reimplement because it needs `enable_tiling()` on the loaded pipe.
from huggingface_hub import export_folder_as_dduf
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device="cpu")
inputs.pop("generator")
inputs["generator"] = torch.manual_seed(0)
pipeline_out = pipe(**inputs)[0].cpu()
with tempfile.TemporaryDirectory() as tmpdir:
dduf_filename = os.path.join(tmpdir, f"{pipe.__class__.__name__.lower()}.dduf")
pipe.save_pretrained(tmpdir, safe_serialization=True)
export_folder_as_dduf(dduf_filename, folder_path=tmpdir)
loaded_pipe = self.pipeline_class.from_pretrained(tmpdir, dduf_file=dduf_filename).to(torch_device)
loaded_pipe.vae.enable_tiling()
inputs["generator"] = torch.manual_seed(0)
loaded_pipeline_out = loaded_pipe(**inputs)[0].cpu()
assert np.allclose(pipeline_out, loaded_pipeline_out)
@slow @slow
@require_torch_gpu @require_torch_gpu
......
...@@ -63,6 +63,8 @@ class AudioLDMPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -63,6 +63,8 @@ class AudioLDMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
] ]
) )
supports_dduf = False
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
unet = UNet2DConditionModel( unet = UNet2DConditionModel(
......
...@@ -70,6 +70,8 @@ class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -70,6 +70,8 @@ class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
] ]
) )
supports_dduf = False
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
unet = AudioLDM2UNet2DConditionModel( unet = AudioLDM2UNet2DConditionModel(
......
...@@ -60,6 +60,8 @@ class BlipDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -60,6 +60,8 @@ class BlipDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
"prompt_reps", "prompt_reps",
] ]
supports_dduf = False
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
text_encoder_config = CLIPTextConfig( text_encoder_config = CLIPTextConfig(
......
...@@ -291,6 +291,8 @@ class StableDiffusionMultiControlNetPipelineFastTests( ...@@ -291,6 +291,8 @@ class StableDiffusionMultiControlNetPipelineFastTests(
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess
supports_dduf = False
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
unet = UNet2DConditionModel( unet = UNet2DConditionModel(
...@@ -523,6 +525,8 @@ class StableDiffusionMultiControlNetOneModelPipelineFastTests( ...@@ -523,6 +525,8 @@ class StableDiffusionMultiControlNetOneModelPipelineFastTests(
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess
supports_dduf = False
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
unet = UNet2DConditionModel( unet = UNet2DConditionModel(
......
...@@ -68,6 +68,8 @@ class BlipDiffusionControlNetPipelineFastTests(PipelineTesterMixin, unittest.Tes ...@@ -68,6 +68,8 @@ class BlipDiffusionControlNetPipelineFastTests(PipelineTesterMixin, unittest.Tes
"prompt_reps", "prompt_reps",
] ]
supports_dduf = False
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
text_encoder_config = CLIPTextConfig( text_encoder_config = CLIPTextConfig(
......
...@@ -198,6 +198,8 @@ class StableDiffusionMultiControlNetPipelineFastTests( ...@@ -198,6 +198,8 @@ class StableDiffusionMultiControlNetPipelineFastTests(
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess
supports_dduf = False
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
unet = UNet2DConditionModel( unet = UNet2DConditionModel(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment