Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
06a973fd
Unverified
Commit
06a973fd
authored
Oct 08, 2020
by
Suraj Patil
Committed by
GitHub
Oct 08, 2020
Browse files
[s2s] configure lr_scheduler from command line (#7641)
parent
4a00613c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
37 additions
and
3 deletions
+37
-3
examples/seq2seq/finetune_trainer.py
examples/seq2seq/finetune_trainer.py
+4
-1
examples/seq2seq/seq2seq_trainer.py
examples/seq2seq/seq2seq_trainer.py
+33
-2
No files found.
examples/seq2seq/finetune_trainer.py
View file @
06a973fd
...
...
@@ -4,7 +4,7 @@ import sys
from
dataclasses
import
dataclass
,
field
from
typing
import
Optional
from
seq2seq_trainer
import
Seq2SeqTrainer
from
seq2seq_trainer
import
Seq2SeqTrainer
,
arg_to_scheduler_choices
from
transformers
import
(
AutoConfig
,
AutoModelForSeq2SeqLM
,
...
...
@@ -63,6 +63,9 @@ class Seq2SeqTrainingArguments(TrainingArguments):
attention_dropout
:
Optional
[
float
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Attention dropout probability. Goes into model.config."
}
)
lr_scheduler
:
Optional
[
str
]
=
field
(
default
=
"linear"
,
metadata
=
{
"help"
:
f
"Which lr scheduler to use. Selected in
{
arg_to_scheduler_choices
}
"
}
)
@
dataclass
...
...
examples/seq2seq/seq2seq_trainer.py
View file @
06a973fd
...
...
@@ -8,7 +8,16 @@ from torch.utils.data import DistributedSampler, RandomSampler
from
transformers
import
Trainer
from
transformers.configuration_fsmt
import
FSMTConfig
from
transformers.file_utils
import
is_torch_tpu_available
from
transformers.optimization
import
Adafactor
,
AdamW
,
get_linear_schedule_with_warmup
from
transformers.optimization
import
(
Adafactor
,
AdamW
,
get_constant_schedule
,
get_constant_schedule_with_warmup
,
get_cosine_schedule_with_warmup
,
get_cosine_with_hard_restarts_schedule_with_warmup
,
get_linear_schedule_with_warmup
,
get_polynomial_decay_schedule_with_warmup
,
)
from
transformers.trainer_pt_utils
import
get_tpu_sampler
...
...
@@ -20,6 +29,16 @@ except ImportError:
logger
=
logging
.
getLogger
(
__name__
)
arg_to_scheduler
=
{
"linear"
:
get_linear_schedule_with_warmup
,
"cosine"
:
get_cosine_schedule_with_warmup
,
"cosine_w_restarts"
:
get_cosine_with_hard_restarts_schedule_with_warmup
,
"polynomial"
:
get_polynomial_decay_schedule_with_warmup
,
"constant"
:
get_constant_schedule
,
"constant_w_warmup"
:
get_constant_schedule_with_warmup
,
}
arg_to_scheduler_choices
=
sorted
(
arg_to_scheduler
.
keys
())
class
Seq2SeqTrainer
(
Trainer
):
def
__init__
(
self
,
config
,
data_args
,
*
args
,
**
kwargs
):
...
...
@@ -62,9 +81,21 @@ class Seq2SeqTrainer(Trainer):
)
if
self
.
lr_scheduler
is
None
:
self
.
lr_scheduler
=
get_linear_schedule_with_warmup
(
self
.
lr_scheduler
=
self
.
_get_lr_scheduler
(
num_training_steps
)
else
:
# ignoring --lr_scheduler
logger
.
warn
(
"scheduler is passed to `Seq2SeqTrainer`, `--lr_scheduler` arg is ignored."
)
def
_get_lr_scheduler
(
self
,
num_training_steps
):
schedule_func
=
arg_to_scheduler
[
self
.
args
.
lr_scheduler
]
if
self
.
args
.
lr_scheduler
==
"constant"
:
scheduler
=
schedule_func
(
self
.
optimizer
)
elif
self
.
args
.
lr_scheduler
==
"constant_w_warmup"
:
scheduler
=
schedule_func
(
self
.
optimizer
,
num_warmup_steps
=
self
.
args
.
warmup_steps
)
else
:
scheduler
=
schedule_func
(
self
.
optimizer
,
num_warmup_steps
=
self
.
args
.
warmup_steps
,
num_training_steps
=
num_training_steps
)
return
scheduler
def
_get_train_sampler
(
self
)
->
Optional
[
torch
.
utils
.
data
.
sampler
.
Sampler
]:
if
isinstance
(
self
.
train_dataset
,
torch
.
utils
.
data
.
IterableDataset
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment