Unverified Commit baa9b582 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[core] parallel loading of shards (#12028)



* checking.

* checking

* checking

* up

* up

* up

* Apply suggestions from code review
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>

* up

* up

* fix

* review feedback.

---------
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent da096a49
...@@ -62,7 +62,7 @@ logger = logging.get_logger(__name__) ...@@ -62,7 +62,7 @@ logger = logging.get_logger(__name__)
if is_accelerate_available(): if is_accelerate_available():
from accelerate import dispatch_model, init_empty_weights from accelerate import dispatch_model, init_empty_weights
from ..models.modeling_utils import load_model_dict_into_meta from ..models.model_loading_utils import load_model_dict_into_meta
if is_torch_version(">=", "1.9.0") and is_accelerate_available(): if is_torch_version(">=", "1.9.0") and is_accelerate_available():
_LOW_CPU_MEM_USAGE_DEFAULT = True _LOW_CPU_MEM_USAGE_DEFAULT = True
......
...@@ -55,7 +55,7 @@ if is_transformers_available(): ...@@ -55,7 +55,7 @@ if is_transformers_available():
if is_accelerate_available(): if is_accelerate_available():
from accelerate import init_empty_weights from accelerate import init_empty_weights
from ..models.modeling_utils import load_model_dict_into_meta from ..models.model_loading_utils import load_model_dict_into_meta
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......
...@@ -17,7 +17,8 @@ from ..models.embeddings import ( ...@@ -17,7 +17,8 @@ from ..models.embeddings import (
ImageProjection, ImageProjection,
MultiIPAdapterImageProjection, MultiIPAdapterImageProjection,
) )
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta from ..models.model_loading_utils import load_model_dict_into_meta
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
from ..utils import is_accelerate_available, is_torch_version, logging from ..utils import is_accelerate_available, is_torch_version, logging
from ..utils.torch_utils import empty_device_cache from ..utils.torch_utils import empty_device_cache
......
...@@ -16,7 +16,8 @@ from typing import Dict ...@@ -16,7 +16,8 @@ from typing import Dict
from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0 from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0
from ..models.embeddings import IPAdapterTimeImageProjection from ..models.embeddings import IPAdapterTimeImageProjection
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta from ..models.model_loading_utils import load_model_dict_into_meta
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
from ..utils import is_accelerate_available, is_torch_version, logging from ..utils import is_accelerate_available, is_torch_version, logging
from ..utils.torch_utils import empty_device_cache from ..utils.torch_utils import empty_device_cache
......
...@@ -30,7 +30,8 @@ from ..models.embeddings import ( ...@@ -30,7 +30,8 @@ from ..models.embeddings import (
IPAdapterPlusImageProjection, IPAdapterPlusImageProjection,
MultiIPAdapterImageProjection, MultiIPAdapterImageProjection,
) )
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta, load_state_dict from ..models.model_loading_utils import load_model_dict_into_meta
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict
from ..utils import ( from ..utils import (
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
_get_model_file, _get_model_file,
......
...@@ -14,12 +14,14 @@ ...@@ -14,12 +14,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import functools
import importlib import importlib
import inspect import inspect
import math import math
import os import os
from array import array from array import array
from collections import OrderedDict, defaultdict from collections import OrderedDict, defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
from zipfile import is_zipfile from zipfile import is_zipfile
...@@ -31,6 +33,7 @@ from huggingface_hub.utils import EntryNotFoundError ...@@ -31,6 +33,7 @@ from huggingface_hub.utils import EntryNotFoundError
from ..quantizers import DiffusersQuantizer from ..quantizers import DiffusersQuantizer
from ..utils import ( from ..utils import (
DEFAULT_HF_PARALLEL_LOADING_WORKERS,
GGUF_FILE_EXTENSION, GGUF_FILE_EXTENSION,
SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME,
SAFETENSORS_FILE_EXTENSION, SAFETENSORS_FILE_EXTENSION,
...@@ -310,6 +313,161 @@ def load_model_dict_into_meta( ...@@ -310,6 +313,161 @@ def load_model_dict_into_meta(
return offload_index, state_dict_index return offload_index, state_dict_index
def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefix=""):
"""
Checks if `model_to_load` supports param buffer assignment (such as when loading in empty weights) by first
checking if the model explicitly disables it, then by ensuring that the state dict keys are a subset of the model's
parameters.
"""
if model_to_load.device.type == "meta":
return False
if len([key for key in state_dict if key.startswith(start_prefix)]) == 0:
return False
# Some models explicitly do not support param buffer assignment
if not getattr(model_to_load, "_supports_param_buffer_assignment", True):
logger.debug(
f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower"
)
return False
# If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype
first_key = next(iter(model_to_load.state_dict().keys()))
if start_prefix + first_key in state_dict:
return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype
return False
def _load_shard_file(
shard_file,
model,
model_state_dict,
device_map=None,
dtype=None,
hf_quantizer=None,
keep_in_fp32_modules=None,
dduf_entries=None,
loaded_keys=None,
unexpected_keys=None,
offload_index=None,
offload_folder=None,
state_dict_index=None,
state_dict_folder=None,
ignore_mismatched_sizes=False,
low_cpu_mem_usage=False,
):
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries)
mismatched_keys = _find_mismatched_keys(
state_dict,
model_state_dict,
loaded_keys,
ignore_mismatched_sizes,
)
error_msgs = []
if low_cpu_mem_usage:
offload_index, state_dict_index = load_model_dict_into_meta(
model,
state_dict,
device_map=device_map,
dtype=dtype,
hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules,
unexpected_keys=unexpected_keys,
offload_folder=offload_folder,
offload_index=offload_index,
state_dict_index=state_dict_index,
state_dict_folder=state_dict_folder,
)
else:
assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict)
error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers)
return offload_index, state_dict_index, mismatched_keys, error_msgs
def _load_shard_files_with_threadpool(
shard_files,
model,
model_state_dict,
device_map=None,
dtype=None,
hf_quantizer=None,
keep_in_fp32_modules=None,
dduf_entries=None,
loaded_keys=None,
unexpected_keys=None,
offload_index=None,
offload_folder=None,
state_dict_index=None,
state_dict_folder=None,
ignore_mismatched_sizes=False,
low_cpu_mem_usage=False,
):
# Do not spawn anymore workers than you need
num_workers = min(len(shard_files), DEFAULT_HF_PARALLEL_LOADING_WORKERS)
logger.info(f"Loading model weights in parallel with {num_workers} workers...")
error_msgs = []
mismatched_keys = []
load_one = functools.partial(
_load_shard_file,
model=model,
model_state_dict=model_state_dict,
device_map=device_map,
dtype=dtype,
hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules,
dduf_entries=dduf_entries,
loaded_keys=loaded_keys,
unexpected_keys=unexpected_keys,
offload_index=offload_index,
offload_folder=offload_folder,
state_dict_index=state_dict_index,
state_dict_folder=state_dict_folder,
ignore_mismatched_sizes=ignore_mismatched_sizes,
low_cpu_mem_usage=low_cpu_mem_usage,
)
with ThreadPoolExecutor(max_workers=num_workers) as executor:
with logging.tqdm(total=len(shard_files), desc="Loading checkpoint shards") as pbar:
futures = [executor.submit(load_one, shard_file) for shard_file in shard_files]
for future in as_completed(futures):
result = future.result()
offload_index, state_dict_index, _mismatched_keys, _error_msgs = result
error_msgs += _error_msgs
mismatched_keys += _mismatched_keys
pbar.update(1)
return offload_index, state_dict_index, mismatched_keys, error_msgs
def _find_mismatched_keys(
state_dict,
model_state_dict,
loaded_keys,
ignore_mismatched_sizes,
):
mismatched_keys = []
if ignore_mismatched_sizes:
for checkpoint_key in loaded_keys:
model_key = checkpoint_key
# If the checkpoint is sharded, we may not have the key here.
if checkpoint_key not in state_dict:
continue
if model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape:
mismatched_keys.append(
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
)
del state_dict[checkpoint_key]
return mismatched_keys
def _load_state_dict_into_model( def _load_state_dict_into_model(
model_to_load, state_dict: OrderedDict, assign_to_params_buffers: bool = False model_to_load, state_dict: OrderedDict, assign_to_params_buffers: bool = False
) -> List[str]: ) -> List[str]:
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# limitations under the License. # limitations under the License.
import copy import copy
import functools
import inspect import inspect
import itertools import itertools
import json import json
...@@ -41,7 +42,9 @@ from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer ...@@ -41,7 +42,9 @@ from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer
from ..quantizers.quantization_config import QuantizationMethod from ..quantizers.quantization_config import QuantizationMethod
from ..utils import ( from ..utils import (
CONFIG_NAME, CONFIG_NAME,
ENV_VARS_TRUE_VALUES,
FLAX_WEIGHTS_NAME, FLAX_WEIGHTS_NAME,
HF_PARALLEL_LOADING_FLAG,
SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME,
SAFETENSORS_WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME,
...@@ -69,9 +72,8 @@ from .model_loading_utils import ( ...@@ -69,9 +72,8 @@ from .model_loading_utils import (
_expand_device_map, _expand_device_map,
_fetch_index_file, _fetch_index_file,
_fetch_index_file_legacy, _fetch_index_file_legacy,
_find_mismatched_keys, _load_shard_file,
_load_state_dict_into_model, _load_shard_files_with_threadpool,
load_model_dict_into_meta,
load_state_dict, load_state_dict,
) )
...@@ -208,34 +210,6 @@ def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype: ...@@ -208,34 +210,6 @@ def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
return last_tuple[1].dtype return last_tuple[1].dtype
def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefix=""):
"""
Checks if `model_to_load` supports param buffer assignment (such as when loading in empty weights) by first
checking if the model explicitly disables it, then by ensuring that the state dict keys are a subset of the model's
parameters.
"""
if model_to_load.device.type == "meta":
return False
if len([key for key in state_dict if key.startswith(start_prefix)]) == 0:
return False
# Some models explicitly do not support param buffer assignment
if not getattr(model_to_load, "_supports_param_buffer_assignment", True):
logger.debug(
f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower"
)
return False
# If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype
first_key = next(iter(model_to_load.state_dict().keys()))
if start_prefix + first_key in state_dict:
return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype
return False
@contextmanager @contextmanager
def no_init_weights(): def no_init_weights():
""" """
...@@ -988,6 +962,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -988,6 +962,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None) dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
disable_mmap = kwargs.pop("disable_mmap", False) disable_mmap = kwargs.pop("disable_mmap", False)
is_parallel_loading_enabled = os.environ.get(HF_PARALLEL_LOADING_FLAG, "").upper() in ENV_VARS_TRUE_VALUES
if is_parallel_loading_enabled and not low_cpu_mem_usage:
raise NotImplementedError("Parallel loading is not supported when not using `low_cpu_mem_usage`.")
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32 torch_dtype = torch.float32
logger.warning( logger.warning(
...@@ -1323,6 +1301,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -1323,6 +1301,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
hf_quantizer=hf_quantizer, hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules, keep_in_fp32_modules=keep_in_fp32_modules,
dduf_entries=dduf_entries, dduf_entries=dduf_entries,
is_parallel_loading_enabled=is_parallel_loading_enabled,
) )
loading_info = { loading_info = {
"missing_keys": missing_keys, "missing_keys": missing_keys,
...@@ -1518,6 +1497,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -1518,6 +1497,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
offload_state_dict: Optional[bool] = None, offload_state_dict: Optional[bool] = None,
offload_folder: Optional[Union[str, os.PathLike]] = None, offload_folder: Optional[Union[str, os.PathLike]] = None,
dduf_entries: Optional[Dict[str, DDUFEntry]] = None, dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
is_parallel_loading_enabled: Optional[bool] = False,
): ):
model_state_dict = model.state_dict() model_state_dict = model.state_dict()
expected_keys = list(model_state_dict.keys()) expected_keys = list(model_state_dict.keys())
...@@ -1531,6 +1511,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -1531,6 +1511,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
for pat in cls._keys_to_ignore_on_load_unexpected: for pat in cls._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
mismatched_keys = []
error_msgs = []
# Deal with offload # Deal with offload
if device_map is not None and "disk" in device_map.values(): if device_map is not None and "disk" in device_map.values():
if offload_folder is None: if offload_folder is None:
...@@ -1566,37 +1549,39 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -1566,37 +1549,39 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
# if state dict is not None, it means that we don't need to read the files from resolved_model_file also # if state dict is not None, it means that we don't need to read the files from resolved_model_file also
resolved_model_file = [state_dict] resolved_model_file = [state_dict]
if len(resolved_model_file) > 1: # Prepare the loading function sharing the attributes shared between them.
resolved_model_file = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards") load_fn = functools.partial(
_load_shard_files_with_threadpool if is_parallel_loading_enabled else _load_shard_file,
mismatched_keys = [] model=model,
assign_to_params_buffers = None model_state_dict=model_state_dict,
error_msgs = [] device_map=device_map,
dtype=dtype,
for shard_file in resolved_model_file: hf_quantizer=hf_quantizer,
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries) keep_in_fp32_modules=keep_in_fp32_modules,
mismatched_keys += _find_mismatched_keys( dduf_entries=dduf_entries,
state_dict, model_state_dict, loaded_keys, ignore_mismatched_sizes loaded_keys=loaded_keys,
) unexpected_keys=unexpected_keys,
offload_index=offload_index,
offload_folder=offload_folder,
state_dict_index=state_dict_index,
state_dict_folder=state_dict_folder,
ignore_mismatched_sizes=ignore_mismatched_sizes,
low_cpu_mem_usage=low_cpu_mem_usage,
)
if low_cpu_mem_usage: if is_parallel_loading_enabled:
offload_index, state_dict_index = load_model_dict_into_meta( offload_index, state_dict_index, _mismatched_keys, _error_msgs = load_fn(resolved_model_file)
model, error_msgs += _error_msgs
state_dict, mismatched_keys += _mismatched_keys
device_map=device_map, else:
dtype=dtype, shard_files = resolved_model_file
hf_quantizer=hf_quantizer, if len(resolved_model_file) > 1:
keep_in_fp32_modules=keep_in_fp32_modules, shard_files = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards")
unexpected_keys=unexpected_keys,
offload_folder=offload_folder, for shard_file in shard_files:
offload_index=offload_index, offload_index, state_dict_index, _mismatched_keys, _error_msgs = load_fn(shard_file)
state_dict_index=state_dict_index, error_msgs += _error_msgs
state_dict_folder=state_dict_folder, mismatched_keys += _mismatched_keys
)
else:
if assign_to_params_buffers is None:
assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict)
error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers)
empty_device_cache() empty_device_cache()
......
...@@ -20,11 +20,13 @@ from packaging import version ...@@ -20,11 +20,13 @@ from packaging import version
from .. import __version__ from .. import __version__
from .constants import ( from .constants import (
CONFIG_NAME, CONFIG_NAME,
DEFAULT_HF_PARALLEL_LOADING_WORKERS,
DEPRECATED_REVISION_ARGS, DEPRECATED_REVISION_ARGS,
DIFFUSERS_DYNAMIC_MODULE_NAME, DIFFUSERS_DYNAMIC_MODULE_NAME,
FLAX_WEIGHTS_NAME, FLAX_WEIGHTS_NAME,
GGUF_FILE_EXTENSION, GGUF_FILE_EXTENSION,
HF_MODULES_CACHE, HF_MODULES_CACHE,
HF_PARALLEL_LOADING_FLAG,
HUGGINGFACE_CO_RESOLVE_ENDPOINT, HUGGINGFACE_CO_RESOLVE_ENDPOINT,
MIN_PEFT_VERSION, MIN_PEFT_VERSION,
ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_EXTERNAL_WEIGHTS_NAME,
......
...@@ -43,6 +43,8 @@ DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"] ...@@ -43,6 +43,8 @@ DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
DIFFUSERS_REQUEST_TIMEOUT = 60 DIFFUSERS_REQUEST_TIMEOUT = 60
DIFFUSERS_ATTN_BACKEND = os.getenv("DIFFUSERS_ATTN_BACKEND", "native") DIFFUSERS_ATTN_BACKEND = os.getenv("DIFFUSERS_ATTN_BACKEND", "native")
DIFFUSERS_ATTN_CHECKS = os.getenv("DIFFUSERS_ATTN_CHECKS", "0") in ENV_VARS_TRUE_VALUES DIFFUSERS_ATTN_CHECKS = os.getenv("DIFFUSERS_ATTN_CHECKS", "0") in ENV_VARS_TRUE_VALUES
DEFAULT_HF_PARALLEL_LOADING_WORKERS = 8
HF_PARALLEL_LOADING_FLAG = "HF_ENABLE_PARALLEL_LOADING"
# Below should be `True` if the current version of `peft` and `transformers` are compatible with # Below should be `True` if the current version of `peft` and `transformers` are compatible with
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are # PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are
......
...@@ -1428,6 +1428,41 @@ class ModelTesterMixin: ...@@ -1428,6 +1428,41 @@ class ModelTesterMixin:
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
@require_torch_accelerator
def test_sharded_checkpoints_with_parallel_loading(self):
torch.manual_seed(0)
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**config).eval()
model = model.to(torch_device)
base_output = model(**inputs_dict)
model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB")
self.assertTrue(os.path.exists(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)))
# Now check if the right number of shards exists. First, let's get the number of shards.
# Since this number can be dependent on the model being tested, it's important that we calculate it
# instead of hardcoding it.
expected_num_shards = caculate_expected_num_shards(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME))
actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")])
self.assertTrue(actual_num_shards == expected_num_shards)
# Load with parallel loading
os.environ["HF_ENABLE_PARALLEL_LOADING"] = "yes"
new_model = self.model_class.from_pretrained(tmp_dir).eval()
new_model = new_model.to(torch_device)
torch.manual_seed(0)
if "generator" in inputs_dict:
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
new_output = new_model(**inputs_dict)
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
# set to no.
os.environ["HF_ENABLE_PARALLEL_LOADING"] = "no"
@require_torch_accelerator @require_torch_accelerator
def test_sharded_checkpoints_device_map(self): def test_sharded_checkpoints_device_map(self):
if self.model_class._no_split_modules is None: if self.model_class._no_split_modules is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment