"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "3b9174f248a9c16e05a6701f1c5bdc28fc19f995"
Unverified Commit 5da33f87 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[modeling utils] revamp `from_pretrained(..., low_cpu_mem_usage=True)` + tests (#16657)

* add low_cpu_mem_usage tests

* wip: revamping

* wip

* install /usr/bin/time

* wip

* cleanup

* cleanup

* cleanup

* cleanup

* cleanup

* fix assert

* put the wrapper back

* cleanup; switch to bert-base-cased

* Trigger CI

* Trigger CI
parent ce2fef2a
......@@ -217,7 +217,7 @@ jobs:
keys:
- v0.4-torch-{{ checksum "setup.py" }}
- v0.4-{{ checksum "setup.py" }}
- run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev espeak-ng
- run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev espeak-ng time
- run: pip install --upgrade pip
- run: pip install .[sklearn,torch,testing,sentencepiece,torch-speech,vision,timm]
- run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.11.0+cpu.html
......
......@@ -400,6 +400,95 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
return error_msgs
def find_submodule_and_param_name(model, long_key, start_prefix):
"""
A helper util to find the last sub-module and the param/buffer name. If `start_prefix` is supplied it'll be removed
from the start of the key
"""
if len(start_prefix) > 0 and long_key.startswith(start_prefix):
long_key = ".".join(long_key.split(".")[1:])
split_key = long_key.split(".")
submodule = model
while len(split_key) > 1:
if hasattr(submodule, split_key[0]):
submodule = getattr(submodule, split_key[0])
del split_key[0]
else:
submodule = None
break
if submodule == model:
submodule = None
return submodule, split_key[0]
def _move_model_to_meta(model, loaded_state_dict_keys, start_prefix):
"""
Moves `loaded_state_dict_keys` in model to meta device which frees up the memory taken by those params.
`start_prefix` is used for models which insert their name into model keys, e.g. `bert` in
`bert.pooler.dense.weight`
"""
# meta device was added in pt=1.9
require_version_core("torch>=1.9")
# dematerialize param storage for keys that are going to be replaced by state_dict, by
# putting those on the meta device
for k in loaded_state_dict_keys:
submodule, param_name = find_submodule_and_param_name(model, k, start_prefix)
if submodule is not None:
# selectively switch to the meta device only those params/buffers that will
# be next replaced from state_dict. This a complex way to do p.to_("meta")
# since we have no in-place to_ for tensors.
new_val = getattr(submodule, param_name)
if isinstance(new_val, torch.nn.Parameter):
# isinstance returns False for Params on meta device, so switch after the check
new_val = torch.nn.Parameter(new_val.to("meta"))
else:
new_val = new_val.to("meta")
setattr(submodule, param_name, new_val)
def _load_state_dict_into_meta_model(model, state_dict, loaded_state_dict_keys, start_prefix):
"""
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
params on a `meta` device. It replaces the model params with the data from the `state_dict`, while moving the
params back to the normal device, but only for `loaded_state_dict_keys`.
`start_prefix` is used for models which insert their name into model keys, e.g. `bert` in
`bert.pooler.dense.weight`
"""
# XXX: remaining features to implement to be fully compatible with _load_state_dict_into_model
# - deepspeed zero 3 support
# - need to copy metadata if any - see _load_state_dict_into_model
# - handling error_msgs - mimicking the error handling in module._load_from_state_dict()
# - Is there a situation where some keys aren't in `loaded_state_dict_keys` and in which case
# they won't get loaded.
if is_deepspeed_zero3_enabled():
raise ValueError("low_cpu_mem_usage arg cannot currently be used with DeepSpeed ZeRO-3")
error_msgs = []
# materialize state_dict entries one by one on CPU
for k in loaded_state_dict_keys:
if k in state_dict:
submodule, param_name = find_submodule_and_param_name(model, k, start_prefix)
if submodule is not None:
param_dtype = getattr(submodule, param_name).dtype
new_val = state_dict[k].to(param_dtype)
if isinstance(getattr(submodule, param_name), torch.nn.Parameter):
new_val = torch.nn.Parameter(new_val)
setattr(submodule, param_name, new_val)
return error_msgs
class ModuleUtilsMixin:
"""
A few utilities for `torch.nn.Modules`, to be used as a mixin.
......@@ -1529,7 +1618,24 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
>>> model = BertModel.from_pretrained("./tf_model/my_tf_checkpoint.ckpt.index", from_tf=True, config=config)
>>> # Loading from a Flax checkpoint file instead of a PyTorch model (slower)
>>> model = BertModel.from_pretrained("bert-base-uncased", from_flax=True)
```"""
```
* `low_cpu_mem_usage` algorithm:
This is an experimental function that loads the model using ~1x model size CPU memory
Here is how it works:
1. save which state_dict keys we have
2. drop state_dict before the model is created, since the latter takes 1x model size CPU memory
3. after the model has been instantiated switch to the meta device all params/buffers that
are going to be replaced from the loaded state_dict
4. load state_dict 2nd time
5. replace the params/buffers from the state_dict
Currently, it can't handle deepspeed ZeRO stage 3 and ignores loading errors
"""
config = kwargs.pop("config", None)
state_dict = kwargs.pop("state_dict", None)
cache_dir = kwargs.pop("cache_dir", None)
......@@ -1778,6 +1884,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if not is_sharded and state_dict is None:
# Time to load the checkpoint
state_dict = load_state_dict(resolved_archive_file)
# set dtype to instantiate the model under:
# 1. If torch_dtype is not None, we use that dtype
# 2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first
......@@ -1801,13 +1908,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
dtype_orig = cls._set_default_torch_dtype(torch_dtype)
if is_sharded:
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
else:
loaded_state_dict_keys = [k for k in state_dict.keys()]
if low_cpu_mem_usage:
# save the keys
if is_sharded:
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
else:
loaded_state_dict_keys = [k for k in state_dict.keys()]
del state_dict # free CPU memory - will reload again later
state_dict = None
config.name_or_path = pretrained_model_name_or_path
......@@ -1825,11 +1931,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
with no_init_weights(_enable=_fast_init):
model = cls(config, *model_args, **model_kwargs)
if from_pt:
# restore default dtype
if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)
if from_tf:
if resolved_archive_file.endswith(".index"):
# Load from a TensorFlow 1.X checkpoint - provided by original authors
......@@ -1859,18 +1960,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
raise
elif from_pt:
if low_cpu_mem_usage:
cls._load_pretrained_model_low_mem(model, loaded_state_dict_keys, resolved_archive_file)
else:
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
model,
state_dict,
resolved_archive_file,
pretrained_model_name_or_path,
ignore_mismatched_sizes=ignore_mismatched_sizes,
sharded_metadata=sharded_metadata,
_fast_init=_fast_init,
)
# restore default dtype
if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
model,
state_dict,
loaded_state_dict_keys, # XXX: rename?
resolved_archive_file,
pretrained_model_name_or_path,
ignore_mismatched_sizes=ignore_mismatched_sizes,
sharded_metadata=sharded_metadata,
_fast_init=_fast_init,
low_cpu_mem_usage=low_cpu_mem_usage,
)
# make sure token embedding weights are still tied if needed
model.tie_weights()
......@@ -1894,16 +1998,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
cls,
model,
state_dict,
loaded_keys,
resolved_archive_file,
pretrained_model_name_or_path,
ignore_mismatched_sizes=False,
sharded_metadata=None,
_fast_init=True,
low_cpu_mem_usage=False,
):
# Retrieve missing & unexpected_keys
model_state_dict = model.state_dict()
expected_keys = list(model_state_dict.keys())
loaded_keys = list(state_dict.keys()) if state_dict is not None else sharded_metadata["all_checkpoint_keys"]
prefix = model.base_model_prefix
def _fix_key(key):
......@@ -1994,9 +2099,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
)
del state_dict[checkpoint_key]
return mismatched_keys
if low_cpu_mem_usage:
model_state_dict = None # free references to model's params to allow memory freeing
_move_model_to_meta(model, loaded_keys, start_prefix)
if state_dict is not None:
# Whole checkpoint
mismatched_keys = _find_mismatched_keys(
......@@ -2009,7 +2117,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
error_msgs = _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
else:
# Sharded checkpoint
# Sharded checkpoint or whole but low_cpu_mem_usage==True
# This should always be a list but, just to be sure.
if not isinstance(resolved_archive_file, list):
resolved_archive_file = [resolved_archive_file]
......@@ -2018,6 +2127,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
mismatched_keys = []
for shard_file in resolved_archive_file:
state_dict = load_state_dict(shard_file)
if low_cpu_mem_usage:
model_state_dict = model.state_dict()
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
# matching the weights in the model.
mismatched_keys += _find_mismatched_keys(
......@@ -2028,7 +2141,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
remove_prefix_from_model,
ignore_mismatched_sizes,
)
error_msgs += _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
if low_cpu_mem_usage:
error_msgs += _load_state_dict_into_meta_model(
model_to_load, state_dict, loaded_keys, start_prefix
)
else:
error_msgs += _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
if len(error_msgs) > 0:
error_msg = "\n\t".join(error_msgs)
......@@ -2093,13 +2212,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
return retrieved_modules
@staticmethod
def _load_pretrained_model_low_mem(model, loaded_state_dict_keys, resolved_archive_file):
def _load_pretrained_model_low_mem(model, loaded_state_dict_keys, resolved_archive_file, start_prefix=""):
"""
This is an experimental function that loads the model using ~1.x model size CPU memory
Before it gets called we do:
Before you call it do:
1. save which state_dict keys we have
1. save which state_dict keys are available
2. drop state_dict before model is created, since the latter takes 1x model size memory
Here then we continue:
......@@ -2110,58 +2229,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
Currently, it doesn't handle missing_keys, unexpected_keys, mismatched_keys. It can't handle deepspeed.
"""
require_version_core("torch>=1.9")
if is_deepspeed_zero3_enabled():
raise ValueError("low_cpu_mem_usage arg cannot be used with DeepSpeed ZeRO-3")
# a helper util to find the last sub-module and the param/buffer name
def find_submodule_and_param_name(model, long_key):
split_key = long_key.split(".")
submodule = model
while len(split_key) > 1:
if hasattr(submodule, split_key[0]):
submodule = getattr(submodule, split_key[0])
del split_key[0]
else:
submodule = None
break
return submodule, split_key[0]
# dematerialize param storage for keys that are going to be replaced by state_dict, by
# putting those on the meta device
for k in loaded_state_dict_keys:
submodule, param_name = find_submodule_and_param_name(model, k)
if submodule is not None:
# selectively switch to the meta device only those params/buffers that will
# be next replaced from state_dict. This a complex way to do p.to_("meta")
# since we have no in-place to_ for tensors.
new_val = getattr(submodule, param_name)
if isinstance(new_val, torch.nn.Parameter):
# isinstance returns False for Params on meta device, so switch after the check
new_val = torch.nn.Parameter(new_val.to("meta"))
else:
new_val = new_val.to("meta")
setattr(submodule, param_name, new_val)
# only now can load state_dict(s)
if not isinstance(resolved_archive_file, list):
resolved_archive_file = [resolved_archive_file]
for archive_file in resolved_archive_file:
state_dict = torch.load(archive_file, map_location="cpu")
# materialize state_dict entries one by one on CPU
for k in loaded_state_dict_keys:
if k in state_dict:
submodule, param_name = find_submodule_and_param_name(model, k)
if submodule is not None:
param_dtype = getattr(submodule, param_name).dtype
new_val = state_dict[k].to(param_dtype)
if isinstance(getattr(submodule, param_name), torch.nn.Parameter):
new_val = torch.nn.Parameter(new_val)
setattr(submodule, param_name, new_val)
del state_dict
_move_model_to_meta(model, loaded_state_dict_keys, start_prefix)
state_dict = load_state_dict(resolved_archive_file)
error_msgs = _load_state_dict_into_meta_model(model, state_dict, loaded_state_dict_keys, start_prefix)
return error_msgs
@classmethod
def register_for_auto_class(cls, auto_class="AutoModel"):
......
......@@ -17,6 +17,7 @@ import inspect
import logging
import os
import re
import shlex
import shutil
import sys
import tempfile
......@@ -667,6 +668,20 @@ def require_librosa(test_case):
return test_case
def cmd_exists(cmd):
return shutil.which(cmd) is not None
def require_usr_bin_time(test_case):
"""
Decorator marking a test that requires `/usr/bin/time`
"""
if not cmd_exists("/usr/bin/time"):
return unittest.skip("test requires /usr/bin/time")(test_case)
else:
return test_case
def get_gpu_count():
"""
Return the number of available gpus (regardless of whether torch, tf or jax is used)
......@@ -1178,6 +1193,39 @@ class TestCasePlus(unittest.TestCase):
return tmp_dir
def python_one_liner_max_rss(self, one_liner_str):
"""
Runs the passed python one liner (just the code) and returns how much max cpu memory was used to run the
program.
Args:
one_liner_str (`string`):
a python one liner code that gets passed to `python -c`
Returns:
max cpu memory bytes used to run the program. This value is likely to vary slightly from run to run.
Requirements:
this helper needs `/usr/bin/time` to be installed (`apt install time`)
Example:
```
one_liner_str = 'from transformers import AutoModel; AutoModel.from_pretrained("t5-large")'
max_rss = self.python_one_liner_max_rss(one_liner_str)
```
"""
if not cmd_exists("/usr/bin/time"):
raise ValueError("/usr/bin/time is required, install with `apt install time`")
cmd = shlex.split(f"/usr/bin/time -f %M python -c '{one_liner_str}'")
with CaptureStd() as cs:
execute_subprocess_async(cmd, env=self.get_env())
# returned data is in KB so convert to bytes
max_rss = int(cs.err.split("\n")[-2].replace("stderr: ", "")) * 1024
return max_rss
def tearDown(self):
# get_auto_remove_tmp_dir feature: remove registered temp dirs
......
......@@ -52,6 +52,7 @@ from transformers.testing_utils import (
is_staging_test,
require_torch,
require_torch_multi_gpu,
require_usr_bin_time,
slow,
torch_device,
)
......@@ -2489,6 +2490,56 @@ class ModelUtilsTest(TestCasePlus):
for p1, p2 in zip(model.parameters(), ref_model.parameters()):
self.assertTrue(torch.allclose(p1, p2))
def test_from_pretrained_low_cpu_mem_usage_functional(self):
# test that we can use `from_pretrained(..., low_cpu_mem_usage=True)` with normal and
# sharded models
mnames = [
"hf-internal-testing/tiny-random-bert-sharded",
"hf-internal-testing/tiny-random-bert",
]
for mname in mnames:
_ = BertModel.from_pretrained(mname, low_cpu_mem_usage=True)
@require_usr_bin_time
def test_from_pretrained_low_cpu_mem_usage_measured(self):
# test that `from_pretrained(..., low_cpu_mem_usage=True)` uses less cpu memory than default
mname = "bert-base-cased"
preamble = "from transformers import AutoModel"
one_liner_str = f'{preamble}; AutoModel.from_pretrained("{mname}", low_cpu_mem_usage=False)'
max_rss_normal = self.python_one_liner_max_rss(one_liner_str)
# print(f"{max_rss_normal=}")
one_liner_str = f'{preamble}; AutoModel.from_pretrained("{mname}", low_cpu_mem_usage=True)'
max_rss_low_mem = self.python_one_liner_max_rss(one_liner_str)
# print(f"{max_rss_low_mem=}")
diff_bytes = max_rss_normal - max_rss_low_mem
diff_percent = diff_bytes / max_rss_low_mem
# print(f"{diff_bytes=}, {diff_percent=}")
# ideally we would compare that the diff is close to ~1x checkpoint size in bytes, but
# measuring cpu memory on linux is very tricky and inconsistent, so instead let's check that
# it's at least 15% less cpu memory consumed
self.assertGreater(
diff_percent,
0.15,
"should use less CPU memory for low_cpu_mem_usage=True, "
f"but got max_rss_normal={max_rss_normal} and max_rss_low_mem={max_rss_low_mem}",
)
# if you want to compare things manually, let's first look at the size of the model in bytes
# model = BertModel.from_pretrained(mname, low_cpu_mem_usage=False)
# total_numel = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values())
# total_bytes = total_numel * 4 # 420MB
# Now the diff_bytes should be very close to total_bytes, but the reports are inconsistent.
# The easiest way to test this is to switch the model and torch.load to do all the work on
# gpu - that way one can measure exactly the total and peak memory used. Perhaps once we add
# functionality to load models directly on gpu, this test can be rewritten to use torch's
# cuda memory tracking and then we should be able to do a much more precise test.
def test_cached_files_are_used_when_internet_is_down(self):
# A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment