Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
b9772897
Unverified
Commit
b9772897
authored
Aug 31, 2020
by
Sam Shleifer
Committed by
GitHub
Aug 31, 2020
Browse files
[s2s] command line args for faster val steps (#6833)
parent
8af1970e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
10 additions
and
3 deletions
+10
-3
examples/seq2seq/distillation.py
examples/seq2seq/distillation.py
+1
-1
examples/seq2seq/finetune.py
examples/seq2seq/finetune.py
+7
-2
examples/seq2seq/test_seq2seq_examples.py
examples/seq2seq/test_seq2seq_examples.py
+2
-0
No files found.
examples/seq2seq/distillation.py
View file @
b9772897
...
...
@@ -262,7 +262,7 @@ class BartTranslationDistiller(BartSummarizationDistiller):
mode
=
"translation"
metric_names
=
[
"bleu"
]
val_metric
=
"bleu"
default_
val_metric
=
"bleu"
def
__init__
(
self
,
hparams
,
**
kwargs
):
super
().
__init__
(
hparams
,
**
kwargs
)
...
...
examples/seq2seq/finetune.py
View file @
b9772897
...
...
@@ -63,7 +63,7 @@ class SummarizationModule(BaseTransformer):
mode
=
"summarization"
loss_names
=
[
"loss"
]
metric_names
=
ROUGE_KEYS
val_metric
=
"rouge2"
default_
val_metric
=
"rouge2"
def
__init__
(
self
,
hparams
,
**
kwargs
):
super
().
__init__
(
hparams
,
num_labels
=
None
,
mode
=
self
.
mode
,
**
kwargs
)
...
...
@@ -110,6 +110,9 @@ class SummarizationModule(BaseTransformer):
self
.
dataset_class
=
(
Seq2SeqDataset
if
hasattr
(
self
.
tokenizer
,
"prepare_seq2seq_batch"
)
else
LegacySeq2SeqDataset
)
self
.
eval_beams
=
self
.
model
.
config
.
num_beams
if
self
.
hparams
.
eval_beams
is
None
else
self
.
hparams
.
eval_beams
assert
self
.
eval_beams
>=
1
,
f
"got self.eval_beams=
{
self
.
eval_beams
}
. Need an integer > 1"
self
.
val_metric
=
self
.
default_val_metric
if
self
.
hparams
.
val_metric
is
None
else
self
.
hparams
.
val_metric
def
freeze_embeds
(
self
):
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
...
...
@@ -301,6 +304,8 @@ class SummarizationModule(BaseTransformer):
parser
.
add_argument
(
"--label_smoothing"
,
type
=
float
,
default
=
0.0
,
required
=
False
)
parser
.
add_argument
(
"--src_lang"
,
type
=
str
,
default
=
""
,
required
=
False
)
parser
.
add_argument
(
"--tgt_lang"
,
type
=
str
,
default
=
""
,
required
=
False
)
parser
.
add_argument
(
"--eval_beams"
,
type
=
int
,
default
=
None
,
required
=
False
)
parser
.
add_argument
(
"--val_metric"
,
type
=
str
,
default
=
None
,
required
=
False
)
parser
.
add_argument
(
"--early_stopping_patience"
,
type
=
int
,
...
...
@@ -315,7 +320,7 @@ class TranslationModule(SummarizationModule):
mode
=
"translation"
loss_names
=
[
"loss"
]
metric_names
=
[
"bleu"
]
val_metric
=
"bleu"
default_
val_metric
=
"bleu"
def
__init__
(
self
,
hparams
,
**
kwargs
):
super
().
__init__
(
hparams
,
**
kwargs
)
...
...
examples/seq2seq/test_seq2seq_examples.py
View file @
b9772897
...
...
@@ -31,6 +31,8 @@ logger = logging.getLogger()
CUDA_AVAILABLE
=
torch
.
cuda
.
is_available
()
CHEAP_ARGS
=
{
"label_smoothing"
:
0.2
,
"eval_beams"
:
1
,
"val_metric"
:
None
,
"adafactor"
:
True
,
"early_stopping_patience"
:
2
,
"logger_name"
:
"default"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment