Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
e95d262f
Unverified
Commit
e95d262f
authored
Sep 03, 2020
by
Sam Shleifer
Committed by
GitHub
Sep 03, 2020
Browse files
[s2s] support early stopping based on loss, rather than rouge (#6927)
parent
207ed8cb
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
38 additions
and
21 deletions
+38
-21
examples/seq2seq/callbacks.py
examples/seq2seq/callbacks.py
+7
-5
examples/seq2seq/finetune.py
examples/seq2seq/finetune.py
+28
-13
examples/seq2seq/test_seq2seq_examples.py
examples/seq2seq/test_seq2seq_examples.py
+3
-3
No files found.
examples/seq2seq/callbacks.py
View file @
e95d262f
...
@@ -75,21 +75,23 @@ class Seq2SeqLoggingCallback(pl.Callback):
...
@@ -75,21 +75,23 @@ class Seq2SeqLoggingCallback(pl.Callback):
return
self
.
_write_logs
(
trainer
,
pl_module
,
"test"
)
return
self
.
_write_logs
(
trainer
,
pl_module
,
"test"
)
def
get_checkpoint_callback
(
output_dir
,
metric
,
save_top_k
=
1
):
def
get_checkpoint_callback
(
output_dir
,
metric
,
save_top_k
=
1
,
lower_is_better
=
False
):
"""Saves the best model by validation ROUGE2 score."""
"""Saves the best model by validation ROUGE2 score."""
if
metric
==
"rouge2"
:
if
metric
==
"rouge2"
:
exp
=
"{val_avg_rouge2:.4f}-{step_count}"
exp
=
"{val_avg_rouge2:.4f}-{step_count}"
elif
metric
==
"bleu"
:
elif
metric
==
"bleu"
:
exp
=
"{val_avg_bleu:.4f}-{step_count}"
exp
=
"{val_avg_bleu:.4f}-{step_count}"
elif
metric
==
"loss"
:
exp
=
"{val_avg_loss:.4f}-{step_count}"
else
:
else
:
raise
NotImplementedError
(
raise
NotImplementedError
(
f
"seq2seq callbacks only support rouge2 and
bleu
, got
{
metric
}
, You can make your own by adding to this function."
f
"seq2seq callbacks only support rouge2
, bleu
and
loss
, got
{
metric
}
, You can make your own by adding to this function."
)
)
checkpoint_callback
=
ModelCheckpoint
(
checkpoint_callback
=
ModelCheckpoint
(
filepath
=
os
.
path
.
join
(
output_dir
,
exp
),
filepath
=
os
.
path
.
join
(
output_dir
,
exp
),
monitor
=
f
"val_
{
metric
}
"
,
monitor
=
f
"val_
{
metric
}
"
,
mode
=
"max"
,
mode
=
"min"
if
"loss"
in
metric
else
"max"
,
save_top_k
=
save_top_k
,
save_top_k
=
save_top_k
,
period
=
0
,
# maybe save a checkpoint every time val is run, not just end of epoch.
period
=
0
,
# maybe save a checkpoint every time val is run, not just end of epoch.
)
)
...
@@ -98,8 +100,8 @@ def get_checkpoint_callback(output_dir, metric, save_top_k=1):
...
@@ -98,8 +100,8 @@ def get_checkpoint_callback(output_dir, metric, save_top_k=1):
def
get_early_stopping_callback
(
metric
,
patience
):
def
get_early_stopping_callback
(
metric
,
patience
):
return
EarlyStopping
(
return
EarlyStopping
(
monitor
=
f
"val_
{
metric
}
"
,
monitor
=
f
"val_
{
metric
}
"
,
# does this need avg?
mode
=
"max"
,
mode
=
"min"
if
"loss"
in
metric
else
"max"
,
patience
=
patience
,
patience
=
patience
,
verbose
=
True
,
verbose
=
True
,
)
)
examples/seq2seq/finetune.py
View file @
e95d262f
...
@@ -148,10 +148,10 @@ class SummarizationModule(BaseTransformer):
...
@@ -148,10 +148,10 @@ class SummarizationModule(BaseTransformer):
lm_logits
=
outputs
[
0
]
lm_logits
=
outputs
[
0
]
if
self
.
hparams
.
label_smoothing
==
0
:
if
self
.
hparams
.
label_smoothing
==
0
:
# Same behavior as modeling_bart.py, besides ignoring pad_token_id
# Same behavior as modeling_bart.py, besides ignoring pad_token_id
loss_fct
=
torch
.
nn
.
CrossEntropyLoss
(
ignore_index
=
pad_token_id
)
ce_
loss_fct
=
torch
.
nn
.
CrossEntropyLoss
(
ignore_index
=
pad_token_id
)
assert
lm_logits
.
shape
[
-
1
]
==
self
.
model
.
config
.
vocab_size
assert
lm_logits
.
shape
[
-
1
]
==
self
.
model
.
config
.
vocab_size
loss
=
loss_fct
(
lm_logits
.
view
(
-
1
,
lm_logits
.
shape
[
-
1
]),
tgt_ids
.
view
(
-
1
))
loss
=
ce_
loss_fct
(
lm_logits
.
view
(
-
1
,
lm_logits
.
shape
[
-
1
]),
tgt_ids
.
view
(
-
1
))
else
:
else
:
lprobs
=
torch
.
nn
.
functional
.
log_softmax
(
lm_logits
,
dim
=-
1
)
lprobs
=
torch
.
nn
.
functional
.
log_softmax
(
lm_logits
,
dim
=-
1
)
loss
,
nll_loss
=
label_smoothed_nll_loss
(
loss
,
nll_loss
=
label_smoothed_nll_loss
(
...
@@ -178,15 +178,25 @@ class SummarizationModule(BaseTransformer):
...
@@ -178,15 +178,25 @@ class SummarizationModule(BaseTransformer):
self
.
step_count
+=
1
self
.
step_count
+=
1
losses
=
{
k
:
torch
.
stack
([
x
[
k
]
for
x
in
outputs
]).
mean
()
for
k
in
self
.
loss_names
}
losses
=
{
k
:
torch
.
stack
([
x
[
k
]
for
x
in
outputs
]).
mean
()
for
k
in
self
.
loss_names
}
loss
=
losses
[
"loss"
]
loss
=
losses
[
"loss"
]
rouges
=
{
k
:
np
.
array
([
x
[
k
]
for
x
in
outputs
]).
mean
()
for
k
in
self
.
metric_names
+
[
"gen_time"
,
"gen_len"
]}
generative_metrics
=
{
rouge_tensor
:
torch
.
FloatTensor
=
torch
.
tensor
(
rouges
[
self
.
val_metric
]).
type_as
(
loss
)
k
:
np
.
array
([
x
[
k
]
for
x
in
outputs
]).
mean
()
for
k
in
self
.
metric_names
+
[
"gen_time"
,
"gen_len"
]
rouges
.
update
({
k
:
v
.
item
()
for
k
,
v
in
losses
.
items
()})
}
losses
.
update
(
rouges
)
metric_val
=
(
metrics
=
{
f
"
{
prefix
}
_avg_
{
k
}
"
:
x
for
k
,
x
in
losses
.
items
()}
generative_metrics
[
self
.
val_metric
]
if
self
.
val_metric
in
generative_metrics
else
losses
[
self
.
val_metric
]
metrics
[
"step_count"
]
=
self
.
step_count
)
self
.
save_metrics
(
metrics
,
prefix
)
# writes to self.metrics_save_path
metric_tensor
:
torch
.
FloatTensor
=
torch
.
tensor
(
metric_val
).
type_as
(
loss
)
generative_metrics
.
update
({
k
:
v
.
item
()
for
k
,
v
in
losses
.
items
()})
losses
.
update
(
generative_metrics
)
all_metrics
=
{
f
"
{
prefix
}
_avg_
{
k
}
"
:
x
for
k
,
x
in
losses
.
items
()}
all_metrics
[
"step_count"
]
=
self
.
step_count
self
.
save_metrics
(
all_metrics
,
prefix
)
# writes to self.metrics_save_path
preds
=
flatten_list
([
x
[
"preds"
]
for
x
in
outputs
])
preds
=
flatten_list
([
x
[
"preds"
]
for
x
in
outputs
])
return
{
"log"
:
metrics
,
"preds"
:
preds
,
f
"
{
prefix
}
_loss"
:
loss
,
f
"
{
prefix
}
_
{
self
.
val_metric
}
"
:
rouge_tensor
}
return
{
"log"
:
all_metrics
,
"preds"
:
preds
,
f
"
{
prefix
}
_loss"
:
loss
,
f
"
{
prefix
}
_
{
self
.
val_metric
}
"
:
metric_tensor
,
}
def
save_metrics
(
self
,
latest_metrics
,
type_path
)
->
None
:
def
save_metrics
(
self
,
latest_metrics
,
type_path
)
->
None
:
self
.
metrics
[
type_path
].
append
(
latest_metrics
)
self
.
metrics
[
type_path
].
append
(
latest_metrics
)
...
@@ -306,7 +316,9 @@ class SummarizationModule(BaseTransformer):
...
@@ -306,7 +316,9 @@ class SummarizationModule(BaseTransformer):
parser
.
add_argument
(
"--src_lang"
,
type
=
str
,
default
=
""
,
required
=
False
)
parser
.
add_argument
(
"--src_lang"
,
type
=
str
,
default
=
""
,
required
=
False
)
parser
.
add_argument
(
"--tgt_lang"
,
type
=
str
,
default
=
""
,
required
=
False
)
parser
.
add_argument
(
"--tgt_lang"
,
type
=
str
,
default
=
""
,
required
=
False
)
parser
.
add_argument
(
"--eval_beams"
,
type
=
int
,
default
=
None
,
required
=
False
)
parser
.
add_argument
(
"--eval_beams"
,
type
=
int
,
default
=
None
,
required
=
False
)
parser
.
add_argument
(
"--val_metric"
,
type
=
str
,
default
=
None
,
required
=
False
)
parser
.
add_argument
(
"--val_metric"
,
type
=
str
,
default
=
None
,
required
=
False
,
choices
=
[
"bleu"
,
"rouge2"
,
"loss"
,
None
]
)
parser
.
add_argument
(
"--save_top_k"
,
type
=
int
,
default
=
1
,
required
=
False
,
help
=
"How many checkpoints to save"
)
parser
.
add_argument
(
"--save_top_k"
,
type
=
int
,
default
=
1
,
required
=
False
,
help
=
"How many checkpoints to save"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--early_stopping_patience"
,
"--early_stopping_patience"
,
...
@@ -366,14 +378,17 @@ def main(args, model=None) -> SummarizationModule:
...
@@ -366,14 +378,17 @@ def main(args, model=None) -> SummarizationModule:
es_callback
=
get_early_stopping_callback
(
model
.
val_metric
,
args
.
early_stopping_patience
)
es_callback
=
get_early_stopping_callback
(
model
.
val_metric
,
args
.
early_stopping_patience
)
else
:
else
:
es_callback
=
False
es_callback
=
False
lower_is_better
=
args
.
val_metric
==
"loss"
trainer
:
pl
.
Trainer
=
generic_train
(
trainer
:
pl
.
Trainer
=
generic_train
(
model
,
model
,
args
,
args
,
logging_callback
=
Seq2SeqLoggingCallback
(),
logging_callback
=
Seq2SeqLoggingCallback
(),
checkpoint_callback
=
get_checkpoint_callback
(
args
.
output_dir
,
model
.
val_metric
,
args
.
save_top_k
),
checkpoint_callback
=
get_checkpoint_callback
(
args
.
output_dir
,
model
.
val_metric
,
args
.
save_top_k
,
lower_is_better
),
early_stopping_callback
=
es_callback
,
early_stopping_callback
=
es_callback
,
logger
=
logger
,
logger
=
logger
,
# TODO: early stopping callback seems messed up
)
)
pickle_save
(
model
.
hparams
,
model
.
output_dir
/
"hparams.pkl"
)
pickle_save
(
model
.
hparams
,
model
.
output_dir
/
"hparams.pkl"
)
if
not
args
.
do_predict
:
if
not
args
.
do_predict
:
...
...
examples/seq2seq/test_seq2seq_examples.py
View file @
e95d262f
...
@@ -33,7 +33,7 @@ CUDA_AVAILABLE = torch.cuda.is_available()
...
@@ -33,7 +33,7 @@ CUDA_AVAILABLE = torch.cuda.is_available()
CHEAP_ARGS
=
{
CHEAP_ARGS
=
{
"label_smoothing"
:
0.2
,
"label_smoothing"
:
0.2
,
"eval_beams"
:
1
,
"eval_beams"
:
1
,
"val_metric"
:
None
,
"val_metric"
:
"loss"
,
"save_top_k"
:
1
,
"save_top_k"
:
1
,
"adafactor"
:
True
,
"adafactor"
:
True
,
"early_stopping_patience"
:
2
,
"early_stopping_patience"
:
2
,
...
@@ -262,9 +262,9 @@ class TestSummarizationDistiller(unittest.TestCase):
...
@@ -262,9 +262,9 @@ class TestSummarizationDistiller(unittest.TestCase):
if
not
check_contents
:
if
not
check_contents
:
return
model
return
model
contents
=
os
.
listdir
(
output_dir
)
contents
=
os
.
listdir
(
output_dir
)
ckpt_name
=
"val_avg_rouge2=0.0000-step_count=2.ckpt"
# "val_avg_rouge2=0.0000-epoch=1.ckpt" # "epoch=1-val_avg_rouge2=0.0000.ckpt"
contents
=
{
os
.
path
.
basename
(
p
)
for
p
in
contents
}
contents
=
{
os
.
path
.
basename
(
p
)
for
p
in
contents
}
self
.
assertIn
(
ckpt_name
,
contents
)
ckpt_files
=
[
p
for
p
in
contents
if
p
.
endswith
(
"ckpt"
)]
assert
len
(
ckpt_files
)
>
0
self
.
assertIn
(
"test_generations.txt"
,
contents
)
self
.
assertIn
(
"test_generations.txt"
,
contents
)
self
.
assertIn
(
"test_results.txt"
,
contents
)
self
.
assertIn
(
"test_results.txt"
,
contents
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment