Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a0042379
Unverified
Commit
a0042379
authored
Jul 27, 2023
by
Sourab Mangrulkar
Committed by
GitHub
Jul 27, 2023
Browse files
fix deepspeed load best model at end when the model gets sharded (#25057)
parent
1689aea7
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
53 additions
and
54 deletions
+53
-54
src/transformers/trainer.py
src/transformers/trainer.py
+53
-54
No files found.
src/transformers/trainer.py
View file @
a0042379
...
@@ -2093,71 +2093,70 @@ class Trainer:
...
@@ -2093,71 +2093,70 @@ class Trainer:
best_safe_adapter_model_path
=
os
.
path
.
join
(
self
.
state
.
best_model_checkpoint
,
ADAPTER_SAFE_WEIGHTS_NAME
)
best_safe_adapter_model_path
=
os
.
path
.
join
(
self
.
state
.
best_model_checkpoint
,
ADAPTER_SAFE_WEIGHTS_NAME
)
model
=
self
.
model_wrapped
if
is_sagemaker_mp_enabled
()
else
self
.
model
model
=
self
.
model_wrapped
if
is_sagemaker_mp_enabled
()
else
self
.
model
if
(
if
self
.
is_deepspeed_enabled
:
deepspeed_load_checkpoint
(
self
.
model_wrapped
,
self
.
state
.
best_model_checkpoint
)
elif
(
os
.
path
.
exists
(
best_model_path
)
os
.
path
.
exists
(
best_model_path
)
or
os
.
path
.
exists
(
best_safe_model_path
)
or
os
.
path
.
exists
(
best_safe_model_path
)
or
os
.
path
.
exists
(
best_adapter_model_path
)
or
os
.
path
.
exists
(
best_adapter_model_path
)
or
os
.
path
.
exists
(
best_safe_adapter_model_path
)
or
os
.
path
.
exists
(
best_safe_adapter_model_path
)
):
):
if
self
.
is_deepspeed_enabled
:
has_been_loaded
=
True
deepspeed_load_checkpoint
(
self
.
model_wrapped
,
self
.
state
.
best_model_checkpoint
)
if
is_sagemaker_mp_enabled
():
else
:
if
os
.
path
.
isfile
(
os
.
path
.
join
(
self
.
state
.
best_model_checkpoint
,
"user_content.pt"
)):
has_been_loaded
=
True
# If the 'user_content.pt' file exists, load with the new smp api.
if
is_sagemaker_mp_enabled
():
# Checkpoint must have been saved with the new smp api.
if
os
.
path
.
isfile
(
os
.
path
.
join
(
self
.
state
.
best_model_checkpoint
,
"user_content.pt"
)):
smp
.
resume_from_checkpoint
(
# If the 'user_content.pt' file exists, load with the new smp api.
path
=
self
.
state
.
best_model_checkpoint
,
# Checkpoint must have been saved with the new smp api.
tag
=
WEIGHTS_NAME
,
smp
.
resume_from_checkpoint
(
partial
=
False
,
path
=
self
.
state
.
best_model_checkpoint
,
load_optimizer
=
False
,
tag
=
WEIGHTS_NAME
,
partial
=
False
,
load_optimizer
=
False
,
)
else
:
# If the 'user_content.pt' file does NOT exist, load with the old smp api.
# Checkpoint must have been saved with the old smp api.
if
self
.
args
.
save_safetensors
and
os
.
path
.
isfile
(
best_safe_model_path
):
state_dict
=
safetensors
.
torch
.
load_file
(
best_safe_model_path
,
device
=
"cpu"
)
else
:
state_dict
=
torch
.
load
(
best_model_path
,
map_location
=
"cpu"
)
state_dict
[
"_smp_is_partial"
]
=
False
load_result
=
model
.
load_state_dict
(
state_dict
,
strict
=
True
)
elif
self
.
is_fsdp_enabled
:
load_result
=
load_fsdp_model
(
self
.
accelerator
.
state
.
fsdp_plugin
,
self
.
accelerator
,
model
,
self
.
state
.
best_model_checkpoint
)
)
else
:
else
:
if
is_peft_available
()
and
isinstance
(
model
,
PeftModel
):
# If the 'user_content.pt' file does NOT exist, load with the old smp api.
# If train a model using PEFT & LoRA, assume that adapter have been saved properly.
# Checkpoint must have been saved with the old smp api.
if
hasattr
(
model
,
"active_adapter"
)
and
hasattr
(
model
,
"load_adapter"
):
if
self
.
args
.
save_safetensors
and
os
.
path
.
isfile
(
best_safe_model_path
):
if
os
.
path
.
exists
(
best_adapter_model_path
)
or
os
.
path
.
exists
(
best_safe_adapter_model_path
):
state_dict
=
safetensors
.
torch
.
load_file
(
best_safe_model_path
,
device
=
"cpu"
)
model
.
load_adapter
(
self
.
state
.
best_model_checkpoint
,
model
.
active_adapter
)
else
:
# Load_adapter has no return value present, modify it when appropriate.
state_dict
=
torch
.
load
(
best_model_path
,
map_location
=
"cpu"
)
from
torch.nn.modules.module
import
_IncompatibleKeys
state_dict
[
"_smp_is_partial"
]
=
False
load_result
=
_IncompatibleKeys
([],
[])
load_result
=
model
.
load_state_dict
(
state_dict
,
strict
=
True
)
else
:
elif
self
.
is_fsdp_enabled
:
logger
.
warning
(
load_result
=
load_fsdp_model
(
"The intermediate checkpoints of PEFT may not be saved correctly, "
self
.
accelerator
.
state
.
fsdp_plugin
,
self
.
accelerator
,
model
,
self
.
state
.
best_model_checkpoint
f
"consider using a custom callback to save
{
ADAPTER_WEIGHTS_NAME
}
in corresponding saving folders. "
)
"Check some examples here: https://github.com/huggingface/peft/issues/96"
else
:
)
if
is_peft_available
()
and
isinstance
(
model
,
PeftModel
):
has_been_loaded
=
False
# If train a model using PEFT & LoRA, assume that adapter have been saved properly.
if
hasattr
(
model
,
"active_adapter"
)
and
hasattr
(
model
,
"load_adapter"
):
if
os
.
path
.
exists
(
best_adapter_model_path
)
or
os
.
path
.
exists
(
best_safe_adapter_model_path
):
model
.
load_adapter
(
self
.
state
.
best_model_checkpoint
,
model
.
active_adapter
)
# Load_adapter has no return value present, modify it when appropriate.
from
torch.nn.modules.module
import
_IncompatibleKeys
load_result
=
_IncompatibleKeys
([],
[])
else
:
else
:
logger
.
warning
(
"Could not load adapter model, make sure to have `peft>=0.3.0` installed"
)
logger
.
warning
(
"The intermediate checkpoints of PEFT may not be saved correctly, "
f
"consider using a custom callback to save
{
ADAPTER_WEIGHTS_NAME
}
in corresponding saving folders. "
"Check some examples here: https://github.com/huggingface/peft/issues/96"
)
has_been_loaded
=
False
has_been_loaded
=
False
else
:
else
:
# We load the model state dict on the CPU to avoid an OOM error.
logger
.
warning
(
"Could not load adapter model, make sure to have `peft>=0.3.0` installed"
)
if
self
.
args
.
save_safetensors
and
os
.
path
.
isfile
(
best_safe_model_path
):
has_been_loaded
=
False
state_dict
=
safetensors
.
torch
.
load_file
(
best_safe_model_path
,
device
=
"cpu"
)
else
:
else
:
# We load the model state dict on the CPU to avoid an OOM error.
state_dict
=
torch
.
load
(
best_model_path
,
map_location
=
"cpu"
)
if
self
.
args
.
save_safetensors
and
os
.
path
.
isfile
(
best_safe_model_path
):
state_dict
=
safetensors
.
torch
.
load_file
(
best_safe_model_path
,
device
=
"cpu"
)
else
:
state_dict
=
torch
.
load
(
best_model_path
,
map_location
=
"cpu"
)
# If the model is on the GPU, it still works!
# If the model is on the GPU, it still works!
# workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
# workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
# which takes *args instead of **kwargs
# which takes *args instead of **kwargs
load_result
=
model
.
load_state_dict
(
state_dict
,
False
)
load_result
=
model
.
load_state_dict
(
state_dict
,
False
)
if
not
is_sagemaker_mp_enabled
()
and
has_been_loaded
:
if
not
is_sagemaker_mp_enabled
()
and
has_been_loaded
:
self
.
_issue_warnings_after_load
(
load_result
)
self
.
_issue_warnings_after_load
(
load_result
)
elif
os
.
path
.
exists
(
os
.
path
.
join
(
self
.
state
.
best_model_checkpoint
,
WEIGHTS_INDEX_NAME
)):
elif
os
.
path
.
exists
(
os
.
path
.
join
(
self
.
state
.
best_model_checkpoint
,
WEIGHTS_INDEX_NAME
)):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment