Unverified Commit e0dfd7bc authored by Zach Mueller's avatar Zach Mueller Committed by GitHub
Browse files

Speedup model init on CPU (by 10x+ for llama-3-8B as one example) (#31771)



* 1,100%!

* Clean

* Don't touch DS

* Experiment with dtype allocation

* skip test_load_save_without_tied_weights test

* A little faster

* Include proper upscaling?

* Fixup tests

* Potentially skip?

* Let's see if this fixes git history

* Maintain new dtype

* Fin

* Rm hook idea for now

* New approach, see what breaks

* stage

* Clean

* Stash

* Should be fin now, just need to mark failing models

* Clean up

* Simplify

* Deal with weird models

* Enc/Dec

* Skip w/ reason

* Adjust test

* Fix test

* one more test

* Keep experimenting

* Fix ref

* TO REMOVE: testing feedback CI

* Right push

* Update tests/utils/test_modeling_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* disable

* Add new func

* Test nits from Amy

* Update src/transformers/modeling_utils.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Adjust comment

* Adjust comment on skip

* make private

* Fin

* Should be a not flag

* Clarify and rename test

---------
Co-authored-by: default avatarMarc Sun <marc@huggingface.co>
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 03a3becc
......@@ -40,6 +40,10 @@ for text generation, [`~generation.GenerationMixin`] (for the PyTorch models),
- push_to_hub
- all
Custom models should also include a `_supports_assign_param_buffer`, which determines if superfast init can apply
on the particular model. Signs that your model needs this are if `test_save_and_load_from_pretrained` fails. If so,
set this to `False`.
## ModuleUtilsMixin
[[autodoc]] modeling_utils.ModuleUtilsMixin
......
......@@ -338,6 +338,32 @@ def dtype_byte_size(dtype):
return bit_size // 8
def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefix=""):
"""
Checks if `model_to_load` supports param buffer assignment (such
as when loading in empty weights) by first checking
if the model explicitly disables it, then by ensuring that the state dict keys
are a subset of the model's parameters.
"""
if len([key for key in state_dict if key.startswith(start_prefix)]) == 0:
return False
# Some models explicitly do not support param buffer assignment
if not getattr(model_to_load, "_supports_param_buffer_assignment", False):
logger.debug(
f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower"
)
return False
# If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype
first_key = list(model_to_load.state_dict().keys())[0]
if start_prefix + first_key in state_dict:
return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype
# For cases when the `state_dict` doesn't contain real weights to the model (`test_model_weights_reload_no_missing_tied_weights`)
return False
def shard_checkpoint(
state_dict: Dict[str, torch.Tensor], max_shard_size: Union[int, str] = "10GB", weights_name: str = WEIGHTS_NAME
):
......@@ -657,7 +683,7 @@ def _find_identical(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]
return shared_tensors, identical
def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
def _load_state_dict_into_model(model_to_load, state_dict, start_prefix, assign_to_params_buffers=False):
# Convert old format to new format if needed from a PyTorch state_dict
old_keys = []
new_keys = []
......@@ -685,8 +711,10 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
# so we need to apply the function recursively.
def load(module: nn.Module, state_dict, prefix=""):
def load(module: nn.Module, state_dict, prefix="", assign_to_params_buffers=False):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
local_metadata["assign_to_params_buffers"] = assign_to_params_buffers
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
# Parameters of module and children will start with prefix. We can exit early if there are none in this
# state_dict
......@@ -710,9 +738,9 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
for name, child in module._modules.items():
if child is not None:
load(child, state_dict, prefix + name + ".")
load(child, state_dict, prefix + name + ".", assign_to_params_buffers)
load(model_to_load, state_dict, prefix=start_prefix)
load(model_to_load, state_dict, prefix=start_prefix, assign_to_params_buffers=assign_to_params_buffers)
# Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so
# it's safe to delete it.
del state_dict
......@@ -2852,6 +2880,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
weights are discarded.
If model weights are the same precision as the base model (and is a supported model), weights will be lazily loaded
in using the `meta` device and brought into memory once an input is passed through that layer regardless of
`low_cpu_mem_usage`.
Parameters:
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
Can be either:
......@@ -2952,7 +2984,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
low_cpu_mem_usage(`bool`, *optional*):
Tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Generally should be combined with a `device_map` (such as `"auto"`) for best results.
This is an experimental feature and a subject to change at any moment.
</Tip>
If the model weights are in the same precision as the model loaded in, `low_cpu_mem_usage` (without
`device_map`) is redundant and will not provide any benefit in regards to CPU memory usage. However,
this should still be enabled if you are passing in a `device_map`.
</Tip>
torch_dtype (`str` or `torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model under a specific `dtype`. The different options
are:
......@@ -4018,6 +4056,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
missing_keys = sorted(set(expected_keys) - set(loaded_keys))
unexpected_keys = set(loaded_keys) - set(expected_keys)
# Remove nonpersistent buffers from unexpected keys: they are not in the state dict but will be in the model
# buffers
model_buffers = {n for n, _ in model.named_buffers()}
......@@ -4252,7 +4291,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
else:
# Sharded checkpoint or whole but low_cpu_mem_usage==True
error_msgs = _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
assign_to_params_buffers = check_support_param_buffer_assignment(
model_to_load, state_dict, start_prefix
)
error_msgs = _load_state_dict_into_model(
model_to_load, state_dict, start_prefix, assign_to_params_buffers
)
else:
# This should always be a list but, just to be sure.
......@@ -4280,6 +4324,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if len(resolved_archive_file) > 1:
resolved_archive_file = logging.tqdm(resolved_archive_file, desc="Loading checkpoint shards")
assign_to_params_buffers = None
for shard_file in resolved_archive_file:
# Skip the load for shards that only contain disk-offloaded weights when using safetensors for the offload.
if shard_file in disk_only_shard_files:
......@@ -4323,7 +4368,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
error_msgs += new_error_msgs
else:
error_msgs += _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
# Sharded checkpoint or whole but low_cpu_mem_usage==True
if assign_to_params_buffers is None:
assign_to_params_buffers = check_support_param_buffer_assignment(
model_to_load, state_dict, start_prefix
)
error_msgs += _load_state_dict_into_model(
model_to_load, state_dict, start_prefix, assign_to_params_buffers
)
# force memory release
del state_dict
......
......@@ -178,6 +178,7 @@ class EncoderDecoderModel(PreTrainedModel):
base_model_prefix = "encoder_decoder"
main_input_name = "input_ids"
supports_gradient_checkpointing = True
_supports_param_buffer_assignment = False
def __init__(
self,
......
......@@ -773,6 +773,7 @@ class LxmertPreTrainedModel(PreTrainedModel):
config_class = LxmertConfig
load_tf_weights = load_tf_weights_in_lxmert
base_model_prefix = "lxmert"
_supports_param_buffer_assignment = False
def _init_weights(self, module):
"""Initialize the weights"""
......
......@@ -159,6 +159,7 @@ class VisionEncoderDecoderModel(PreTrainedModel):
base_model_prefix = "vision_encoder_decoder"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_supports_param_buffer_assignment = False
def __init__(
self,
......
......@@ -512,6 +512,12 @@ class BartModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
model.generate(input_ids, attention_mask=attention_mask)
model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)
@unittest.skip(
reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
)
def test_load_save_without_tied_weights(self):
pass
def assert_tensors_close(a, b, atol=1e-12, prefix=""):
"""If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error."""
......
......@@ -476,6 +476,12 @@ class BigBirdPegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
self.assertTrue(torch.allclose(outputs1, outputs2, atol=1e-5))
@unittest.skip(
reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
)
def test_load_save_without_tied_weights(self):
pass
@require_torch
@require_sentencepiece
......
......@@ -758,6 +758,12 @@ class LongT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
[encoder_expected_shape] * len(attentions),
)
@unittest.skip(
reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
)
def test_load_save_without_tied_weights(self):
pass
@require_torch
class LongT5TGlobalModelTest(LongT5ModelTest):
......@@ -1097,6 +1103,12 @@ class LongT5EncoderOnlyModelTest(ModelTesterMixin, unittest.TestCase):
[self.model_tester.num_attention_heads, block_len, 3 * block_len],
)
@unittest.skip(
reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
)
def test_load_save_without_tied_weights(self):
pass
class LongT5EncoderOnlyTGlobalModelTest(LongT5EncoderOnlyModelTest):
def setUp(self):
......
......@@ -778,6 +778,12 @@ class LxmertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
def test_save_load_low_cpu_mem_usage_no_safetensors(self):
pass
@unittest.skip(
reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
)
def test_load_save_without_tied_weights(self):
pass
@require_torch
class LxmertModelIntegrationTest(unittest.TestCase):
......
......@@ -331,6 +331,12 @@ class M2M100ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
model.generate(input_ids, attention_mask=attention_mask)
model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)
@unittest.skip(
reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
)
def test_load_save_without_tied_weights(self):
pass
def _long_tensor(tok_lst):
return torch.tensor(tok_lst, dtype=torch.long, device=torch_device)
......
......@@ -369,6 +369,12 @@ class MBartModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
2,
)
@unittest.skip(
reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
)
def test_load_save_without_tied_weights(self):
pass
def assert_tensors_close(a, b, atol=1e-12, prefix=""):
"""If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error."""
......
......@@ -346,6 +346,12 @@ class NllbMoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
self.assertIsNotNone(model(**input_dict)["encoder_router_logits"][1])
self.assertIsNotNone(model(**input_dict)["decoder_router_logits"][0])
@unittest.skip(
reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
)
def test_load_save_without_tied_weights(self):
pass
@require_torch
@require_sentencepiece
......
......@@ -323,6 +323,12 @@ class PLBartModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
def test_sample_generate(self):
pass
@unittest.skip(
reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
)
def test_load_save_without_tied_weights(self):
pass
def assert_tensors_close(a, b, atol=1e-12, prefix=""):
"""If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error."""
......
......@@ -506,6 +506,12 @@ class SeamlessM4TModelWithSpeechInputTest(ModelTesterMixin, unittest.TestCase):
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@unittest.skip(
reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
)
def test_load_save_without_tied_weights(self):
pass
def test_attention_outputs(self):
# expected length is subsampled so need to change a bit this test
if not self.has_attentions:
......@@ -758,6 +764,12 @@ class SeamlessM4TModelWithTextInputTest(
def test_retain_grad_hidden_states_attentions(self):
pass
@unittest.skip(
reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
)
def test_load_save_without_tied_weights(self):
pass
@require_torch
class SeamlessM4TGenerationTest(unittest.TestCase):
......
......@@ -522,6 +522,12 @@ class SeamlessM4Tv2ModelWithSpeechInputTest(ModelTesterMixin, unittest.TestCase)
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@unittest.skip(
reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
)
def test_load_save_without_tied_weights(self):
pass
def test_attention_outputs(self):
# expected length is subsampled so need to change a bit this test
if not self.has_attentions:
......@@ -748,6 +754,12 @@ class SeamlessM4Tv2ModelWithTextInputTest(ModelTesterMixin, GenerationTesterMixi
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@unittest.skip(
reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
)
def test_load_save_without_tied_weights(self):
pass
@require_torch
class SeamlessM4Tv2GenerationTest(unittest.TestCase):
......
......@@ -720,6 +720,12 @@ class SwitchTransformersModelTest(ModelTesterMixin, GenerationTesterMixin, Pipel
attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1]
self.assertEqual(sum([w.sum().item() for w in attn_weights]), 0.0)
@unittest.skip(
reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
)
def test_load_save_without_tied_weights(self):
pass
class SwitchTransformersEncoderOnlyModelTester:
def __init__(
......@@ -843,6 +849,12 @@ class SwitchTransformersEncoderOnlyModelTest(ModelTesterMixin, unittest.TestCase
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_fp16_forward(*config_and_inputs)
@unittest.skip(
reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
)
def test_load_save_without_tied_weights(self):
pass
def use_task_specific_params(model, task):
model.config.update(model.config.task_specific_params[task])
......
......@@ -20,6 +20,7 @@ import os.path
import sys
import tempfile
import threading
import time
import unittest
import unittest.mock as mock
import uuid
......@@ -894,32 +895,42 @@ class ModelUtilsTest(TestCasePlus):
@require_usr_bin_time
@require_accelerate
@mark.accelerate_tests
def test_from_pretrained_low_cpu_mem_usage_measured(self):
# test that `from_pretrained(..., low_cpu_mem_usage=True)` uses less cpu memory than default
def test_from_pretrained_low_cpu_mem_usage_slower(self):
# Before this would test that `from_pretrained(..., low_cpu_mem_usage=True)` uses less cpu memory than default
# Now though the memory is the same, we simply test that loading with `low_cpu_mem_usage` winds up being *slower*
# (mostly from extra logic needed)
mname = "google-bert/bert-base-cased"
mname = "hf-internal-testing/tiny-random-bert"
preamble = "from transformers import AutoModel"
one_liner_str = f'{preamble}; AutoModel.from_pretrained("{mname}", low_cpu_mem_usage=False)'
start_time = time.time()
# Save this output as `max_rss_normal` if testing memory results
max_rss_normal = self.python_one_liner_max_rss(one_liner_str)
end_time = time.time()
elapsed_time_normal = end_time - start_time
# print(f"{max_rss_normal=}")
one_liner_str = f'{preamble}; AutoModel.from_pretrained("{mname}", low_cpu_mem_usage=True)'
start_time = time.time()
# Save this output as `max_rss_low_mem` if testing memory results
max_rss_low_mem = self.python_one_liner_max_rss(one_liner_str)
# print(f"{max_rss_low_mem=}")
end_time = time.time()
elapsed_time_low_mem = end_time - start_time
diff_bytes = max_rss_normal - max_rss_low_mem
diff_percent = diff_bytes / max_rss_low_mem
# print(f"{diff_bytes=}, {diff_percent=}")
# ideally we would compare that the diff is close to ~1x checkpoint size in bytes, but
# measuring cpu memory on linux is very tricky and inconsistent, so instead let's check that
# it's at least 15% less cpu memory consumed
# Should be within 2MBs of each other (overhead)
self.assertAlmostEqual(
max_rss_normal / 1024 / 1024,
max_rss_low_mem / 1024 / 1024,
delta=2,
msg="using `low_cpu_mem_usage` should incur the same memory usage in both cases.",
)
self.assertGreater(
diff_percent,
0.15,
"should use less CPU memory for low_cpu_mem_usage=True, "
f"but got max_rss_normal={max_rss_normal} and max_rss_low_mem={max_rss_low_mem}",
elapsed_time_low_mem,
elapsed_time_normal,
"using `low_cpu_mem_usage` should be slower due to extra logic, "
f"but got elapsed_time_normal={elapsed_time_normal} and elapsed_time_low_mem={elapsed_time_low_mem}",
)
# if you want to compare things manually, let's first look at the size of the model in bytes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment