Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
701298d2
Unverified
Commit
701298d2
authored
Jan 10, 2024
by
Weiming Zhao
Committed by
GitHub
Jan 10, 2024
Browse files
Use mmap option to load_state_dict (#28331)
Use mmap option to load_state_dict (#28331)
parent
0f2f0c63
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
60 additions
and
3 deletions
+60
-3
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+11
-2
tests/test_modeling_common.py
tests/test_modeling_common.py
+49
-1
No files found.
src/transformers/modeling_utils.py
View file @
701298d2
...
@@ -30,6 +30,7 @@ from contextlib import contextmanager
...
@@ -30,6 +30,7 @@ from contextlib import contextmanager
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
functools
import
partial
,
wraps
from
functools
import
partial
,
wraps
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
zipfile
import
is_zipfile
import
torch
import
torch
from
packaging
import
version
from
packaging
import
version
...
@@ -516,8 +517,16 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
...
@@ -516,8 +517,16 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
map_location
=
"meta"
map_location
=
"meta"
else
:
else
:
map_location
=
"cpu"
map_location
=
"cpu"
extra_args
=
{}
return
torch
.
load
(
checkpoint_file
,
map_location
=
map_location
,
weights_only
=
True
)
# mmap can only be used with files serialized with zipfile-based format.
if
(
isinstance
(
checkpoint_file
,
str
)
and
map_location
!=
"meta"
and
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
"2.1.0"
)
and
is_zipfile
(
checkpoint_file
)
):
extra_args
=
{
"mmap"
:
True
}
return
torch
.
load
(
checkpoint_file
,
map_location
=
map_location
,
weights_only
=
True
,
**
extra_args
)
except
Exception
as
e
:
except
Exception
as
e
:
try
:
try
:
with
open
(
checkpoint_file
)
as
f
:
with
open
(
checkpoint_file
)
as
f
:
...
...
tests/test_modeling_common.py
View file @
701298d2
...
@@ -101,7 +101,7 @@ if is_torch_available():
...
@@ -101,7 +101,7 @@ if is_torch_available():
from
torch
import
nn
from
torch
import
nn
from
transformers
import
MODEL_MAPPING
,
AdaptiveEmbedding
from
transformers
import
MODEL_MAPPING
,
AdaptiveEmbedding
from
transformers.modeling_utils
import
no_init_weights
from
transformers.modeling_utils
import
load_state_dict
,
no_init_weights
from
transformers.pytorch_utils
import
id_tensor_storage
from
transformers.pytorch_utils
import
id_tensor_storage
...
@@ -536,6 +536,54 @@ class ModelTesterMixin:
...
@@ -536,6 +536,54 @@ class ModelTesterMixin:
).
item
()
).
item
()
self
.
assertLessEqual
(
max_diff
,
1e-3
,
msg
=
f
"
{
key
}
not identical"
)
self
.
assertLessEqual
(
max_diff
,
1e-3
,
msg
=
f
"
{
key
}
not identical"
)
def
test_torch_save_load
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
if
config
.
__class__
not
in
MODEL_MAPPING
:
return
base_class
=
MODEL_MAPPING
[
config
.
__class__
]
if
isinstance
(
base_class
,
tuple
):
base_class
=
base_class
[
0
]
for
model_class
in
self
.
all_model_classes
:
if
model_class
==
base_class
:
continue
# make a copy of model class to not break future tests
# from https://stackoverflow.com/questions/9541025/how-to-copy-a-python-class
class
CopyClass
(
base_class
):
pass
base_class_copy
=
CopyClass
# make sure that all keys are expected for test
base_class_copy
.
_keys_to_ignore_on_load_missing
=
[]
# make init deterministic, but make sure that
# non-initialized weights throw errors nevertheless
base_class_copy
.
_init_weights
=
_mock_init_weights
base_class_copy
.
init_weights
=
_mock_all_init_weights
model
=
model_class
(
config
)
state_dict
=
model
.
state_dict
()
def
check_equal
(
loaded
):
for
key
in
state_dict
.
keys
():
max_diff
=
torch
.
max
(
state_dict
()[
key
]
^
loaded
[
key
]
if
isinstance
(
state_dict
[
key
],
torch
.
BoolTensor
)
else
torch
.
abs
(
state_dict
[
key
]
-
loaded
[
key
])
).
item
()
self
.
assertLessEqual
(
max_diff
,
1e-6
,
msg
=
f
"
{
key
}
not identical"
)
# check that certain keys didn't get saved with the model
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
pt_checkpoint_path
=
os
.
path
.
join
(
tmpdirname
,
"pytorch_model.bin"
)
torch
.
save
(
state_dict
,
pt_checkpoint_path
,
_use_new_zipfile_serialization
=
True
)
check_equal
(
load_state_dict
(
pt_checkpoint_path
))
torch
.
save
(
state_dict
,
pt_checkpoint_path
,
_use_new_zipfile_serialization
=
False
)
check_equal
(
load_state_dict
(
pt_checkpoint_path
))
def
test_initialization
(
self
):
def
test_initialization
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment