Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
de4d7b00
Unverified
Commit
de4d7b00
authored
Oct 01, 2020
by
Sam Shleifer
Committed by
GitHub
Oct 01, 2020
Browse files
[s2s] Adafactor support for builtin trainer (#7522)
parent
d3a9601a
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
40 additions
and
0 deletions
+40
-0
examples/seq2seq/finetune_trainer.py
examples/seq2seq/finetune_trainer.py
+1
-0
examples/seq2seq/seq2seq_trainer.py
examples/seq2seq/seq2seq_trainer.py
+38
-0
examples/seq2seq/test_finetune_trainer.py
examples/seq2seq/test_finetune_trainer.py
+1
-0
No files found.
examples/seq2seq/finetune_trainer.py
View file @
de4d7b00
...
...
@@ -52,6 +52,7 @@ class Seq2SeqTrainingArguments(TrainingArguments):
predict_with_generate
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether to use generate to calculate generative metrics (ROUGE, BLEU)."
}
)
adafactor
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"whether to use adafactor"
})
@
dataclass
...
...
examples/seq2seq/seq2seq_trainer.py
View file @
de4d7b00
...
...
@@ -7,6 +7,7 @@ from torch.utils.data import DistributedSampler, RandomSampler
from
transformers
import
Trainer
from
transformers.file_utils
import
is_torch_tpu_available
from
transformers.optimization
import
Adafactor
,
AdamW
,
get_linear_schedule_with_warmup
from
transformers.trainer
import
get_tpu_sampler
...
...
@@ -28,6 +29,43 @@ class Seq2SeqTrainer(Trainer):
self
.
pad_token_id
=
self
.
config
.
pad_token_id
self
.
vocab_size
=
self
.
config
.
vocab_size
def
create_optimizer_and_scheduler
(
self
,
num_training_steps
:
int
):
"""
Setup the optimizer and the learning rate scheduler.
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass.
"""
if
self
.
optimizer
is
None
:
no_decay
=
[
"bias"
,
"LayerNorm.weight"
]
optimizer_grouped_parameters
=
[
{
"params"
:
[
p
for
n
,
p
in
self
.
model
.
named_parameters
()
if
not
any
(
nd
in
n
for
nd
in
no_decay
)],
"weight_decay"
:
self
.
args
.
weight_decay
,
},
{
"params"
:
[
p
for
n
,
p
in
self
.
model
.
named_parameters
()
if
any
(
nd
in
n
for
nd
in
no_decay
)],
"weight_decay"
:
0.0
,
},
]
if
self
.
args
.
adafactor
:
self
.
optimizer
=
Adafactor
(
optimizer_grouped_parameters
,
lr
=
self
.
args
.
learning_rate
,
scale_parameter
=
False
,
relative_step
=
False
,
)
else
:
self
.
optimizer
=
AdamW
(
optimizer_grouped_parameters
,
lr
=
self
.
args
.
learning_rate
,
eps
=
self
.
args
.
adam_epsilon
)
if
self
.
lr_scheduler
is
None
:
self
.
lr_scheduler
=
get_linear_schedule_with_warmup
(
self
.
optimizer
,
num_warmup_steps
=
self
.
args
.
warmup_steps
,
num_training_steps
=
num_training_steps
)
def
_get_train_sampler
(
self
)
->
Optional
[
torch
.
utils
.
data
.
sampler
.
Sampler
]:
if
isinstance
(
self
.
train_dataset
,
torch
.
utils
.
data
.
IterableDataset
):
return
None
...
...
examples/seq2seq/test_finetune_trainer.py
View file @
de4d7b00
...
...
@@ -91,6 +91,7 @@ def run_trainer(eval_steps: int, max_len: str, model_name: str, num_train_epochs
"0.1"
,
# "--eval_beams",
# "2",
"--adafactor"
,
"--task"
,
"translation"
,
"--tgt_lang"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment