Unverified Commit 39ed68d5 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[s2s] allow task_specific_params=summarization_xsum (#6923)

parent 5a318f07
...@@ -75,7 +75,7 @@ class Seq2SeqLoggingCallback(pl.Callback): ...@@ -75,7 +75,7 @@ class Seq2SeqLoggingCallback(pl.Callback):
return self._write_logs(trainer, pl_module, "test") return self._write_logs(trainer, pl_module, "test")
def get_checkpoint_callback(output_dir, metric): def get_checkpoint_callback(output_dir, metric, save_top_k=1):
"""Saves the best model by validation ROUGE2 score.""" """Saves the best model by validation ROUGE2 score."""
if metric == "rouge2": if metric == "rouge2":
exp = "{val_avg_rouge2:.4f}-{step_count}" exp = "{val_avg_rouge2:.4f}-{step_count}"
...@@ -90,7 +90,7 @@ def get_checkpoint_callback(output_dir, metric): ...@@ -90,7 +90,7 @@ def get_checkpoint_callback(output_dir, metric):
filepath=os.path.join(output_dir, exp), filepath=os.path.join(output_dir, exp),
monitor=f"val_{metric}", monitor=f"val_{metric}",
mode="max", mode="max",
save_top_k=1, save_top_k=save_top_k,
period=0, # maybe save a checkpoint every time val is run, not just end of epoch. period=0, # maybe save a checkpoint every time val is run, not just end of epoch.
) )
return checkpoint_callback return checkpoint_callback
......
...@@ -306,6 +306,7 @@ class SummarizationModule(BaseTransformer): ...@@ -306,6 +306,7 @@ class SummarizationModule(BaseTransformer):
parser.add_argument("--tgt_lang", type=str, default="", required=False) parser.add_argument("--tgt_lang", type=str, default="", required=False)
parser.add_argument("--eval_beams", type=int, default=None, required=False) parser.add_argument("--eval_beams", type=int, default=None, required=False)
parser.add_argument("--val_metric", type=str, default=None, required=False) parser.add_argument("--val_metric", type=str, default=None, required=False)
parser.add_argument("--save_top_k", type=int, default=1, required=False, help="How many checkpoints to save")
parser.add_argument( parser.add_argument(
"--early_stopping_patience", "--early_stopping_patience",
type=int, type=int,
...@@ -336,7 +337,7 @@ def main(args, model=None) -> SummarizationModule: ...@@ -336,7 +337,7 @@ def main(args, model=None) -> SummarizationModule:
if len(os.listdir(args.output_dir)) > 3 and args.do_train: if len(os.listdir(args.output_dir)) > 3 and args.do_train:
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
if model is None: if model is None:
if args.task == "summarization": if "summarization" in args.task:
model: SummarizationModule = SummarizationModule(args) model: SummarizationModule = SummarizationModule(args)
else: else:
model: SummarizationModule = TranslationModule(args) model: SummarizationModule = TranslationModule(args)
...@@ -368,7 +369,7 @@ def main(args, model=None) -> SummarizationModule: ...@@ -368,7 +369,7 @@ def main(args, model=None) -> SummarizationModule:
model, model,
args, args,
logging_callback=Seq2SeqLoggingCallback(), logging_callback=Seq2SeqLoggingCallback(),
checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric), checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric, args.save_top_k),
early_stopping_callback=es_callback, early_stopping_callback=es_callback,
logger=logger, logger=logger,
# TODO: early stopping callback seems messed up # TODO: early stopping callback seems messed up
......
...@@ -34,6 +34,7 @@ CHEAP_ARGS = { ...@@ -34,6 +34,7 @@ CHEAP_ARGS = {
"label_smoothing": 0.2, "label_smoothing": 0.2,
"eval_beams": 1, "eval_beams": 1,
"val_metric": None, "val_metric": None,
"save_top_k": 1,
"adafactor": True, "adafactor": True,
"early_stopping_patience": 2, "early_stopping_patience": 2,
"logger_name": "default", "logger_name": "default",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment