Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a0042379
Unverified
Commit
a0042379
authored
Jul 27, 2023
by
Sourab Mangrulkar
Committed by
GitHub
Jul 27, 2023
Browse files
fix deepspeed load best model at end when the model gets sharded (#25057)
parent
1689aea7
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
53 additions
and
54 deletions
+53
-54
src/transformers/trainer.py
src/transformers/trainer.py
+53
-54
No files found.
src/transformers/trainer.py
View file @
a0042379
...
...
@@ -2093,71 +2093,70 @@ class Trainer:
best_safe_adapter_model_path
=
os
.
path
.
join
(
self
.
state
.
best_model_checkpoint
,
ADAPTER_SAFE_WEIGHTS_NAME
)
model
=
self
.
model_wrapped
if
is_sagemaker_mp_enabled
()
else
self
.
model
if
(
if
self
.
is_deepspeed_enabled
:
deepspeed_load_checkpoint
(
self
.
model_wrapped
,
self
.
state
.
best_model_checkpoint
)
elif
(
os
.
path
.
exists
(
best_model_path
)
or
os
.
path
.
exists
(
best_safe_model_path
)
or
os
.
path
.
exists
(
best_adapter_model_path
)
or
os
.
path
.
exists
(
best_safe_adapter_model_path
)
):
if
self
.
is_deepspeed_enabled
:
deepspeed_load_checkpoint
(
self
.
model_wrapped
,
self
.
state
.
best_model_checkpoint
)
else
:
has_been_loaded
=
True
if
is_sagemaker_mp_enabled
():
if
os
.
path
.
isfile
(
os
.
path
.
join
(
self
.
state
.
best_model_checkpoint
,
"user_content.pt"
)):
# If the 'user_content.pt' file exists, load with the new smp api.
# Checkpoint must have been saved with the new smp api.
smp
.
resume_from_checkpoint
(
path
=
self
.
state
.
best_model_checkpoint
,
tag
=
WEIGHTS_NAME
,
partial
=
False
,
load_optimizer
=
False
,
)
else
:
# If the 'user_content.pt' file does NOT exist, load with the old smp api.
# Checkpoint must have been saved with the old smp api.
if
self
.
args
.
save_safetensors
and
os
.
path
.
isfile
(
best_safe_model_path
):
state_dict
=
safetensors
.
torch
.
load_file
(
best_safe_model_path
,
device
=
"cpu"
)
else
:
state_dict
=
torch
.
load
(
best_model_path
,
map_location
=
"cpu"
)
state_dict
[
"_smp_is_partial"
]
=
False
load_result
=
model
.
load_state_dict
(
state_dict
,
strict
=
True
)
elif
self
.
is_fsdp_enabled
:
load_result
=
load_fsdp_model
(
self
.
accelerator
.
state
.
fsdp_plugin
,
self
.
accelerator
,
model
,
self
.
state
.
best_model_checkpoint
has_been_loaded
=
True
if
is_sagemaker_mp_enabled
():
if
os
.
path
.
isfile
(
os
.
path
.
join
(
self
.
state
.
best_model_checkpoint
,
"user_content.pt"
)):
# If the 'user_content.pt' file exists, load with the new smp api.
# Checkpoint must have been saved with the new smp api.
smp
.
resume_from_checkpoint
(
path
=
self
.
state
.
best_model_checkpoint
,
tag
=
WEIGHTS_NAME
,
partial
=
False
,
load_optimizer
=
False
,
)
else
:
if
is_peft_available
()
and
isinstance
(
model
,
PeftModel
):
# If train a model using PEFT & LoRA, assume that adapter have been saved properly.
if
hasattr
(
model
,
"active_adapter"
)
and
hasattr
(
model
,
"load_adapter"
):
if
os
.
path
.
exists
(
best_adapter_model_path
)
or
os
.
path
.
exists
(
best_safe_adapter_model_path
):
model
.
load_adapter
(
self
.
state
.
best_model_checkpoint
,
model
.
active_adapter
)
# Load_adapter has no return value present, modify it when appropriate.
from
torch.nn.modules.module
import
_IncompatibleKeys
load_result
=
_IncompatibleKeys
([],
[])
else
:
logger
.
warning
(
"The intermediate checkpoints of PEFT may not be saved correctly, "
f
"consider using a custom callback to save
{
ADAPTER_WEIGHTS_NAME
}
in corresponding saving folders. "
"Check some examples here: https://github.com/huggingface/peft/issues/96"
)
has_been_loaded
=
False
# If the 'user_content.pt' file does NOT exist, load with the old smp api.
# Checkpoint must have been saved with the old smp api.
if
self
.
args
.
save_safetensors
and
os
.
path
.
isfile
(
best_safe_model_path
):
state_dict
=
safetensors
.
torch
.
load_file
(
best_safe_model_path
,
device
=
"cpu"
)
else
:
state_dict
=
torch
.
load
(
best_model_path
,
map_location
=
"cpu"
)
state_dict
[
"_smp_is_partial"
]
=
False
load_result
=
model
.
load_state_dict
(
state_dict
,
strict
=
True
)
elif
self
.
is_fsdp_enabled
:
load_result
=
load_fsdp_model
(
self
.
accelerator
.
state
.
fsdp_plugin
,
self
.
accelerator
,
model
,
self
.
state
.
best_model_checkpoint
)
else
:
if
is_peft_available
()
and
isinstance
(
model
,
PeftModel
):
# If train a model using PEFT & LoRA, assume that adapter have been saved properly.
if
hasattr
(
model
,
"active_adapter"
)
and
hasattr
(
model
,
"load_adapter"
):
if
os
.
path
.
exists
(
best_adapter_model_path
)
or
os
.
path
.
exists
(
best_safe_adapter_model_path
):
model
.
load_adapter
(
self
.
state
.
best_model_checkpoint
,
model
.
active_adapter
)
# Load_adapter has no return value present, modify it when appropriate.
from
torch.nn.modules.module
import
_IncompatibleKeys
load_result
=
_IncompatibleKeys
([],
[])
else
:
logger
.
warning
(
"Could not load adapter model, make sure to have `peft>=0.3.0` installed"
)
logger
.
warning
(
"The intermediate checkpoints of PEFT may not be saved correctly, "
f
"consider using a custom callback to save
{
ADAPTER_WEIGHTS_NAME
}
in corresponding saving folders. "
"Check some examples here: https://github.com/huggingface/peft/issues/96"
)
has_been_loaded
=
False
else
:
# We load the model state dict on the CPU to avoid an OOM error.
if
self
.
args
.
save_safetensors
and
os
.
path
.
isfile
(
best_safe_model_path
):
state_dict
=
safetensors
.
torch
.
load_file
(
best_safe_model_path
,
device
=
"cpu"
)
else
:
state_dict
=
torch
.
load
(
best_model_path
,
map_location
=
"cpu"
)
logger
.
warning
(
"Could not load adapter model, make sure to have `peft>=0.3.0` installed"
)
has_been_loaded
=
False
else
:
# We load the model state dict on the CPU to avoid an OOM error.
if
self
.
args
.
save_safetensors
and
os
.
path
.
isfile
(
best_safe_model_path
):
state_dict
=
safetensors
.
torch
.
load_file
(
best_safe_model_path
,
device
=
"cpu"
)
else
:
state_dict
=
torch
.
load
(
best_model_path
,
map_location
=
"cpu"
)
# If the model is on the GPU, it still works!
# workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
# which takes *args instead of **kwargs
load_result
=
model
.
load_state_dict
(
state_dict
,
False
)
# If the model is on the GPU, it still works!
# workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
# which takes *args instead of **kwargs
load_result
=
model
.
load_state_dict
(
state_dict
,
False
)
if
not
is_sagemaker_mp_enabled
()
and
has_been_loaded
:
self
.
_issue_warnings_after_load
(
load_result
)
elif
os
.
path
.
exists
(
os
.
path
.
join
(
self
.
state
.
best_model_checkpoint
,
WEIGHTS_INDEX_NAME
)):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment