Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
e8246f78
Unverified
Commit
e8246f78
authored
Mar 12, 2021
by
Sylvain Gugger
Committed by
GitHub
Mar 12, 2021
Browse files
Add auto_wrap option in fairscale integration (#10673)
* Add auto_wrap option in fairscale integration * Style
parent
184ef8ec
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
13 additions
and
6 deletions
+13
-6
docs/source/main_classes/trainer.rst
docs/source/main_classes/trainer.rst
+2
-2
src/transformers/trainer.py
src/transformers/trainer.py
+7
-1
src/transformers/trainer_utils.py
src/transformers/trainer_utils.py
+1
-0
src/transformers/training_args.py
src/transformers/training_args.py
+3
-3
No files found.
docs/source/main_classes/trainer.rst
View file @
e8246f78
...
...
@@ -335,8 +335,8 @@ Known caveats:
- This feature is incompatible with :obj:`--predict_with_generate` in the `run_seq2seq.py` script.
- Using :obj:`--sharded_ddp zero_dp_3` requires wrapping each layer of the model in the special container
:obj:`FullyShardedDataParallelism` of fairscale.
This is not done automatically by any of the example scripts of the
:class:`~transformers.Trainer
`.
:obj:`FullyShardedDataParallelism` of fairscale.
It should be used with the option :obj:`auto_wrap` if you are not
doing this yourself: :obj:`--sharded_ddp "zero_dp_3 auto_wrap"
`.
DeepSpeed
...
...
src/transformers/trainer.py
View file @
e8246f78
...
...
@@ -144,6 +144,7 @@ if is_fairscale_available():
if
version
.
parse
(
fairscale
.
__version__
)
>=
version
.
parse
(
"0.3"
):
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FullyShardedDDP
from
fairscale.nn.wrap
import
auto_wrap
else
:
FullyShardedDDP
=
None
...
...
@@ -775,8 +776,13 @@ class Trainer:
cpu_offload
=
ShardedDDPOption
.
OFFLOAD
in
self
.
args
.
sharded_ddp
zero_3
=
self
.
sharded_ddp
==
ShardedDDPOption
.
ZERO_DP_3
# XXX: Breaking the self.model convention but I see no way around it for now.
if
ShardedDDPOption
.
AUTO_WRAP
in
self
.
args
.
sharded_ddp
:
model
=
auto_wrap
(
model
)
self
.
model
=
model
=
FullyShardedDDP
(
model
,
mixed_precision
=
mixed_precision
,
reshard_after_forward
=
zero_3
,
cpu_offload
=
cpu_offload
model
,
mixed_precision
=
mixed_precision
,
reshard_after_forward
=
zero_3
,
cpu_offload
=
cpu_offload
,
).
to
(
self
.
args
.
device
)
elif
is_sagemaker_distributed_available
():
...
...
src/transformers/trainer_utils.py
View file @
e8246f78
...
...
@@ -446,3 +446,4 @@ class ShardedDDPOption(ExplicitEnum):
ZERO_DP_2
=
"zero_dp_2"
ZERO_DP_3
=
"zero_dp_3"
OFFLOAD
=
"offload"
AUTO_WRAP
=
"auto_wrap"
src/transformers/training_args.py
View file @
e8246f78
...
...
@@ -470,10 +470,10 @@ class TrainingArguments:
sharded_ddp
:
str
=
field
(
default
=
""
,
metadata
=
{
"choices"
:
[
"simple"
,
"zero_dp_2"
,
"zero_dp_3"
,
"zero_dp_2 offload"
,
"zero_dp_3 offload"
],
"help"
:
"Whether or not to use sharded DDP training (in distributed training only). The base option "
"should be `simple`, `zero_dp_2` or `zero_dp_3` and you can add CPU-offload to `zero_dp_2` or `zero_dp_3` "
"like this: zero_dp_2 offload` or `zero_dp_3 offload`"
,
"like this: zero_dp_2 offload` or `zero_dp_3 offload`. You can add auto-wrap to `zero_dp_2` or "
"with the same syntax: zero_dp_2 auto_wrap` or `zero_dp_3 auto_wrap`."
,
},
)
deepspeed
:
Optional
[
str
]
=
field
(
...
...
@@ -570,7 +570,7 @@ class TrainingArguments:
"`--sharded_ddp offload` can't work on its own. It needs to be added to `--sharded_ddp zero_dp_2` or "
'`--sharded_ddp zero_dp_3`. For example, `--sharded_ddp "zero_dp_2 offload"`.'
)
elif
len
(
self
.
sharded_ddp
)
>
1
and
ShardedDDPOption
.
S
imple
in
self
.
sharded_ddp
:
elif
len
(
self
.
sharded_ddp
)
>
1
and
ShardedDDPOption
.
S
IMPLE
in
self
.
sharded_ddp
:
raise
ValueError
(
"`--sharded_ddp simple` is not compatible with any other option."
)
elif
ShardedDDPOption
.
ZERO_DP_2
in
self
.
sharded_ddp
and
ShardedDDPOption
.
ZERO_DP_3
in
self
.
sharded_ddp
:
raise
ValueError
(
"`--sharded_ddp zero_dp_2` is not compatible with `--sharded_ddp zero_dp_3`."
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment