Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
8c5c180d
Unverified
Commit
8c5c180d
authored
Jul 05, 2024
by
Marc Sun
Committed by
GitHub
Jul 05, 2024
Browse files
Fix serialization for offloaded model (#31727)
* Fix serialization * style * add test
parent
eaa5f414
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
29 additions
and
12 deletions
+29
-12
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+9
-9
tests/utils/test_modeling_utils.py
tests/utils/test_modeling_utils.py
+20
-3
No files found.
src/transformers/modeling_utils.py
View file @
8c5c180d
...
@@ -2518,9 +2518,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
...
@@ -2518,9 +2518,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Save the model
# Save the model
if
state_dict
is
None
:
if
state_dict
is
None
:
# if any model parameters are offloaded to the disk, make module map
# if any model parameters are offloaded, make module map
if
hasattr
(
self
,
"hf_device_map"
)
and
(
if
(
"cpu"
in
self
.
hf_device_map
.
values
()
or
"disk"
in
self
.
hf_device_map
.
values
()
hasattr
(
self
,
"hf_device_map"
)
and
len
(
set
(
self
.
hf_device_map
.
values
()))
>
1
and
(
"cpu"
in
self
.
hf_device_map
.
values
()
or
"disk"
in
self
.
hf_device_map
.
values
())
):
):
warnings
.
warn
(
warnings
.
warn
(
"Attempting to save a model with offloaded modules. Ensure that unallocated cpu memory exceeds the `shard_size` (5GB default)"
"Attempting to save a model with offloaded modules. Ensure that unallocated cpu memory exceeds the `shard_size` (5GB default)"
...
@@ -2532,7 +2534,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
...
@@ -2532,7 +2534,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
for
key
in
module_state_dict
:
for
key
in
module_state_dict
:
module_map
[
name
+
f
".
{
key
}
"
]
=
module
module_map
[
name
+
f
".
{
key
}
"
]
=
module
state_dict
=
model_to_save
.
state_dict
()
state_dict
=
model_to_save
.
state_dict
()
# Translate state_dict from smp to hf if saving with smp >= 1.10
# Translate state_dict from smp to hf if saving with smp >= 1.10
...
@@ -2655,7 +2656,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
...
@@ -2655,7 +2656,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
and
reg
.
fullmatch
(
filename_no_suffix
)
is
not
None
and
reg
.
fullmatch
(
filename_no_suffix
)
is
not
None
):
):
os
.
remove
(
full_filename
)
os
.
remove
(
full_filename
)
# Save the model
# Save the model
for
shard_file
,
tensors
in
state_dict_split
.
filename_to_tensors
.
items
():
for
shard_file
,
tensors
in
state_dict_split
.
filename_to_tensors
.
items
():
shard
=
{
tensor
:
state_dict
[
tensor
]
for
tensor
in
tensors
}
shard
=
{
tensor
:
state_dict
[
tensor
]
for
tensor
in
tensors
}
...
@@ -2667,15 +2667,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
...
@@ -2667,15 +2667,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
f
"Please upgrade accelerate with `pip install -U accelerate`"
f
"Please upgrade accelerate with `pip install -U accelerate`"
)
)
# init state_dict for this shard
# init state_dict for this shard
state_dict
=
{
name
:
""
for
name
in
shard
}
shard_
state_dict
=
{
name
:
""
for
name
in
shard
}
for
module_name
in
shard
:
for
module_name
in
shard
:
module
=
module_map
[
module_name
]
module
=
module_map
[
module_name
]
# update state dict with onloaded parameters
# update state dict with onloaded parameters
state_dict
=
get_state_dict_from_offload
(
module
,
module_name
,
state_dict
)
shard_
state_dict
=
get_state_dict_from_offload
(
module
,
module_name
,
shard_
state_dict
)
# assign shard to be the completed state dict
# assign shard to be the completed state dict
shard
=
state_dict
shard
=
shard_
state_dict
del
state_dict
del
shard_
state_dict
gc
.
collect
()
gc
.
collect
()
if
safe_serialization
:
if
safe_serialization
:
...
...
tests/utils/test_modeling_utils.py
View file @
8c5c180d
...
@@ -1065,6 +1065,23 @@ class ModelUtilsTest(TestCasePlus):
...
@@ -1065,6 +1065,23 @@ class ModelUtilsTest(TestCasePlus):
# This check we did call the fake head request
# This check we did call the fake head request
mock_head
.
assert_called
()
mock_head
.
assert_called
()
@
require_accelerate
@
mark
.
accelerate_tests
def
test_save_model_with_device_map_cpu
(
self
):
model_id
=
"hf-internal-testing/tiny-random-gpt2"
inputs
=
torch
.
tensor
([[
1
,
2
,
3
]])
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_id
,
device_map
=
"cpu"
)
output
=
model
(
inputs
)[
0
]
model
.
save_pretrained
(
tmp_dir
,
max_shard_size
=
"200KB"
)
# model is 1.6MB, max shard size is allocated to cpu by default
saved_model
=
AutoModelForCausalLM
.
from_pretrained
(
tmp_dir
,
device_map
=
"cpu"
)
saved_model_output
=
saved_model
(
inputs
)[
0
]
self
.
assertTrue
(
torch
.
allclose
(
output
,
saved_model_output
))
@
require_accelerate
@
require_accelerate
@
mark
.
accelerate_tests
@
mark
.
accelerate_tests
@
require_torch_accelerator
@
require_torch_accelerator
...
@@ -1083,9 +1100,9 @@ class ModelUtilsTest(TestCasePlus):
...
@@ -1083,9 +1100,9 @@ class ModelUtilsTest(TestCasePlus):
# check_models_equal requires onloaded tensors
# check_models_equal requires onloaded tensors
model_id
=
"hf-internal-testing/tiny-random-gpt2"
model_id
=
"hf-internal-testing/tiny-random-gpt2"
onloaded_model
=
AutoModelForCausalLM
.
from_pretrained
(
model_id
,
device_map
=
"cpu"
)
onloaded_model
=
AutoModelForCausalLM
.
from_pretrained
(
model_id
,
device_map
=
"cpu"
)
.
to
(
f
"
{
torch_device
}
:0"
)
inputs
=
torch
.
tensor
([[
1
,
2
,
3
]]).
to
(
f
"
{
torch_device
}
:0"
)
inputs
=
torch
.
tensor
([[
1
,
2
,
3
]]).
to
(
f
"
{
torch_device
}
:0"
)
cpu_
output
=
onloaded_model
(
inputs
)[
0
]
output
=
onloaded_model
(
inputs
)[
0
]
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
offload_folder
=
os
.
path
.
join
(
tmp_dir
,
"offload"
)
offload_folder
=
os
.
path
.
join
(
tmp_dir
,
"offload"
)
...
@@ -1099,7 +1116,7 @@ class ModelUtilsTest(TestCasePlus):
...
@@ -1099,7 +1116,7 @@ class ModelUtilsTest(TestCasePlus):
saved_model
=
AutoModelForCausalLM
.
from_pretrained
(
tmp_dir
,
device_map
=
device_map
)
saved_model
=
AutoModelForCausalLM
.
from_pretrained
(
tmp_dir
,
device_map
=
device_map
)
postsaved_output
=
saved_model
(
inputs
)[
0
]
postsaved_output
=
saved_model
(
inputs
)[
0
]
self
.
assertTrue
(
torch
.
allclose
(
cpu_
output
,
presaved_output
,
atol
=
1e-4
))
self
.
assertTrue
(
torch
.
allclose
(
output
,
presaved_output
,
atol
=
1e-4
))
self
.
assertTrue
(
torch
.
allclose
(
presaved_output
,
postsaved_output
))
self
.
assertTrue
(
torch
.
allclose
(
presaved_output
,
postsaved_output
))
@
require_safetensors
@
require_safetensors
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment