Unverified Commit 393b8dc0 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

examples/seq2seq/run_eval.py fixes and docs (#5322)

parent 5543b30a
...@@ -37,13 +37,50 @@ export ENRO_DIR=${PWD}/wmt_en_ro ...@@ -37,13 +37,50 @@ export ENRO_DIR=${PWD}/wmt_en_ro
If you are using your own data, it must be formatted as one directory with 6 files: train.source, train.target, val.source, val.target, test.source, test.target. If you are using your own data, it must be formatted as one directory with 6 files: train.source, train.target, val.source, val.target, test.source, test.target.
The `.source` files are the input, the `.target` files are the desired output. The `.source` files are the input, the `.target` files are the desired output.
### Evaluation ### Evaluation Commands
To create summaries for each article in dataset, run: To create summaries for each article in dataset, we use `run_eval.py`, here are a few commands that run eval for different tasks and models.
If 'translation' is in your task name, the computed metric will be BLEU. Otherwise, ROUGE will be used.
For t5, you need to specify --task translation_{src}_to_{tgt} as follows:
```bash
export DATA_DIR=wmt_en_ro
python run_eval.py t5_base \
$DATA_DIR/val.source mbart_val_generations.txt \
--reference_path $DATA_DIR/val.target \
--score_path enro_bleu.json \
--task translation_en_to_ro \
--n_obs 100 \
--device cuda \
--fp16 \
--bs 32
```
This command works for MBART, although the BLEU score is suspiciously low.
```bash
export DATA_DIR=wmt_en_ro
python run_eval.py facebook/mbart-large-en-ro $DATA_DIR/val.source mbart_val_generations.txt \
--reference_path $DATA_DIR/val.target \
--score_path enro_bleu.json \
--task translation \
--n_obs 100 \
--device cuda \
--fp16 \
--bs 32
```
Summarization (xsum will be very similar):
```bash ```bash
python run_eval.py <path_to_test.source> test_generations.txt <model-name> --score_path rouge_scores.txt export DATA_DIR=cnn_dm
python run_eval.py sshleifer/distilbart-cnn-12-6 $DATA_DIR/val.source dbart_val_generations.txt \
--reference_path $DATA_DIR/val.target \
--score_path cnn_rouge.json \
--task summarization \
--n_obs 100 \
--device cuda \
--fp16 \
--bs 32
``` ```
The default batch size, 4, fits in 16GB GPU memory, but may need to be adjusted to fit your system.
### Summarization Finetuning ### Summarization Finetuning
......
...@@ -9,9 +9,9 @@ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer ...@@ -9,9 +9,9 @@ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
try: try:
from .utils import calculate_rouge, use_task_specific_params, calculate_bleu_score from .utils import calculate_rouge, use_task_specific_params, calculate_bleu_score, trim_batch
except ImportError: except ImportError:
from utils import calculate_rouge, use_task_specific_params, calculate_bleu_score from utils import calculate_rouge, use_task_specific_params, calculate_bleu_score, trim_batch
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
...@@ -29,6 +29,7 @@ def generate_summaries_or_translations( ...@@ -29,6 +29,7 @@ def generate_summaries_or_translations(
batch_size: int = 8, batch_size: int = 8,
device: str = DEFAULT_DEVICE, device: str = DEFAULT_DEVICE,
fp16=False, fp16=False,
task="summarization",
**gen_kwargs, **gen_kwargs,
) -> None: ) -> None:
fout = Path(out_file).open("w", encoding="utf-8") fout = Path(out_file).open("w", encoding="utf-8")
...@@ -40,7 +41,7 @@ def generate_summaries_or_translations( ...@@ -40,7 +41,7 @@ def generate_summaries_or_translations(
tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name)
# update config with summarization specific params # update config with summarization specific params
use_task_specific_params(model, "summarization") use_task_specific_params(model, task)
for batch in tqdm(list(chunks(examples, batch_size))): for batch in tqdm(list(chunks(examples, batch_size))):
if "t5" in model_name: if "t5" in model_name:
...@@ -48,7 +49,8 @@ def generate_summaries_or_translations( ...@@ -48,7 +49,8 @@ def generate_summaries_or_translations(
batch = tokenizer(batch, max_length=1024, return_tensors="pt", truncation=True, padding="max_length").to( batch = tokenizer(batch, max_length=1024, return_tensors="pt", truncation=True, padding="max_length").to(
device device
) )
summaries = model.generate(**batch, **gen_kwargs) input_ids, attention_mask = trim_batch(**batch, pad_token_id=tokenizer.pad_token_id)
summaries = model.generate(input_ids=input_ids, attention_mask=attention_mask, **gen_kwargs)
dec = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False) dec = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False)
for hypothesis in dec: for hypothesis in dec:
fout.write(hypothesis + "\n") fout.write(hypothesis + "\n")
...@@ -57,30 +59,42 @@ def generate_summaries_or_translations( ...@@ -57,30 +59,42 @@ def generate_summaries_or_translations(
def run_generate(): def run_generate():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("input_path", type=str, help="like cnn_dm/test.source")
parser.add_argument("output_path", type=str, help="where to save summaries")
parser.add_argument("model_name", type=str, help="like facebook/bart-large-cnn,t5-base, etc.") parser.add_argument("model_name", type=str, help="like facebook/bart-large-cnn,t5-base, etc.")
parser.add_argument("input_path", type=str, help="like cnn_dm/test.source")
parser.add_argument("save_path", type=str, help="where to save summaries")
parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test_reference_summaries.txt") parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test_reference_summaries.txt")
parser.add_argument("--score_path", type=str, required=False, help="where to save the rouge score in json format") parser.add_argument("--score_path", type=str, required=False, help="where to save the rouge score in json format")
parser.add_argument("--metric", type=str, choices=["bleu", "rouge"], default="rouge")
parser.add_argument("--device", type=str, required=False, default=DEFAULT_DEVICE, help="cuda, cuda:1, cpu etc.") parser.add_argument("--device", type=str, required=False, default=DEFAULT_DEVICE, help="cuda, cuda:1, cpu etc.")
parser.add_argument("--task", type=str, default="summarization", help="typically translation or summarization")
parser.add_argument("--bs", type=int, default=8, required=False, help="batch size") parser.add_argument("--bs", type=int, default=8, required=False, help="batch size")
parser.add_argument(
"--n_obs", type=int, default=-1, required=False, help="How many observations. Defaults to all."
)
parser.add_argument("--fp16", action="store_true") parser.add_argument("--fp16", action="store_true")
args = parser.parse_args() args = parser.parse_args()
examples = [" " + x.rstrip() if "t5" in args.model_name else x.rstrip() for x in open(args.input_path).readlines()] examples = [" " + x.rstrip() if "t5" in args.model_name else x.rstrip() for x in open(args.input_path).readlines()]
if args.n_obs > 0:
examples = examples[: args.n_obs]
generate_summaries_or_translations( generate_summaries_or_translations(
examples, args.output_path, args.model_name, batch_size=args.bs, device=args.device, fp16=args.fp16 examples,
args.save_path,
args.model_name,
batch_size=args.bs,
device=args.device,
fp16=args.fp16,
task=args.task,
) )
if args.reference_path is None:
output_lns = [x.rstrip() for x in open(args.output_path).readlines()] return
scores = {} # Compute scores
if args.reference_path is not None: score_fn = calculate_bleu_score if "translation" in args.task else calculate_rouge
score_fn = {"bleu": calculate_bleu_score, "rouge": calculate_rouge}[args.metric] output_lns = [x.rstrip() for x in open(args.save_path).readlines()]
reference_lns = [x.rstrip() for x in open(args.reference_path).readlines()] reference_lns = [x.rstrip() for x in open(args.reference_path).readlines()][: len(output_lns)]
scores: dict = score_fn(output_lns, reference_lns) scores: dict = score_fn(output_lns, reference_lns)
if args.score_path is not None: if args.score_path is not None:
json.dump(scores, open("score_path", "w+")) json.dump(scores, open(args.score_path, "w+"))
return scores return scores
......
...@@ -198,7 +198,7 @@ def test_run_eval_bart(model): ...@@ -198,7 +198,7 @@ def test_run_eval_bart(model):
assert not output_file_name.exists() assert not output_file_name.exists()
articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."] articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."]
_dump_articles(input_file_name, articles) _dump_articles(input_file_name, articles)
testargs = ["run_eval.py", str(input_file_name), str(output_file_name), model] # TODO: test score_path testargs = ["run_eval.py", model, str(input_file_name), str(output_file_name)] # TODO: test score_path
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
run_generate() run_generate()
assert Path(output_file_name).exists() assert Path(output_file_name).exists()
......
...@@ -60,8 +60,9 @@ def lmap(f: Callable, x: Iterable) -> List: ...@@ -60,8 +60,9 @@ def lmap(f: Callable, x: Iterable) -> List:
return list(map(f, x)) return list(map(f, x))
def calculate_bleu_score(output_lns, refs_lns) -> dict: def calculate_bleu_score(output_lns, refs_lns, **kwargs) -> dict:
return {"bleu": corpus_bleu(output_lns, [refs_lns]).score} """Uses sacrebleu's corpus_bleu implementation."""
return {"bleu": corpus_bleu(output_lns, [refs_lns], **kwargs).score}
def trim_batch( def trim_batch(
......
...@@ -253,9 +253,9 @@ class MBartIntegrationTests(unittest.TestCase): ...@@ -253,9 +253,9 @@ class MBartIntegrationTests(unittest.TestCase):
with torch.no_grad(): with torch.no_grad():
logits, *other_stuff = model(**net_input) logits, *other_stuff = model(**net_input)
expected_slice = torch.tensor([9.0078, 10.1113, 14.4787], device=torch_device, dtype=model.dtype) expected_slice = [9.0078, 10.1113, 14.4787]
result_slice = logits[0][0][:3] result_slice = logits[0][0][:3].tolist()
self.assertTrue(torch.allclose(expected_slice, result_slice, atol=TOLERANCE)) self.assertListEqual(expected_slice, result_slice)
@slow @slow
def test_enro_generate(self): def test_enro_generate(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment