"...git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "033cc4cfad07e8467ca6cb9e4401752c9c05d32c"
Unverified Commit ff80b731 authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Add option to choose T5 model size. (#3480)

T5-small in test


isort
parent e2c05f06
...@@ -14,13 +14,13 @@ def chunks(lst, n): ...@@ -14,13 +14,13 @@ def chunks(lst, n):
yield lst[i : i + n] yield lst[i : i + n]
def generate_summaries(lns, output_file_path, batch_size, device): def generate_summaries(lns, output_file_path, model_size, batch_size, device):
output_file = Path(output_file_path).open("w") output_file = Path(output_file_path).open("w")
model = T5ForConditionalGeneration.from_pretrained("t5-large") model = T5ForConditionalGeneration.from_pretrained(model_size)
model.to(device) model.to(device)
tokenizer = T5Tokenizer.from_pretrained("t5-large") tokenizer = T5Tokenizer.from_pretrained(model_size)
# update config with summarization specific params # update config with summarization specific params
task_specific_params = model.config.task_specific_params task_specific_params = model.config.task_specific_params
...@@ -61,6 +61,12 @@ def calculate_rouge(output_lns, reference_lns, score_path): ...@@ -61,6 +61,12 @@ def calculate_rouge(output_lns, reference_lns, score_path):
def run_generate(): def run_generate():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument(
"model_size",
type=str,
help="T5 model size, either 't5-small', 't5-base' or 't5-large'. Defaults to base.",
default="t5-base",
)
parser.add_argument( parser.add_argument(
"input_path", type=str, help="like cnn_dm/test_articles_input.txt", "input_path", type=str, help="like cnn_dm/test_articles_input.txt",
) )
...@@ -83,7 +89,7 @@ def run_generate(): ...@@ -83,7 +89,7 @@ def run_generate():
source_lns = [x.rstrip() for x in open(args.input_path).readlines()] source_lns = [x.rstrip() for x in open(args.input_path).readlines()]
generate_summaries(source_lns, args.output_path, args.batch_size, args.device) generate_summaries(source_lns, args.output_path, args.model_size, args.batch_size, args.device)
output_lns = [x.rstrip() for x in open(args.output_path).readlines()] output_lns = [x.rstrip() for x in open(args.output_path).readlines()]
reference_lns = [x.rstrip() for x in open(args.reference_path).readlines()] reference_lns = [x.rstrip() for x in open(args.reference_path).readlines()]
......
...@@ -22,7 +22,7 @@ class TestT5Examples(unittest.TestCase): ...@@ -22,7 +22,7 @@ class TestT5Examples(unittest.TestCase):
tmp = Path(tempfile.gettempdir()) / "utest_generations.hypo" tmp = Path(tempfile.gettempdir()) / "utest_generations.hypo"
with tmp.open("w") as f: with tmp.open("w") as f:
f.write("\n".join(articles)) f.write("\n".join(articles))
testargs = ["evaluate_cnn.py", str(tmp), "output.txt", str(tmp), "score.txt"] testargs = ["evaluate_cnn.py", "t5-small", str(tmp), "output.txt", str(tmp), "score.txt"]
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
run_generate() run_generate()
self.assertTrue(Path("output.txt").exists()) self.assertTrue(Path("output.txt").exists())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment