Unverified Commit 52c05bd4 authored by Daniel Hipke's avatar Daniel Hipke Committed by GitHub
Browse files

Add a `disable_mmap` option to the `from_single_file` loader to improve load...


Add a `disable_mmap` option to the `from_single_file` loader to improve load performance on network mounts (#10305)

* Add no_mmap arg.

* Fix arg parsing.

* Update another method to force no mmap.

* logging

* logging2

* propagate no_mmap

* logging3

* propagate no_mmap

* logging4

* fix open call

* clean up logging

* cleanup

* fix missing arg

* update logging and comments

* Rename to disable_mmap and update other references.

* [Docs] Update ltx_video.md to remove generator from `from_pretrained()` (#10316)

Update ltx_video.md to remove generator from `from_pretrained()`

* docs: fix a mistake in docstring (#10319)

Update pipeline_hunyuan_video.py

docs: fix a mistake

* [BUG FIX] [Stable Audio Pipeline] Resolve torch.Tensor.new_zeros() TypeError in function prepare_latents caused by audio_vae_length (#10306)

[BUG FIX] [Stable Audio Pipeline] TypeError: new_zeros(): argument 'size' failed to unpack the object at pos 3 with error "type must be tuple of ints,but got float"

torch.Tensor.new_zeros() takes a single argument size (int...) – a list, tuple, or torch.Size of integers defining the shape of the output tensor.

in function prepare_latents:
audio_vae_length = self.transformer.config.sample_size * self.vae.hop_length
audio_shape = (batch_size // num_waveforms_per_prompt, audio_channels, audio_vae_length)
...
audio = initial_audio_waveforms.new_zeros(audio_shape)

audio_vae_length evaluates to float because self.transformer.config.sample_size returns a float
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* [docs] Fix quantization links (#10323)

Update overview.md

* [Sana]add 2K related model for Sana (#10322)

add 2K related model for Sana

* Update src/diffusers/loaders/single_file_model.py
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>

* Update src/diffusers/loaders/single_file.py
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>

* make style

---------
Co-authored-by: default avatarhlky <hlky@hlky.ac>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarLeojc <liao_junchao@outlook.com>
Co-authored-by: default avatarAditya Raj <syntaxticsugr@gmail.com>
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: default avatarJunsong Chen <cjs1020440147@icloud.com>
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent a6f043a8
...@@ -60,6 +60,7 @@ def load_single_file_sub_model( ...@@ -60,6 +60,7 @@ def load_single_file_sub_model(
local_files_only=False, local_files_only=False,
torch_dtype=None, torch_dtype=None,
is_legacy_loading=False, is_legacy_loading=False,
disable_mmap=False,
**kwargs, **kwargs,
): ):
if is_pipeline_module: if is_pipeline_module:
...@@ -106,6 +107,7 @@ def load_single_file_sub_model( ...@@ -106,6 +107,7 @@ def load_single_file_sub_model(
subfolder=name, subfolder=name,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
local_files_only=local_files_only, local_files_only=local_files_only,
disable_mmap=disable_mmap,
**kwargs, **kwargs,
) )
...@@ -308,6 +310,9 @@ class FromSingleFileMixin: ...@@ -308,6 +310,9 @@ class FromSingleFileMixin:
hosted on the Hub. hosted on the Hub.
- A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline - A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline
component configs in Diffusers format. component configs in Diffusers format.
disable_mmap ('bool', *optional*, defaults to 'False'):
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
is on a network mount or hard drive.
kwargs (remaining dictionary of keyword arguments, *optional*): kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
class). The overwritten components are passed directly to the pipelines `__init__` method. See example class). The overwritten components are passed directly to the pipelines `__init__` method. See example
...@@ -355,6 +360,7 @@ class FromSingleFileMixin: ...@@ -355,6 +360,7 @@ class FromSingleFileMixin:
local_files_only = kwargs.pop("local_files_only", False) local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None) torch_dtype = kwargs.pop("torch_dtype", None)
disable_mmap = kwargs.pop("disable_mmap", False)
is_legacy_loading = False is_legacy_loading = False
...@@ -383,6 +389,7 @@ class FromSingleFileMixin: ...@@ -383,6 +389,7 @@ class FromSingleFileMixin:
cache_dir=cache_dir, cache_dir=cache_dir,
local_files_only=local_files_only, local_files_only=local_files_only,
revision=revision, revision=revision,
disable_mmap=disable_mmap,
) )
if config is None: if config is None:
...@@ -504,6 +511,7 @@ class FromSingleFileMixin: ...@@ -504,6 +511,7 @@ class FromSingleFileMixin:
original_config=original_config, original_config=original_config,
local_files_only=local_files_only, local_files_only=local_files_only,
is_legacy_loading=is_legacy_loading, is_legacy_loading=is_legacy_loading,
disable_mmap=disable_mmap,
**kwargs, **kwargs,
) )
except SingleFileComponentError as e: except SingleFileComponentError as e:
......
...@@ -187,6 +187,9 @@ class FromOriginalModelMixin: ...@@ -187,6 +187,9 @@ class FromOriginalModelMixin:
revision (`str`, *optional*, defaults to `"main"`): revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git. allowed by Git.
disable_mmap ('bool', *optional*, defaults to 'False'):
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
kwargs (remaining dictionary of keyword arguments, *optional*): kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load and saveable variables (for example the pipeline components of the Can be used to overwrite load and saveable variables (for example the pipeline components of the
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__` specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
...@@ -234,6 +237,7 @@ class FromOriginalModelMixin: ...@@ -234,6 +237,7 @@ class FromOriginalModelMixin:
torch_dtype = kwargs.pop("torch_dtype", None) torch_dtype = kwargs.pop("torch_dtype", None)
quantization_config = kwargs.pop("quantization_config", None) quantization_config = kwargs.pop("quantization_config", None)
device = kwargs.pop("device", None) device = kwargs.pop("device", None)
disable_mmap = kwargs.pop("disable_mmap", False)
if isinstance(pretrained_model_link_or_path_or_dict, dict): if isinstance(pretrained_model_link_or_path_or_dict, dict):
checkpoint = pretrained_model_link_or_path_or_dict checkpoint = pretrained_model_link_or_path_or_dict
...@@ -246,6 +250,7 @@ class FromOriginalModelMixin: ...@@ -246,6 +250,7 @@ class FromOriginalModelMixin:
cache_dir=cache_dir, cache_dir=cache_dir,
local_files_only=local_files_only, local_files_only=local_files_only,
revision=revision, revision=revision,
disable_mmap=disable_mmap,
) )
if quantization_config is not None: if quantization_config is not None:
hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config) hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config)
......
...@@ -387,6 +387,7 @@ def load_single_file_checkpoint( ...@@ -387,6 +387,7 @@ def load_single_file_checkpoint(
cache_dir=None, cache_dir=None,
local_files_only=None, local_files_only=None,
revision=None, revision=None,
disable_mmap=False,
): ):
if os.path.isfile(pretrained_model_link_or_path): if os.path.isfile(pretrained_model_link_or_path):
pretrained_model_link_or_path = pretrained_model_link_or_path pretrained_model_link_or_path = pretrained_model_link_or_path
...@@ -404,7 +405,7 @@ def load_single_file_checkpoint( ...@@ -404,7 +405,7 @@ def load_single_file_checkpoint(
revision=revision, revision=revision,
) )
checkpoint = load_state_dict(pretrained_model_link_or_path) checkpoint = load_state_dict(pretrained_model_link_or_path, disable_mmap=disable_mmap)
# some checkpoints contain the model state dict under a "state_dict" key # some checkpoints contain the model state dict under a "state_dict" key
while "state_dict" in checkpoint: while "state_dict" in checkpoint:
......
...@@ -131,7 +131,9 @@ def _fetch_remapped_cls_from_config(config, old_class): ...@@ -131,7 +131,9 @@ def _fetch_remapped_cls_from_config(config, old_class):
return old_class return old_class
def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None): def load_state_dict(
checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None, disable_mmap: bool = False
):
""" """
Reads a checkpoint file, returning properly formatted errors if they arise. Reads a checkpoint file, returning properly formatted errors if they arise.
""" """
...@@ -142,7 +144,10 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[ ...@@ -142,7 +144,10 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
try: try:
file_extension = os.path.basename(checkpoint_file).split(".")[-1] file_extension = os.path.basename(checkpoint_file).split(".")[-1]
if file_extension == SAFETENSORS_FILE_EXTENSION: if file_extension == SAFETENSORS_FILE_EXTENSION:
return safetensors.torch.load_file(checkpoint_file, device="cpu") if disable_mmap:
return safetensors.torch.load(open(checkpoint_file, "rb").read())
else:
return safetensors.torch.load_file(checkpoint_file, device="cpu")
elif file_extension == GGUF_FILE_EXTENSION: elif file_extension == GGUF_FILE_EXTENSION:
return load_gguf_checkpoint(checkpoint_file) return load_gguf_checkpoint(checkpoint_file)
else: else:
......
...@@ -559,6 +559,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -559,6 +559,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
`safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors` `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
weights. If set to `False`, `safetensors` weights are not loaded. weights. If set to `False`, `safetensors` weights are not loaded.
disable_mmap ('bool', *optional*, defaults to 'False'):
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
<Tip> <Tip>
...@@ -604,6 +607,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -604,6 +607,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
variant = kwargs.pop("variant", None) variant = kwargs.pop("variant", None)
use_safetensors = kwargs.pop("use_safetensors", None) use_safetensors = kwargs.pop("use_safetensors", None)
quantization_config = kwargs.pop("quantization_config", None) quantization_config = kwargs.pop("quantization_config", None)
disable_mmap = kwargs.pop("disable_mmap", False)
allow_pickle = False allow_pickle = False
if use_safetensors is None: if use_safetensors is None:
...@@ -883,7 +887,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -883,7 +887,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
# TODO (sayakpaul, SunMarc): remove this after model loading refactor # TODO (sayakpaul, SunMarc): remove this after model loading refactor
else: else:
param_device = torch.device(torch.cuda.current_device()) param_device = torch.device(torch.cuda.current_device())
state_dict = load_state_dict(model_file, variant=variant) state_dict = load_state_dict(model_file, variant=variant, disable_mmap=disable_mmap)
model._convert_deprecated_attention_blocks(state_dict) model._convert_deprecated_attention_blocks(state_dict)
# move the params from meta device to cpu # move the params from meta device to cpu
...@@ -979,7 +983,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -979,7 +983,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
else: else:
model = cls.from_config(config, **unused_kwargs) model = cls.from_config(config, **unused_kwargs)
state_dict = load_state_dict(model_file, variant=variant) state_dict = load_state_dict(model_file, variant=variant, disable_mmap=disable_mmap)
model._convert_deprecated_attention_blocks(state_dict) model._convert_deprecated_attention_blocks(state_dict)
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment