Unverified Commit 8c5c180d authored by Marc Sun's avatar Marc Sun Committed by GitHub
Browse files

Fix serialization for offloaded model (#31727)

* Fix serialization

* style

* add test
parent eaa5f414
...@@ -2518,9 +2518,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2518,9 +2518,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Save the model # Save the model
if state_dict is None: if state_dict is None:
# if any model parameters are offloaded to the disk, make module map # if any model parameters are offloaded, make module map
if hasattr(self, "hf_device_map") and ( if (
"cpu" in self.hf_device_map.values() or "disk" in self.hf_device_map.values() hasattr(self, "hf_device_map")
and len(set(self.hf_device_map.values())) > 1
and ("cpu" in self.hf_device_map.values() or "disk" in self.hf_device_map.values())
): ):
warnings.warn( warnings.warn(
"Attempting to save a model with offloaded modules. Ensure that unallocated cpu memory exceeds the `shard_size` (5GB default)" "Attempting to save a model with offloaded modules. Ensure that unallocated cpu memory exceeds the `shard_size` (5GB default)"
...@@ -2532,7 +2534,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2532,7 +2534,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
for key in module_state_dict: for key in module_state_dict:
module_map[name + f".{key}"] = module module_map[name + f".{key}"] = module
state_dict = model_to_save.state_dict() state_dict = model_to_save.state_dict()
# Translate state_dict from smp to hf if saving with smp >= 1.10 # Translate state_dict from smp to hf if saving with smp >= 1.10
...@@ -2655,7 +2656,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2655,7 +2656,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
and reg.fullmatch(filename_no_suffix) is not None and reg.fullmatch(filename_no_suffix) is not None
): ):
os.remove(full_filename) os.remove(full_filename)
# Save the model # Save the model
for shard_file, tensors in state_dict_split.filename_to_tensors.items(): for shard_file, tensors in state_dict_split.filename_to_tensors.items():
shard = {tensor: state_dict[tensor] for tensor in tensors} shard = {tensor: state_dict[tensor] for tensor in tensors}
...@@ -2667,15 +2667,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2667,15 +2667,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
f"Please upgrade accelerate with `pip install -U accelerate`" f"Please upgrade accelerate with `pip install -U accelerate`"
) )
# init state_dict for this shard # init state_dict for this shard
state_dict = {name: "" for name in shard} shard_state_dict = {name: "" for name in shard}
for module_name in shard: for module_name in shard:
module = module_map[module_name] module = module_map[module_name]
# update state dict with onloaded parameters # update state dict with onloaded parameters
state_dict = get_state_dict_from_offload(module, module_name, state_dict) shard_state_dict = get_state_dict_from_offload(module, module_name, shard_state_dict)
# assign shard to be the completed state dict # assign shard to be the completed state dict
shard = state_dict shard = shard_state_dict
del state_dict del shard_state_dict
gc.collect() gc.collect()
if safe_serialization: if safe_serialization:
......
...@@ -1065,6 +1065,23 @@ class ModelUtilsTest(TestCasePlus): ...@@ -1065,6 +1065,23 @@ class ModelUtilsTest(TestCasePlus):
# This check we did call the fake head request # This check we did call the fake head request
mock_head.assert_called() mock_head.assert_called()
@require_accelerate
@mark.accelerate_tests
def test_save_model_with_device_map_cpu(self):
model_id = "hf-internal-testing/tiny-random-gpt2"
inputs = torch.tensor([[1, 2, 3]])
with tempfile.TemporaryDirectory() as tmp_dir:
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cpu")
output = model(inputs)[0]
model.save_pretrained(
tmp_dir, max_shard_size="200KB"
) # model is 1.6MB, max shard size is allocated to cpu by default
saved_model = AutoModelForCausalLM.from_pretrained(tmp_dir, device_map="cpu")
saved_model_output = saved_model(inputs)[0]
self.assertTrue(torch.allclose(output, saved_model_output))
@require_accelerate @require_accelerate
@mark.accelerate_tests @mark.accelerate_tests
@require_torch_accelerator @require_torch_accelerator
...@@ -1083,9 +1100,9 @@ class ModelUtilsTest(TestCasePlus): ...@@ -1083,9 +1100,9 @@ class ModelUtilsTest(TestCasePlus):
# check_models_equal requires onloaded tensors # check_models_equal requires onloaded tensors
model_id = "hf-internal-testing/tiny-random-gpt2" model_id = "hf-internal-testing/tiny-random-gpt2"
onloaded_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cpu") onloaded_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cpu").to(f"{torch_device}:0")
inputs = torch.tensor([[1, 2, 3]]).to(f"{torch_device}:0") inputs = torch.tensor([[1, 2, 3]]).to(f"{torch_device}:0")
cpu_output = onloaded_model(inputs)[0] output = onloaded_model(inputs)[0]
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
offload_folder = os.path.join(tmp_dir, "offload") offload_folder = os.path.join(tmp_dir, "offload")
...@@ -1099,7 +1116,7 @@ class ModelUtilsTest(TestCasePlus): ...@@ -1099,7 +1116,7 @@ class ModelUtilsTest(TestCasePlus):
saved_model = AutoModelForCausalLM.from_pretrained(tmp_dir, device_map=device_map) saved_model = AutoModelForCausalLM.from_pretrained(tmp_dir, device_map=device_map)
postsaved_output = saved_model(inputs)[0] postsaved_output = saved_model(inputs)[0]
self.assertTrue(torch.allclose(cpu_output, presaved_output, atol=1e-4)) self.assertTrue(torch.allclose(output, presaved_output, atol=1e-4))
self.assertTrue(torch.allclose(presaved_output, postsaved_output)) self.assertTrue(torch.allclose(presaved_output, postsaved_output))
@require_safetensors @require_safetensors
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment