Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
567d9c06
Unverified
Commit
567d9c06
authored
May 31, 2022
by
Sylvain Gugger
Committed by
GitHub
May 31, 2022
Browse files
Disk offload fix (#17428)
* Fix offload to disk for big models * Add test * Fix test for other models
parent
975dd2bb
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
43 additions
and
1 deletion
+43
-1
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+7
-1
tests/test_modeling_common.py
tests/test_modeling_common.py
+36
-0
No files found.
src/transformers/modeling_utils.py
View file @
567d9c06
...
...
@@ -597,11 +597,12 @@ def _load_state_dict_into_meta_model(
raise
ValueError
(
f
"
{
param_name
}
doesn't have any device set."
)
param_device
=
device_map
[
module_name
]
set_module_tensor_to_device
(
model
,
param_name
,
param_device
,
value
=
param
)
if
param_device
==
"disk"
:
offload_index
=
offload_weight
(
param
,
param_name
,
offload_folder
,
offload_index
)
elif
param_device
==
"cpu"
and
state_dict_index
is
not
None
:
state_dict_index
=
offload_weight
(
param
,
param_name
,
state_dict_folder
,
state_dict_index
)
else
:
set_module_tensor_to_device
(
model
,
param_name
,
param_device
,
value
=
param
)
return
error_msgs
,
offload_index
,
state_dict_index
...
...
@@ -2216,6 +2217,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
offload_state_dict
=
False
,
dtype
=
None
,
):
if
device_map
is
not
None
and
"disk"
in
device_map
.
values
()
and
offload_folder
is
None
:
raise
ValueError
(
"The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder` for"
" them."
)
# Retrieve missing & unexpected_keys
model_state_dict
=
model
.
state_dict
()
expected_keys
=
list
(
model_state_dict
.
keys
())
...
...
tests/test_modeling_common.py
View file @
567d9c06
...
...
@@ -2214,6 +2214,42 @@ class ModelTesterMixin:
else
:
self
.
assertEqual
(
param
.
device
,
torch
.
device
(
param_device
))
@
require_accelerate
@
require_torch_gpu
def
test_disk_offload
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
if
isinstance
(
getattr
(
config
,
"num_hidden_layers"
,
None
),
int
)
and
config
.
num_hidden_layers
<
5
:
config
.
num_hidden_layers
=
5
for
model_class
in
self
.
all_model_classes
:
if
model_class
.
_no_split_modules
is
None
:
continue
inputs_dict
=
self
.
_prepare_for_class
(
inputs_dict
,
model_class
)
model
=
model_class
(
config
).
eval
()
model
=
model
.
to
(
torch_device
)
base_output
=
model
(
**
inputs_dict
)
model_size
=
compute_module_sizes
(
model
)[
""
]
# We test several splits of sizes to make sure it works.
max_size
=
int
(
0.4
*
model_size
)
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
model
.
cpu
().
save_pretrained
(
tmp_dir
)
max_memory
=
{
0
:
max_size
,
"cpu"
:
max_size
}
with
self
.
assertRaises
(
ValueError
):
# This errors out cause it's missing an offload folder
new_model
=
model_class
.
from_pretrained
(
tmp_dir
,
device_map
=
"auto"
,
max_memory
=
max_memory
)
new_model
=
model_class
.
from_pretrained
(
tmp_dir
,
device_map
=
"auto"
,
max_memory
=
max_memory
,
offload_folder
=
tmp_dir
)
self
.
check_device_map_is_respected
(
new_model
,
new_model
.
hf_device_map
)
new_output
=
new_model
(
**
inputs_dict
)
self
.
assertTrue
(
torch
.
allclose
(
base_output
[
0
],
new_output
[
0
]))
@
require_accelerate
@
require_torch_gpu
def
test_cpu_offload
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment