Unverified Commit 80a16945 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Examples, T5] Change newstest2013 to newstest2014 and clean up (#3817)



* Refactored use of newstest2013 to newstest2014. Fixed bug where argparse consumed first command line argument as model_size argument rather than using default model_size by forcing explicit --model_size flag inclusion

* More pythonic file handling through 'with' context

* COSMETIC - ran Black and isort

* Fixed reference to number of lines in newstest2014

* Fixed failing test. More pythonic file handling

* finish PR from tholiao

* remove outcommented lines

* make style

* make isort happy
Co-authored-by: default avatarThomas Liao <tholiao@gmail.com>
parent d4867951
...@@ -9,17 +9,17 @@ evaluated on the WMT English-German dataset. ...@@ -9,17 +9,17 @@ evaluated on the WMT English-German dataset.
To be able to reproduce the authors' results on WMT English to German, you first need to download To be able to reproduce the authors' results on WMT English to German, you first need to download
the WMT14 en-de news datasets. the WMT14 en-de news datasets.
Go on Stanford's official NLP [website](https://nlp.stanford.edu/projects/nmt/) and find "newstest2013.en" and "newstest2013.de" under WMT'14 English-German data or download the dataset directly via: Go on Stanford's official NLP [website](https://nlp.stanford.edu/projects/nmt/) and find "newstest2014.en" and "newstest2014.de" under WMT'14 English-German data or download the dataset directly via:
```bash ```bash
curl https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/newstest2013.en > newstest2013.en curl https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/newstest2014.en > newstest2014.en
curl https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/newstest2013.de > newstest2013.de curl https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/newstest2014.de > newstest2014.de
``` ```
You should have 3000 sentence in each file. You can verify this by running: You should have 2737 sentences in each file. You can verify this by running:
```bash ```bash
wc -l newstest2013.en # should give 3000 wc -l newstest2014.en # should give 2737
``` ```
### Usage ### Usage
...@@ -29,8 +29,8 @@ Let's check the longest and shortest sentence in our file to find reasonable dec ...@@ -29,8 +29,8 @@ Let's check the longest and shortest sentence in our file to find reasonable dec
Get the longest and shortest sentence: Get the longest and shortest sentence:
```bash ```bash
awk '{print NF}' newstest2013.en | sort -n | head -1 # shortest sentence has 1 word awk '{print NF}' newstest2014.en | sort -n | head -1 # shortest sentence has 2 word
awk '{print NF}' newstest2013.en | sort -n | tail -1 # longest sentence has 106 words awk '{print NF}' newstest2014.en | sort -n | tail -1 # longest sentence has 91 words
``` ```
We will set our `max_length` to ~3 times the longest sentence and leave `min_length` to its default value of 0. We will set our `max_length` to ~3 times the longest sentence and leave `min_length` to its default value of 0.
...@@ -38,7 +38,7 @@ We decode with beam search `num_beams=4` as proposed in the paper. Also as is co ...@@ -38,7 +38,7 @@ We decode with beam search `num_beams=4` as proposed in the paper. Also as is co
To create translation for each in dataset and get a final BLEU score, run: To create translation for each in dataset and get a final BLEU score, run:
```bash ```bash
python evaluate_wmt.py <path_to_newstest2013.en> newstest2013_de_translations.txt <path_to_newstest2013.de> newsstest2013_en_de_bleu.txt python evaluate_wmt.py <path_to_newstest2014.en> newstest2014_de_translations.txt <path_to_newstest2014.de> newsstest2014_en_de_bleu.txt
``` ```
the default batch size, 16, fits in 16GB GPU memory, but may need to be adjusted to fit your system. the default batch size, 16, fits in 16GB GPU memory, but may need to be adjusted to fit your system.
......
...@@ -15,8 +15,6 @@ def chunks(lst, n): ...@@ -15,8 +15,6 @@ def chunks(lst, n):
def generate_translations(lns, output_file_path, model_size, batch_size, device): def generate_translations(lns, output_file_path, model_size, batch_size, device):
output_file = Path(output_file_path).open("w")
model = T5ForConditionalGeneration.from_pretrained(model_size) model = T5ForConditionalGeneration.from_pretrained(model_size)
model.to(device) model.to(device)
...@@ -27,27 +25,29 @@ def generate_translations(lns, output_file_path, model_size, batch_size, device) ...@@ -27,27 +25,29 @@ def generate_translations(lns, output_file_path, model_size, batch_size, device)
if task_specific_params is not None: if task_specific_params is not None:
model.config.update(task_specific_params.get("translation_en_to_de", {})) model.config.update(task_specific_params.get("translation_en_to_de", {}))
for batch in tqdm(list(chunks(lns, batch_size))): with Path(output_file_path).open("w") as output_file:
batch = [model.config.prefix + text for text in batch] for batch in tqdm(list(chunks(lns, batch_size))):
batch = [model.config.prefix + text for text in batch]
dct = tokenizer.batch_encode_plus(batch, max_length=512, return_tensors="pt", pad_to_max_length=True) dct = tokenizer.batch_encode_plus(batch, max_length=512, return_tensors="pt", pad_to_max_length=True)
input_ids = dct["input_ids"].to(device) input_ids = dct["input_ids"].to(device)
attention_mask = dct["attention_mask"].to(device) attention_mask = dct["attention_mask"].to(device)
translations = model.generate(input_ids=input_ids, attention_mask=attention_mask) translations = model.generate(input_ids=input_ids, attention_mask=attention_mask)
dec = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in translations] dec = [
tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in translations
]
for hypothesis in dec: for hypothesis in dec:
output_file.write(hypothesis + "\n") output_file.write(hypothesis + "\n")
output_file.flush()
def calculate_bleu_score(output_lns, refs_lns, score_path): def calculate_bleu_score(output_lns, refs_lns, score_path):
bleu = corpus_bleu(output_lns, [refs_lns]) bleu = corpus_bleu(output_lns, [refs_lns])
result = "BLEU score: {}".format(bleu.score) result = "BLEU score: {}".format(bleu.score)
score_file = Path(score_path).open("w") with Path(score_path).open("w") as score_file:
score_file.write(result) score_file.write(result)
def run_generate(): def run_generate():
...@@ -59,13 +59,13 @@ def run_generate(): ...@@ -59,13 +59,13 @@ def run_generate():
default="t5-base", default="t5-base",
) )
parser.add_argument( parser.add_argument(
"input_path", type=str, help="like wmt/newstest2013.en", "input_path", type=str, help="like wmt/newstest2014.en",
) )
parser.add_argument( parser.add_argument(
"output_path", type=str, help="where to save translation", "output_path", type=str, help="where to save translation",
) )
parser.add_argument( parser.add_argument(
"reference_path", type=str, help="like wmt/newstest2013.de", "reference_path", type=str, help="like wmt/newstest2014.de",
) )
parser.add_argument( parser.add_argument(
"score_path", type=str, help="where to save the bleu score", "score_path", type=str, help="where to save the bleu score",
...@@ -82,12 +82,19 @@ def run_generate(): ...@@ -82,12 +82,19 @@ def run_generate():
dash_pattern = (" ##AT##-##AT## ", "-") dash_pattern = (" ##AT##-##AT## ", "-")
input_lns = [x.strip().replace(dash_pattern[0], dash_pattern[1]) for x in open(args.input_path).readlines()] # Read input lines into python
with open(args.input_path, "r") as input_file:
input_lns = [x.strip().replace(dash_pattern[0], dash_pattern[1]) for x in input_file.readlines()]
generate_translations(input_lns, args.output_path, args.model_size, args.batch_size, args.device) generate_translations(input_lns, args.output_path, args.model_size, args.batch_size, args.device)
output_lns = [x.strip() for x in open(args.output_path).readlines()] # Read generated lines into python
refs_lns = [x.strip().replace(dash_pattern[0], dash_pattern[1]) for x in open(args.reference_path).readlines()] with open(args.output_path, "r") as output_file:
output_lns = [x.strip() for x in output_file.readlines()]
# Read reference lines into python
with open(args.reference_path, "r") as reference_file:
refs_lns = [x.strip().replace(dash_pattern[0], dash_pattern[1]) for x in reference_file.readlines()]
calculate_bleu_score(output_lns, refs_lns, args.score_path) calculate_bleu_score(output_lns, refs_lns, args.score_path)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment