Unverified Commit ae6834e0 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Examples] Clean summarization and translation example testing files for T5 and Bart (#3514)

* fix conflicts

* add model size argument to summarization

* correct wrong import

* fix isort

* correct imports

* other isort make style

* make style
parent 0373b60c
...@@ -45,7 +45,7 @@ def generate_summaries( ...@@ -45,7 +45,7 @@ def generate_summaries(
fout.flush() fout.flush()
def _run_generate(): def run_generate():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"source_path", type=str, help="like cnn_dm/test.source", "source_path", type=str, help="like cnn_dm/test.source",
...@@ -68,4 +68,4 @@ def _run_generate(): ...@@ -68,4 +68,4 @@ def _run_generate():
if __name__ == "__main__": if __name__ == "__main__":
_run_generate() run_generate()
import logging import logging
import os
import sys import sys
import tempfile import tempfile
import unittest import unittest
from pathlib import Path from pathlib import Path
from unittest.mock import patch from unittest.mock import patch
from .evaluate_cnn import _run_generate from .evaluate_cnn import run_generate
output_file_name = "output_bart_sum.txt"
articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."] articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."]
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
...@@ -26,8 +23,10 @@ class TestBartExamples(unittest.TestCase): ...@@ -26,8 +23,10 @@ class TestBartExamples(unittest.TestCase):
with tmp.open("w") as f: with tmp.open("w") as f:
f.write("\n".join(articles)) f.write("\n".join(articles))
testargs = ["evaluate_cnn.py", str(tmp), output_file_name, "sshleifer/bart-tiny-random"] output_file_name = Path(tempfile.gettempdir()) / "utest_output_bart_sum.hypo"
testargs = ["evaluate_cnn.py", str(tmp), str(output_file_name), "sshleifer/bart-tiny-random"]
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
_run_generate() run_generate()
self.assertTrue(Path(output_file_name).exists()) self.assertTrue(Path(output_file_name).exists())
os.remove(Path(output_file_name))
...@@ -64,7 +64,7 @@ def run_generate(): ...@@ -64,7 +64,7 @@ def run_generate():
parser.add_argument( parser.add_argument(
"model_size", "model_size",
type=str, type=str,
help="T5 model size, either 't5-small', 't5-base' or 't5-large'. Defaults to base.", help="T5 model size, either 't5-small', 't5-base', 't5-large', 't5-3b', 't5-11b'. Defaults to 't5-base'.",
default="t5-base", default="t5-base",
) )
parser.add_argument( parser.add_argument(
......
import logging import logging
import os
import sys import sys
import tempfile import tempfile
import unittest import unittest
...@@ -26,10 +25,13 @@ class TestT5Examples(unittest.TestCase): ...@@ -26,10 +25,13 @@ class TestT5Examples(unittest.TestCase):
tmp = Path(tempfile.gettempdir()) / "utest_generations_t5_sum.hypo" tmp = Path(tempfile.gettempdir()) / "utest_generations_t5_sum.hypo"
with tmp.open("w") as f: with tmp.open("w") as f:
f.write("\n".join(articles)) f.write("\n".join(articles))
testargs = ["evaluate_cnn.py", "t5-small", str(tmp), output_file_name, str(tmp), score_file_name]
output_file_name = Path(tempfile.gettempdir()) / "utest_output_t5_sum.hypo"
score_file_name = Path(tempfile.gettempdir()) / "utest_score_t5_sum.hypo"
testargs = ["evaluate_cnn.py", "t5-small", str(tmp), str(output_file_name), str(tmp), str(score_file_name)]
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
run_generate() run_generate()
self.assertTrue(Path(output_file_name).exists()) self.assertTrue(Path(output_file_name).exists())
self.assertTrue(Path(score_file_name).exists()) self.assertTrue(Path(score_file_name).exists())
os.remove(Path(output_file_name))
os.remove(Path(score_file_name))
...@@ -14,13 +14,13 @@ def chunks(lst, n): ...@@ -14,13 +14,13 @@ def chunks(lst, n):
yield lst[i : i + n] yield lst[i : i + n]
def generate_translations(lns, output_file_path, batch_size, device): def generate_translations(lns, output_file_path, model_size, batch_size, device):
output_file = Path(output_file_path).open("w") output_file = Path(output_file_path).open("w")
model = T5ForConditionalGeneration.from_pretrained("t5-base") model = T5ForConditionalGeneration.from_pretrained(model_size)
model.to(device) model.to(device)
tokenizer = T5Tokenizer.from_pretrained("t5-base") tokenizer = T5Tokenizer.from_pretrained(model_size)
# update config with summarization specific params # update config with summarization specific params
task_specific_params = model.config.task_specific_params task_specific_params = model.config.task_specific_params
...@@ -52,6 +52,12 @@ def calculate_bleu_score(output_lns, refs_lns, score_path): ...@@ -52,6 +52,12 @@ def calculate_bleu_score(output_lns, refs_lns, score_path):
def run_generate(): def run_generate():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument(
"model_size",
type=str,
help="T5 model size, either 't5-small', 't5-base', 't5-large', 't5-3b', 't5-11b'. Defaults to 't5-base'.",
default="t5-base",
)
parser.add_argument( parser.add_argument(
"input_path", type=str, help="like wmt/newstest2013.en", "input_path", type=str, help="like wmt/newstest2013.en",
) )
...@@ -78,7 +84,7 @@ def run_generate(): ...@@ -78,7 +84,7 @@ def run_generate():
input_lns = [x.strip().replace(dash_pattern[0], dash_pattern[1]) for x in open(args.input_path).readlines()] input_lns = [x.strip().replace(dash_pattern[0], dash_pattern[1]) for x in open(args.input_path).readlines()]
generate_translations(input_lns, args.output_path, args.batch_size, args.device) generate_translations(input_lns, args.output_path, args.model_size, args.batch_size, args.device)
output_lns = [x.strip() for x in open(args.output_path).readlines()] output_lns = [x.strip() for x in open(args.output_path).readlines()]
refs_lns = [x.strip().replace(dash_pattern[0], dash_pattern[1]) for x in open(args.reference_path).readlines()] refs_lns = [x.strip().replace(dash_pattern[0], dash_pattern[1]) for x in open(args.reference_path).readlines()]
......
import logging import logging
import os
import sys import sys
import tempfile import tempfile
import unittest import unittest
...@@ -33,11 +32,19 @@ class TestT5Examples(unittest.TestCase): ...@@ -33,11 +32,19 @@ class TestT5Examples(unittest.TestCase):
with tmp_target.open("w") as f: with tmp_target.open("w") as f:
f.write("\n".join(translation)) f.write("\n".join(translation))
testargs = ["evaluate_wmt.py", str(tmp_source), output_file_name, str(tmp_target), score_file_name] output_file_name = Path(tempfile.gettempdir()) / "utest_output_trans.hypo"
score_file_name = Path(tempfile.gettempdir()) / "utest_score.hypo"
testargs = [
"evaluate_wmt.py",
"t5-small",
str(tmp_source),
str(output_file_name),
str(tmp_target),
str(score_file_name),
]
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
run_generate() run_generate()
self.assertTrue(Path(output_file_name).exists()) self.assertTrue(Path(output_file_name).exists())
self.assertTrue(Path(score_file_name).exists()) self.assertTrue(Path(score_file_name).exists())
os.remove(Path(output_file_name))
os.remove(Path(score_file_name))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment