Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
1ac599d9
Unverified
Commit
1ac599d9
authored
Nov 15, 2023
by
Marc Sun
Committed by
GitHub
Nov 15, 2023
Browse files
Fix offload disk for loading derivated model checkpoint into base model (#27253)
* fix * style * add test
parent
b71c38a0
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
54 additions
and
8 deletions
+54
-8
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+14
-8
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+40
-0
No files found.
src/transformers/modeling_utils.py
View file @
1ac599d9
...
...
@@ -3793,8 +3793,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
else
:
folder
=
None
if
device_map
is
not
None
and
is_safetensors
:
param_device_map
=
expand_device_map
(
device_map
,
original_loaded_keys
)
param_device_map
=
expand_device_map
(
device_map
,
original_loaded_keys
,
start_prefix
)
str_dtype
=
str
(
dtype
).
replace
(
"torch."
,
""
)
if
dtype
is
not
None
else
"float32"
if
sharded_metadata
is
None
:
archive_file
=
(
...
...
@@ -3806,9 +3805,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
else
:
weight_map
=
{
p
:
os
.
path
.
join
(
folder
,
f
)
for
p
,
f
in
sharded_metadata
[
"weight_map"
].
items
()}
offload_index
=
{
p
:
{
"safetensors_file"
:
f
,
"weight_name"
:
p
,
"dtype"
:
str_dtype
}
p
[
len
(
start_prefix
)
:]
:
{
"safetensors_file"
:
f
,
"weight_name"
:
p
,
"dtype"
:
str_dtype
}
for
p
,
f
in
weight_map
.
items
()
if
p
aram_device_map
[
p
]
==
"disk"
if
p
.
startswith
(
start_prefix
)
and
param_device_map
[
p
[
len
(
start_prefix
)
:]
]
==
"disk"
}
if
state_dict
is
not
None
:
...
...
@@ -3842,7 +3841,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
state_dict_index
=
None
if
is_sharded_safetensors
:
disk_only_shard_files
=
get_disk_only_shard_files
(
device_map
,
sharded_metadata
=
sharded_metadata
)
disk_only_shard_files
=
get_disk_only_shard_files
(
device_map
,
sharded_metadata
=
sharded_metadata
,
start_prefix
=
start_prefix
)
disk_only_shard_files
=
[
os
.
path
.
join
(
folder
,
f
)
for
f
in
disk_only_shard_files
]
else
:
disk_only_shard_files
=
[]
...
...
@@ -4576,11 +4577,12 @@ def unwrap_model(model: nn.Module) -> nn.Module:
return
model
def
expand_device_map
(
device_map
,
param_names
):
def
expand_device_map
(
device_map
,
param_names
,
start_prefix
):
"""
Expand a device map to return the correspondance parameter name to device.
"""
new_device_map
=
{}
param_names
=
[
p
[
len
(
start_prefix
)
:]
for
p
in
param_names
if
p
.
startswith
(
start_prefix
)]
for
module
,
device
in
device_map
.
items
():
new_device_map
.
update
(
{
p
:
device
for
p
in
param_names
if
p
==
module
or
p
.
startswith
(
f
"
{
module
}
."
)
or
module
==
""
}
...
...
@@ -4588,12 +4590,16 @@ def expand_device_map(device_map, param_names):
return
new_device_map
def
get_disk_only_shard_files
(
device_map
,
sharded_metadata
):
def
get_disk_only_shard_files
(
device_map
,
sharded_metadata
,
start_prefix
):
"""
Returns the list of shard files containing only weights offloaded to disk.
"""
weight_map
=
{
p
[
len
(
start_prefix
)
:]:
v
for
p
,
v
in
sharded_metadata
[
"weight_map"
].
items
()
if
p
.
startswith
(
start_prefix
)
}
files_content
=
collections
.
defaultdict
(
list
)
for
weight_name
,
filename
in
sharded_metadata
[
"
weight_map
"
]
.
items
():
for
weight_name
,
filename
in
weight_map
.
items
():
while
len
(
weight_name
)
>
0
and
weight_name
not
in
device_map
:
weight_name
=
"."
.
join
(
weight_name
.
split
(
"."
)[:
-
1
])
files_content
[
filename
].
append
(
device_map
[
weight_name
])
...
...
tests/test_modeling_utils.py
View file @
1ac599d9
...
...
@@ -750,6 +750,46 @@ class ModelUtilsTest(TestCasePlus):
self
.
assertTrue
(
torch
.
allclose
(
outputs1
.
logits
.
cpu
(),
outputs2
.
logits
.
cpu
()))
@
require_accelerate
@
mark
.
accelerate_tests
@
require_torch_accelerator
def
test_from_pretrained_disk_offload_derived_to_base_model
(
self
):
derived_model
=
AutoModelForCausalLM
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
)
device_map
=
{
"wte"
:
0
,
"wpe"
:
0
,
"h.0"
:
"cpu"
,
"h.1"
:
"cpu"
,
"h.2"
:
"cpu"
,
"h.3"
:
"disk"
,
"h.4"
:
"disk"
,
"ln_f"
:
0
,
}
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
inputs
=
torch
.
tensor
([[
1
,
2
,
3
]]).
to
(
0
)
derived_model
.
save_pretrained
(
tmp_dir
,
use_safetensors
=
True
)
base_model
=
AutoModel
.
from_pretrained
(
tmp_dir
)
outputs1
=
base_model
.
to
(
0
)(
inputs
)
# with disk offload
offload_folder
=
os
.
path
.
join
(
tmp_dir
,
"offload"
)
base_model_with_offload
=
AutoModel
.
from_pretrained
(
tmp_dir
,
device_map
=
device_map
,
offload_folder
=
offload_folder
)
outputs2
=
base_model_with_offload
(
inputs
)
self
.
assertTrue
(
torch
.
allclose
(
outputs1
[
0
].
cpu
(),
outputs2
[
0
].
cpu
()))
# With state dict temp offload
new_model_with_offload
=
AutoModel
.
from_pretrained
(
tmp_dir
,
device_map
=
device_map
,
offload_folder
=
offload_folder
,
offload_state_dict
=
True
,
)
outputs2
=
new_model_with_offload
(
inputs
)
self
.
assertTrue
(
torch
.
allclose
(
outputs1
[
0
].
cpu
(),
outputs2
[
0
].
cpu
()))
def
test_cached_files_are_used_when_internet_is_down
(
self
):
# A mock response for an HTTP head request to emulate server down
response_mock
=
mock
.
Mock
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment