Unverified Commit e86cece6 authored by Shivansh Pachnanda's avatar Shivansh Pachnanda Committed by GitHub
Browse files

Add MLQA (#2622)

* Add MLQA
* add mlqa_common_yaml

* add 49 tests of mlqa family

* update tasks/README.md

---------

* fix: mlqa ast error

* nit: removed .yaml ext from template_yaml

* nit changes: minor modifications generate_tasks.py

* deleted    lm_eval/tasks/mlqa/mlqa_common_yaml.yaml

* tests updated

* nit
parent 5db23e2c
# Generated by generate_tasks.py
include: mlqa_common_yaml
task: mlqa_vi_de
dataset_name: mlqa.vi.de
process_results: !function utils.process_results_vi
# Generated by generate_tasks.py
include: mlqa_common_yaml
task: mlqa_vi_en
dataset_name: mlqa.vi.en
process_results: !function utils.process_results_vi
# Generated by generate_tasks.py
include: mlqa_common_yaml
task: mlqa_vi_es
dataset_name: mlqa.vi.es
process_results: !function utils.process_results_vi
# Generated by generate_tasks.py
include: mlqa_common_yaml
task: mlqa_vi_hi
dataset_name: mlqa.vi.hi
process_results: !function utils.process_results_vi
# Generated by generate_tasks.py
include: mlqa_common_yaml
task: mlqa_vi_vi
dataset_name: mlqa.vi.vi
process_results: !function utils.process_results_vi
# Generated by generate_tasks.py
include: mlqa_common_yaml
task: mlqa_vi_zh
dataset_name: mlqa.vi.zh
process_results: !function utils.process_results_vi
# Generated by generate_tasks.py
include: mlqa_common_yaml
task: mlqa_zh_ar
dataset_name: mlqa.zh.ar
process_results: !function utils.process_results_zh
# Generated by generate_tasks.py
include: mlqa_common_yaml
task: mlqa_zh_de
dataset_name: mlqa.zh.de
process_results: !function utils.process_results_zh
# Generated by generate_tasks.py
include: mlqa_common_yaml
task: mlqa_zh_en
dataset_name: mlqa.zh.en
process_results: !function utils.process_results_zh
# Generated by generate_tasks.py
include: mlqa_common_yaml
task: mlqa_zh_es
dataset_name: mlqa.zh.es
process_results: !function utils.process_results_zh
# Generated by generate_tasks.py
include: mlqa_common_yaml
task: mlqa_zh_hi
dataset_name: mlqa.zh.hi
process_results: !function utils.process_results_zh
# Generated by generate_tasks.py
include: mlqa_common_yaml
task: mlqa_zh_vi
dataset_name: mlqa.zh.vi
process_results: !function utils.process_results_zh
# Generated by generate_tasks.py
include: mlqa_common_yaml
task: mlqa_zh_zh
dataset_name: mlqa.zh.zh
process_results: !function utils.process_results_zh
"""
Code based on Official evaluation script for the MLQA dataset.
Repo: https://github.com/facebookresearch/MLQA/blob/main/mlqa_evaluation_v1.py
"""
import re
import string
import sys
import unicodedata
from collections import Counter
import datasets
PUNCT = {
chr(i)
for i in range(sys.maxunicode)
if unicodedata.category(chr(i)).startswith("P")
}.union(string.punctuation)
WHITESPACE_LANGS = ["en", "es", "hi", "vi", "de", "ar"]
MIXED_SEGMENTATION_LANGS = ["zh"]
def whitespace_tokenize(text):
return text.split()
def mixed_segmentation(text):
segs_out = []
temp_str = ""
for char in text:
if re.search(r"[\u4e00-\u9fa5]", char) or char in PUNCT:
if temp_str != "":
ss = whitespace_tokenize(temp_str)
segs_out.extend(ss)
temp_str = ""
segs_out.append(char)
else:
temp_str += char
if temp_str != "":
ss = whitespace_tokenize(temp_str)
segs_out.extend(ss)
return segs_out
def normalize_answer(s, lang):
"""Lower text and remove punctuation, articles and extra whitespace."""
def remove_articles(text, lang):
if lang == "en":
return re.sub(r"\b(a|an|the)\b", " ", text)
elif lang == "es":
return re.sub(r"\b(un|una|unos|unas|el|la|los|las)\b", " ", text)
elif lang == "hi":
return text # Hindi does not have formal articles
elif lang == "vi":
return re.sub(r"\b(của|là|cái|chiếc|những)\b", " ", text)
elif lang == "de":
return re.sub(
r"\b(ein|eine|einen|einem|eines|einer|der|die|das|den|dem|des)\b",
" ",
text,
)
elif lang == "ar":
return re.sub(r"\sال^|ال", " ", text)
elif lang == "zh":
return text # Chinese does not have formal articles
else:
raise Exception("Unknown Language {}".format(lang))
def white_space_fix(text, lang):
if lang in WHITESPACE_LANGS:
tokens = whitespace_tokenize(text)
elif lang in MIXED_SEGMENTATION_LANGS:
tokens = mixed_segmentation(text)
else:
raise Exception("Unknown Language {}".format(lang))
return " ".join([t for t in tokens if t.strip() != ""])
def remove_punc(text):
return "".join(ch for ch in text if ch not in PUNCT)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s)), lang), lang)
def f1_score(prediction, ground_truth, lang):
prediction_tokens = normalize_answer(prediction, lang).split()
ground_truth_tokens = normalize_answer(ground_truth, lang).split()
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
num_same = sum(common.values())
if num_same == 0:
return 0
precision = 1.0 * num_same / len(prediction_tokens)
recall = 1.0 * num_same / len(ground_truth_tokens)
f1 = (2 * precision * recall) / (precision + recall)
return f1
def exact_match_score(prediction, ground_truth, lang):
return normalize_answer(prediction, lang) == normalize_answer(ground_truth, lang)
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths, lang):
scores_for_ground_truths = []
for ground_truth in ground_truths:
score = metric_fn(prediction, ground_truth, lang)
scores_for_ground_truths.append(score)
return max(scores_for_ground_truths)
def process_docs(dataset: datasets.Dataset) -> datasets.Dataset:
def _process_doc(doc):
out_doc = {
"context": doc["context"],
"question": doc["question"],
"answers": doc["answers"]["text"],
}
return out_doc
return dataset.map(_process_doc)
# Base function
def process_results_lang(doc, results, lang):
ground_truths = doc["answers"]
prediction = results[0].strip()
exact_match = metric_max_over_ground_truths(
exact_match_score, prediction, ground_truths, lang
)
f1 = metric_max_over_ground_truths(f1_score, prediction, ground_truths, lang)
return {"exact_match": exact_match, "f1": f1}
# Language Wrapper functions
def process_results_en(doc, results):
return process_results_lang(doc, results, "en")
def process_results_es(doc, results):
return process_results_lang(doc, results, "es")
def process_results_hi(doc, results):
return process_results_lang(doc, results, "hi")
def process_results_vi(doc, results):
return process_results_lang(doc, results, "vi")
def process_results_de(doc, results):
return process_results_lang(doc, results, "de")
def process_results_ar(doc, results):
return process_results_lang(doc, results, "ar")
def process_results_zh(doc, results):
return process_results_lang(doc, results, "zh")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment