Unverified Commit f86d6874 authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Merge pull request #775 from EleutherAI/wmt

[Refactor] Add WMT tasks
parents 1a02d9df 01ad787d
......@@ -56,6 +56,55 @@ def matthews_corrcoef(items):
return sklearn.metrics.matthews_corrcoef(golds, preds)
@register_aggregation("bleu")
def bleu(items):
"""The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric
for evaluating a generated sentence to a reference sentence. It counts matching
n-grams in the candidate translation to n-grams in the reference text, where
1-gram or unigram would be each token and a bigram comparison would be each
word pair. The comparison is made regardless of word order
Source: https://machinelearningmastery.com/calculate-bleu-score-for-text-python/
Paper: https://www.aclweb.org/anthology/P02-1040/
Higher is better
"""
refs = list(zip(*items))[0]
preds = list(zip(*items))[1]
refs, preds = _sacreformat(refs, preds)
return sacrebleu.corpus_bleu(preds, refs).score
@register_aggregation("chrf")
def chrf(items):
"""chrF++ is a tool for automatic evaluation of machine translation output
based on character n-gram precision and recall enhanced with word n-grams.
Source: https://github.com/m-popovic/chrF
Paper: https://www.aclweb.org/anthology/W15-3049.pdf
Higher is better # TODO I think
"""
refs = list(zip(*items))[0]
preds = list(zip(*items))[1]
refs, preds = _sacreformat(refs, preds)
return sacrebleu.corpus_chrf(preds, refs).score
@register_aggregation("ter")
def ter(items):
"""Translation Error Rate is an error metric for machine translation that
measures the number of edits required to change a system output into one
of the references
Source: http://www.cs.umd.edu/~snover/tercom/
Paper: http://mt-archive.info/AMTA-2006-Snover.pdf
Lower is better
"""
refs = list(zip(*items))[0]
preds = list(zip(*items))[1]
refs, preds = _sacreformat(refs, preds)
return sacrebleu.corpus_ter(preds, refs).score
@register_metric(
metric="acc",
higher_is_better=True,
......@@ -160,6 +209,36 @@ def f1_fn(items): # This is a passthrough function
return items
@register_metric(
metric="bleu",
higher_is_better=True,
output_type="greedy_until",
aggregation="bleu",
)
def bleu_fn(items): # This is a passthrough function
return items
@register_metric(
metric="chrf",
higher_is_better=True,
output_type="greedy_until",
aggregation="chrf",
)
def chrf_fn(items): # This is a passthrough function
return items
@register_metric(
metric="ter",
higher_is_better=True,
output_type="greedy_until",
aggregation="ter",
)
def ter_fn(items): # This is a passthrough function
return items
@register_metric(
metric="acc_all",
higher_is_better=True,
......@@ -217,55 +296,6 @@ def weighted_mean(items):
return sum(a) / sum(b)
@register_metric(metric="bleu", higher_is_better=True, aggregation="mean")
def bleu(items):
"""The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric
for evaluating a generated sentence to a reference sentence. It counts matching
n-grams in the candidate translation to n-grams in the reference text, where
1-gram or unigram would be each token and a bigram comparison would be each
word pair. The comparison is made regardless of word order
Source: https://machinelearningmastery.com/calculate-bleu-score-for-text-python/
Paper: https://www.aclweb.org/anthology/P02-1040/
Higher is better
"""
refs = list(zip(*items))[0]
preds = list(zip(*items))[1]
refs, preds = _sacreformat(refs, preds)
return sacrebleu.corpus_bleu(preds, refs).score
@register_metric(metric="chrf", higher_is_better=True, aggregation="mean")
def chrf(items):
"""chrF++ is a tool for automatic evaluation of machine translation output
based on character n-gram precision and recall enhanced with word n-grams.
Source: https://github.com/m-popovic/chrF
Paper: https://www.aclweb.org/anthology/W15-3049.pdf
Higher is better # TODO I think
"""
refs = list(zip(*items))[0]
preds = list(zip(*items))[1]
refs, preds = _sacreformat(refs, preds)
return sacrebleu.corpus_chrf(preds, refs).score
@register_metric(metric="ter", higher_is_better=True, aggregation="mean")
def ter(items):
"""Translation Error Rate is an error metric for machine translation that
measures the number of edits required to change a system output into one
of the references
Source: http://www.cs.umd.edu/~snover/tercom/
Paper: http://mt-archive.info/AMTA-2006-Snover.pdf
Lower is better
"""
refs = list(zip(*items))[0]
preds = list(zip(*items))[1]
refs, preds = _sacreformat(refs, preds)
return sacrebleu.corpus_ter(preds, refs).score
def is_non_str_iterable(obj):
return isinstance(obj, Iterable) and not isinstance(obj, str)
......
......@@ -78,7 +78,7 @@ class TaskConfig(dict):
# runtime configuration options
num_fewshot: int = 0
# scoring options
metric_list: str = None
metric_list: list = None
output_type: str = "greedy_until"
generation_kwargs: dict = None
repeats: int = 1
......@@ -88,7 +88,6 @@ class TaskConfig(dict):
metadata: str = None # by default, not used in the code. allows for users to pass arbitrary info to tasks
def __post_init__(self) -> None:
if "." in self.dataset_path:
import inspect
......@@ -1073,25 +1072,33 @@ class ConfigurableTask(Task):
# TODO: this may break for multipLe_target, non zero-or-1 metrics
scores = []
for gold_option in gold:
res = self._metric_fn_list[metric](
references=[gold_option],
predictions=[result],
**self._metric_fn_kwargs[metric],
)
if isinstance(res, dict):
try:
result_score = self._metric_fn_list[metric](
references=[gold_option],
predictions=[result],
**self._metric_fn_kwargs[metric],
)
except TypeError: # TODO: this is hacky and I don't want to do it
result_score = self._metric_fn_list[metric](
[gold_option, result]
)
if isinstance(result_score, dict):
# TODO: this handles the case where HF evaluate returns a dict.
res = res[metric]
scores.append(res)
result_score = result_score[metric]
scores.append(result_score)
if any(scores):
result_score = 1.0
else:
result_score = 0.0
else:
result_score = self._metric_fn_list[metric](
references=[gold],
predictions=[result],
**self._metric_fn_kwargs[metric],
)
try:
result_score = self._metric_fn_list[metric](
references=[gold],
predictions=[result],
**self._metric_fn_kwargs[metric],
)
except TypeError: # needed for now in order to use a different interface between our own metrics and HF Evaluate metrics
result_score = self._metric_fn_list[metric]([gold, result])
if isinstance(result_score, dict):
# TODO: this handles the case where HF evaluate returns a dict.
result_score = result_score[metric]
......
......@@ -250,7 +250,8 @@ def evaluate(
# print the prompt for the first few documents
if inst.doc_id < 1:
eval_logger.info(
f"Task: {task_name}; document {inst.doc_id}; context prompt (starting on next line):\n{inst.args[0]}\n(end of prompt on previous line)"
f"Task: {task_name}; document {inst.doc_id}; context prompt (starting on next line):\
\n{inst.args[0]}\n(end of prompt on previous line)\ntarget string or answer choice index (starting on next line):\n{task.doc_to_target(inst.doc)}\n(end of target on previous line)"
)
eval_logger.info(f"Request: {str(inst)}")
......@@ -359,28 +360,35 @@ def evaluate(
if type(items[0]) == tuple:
numitem = len(items[0])
# distributed gather requires all ranks to have same dimensions
# so we pad out with float32 min value
pad_value = torch.finfo(torch.float32).min
metrics_tensor = torch.tensor(items, device=lm.device)
if isinstance(items[0], (str, list)):
# handle the string case
gathered_items = [None] * lm.accelerator.num_processes
torch.distributed.all_gather_object(gathered_items, items)
original_dtype = metrics_tensor.dtype # store original dtype
torch_device_tensor = lm.accelerator.pad_across_processes(
metrics_tensor.to(torch.float32), pad_index=pad_value
)
gathered_item = lm.accelerator.gather(torch_device_tensor)
if numitem > 0:
gathered_filtered = gathered_item[gathered_item[:, 0] != pad_value]
gathered_item = list(itertools.chain.from_iterable(gathered_items))
else:
gathered_filtered = gathered_item[gathered_item != pad_value]
# distributed gather requires all ranks to have same dimensions
# so we pad out with float32 min value
pad_value = torch.finfo(torch.float32).min
metrics_tensor = torch.tensor(items, device=lm.device)
original_dtype = metrics_tensor.dtype # store original dtype
torch_device_tensor = lm.accelerator.pad_across_processes(
metrics_tensor.to(torch.float32), pad_index=pad_value
)
gathered_item = lm.accelerator.gather(torch_device_tensor)
gathered_item = (
gathered_filtered.to(original_dtype).cpu().detach().numpy().tolist()
)
# reconvert if we were passed a tuple of values
if numitem > 0:
gathered_item = [tuple(g) for g in gathered_item]
if numitem > 0:
gathered_filtered = gathered_item[gathered_item[:, 0] != pad_value]
else:
gathered_filtered = gathered_item[gathered_item != pad_value]
gathered_item = (
gathered_filtered.to(original_dtype).cpu().detach().numpy().tolist()
)
# reconvert if we were passed a tuple of values
if numitem > 0:
gathered_item = [tuple(g) for g in gathered_item]
if lm.rank == 0:
vals_torch[(task_name, key, metric)] = gathered_item
......@@ -412,7 +420,7 @@ def evaluate(
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this
if bootstrap_iters > 0:
if False: # bootstrap_iters > 0:
stderr = lm_eval.api.metrics.stderr_for_metric(
metric=task.aggregation()[metric],
bootstrap_iters=min(bootstrap_iters, 1000)
......
......@@ -6,3 +6,5 @@ logging.basicConfig(
level=logging.INFO,
)
eval_logger = logging.getLogger("lm-eval")
SPACING = " " * 47
......@@ -432,9 +432,9 @@ class HFLM(LM):
return encoding
def tok_batch_encode(
self,
strings: List[str],
padding_side: str = "left",
self,
strings: List[str],
padding_side: str = "left",
left_truncate_len: int = None,
truncation: bool = False,
):
......
# Translation Tasks
### Paper
### Citation
```
```
### Groups and Tasks
#### Groups
* `gpt3_translation_tasks`
* `wmt14`
* `wmt16`
* `wmt20`
* `iwslt2017`
#### Tasks
*
### Checklist
For adding novel benchmarks/datasets to the library:
* [x] Is the task an existing benchmark in the literature?
* [ ] Have you referenced the original paper that introduced the task?
* [ ] If yes, does the original paper provide a reference implementation? If so, have you checked against the reference implementation and documented how to run such a test?
If other tasks on this dataset are already supported:
* [x] Is the "Main" variant of this task clearly denoted?
* [ ] Have you provided a short sentence in a README on what each new variant adds / evaluates?
* [ ] Have you noted which, if any, published evaluation setups are matched by this variant?
* [ ] Checked for equivalence with v0.3.0 LM Evaluation Harness
# Generated by utils.py
dataset_name: iwslt2017-en-ar
dataset_path: iwslt2017
doc_to_target: ' {{translation["en"]}}'
doc_to_text: 'Arabic phrase: {{translation["ar"]}}
English phrase:'
group:
- greedy_until
- translation
- iwslt2017
include: wmt_common_yaml
task: iwslt2017-ar-en
# Generated by utils.py
dataset_name: iwslt2017-en-ar
dataset_path: iwslt2017
doc_to_target: ' {{translation["ar"]}}'
doc_to_text: 'English phrase: {{translation["en"]}}
Arabic phrase:'
group:
- greedy_until
- translation
- iwslt2017
include: wmt_common_yaml
task: iwslt2017-en-ar
import argparse
from typing import Dict, List
import yaml
import sacrebleu
try:
import pycountry
except ModuleNotFoundError:
raise Exception(
"`pycountry` is required for generating translation task prompt templates. \
please install pycountry via pip install lm-eval[multilingua] or pip install -e .[multilingual]",
)
# Different translation benchmarks included in the library. Mostly WMT.
# These correspond to dataset names (subsets) on HuggingFace for each dataset.
# A yaml file is generated by this script for each language pair.
gpt3_translation_benchmarks = {
"wmt14": ["fr-en"], # ["en-fr", "fr-en"], # French
"wmt16": [
"ro-en",
"de-en",
], # ["en-ro", "ro-en", "de-en", "en-de"], # German, Romanian
}
# 28 total
LANGUAGES = {
**gpt3_translation_benchmarks,
# "wmt20": sacrebleu.get_langpairs_for_testset("wmt20"),
"iwslt2017": ["en-ar"], # Arabic
}
def code_to_language(code):
# key is alpha_2 or alpha_3 depending on the code length
language_tuple = pycountry.languages.get(**{f"alpha_{len(code)}": code})
return language_tuple.name
def gen_lang_yamls(output_dir: str, overwrite: bool) -> None:
"""
Generate a yaml file for each language.
:param output_dir: The directory to output the files to.
:param overwrite: Whether to overwrite files if they already exist.
"""
err = []
for lang in LANGUAGES.keys():
for dataset_name in LANGUAGES[lang]:
src_lang, _, tgt_lang = dataset_name.partition("-")
for src, tgt in [[src_lang, tgt_lang], [tgt_lang, src_lang]]:
# both translation directions for each lang pair
lang_pair = src + "-" + tgt
file_name = f"{lang}_{lang_pair}.yaml"
try:
source, target = code_to_language(src), code_to_language(tgt)
groups = ["greedy_until", "translation", lang]
if lang in gpt3_translation_benchmarks.keys():
groups += ["gpt3_translation_benchmarks"]
with open(
f"{output_dir}/{file_name}",
"w" if overwrite else "x",
encoding="utf8",
) as f:
f.write("# Generated by utils.py\n")
yaml.dump(
{
"include": "wmt_common_yaml",
"group": groups,
"dataset_path": lang,
"dataset_name": dataset_name
if not (lang == "iwslt2017")
else "iwslt2017-" + dataset_name,
"task": f"{lang}-{lang_pair}",
"doc_to_text": f"{source} phrase: "
+ "{{translation["
+ f'"{src}"'
+ "]}}\n"
+ f"{target} phrase:",
"doc_to_target": " {{"
+ "translation["
+ f'"{tgt}"]'
+ "}}",
},
f,
)
except FileExistsError:
err.append(file_name)
if len(err) > 0:
raise FileExistsError(
"Files were not created because they already exist (use --overwrite flag):"
f" {', '.join(err)}"
)
def main() -> None:
"""Parse CLI args and generate language-specific yaml files."""
parser = argparse.ArgumentParser()
parser.add_argument(
"--overwrite",
default=False,
action="store_true",
help="Overwrite files if they already exist",
)
parser.add_argument(
"--output-dir", default=".", help="Directory to write yaml files to"
)
args = parser.parse_args()
gen_lang_yamls(output_dir=args.output_dir, overwrite=args.overwrite)
if __name__ == "__main__":
main()
# Generated by utils.py
dataset_name: fr-en
dataset_path: wmt14
doc_to_target: ' {{translation["fr"]}}'
doc_to_text: 'English phrase: {{translation["en"]}}
French phrase:'
group:
- greedy_until
- translation
- wmt14
- gpt3_translation_benchmarks
include: wmt_common_yaml
task: wmt14-en-fr
# Generated by utils.py
dataset_name: fr-en
dataset_path: wmt14
doc_to_target: ' {{translation["en"]}}'
doc_to_text: 'French phrase: {{translation["fr"]}}
English phrase:'
group:
- greedy_until
- translation
- wmt14
- gpt3_translation_benchmarks
include: wmt_common_yaml
task: wmt14-fr-en
# Generated by utils.py
dataset_name: de-en
dataset_path: wmt16
doc_to_target: ' {{translation["en"]}}'
doc_to_text: 'German phrase: {{translation["de"]}}
English phrase:'
group:
- greedy_until
- translation
- wmt16
- gpt3_translation_benchmarks
include: wmt_common_yaml
task: wmt16-de-en
# Generated by utils.py
dataset_name: de-en
dataset_path: wmt16
doc_to_target: ' {{translation["de"]}}'
doc_to_text: 'English phrase: {{translation["en"]}}
German phrase:'
group:
- greedy_until
- translation
- wmt16
- gpt3_translation_benchmarks
include: wmt_common_yaml
task: wmt16-en-de
# Generated by utils.py
dataset_name: ro-en
dataset_path: wmt16
doc_to_target: ' {{translation["ro"]}}'
doc_to_text: 'English phrase: {{translation["en"]}}
Romanian phrase:'
group:
- greedy_until
- translation
- wmt16
- gpt3_translation_benchmarks
include: wmt_common_yaml
task: wmt16-en-ro
# Generated by utils.py
dataset_name: ro-en
dataset_path: wmt16
doc_to_target: ' {{translation["en"]}}'
doc_to_text: 'Romanian phrase: {{translation["ro"]}}
English phrase:'
group:
- greedy_until
- translation
- wmt16
- gpt3_translation_benchmarks
include: wmt_common_yaml
task: wmt16-ro-en
output_type: greedy_until
training_split: train
validation_split: validation
fewshot_split: validation
test_split: test
metric_list:
- metric: bleu
- metric: ter
- metric: chrf
generation_kwargs:
until:
- "\n"
do_sample: false
temperature: 0.0
repeats: 1
......@@ -9,7 +9,7 @@ from pathlib import Path
from lm_eval import evaluator, utils
from lm_eval.api.registry import ALL_TASKS
from lm_eval.logger import eval_logger
from lm_eval.logger import eval_logger, SPACING
from lm_eval.tasks import include_task_folder
from lm_eval.benchmarks import include_benchmarks
......@@ -17,16 +17,18 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false"
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument("--model", required=True, help="Name of model e.g. `hf`")
parser.add_argument(
"--tasks",
default=None,
help="Available Tasks:\n - {}".format("\n - ".join(sorted(ALL_TASKS))),
)
parser.add_argument(
"--model_args",
default="",
help="String arguments for model, e.g. `pretrained=EleutherAI/pythia-160m,dtype=float32`",
)
parser.add_argument(
"--tasks", default=None # , choices=utils.MultiChoice(sorted(ALL_TASKS))
)
parser.add_argument(
"--num_fewshot",
type=int,
......@@ -126,10 +128,21 @@ def main() -> None:
else:
tasks_list = args.tasks.split(",")
task_names = utils.pattern_match(tasks_list, ALL_TASKS)
task_missing = []
for task in [task for task in tasks_list if task not in task_names]:
if os.path.isfile(task):
config = utils.load_yaml_config(task)
task_names.append(config)
else:
task_missing.append(task)
if task_missing != []:
missing = ", ".join(task_missing)
eval_logger.error(
f"Tasks were not found: {missing}\n"
f"{SPACING}Try `lm-eval -h` for list of available tasks",
)
raise ValueError(f"Tasks {missing} were not found.")
if args.output_path:
path = Path(args.output_path)
......
......@@ -15,7 +15,7 @@ extras_require = {
],
"testing": ["pytest", "pytest-cov", "pytest-xdist"],
"multilingual": ["nagisa>=0.2.7", "jieba>=0.42.1"],
"sentencepiece": ["sentencepiece>=0.1.98", "protobuf>=4.22.1"],
"sentencepiece": ["sentencepiece>=0.1.98", "protobuf>=4.22.1", "pycountry"],
"promptsource": [
"promptsource @ git+https://github.com/bigscience-workshop/promptsource.git#egg=promptsource"
],
......@@ -62,10 +62,9 @@ setuptools.setup(
"omegaconf>=2.2",
"peft>=0.2.0",
"pybind11>=2.6.2",
"pycountry",
"pytablewriter",
"rouge-score>=0.0.4",
"sacrebleu==1.5.0",
"sacrebleu>=1.5.0",
"scikit-learn>=0.24.1",
"sqlitedict",
"torch>=1.8",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment