Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
06a973fd
Unverified
Commit
06a973fd
authored
Oct 08, 2020
by
Suraj Patil
Committed by
GitHub
Oct 08, 2020
Browse files
[s2s] configure lr_scheduler from command line (#7641)
parent
4a00613c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
37 additions
and
3 deletions
+37
-3
examples/seq2seq/finetune_trainer.py
examples/seq2seq/finetune_trainer.py
+4
-1
examples/seq2seq/seq2seq_trainer.py
examples/seq2seq/seq2seq_trainer.py
+33
-2
No files found.
examples/seq2seq/finetune_trainer.py
View file @
06a973fd
...
...
@@ -4,7 +4,7 @@ import sys
from
dataclasses
import
dataclass
,
field
from
typing
import
Optional
from
seq2seq_trainer
import
Seq2SeqTrainer
from
seq2seq_trainer
import
Seq2SeqTrainer
,
arg_to_scheduler_choices
from
transformers
import
(
AutoConfig
,
AutoModelForSeq2SeqLM
,
...
...
@@ -63,6 +63,9 @@ class Seq2SeqTrainingArguments(TrainingArguments):
attention_dropout
:
Optional
[
float
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Attention dropout probability. Goes into model.config."
}
)
lr_scheduler
:
Optional
[
str
]
=
field
(
default
=
"linear"
,
metadata
=
{
"help"
:
f
"Which lr scheduler to use. Selected in
{
arg_to_scheduler_choices
}
"
}
)
@
dataclass
...
...
examples/seq2seq/seq2seq_trainer.py
View file @
06a973fd
...
...
@@ -8,7 +8,16 @@ from torch.utils.data import DistributedSampler, RandomSampler
from
transformers
import
Trainer
from
transformers.configuration_fsmt
import
FSMTConfig
from
transformers.file_utils
import
is_torch_tpu_available
from
transformers.optimization
import
Adafactor
,
AdamW
,
get_linear_schedule_with_warmup
from
transformers.optimization
import
(
Adafactor
,
AdamW
,
get_constant_schedule
,
get_constant_schedule_with_warmup
,
get_cosine_schedule_with_warmup
,
get_cosine_with_hard_restarts_schedule_with_warmup
,
get_linear_schedule_with_warmup
,
get_polynomial_decay_schedule_with_warmup
,
)
from
transformers.trainer_pt_utils
import
get_tpu_sampler
...
...
@@ -20,6 +29,16 @@ except ImportError:
logger
=
logging
.
getLogger
(
__name__
)
arg_to_scheduler
=
{
"linear"
:
get_linear_schedule_with_warmup
,
"cosine"
:
get_cosine_schedule_with_warmup
,
"cosine_w_restarts"
:
get_cosine_with_hard_restarts_schedule_with_warmup
,
"polynomial"
:
get_polynomial_decay_schedule_with_warmup
,
"constant"
:
get_constant_schedule
,
"constant_w_warmup"
:
get_constant_schedule_with_warmup
,
}
arg_to_scheduler_choices
=
sorted
(
arg_to_scheduler
.
keys
())
class
Seq2SeqTrainer
(
Trainer
):
def
__init__
(
self
,
config
,
data_args
,
*
args
,
**
kwargs
):
...
...
@@ -62,9 +81,21 @@ class Seq2SeqTrainer(Trainer):
)
if
self
.
lr_scheduler
is
None
:
self
.
lr_scheduler
=
get_linear_schedule_with_warmup
(
self
.
lr_scheduler
=
self
.
_get_lr_scheduler
(
num_training_steps
)
else
:
# ignoring --lr_scheduler
logger
.
warn
(
"scheduler is passed to `Seq2SeqTrainer`, `--lr_scheduler` arg is ignored."
)
def
_get_lr_scheduler
(
self
,
num_training_steps
):
schedule_func
=
arg_to_scheduler
[
self
.
args
.
lr_scheduler
]
if
self
.
args
.
lr_scheduler
==
"constant"
:
scheduler
=
schedule_func
(
self
.
optimizer
)
elif
self
.
args
.
lr_scheduler
==
"constant_w_warmup"
:
scheduler
=
schedule_func
(
self
.
optimizer
,
num_warmup_steps
=
self
.
args
.
warmup_steps
)
else
:
scheduler
=
schedule_func
(
self
.
optimizer
,
num_warmup_steps
=
self
.
args
.
warmup_steps
,
num_training_steps
=
num_training_steps
)
return
scheduler
def
_get_train_sampler
(
self
)
->
Optional
[
torch
.
utils
.
data
.
sampler
.
Sampler
]:
if
isinstance
(
self
.
train_dataset
,
torch
.
utils
.
data
.
IterableDataset
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment