Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
e95d262f
Unverified
Commit
e95d262f
authored
Sep 03, 2020
by
Sam Shleifer
Committed by
GitHub
Sep 03, 2020
Browse files
[s2s] support early stopping based on loss, rather than rouge (#6927)
parent
207ed8cb
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
38 additions
and
21 deletions
+38
-21
examples/seq2seq/callbacks.py
examples/seq2seq/callbacks.py
+7
-5
examples/seq2seq/finetune.py
examples/seq2seq/finetune.py
+28
-13
examples/seq2seq/test_seq2seq_examples.py
examples/seq2seq/test_seq2seq_examples.py
+3
-3
No files found.
examples/seq2seq/callbacks.py
View file @
e95d262f
...
...
@@ -75,21 +75,23 @@ class Seq2SeqLoggingCallback(pl.Callback):
return
self
.
_write_logs
(
trainer
,
pl_module
,
"test"
)
def
get_checkpoint_callback
(
output_dir
,
metric
,
save_top_k
=
1
):
def
get_checkpoint_callback
(
output_dir
,
metric
,
save_top_k
=
1
,
lower_is_better
=
False
):
"""Saves the best model by validation ROUGE2 score."""
if
metric
==
"rouge2"
:
exp
=
"{val_avg_rouge2:.4f}-{step_count}"
elif
metric
==
"bleu"
:
exp
=
"{val_avg_bleu:.4f}-{step_count}"
elif
metric
==
"loss"
:
exp
=
"{val_avg_loss:.4f}-{step_count}"
else
:
raise
NotImplementedError
(
f
"seq2seq callbacks only support rouge2 and
bleu
, got
{
metric
}
, You can make your own by adding to this function."
f
"seq2seq callbacks only support rouge2
, bleu
and
loss
, got
{
metric
}
, You can make your own by adding to this function."
)
checkpoint_callback
=
ModelCheckpoint
(
filepath
=
os
.
path
.
join
(
output_dir
,
exp
),
monitor
=
f
"val_
{
metric
}
"
,
mode
=
"max"
,
mode
=
"min"
if
"loss"
in
metric
else
"max"
,
save_top_k
=
save_top_k
,
period
=
0
,
# maybe save a checkpoint every time val is run, not just end of epoch.
)
...
...
@@ -98,8 +100,8 @@ def get_checkpoint_callback(output_dir, metric, save_top_k=1):
def
get_early_stopping_callback
(
metric
,
patience
):
return
EarlyStopping
(
monitor
=
f
"val_
{
metric
}
"
,
mode
=
"max"
,
monitor
=
f
"val_
{
metric
}
"
,
# does this need avg?
mode
=
"min"
if
"loss"
in
metric
else
"max"
,
patience
=
patience
,
verbose
=
True
,
)
examples/seq2seq/finetune.py
View file @
e95d262f
...
...
@@ -148,10 +148,10 @@ class SummarizationModule(BaseTransformer):
lm_logits
=
outputs
[
0
]
if
self
.
hparams
.
label_smoothing
==
0
:
# Same behavior as modeling_bart.py, besides ignoring pad_token_id
loss_fct
=
torch
.
nn
.
CrossEntropyLoss
(
ignore_index
=
pad_token_id
)
ce_
loss_fct
=
torch
.
nn
.
CrossEntropyLoss
(
ignore_index
=
pad_token_id
)
assert
lm_logits
.
shape
[
-
1
]
==
self
.
model
.
config
.
vocab_size
loss
=
loss_fct
(
lm_logits
.
view
(
-
1
,
lm_logits
.
shape
[
-
1
]),
tgt_ids
.
view
(
-
1
))
loss
=
ce_
loss_fct
(
lm_logits
.
view
(
-
1
,
lm_logits
.
shape
[
-
1
]),
tgt_ids
.
view
(
-
1
))
else
:
lprobs
=
torch
.
nn
.
functional
.
log_softmax
(
lm_logits
,
dim
=-
1
)
loss
,
nll_loss
=
label_smoothed_nll_loss
(
...
...
@@ -178,15 +178,25 @@ class SummarizationModule(BaseTransformer):
self
.
step_count
+=
1
losses
=
{
k
:
torch
.
stack
([
x
[
k
]
for
x
in
outputs
]).
mean
()
for
k
in
self
.
loss_names
}
loss
=
losses
[
"loss"
]
rouges
=
{
k
:
np
.
array
([
x
[
k
]
for
x
in
outputs
]).
mean
()
for
k
in
self
.
metric_names
+
[
"gen_time"
,
"gen_len"
]}
rouge_tensor
:
torch
.
FloatTensor
=
torch
.
tensor
(
rouges
[
self
.
val_metric
]).
type_as
(
loss
)
rouges
.
update
({
k
:
v
.
item
()
for
k
,
v
in
losses
.
items
()})
losses
.
update
(
rouges
)
metrics
=
{
f
"
{
prefix
}
_avg_
{
k
}
"
:
x
for
k
,
x
in
losses
.
items
()}
metrics
[
"step_count"
]
=
self
.
step_count
self
.
save_metrics
(
metrics
,
prefix
)
# writes to self.metrics_save_path
generative_metrics
=
{
k
:
np
.
array
([
x
[
k
]
for
x
in
outputs
]).
mean
()
for
k
in
self
.
metric_names
+
[
"gen_time"
,
"gen_len"
]
}
metric_val
=
(
generative_metrics
[
self
.
val_metric
]
if
self
.
val_metric
in
generative_metrics
else
losses
[
self
.
val_metric
]
)
metric_tensor
:
torch
.
FloatTensor
=
torch
.
tensor
(
metric_val
).
type_as
(
loss
)
generative_metrics
.
update
({
k
:
v
.
item
()
for
k
,
v
in
losses
.
items
()})
losses
.
update
(
generative_metrics
)
all_metrics
=
{
f
"
{
prefix
}
_avg_
{
k
}
"
:
x
for
k
,
x
in
losses
.
items
()}
all_metrics
[
"step_count"
]
=
self
.
step_count
self
.
save_metrics
(
all_metrics
,
prefix
)
# writes to self.metrics_save_path
preds
=
flatten_list
([
x
[
"preds"
]
for
x
in
outputs
])
return
{
"log"
:
metrics
,
"preds"
:
preds
,
f
"
{
prefix
}
_loss"
:
loss
,
f
"
{
prefix
}
_
{
self
.
val_metric
}
"
:
rouge_tensor
}
return
{
"log"
:
all_metrics
,
"preds"
:
preds
,
f
"
{
prefix
}
_loss"
:
loss
,
f
"
{
prefix
}
_
{
self
.
val_metric
}
"
:
metric_tensor
,
}
def
save_metrics
(
self
,
latest_metrics
,
type_path
)
->
None
:
self
.
metrics
[
type_path
].
append
(
latest_metrics
)
...
...
@@ -306,7 +316,9 @@ class SummarizationModule(BaseTransformer):
parser
.
add_argument
(
"--src_lang"
,
type
=
str
,
default
=
""
,
required
=
False
)
parser
.
add_argument
(
"--tgt_lang"
,
type
=
str
,
default
=
""
,
required
=
False
)
parser
.
add_argument
(
"--eval_beams"
,
type
=
int
,
default
=
None
,
required
=
False
)
parser
.
add_argument
(
"--val_metric"
,
type
=
str
,
default
=
None
,
required
=
False
)
parser
.
add_argument
(
"--val_metric"
,
type
=
str
,
default
=
None
,
required
=
False
,
choices
=
[
"bleu"
,
"rouge2"
,
"loss"
,
None
]
)
parser
.
add_argument
(
"--save_top_k"
,
type
=
int
,
default
=
1
,
required
=
False
,
help
=
"How many checkpoints to save"
)
parser
.
add_argument
(
"--early_stopping_patience"
,
...
...
@@ -366,14 +378,17 @@ def main(args, model=None) -> SummarizationModule:
es_callback
=
get_early_stopping_callback
(
model
.
val_metric
,
args
.
early_stopping_patience
)
else
:
es_callback
=
False
lower_is_better
=
args
.
val_metric
==
"loss"
trainer
:
pl
.
Trainer
=
generic_train
(
model
,
args
,
logging_callback
=
Seq2SeqLoggingCallback
(),
checkpoint_callback
=
get_checkpoint_callback
(
args
.
output_dir
,
model
.
val_metric
,
args
.
save_top_k
),
checkpoint_callback
=
get_checkpoint_callback
(
args
.
output_dir
,
model
.
val_metric
,
args
.
save_top_k
,
lower_is_better
),
early_stopping_callback
=
es_callback
,
logger
=
logger
,
# TODO: early stopping callback seems messed up
)
pickle_save
(
model
.
hparams
,
model
.
output_dir
/
"hparams.pkl"
)
if
not
args
.
do_predict
:
...
...
examples/seq2seq/test_seq2seq_examples.py
View file @
e95d262f
...
...
@@ -33,7 +33,7 @@ CUDA_AVAILABLE = torch.cuda.is_available()
CHEAP_ARGS
=
{
"label_smoothing"
:
0.2
,
"eval_beams"
:
1
,
"val_metric"
:
None
,
"val_metric"
:
"loss"
,
"save_top_k"
:
1
,
"adafactor"
:
True
,
"early_stopping_patience"
:
2
,
...
...
@@ -262,9 +262,9 @@ class TestSummarizationDistiller(unittest.TestCase):
if
not
check_contents
:
return
model
contents
=
os
.
listdir
(
output_dir
)
ckpt_name
=
"val_avg_rouge2=0.0000-step_count=2.ckpt"
# "val_avg_rouge2=0.0000-epoch=1.ckpt" # "epoch=1-val_avg_rouge2=0.0000.ckpt"
contents
=
{
os
.
path
.
basename
(
p
)
for
p
in
contents
}
self
.
assertIn
(
ckpt_name
,
contents
)
ckpt_files
=
[
p
for
p
in
contents
if
p
.
endswith
(
"ckpt"
)]
assert
len
(
ckpt_files
)
>
0
self
.
assertIn
(
"test_generations.txt"
,
contents
)
self
.
assertIn
(
"test_results.txt"
,
contents
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment