Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
df8e6804
Unverified
Commit
df8e6804
authored
Jun 22, 2022
by
Sylvain Gugger
Committed by
GitHub
Jun 22, 2022
Browse files
Offload fixes (#17810)
* Offload fixes * Add a test
parent
0d0c392c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
58 additions
and
5 deletions
+58
-5
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+16
-5
tests/test_modeling_common.py
tests/test_modeling_common.py
+42
-0
No files found.
src/transformers/modeling_utils.py
View file @
df8e6804
...
@@ -2166,11 +2166,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
...
@@ -2166,11 +2166,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
offload_state_dict
=
False
,
offload_state_dict
=
False
,
dtype
=
None
,
dtype
=
None
,
):
):
if
device_map
is
not
None
and
"disk"
in
device_map
.
values
()
and
offload_folder
is
None
:
if
device_map
is
not
None
and
"disk"
in
device_map
.
values
():
raise
ValueError
(
if
offload_folder
is
None
:
"The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder` for"
raise
ValueError
(
" them."
"The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder`"
)
" for them."
)
os
.
makedirs
(
offload_folder
,
exist_ok
=
True
)
# Retrieve missing & unexpected_keys
# Retrieve missing & unexpected_keys
model_state_dict
=
model
.
state_dict
()
model_state_dict
=
model
.
state_dict
()
expected_keys
=
list
(
model_state_dict
.
keys
())
expected_keys
=
list
(
model_state_dict
.
keys
())
...
@@ -2344,6 +2346,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
...
@@ -2344,6 +2346,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
gc
.
collect
()
gc
.
collect
()
if
offload_index
is
not
None
and
len
(
offload_index
)
>
0
:
if
offload_index
is
not
None
and
len
(
offload_index
)
>
0
:
if
model
!=
model_to_load
:
# We need to add the prefix of the base model
prefix
=
cls
.
base_model_prefix
for
weight_name
in
offload_index
:
shutil
.
move
(
os
.
path
.
join
(
offload_folder
,
f
"
{
weight_name
}
.dat"
),
os
.
path
.
join
(
offload_folder
,
f
"
{
prefix
}
.
{
weight_name
}
.dat"
),
)
offload_index
=
{
f
"
{
prefix
}
.
{
key
}
"
:
value
for
key
,
value
in
offload_index
.
items
()}
save_offload_index
(
offload_index
,
offload_folder
)
save_offload_index
(
offload_index
,
offload_folder
)
if
offload_state_dict
:
if
offload_state_dict
:
...
...
tests/test_modeling_common.py
View file @
df8e6804
...
@@ -2811,6 +2811,48 @@ class ModelUtilsTest(TestCasePlus):
...
@@ -2811,6 +2811,48 @@ class ModelUtilsTest(TestCasePlus):
text_output
=
tokenizer
.
decode
(
output
[
0
].
tolist
())
text_output
=
tokenizer
.
decode
(
output
[
0
].
tolist
())
self
.
assertEqual
(
text_output
,
"Hello, my name is John. I'm a writer, and I'm a writer. I'm"
)
self
.
assertEqual
(
text_output
,
"Hello, my name is John. I'm a writer, and I'm a writer. I'm"
)
@
require_accelerate
@
require_torch_gpu
def
test_from_pretrained_disk_offload_task_model
(
self
):
model
=
AutoModel
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
)
device_map
=
{
"transformer.wte"
:
0
,
"transformer.wpe"
:
0
,
"transformer.h.0"
:
"cpu"
,
"transformer.h.1"
:
"cpu"
,
"transformer.h.2"
:
"cpu"
,
"transformer.h.3"
:
"disk"
,
"transformer.h.4"
:
"disk"
,
"transformer.ln_f"
:
0
,
"lm_head"
:
0
,
}
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
inputs
=
torch
.
tensor
([[
1
,
2
,
3
]]).
to
(
0
)
model
.
save_pretrained
(
tmp_dir
)
new_model
=
AutoModelForCausalLM
.
from_pretrained
(
tmp_dir
).
to
(
0
)
outputs1
=
new_model
.
to
(
0
)(
inputs
)
offload_folder
=
os
.
path
.
join
(
tmp_dir
,
"offload"
)
new_model_with_offload
=
AutoModelForCausalLM
.
from_pretrained
(
tmp_dir
,
device_map
=
device_map
,
offload_folder
=
offload_folder
)
outputs2
=
new_model_with_offload
(
inputs
)
self
.
assertTrue
(
torch
.
allclose
(
outputs1
.
logits
.
cpu
(),
outputs2
.
logits
.
cpu
()))
# With state dict temp offload
offload_folder
=
os
.
path
.
join
(
tmp_dir
,
"offload"
)
new_model_with_offload
=
AutoModelForCausalLM
.
from_pretrained
(
tmp_dir
,
device_map
=
device_map
,
offload_folder
=
offload_folder
,
offload_state_dict
=
True
,
)
outputs2
=
new_model_with_offload
(
inputs
)
self
.
assertTrue
(
torch
.
allclose
(
outputs1
.
logits
.
cpu
(),
outputs2
.
logits
.
cpu
()))
def
test_cached_files_are_used_when_internet_is_down
(
self
):
def
test_cached_files_are_used_when_internet_is_down
(
self
):
# A mock response for an HTTP head request to emulate server down
# A mock response for an HTTP head request to emulate server down
response_mock
=
mock
.
Mock
()
response_mock
=
mock
.
Mock
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment