Unverified Commit c9035275 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

Speedup model loading by 4-5x (#11904)



* update

* update

* update

* pin accelerate version

* add comment explanations

* update docstring

* make style

* non_blocking does not matter for dtype cast

* _empty_cache -> clear_cache

* update

* Update src/diffusers/models/model_loading_utils.py
Co-authored-by: default avatarMarc Sun <57196510+SunMarc@users.noreply.github.com>

* Update src/diffusers/models/model_loading_utils.py

---------
Co-authored-by: default avatarMarc Sun <57196510+SunMarc@users.noreply.github.com>
parent 7a935a0b
......@@ -24,6 +24,7 @@ from typing_extensions import Self
from .. import __version__
from ..quantizers import DiffusersAutoQuantizer
from ..utils import deprecate, is_accelerate_available, logging
from ..utils.torch_utils import device_synchronize, empty_device_cache
from .single_file_utils import (
SingleFileComponentError,
convert_animatediff_checkpoint_to_diffusers,
......@@ -430,6 +431,10 @@ class FromOriginalModelMixin:
keep_in_fp32_modules=keep_in_fp32_modules,
unexpected_keys=unexpected_keys,
)
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
empty_device_cache()
device_synchronize()
else:
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
......
......@@ -46,6 +46,7 @@ from ..utils import (
)
from ..utils.constants import DIFFUSERS_REQUEST_TIMEOUT
from ..utils.hub_utils import _get_model_file
from ..utils.torch_utils import device_synchronize, empty_device_cache
if is_transformers_available():
......@@ -1689,6 +1690,10 @@ def create_diffusers_clip_model_from_ldm(
if is_accelerate_available():
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
empty_device_cache()
device_synchronize()
else:
model.load_state_dict(diffusers_format_checkpoint, strict=False)
......@@ -2148,6 +2153,10 @@ def create_diffusers_t5_model_from_checkpoint(
if is_accelerate_available():
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
empty_device_cache()
device_synchronize()
else:
model.load_state_dict(diffusers_format_checkpoint)
......
......@@ -18,11 +18,8 @@ from ..models.embeddings import (
MultiIPAdapterImageProjection,
)
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..utils import (
is_accelerate_available,
is_torch_version,
logging,
)
from ..utils import is_accelerate_available, is_torch_version, logging
from ..utils.torch_utils import device_synchronize, empty_device_cache
if is_accelerate_available():
......@@ -84,6 +81,8 @@ class FluxTransformer2DLoadersMixin:
else:
device_map = {"": self.device}
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
empty_device_cache()
device_synchronize()
return image_projection
......@@ -158,6 +157,9 @@ class FluxTransformer2DLoadersMixin:
key_id += 1
empty_device_cache()
device_synchronize()
return attn_procs
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
......
......@@ -18,6 +18,7 @@ from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0
from ..models.embeddings import IPAdapterTimeImageProjection
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..utils import is_accelerate_available, is_torch_version, logging
from ..utils.torch_utils import device_synchronize, empty_device_cache
logger = logging.get_logger(__name__)
......@@ -80,6 +81,9 @@ class SD3Transformer2DLoadersMixin:
attn_procs[name], layer_state_dict[idx], device_map=device_map, dtype=self.dtype
)
empty_device_cache()
device_synchronize()
return attn_procs
def _convert_ip_adapter_image_proj_to_diffusers(
......@@ -147,6 +151,8 @@ class SD3Transformer2DLoadersMixin:
else:
device_map = {"": self.device}
load_model_dict_into_meta(image_proj, updated_state_dict, device_map=device_map, dtype=self.dtype)
empty_device_cache()
device_synchronize()
return image_proj
......
......@@ -43,6 +43,7 @@ from ..utils import (
is_torch_version,
logging,
)
from ..utils.torch_utils import device_synchronize, empty_device_cache
from .lora_base import _func_optionally_disable_offloading
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME
from .utils import AttnProcsLayers
......@@ -753,6 +754,8 @@ class UNet2DConditionLoadersMixin:
else:
device_map = {"": self.device}
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
empty_device_cache()
device_synchronize()
return image_projection
......@@ -850,6 +853,9 @@ class UNet2DConditionLoadersMixin:
key_id += 2
empty_device_cache()
device_synchronize()
return attn_procs
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
......
......@@ -16,9 +16,10 @@
import importlib
import inspect
import math
import os
from array import array
from collections import OrderedDict
from collections import OrderedDict, defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Union
from zipfile import is_zipfile
......@@ -38,6 +39,7 @@ from ..utils import (
_get_model_file,
deprecate,
is_accelerate_available,
is_accelerate_version,
is_gguf_available,
is_torch_available,
is_torch_version,
......@@ -252,6 +254,10 @@ def load_model_dict_into_meta(
param = param.to(dtype)
set_module_kwargs["dtype"] = dtype
if is_accelerate_version(">", "1.8.1"):
set_module_kwargs["non_blocking"] = True
set_module_kwargs["clear_cache"] = False
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
# uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
# Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29
......@@ -520,3 +526,60 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
parsed_parameters[name] = GGUFParameter(weights, quant_type=quant_type) if is_gguf_quant else weights
return parsed_parameters
def _find_mismatched_keys(state_dict, model_state_dict, loaded_keys, ignore_mismatched_sizes):
mismatched_keys = []
if not ignore_mismatched_sizes:
return mismatched_keys
for checkpoint_key in loaded_keys:
model_key = checkpoint_key
# If the checkpoint is sharded, we may not have the key here.
if checkpoint_key not in state_dict:
continue
if model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape:
mismatched_keys.append(
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
)
del state_dict[checkpoint_key]
return mismatched_keys
def _expand_device_map(device_map, param_names):
"""
Expand a device map to return the correspondence parameter name to device.
"""
new_device_map = {}
for module, device in device_map.items():
new_device_map.update(
{p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""}
)
return new_device_map
# Adapted from: https://github.com/huggingface/transformers/blob/0687d481e2c71544501ef9cb3eef795a6e79b1de/src/transformers/modeling_utils.py#L5859
def _caching_allocator_warmup(model, expanded_device_map: Dict[str, torch.device], dtype: torch.dtype) -> None:
"""
This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
device. It allows to have one large call to Malloc, instead of recursively calling it later when loading the model,
which is actually the loading speed bottleneck. Calling this function allows to cut the model loading time by a
very large margin.
"""
# Remove disk and cpu devices, and cast to proper torch.device
accelerator_device_map = {
param: torch.device(device)
for param, device in expanded_device_map.items()
if str(device) not in ["cpu", "disk"]
}
parameter_count = defaultdict(lambda: 0)
for param_name, device in accelerator_device_map.items():
try:
param = model.get_parameter(param_name)
except AttributeError:
param = model.get_buffer(param_name)
parameter_count[device] += math.prod(param.shape)
# This will kick off the caching allocator to avoid having to Malloc afterwards
for device, param_count in parameter_count.items():
_ = torch.empty(param_count, dtype=dtype, device=device, requires_grad=False)
......@@ -62,10 +62,14 @@ from ..utils.hub_utils import (
load_or_create_model_card,
populate_model_card,
)
from ..utils.torch_utils import device_synchronize, empty_device_cache
from .model_loading_utils import (
_caching_allocator_warmup,
_determine_device_map,
_expand_device_map,
_fetch_index_file,
_fetch_index_file_legacy,
_find_mismatched_keys,
_load_state_dict_into_model,
load_model_dict_into_meta,
load_state_dict,
......@@ -1469,11 +1473,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
for pat in cls._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
mismatched_keys = []
assign_to_params_buffers = None
error_msgs = []
# Deal with offload
if device_map is not None and "disk" in device_map.values():
if offload_folder is None:
......@@ -1482,18 +1481,27 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
" for them. Alternatively, make sure you have `safetensors` installed if the model you are using"
" offers the weights in this format."
)
if offload_folder is not None:
else:
os.makedirs(offload_folder, exist_ok=True)
if offload_state_dict is None:
offload_state_dict = True
# If a device map has been used, we can speedup the load time by warming up the device caching allocator.
# If we don't warmup, each tensor allocation on device calls to the allocator for memory (effectively, a
# lot of individual calls to device malloc). We can, however, preallocate the memory required by the
# tensors using their expected shape and not performing any initialization of the memory (empty data).
# When the actual device allocations happen, the allocator already has a pool of unused device memory
# that it can re-use for faster loading of the model.
# TODO: add support for warmup with hf_quantizer
if device_map is not None and hf_quantizer is None:
expanded_device_map = _expand_device_map(device_map, expected_keys)
_caching_allocator_warmup(model, expanded_device_map, dtype)
offload_index = {} if device_map is not None and "disk" in device_map.values() else None
state_dict_folder, state_dict_index = None, None
if offload_state_dict:
state_dict_folder = tempfile.mkdtemp()
state_dict_index = {}
else:
state_dict_folder = None
state_dict_index = None
if state_dict is not None:
# load_state_dict will manage the case where we pass a dict instead of a file
......@@ -1503,38 +1511,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
if len(resolved_model_file) > 1:
resolved_model_file = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards")
mismatched_keys = []
assign_to_params_buffers = None
error_msgs = []
for shard_file in resolved_model_file:
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries)
def _find_mismatched_keys(
state_dict,
model_state_dict,
loaded_keys,
ignore_mismatched_sizes,
):
mismatched_keys = []
if ignore_mismatched_sizes:
for checkpoint_key in loaded_keys:
model_key = checkpoint_key
# If the checkpoint is sharded, we may not have the key here.
if checkpoint_key not in state_dict:
continue
if (
model_key in model_state_dict
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
):
mismatched_keys.append(
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
)
del state_dict[checkpoint_key]
return mismatched_keys
mismatched_keys += _find_mismatched_keys(
state_dict,
model_state_dict,
loaded_keys,
ignore_mismatched_sizes,
state_dict, model_state_dict, loaded_keys, ignore_mismatched_sizes
)
if low_cpu_mem_usage:
......@@ -1554,9 +1538,13 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
else:
if assign_to_params_buffers is None:
assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict)
error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers)
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
empty_device_cache()
device_synchronize()
if offload_index is not None and len(offload_index) > 0:
save_offload_index(offload_index, offload_folder)
offload_index = None
......
......@@ -184,5 +184,14 @@ def get_device():
def empty_device_cache(device_type: Optional[str] = None):
if device_type is None:
device_type = get_device()
if device_type in ["cpu"]:
return
device_mod = getattr(torch, device_type, torch.cuda)
device_mod.empty_cache()
def device_synchronize(device_type: Optional[str] = None):
if device_type is None:
device_type = get_device()
device_mod = getattr(torch, device_type, torch.cuda)
device_mod.synchronize()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment