"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "38f7461df3fe51308a62a81e4a0e7770a38d7125"
Unverified Commit e0bc4c73 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Add balanced strategies for device_map in from_pretrained (#18349)



* Add balanced strategies for device_map in from_pretrained

* Add safeguards for Accelerate version

* Update src/transformers/modeling_utils.py
Co-authored-by: default avatarLysandre Debut <lysandre.debut@reseau.eseo.fr>

* Style
Co-authored-by: default avatarLysandre Debut <lysandre.debut@reseau.eseo.fr>
parent 39e76d76
...@@ -76,6 +76,7 @@ from .utils.versions import require_version_core ...@@ -76,6 +76,7 @@ from .utils.versions import require_version_core
if is_accelerate_available(): if is_accelerate_available():
from accelerate import __version__ as accelerate_version
from accelerate import dispatch_model, infer_auto_device_map, init_empty_weights from accelerate import dispatch_model, infer_auto_device_map, init_empty_weights
from accelerate.utils import ( from accelerate.utils import (
load_offloaded_weights, load_offloaded_weights,
...@@ -84,6 +85,11 @@ if is_accelerate_available(): ...@@ -84,6 +85,11 @@ if is_accelerate_available():
set_module_tensor_to_device, set_module_tensor_to_device,
) )
if version.parse(accelerate_version) > version.parse("0.11.0"):
from accelerate.utils import get_balanced_memory
else:
get_balanced_memory = None
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -1697,7 +1703,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1697,7 +1703,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
same device. same device.
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
more information about each option see [designing a device
map](https://hf.co/docs/accelerate/main/big_modeling#designing-a-device-map).
max_memory (`Dict`, *optional*): max_memory (`Dict`, *optional*):
A dictionary device identifier to maximum memory. Will default to the maximum memory available for each A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
GPU and the available CPU RAM if unset. GPU and the available CPU RAM if unset.
...@@ -2105,10 +2113,25 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2105,10 +2113,25 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
with ContextManagers(init_contexts): with ContextManagers(init_contexts):
model = cls(config, *model_args, **model_kwargs) model = cls(config, *model_args, **model_kwargs)
if device_map == "auto": if isinstance(device_map, str):
if model._no_split_modules is None: if model._no_split_modules is None:
raise ValueError(f"{model.__class__.__name__} does not support `device_map='auto'` yet.") raise ValueError(f"{model.__class__.__name__} does not support `device_map='{device_map}'` yet.")
no_split_modules = model._no_split_modules no_split_modules = model._no_split_modules
if device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]:
raise ValueError(
"If passing a string for `device_map`, please choose 'auto', 'balanced', 'balanced_low_0' or "
"'sequential'."
)
elif device_map in ["balanced", "balanced_low_0"] and get_balanced_memory is None:
raise ValueError(f"`device_map={device_map}` requires a source install of Accelerate.")
if device_map != "sequential" and get_balanced_memory is not None:
max_memory = get_balanced_memory(
model,
max_memory=max_memory,
no_split_module_classes=no_split_modules,
dtype=torch_dtype,
low_zero=(device_map == "balanced_low_0"),
)
# Make sure tied weights are tied before creating the device map. # Make sure tied weights are tied before creating the device map.
model.tie_weights() model.tie_weights()
device_map = infer_auto_device_map( device_map = infer_auto_device_map(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment