Unverified Commit 701298d2 authored by Weiming Zhao's avatar Weiming Zhao Committed by GitHub
Browse files

Use mmap option to load_state_dict (#28331)

Use mmap option to load_state_dict (#28331)
parent 0f2f0c63
...@@ -30,6 +30,7 @@ from contextlib import contextmanager ...@@ -30,6 +30,7 @@ from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial, wraps from functools import partial, wraps
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from zipfile import is_zipfile
import torch import torch
from packaging import version from packaging import version
...@@ -516,8 +517,16 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]): ...@@ -516,8 +517,16 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
map_location = "meta" map_location = "meta"
else: else:
map_location = "cpu" map_location = "cpu"
extra_args = {}
return torch.load(checkpoint_file, map_location=map_location, weights_only=True) # mmap can only be used with files serialized with zipfile-based format.
if (
isinstance(checkpoint_file, str)
and map_location != "meta"
and version.parse(torch.__version__) >= version.parse("2.1.0")
and is_zipfile(checkpoint_file)
):
extra_args = {"mmap": True}
return torch.load(checkpoint_file, map_location=map_location, weights_only=True, **extra_args)
except Exception as e: except Exception as e:
try: try:
with open(checkpoint_file) as f: with open(checkpoint_file) as f:
......
...@@ -101,7 +101,7 @@ if is_torch_available(): ...@@ -101,7 +101,7 @@ if is_torch_available():
from torch import nn from torch import nn
from transformers import MODEL_MAPPING, AdaptiveEmbedding from transformers import MODEL_MAPPING, AdaptiveEmbedding
from transformers.modeling_utils import no_init_weights from transformers.modeling_utils import load_state_dict, no_init_weights
from transformers.pytorch_utils import id_tensor_storage from transformers.pytorch_utils import id_tensor_storage
...@@ -536,6 +536,54 @@ class ModelTesterMixin: ...@@ -536,6 +536,54 @@ class ModelTesterMixin:
).item() ).item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
def test_torch_save_load(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if config.__class__ not in MODEL_MAPPING:
return
base_class = MODEL_MAPPING[config.__class__]
if isinstance(base_class, tuple):
base_class = base_class[0]
for model_class in self.all_model_classes:
if model_class == base_class:
continue
# make a copy of model class to not break future tests
# from https://stackoverflow.com/questions/9541025/how-to-copy-a-python-class
class CopyClass(base_class):
pass
base_class_copy = CopyClass
# make sure that all keys are expected for test
base_class_copy._keys_to_ignore_on_load_missing = []
# make init deterministic, but make sure that
# non-initialized weights throw errors nevertheless
base_class_copy._init_weights = _mock_init_weights
base_class_copy.init_weights = _mock_all_init_weights
model = model_class(config)
state_dict = model.state_dict()
def check_equal(loaded):
for key in state_dict.keys():
max_diff = torch.max(
state_dict()[key] ^ loaded[key]
if isinstance(state_dict[key], torch.BoolTensor)
else torch.abs(state_dict[key] - loaded[key])
).item()
self.assertLessEqual(max_diff, 1e-6, msg=f"{key} not identical")
# check that certain keys didn't get saved with the model
with tempfile.TemporaryDirectory() as tmpdirname:
pt_checkpoint_path = os.path.join(tmpdirname, "pytorch_model.bin")
torch.save(state_dict, pt_checkpoint_path, _use_new_zipfile_serialization=True)
check_equal(load_state_dict(pt_checkpoint_path))
torch.save(state_dict, pt_checkpoint_path, _use_new_zipfile_serialization=False)
check_equal(load_state_dict(pt_checkpoint_path))
def test_initialization(self): def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment