Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
adb0760b
Unverified
Commit
adb0760b
authored
May 04, 2023
by
Qingyang Wu
Committed by
GitHub
May 04, 2023
Browse files
fix resume fsdp (#23111)
* fix resume fsdp * fix rank 0 loading * fix style and quality
parent
3b74889e
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
24 additions
and
5 deletions
+24
-5
src/transformers/trainer.py
src/transformers/trainer.py
+24
-5
No files found.
src/transformers/trainer.py
View file @
adb0760b
...
@@ -2114,7 +2114,7 @@ class Trainer:
...
@@ -2114,7 +2114,7 @@ class Trainer:
safe_weights_index_file
=
os
.
path
.
join
(
resume_from_checkpoint
,
SAFE_WEIGHTS_INDEX_NAME
)
safe_weights_index_file
=
os
.
path
.
join
(
resume_from_checkpoint
,
SAFE_WEIGHTS_INDEX_NAME
)
if
not
any
(
if
not
any
(
[
os
.
path
.
isfile
(
f
)
for
f
in
[
weights_file
,
safe_weights_file
,
weights_index_file
,
safe_weights_index_file
]
]
os
.
path
.
isfile
(
f
)
for
f
in
[
weights_file
,
safe_weights_file
,
weights_index_file
,
safe_weights_index_file
]
):
):
raise
ValueError
(
f
"Can't find a valid checkpoint at
{
resume_from_checkpoint
}
"
)
raise
ValueError
(
f
"Can't find a valid checkpoint at
{
resume_from_checkpoint
}
"
)
...
@@ -2364,6 +2364,12 @@ class Trainer:
...
@@ -2364,6 +2364,12 @@ class Trainer:
if
self
.
sharded_ddp
==
ShardedDDPOption
.
SIMPLE
:
if
self
.
sharded_ddp
==
ShardedDDPOption
.
SIMPLE
:
self
.
optimizer
.
consolidate_state_dict
()
self
.
optimizer
.
consolidate_state_dict
()
if
self
.
fsdp
:
# FSDP has a different interface for saving optimizer states.
# Needs to be called on all ranks to gather all states.
# full_optim_state_dict will be deprecated after Pytorch 2.2!
full_osd
=
self
.
model
.
__class__
.
full_optim_state_dict
(
self
.
model
,
self
.
optimizer
)
if
is_torch_tpu_available
():
if
is_torch_tpu_available
():
xm
.
rendezvous
(
"saving_optimizer_states"
)
xm
.
rendezvous
(
"saving_optimizer_states"
)
xm
.
save
(
self
.
optimizer
.
state_dict
(),
os
.
path
.
join
(
output_dir
,
OPTIMIZER_NAME
))
xm
.
save
(
self
.
optimizer
.
state_dict
(),
os
.
path
.
join
(
output_dir
,
OPTIMIZER_NAME
))
...
@@ -2388,7 +2394,11 @@ class Trainer:
...
@@ -2388,7 +2394,11 @@ class Trainer:
torch
.
save
(
self
.
scaler
.
state_dict
(),
os
.
path
.
join
(
output_dir
,
SCALER_NAME
))
torch
.
save
(
self
.
scaler
.
state_dict
(),
os
.
path
.
join
(
output_dir
,
SCALER_NAME
))
elif
self
.
args
.
should_save
and
not
self
.
deepspeed
:
elif
self
.
args
.
should_save
and
not
self
.
deepspeed
:
# deepspeed.save_checkpoint above saves model/optim/sched
# deepspeed.save_checkpoint above saves model/optim/sched
if
self
.
fsdp
:
torch
.
save
(
full_osd
,
os
.
path
.
join
(
output_dir
,
OPTIMIZER_NAME
))
else
:
torch
.
save
(
self
.
optimizer
.
state_dict
(),
os
.
path
.
join
(
output_dir
,
OPTIMIZER_NAME
))
torch
.
save
(
self
.
optimizer
.
state_dict
(),
os
.
path
.
join
(
output_dir
,
OPTIMIZER_NAME
))
with
warnings
.
catch_warnings
(
record
=
True
)
as
caught_warnings
:
with
warnings
.
catch_warnings
(
record
=
True
)
as
caught_warnings
:
torch
.
save
(
self
.
lr_scheduler
.
state_dict
(),
os
.
path
.
join
(
output_dir
,
SCHEDULER_NAME
))
torch
.
save
(
self
.
lr_scheduler
.
state_dict
(),
os
.
path
.
join
(
output_dir
,
SCHEDULER_NAME
))
reissue_pt_warnings
(
caught_warnings
)
reissue_pt_warnings
(
caught_warnings
)
...
@@ -2498,6 +2508,15 @@ class Trainer:
...
@@ -2498,6 +2508,15 @@ class Trainer:
# In distributed training however, we load directly on each GPU and risk the GPU OOM as it's more
# In distributed training however, we load directly on each GPU and risk the GPU OOM as it's more
# likely to get OOM on CPU (since we load num_gpu times the optimizer state
# likely to get OOM on CPU (since we load num_gpu times the optimizer state
map_location
=
self
.
args
.
device
if
self
.
args
.
world_size
>
1
else
"cpu"
map_location
=
self
.
args
.
device
if
self
.
args
.
world_size
>
1
else
"cpu"
if
self
.
fsdp
:
full_osd
=
None
# In FSDP, we need to load the full optimizer state dict on rank 0 and then shard it
if
self
.
args
.
process_index
==
0
:
full_osd
=
torch
.
load
(
os
.
path
.
join
(
checkpoint
,
OPTIMIZER_NAME
))
# call scatter_full_optim_state_dict on all ranks
sharded_osd
=
self
.
model
.
__class__
.
scatter_full_optim_state_dict
(
full_osd
,
self
.
model
)
self
.
optimizer
.
load_state_dict
(
sharded_osd
)
else
:
self
.
optimizer
.
load_state_dict
(
self
.
optimizer
.
load_state_dict
(
torch
.
load
(
os
.
path
.
join
(
checkpoint
,
OPTIMIZER_NAME
),
map_location
=
map_location
)
torch
.
load
(
os
.
path
.
join
(
checkpoint
,
OPTIMIZER_NAME
),
map_location
=
map_location
)
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment