Unverified Commit b9772897 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[s2s] command line args for faster val steps (#6833)

parent 8af1970e
...@@ -262,7 +262,7 @@ class BartTranslationDistiller(BartSummarizationDistiller): ...@@ -262,7 +262,7 @@ class BartTranslationDistiller(BartSummarizationDistiller):
mode = "translation" mode = "translation"
metric_names = ["bleu"] metric_names = ["bleu"]
val_metric = "bleu" default_val_metric = "bleu"
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs) super().__init__(hparams, **kwargs)
......
...@@ -63,7 +63,7 @@ class SummarizationModule(BaseTransformer): ...@@ -63,7 +63,7 @@ class SummarizationModule(BaseTransformer):
mode = "summarization" mode = "summarization"
loss_names = ["loss"] loss_names = ["loss"]
metric_names = ROUGE_KEYS metric_names = ROUGE_KEYS
val_metric = "rouge2" default_val_metric = "rouge2"
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs) super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs)
...@@ -110,6 +110,9 @@ class SummarizationModule(BaseTransformer): ...@@ -110,6 +110,9 @@ class SummarizationModule(BaseTransformer):
self.dataset_class = ( self.dataset_class = (
Seq2SeqDataset if hasattr(self.tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset Seq2SeqDataset if hasattr(self.tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset
) )
self.eval_beams = self.model.config.num_beams if self.hparams.eval_beams is None else self.hparams.eval_beams
assert self.eval_beams >= 1, f"got self.eval_beams={self.eval_beams}. Need an integer > 1"
self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric
def freeze_embeds(self): def freeze_embeds(self):
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5.""" """Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
...@@ -301,6 +304,8 @@ class SummarizationModule(BaseTransformer): ...@@ -301,6 +304,8 @@ class SummarizationModule(BaseTransformer):
parser.add_argument("--label_smoothing", type=float, default=0.0, required=False) parser.add_argument("--label_smoothing", type=float, default=0.0, required=False)
parser.add_argument("--src_lang", type=str, default="", required=False) parser.add_argument("--src_lang", type=str, default="", required=False)
parser.add_argument("--tgt_lang", type=str, default="", required=False) parser.add_argument("--tgt_lang", type=str, default="", required=False)
parser.add_argument("--eval_beams", type=int, default=None, required=False)
parser.add_argument("--val_metric", type=str, default=None, required=False)
parser.add_argument( parser.add_argument(
"--early_stopping_patience", "--early_stopping_patience",
type=int, type=int,
...@@ -315,7 +320,7 @@ class TranslationModule(SummarizationModule): ...@@ -315,7 +320,7 @@ class TranslationModule(SummarizationModule):
mode = "translation" mode = "translation"
loss_names = ["loss"] loss_names = ["loss"]
metric_names = ["bleu"] metric_names = ["bleu"]
val_metric = "bleu" default_val_metric = "bleu"
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs) super().__init__(hparams, **kwargs)
......
...@@ -31,6 +31,8 @@ logger = logging.getLogger() ...@@ -31,6 +31,8 @@ logger = logging.getLogger()
CUDA_AVAILABLE = torch.cuda.is_available() CUDA_AVAILABLE = torch.cuda.is_available()
CHEAP_ARGS = { CHEAP_ARGS = {
"label_smoothing": 0.2, "label_smoothing": 0.2,
"eval_beams": 1,
"val_metric": None,
"adafactor": True, "adafactor": True,
"early_stopping_patience": 2, "early_stopping_patience": 2,
"logger_name": "default", "logger_name": "default",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment