Unverified Commit 0344428f authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[s2s] round bleu, rouge to 4 digits (#6704)

parent b6512d23
...@@ -20,7 +20,7 @@ try: ...@@ -20,7 +20,7 @@ try:
from .utils import ( from .utils import (
any_requires_grad, any_requires_grad,
assert_all_frozen, assert_all_frozen,
calculate_bleu_score, calculate_bleu,
freeze_params, freeze_params,
pickle_load, pickle_load,
use_task_specific_params, use_task_specific_params,
...@@ -32,7 +32,7 @@ except ImportError: ...@@ -32,7 +32,7 @@ except ImportError:
from utils import ( from utils import (
any_requires_grad, any_requires_grad,
assert_all_frozen, assert_all_frozen,
calculate_bleu_score, calculate_bleu,
freeze_params, freeze_params,
pickle_load, pickle_load,
use_task_specific_params, use_task_specific_params,
...@@ -261,7 +261,7 @@ class BartTranslationDistiller(BartSummarizationDistiller): ...@@ -261,7 +261,7 @@ class BartTranslationDistiller(BartSummarizationDistiller):
self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang] self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
def calc_generative_metrics(self, preds, target) -> dict: def calc_generative_metrics(self, preds, target) -> dict:
return calculate_bleu_score(preds, target) return calculate_bleu(preds, target)
@staticmethod @staticmethod
def add_model_specific_args(parser, root_dir): def add_model_specific_args(parser, root_dir):
......
...@@ -23,7 +23,7 @@ try: ...@@ -23,7 +23,7 @@ try:
Seq2SeqDataset, Seq2SeqDataset,
TranslationDataset, TranslationDataset,
assert_all_frozen, assert_all_frozen,
calculate_bleu_score, calculate_bleu,
calculate_rouge, calculate_rouge,
flatten_list, flatten_list,
freeze_params, freeze_params,
...@@ -42,7 +42,7 @@ except ImportError: ...@@ -42,7 +42,7 @@ except ImportError:
Seq2SeqDataset, Seq2SeqDataset,
TranslationDataset, TranslationDataset,
assert_all_frozen, assert_all_frozen,
calculate_bleu_score, calculate_bleu,
calculate_rouge, calculate_rouge,
flatten_list, flatten_list,
freeze_params, freeze_params,
...@@ -325,7 +325,7 @@ class TranslationModule(SummarizationModule): ...@@ -325,7 +325,7 @@ class TranslationModule(SummarizationModule):
self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang
def calc_generative_metrics(self, preds, target) -> dict: def calc_generative_metrics(self, preds, target) -> dict:
return calculate_bleu_score(preds, target) return calculate_bleu(preds, target)
def main(args, model=None) -> SummarizationModule: def main(args, model=None) -> SummarizationModule:
......
...@@ -9,9 +9,9 @@ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer ...@@ -9,9 +9,9 @@ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
try: try:
from .utils import calculate_bleu_score, calculate_rouge, trim_batch, use_task_specific_params from .utils import calculate_bleu, calculate_rouge, trim_batch, use_task_specific_params
except ImportError: except ImportError:
from utils import calculate_bleu_score, calculate_rouge, trim_batch, use_task_specific_params from utils import calculate_bleu, calculate_rouge, trim_batch, use_task_specific_params
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
...@@ -103,7 +103,7 @@ def run_generate(): ...@@ -103,7 +103,7 @@ def run_generate():
if args.reference_path is None: if args.reference_path is None:
return return
# Compute scores # Compute scores
score_fn = calculate_bleu_score if "translation" in args.task else calculate_rouge score_fn = calculate_bleu if "translation" in args.task else calculate_rouge
output_lns = [x.rstrip() for x in open(args.save_path).readlines()] output_lns = [x.rstrip() for x in open(args.save_path).readlines()]
reference_lns = [x.rstrip() for x in open(args.reference_path).readlines()][: len(output_lns)] reference_lns = [x.rstrip() for x in open(args.reference_path).readlines()][: len(output_lns)]
scores: dict = score_fn(output_lns, reference_lns) scores: dict = score_fn(output_lns, reference_lns)
......
...@@ -57,9 +57,9 @@ def lmap(f: Callable, x: Iterable) -> List: ...@@ -57,9 +57,9 @@ def lmap(f: Callable, x: Iterable) -> List:
return list(map(f, x)) return list(map(f, x))
def calculate_bleu_score(output_lns, refs_lns, **kwargs) -> dict: def calculate_bleu(output_lns, refs_lns, **kwargs) -> dict:
"""Uses sacrebleu's corpus_bleu implementation.""" """Uses sacrebleu's corpus_bleu implementation."""
return {"bleu": corpus_bleu(output_lns, [refs_lns], **kwargs).score} return {"bleu": round(corpus_bleu(output_lns, [refs_lns], **kwargs).score, 4)}
def trim_batch( def trim_batch(
...@@ -271,7 +271,7 @@ def calculate_rouge(output_lns: List[str], reference_lns: List[str], use_stemmer ...@@ -271,7 +271,7 @@ def calculate_rouge(output_lns: List[str], reference_lns: List[str], use_stemmer
aggregator.add_scores(scores) aggregator.add_scores(scores)
result = aggregator.aggregate() result = aggregator.aggregate()
return {k: v.mid.fmeasure * 100 for k, v in result.items()} return {k: round(v.mid.fmeasure * 100, 4) for k, v in result.items()}
def freeze_params(model: nn.Module): def freeze_params(model: nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment