Unverified Commit e0dfd7bc authored by Zach Mueller's avatar Zach Mueller Committed by GitHub
Browse files

Speedup model init on CPU (by 10x+ for llama-3-8B as one example) (#31771)



* 1,100%!

* Clean

* Don't touch DS

* Experiment with dtype allocation

* skip test_load_save_without_tied_weights test

* A little faster

* Include proper upscaling?

* Fixup tests

* Potentially skip?

* Let's see if this fixes git history

* Maintain new dtype

* Fin

* Rm hook idea for now

* New approach, see what breaks

* stage

* Clean

* Stash

* Should be fin now, just need to mark failing models

* Clean up

* Simplify

* Deal with weird models

* Enc/Dec

* Skip w/ reason

* Adjust test

* Fix test

* one more test

* Keep experimenting

* Fix ref

* TO REMOVE: testing feedback CI

* Right push

* Update tests/utils/test_modeling_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* disable

* Add new func

* Test nits from Amy

* Update src/transformers/modeling_utils.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Adjust comment

* Adjust comment on skip

* make private

* Fin

* Should be a not flag

* Clarify and rename test

---------
Co-authored-by: default avatarMarc Sun <marc@huggingface.co>
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 03a3becc
...@@ -40,6 +40,10 @@ for text generation, [`~generation.GenerationMixin`] (for the PyTorch models), ...@@ -40,6 +40,10 @@ for text generation, [`~generation.GenerationMixin`] (for the PyTorch models),
- push_to_hub - push_to_hub
- all - all
Custom models should also include a `_supports_assign_param_buffer`, which determines if superfast init can apply
on the particular model. Signs that your model needs this are if `test_save_and_load_from_pretrained` fails. If so,
set this to `False`.
## ModuleUtilsMixin ## ModuleUtilsMixin
[[autodoc]] modeling_utils.ModuleUtilsMixin [[autodoc]] modeling_utils.ModuleUtilsMixin
......
...@@ -338,6 +338,32 @@ def dtype_byte_size(dtype): ...@@ -338,6 +338,32 @@ def dtype_byte_size(dtype):
return bit_size // 8 return bit_size // 8
def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefix=""):
"""
Checks if `model_to_load` supports param buffer assignment (such
as when loading in empty weights) by first checking
if the model explicitly disables it, then by ensuring that the state dict keys
are a subset of the model's parameters.
"""
if len([key for key in state_dict if key.startswith(start_prefix)]) == 0:
return False
# Some models explicitly do not support param buffer assignment
if not getattr(model_to_load, "_supports_param_buffer_assignment", False):
logger.debug(
f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower"
)
return False
# If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype
first_key = list(model_to_load.state_dict().keys())[0]
if start_prefix + first_key in state_dict:
return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype
# For cases when the `state_dict` doesn't contain real weights to the model (`test_model_weights_reload_no_missing_tied_weights`)
return False
def shard_checkpoint( def shard_checkpoint(
state_dict: Dict[str, torch.Tensor], max_shard_size: Union[int, str] = "10GB", weights_name: str = WEIGHTS_NAME state_dict: Dict[str, torch.Tensor], max_shard_size: Union[int, str] = "10GB", weights_name: str = WEIGHTS_NAME
): ):
...@@ -657,7 +683,7 @@ def _find_identical(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor] ...@@ -657,7 +683,7 @@ def _find_identical(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]
return shared_tensors, identical return shared_tensors, identical
def _load_state_dict_into_model(model_to_load, state_dict, start_prefix): def _load_state_dict_into_model(model_to_load, state_dict, start_prefix, assign_to_params_buffers=False):
# Convert old format to new format if needed from a PyTorch state_dict # Convert old format to new format if needed from a PyTorch state_dict
old_keys = [] old_keys = []
new_keys = [] new_keys = []
...@@ -685,8 +711,10 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix): ...@@ -685,8 +711,10 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
# so we need to apply the function recursively. # so we need to apply the function recursively.
def load(module: nn.Module, state_dict, prefix=""): def load(module: nn.Module, state_dict, prefix="", assign_to_params_buffers=False):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
local_metadata["assign_to_params_buffers"] = assign_to_params_buffers
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs) args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
# Parameters of module and children will start with prefix. We can exit early if there are none in this # Parameters of module and children will start with prefix. We can exit early if there are none in this
# state_dict # state_dict
...@@ -710,9 +738,9 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix): ...@@ -710,9 +738,9 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
for name, child in module._modules.items(): for name, child in module._modules.items():
if child is not None: if child is not None:
load(child, state_dict, prefix + name + ".") load(child, state_dict, prefix + name + ".", assign_to_params_buffers)
load(model_to_load, state_dict, prefix=start_prefix) load(model_to_load, state_dict, prefix=start_prefix, assign_to_params_buffers=assign_to_params_buffers)
# Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so # Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so
# it's safe to delete it. # it's safe to delete it.
del state_dict del state_dict
...@@ -2852,6 +2880,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2852,6 +2880,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
weights are discarded. weights are discarded.
If model weights are the same precision as the base model (and is a supported model), weights will be lazily loaded
in using the `meta` device and brought into memory once an input is passed through that layer regardless of
`low_cpu_mem_usage`.
Parameters: Parameters:
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
Can be either: Can be either:
...@@ -2952,7 +2984,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2952,7 +2984,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
low_cpu_mem_usage(`bool`, *optional*): low_cpu_mem_usage(`bool`, *optional*):
Tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. Tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Generally should be combined with a `device_map` (such as `"auto"`) for best results.
This is an experimental feature and a subject to change at any moment. This is an experimental feature and a subject to change at any moment.
</Tip>
If the model weights are in the same precision as the model loaded in, `low_cpu_mem_usage` (without
`device_map`) is redundant and will not provide any benefit in regards to CPU memory usage. However,
this should still be enabled if you are passing in a `device_map`.
</Tip>
torch_dtype (`str` or `torch.dtype`, *optional*): torch_dtype (`str` or `torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model under a specific `dtype`. The different options Override the default `torch.dtype` and load the model under a specific `dtype`. The different options
are: are:
...@@ -4018,6 +4056,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -4018,6 +4056,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
missing_keys = sorted(set(expected_keys) - set(loaded_keys)) missing_keys = sorted(set(expected_keys) - set(loaded_keys))
unexpected_keys = set(loaded_keys) - set(expected_keys) unexpected_keys = set(loaded_keys) - set(expected_keys)
# Remove nonpersistent buffers from unexpected keys: they are not in the state dict but will be in the model # Remove nonpersistent buffers from unexpected keys: they are not in the state dict but will be in the model
# buffers # buffers
model_buffers = {n for n, _ in model.named_buffers()} model_buffers = {n for n, _ in model.named_buffers()}
...@@ -4252,7 +4291,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -4252,7 +4291,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
) )
else: else:
# Sharded checkpoint or whole but low_cpu_mem_usage==True # Sharded checkpoint or whole but low_cpu_mem_usage==True
error_msgs = _load_state_dict_into_model(model_to_load, state_dict, start_prefix) assign_to_params_buffers = check_support_param_buffer_assignment(
model_to_load, state_dict, start_prefix
)
error_msgs = _load_state_dict_into_model(
model_to_load, state_dict, start_prefix, assign_to_params_buffers
)
else: else:
# This should always be a list but, just to be sure. # This should always be a list but, just to be sure.
...@@ -4280,6 +4324,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -4280,6 +4324,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if len(resolved_archive_file) > 1: if len(resolved_archive_file) > 1:
resolved_archive_file = logging.tqdm(resolved_archive_file, desc="Loading checkpoint shards") resolved_archive_file = logging.tqdm(resolved_archive_file, desc="Loading checkpoint shards")
assign_to_params_buffers = None
for shard_file in resolved_archive_file: for shard_file in resolved_archive_file:
# Skip the load for shards that only contain disk-offloaded weights when using safetensors for the offload. # Skip the load for shards that only contain disk-offloaded weights when using safetensors for the offload.
if shard_file in disk_only_shard_files: if shard_file in disk_only_shard_files:
...@@ -4323,7 +4368,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -4323,7 +4368,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
) )
error_msgs += new_error_msgs error_msgs += new_error_msgs
else: else:
error_msgs += _load_state_dict_into_model(model_to_load, state_dict, start_prefix) # Sharded checkpoint or whole but low_cpu_mem_usage==True
if assign_to_params_buffers is None:
assign_to_params_buffers = check_support_param_buffer_assignment(
model_to_load, state_dict, start_prefix
)
error_msgs += _load_state_dict_into_model(
model_to_load, state_dict, start_prefix, assign_to_params_buffers
)
# force memory release # force memory release
del state_dict del state_dict
......
...@@ -178,6 +178,7 @@ class EncoderDecoderModel(PreTrainedModel): ...@@ -178,6 +178,7 @@ class EncoderDecoderModel(PreTrainedModel):
base_model_prefix = "encoder_decoder" base_model_prefix = "encoder_decoder"
main_input_name = "input_ids" main_input_name = "input_ids"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_supports_param_buffer_assignment = False
def __init__( def __init__(
self, self,
......
...@@ -773,6 +773,7 @@ class LxmertPreTrainedModel(PreTrainedModel): ...@@ -773,6 +773,7 @@ class LxmertPreTrainedModel(PreTrainedModel):
config_class = LxmertConfig config_class = LxmertConfig
load_tf_weights = load_tf_weights_in_lxmert load_tf_weights = load_tf_weights_in_lxmert
base_model_prefix = "lxmert" base_model_prefix = "lxmert"
_supports_param_buffer_assignment = False
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
......
...@@ -159,6 +159,7 @@ class VisionEncoderDecoderModel(PreTrainedModel): ...@@ -159,6 +159,7 @@ class VisionEncoderDecoderModel(PreTrainedModel):
base_model_prefix = "vision_encoder_decoder" base_model_prefix = "vision_encoder_decoder"
main_input_name = "pixel_values" main_input_name = "pixel_values"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_supports_param_buffer_assignment = False
def __init__( def __init__(
self, self,
......
...@@ -512,6 +512,12 @@ class BartModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin ...@@ -512,6 +512,12 @@ class BartModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
model.generate(input_ids, attention_mask=attention_mask) model.generate(input_ids, attention_mask=attention_mask)
model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3) model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)
@unittest.skip(
reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
)
def test_load_save_without_tied_weights(self):
pass
def assert_tensors_close(a, b, atol=1e-12, prefix=""): def assert_tensors_close(a, b, atol=1e-12, prefix=""):
"""If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error.""" """If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error."""
......
...@@ -476,6 +476,12 @@ class BigBirdPegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT ...@@ -476,6 +476,12 @@ class BigBirdPegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
self.assertTrue(torch.allclose(outputs1, outputs2, atol=1e-5)) self.assertTrue(torch.allclose(outputs1, outputs2, atol=1e-5))
@unittest.skip(
reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
)
def test_load_save_without_tied_weights(self):
pass
@require_torch @require_torch
@require_sentencepiece @require_sentencepiece
......
...@@ -758,6 +758,12 @@ class LongT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix ...@@ -758,6 +758,12 @@ class LongT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
[encoder_expected_shape] * len(attentions), [encoder_expected_shape] * len(attentions),
) )
@unittest.skip(
reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
)
def test_load_save_without_tied_weights(self):
pass
@require_torch @require_torch
class LongT5TGlobalModelTest(LongT5ModelTest): class LongT5TGlobalModelTest(LongT5ModelTest):
...@@ -1097,6 +1103,12 @@ class LongT5EncoderOnlyModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -1097,6 +1103,12 @@ class LongT5EncoderOnlyModelTest(ModelTesterMixin, unittest.TestCase):
[self.model_tester.num_attention_heads, block_len, 3 * block_len], [self.model_tester.num_attention_heads, block_len, 3 * block_len],
) )
@unittest.skip(
reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
)
def test_load_save_without_tied_weights(self):
pass
class LongT5EncoderOnlyTGlobalModelTest(LongT5EncoderOnlyModelTest): class LongT5EncoderOnlyTGlobalModelTest(LongT5EncoderOnlyModelTest):
def setUp(self): def setUp(self):
......
...@@ -778,6 +778,12 @@ class LxmertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): ...@@ -778,6 +778,12 @@ class LxmertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
def test_save_load_low_cpu_mem_usage_no_safetensors(self): def test_save_load_low_cpu_mem_usage_no_safetensors(self):
pass pass
@unittest.skip(
reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
)
def test_load_save_without_tied_weights(self):
pass
@require_torch @require_torch
class LxmertModelIntegrationTest(unittest.TestCase): class LxmertModelIntegrationTest(unittest.TestCase):
......
...@@ -331,6 +331,12 @@ class M2M100ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix ...@@ -331,6 +331,12 @@ class M2M100ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
model.generate(input_ids, attention_mask=attention_mask) model.generate(input_ids, attention_mask=attention_mask)
model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3) model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)
@unittest.skip(
reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
)
def test_load_save_without_tied_weights(self):
pass
def _long_tensor(tok_lst): def _long_tensor(tok_lst):
return torch.tensor(tok_lst, dtype=torch.long, device=torch_device) return torch.tensor(tok_lst, dtype=torch.long, device=torch_device)
......
...@@ -369,6 +369,12 @@ class MBartModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi ...@@ -369,6 +369,12 @@ class MBartModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
2, 2,
) )
@unittest.skip(
reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
)
def test_load_save_without_tied_weights(self):
pass
def assert_tensors_close(a, b, atol=1e-12, prefix=""): def assert_tensors_close(a, b, atol=1e-12, prefix=""):
"""If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error.""" """If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error."""
......
...@@ -346,6 +346,12 @@ class NllbMoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi ...@@ -346,6 +346,12 @@ class NllbMoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
self.assertIsNotNone(model(**input_dict)["encoder_router_logits"][1]) self.assertIsNotNone(model(**input_dict)["encoder_router_logits"][1])
self.assertIsNotNone(model(**input_dict)["decoder_router_logits"][0]) self.assertIsNotNone(model(**input_dict)["decoder_router_logits"][0])
@unittest.skip(
reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
)
def test_load_save_without_tied_weights(self):
pass
@require_torch @require_torch
@require_sentencepiece @require_sentencepiece
......
...@@ -323,6 +323,12 @@ class PLBartModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix ...@@ -323,6 +323,12 @@ class PLBartModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
def test_sample_generate(self): def test_sample_generate(self):
pass pass
@unittest.skip(
reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
)
def test_load_save_without_tied_weights(self):
pass
def assert_tensors_close(a, b, atol=1e-12, prefix=""): def assert_tensors_close(a, b, atol=1e-12, prefix=""):
"""If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error.""" """If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error."""
......
...@@ -506,6 +506,12 @@ class SeamlessM4TModelWithSpeechInputTest(ModelTesterMixin, unittest.TestCase): ...@@ -506,6 +506,12 @@ class SeamlessM4TModelWithSpeechInputTest(ModelTesterMixin, unittest.TestCase):
def test_training_gradient_checkpointing_use_reentrant_false(self): def test_training_gradient_checkpointing_use_reentrant_false(self):
pass pass
@unittest.skip(
reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
)
def test_load_save_without_tied_weights(self):
pass
def test_attention_outputs(self): def test_attention_outputs(self):
# expected length is subsampled so need to change a bit this test # expected length is subsampled so need to change a bit this test
if not self.has_attentions: if not self.has_attentions:
...@@ -758,6 +764,12 @@ class SeamlessM4TModelWithTextInputTest( ...@@ -758,6 +764,12 @@ class SeamlessM4TModelWithTextInputTest(
def test_retain_grad_hidden_states_attentions(self): def test_retain_grad_hidden_states_attentions(self):
pass pass
@unittest.skip(
reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
)
def test_load_save_without_tied_weights(self):
pass
@require_torch @require_torch
class SeamlessM4TGenerationTest(unittest.TestCase): class SeamlessM4TGenerationTest(unittest.TestCase):
......
...@@ -522,6 +522,12 @@ class SeamlessM4Tv2ModelWithSpeechInputTest(ModelTesterMixin, unittest.TestCase) ...@@ -522,6 +522,12 @@ class SeamlessM4Tv2ModelWithSpeechInputTest(ModelTesterMixin, unittest.TestCase)
def test_training_gradient_checkpointing_use_reentrant_false(self): def test_training_gradient_checkpointing_use_reentrant_false(self):
pass pass
@unittest.skip(
reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
)
def test_load_save_without_tied_weights(self):
pass
def test_attention_outputs(self): def test_attention_outputs(self):
# expected length is subsampled so need to change a bit this test # expected length is subsampled so need to change a bit this test
if not self.has_attentions: if not self.has_attentions:
...@@ -748,6 +754,12 @@ class SeamlessM4Tv2ModelWithTextInputTest(ModelTesterMixin, GenerationTesterMixi ...@@ -748,6 +754,12 @@ class SeamlessM4Tv2ModelWithTextInputTest(ModelTesterMixin, GenerationTesterMixi
def test_training_gradient_checkpointing_use_reentrant_false(self): def test_training_gradient_checkpointing_use_reentrant_false(self):
pass pass
@unittest.skip(
reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
)
def test_load_save_without_tied_weights(self):
pass
@require_torch @require_torch
class SeamlessM4Tv2GenerationTest(unittest.TestCase): class SeamlessM4Tv2GenerationTest(unittest.TestCase):
......
...@@ -720,6 +720,12 @@ class SwitchTransformersModelTest(ModelTesterMixin, GenerationTesterMixin, Pipel ...@@ -720,6 +720,12 @@ class SwitchTransformersModelTest(ModelTesterMixin, GenerationTesterMixin, Pipel
attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1] attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1]
self.assertEqual(sum([w.sum().item() for w in attn_weights]), 0.0) self.assertEqual(sum([w.sum().item() for w in attn_weights]), 0.0)
@unittest.skip(
reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
)
def test_load_save_without_tied_weights(self):
pass
class SwitchTransformersEncoderOnlyModelTester: class SwitchTransformersEncoderOnlyModelTester:
def __init__( def __init__(
...@@ -843,6 +849,12 @@ class SwitchTransformersEncoderOnlyModelTest(ModelTesterMixin, unittest.TestCase ...@@ -843,6 +849,12 @@ class SwitchTransformersEncoderOnlyModelTest(ModelTesterMixin, unittest.TestCase
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_fp16_forward(*config_and_inputs) self.model_tester.create_and_check_model_fp16_forward(*config_and_inputs)
@unittest.skip(
reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
)
def test_load_save_without_tied_weights(self):
pass
def use_task_specific_params(model, task): def use_task_specific_params(model, task):
model.config.update(model.config.task_specific_params[task]) model.config.update(model.config.task_specific_params[task])
......
...@@ -20,6 +20,7 @@ import os.path ...@@ -20,6 +20,7 @@ import os.path
import sys import sys
import tempfile import tempfile
import threading import threading
import time
import unittest import unittest
import unittest.mock as mock import unittest.mock as mock
import uuid import uuid
...@@ -894,32 +895,42 @@ class ModelUtilsTest(TestCasePlus): ...@@ -894,32 +895,42 @@ class ModelUtilsTest(TestCasePlus):
@require_usr_bin_time @require_usr_bin_time
@require_accelerate @require_accelerate
@mark.accelerate_tests @mark.accelerate_tests
def test_from_pretrained_low_cpu_mem_usage_measured(self): def test_from_pretrained_low_cpu_mem_usage_slower(self):
# test that `from_pretrained(..., low_cpu_mem_usage=True)` uses less cpu memory than default # Before this would test that `from_pretrained(..., low_cpu_mem_usage=True)` uses less cpu memory than default
# Now though the memory is the same, we simply test that loading with `low_cpu_mem_usage` winds up being *slower*
# (mostly from extra logic needed)
mname = "google-bert/bert-base-cased" mname = "hf-internal-testing/tiny-random-bert"
preamble = "from transformers import AutoModel" preamble = "from transformers import AutoModel"
one_liner_str = f'{preamble}; AutoModel.from_pretrained("{mname}", low_cpu_mem_usage=False)' one_liner_str = f'{preamble}; AutoModel.from_pretrained("{mname}", low_cpu_mem_usage=False)'
start_time = time.time()
# Save this output as `max_rss_normal` if testing memory results
max_rss_normal = self.python_one_liner_max_rss(one_liner_str) max_rss_normal = self.python_one_liner_max_rss(one_liner_str)
end_time = time.time()
elapsed_time_normal = end_time - start_time
# print(f"{max_rss_normal=}") # print(f"{max_rss_normal=}")
one_liner_str = f'{preamble}; AutoModel.from_pretrained("{mname}", low_cpu_mem_usage=True)' one_liner_str = f'{preamble}; AutoModel.from_pretrained("{mname}", low_cpu_mem_usage=True)'
start_time = time.time()
# Save this output as `max_rss_low_mem` if testing memory results
max_rss_low_mem = self.python_one_liner_max_rss(one_liner_str) max_rss_low_mem = self.python_one_liner_max_rss(one_liner_str)
# print(f"{max_rss_low_mem=}") end_time = time.time()
elapsed_time_low_mem = end_time - start_time
diff_bytes = max_rss_normal - max_rss_low_mem
diff_percent = diff_bytes / max_rss_low_mem # Should be within 2MBs of each other (overhead)
# print(f"{diff_bytes=}, {diff_percent=}") self.assertAlmostEqual(
# ideally we would compare that the diff is close to ~1x checkpoint size in bytes, but max_rss_normal / 1024 / 1024,
# measuring cpu memory on linux is very tricky and inconsistent, so instead let's check that max_rss_low_mem / 1024 / 1024,
# it's at least 15% less cpu memory consumed delta=2,
msg="using `low_cpu_mem_usage` should incur the same memory usage in both cases.",
)
self.assertGreater( self.assertGreater(
diff_percent, elapsed_time_low_mem,
0.15, elapsed_time_normal,
"should use less CPU memory for low_cpu_mem_usage=True, " "using `low_cpu_mem_usage` should be slower due to extra logic, "
f"but got max_rss_normal={max_rss_normal} and max_rss_low_mem={max_rss_low_mem}", f"but got elapsed_time_normal={elapsed_time_normal} and elapsed_time_low_mem={elapsed_time_low_mem}",
) )
# if you want to compare things manually, let's first look at the size of the model in bytes # if you want to compare things manually, let's first look at the size of the model in bytes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment