"...composable_kernel_rocm.git" did not exist on "72fd9184c89bc95d15b3b7f781dce7f51783e4b2"
Unverified Commit 096f3046 authored by Steven Liu's avatar Steven Liu Committed by GitHub
Browse files

[docs] Big model loading (#29920)

* update

* feedback
parent c9f6e5e3
...@@ -172,7 +172,7 @@ ...@@ -172,7 +172,7 @@
title: GPU inference title: GPU inference
title: Optimizing inference title: Optimizing inference
- local: big_models - local: big_models
title: Instantiating a big model title: Instantiate a big model
- local: debugging - local: debugging
title: Debugging title: Debugging
- local: tf_xla - local: tf_xla
......
...@@ -14,110 +14,202 @@ rendered properly in your Markdown viewer. ...@@ -14,110 +14,202 @@ rendered properly in your Markdown viewer.
--> -->
# Instantiating a big model # Instantiate a big model
When you want to use a very big pretrained model, one challenge is to minimize the use of the RAM. The usual workflow A barrier to accessing very large pretrained models is the amount of memory required. When loading a pretrained PyTorch model, you usually:
from PyTorch is:
1. Create your model with random weights. 1. Create a model with random weights.
2. Load your pretrained weights. 2. Load your pretrained weights.
3. Put those pretrained weights in your random model. 3. Put those pretrained weights in the model.
Step 1 and 2 both require a full version of the model in memory, which is not a problem in most cases, but if your model starts weighing several GigaBytes, those two copies can make you get out of RAM. Even worse, if you are using `torch.distributed` to launch a distributed training, each process will load the pretrained model and store these two copies in RAM. The first two steps both require a full version of the model in memory and if the model weighs several GBs, you may not have enough memory for two copies of it. This problem is amplified in distributed training environments because each process loads a pretrained model and stores two copies in memory.
<Tip> > [!TIP]
> The randomly created model is initialized with "empty" tensors, which take space in memory without filling it. The random values are whatever was in this chunk of memory at the time. To improve loading speed, the [`_fast_init`](https://github.com/huggingface/transformers/blob/c9f6e5e35156e068b227dd9b15521767f6afd4d2/src/transformers/modeling_utils.py#L2710) parameter is set to `True` by default to skip the random initialization for all weights that are correctly loaded.
Note that the randomly created model is initialized with "empty" tensors, which take the space in memory without filling it (thus the random values are whatever was in this chunk of memory at a given time). The random initialization following the appropriate distribution for the kind of model/parameters instantiated (like a normal distribution for instance) is only performed after step 3 on the non-initialized weights, to be as fast as possible! This guide will show you how Transformers can help you load large pretrained models despite their memory requirements.
</Tip>
In this guide, we explore the solutions Transformers offer to deal with this issue. Note that this is an area of active development, so the APIs explained here may change slightly in the future.
## Sharded checkpoints ## Sharded checkpoints
Since version 4.18.0, model checkpoints that end up taking more than 10GB of space are automatically sharded in smaller pieces. In terms of having one single checkpoint when you do `model.save_pretrained(save_dir)`, you will end up with several partial checkpoints (each of which being of size < 10GB) and an index that maps parameter names to the files they are stored in. From Transformers v4.18.0, a checkpoint larger than 10GB is automatically sharded by the [`~PreTrainedModel.save_pretrained`] method. It is split into several smaller partial checkpoints and creates an index file that maps parameter names to the files they're stored in.
You can control the maximum size before sharding with the `max_shard_size` parameter, so for the sake of an example, we'll use a normal-size models with a small shard size: let's take a traditional BERT model. The maximum shard size is controlled with the `max_shard_size` parameter, but by default it is 5GB, because it is easier to run on free-tier GPU instances without running out of memory.
```py For example, let's shard [BioMistral/BioMistral-7B](https://hf.co/BioMistral/BioMistral-7B).
from transformers import AutoModel
model = AutoModel.from_pretrained("google-bert/bert-base-cased")
```
If you save it using [`~PreTrainedModel.save_pretrained`], you will get a new folder with two files: the config of the model and its weights:
```py ```py
>>> import os
>>> import tempfile
>>> with tempfile.TemporaryDirectory() as tmp_dir: >>> with tempfile.TemporaryDirectory() as tmp_dir:
... model.save_pretrained(tmp_dir) ... model.save_pretrained(tmp_dir, max_shard_size="5GB")
... print(sorted(os.listdir(tmp_dir))) ... print(sorted(os.listdir(tmp_dir)))
['config.json', 'pytorch_model.bin'] ['config.json', 'generation_config.json', 'model-00001-of-00006.safetensors', 'model-00002-of-00006.safetensors', 'model-00003-of-00006.safetensors', 'model-00004-of-00006.safetensors', 'model-00005-of-00006.safetensors', 'model-00006-of-00006.safetensors', 'model.safetensors.index.json']
``` ```
Now let's use a maximum shard size of 200MB: The sharded checkpoint is reloaded with the [`~PreTrainedModel.from_pretrained`] method.
```py ```py
>>> with tempfile.TemporaryDirectory() as tmp_dir: >>> with tempfile.TemporaryDirectory() as tmp_dir:
... model.save_pretrained(tmp_dir, max_shard_size="200MB") ... model.save_pretrained(tmp_dir, max_shard_size="5GB")
... print(sorted(os.listdir(tmp_dir))) ... new_model = AutoModel.from_pretrained(tmp_dir)
['config.json', 'pytorch_model-00001-of-00003.bin', 'pytorch_model-00002-of-00003.bin', 'pytorch_model-00003-of-00003.bin', 'pytorch_model.bin.index.json']
``` ```
On top of the configuration of the model, we see three different weights files, and an `index.json` file which is our index. A checkpoint like this can be fully reloaded using the [`~PreTrainedModel.from_pretrained`] method: The main advantage of sharded checkpoints for big models is that each shard is loaded after the previous one, which caps the memory usage to only the model size and the largest shard size.
You could also directly load a sharded checkpoint inside a model without the [`~PreTrainedModel.from_pretrained`] method (similar to PyTorch's `load_state_dict()` method for a full checkpoint). In this case, use the [`~modeling_utils.load_sharded_checkpoint`] method.
```py ```py
>>> from transformers.modeling_utils import load_sharded_checkpoint
>>> with tempfile.TemporaryDirectory() as tmp_dir: >>> with tempfile.TemporaryDirectory() as tmp_dir:
... model.save_pretrained(tmp_dir, max_shard_size="200MB") ... model.save_pretrained(tmp_dir, max_shard_size="5GB")
... new_model = AutoModel.from_pretrained(tmp_dir) ... load_sharded_checkpoint(model, tmp_dir)
``` ```
The main advantage of doing this for big models is that during step 2 of the workflow shown above, each shard of the checkpoint is loaded after the previous one, capping the memory usage in RAM to the model size plus the size of the biggest shard. ### Shard metadata
Behind the scenes, the index file is used to determine which keys are in the checkpoint, and where the corresponding weights are stored. We can load that index like any json and get a dictionary: The index file determines which keys are in the checkpoint and where the corresponding weights are stored. This file is loaded like any other JSON file and you can get a dictionary from it.
```py ```py
>>> import json >>> import json
>>> with tempfile.TemporaryDirectory() as tmp_dir: >>> with tempfile.TemporaryDirectory() as tmp_dir:
... model.save_pretrained(tmp_dir, max_shard_size="200MB") ... model.save_pretrained(tmp_dir, max_shard_size="5GB")
... with open(os.path.join(tmp_dir, "pytorch_model.bin.index.json"), "r") as f: ... with open(os.path.join(tmp_dir, "model.safetensors.index.json"), "r") as f:
... index = json.load(f) ... index = json.load(f)
>>> print(index.keys()) >>> print(index.keys())
dict_keys(['metadata', 'weight_map']) dict_keys(['metadata', 'weight_map'])
``` ```
The metadata just consists of the total size of the model for now. We plan to add other information in the future: The `metadata` key provides the total model size.
```py ```py
>>> index["metadata"] >>> index["metadata"]
{'total_size': 433245184} {'total_size': 28966928384}
``` ```
The weights map is the main part of this index, which maps each parameter name (as usually found in a PyTorch model `state_dict`) to the file it's stored in: The `weight_map` key maps each parameter name (typically `state_dict` in a PyTorch model) to the shard it's stored in.
```py ```py
>>> index["weight_map"] >>> index["weight_map"]
{'embeddings.LayerNorm.bias': 'pytorch_model-00001-of-00003.bin', {'lm_head.weight': 'model-00006-of-00006.safetensors',
'embeddings.LayerNorm.weight': 'pytorch_model-00001-of-00003.bin', 'model.embed_tokens.weight': 'model-00001-of-00006.safetensors',
'model.layers.0.input_layernorm.weight': 'model-00001-of-00006.safetensors',
'model.layers.0.mlp.down_proj.weight': 'model-00001-of-00006.safetensors',
... ...
}
``` ```
If you want to directly load such a sharded checkpoint inside a model without using [`~PreTrainedModel.from_pretrained`] (like you would do `model.load_state_dict()` for a full checkpoint) you should use [`~modeling_utils.load_sharded_checkpoint`]: ## Accelerate's Big Model Inference
> [!TIP]
> Make sure you have Accelerate v0.9.0 or later and PyTorch v1.9.0 or later installed.
From Transformers v4.20.0, the [`~PreTrainedModel.from_pretrained`] method is supercharged with Accelerate's [Big Model Inference](https://hf.co/docs/accelerate/usage_guides/big_modeling) feature to efficiently handle really big models! Big Model Inference creates a *model skeleton* on PyTorch's [**meta**](https://pytorch.org/docs/main/meta.html) device. The randomly initialized parameters are only created when the pretrained weights are loaded. This way, you aren't keeping two copies of the model in memory at the same time (one for the randomly initialized model and one for the pretrained weights), and the maximum memory consumed is only the full model size.
To enable Big Model Inference in Transformers, set `low_cpu_mem_usage=True` in the [`~PreTrainedModel.from_pretrained`] method.
```py ```py
>>> from transformers.modeling_utils import load_sharded_checkpoint from transformers import AutoModelForCausalLM
>>> with tempfile.TemporaryDirectory() as tmp_dir: gemma = AutoModelForCausalLM.from_pretrained("google/gemma-7b", low_cpu_mem_usage=True)
... model.save_pretrained(tmp_dir, max_shard_size="200MB") ```
... load_sharded_checkpoint(model, tmp_dir)
Accelerate automatically dispatches the model weights across all available devices, starting with the fastest device (GPU) first and then offloading to the slower devices (CPU and even hard drive). This is enabled by setting `device_map="auto"` in the [`~PreTrainedModel.from_pretrained`] method. When you pass the `device_map` parameter, `low_cpu_mem_usage` is automatically set to `True` so you don't need to specify it.
```py
from transformers import AutoModelForCausalLM
# these loading methods are equivalent
gemma = AutoModelForCausalLM.from_pretrained("google/gemma-7b", device_map="auto")
gemma = AutoModelForCausalLM.from_pretrained("google/gemma-7b", device_map="auto", low_cpu_mem_usage=True)
``` ```
## Low memory loading You can also write your own `device_map` by mapping each layer to a device. It should map all model parameters to a device, but you don't have to detail where all the submodules of a layer go if the entire layer is on the same device.
Sharded checkpoints reduce the memory usage during step 2 of the workflow mentioned above, but in order to use that model in a low memory setting, we recommend leveraging our tools based on the Accelerate library. ```python
device_map = {"model.layers.1": 0, "model.layers.14": 1, "model.layers.31": "cpu", "lm_head": "disk"}
```
Access `hf_device_map` attribute to see how Accelerate split the model across devices.
```py
gemma.hf_device_map
```
```python out
{'model.embed_tokens': 0,
'model.layers.0': 0,
'model.layers.1': 0,
'model.layers.2': 0,
'model.layers.3': 0,
'model.layers.4': 0,
'model.layers.5': 0,
'model.layers.6': 0,
'model.layers.7': 0,
'model.layers.8': 0,
'model.layers.9': 0,
'model.layers.10': 0,
'model.layers.11': 0,
'model.layers.12': 0,
'model.layers.13': 0,
'model.layers.14': 'cpu',
'model.layers.15': 'cpu',
'model.layers.16': 'cpu',
'model.layers.17': 'cpu',
'model.layers.18': 'cpu',
'model.layers.19': 'cpu',
'model.layers.20': 'cpu',
'model.layers.21': 'cpu',
'model.layers.22': 'cpu',
'model.layers.23': 'cpu',
'model.layers.24': 'cpu',
'model.layers.25': 'cpu',
'model.layers.26': 'cpu',
'model.layers.27': 'cpu',
'model.layers.28': 'cpu',
'model.layers.29': 'cpu',
'model.layers.30': 'cpu',
'model.layers.31': 'cpu',
'model.norm': 'cpu',
'lm_head': 'cpu'}
```
Please read the following guide for more information: [Large model loading using Accelerate](./main_classes/model#large-model-loading) ## Model data type
PyTorch model weights are normally instantiated as torch.float32 and it can be an issue if you try to load a model as a different data type. For example, you'd need twice as much memory to load the weights in torch.float32 and then again to load them in your desired data type, like torch.float16.
> [!WARNING]
> Due to how PyTorch is designed, the `torch_dtype` parameter only supports floating data types.
To avoid wasting memory like this, explicitly set the `torch_dtype` parameter to the desired data type or set `torch_dtype="auto"` to load the weights with the most optimal memory pattern (the data type is automatically derived from the model weights).
<hfoptions id="dtype">
<hfoption id="specific dtype">
```py
from transformers import AutoModelForCausalLM
gemma = AutoModelForCausalLM.from_pretrained("google/gemma-7b", torch_dtype=torch.float16)
```
</hfoption>
<hfoption id="auto dtype">
```py
from transformers import AutoModelForCausalLM
gemma = AutoModelForCausalLM.from_pretrained("google/gemma-7b", torch_dtype="auto")
```
</hfoption>
</hfoptions>
You can also set the data type to use for models instantiated from scratch.
```python
import torch
from transformers import AutoConfig, AutoModel
my_config = AutoConfig.from_pretrained("google/gemma-2b", torch_dtype=torch.float16)
model = AutoModel.from_config(my_config)
```
...@@ -40,104 +40,6 @@ for text generation, [`~generation.GenerationMixin`] (for the PyTorch models), ...@@ -40,104 +40,6 @@ for text generation, [`~generation.GenerationMixin`] (for the PyTorch models),
- push_to_hub - push_to_hub
- all - all
<a id='from_pretrained-torch-dtype'></a>
### Large model loading
In Transformers 4.20.0, the [`~PreTrainedModel.from_pretrained`] method has been reworked to accommodate large models using [Accelerate](https://huggingface.co/docs/accelerate/big_modeling). This requires Accelerate >= 0.9.0 and PyTorch >= 1.9.0. Instead of creating the full model, then loading the pretrained weights inside it (which takes twice the size of the model in RAM, one for the randomly initialized model, one for the weights), there is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded.
This option can be activated with `low_cpu_mem_usage=True`. The model is first created on the Meta device (with empty weights) and the state dict is then loaded inside it (shard by shard in the case of a sharded checkpoint). This way the maximum RAM used is the full size of the model only.
```py
from transformers import AutoModelForSeq2SeqLM
t0pp = AutoModelForSeq2SeqLM.from_pretrained("bigscience/T0pp", low_cpu_mem_usage=True)
```
Moreover, you can directly place the model on different devices if it doesn't fully fit in RAM (only works for inference for now). With `device_map="auto"`, Accelerate will determine where to put each layer to maximize the use of your fastest devices (GPUs) and offload the rest on the CPU, or even the hard drive if you don't have enough GPU RAM (or CPU RAM). Even if the model is split across several devices, it will run as you would normally expect.
When passing a `device_map`, `low_cpu_mem_usage` is automatically set to `True`, so you don't need to specify it:
```py
from transformers import AutoModelForSeq2SeqLM
t0pp = AutoModelForSeq2SeqLM.from_pretrained("bigscience/T0pp", device_map="auto")
```
You can inspect how the model was split across devices by looking at its `hf_device_map` attribute:
```py
t0pp.hf_device_map
```
```python out
{'shared': 0,
'decoder.embed_tokens': 0,
'encoder': 0,
'decoder.block.0': 0,
'decoder.block.1': 1,
'decoder.block.2': 1,
'decoder.block.3': 1,
'decoder.block.4': 1,
'decoder.block.5': 1,
'decoder.block.6': 1,
'decoder.block.7': 1,
'decoder.block.8': 1,
'decoder.block.9': 1,
'decoder.block.10': 1,
'decoder.block.11': 1,
'decoder.block.12': 1,
'decoder.block.13': 1,
'decoder.block.14': 1,
'decoder.block.15': 1,
'decoder.block.16': 1,
'decoder.block.17': 1,
'decoder.block.18': 1,
'decoder.block.19': 1,
'decoder.block.20': 1,
'decoder.block.21': 1,
'decoder.block.22': 'cpu',
'decoder.block.23': 'cpu',
'decoder.final_layer_norm': 'cpu',
'decoder.dropout': 'cpu',
'lm_head': 'cpu'}
```
You can also write your own device map following the same format (a dictionary layer name to device). It should map all parameters of the model to a given device, but you don't have to detail where all the submodules of one layer go if that layer is entirely on the same device. For instance, the following device map would work properly for T0pp (as long as you have the GPU memory):
```python
device_map = {"shared": 0, "encoder": 0, "decoder": 1, "lm_head": 1}
```
Another way to minimize the memory impact of your model is to instantiate it at a lower precision dtype (like `torch.float16`) or use direct quantization techniques as described below.
### Model Instantiation dtype
Under Pytorch a model normally gets instantiated with `torch.float32` format. This can be an issue if one tries to
load a model whose weights are in fp16, since it'd require twice as much memory. To overcome this limitation, you can
either explicitly pass the desired `dtype` using `torch_dtype` argument:
```python
model = T5ForConditionalGeneration.from_pretrained("t5", torch_dtype=torch.float16)
```
or, if you want the model to always load in the most optimal memory pattern, you can use the special value `"auto"`,
and then `dtype` will be automatically derived from the model's weights:
```python
model = T5ForConditionalGeneration.from_pretrained("t5", torch_dtype="auto")
```
Models instantiated from scratch can also be told which `dtype` to use with:
```python
config = T5Config.from_pretrained("t5")
model = AutoModel.from_config(config)
```
Due to Pytorch design, this functionality is only available for floating dtypes.
## ModuleUtilsMixin ## ModuleUtilsMixin
[[autodoc]] modeling_utils.ModuleUtilsMixin [[autodoc]] modeling_utils.ModuleUtilsMixin
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment