Unverified Commit 29acabd8 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[trainer] group fp16 args together (#9409)

* [t5 doc] typos

a few run away backticks

@sgugger

* style

* [trainer] put fp16 args together

this PR proposes a purely cosmetic change that puts all the fp16 args together - so they are easier to manager/read

@sgugger

* style
parent 57a66269
...@@ -147,6 +147,10 @@ class TrainingArguments: ...@@ -147,6 +147,10 @@ class TrainingArguments:
fp16_opt_level (:obj:`str`, `optional`, defaults to 'O1'): fp16_opt_level (:obj:`str`, `optional`, defaults to 'O1'):
For :obj:`fp16` training, Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']. See details For :obj:`fp16` training, Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']. See details
on the `Apex documentation <https://nvidia.github.io/apex/amp.html>`__. on the `Apex documentation <https://nvidia.github.io/apex/amp.html>`__.
fp16_backend (:obj:`str`, `optional`, defaults to :obj:`"auto"`):
The backend to use for mixed precision training. Must be one of :obj:`"auto"`, :obj:`"amp"` or
:obj:`"apex"`. :obj:`"auto"` will use AMP or APEX depending on the PyTorch version detected, while the
other choices will force the requested backend.
local_rank (:obj:`int`, `optional`, defaults to -1): local_rank (:obj:`int`, `optional`, defaults to -1):
Rank of the process during distributed training. Rank of the process during distributed training.
tpu_num_cores (:obj:`int`, `optional`): tpu_num_cores (:obj:`int`, `optional`):
...@@ -213,10 +217,6 @@ class TrainingArguments: ...@@ -213,10 +217,6 @@ class TrainingArguments:
When resuming training, whether or not to skip the epochs and batches to get the data loading at the same When resuming training, whether or not to skip the epochs and batches to get the data loading at the same
stage as in the previous training. If set to :obj:`True`, the training will begin faster (as that skipping stage as in the previous training. If set to :obj:`True`, the training will begin faster (as that skipping
step can take a long time) but will not yield the same results as the interrupted training would have. step can take a long time) but will not yield the same results as the interrupted training would have.
fp16_backend (:obj:`str`, `optional`, defaults to :obj:`"auto"`):
The backend to use for mixed precision training. Must be one of :obj:`"auto"`, :obj:`"amp"` or
:obj:`"apex"`. :obj:`"auto"` will use AMP or APEX depending on the PyTorch version detected, while the
other choices will force the requested backend.
sharded_ddp (:obj:`bool`, `optional`, defaults to :obj:`False`): sharded_ddp (:obj:`bool`, `optional`, defaults to :obj:`False`):
Use Sharded DDP training from `FairScale <https://github.com/facebookresearch/fairscale>`__ (in distributed Use Sharded DDP training from `FairScale <https://github.com/facebookresearch/fairscale>`__ (in distributed
training only). This is an experimental feature. training only). This is an experimental feature.
...@@ -341,6 +341,10 @@ class TrainingArguments: ...@@ -341,6 +341,10 @@ class TrainingArguments:
) )
}, },
) )
fp16_backend: str = field(
default="auto",
metadata={"help": "The backend to be used for mixed precision.", "choices": ["auto", "amp", "apex"]},
)
local_rank: int = field(default=-1, metadata={"help": "For distributed training: local_rank"}) local_rank: int = field(default=-1, metadata={"help": "For distributed training: local_rank"})
tpu_num_cores: Optional[int] = field( tpu_num_cores: Optional[int] = field(
...@@ -398,10 +402,6 @@ class TrainingArguments: ...@@ -398,10 +402,6 @@ class TrainingArguments:
"help": "When resuming training, whether or not to skip the first epochs and batches to get to the same training data." "help": "When resuming training, whether or not to skip the first epochs and batches to get to the same training data."
}, },
) )
fp16_backend: str = field(
default="auto",
metadata={"help": "The backend to be used for mixed precision.", "choices": ["auto", "amp", "apex"]},
)
sharded_ddp: bool = field( sharded_ddp: bool = field(
default=False, default=False,
metadata={"help": "Whether or not to use sharded DDP training (in distributed training only)."}, metadata={"help": "Whether or not to use sharded DDP training (in distributed training only)."},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment