Unverified Commit 43efd7cb authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix balanced and auto device_map (#22271)

parent 89f0fda5
...@@ -2563,7 +2563,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2563,7 +2563,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
elif device_map in ["balanced", "balanced_low_0"] and get_balanced_memory is None: elif device_map in ["balanced", "balanced_low_0"] and get_balanced_memory is None:
raise ValueError(f"`device_map={device_map}` requires a source install of Accelerate.") raise ValueError(f"`device_map={device_map}` requires a source install of Accelerate.")
kwargs = {"no_split_module_classes": no_split_modules, "max_memory": max_memory} kwargs = {"no_split_module_classes": no_split_modules}
if "special_dtypes" in inspect.signature(infer_auto_device_map).parameters: if "special_dtypes" in inspect.signature(infer_auto_device_map).parameters:
kwargs["special_dtypes"] = special_dtypes kwargs["special_dtypes"] = special_dtypes
elif len(special_dtypes) > 0: elif len(special_dtypes) > 0:
...@@ -2578,6 +2578,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2578,6 +2578,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
low_zero=(device_map == "balanced_low_0"), low_zero=(device_map == "balanced_low_0"),
**kwargs, **kwargs,
) )
kwargs["max_memory"] = max_memory
# Make sure tied weights are tied before creating the device map. # Make sure tied weights are tied before creating the device map.
model.tie_weights() model.tie_weights()
device_map = infer_auto_device_map(model, dtype=torch_dtype if not load_in_8bit else torch.int8, **kwargs) device_map = infer_auto_device_map(model, dtype=torch_dtype if not load_in_8bit else torch.int8, **kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment