Unverified Commit 393b8dc0 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

examples/seq2seq/run_eval.py fixes and docs (#5322)

parent 5543b30a
......@@ -37,13 +37,50 @@ export ENRO_DIR=${PWD}/wmt_en_ro
If you are using your own data, it must be formatted as one directory with 6 files: train.source, train.target, val.source, val.target, test.source, test.target.
The `.source` files are the input, the `.target` files are the desired output.
### Evaluation
### Evaluation Commands
To create summaries for each article in dataset, run:
To create summaries for each article in dataset, we use `run_eval.py`, here are a few commands that run eval for different tasks and models.
If 'translation' is in your task name, the computed metric will be BLEU. Otherwise, ROUGE will be used.
For t5, you need to specify --task translation_{src}_to_{tgt} as follows:
```bash
export DATA_DIR=wmt_en_ro
python run_eval.py t5_base \
$DATA_DIR/val.source mbart_val_generations.txt \
--reference_path $DATA_DIR/val.target \
--score_path enro_bleu.json \
--task translation_en_to_ro \
--n_obs 100 \
--device cuda \
--fp16 \
--bs 32
```
This command works for MBART, although the BLEU score is suspiciously low.
```bash
export DATA_DIR=wmt_en_ro
python run_eval.py facebook/mbart-large-en-ro $DATA_DIR/val.source mbart_val_generations.txt \
--reference_path $DATA_DIR/val.target \
--score_path enro_bleu.json \
--task translation \
--n_obs 100 \
--device cuda \
--fp16 \
--bs 32
```
Summarization (xsum will be very similar):
```bash
python run_eval.py <path_to_test.source> test_generations.txt <model-name> --score_path rouge_scores.txt
export DATA_DIR=cnn_dm
python run_eval.py sshleifer/distilbart-cnn-12-6 $DATA_DIR/val.source dbart_val_generations.txt \
--reference_path $DATA_DIR/val.target \
--score_path cnn_rouge.json \
--task summarization \
--n_obs 100 \
--device cuda \
--fp16 \
--bs 32
```
The default batch size, 4, fits in 16GB GPU memory, but may need to be adjusted to fit your system.
### Summarization Finetuning
......
......@@ -9,9 +9,9 @@ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
try:
from .utils import calculate_rouge, use_task_specific_params, calculate_bleu_score
from .utils import calculate_rouge, use_task_specific_params, calculate_bleu_score, trim_batch
except ImportError:
from utils import calculate_rouge, use_task_specific_params, calculate_bleu_score
from utils import calculate_rouge, use_task_specific_params, calculate_bleu_score, trim_batch
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
......@@ -29,6 +29,7 @@ def generate_summaries_or_translations(
batch_size: int = 8,
device: str = DEFAULT_DEVICE,
fp16=False,
task="summarization",
**gen_kwargs,
) -> None:
fout = Path(out_file).open("w", encoding="utf-8")
......@@ -40,7 +41,7 @@ def generate_summaries_or_translations(
tokenizer = AutoTokenizer.from_pretrained(model_name)
# update config with summarization specific params
use_task_specific_params(model, "summarization")
use_task_specific_params(model, task)
for batch in tqdm(list(chunks(examples, batch_size))):
if "t5" in model_name:
......@@ -48,7 +49,8 @@ def generate_summaries_or_translations(
batch = tokenizer(batch, max_length=1024, return_tensors="pt", truncation=True, padding="max_length").to(
device
)
summaries = model.generate(**batch, **gen_kwargs)
input_ids, attention_mask = trim_batch(**batch, pad_token_id=tokenizer.pad_token_id)
summaries = model.generate(input_ids=input_ids, attention_mask=attention_mask, **gen_kwargs)
dec = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False)
for hypothesis in dec:
fout.write(hypothesis + "\n")
......@@ -57,30 +59,42 @@ def generate_summaries_or_translations(
def run_generate():
parser = argparse.ArgumentParser()
parser.add_argument("input_path", type=str, help="like cnn_dm/test.source")
parser.add_argument("output_path", type=str, help="where to save summaries")
parser.add_argument("model_name", type=str, help="like facebook/bart-large-cnn,t5-base, etc.")
parser.add_argument("input_path", type=str, help="like cnn_dm/test.source")
parser.add_argument("save_path", type=str, help="where to save summaries")
parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test_reference_summaries.txt")
parser.add_argument("--score_path", type=str, required=False, help="where to save the rouge score in json format")
parser.add_argument("--metric", type=str, choices=["bleu", "rouge"], default="rouge")
parser.add_argument("--device", type=str, required=False, default=DEFAULT_DEVICE, help="cuda, cuda:1, cpu etc.")
parser.add_argument("--task", type=str, default="summarization", help="typically translation or summarization")
parser.add_argument("--bs", type=int, default=8, required=False, help="batch size")
parser.add_argument(
"--n_obs", type=int, default=-1, required=False, help="How many observations. Defaults to all."
)
parser.add_argument("--fp16", action="store_true")
args = parser.parse_args()
examples = [" " + x.rstrip() if "t5" in args.model_name else x.rstrip() for x in open(args.input_path).readlines()]
if args.n_obs > 0:
examples = examples[: args.n_obs]
generate_summaries_or_translations(
examples, args.output_path, args.model_name, batch_size=args.bs, device=args.device, fp16=args.fp16
examples,
args.save_path,
args.model_name,
batch_size=args.bs,
device=args.device,
fp16=args.fp16,
task=args.task,
)
output_lns = [x.rstrip() for x in open(args.output_path).readlines()]
scores = {}
if args.reference_path is not None:
score_fn = {"bleu": calculate_bleu_score, "rouge": calculate_rouge}[args.metric]
reference_lns = [x.rstrip() for x in open(args.reference_path).readlines()]
scores: dict = score_fn(output_lns, reference_lns)
if args.score_path is not None:
json.dump(scores, open("score_path", "w+"))
if args.reference_path is None:
return
# Compute scores
score_fn = calculate_bleu_score if "translation" in args.task else calculate_rouge
output_lns = [x.rstrip() for x in open(args.save_path).readlines()]
reference_lns = [x.rstrip() for x in open(args.reference_path).readlines()][: len(output_lns)]
scores: dict = score_fn(output_lns, reference_lns)
if args.score_path is not None:
json.dump(scores, open(args.score_path, "w+"))
return scores
......
......@@ -198,7 +198,7 @@ def test_run_eval_bart(model):
assert not output_file_name.exists()
articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."]
_dump_articles(input_file_name, articles)
testargs = ["run_eval.py", str(input_file_name), str(output_file_name), model] # TODO: test score_path
testargs = ["run_eval.py", model, str(input_file_name), str(output_file_name)] # TODO: test score_path
with patch.object(sys, "argv", testargs):
run_generate()
assert Path(output_file_name).exists()
......
......@@ -60,8 +60,9 @@ def lmap(f: Callable, x: Iterable) -> List:
return list(map(f, x))
def calculate_bleu_score(output_lns, refs_lns) -> dict:
return {"bleu": corpus_bleu(output_lns, [refs_lns]).score}
def calculate_bleu_score(output_lns, refs_lns, **kwargs) -> dict:
"""Uses sacrebleu's corpus_bleu implementation."""
return {"bleu": corpus_bleu(output_lns, [refs_lns], **kwargs).score}
def trim_batch(
......
......@@ -253,9 +253,9 @@ class MBartIntegrationTests(unittest.TestCase):
with torch.no_grad():
logits, *other_stuff = model(**net_input)
expected_slice = torch.tensor([9.0078, 10.1113, 14.4787], device=torch_device, dtype=model.dtype)
result_slice = logits[0][0][:3]
self.assertTrue(torch.allclose(expected_slice, result_slice, atol=TOLERANCE))
expected_slice = [9.0078, 10.1113, 14.4787]
result_slice = logits[0][0][:3].tolist()
self.assertListEqual(expected_slice, result_slice)
@slow
def test_enro_generate(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment