Unverified Commit 61518e2d authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[s2s] run_eval.py QOL improvements and cleanup(#6746)

parent 434936f3
import argparse import argparse
import json import json
import time
import warnings
from logging import getLogger
from pathlib import Path from pathlib import Path
from typing import Dict, List
import torch import torch
from tqdm import tqdm from tqdm import tqdm
...@@ -8,10 +12,12 @@ from tqdm import tqdm ...@@ -8,10 +12,12 @@ from tqdm import tqdm
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
logger = getLogger(__name__)
try: try:
from .utils import calculate_bleu, calculate_rouge, trim_batch, use_task_specific_params from .utils import calculate_bleu, calculate_rouge, use_task_specific_params
except ImportError: except ImportError:
from utils import calculate_bleu, calculate_rouge, trim_batch, use_task_specific_params from utils import calculate_bleu, calculate_rouge, use_task_specific_params
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
...@@ -23,7 +29,7 @@ def chunks(lst, n): ...@@ -23,7 +29,7 @@ def chunks(lst, n):
def generate_summaries_or_translations( def generate_summaries_or_translations(
examples: list, examples: List[str],
out_file: str, out_file: str,
model_name: str, model_name: str,
batch_size: int = 8, batch_size: int = 8,
...@@ -31,36 +37,39 @@ def generate_summaries_or_translations( ...@@ -31,36 +37,39 @@ def generate_summaries_or_translations(
fp16=False, fp16=False,
task="summarization", task="summarization",
decoder_start_token_id=None, decoder_start_token_id=None,
**gen_kwargs, **generate_kwargs,
) -> None: ) -> Dict:
"""Save model.generate results to <out_file>, and return how long it took."""
fout = Path(out_file).open("w", encoding="utf-8") fout = Path(out_file).open("w", encoding="utf-8")
model_name = str(model_name) model_name = str(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device) model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
if fp16: if fp16:
model = model.half() model = model.half()
if decoder_start_token_id is None:
decoder_start_token_id = gen_kwargs.pop("decoder_start_token_id", None)
tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name)
logger.info(f"Inferred tokenizer type: {tokenizer.__class__}") # if this is wrong, check config.model_type.
# update config with summarization specific params start_time = time.time()
# update config with task specific params
use_task_specific_params(model, task) use_task_specific_params(model, task)
for examples_chunk in tqdm(list(chunks(examples, batch_size))):
for batch in tqdm(list(chunks(examples, batch_size))):
if "t5" in model_name: if "t5" in model_name:
batch = [model.config.prefix + text for text in batch] examples_chunk = [model.config.prefix + text for text in examples_chunk]
batch = tokenizer(batch, return_tensors="pt", truncation=True, padding="max_length").to(device) batch = tokenizer(examples_chunk, return_tensors="pt", truncation=True, padding="longest").to(device)
input_ids, attention_mask = trim_batch(**batch, pad_token_id=tokenizer.pad_token_id)
summaries = model.generate( summaries = model.generate(
input_ids=input_ids, input_ids=batch.input_ids,
attention_mask=attention_mask, attention_mask=batch.attention_mask,
decoder_start_token_id=decoder_start_token_id, decoder_start_token_id=decoder_start_token_id,
**gen_kwargs, **generate_kwargs,
) )
dec = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False) dec = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False)
for hypothesis in dec: for hypothesis in dec:
fout.write(hypothesis + "\n") fout.write(hypothesis + "\n")
fout.flush() fout.flush()
fout.close()
runtime = time.time() - start_time
n_obs = len(examples)
return dict(n_obs=n_obs, runtime=runtime, seconds_per_sample=round(runtime / n_obs, 4))
def run_generate(): def run_generate():
...@@ -70,7 +79,13 @@ def run_generate(): ...@@ -70,7 +79,13 @@ def run_generate():
parser.add_argument("save_path", type=str, help="where to save summaries") parser.add_argument("save_path", type=str, help="where to save summaries")
parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test_reference_summaries.txt") parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test_reference_summaries.txt")
parser.add_argument("--score_path", type=str, required=False, help="where to save the rouge score in json format") parser.add_argument(
"--score_path",
type=str,
required=False,
default="metrics.json",
help="where to save the rouge score in json format",
)
parser.add_argument("--device", type=str, required=False, default=DEFAULT_DEVICE, help="cuda, cuda:1, cpu etc.") parser.add_argument("--device", type=str, required=False, default=DEFAULT_DEVICE, help="cuda, cuda:1, cpu etc.")
parser.add_argument("--task", type=str, default="summarization", help="typically translation or summarization") parser.add_argument("--task", type=str, default="summarization", help="typically translation or summarization")
parser.add_argument("--bs", type=int, default=8, required=False, help="batch size") parser.add_argument("--bs", type=int, default=8, required=False, help="batch size")
...@@ -79,7 +94,7 @@ def run_generate(): ...@@ -79,7 +94,7 @@ def run_generate():
type=int, type=int,
default=None, default=None,
required=False, required=False,
help="decoder_start_token_id (otherwise will look at config)", help="Defaults to using config",
) )
parser.add_argument( parser.add_argument(
"--n_obs", type=int, default=-1, required=False, help="How many observations. Defaults to all." "--n_obs", type=int, default=-1, required=False, help="How many observations. Defaults to all."
...@@ -90,7 +105,9 @@ def run_generate(): ...@@ -90,7 +105,9 @@ def run_generate():
if args.n_obs > 0: if args.n_obs > 0:
examples = examples[: args.n_obs] examples = examples[: args.n_obs]
Path(args.save_path).parent.mkdir(exist_ok=True) Path(args.save_path).parent.mkdir(exist_ok=True)
generate_summaries_or_translations( if args.reference_path is None and Path(args.score_path).exists():
warnings.warn(f"score_path {args.score_path} will be overwritten unless you type ctrl-c.")
runtime_metrics = generate_summaries_or_translations(
examples, examples,
args.save_path, args.save_path,
args.model_name, args.model_name,
...@@ -107,9 +124,10 @@ def run_generate(): ...@@ -107,9 +124,10 @@ def run_generate():
output_lns = [x.rstrip() for x in open(args.save_path).readlines()] output_lns = [x.rstrip() for x in open(args.save_path).readlines()]
reference_lns = [x.rstrip() for x in open(args.reference_path).readlines()][: len(output_lns)] reference_lns = [x.rstrip() for x in open(args.reference_path).readlines()][: len(output_lns)]
scores: dict = score_fn(output_lns, reference_lns) scores: dict = score_fn(output_lns, reference_lns)
scores.update(runtime_metrics)
print(scores) print(scores)
if args.score_path is not None: if args.score_path is not None:
json.dump(scores, open(args.score_path, "w+")) json.dump(scores, open(args.score_path, "w"))
return scores return scores
......
...@@ -252,13 +252,24 @@ class TestSummarizationDistiller(unittest.TestCase): ...@@ -252,13 +252,24 @@ class TestSummarizationDistiller(unittest.TestCase):
@pytest.mark.parametrize(["model"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY)]) @pytest.mark.parametrize(["model"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY)])
def test_run_eval_bart(model): def test_run_eval(model):
input_file_name = Path(tempfile.mkdtemp()) / "utest_input.source" input_file_name = Path(tempfile.mkdtemp()) / "utest_input.source"
output_file_name = input_file_name.parent / "utest_output.txt" output_file_name = input_file_name.parent / "utest_output.txt"
assert not output_file_name.exists() assert not output_file_name.exists()
articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."] articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."]
_dump_articles(input_file_name, articles) _dump_articles(input_file_name, articles)
testargs = ["run_eval.py", model, str(input_file_name), str(output_file_name)] # TODO: test score_path score_path = str(Path(tempfile.mkdtemp()) / "scores.json")
task = "translation_en_to_de" if model == T5_TINY else "summarization"
testargs = [
"run_eval.py",
model,
str(input_file_name),
str(output_file_name),
"--score_path",
score_path,
"--task",
task,
]
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
run_generate() run_generate()
assert Path(output_file_name).exists() assert Path(output_file_name).exists()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment