"docs/source/vscode:/vscode.git/clone" did not exist on "842e99f1b9ee2a0fa239997ef695c5ed0bd77195"
Unverified Commit 1ac599d9 authored by Marc Sun's avatar Marc Sun Committed by GitHub
Browse files

Fix offload disk for loading derivated model checkpoint into base model (#27253)

* fix

* style

* add test
parent b71c38a0
...@@ -3793,8 +3793,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -3793,8 +3793,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
else: else:
folder = None folder = None
if device_map is not None and is_safetensors: if device_map is not None and is_safetensors:
param_device_map = expand_device_map(device_map, original_loaded_keys) param_device_map = expand_device_map(device_map, original_loaded_keys, start_prefix)
str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32" str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32"
if sharded_metadata is None: if sharded_metadata is None:
archive_file = ( archive_file = (
...@@ -3806,9 +3805,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -3806,9 +3805,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
else: else:
weight_map = {p: os.path.join(folder, f) for p, f in sharded_metadata["weight_map"].items()} weight_map = {p: os.path.join(folder, f) for p, f in sharded_metadata["weight_map"].items()}
offload_index = { offload_index = {
p: {"safetensors_file": f, "weight_name": p, "dtype": str_dtype} p[len(start_prefix) :]: {"safetensors_file": f, "weight_name": p, "dtype": str_dtype}
for p, f in weight_map.items() for p, f in weight_map.items()
if param_device_map[p] == "disk" if p.startswith(start_prefix) and param_device_map[p[len(start_prefix) :]] == "disk"
} }
if state_dict is not None: if state_dict is not None:
...@@ -3842,7 +3841,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -3842,7 +3841,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
state_dict_index = None state_dict_index = None
if is_sharded_safetensors: if is_sharded_safetensors:
disk_only_shard_files = get_disk_only_shard_files(device_map, sharded_metadata=sharded_metadata) disk_only_shard_files = get_disk_only_shard_files(
device_map, sharded_metadata=sharded_metadata, start_prefix=start_prefix
)
disk_only_shard_files = [os.path.join(folder, f) for f in disk_only_shard_files] disk_only_shard_files = [os.path.join(folder, f) for f in disk_only_shard_files]
else: else:
disk_only_shard_files = [] disk_only_shard_files = []
...@@ -4576,11 +4577,12 @@ def unwrap_model(model: nn.Module) -> nn.Module: ...@@ -4576,11 +4577,12 @@ def unwrap_model(model: nn.Module) -> nn.Module:
return model return model
def expand_device_map(device_map, param_names): def expand_device_map(device_map, param_names, start_prefix):
""" """
Expand a device map to return the correspondance parameter name to device. Expand a device map to return the correspondance parameter name to device.
""" """
new_device_map = {} new_device_map = {}
param_names = [p[len(start_prefix) :] for p in param_names if p.startswith(start_prefix)]
for module, device in device_map.items(): for module, device in device_map.items():
new_device_map.update( new_device_map.update(
{p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""} {p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""}
...@@ -4588,12 +4590,16 @@ def expand_device_map(device_map, param_names): ...@@ -4588,12 +4590,16 @@ def expand_device_map(device_map, param_names):
return new_device_map return new_device_map
def get_disk_only_shard_files(device_map, sharded_metadata): def get_disk_only_shard_files(device_map, sharded_metadata, start_prefix):
""" """
Returns the list of shard files containing only weights offloaded to disk. Returns the list of shard files containing only weights offloaded to disk.
""" """
weight_map = {
p[len(start_prefix) :]: v for p, v in sharded_metadata["weight_map"].items() if p.startswith(start_prefix)
}
files_content = collections.defaultdict(list) files_content = collections.defaultdict(list)
for weight_name, filename in sharded_metadata["weight_map"].items(): for weight_name, filename in weight_map.items():
while len(weight_name) > 0 and weight_name not in device_map: while len(weight_name) > 0 and weight_name not in device_map:
weight_name = ".".join(weight_name.split(".")[:-1]) weight_name = ".".join(weight_name.split(".")[:-1])
files_content[filename].append(device_map[weight_name]) files_content[filename].append(device_map[weight_name])
......
...@@ -750,6 +750,46 @@ class ModelUtilsTest(TestCasePlus): ...@@ -750,6 +750,46 @@ class ModelUtilsTest(TestCasePlus):
self.assertTrue(torch.allclose(outputs1.logits.cpu(), outputs2.logits.cpu())) self.assertTrue(torch.allclose(outputs1.logits.cpu(), outputs2.logits.cpu()))
@require_accelerate
@mark.accelerate_tests
@require_torch_accelerator
def test_from_pretrained_disk_offload_derived_to_base_model(self):
derived_model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2")
device_map = {
"wte": 0,
"wpe": 0,
"h.0": "cpu",
"h.1": "cpu",
"h.2": "cpu",
"h.3": "disk",
"h.4": "disk",
"ln_f": 0,
}
with tempfile.TemporaryDirectory() as tmp_dir:
inputs = torch.tensor([[1, 2, 3]]).to(0)
derived_model.save_pretrained(tmp_dir, use_safetensors=True)
base_model = AutoModel.from_pretrained(tmp_dir)
outputs1 = base_model.to(0)(inputs)
# with disk offload
offload_folder = os.path.join(tmp_dir, "offload")
base_model_with_offload = AutoModel.from_pretrained(
tmp_dir, device_map=device_map, offload_folder=offload_folder
)
outputs2 = base_model_with_offload(inputs)
self.assertTrue(torch.allclose(outputs1[0].cpu(), outputs2[0].cpu()))
# With state dict temp offload
new_model_with_offload = AutoModel.from_pretrained(
tmp_dir,
device_map=device_map,
offload_folder=offload_folder,
offload_state_dict=True,
)
outputs2 = new_model_with_offload(inputs)
self.assertTrue(torch.allclose(outputs1[0].cpu(), outputs2[0].cpu()))
def test_cached_files_are_used_when_internet_is_down(self): def test_cached_files_are_used_when_internet_is_down(self):
# A mock response for an HTTP head request to emulate server down # A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock() response_mock = mock.Mock()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment