Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
39ed68d5
Unverified
Commit
39ed68d5
authored
Sep 03, 2020
by
Sam Shleifer
Committed by
GitHub
Sep 03, 2020
Browse files
[s2s] allow task_specific_params=summarization_xsum (#6923)
parent
5a318f07
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
6 additions
and
4 deletions
+6
-4
examples/seq2seq/callbacks.py
examples/seq2seq/callbacks.py
+2
-2
examples/seq2seq/finetune.py
examples/seq2seq/finetune.py
+3
-2
examples/seq2seq/test_seq2seq_examples.py
examples/seq2seq/test_seq2seq_examples.py
+1
-0
No files found.
examples/seq2seq/callbacks.py
View file @
39ed68d5
...
...
@@ -75,7 +75,7 @@ class Seq2SeqLoggingCallback(pl.Callback):
return
self
.
_write_logs
(
trainer
,
pl_module
,
"test"
)
def
get_checkpoint_callback
(
output_dir
,
metric
):
def
get_checkpoint_callback
(
output_dir
,
metric
,
save_top_k
=
1
):
"""Saves the best model by validation ROUGE2 score."""
if
metric
==
"rouge2"
:
exp
=
"{val_avg_rouge2:.4f}-{step_count}"
...
...
@@ -90,7 +90,7 @@ def get_checkpoint_callback(output_dir, metric):
filepath
=
os
.
path
.
join
(
output_dir
,
exp
),
monitor
=
f
"val_
{
metric
}
"
,
mode
=
"max"
,
save_top_k
=
1
,
save_top_k
=
save_top_k
,
period
=
0
,
# maybe save a checkpoint every time val is run, not just end of epoch.
)
return
checkpoint_callback
...
...
examples/seq2seq/finetune.py
View file @
39ed68d5
...
...
@@ -306,6 +306,7 @@ class SummarizationModule(BaseTransformer):
parser
.
add_argument
(
"--tgt_lang"
,
type
=
str
,
default
=
""
,
required
=
False
)
parser
.
add_argument
(
"--eval_beams"
,
type
=
int
,
default
=
None
,
required
=
False
)
parser
.
add_argument
(
"--val_metric"
,
type
=
str
,
default
=
None
,
required
=
False
)
parser
.
add_argument
(
"--save_top_k"
,
type
=
int
,
default
=
1
,
required
=
False
,
help
=
"How many checkpoints to save"
)
parser
.
add_argument
(
"--early_stopping_patience"
,
type
=
int
,
...
...
@@ -336,7 +337,7 @@ def main(args, model=None) -> SummarizationModule:
if
len
(
os
.
listdir
(
args
.
output_dir
))
>
3
and
args
.
do_train
:
raise
ValueError
(
"Output directory ({}) already exists and is not empty."
.
format
(
args
.
output_dir
))
if
model
is
None
:
if
args
.
task
==
"summarization"
:
if
"summarization"
in
args
.
task
:
model
:
SummarizationModule
=
SummarizationModule
(
args
)
else
:
model
:
SummarizationModule
=
TranslationModule
(
args
)
...
...
@@ -368,7 +369,7 @@ def main(args, model=None) -> SummarizationModule:
model
,
args
,
logging_callback
=
Seq2SeqLoggingCallback
(),
checkpoint_callback
=
get_checkpoint_callback
(
args
.
output_dir
,
model
.
val_metric
),
checkpoint_callback
=
get_checkpoint_callback
(
args
.
output_dir
,
model
.
val_metric
,
args
.
save_top_k
),
early_stopping_callback
=
es_callback
,
logger
=
logger
,
# TODO: early stopping callback seems messed up
...
...
examples/seq2seq/test_seq2seq_examples.py
View file @
39ed68d5
...
...
@@ -34,6 +34,7 @@ CHEAP_ARGS = {
"label_smoothing"
:
0.2
,
"eval_beams"
:
1
,
"val_metric"
:
None
,
"save_top_k"
:
1
,
"adafactor"
:
True
,
"early_stopping_patience"
:
2
,
"logger_name"
:
"default"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment