Unverified Commit a8fa2f91 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Make Trainer compatible with sharded checkpoints (#17053)

* Make Trainer compatible with sharded checkpoints

* Add doc
parent 19420fd9
......@@ -61,6 +61,8 @@
title: Export 🤗 Transformers models
- local: performance
title: 'Performance and Scalability: How To Fit a Bigger Model and Train It Faster'
- local: big_models
title: Instantiating a big model
- local: parallelism
title: Model Parallelism
- local: benchmarks
......
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Instantiating a big model
When you want to use a very big pretrained model, one challenge is to minimize the use of the RAM. The usual workflow
from PyTorch is:
1. Create your model with random weights.
2. Load your pretrained weights.
3. Put those pretrained weights in your random model.
Step 1 and 2 both require a full version of the model in memory, which is not a problem in most cases, but if your model starts weighing several GigaBytes, those two copies can make you got our of RAM. Even worse, if you are using `torch.distributed` to launch a distributed training, each process will load the pretrained model and store these two copies in RAM.
<Tip>
Note that the randomly created model is initialized with "empty" tensors, which take the space in memory without filling it (thus the random values are whatever was in this chunk of memory at a given time). The random initialization following the appropriate distribution for the kind of model/parameters instatiated (like a normal distribution for instance) is only performed after step 3 on the non-initialized weights, to be as fast as possible!
</Tip>
In this guide, we explore the solutions Transformers offer to deal with this issue. Note that this is an area of active development, so the APIs explained here may change slightly in the future.
## Sharded checkpoints
Since version 4.18.0, model checkpoints that end up taking more than 10GB of space are automatically sharded in smaller pieces. In terms of having one single checkpoint when you do `model.save_pretrained(save_dir)`, you will end up with several partial checkpoints (each of which being of size < 10GB) and an index that maps parameter names to the files they are stored in.
You can control the maximum size before sharding with the `max_shard_size` parameter, so for the sake of an example, we'll use a normal-size models with a small shard size: let's take a traditional BERT model.
```py
from transformers import AutoModel
model = AutoModel.from_pretrained("bert-base-cased")
```
If you save it using [`~PreTrainedModel.save_pretrained`], you will get a new folder with two files: the config of the model and its weights:
```py
>>> import os
>>> import tempfile
>>> with tempfile.TemporaryDirectory() as tmp_dir:
... model.save_pretrained(tmp_dir)
... print(sorted(os.listdir(tmp_dir)))
['config.json', 'pytorch_model.bin']
```
Now let's use a maximum shard size of 200MB:
```py
>>> with tempfile.TemporaryDirectory() as tmp_dir:
... model.save_pretrained(tmp_dir, max_shard_size="200MB")
... print(sorted(os.listdir(tmp_dir)))
['config.json', 'pytorch_model-00001-of-00003.bin', 'pytorch_model-00002-of-00003.bin', 'pytorch_model-00003-of-00003.bin', 'pytorch_model.bin.index.json']
```
On top of the configuration of the model, we see three different weights files, and an `index.json` file which is our index. A checkpoint like this can be fully reloaded using the [`~PreTrainedModel.from_pretrained`] method:
```py
>>> with tempfile.TemporaryDirectory() as tmp_dir:
... model.save_pretrained(tmp_dir, max_shard_size="200MB")
... new_model = AutoModel.from_pretrained(tmp_dir)
```
The main advantage of doing this for big models is that during step 2 of the workflow shown above, each shard of the checkpoint is loaded after the previous one, capping the memory usage in RAM to the model size plus the size of the biggest shard.
Beind the scenes, the index file is used to determine which keys are in the checkpoint, and where the corresponding weights are stored. We can load that index like any json and get a dictionary:
```py
>>> import json
>>> with tempfile.TemporaryDirectory() as tmp_dir:
... model.save_pretrained(tmp_dir, max_shard_size="200MB")
... with open(os.path.join(tmp_dir, "pytorch_model.bin.index.json"), "r") as f:
... index = json.load(f)
>>> print(index.keys())
dict_keys(['metadata', 'weight_map'])
```
The metadata just consists of the total size of the model for now. We plan to add several other informations in the future:
```py
>>> index["metadata"]
{'total_size': 433245184}
```
The weights map is the main part of this index, which maps each parameter name (as usually found in a PyTorch model `state_dict`) to the file it's stored in:
```py
>>> index["weight_map"]
{'embeddings.LayerNorm.bias': 'pytorch_model-00001-of-00003.bin',
'embeddings.LayerNorm.weight': 'pytorch_model-00001-of-00003.bin',
...
```
If you want to directly load such a sharded checkpoint inside a model without using [`~PreTrainedModel.from_pretrained`] (like you would do `model.load_state_dict()` for a full checkpoint) you should use [`~modeling_utils.load_sharded_checkpoint`]:
```py
>>> from transformers.modeling_utils import load_sharded_checkpoint
>>> with tempfile.TemporaryDirectory() as tmp_dir:
... model.save_pretrained(tmp_dir, max_shard_size="200MB")
... load_sharded_checkpoint(model, tmp_dir)
```
## Low memory loading
Sharded checkpoints reduce the memory usage during step 2 of the worflow mentioned above, but when loadin a pretrained model, why keep the random weights in memory? The option `low_cpu_mem_usage` will destroy the weights of the randomly initialized model, then progressively load the weights inside, then perform a random initialization for potential missing weights (if you are loadding a model with a newly initialized head for a fine-tuning task for instance).
It's very easy to use, just add `low_cpu_mem_usage=True` to your call to [`~PreTrainedModel.from_pretrained`]:
```py
from transformers import AutoModelForSequenceClas
model = AutoModel.from_pretrained("bert-base-cased", low_cpu_mem_usage=True)
```
This can be used in conjunction with a sharded checkpoint.
......@@ -89,3 +89,7 @@ Due to Pytorch design, this functionality is only available for floating dtypes.
## Pushing to the Hub
[[autodoc]] utils.PushToHubMixin
## Sharded checkpoints
[[autodoc]] modeling_utils.load_sharded_checkpoint
......@@ -327,6 +327,63 @@ def get_checkpoint_shard_files(
return cached_filenames, sharded_metadata
def load_sharded_checkpoint(model, folder, strict=True):
"""
This is the same as
[`torch.nn.Module.load_state_dict`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict)
but for a sharded checkpoint.
This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being
loaded in the model.
Args:
model (`torch.nn.Module`): The model in which to load the checkpoint.
folder (`str` or `os.PathLike`): A path to a folder containing the sharded checkpoint.
strict (`bool`, *optional`, defaults to `True`):
Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint.
Returns:
`NamedTuple`: A named tuple with `missing_keys` and `unexpected_keys` fields
- `missing_keys` is a list of str containing the missing keys
- `unexpected_keys` is a list of str containing the unexpected keys
"""
# Load the index
index_file = os.path.join(folder, WEIGHTS_INDEX_NAME)
if not os.path.isfile(index_file):
raise ValueError(f"Can't find a checkpoint index ({WEIGHTS_INDEX_NAME}) in {folder}.")
with open(index_file, "r", encoding="utf-8") as f:
index = json.load(f)
shard_files = list(set(index["weight_map"].values()))
# If strict=True, error before loading any of the state dicts.
loaded_keys = index["weight_map"].keys()
model_keys = model.state_dict().keys()
missing_keys = [key for key in model_keys if key not in loaded_keys]
unexpected_keys = [key for key in loaded_keys if key not in model_keys]
if strict and (len(missing_keys) > 0 or len(unexpected_keys) > 0):
error_message = f"Error(s) in loading state_dict for {model.__class__.__name__}"
if len(missing_keys) > 0:
str_missing_keys = ",".join([f'"{k}"' for k in missing_keys])
error_message += f"\nMissing key(s): {str_missing_keys}."
if len(unexpected_keys) > 0:
str_unexpected_keys = ",".join([f'"{k}"' for k in unexpected_keys])
error_message += f"\nMissing key(s): {str_unexpected_keys}."
raise RuntimeError(error_message)
for shard_file in shard_files:
state_dict = torch.load(os.path.join(folder, shard_file))
model.load_state_dict(state_dict, strict=False)
# Make sure memory is fred before we load the next state dict.
del state_dict
gc.collect()
# Return the same thing as PyTorch load_state_dict function.
return torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys)
def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
"""
Reads a PyTorch checkpoint file, returning properly formatted errors if they arise.
......
......@@ -66,7 +66,7 @@ from .debug_utils import DebugOption, DebugUnderflowOverflow
from .deepspeed import deepspeed_init, deepspeed_reinit, is_deepspeed_zero3_enabled
from .dependency_versions_check import dep_version_check
from .modelcard import TrainingSummary
from .modeling_utils import PreTrainedModel, unwrap_model
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
from .optimization import Adafactor, get_scheduler
from .tokenization_utils_base import PreTrainedTokenizerBase
from .trainer_callback import (
......@@ -122,6 +122,7 @@ from .trainer_utils import (
from .training_args import OptimizerNames, ParallelMode, TrainingArguments
from .utils import (
CONFIG_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
find_labels,
get_full_repo_name,
......@@ -1559,7 +1560,9 @@ class Trainer:
return TrainOutput(self.state.global_step, train_loss, metrics)
def _load_from_checkpoint(self, resume_from_checkpoint):
if not os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)):
if not os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)) and not os.path.isfile(
os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME)
):
raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")
logger.info(f"Loading model from {resume_from_checkpoint}).")
......@@ -1577,14 +1580,19 @@ class Trainer:
if self.args.deepspeed:
# will be resumed in deepspeed_init
pass
else:
elif os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)):
# We load the model state dict on the CPU to avoid an OOM error.
state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu")
# If the model is on the GPU, it still works!
self._load_state_dict_in_model(state_dict)
load_result = self.model.load_state_dict(state_dict, strict=False)
self._issue_warnings_after_load(load_result)
# release memory
del state_dict
else:
# We load the sharded checkpoint
load_result = load_sharded_checkpoint(self.model, resume_from_checkpoint, strict=False)
self._issue_warnings_after_load(load_result)
def _load_best_model(self):
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
......@@ -1606,15 +1614,19 @@ class Trainer:
# We load the model state dict on the CPU to avoid an OOM error.
state_dict = torch.load(best_model_path, map_location="cpu")
# If the model is on the GPU, it still works!
self._load_state_dict_in_model(state_dict)
load_result = self.model.load_state_dict(state_dict, strict=False)
self._issue_warnings_after_load(load_result)
elif os.path.exists(best_model_path, os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)):
# Best model is a sharded checkpoint
load_result = load_sharded_checkpoint(self.model, self.state.best_model_checkpoint, strict=False)
self._issue_warnings_after_load(load_result)
else:
logger.warning(
f"Could not locate the best model at {best_model_path}, if you are running a distributed training "
"on multiple nodes, you should activate `--save_on_each_node`."
)
def _load_state_dict_in_model(self, state_dict):
load_result = self.model.load_state_dict(state_dict, strict=False)
def _issue_warnings_after_load(self, load_result):
if len(load_result.missing_keys) != 0:
if self.model._keys_to_ignore_on_save is not None and set(load_result.missing_keys) == set(
......
......@@ -15,6 +15,7 @@
import dataclasses
import gc
import json
import math
import os
import random
......@@ -65,7 +66,7 @@ from transformers.testing_utils import (
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers.training_args import OptimizerNames
from transformers.utils import WEIGHTS_NAME, is_apex_available, is_bitsandbytes_available
from transformers.utils import WEIGHTS_INDEX_NAME, WEIGHTS_NAME, is_apex_available, is_bitsandbytes_available
from transformers.utils.hp_naming import TrialShortNamer
......@@ -376,6 +377,25 @@ class TrainerIntegrationCommon:
_ = log1.pop(key, None)
self.assertEqual(log, log1)
def convert_to_sharded_checkpoint(self, folder):
# Converts a checkpoint of a regression model to a sharded checkpoint.
state_dict = torch.load(os.path.join(folder, WEIGHTS_NAME))
os.remove(os.path.join(folder, WEIGHTS_NAME))
keys = list(state_dict.keys())
shard_files = [
WEIGHTS_NAME.replace(".bin", f"-{idx+1:05d}-of-{len(keys):05d}.bin") for idx in range(len(keys))
]
index = {"metadata": {}, "weight_map": {key: shard_files[i] for i, key in enumerate(keys)}}
save_index_file = os.path.join(folder, WEIGHTS_INDEX_NAME)
with open(save_index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
for param_name, shard_file in zip(keys, shard_files):
torch.save({param_name: state_dict[param_name]}, os.path.join(folder, shard_file))
@require_torch
@require_sentencepiece
......@@ -1038,6 +1058,31 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
trainer.train(resume_from_checkpoint=False)
@require_torch_up_to_2_gpus
def test_resume_training_with_shard_checkpoint(self):
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
# save_steps, the checkpoint will resume training at epoch 2 or more (so the data seen by the model
# won't be the same since the training dataloader is shuffled).
with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1)
trainer.train()
(a, b) = trainer.model.a.item(), trainer.model.b.item()
state = dataclasses.asdict(trainer.state)
checkpoint = os.path.join(tmpdir, "checkpoint-5")
self.convert_to_sharded_checkpoint(checkpoint)
# Reinitialize trainer
trainer = get_regression_trainer(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1)
trainer.train(resume_from_checkpoint=checkpoint)
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
state1 = dataclasses.asdict(trainer.state)
self.assertEqual(a, a1)
self.assertEqual(b, b1)
self.check_trainer_state_are_the_same(state, state1)
@require_torch_up_to_2_gpus
def test_resume_training_with_gradient_accumulation(self):
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment