Unverified Commit 54a2361a authored by JB (Don)'s avatar JB (Don) Committed by GitHub
Browse files

Adding _tie_weights() to prediction heads to support low_cpu_mem_usage=True (#29024)

* Adding _tie_weights() to prediction heads to support low_cpu_mem_usage=True

* Testing for the non-safe-tensors case, since the default is safe-tensors already

* Running fixup/fix-copies

* Adding accelerate annotations to tests
parent ce47582d
...@@ -169,6 +169,18 @@ class TimmBackboneModelTest(ModelTesterMixin, BackboneTesterMixin, PipelineTeste ...@@ -169,6 +169,18 @@ class TimmBackboneModelTest(ModelTesterMixin, BackboneTesterMixin, PipelineTeste
def test_save_load(self): def test_save_load(self):
pass pass
@unittest.skip("No support for low_cpu_mem_usage=True.")
def test_save_load_low_cpu_mem_usage(self):
pass
@unittest.skip("No support for low_cpu_mem_usage=True.")
def test_save_load_low_cpu_mem_usage_checkpoints(self):
pass
@unittest.skip("No support for low_cpu_mem_usage=True.")
def test_save_load_low_cpu_mem_usage_no_safetensors(self):
pass
@unittest.skip("model weights aren't tied in TimmBackbone.") @unittest.skip("model weights aren't tied in TimmBackbone.")
def test_tie_model_weights(self): def test_tie_model_weights(self):
pass pass
......
...@@ -437,6 +437,88 @@ class ModelTesterMixin: ...@@ -437,6 +437,88 @@ class ModelTesterMixin:
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item() max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
@slow
@require_accelerate
@mark.accelerate_tests
def test_save_load_low_cpu_mem_usage(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
with tempfile.TemporaryDirectory() as saved_model_path:
for model_class in self.all_model_classes:
model_to_save = model_class(config)
model_to_save.save_pretrained(saved_model_path)
self._check_save_load_low_cpu_mem_usage(model_class, saved_model_path)
@slow
@require_accelerate
@mark.accelerate_tests
def test_save_load_low_cpu_mem_usage_checkpoints(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
with tempfile.TemporaryDirectory() as saved_model_path:
for model_class in self.all_model_classes:
model_to_save = model_class(config)
model_to_save.config.save_pretrained(saved_model_path)
torch.save(model_to_save.state_dict(), os.path.join(saved_model_path, "pytorch_model.bin"))
self._check_save_load_low_cpu_mem_usage(model_class, saved_model_path)
@slow
@require_accelerate
@mark.accelerate_tests
def test_save_load_low_cpu_mem_usage_no_safetensors(self):
with tempfile.TemporaryDirectory() as saved_model_path:
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model_to_save = model_class(config)
model_to_save.save_pretrained(saved_model_path, safe_serialization=False)
self._check_save_load_low_cpu_mem_usage(model_class, saved_model_path)
def _check_save_load_low_cpu_mem_usage(self, model_class, saved_model_path):
# Load the low usage and the normal models.
model_low_usage, loading_info = model_class.from_pretrained(
saved_model_path,
low_cpu_mem_usage=True,
output_loading_info=True,
)
model_non_low_usage = model_class.from_pretrained(saved_model_path)
# Check that there were no missing keys.
self.assertEqual(loading_info["missing_keys"], [])
# The low_cpu_mem_usage=True causes the model params to be initialized with device=meta, and then
# subsequently loaded with the correct values and onto the correct device. We check if there are any
# remaining params that were not properly loaded.
for name, param in model_low_usage.named_parameters():
self.assertNotEqual(
param.device,
torch.device("meta"),
"Parameter '" + name + "' has not been properly loaded and has device=meta.",
)
# Tests moving the model to a device other than meta.
model_low_usage.to(torch_device)
# Check that the parameters are equal.
for p1, p2 in zip(model_low_usage.parameters(), model_non_low_usage.parameters()):
self.assertEquals(p1.data.ne(p2.data).sum(), 0)
# Check that the state dict keys are equal.
self.assertEqual(set(model_low_usage.state_dict().keys()), set(model_non_low_usage.state_dict().keys()))
# Check that the shared tensors are equal.
tensor_ptrs1 = collections.defaultdict(list)
for name, tensor in model_low_usage.state_dict().items():
tensor_ptrs1[id_tensor_storage(tensor)].append(name)
tied_params1 = [names for _, names in tensor_ptrs1.items() if len(names) > 1]
tensor_ptrs2 = collections.defaultdict(list)
for name, tensor in model_non_low_usage.state_dict().items():
tensor_ptrs2[id_tensor_storage(tensor)].append(name)
tied_params2 = [names for _, names in tensor_ptrs2.items() if len(names) > 1]
self.assertEqual(tied_params1, tied_params2)
def test_fast_init_context_manager(self): def test_fast_init_context_manager(self):
# 1. Create a dummy class. Should have buffers as well? To make sure we test __init__ # 1. Create a dummy class. Should have buffers as well? To make sure we test __init__
class MyClass(PreTrainedModel): class MyClass(PreTrainedModel):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment