Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
068e6b5e
Unverified
Commit
068e6b5e
authored
Nov 03, 2020
by
Patrick von Platen
Committed by
GitHub
Nov 03, 2020
Browse files
make files independent (#8267)
parent
cd360dcb
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
48 additions
and
46 deletions
+48
-46
examples/seq2seq/finetune_trainer.py
examples/seq2seq/finetune_trainer.py
+3
-45
examples/seq2seq/seq2seq_trainer.py
examples/seq2seq/seq2seq_trainer.py
+0
-1
examples/seq2seq/seq2seq_training_args.py
examples/seq2seq/seq2seq_training_args.py
+45
-0
No files found.
examples/seq2seq/finetune_trainer.py
View file @
068e6b5e
...
...
@@ -4,16 +4,9 @@ import sys
from
dataclasses
import
dataclass
,
field
from
typing
import
Optional
from
seq2seq_trainer
import
Seq2SeqTrainer
,
arg_to_scheduler_choices
from
transformers
import
(
AutoConfig
,
AutoModelForSeq2SeqLM
,
AutoTokenizer
,
HfArgumentParser
,
MBartTokenizer
,
TrainingArguments
,
set_seed
,
)
from
seq2seq_trainer
import
Seq2SeqTrainer
from
seq2seq_training_args
import
Seq2SeqTrainingArguments
from
transformers
import
AutoConfig
,
AutoModelForSeq2SeqLM
,
AutoTokenizer
,
HfArgumentParser
,
MBartTokenizer
,
set_seed
from
transformers.trainer_utils
import
EvaluationStrategy
from
utils
import
(
Seq2SeqDataCollator
,
...
...
@@ -33,41 +26,6 @@ from utils import (
logger
=
logging
.
getLogger
(
__name__
)
@
dataclass
class
Seq2SeqTrainingArguments
(
TrainingArguments
):
"""
Parameters:
label_smoothing (:obj:`float`, `optional`, defaults to 0):
The label smoothing epsilon to apply (if not zero).
sortish_sampler (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to SortishSamler or not. It sorts the inputs according to lenghts in-order to minimizing the padding size.
predict_with_generate (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to use generate to calculate generative metrics (ROUGE, BLEU).
"""
label_smoothing
:
Optional
[
float
]
=
field
(
default
=
0.0
,
metadata
=
{
"help"
:
"The label smoothing epsilon to apply (if not zero)."
}
)
sortish_sampler
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether to SortishSamler or not."
})
predict_with_generate
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether to use generate to calculate generative metrics (ROUGE, BLEU)."
}
)
adafactor
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"whether to use adafactor"
})
encoder_layerdrop
:
Optional
[
float
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Encoder layer dropout probability. Goes into model.config."
}
)
decoder_layerdrop
:
Optional
[
float
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Decoder layer dropout probability. Goes into model.config."
}
)
dropout
:
Optional
[
float
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Dropout probability. Goes into model.config."
})
attention_dropout
:
Optional
[
float
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Attention dropout probability. Goes into model.config."
}
)
lr_scheduler
:
Optional
[
str
]
=
field
(
default
=
"linear"
,
metadata
=
{
"help"
:
f
"Which lr scheduler to use. Selected in
{
arg_to_scheduler_choices
}
"
}
)
@
dataclass
class
ModelArguments
:
"""
...
...
examples/seq2seq/seq2seq_trainer.py
View file @
068e6b5e
...
...
@@ -30,7 +30,6 @@ arg_to_scheduler = {
"constant"
:
get_constant_schedule
,
"constant_w_warmup"
:
get_constant_schedule_with_warmup
,
}
arg_to_scheduler_choices
=
sorted
(
arg_to_scheduler
.
keys
())
class
Seq2SeqTrainer
(
Trainer
):
...
...
examples/seq2seq/seq2seq_training_args.py
0 → 100644
View file @
068e6b5e
import
logging
from
dataclasses
import
dataclass
,
field
from
typing
import
Optional
from
seq2seq_trainer
import
arg_to_scheduler
from
transformers
import
TrainingArguments
logger
=
logging
.
getLogger
(
__name__
)
@
dataclass
class
Seq2SeqTrainingArguments
(
TrainingArguments
):
"""
Parameters:
label_smoothing (:obj:`float`, `optional`, defaults to 0):
The label smoothing epsilon to apply (if not zero).
sortish_sampler (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to SortishSamler or not. It sorts the inputs according to lenghts in-order to minimizing the padding size.
predict_with_generate (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to use generate to calculate generative metrics (ROUGE, BLEU).
"""
label_smoothing
:
Optional
[
float
]
=
field
(
default
=
0.0
,
metadata
=
{
"help"
:
"The label smoothing epsilon to apply (if not zero)."
}
)
sortish_sampler
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether to SortishSamler or not."
})
predict_with_generate
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether to use generate to calculate generative metrics (ROUGE, BLEU)."
}
)
adafactor
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"whether to use adafactor"
})
encoder_layerdrop
:
Optional
[
float
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Encoder layer dropout probability. Goes into model.config."
}
)
decoder_layerdrop
:
Optional
[
float
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Decoder layer dropout probability. Goes into model.config."
}
)
dropout
:
Optional
[
float
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Dropout probability. Goes into model.config."
})
attention_dropout
:
Optional
[
float
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Attention dropout probability. Goes into model.config."
}
)
lr_scheduler
:
Optional
[
str
]
=
field
(
default
=
"linear"
,
metadata
=
{
"help"
:
f
"Which lr scheduler to use. Selected in
{
sorted
(
arg_to_scheduler
.
keys
())
}
"
},
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment