Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
0d909f6b
Unverified
Commit
0d909f6b
authored
Mar 09, 2021
by
Sylvain Gugger
Committed by
GitHub
Mar 09, 2021
Browse files
Fairscale FSDP fix model save (#10596)
* Hotfix fairscale FSDP * Evaluation works * Save on process zero
parent
ac17f711
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
19 additions
and
15 deletions
+19
-15
examples/tests/trainer/test_trainer_ext.py
examples/tests/trainer/test_trainer_ext.py
+5
-7
src/transformers/trainer.py
src/transformers/trainer.py
+14
-8
No files found.
examples/tests/trainer/test_trainer_ext.py
View file @
0d909f6b
...
@@ -66,7 +66,7 @@ def require_apex(test_case):
...
@@ -66,7 +66,7 @@ def require_apex(test_case):
class
TestTrainerExt
(
TestCasePlus
):
class
TestTrainerExt
(
TestCasePlus
):
def
run_seq2seq_quick
(
self
,
distributed
=
False
,
extra_args_str
=
None
,
eval
=
True
,
predict_with_generate
=
True
):
def
run_seq2seq_quick
(
self
,
distributed
=
False
,
extra_args_str
=
None
,
predict_with_generate
=
True
):
output_dir
=
self
.
run_trainer
(
output_dir
=
self
.
run_trainer
(
eval_steps
=
1
,
eval_steps
=
1
,
max_len
=
12
,
max_len
=
12
,
...
@@ -83,9 +83,9 @@ class TestTrainerExt(TestCasePlus):
...
@@ -83,9 +83,9 @@ class TestTrainerExt(TestCasePlus):
if
predict_with_generate
:
if
predict_with_generate
:
assert
"eval_bleu"
in
first_step_stats
assert
"eval_bleu"
in
first_step_stats
last_step_stats
=
eval_metrics
[
-
1
]
last_step_stats
=
eval_metrics
[
-
1
]
assert
isinstance
(
last_step_stats
[
"eval_bleu"
],
float
)
assert
isinstance
(
last_step_stats
[
"eval_bleu"
],
float
)
assert
not
math
.
isnan
(
float
(
last_step_stats
[
"eval_loss"
])),
"eval_loss must not be `nan`"
assert
not
math
.
isnan
(
float
(
last_step_stats
[
"eval_loss"
])),
"eval_loss must not be `nan`"
@
require_torch_non_multi_gpu
@
require_torch_non_multi_gpu
def
test_run_seq2seq_no_dist
(
self
):
def
test_run_seq2seq_no_dist
(
self
):
...
@@ -116,14 +116,12 @@ class TestTrainerExt(TestCasePlus):
...
@@ -116,14 +116,12 @@ class TestTrainerExt(TestCasePlus):
# test --sharded_ddp zero_dp_2 w/o --fp16
# test --sharded_ddp zero_dp_2 w/o --fp16
@
require_torch_multi_gpu
@
require_torch_multi_gpu
@
require_fairscale
@
require_fairscale
@
unittest
.
skip
(
"XXX: Fixme: hanging"
)
def
test_run_seq2seq_fully_sharded_ddp
(
self
):
def
test_run_seq2seq_fully_sharded_ddp
(
self
):
self
.
run_seq2seq_quick
(
distributed
=
True
,
extra_args_str
=
"--sharded_ddp zero_dp_2"
,
predict_with_generate
=
False
)
self
.
run_seq2seq_quick
(
distributed
=
True
,
extra_args_str
=
"--sharded_ddp zero_dp_2"
,
predict_with_generate
=
False
)
# test --sharded_ddp zero_dp_2 w/ --fp16
# test --sharded_ddp zero_dp_2 w/ --fp16
@
require_torch_multi_gpu
@
require_torch_multi_gpu
@
require_fairscale
@
require_fairscale
@
unittest
.
skip
(
"XXX: Fixme: hanging"
)
def
test_run_seq2seq_fully_sharded_ddp_fp16
(
self
):
def
test_run_seq2seq_fully_sharded_ddp_fp16
(
self
):
self
.
run_seq2seq_quick
(
self
.
run_seq2seq_quick
(
distributed
=
True
,
extra_args_str
=
"--sharded_ddp zero_dp_2 --fp16"
,
predict_with_generate
=
False
distributed
=
True
,
extra_args_str
=
"--sharded_ddp zero_dp_2 --fp16"
,
predict_with_generate
=
False
...
@@ -206,8 +204,8 @@ class TestTrainerExt(TestCasePlus):
...
@@ -206,8 +204,8 @@ class TestTrainerExt(TestCasePlus):
--warmup_steps 8
--warmup_steps 8
--evaluation_strategy steps
--evaluation_strategy steps
--logging_steps 0
--logging_steps 0
--save_steps
{
str
(
eval_steps
)
}
--eval_steps
{
str
(
eval_steps
)
}
--eval_steps
{
str
(
eval_steps
)
}
--save_steps
{
str
(
eval_steps
)
}
--group_by_length
--group_by_length
--label_smoothing_factor 0.1
--label_smoothing_factor 0.1
--adafactor
--adafactor
...
...
src/transformers/trainer.py
View file @
0d909f6b
...
@@ -1497,11 +1497,14 @@ class Trainer:
...
@@ -1497,11 +1497,14 @@ class Trainer:
"""
"""
if
is_torch_tpu_available
():
if
is_torch_tpu_available
():
self
.
_save_tpu
(
output_dir
)
self
.
_save_tpu
(
output_dir
)
else
:
elif
(
ShardedDDPOption
.
ZERO_DP_2
in
self
.
args
.
sharded_ddp
or
ShardedDDPOption
.
ZERO_DP_3
in
self
.
args
.
sharded_ddp
):
state_dict
=
self
.
model
.
state_dict
()
if
self
.
is_world_process_zero
():
if
self
.
is_world_process_zero
():
self
.
_save
(
output_dir
)
self
.
_save
(
output_dir
,
state_dict
=
state_dict
)
if
self
.
args
.
local_rank
!=
-
1
:
el
if
self
.
is_world_process_zero
()
:
dist
.
barrier
(
)
self
.
_save
(
output_dir
)
def
_save_tpu
(
self
,
output_dir
:
Optional
[
str
]
=
None
):
def
_save_tpu
(
self
,
output_dir
:
Optional
[
str
]
=
None
):
output_dir
=
output_dir
if
output_dir
is
not
None
else
self
.
args
.
output_dir
output_dir
=
output_dir
if
output_dir
is
not
None
else
self
.
args
.
output_dir
...
@@ -1531,7 +1534,7 @@ class Trainer:
...
@@ -1531,7 +1534,7 @@ class Trainer:
if
self
.
tokenizer
is
not
None
and
self
.
is_world_process_zero
():
if
self
.
tokenizer
is
not
None
and
self
.
is_world_process_zero
():
self
.
tokenizer
.
save_pretrained
(
output_dir
)
self
.
tokenizer
.
save_pretrained
(
output_dir
)
def
_save
(
self
,
output_dir
:
Optional
[
str
]
=
None
):
def
_save
(
self
,
output_dir
:
Optional
[
str
]
=
None
,
state_dict
=
None
):
# If we are executing this function, we are the process zero, so we don't check for that.
# If we are executing this function, we are the process zero, so we don't check for that.
output_dir
=
output_dir
if
output_dir
is
not
None
else
self
.
args
.
output_dir
output_dir
=
output_dir
if
output_dir
is
not
None
else
self
.
args
.
output_dir
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
...
@@ -1540,13 +1543,16 @@ class Trainer:
...
@@ -1540,13 +1543,16 @@ class Trainer:
# They can then be reloaded using `from_pretrained()`
# They can then be reloaded using `from_pretrained()`
if
not
isinstance
(
self
.
model
,
PreTrainedModel
):
if
not
isinstance
(
self
.
model
,
PreTrainedModel
):
if
isinstance
(
unwrap_model
(
self
.
model
),
PreTrainedModel
):
if
isinstance
(
unwrap_model
(
self
.
model
),
PreTrainedModel
):
unwrap_model
(
self
.
model
).
save_pretrained
(
output_dir
,
state_dict
=
self
.
model
.
state_dict
())
if
state_dict
is
None
:
state_dict
=
self
.
model
.
state_dict
()
unwrap_model
(
self
.
model
).
save_pretrained
(
output_dir
,
state_dict
=
state_dict
)
else
:
else
:
logger
.
info
(
"Trainer.model is not a `PreTrainedModel`, only saving its state dict."
)
logger
.
info
(
"Trainer.model is not a `PreTrainedModel`, only saving its state dict."
)
state_dict
=
self
.
model
.
state_dict
()
if
state_dict
is
None
:
state_dict
=
self
.
model
.
state_dict
()
torch
.
save
(
state_dict
,
os
.
path
.
join
(
output_dir
,
WEIGHTS_NAME
))
torch
.
save
(
state_dict
,
os
.
path
.
join
(
output_dir
,
WEIGHTS_NAME
))
else
:
else
:
self
.
model
.
save_pretrained
(
output_dir
)
self
.
model
.
save_pretrained
(
output_dir
,
state_dict
=
state_dict
)
if
self
.
tokenizer
is
not
None
:
if
self
.
tokenizer
is
not
None
:
self
.
tokenizer
.
save_pretrained
(
output_dir
)
self
.
tokenizer
.
save_pretrained
(
output_dir
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment