"tests/ctrl/test_tokenization_ctrl.py" did not exist on "bc1715c1e0872187a8b76d2f258d43815dcf6067"
Unverified Commit 56f50590 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Use Accelerate in `from_pretrained` for big model inference (#17341)



* Initial work

* More or less finished with first draft

* Update src/transformers/modeling_utils.py
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>

* Update src/transformers/modeling_utils.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Fix randomly initialized weights

* Update src/transformers/modeling_utils.py
Co-authored-by: default avatarLysandre Debut <lysandre.debut@reseau.eseo.fr>

* Address review comments

* Rename DeepSpeed folder to temporarily fix the test issue?

* Revert to try if Accelerate fix works

* Use latest Accelerate release

* Quality and fixes

* Style

* Quality

* Add doc

* Test + fix

* More blocks
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarLysandre Debut <lysandre.debut@reseau.eseo.fr>
parent 2e7e4280
......@@ -38,6 +38,75 @@ for text generation, [`~generation_utils.GenerationMixin`] (for the PyTorch mode
<a id='from_pretrained-torch-dtype'></a>
### Large model loading
In Transformers 4.20.0, the [`~PreTrainedModel.from_pretrained`] method has been reworked to accommodate large models using [Accelerate](https://huggingface.co/docs/accelerate/big_modeling). This requires Accelerate >= 0.9.0 and PyTorch >= 1.9.0. Instead of creating the full model, then loading the pretrained weights inside it (which takes twice the size of the model in RAM, one for the randomly initialized model, one for the weights), there is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded.
This option can be activated with `low_cpu_mem_usage=True`. The model is first created on the Meta device (with empty weights) and the state dict is then loaded inside it (shard by shard in the case of a sharded checkpoint). This way the maximum RAM used is the full size of the model only.
```py
from transformers import AutoModelForSeq2SeqLM
t0pp = AutoModelForSeq2SeqLM.from_pretrained("bigscience/T0pp", low_cpu_mem_usage=True)
```
Moreover, you can directly place the model on different devices if it doesn't fully fit in RAM (only works for inference for now). With `device_map="auto"`, Accelerate will determine where to put each layer to maximize the use of your fastest devices (GPUs) and offload the rest on the CPU, or even the hard drive if you don't have enough GPU RAM (or CPU RAM). Even if the model is split across several devices, it will run as you would normally expect.
When passing a `device_map`, `low_cpu_mem_usage` is automatically set to `True`, so you don't need to specify it:
```py
from transformers import AutoModelForSeq2SeqLM
t0pp = AutoModelForSeq2SeqLM.from_pretrained("bigscience/T0pp", device_map="auto")
```
You can inspect how the model was split across devices by looking at its `hf_device_map` attribute:
```py
t0pp.hf_device_map
```
```python out
{'shared': 0,
'decoder.embed_tokens': 0,
'encoder': 0,
'decoder.block.0': 0,
'decoder.block.1': 1,
'decoder.block.2': 1,
'decoder.block.3': 1,
'decoder.block.4': 1,
'decoder.block.5': 1,
'decoder.block.6': 1,
'decoder.block.7': 1,
'decoder.block.8': 1,
'decoder.block.9': 1,
'decoder.block.10': 1,
'decoder.block.11': 1,
'decoder.block.12': 1,
'decoder.block.13': 1,
'decoder.block.14': 1,
'decoder.block.15': 1,
'decoder.block.16': 1,
'decoder.block.17': 1,
'decoder.block.18': 1,
'decoder.block.19': 1,
'decoder.block.20': 1,
'decoder.block.21': 1,
'decoder.block.22': 'cpu',
'decoder.block.23': 'cpu',
'decoder.final_layer_norm': 'cpu',
'decoder.dropout': 'cpu',
'lm_head': 'cpu'}
```
You can also write your own device map following the same format (a dictionary layer name to device). It should map all parameters of the model to a given device, but you don't have to detail where all the submosules of one layer go if that layer is entirely on the same device. For instance, the following device map would work properly for T0pp (as long as you have the GPU memory):
```python
device_map = {"shared": 0, "encoder": 0, "decoder": 1, "lm_head": 1}
```
Another way to minimize the memory impact of your model is to instantiate it at a lower precision dtype (like `torch.float16`).
### Model Instantiation dtype
Under Pytorch a model normally gets instantiated with `torch.float32` format. This can be an issue if one tries to
......
......@@ -97,7 +97,7 @@ if stale_egg_info.exists():
# 2. once modified, run: `make deps_table_update` to update src/transformers/dependency_versions_table.py
_deps = [
"Pillow",
"accelerate>=0.7.1",
"accelerate>=0.9.0",
"black~=22.0,>=22.3",
"codecarbon==1.2.0",
"cookiecutter==1.7.3",
......
......@@ -3,7 +3,7 @@
# 2. run `make deps_table_update``
deps = {
"Pillow": "Pillow",
"accelerate": "accelerate>=0.7.1",
"accelerate": "accelerate>=0.9.0",
"black": "black~=22.0,>=22.3",
"codecarbon": "codecarbon==1.2.0",
"cookiecutter": "cookiecutter==1.7.3",
......
......@@ -54,6 +54,7 @@ from .utils import (
TF_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
ContextManagers,
EntryNotFoundError,
ModelOutput,
PushToHubMixin,
......@@ -62,6 +63,7 @@ from .utils import (
cached_path,
has_file,
hf_bucket_url,
is_accelerate_available,
is_offline_mode,
is_remote_url,
logging,
......@@ -70,6 +72,15 @@ from .utils import (
from .utils.versions import require_version_core
if is_accelerate_available():
from accelerate import dispatch_model, infer_auto_device_map, init_empty_weights
from accelerate.utils import (
load_offloaded_weights,
offload_weight,
save_offload_index,
set_module_tensor_to_device,
)
logger = logging.get_logger(__name__)
......@@ -514,7 +525,19 @@ def _move_model_to_meta(model, loaded_state_dict_keys, start_prefix):
setattr(submodule, param_name, new_val)
def _load_state_dict_into_meta_model(model, state_dict, loaded_state_dict_keys, start_prefix):
def _load_state_dict_into_meta_model(
model,
state_dict,
loaded_state_dict_keys, # left for now but could be removed, see below
start_prefix,
expected_keys,
device_map=None,
offload_folder=None,
offload_index=None,
state_dict_folder=None,
state_dict_index=None,
dtype=None,
):
"""
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
params on a `meta` device. It replaces the model params with the data from the `state_dict`, while moving the
......@@ -532,23 +555,55 @@ def _load_state_dict_into_meta_model(model, state_dict, loaded_state_dict_keys,
# - Is there a situation where some keys aren't in `loaded_state_dict_keys` and in which case
# they won't get loaded.
if is_deepspeed_zero3_enabled():
raise ValueError("low_cpu_mem_usage arg cannot currently be used with DeepSpeed ZeRO-3")
error_msgs = []
# materialize state_dict entries one by one on CPU
for k in loaded_state_dict_keys:
if k in state_dict:
submodule, param_name = find_submodule_and_param_name(model, k, start_prefix)
if submodule is not None:
param_dtype = getattr(submodule, param_name).dtype
new_val = state_dict[k].to(param_dtype)
if isinstance(getattr(submodule, param_name), torch.nn.Parameter):
new_val = torch.nn.Parameter(new_val)
setattr(submodule, param_name, new_val)
old_keys = []
new_keys = []
for key in state_dict.keys():
new_key = None
if "gamma" in key:
new_key = key.replace("gamma", "weight")
if "beta" in key:
new_key = key.replace("beta", "bias")
if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
return error_msgs
for param_name, param in state_dict.items():
# First part of the test is always true as load_state_dict_keys always contains state_dict keys.
if param_name not in loaded_state_dict_keys or param_name not in expected_keys:
print(param_name)
continue
if param_name.startswith(start_prefix):
param_name = param_name[len(start_prefix) :]
module_name = param_name
# We convert floating dtypes to the `dtype` passed.
if dtype is not None and not str(param.dtype).startswith("torch.int"):
param = param.to(dtype)
if device_map is None:
param_device = "cpu"
else:
# find next higher level module that is defined in device_map:
# bert.lm_head.weight -> bert.lm_head -> bert -> ''
while len(module_name) > 0 and module_name not in device_map:
module_name = ".".join(module_name.split(".")[:-1])
if module_name == "" and "" not in device_map:
# TODO: group all errors and raise at the end.
raise ValueError(f"{param_name} doesn't have any device set.")
param_device = device_map[module_name]
set_module_tensor_to_device(model, param_name, param_device, value=param)
if param_device == "disk":
offload_index = offload_weight(param, param_name, offload_folder, offload_index)
elif param_device == "cpu" and state_dict_index is not None:
state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index)
return error_msgs, offload_index, state_dict_index
class ModuleUtilsMixin:
......@@ -870,6 +925,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
base_model_prefix = ""
main_input_name = "input_ids"
_auto_class = None
_no_split_modules = None
# a list of `re` patterns of `state_dict` keys that should be removed from the list of missing
# keys we find (keys inside the model but not in the checkpoint) and avoid unnecessary warnings.
......@@ -1664,12 +1720,26 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
</Tip>
low_cpu_mem_usage(`bool`, *optional*, defaults to `False`):
> Parameters for big model inference
low_cpu_mem_usage(`bool`, *optional*):
Tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
This is an experimental feature and a subject to change at any moment.
torch_dtype (`str` or `torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
will be automatically derived from the model's weights.
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
A map that specifies where each submodule should go. It doesn't need to be refined to each
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
same device.
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`.
offload_folder (`str` or `os.PathLike`, *optional*):
If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
offload_state_dict (`bool`, *optional*, defaults to `False`):
If `True`, will temporarily offload the CPU state dict to the hard drive to avoid getting out of CPU
RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
`output_attentions=True`). Behaves differently depending on whether a `config` is provided or
......@@ -1750,7 +1820,29 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
from_auto_class = kwargs.pop("_from_auto", False)
_fast_init = kwargs.pop("_fast_init", True)
torch_dtype = kwargs.pop("torch_dtype", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", None)
device_map = kwargs.pop("device_map", None)
offload_folder = kwargs.pop("offload_folder", None)
offload_state_dict = kwargs.pop("offload_state_dict", False)
if device_map is not None:
if low_cpu_mem_usage is None:
low_cpu_mem_usage = True
elif not low_cpu_mem_usage:
raise ValueError("Passing along a `device_map` requires `low_cpu_mem_usage=True`")
if low_cpu_mem_usage:
# low_cpu_mem_usage requires PyTorch >= 1.9 to have the meta device.
require_version_core("torch>=1.9")
if is_deepspeed_zero3_enabled():
raise ValueError(
"DeepSpeed Zero-3 is not compatible with `low_cpu_mem_usage=True` or with passing a `device_map`."
)
elif not is_accelerate_available():
raise ImportError(
"Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install accelerate`"
)
from_pt = not (from_tf | from_flax)
......@@ -1845,10 +1937,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
filename = WEIGHTS_NAME
archive_file = hf_bucket_url(
pretrained_model_name_or_path,
filename=filename,
revision=revision,
mirror=mirror,
pretrained_model_name_or_path, filename=filename, revision=revision, mirror=mirror
)
try:
......@@ -2013,18 +2102,24 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
config.name_or_path = pretrained_model_name_or_path
# Instantiate model.
init_contexts = [no_init_weights(_enable=_fast_init)]
if is_deepspeed_zero3_enabled():
import deepspeed
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
# this immediately partitions the model across all gpus, to avoid the overhead in time
# and memory copying it on CPU or each GPU first
with deepspeed.zero.Init(config_dict_or_path=deepspeed_config()):
with no_init_weights(_enable=_fast_init):
model = cls(config, *model_args, **model_kwargs)
else:
with no_init_weights(_enable=_fast_init):
model = cls(config, *model_args, **model_kwargs)
init_contexts = [deepspeed.zero.Init(config_dict_or_path=deepspeed_config())] + init_contexts
elif low_cpu_mem_usage:
init_contexts.append(init_empty_weights())
with ContextManagers(init_contexts):
model = cls(config, *model_args, **model_kwargs)
if device_map == "auto":
if model._no_split_modules is None:
raise ValueError(f"{model.__class__.__name__} does not support `device_map='auto'` yet.")
no_split_modules = model._no_split_modules
device_map = infer_auto_device_map(model, no_split_module_classes=no_split_modules, dtype=torch_dtype)
if from_tf:
if resolved_archive_file.endswith(".index"):
......@@ -2071,6 +2166,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
sharded_metadata=sharded_metadata,
_fast_init=_fast_init,
low_cpu_mem_usage=low_cpu_mem_usage,
device_map=device_map,
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
)
# make sure token embedding weights are still tied if needed
......@@ -2079,6 +2178,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Set model in evaluation mode to deactivate DropOut modules by default
model.eval()
# Dispatch model with hooks on all devices if necessary
if device_map is not None:
dispatch_model(model, device_map=device_map, offload_dir=offload_folder)
if output_loading_info:
loading_info = {
"missing_keys": missing_keys,
......@@ -2102,6 +2205,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
sharded_metadata=None,
_fast_init=True,
low_cpu_mem_usage=False,
device_map=None,
offload_folder=None,
offload_state_dict=False,
dtype=None,
):
# Retrieve missing & unexpected_keys
model_state_dict = model.state_dict()
......@@ -2149,8 +2256,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
for pat in cls._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
# retrieve weights on meta device and put them back on CPU.
# This is not ideal in terms of memory, but if we don't do that not, we can't initialize them in the next step
if low_cpu_mem_usage:
for key in missing_keys:
if key.startswith(prefix):
key = ".".join(key.split(".")[1:])
param = model_state_dict[key]
if param.device == torch.device("meta"):
set_module_tensor_to_device(model, key, "cpu", torch.empty(*param.size()))
# retrieve unintialized modules and initialize before maybe overriding that with the pretrained weights.
if _fast_init:
# retrieve unintialized modules and initialize
uninitialized_modules = model.retrieve_modules_from_names(
missing_keys, add_prefix=add_prefix_to_model, remove_prefix=remove_prefix_from_model
)
......@@ -2169,6 +2286,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"The state dictionary of the model you are trying to load is corrupted. Are you sure it was "
"properly saved?"
)
if device_map is not None:
device_map = {k.replace(f"{cls.base_model_prefix}.", ""): v for k, v in device_map.items()}
def _find_mismatched_keys(
state_dict,
......@@ -2199,10 +2318,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
del state_dict[checkpoint_key]
return mismatched_keys
if low_cpu_mem_usage:
model_state_dict = None # free references to model's params to allow memory freeing
_move_model_to_meta(model, loaded_keys, start_prefix)
if state_dict is not None:
# Whole checkpoint
mismatched_keys = _find_mismatched_keys(
......@@ -2223,12 +2338,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
error_msgs = []
mismatched_keys = []
offload_index = {} if device_map is not None and "disk" in device_map.values() else None
if offload_state_dict:
state_dict_folder = tempfile.mkdtemp()
state_dict_index = {}
else:
state_dict_folder = None
state_dict_index = None
for shard_file in resolved_archive_file:
state_dict = load_state_dict(shard_file)
if low_cpu_mem_usage:
model_state_dict = model.state_dict()
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
# matching the weights in the model.
mismatched_keys += _find_mismatched_keys(
......@@ -2241,9 +2361,20 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
if low_cpu_mem_usage:
error_msgs += _load_state_dict_into_meta_model(
model_to_load, state_dict, loaded_keys, start_prefix
new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
model_to_load,
state_dict,
loaded_keys,
start_prefix,
expected_keys,
device_map=device_map,
offload_folder=offload_folder,
offload_index=offload_index,
state_dict_folder=state_dict_folder,
state_dict_index=state_dict_index,
dtype=dtype,
)
error_msgs += new_error_msgs
else:
error_msgs += _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
......@@ -2251,6 +2382,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
del state_dict
gc.collect()
save_offload_index(offload_index, offload_folder)
if offload_state_dict:
# Load back temporarily offloaded state dict
load_offloaded_weights(model, state_dict_index, state_dict_folder)
shutil.rmtree(state_dict_folder)
if len(error_msgs) > 0:
error_msg = "\n\t".join(error_msgs)
if "size mismatch" in error_msg:
......
......@@ -448,6 +448,7 @@ class GPT2PreTrainedModel(PreTrainedModel):
base_model_prefix = "transformer"
is_parallelizable = True
supports_gradient_checkpointing = True
_no_split_modules = ["GPT2Block"]
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
......
......@@ -358,6 +358,7 @@ class GPTNeoPreTrainedModel(PreTrainedModel):
load_tf_weights = load_tf_weights_in_gpt_neo
base_model_prefix = "transformer"
supports_gradient_checkpointing = True
_no_split_modules = ["GPTNeoBlock"]
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
......
......@@ -334,6 +334,7 @@ class GPTJPreTrainedModel(PreTrainedModel):
base_model_prefix = "transformer"
is_parallelizable = True
supports_gradient_checkpointing = True
_no_split_modules = ["GPTJBlock"]
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
......
......@@ -747,6 +747,7 @@ class T5PreTrainedModel(PreTrainedModel):
base_model_prefix = "transformer"
is_parallelizable = True
supports_gradient_checkpointing = True
_no_split_modules = ["T5Block"]
@property
def dummy_inputs(self):
......
......@@ -94,6 +94,8 @@ if is_torch_available():
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
MODEL_MAPPING,
AdaptiveEmbedding,
AutoModelForCausalLM,
AutoTokenizer,
BertConfig,
BertModel,
PreTrainedModel,
......@@ -2595,6 +2597,22 @@ class ModelUtilsTest(TestCasePlus):
# functionality to load models directly on gpu, this test can be rewritten to use torch's
# cuda memory tracking and then we should be able to do a much more precise test.
@require_torch_multi_gpu
@slow
def test_model_parallelism_gpt2(self):
device_map = {"transformer.wte": 0, "transformer.wpe": 0, "lm_head": 0, "transformer.ln_f": 1}
for i in range(12):
device_map[f"transformer.h.{i}"] = 0 if i <= 5 else 1
model = AutoModelForCausalLM.from_pretrained("gpt2", device_map=device_map)
tokenizer = AutoTokenizer.from_pretrained("gpt2")
inputs = tokenizer("Hello, my name is", return_tensors="pt")
output = model.generate(inputs["input_ids"].to(0))
text_output = tokenizer.decode(output[0].tolist())
self.assertEqual(text_output, "Hello, my name is John. I'm a writer, and I'm a writer. I'm")
def test_cached_files_are_used_when_internet_is_down(self):
# A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment