Unverified Commit 7ad6bf45 authored by Leo Gao's avatar Leo Gao Committed by GitHub
Browse files

Merge pull request #146 from EleutherAI/translation

Translation
parents 0601c909 ac47d481
env env
*.pyc *.pyc
data/ data/
.idea
lm_cache
\ No newline at end of file
import abc import abc
import random import random
import numpy as np import numpy as np
import sklearn
import math from lm_eval.metrics import mean
class LM(abc.ABC): class LM(abc.ABC):
...@@ -30,6 +30,7 @@ class LM(abc.ABC): ...@@ -30,6 +30,7 @@ class LM(abc.ABC):
""" """
pass pass
# TODO: Add an optional max length
@abc.abstractmethod @abc.abstractmethod
def greedy_until(self, requests): def greedy_until(self, requests):
"""Generate greedily until a stopping sequence """Generate greedily until a stopping sequence
...@@ -61,6 +62,14 @@ class LM(abc.ABC): ...@@ -61,6 +62,14 @@ class LM(abc.ABC):
class Task(abc.ABC): class Task(abc.ABC):
"""A task represents an entire benchmark including its dataset, problems,
answers, and evaluation methods. See BoolQ for a simple example implementation
A `doc` can be any python object which represents one instance of evaluation.
This is usually a dictionary e.g.
{"question": ..., "answer": ...} or
{"question": ..., question, answer)
"""
def __init__(self): def __init__(self):
self.download() self.download()
self._training_docs = None self._training_docs = None
...@@ -148,9 +157,9 @@ class Task(abc.ABC): ...@@ -148,9 +157,9 @@ class Task(abc.ABC):
@abc.abstractmethod @abc.abstractmethod
def aggregation(self): def aggregation(self):
""" """
:returns: {str: [float] -> float} :returns: {str: [metric_score] -> float}
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics functions that aggregate a list of metric scores
""" """
pass pass
...@@ -213,60 +222,6 @@ class MultipleChoiceTask(Task): ...@@ -213,60 +222,6 @@ class MultipleChoiceTask(Task):
} }
def mean(arr):
return sum(arr) / len(arr)
def median(arr):
return arr[len(arr) // 2]
def matthews_corrcoef(items):
unzipped_list = list(zip(*items))
golds = unzipped_list[0]
preds = unzipped_list[1]
return sklearn.metrics.matthews_corrcoef(golds, preds)
def f1_score(items):
unzipped_list = list(zip(*items))
golds = unzipped_list[0]
preds = unzipped_list[1]
fscore = sklearn.metrics.f1_score(golds, preds)
return np.max(fscore)
def acc_all(items):
# Only count as correct if all answers are labeled correctly for each question
question_scoring_dict = {}
preds = list(zip(*items))[0]
docs = list(zip(*items))[1]
for doc, pred in zip(docs, preds):
question_id = doc["idx"]["question"]
if question_id not in question_scoring_dict:
question_scoring_dict[question_id] = []
gold_label = doc["label"] == 1
question_scoring_dict[question_id].append(gold_label == pred)
acc = np.mean([int(all(x)) for x in question_scoring_dict.values()])
return acc
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
"""Compute max metric between prediction and each ground truth."""
scores_for_ground_truths = []
for ground_truth in ground_truths:
score = metric_fn(prediction, ground_truth)
scores_for_ground_truths.append(score)
return max(scores_for_ground_truths)
def perplexity(items):
return math.exp(-mean(items))
req_ret_lens = { req_ret_lens = {
'loglikelihood': 2, 'loglikelihood': 2,
'greedy_until': None, 'greedy_until': None,
......
import math
from pprint import pprint
import numpy as np
import sacrebleu
import sklearn
def mean(arr):
return sum(arr) / len(arr)
def median(arr):
return arr[len(arr) // 2]
def matthews_corrcoef(items):
unzipped_list = list(zip(*items))
golds = unzipped_list[0]
preds = unzipped_list[1]
return sklearn.metrics.matthews_corrcoef(golds, preds)
def f1_score(items):
unzipped_list = list(zip(*items))
golds = unzipped_list[0]
preds = unzipped_list[1]
fscore = sklearn.metrics.f1_score(golds, preds)
return np.max(fscore)
def acc_all(items):
# Only count as correct if all answers are labeled correctly for each question
question_scoring_dict = {}
preds = list(zip(*items))[0]
docs = list(zip(*items))[1]
for doc, pred in zip(docs, preds):
question_id = doc["idx"]["question"]
if question_id not in question_scoring_dict:
question_scoring_dict[question_id] = []
gold_label = doc["label"] == 1
question_scoring_dict[question_id].append(gold_label == pred)
acc = np.mean([int(all(x)) for x in question_scoring_dict.values()])
return acc
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
"""Compute max metric between prediction and each ground truth."""
scores_for_ground_truths = []
for ground_truth in ground_truths:
score = metric_fn(prediction, ground_truth)
scores_for_ground_truths.append(score)
return max(scores_for_ground_truths)
def perplexity(items):
return math.exp(-mean(items))
def bleu(items):
"""The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric
for evaluating a generated sentence to a reference sentence. It counts matching
n-grams in the candidate translation to n-grams in the reference text, where
1-gram or unigram would be each token and a bigram comparison would be each
word pair. The comparison is made regardless of word order
Source: https://machinelearningmastery.com/calculate-bleu-score-for-text-python/
Paper: https://www.aclweb.org/anthology/P02-1040/
Higher is better
"""
refs = list(zip(*items))[0]
preds = list(zip(*items))[1]
refs, preds = _sacreformat(refs, preds)
return sacrebleu.corpus_bleu(preds, refs).score
def chrf(items):
"""chrF++ is a tool for automatic evaluation of machine translation output
based on character n-gram precision and recall enhanced with word n-grams.
Source: https://github.com/m-popovic/chrF
Paper: https://www.aclweb.org/anthology/W15-3049.pdf
Higher is better # TODO I think
"""
refs = list(zip(*items))[0]
preds = list(zip(*items))[1]
refs, preds = _sacreformat(refs, preds)
return sacrebleu.corpus_chrf(preds, refs).score
def ter(items):
"""Translation Error Rate is an error metric for machine translation that
measures the number of edits required to change a system output into one
of the references
Source: http://www.cs.umd.edu/~snover/tercom/
Paper: http://mt-archive.info/AMTA-2006-Snover.pdf
Lower is better
"""
refs = list(zip(*items))[0]
preds = list(zip(*items))[1]
refs, preds = _sacreformat(refs, preds)
return sacrebleu.corpus_ter(preds, refs).score
def _sacreformat(refs, preds):
"""Format refs and preds for sacrebleu corpus calculation. It is very particular"""
# Sacrebleu expects (List[str], List[List[str])
# e.g. sacrebleu.corpus_bleu([pred_t], [[ref1_stream], [ref2_stream], ...])
# Note [ref1_stream] is the first reference for each pred.
# So lists are size N and (M, N) for N preds and M possible refs for each pred
# This is a different order of dimensions that I would expect
# We expect refs to be List[str] or List[List[str]], the outer list corresponding to preds
# Must become List[List[str]] with the inner list corresponding to preds
if not isinstance(refs, list):
refs = list(refs)
if not isinstance(refs[0], list):
refs = [[ref] for ref in refs]
refs = list(zip(*refs))
# Note the number of refs in each ref list much match the number of preds
# We expect preds to be List[str] or List[List[str]]. Must become List[str]
if not isinstance(preds, list):
preds = list(preds)
if isinstance(preds[0], list):
assert len(preds[0]) == 1, f"Pred must be a str, was {preds[0]}"
preds = [pred[0] for pred in preds]
return refs, preds
from pprint import pprint
from . import superglue from . import superglue
from . import glue from . import glue
from . import arc from . import arc
...@@ -21,6 +23,7 @@ from . import pubmedqa ...@@ -21,6 +23,7 @@ from . import pubmedqa
from . import sciq from . import sciq
from . import webqs from . import webqs
from . import qa4mre from . import qa4mre
from . import translation
from . import headqa from . import headqa
from . import mathqa from . import mathqa
...@@ -88,6 +91,11 @@ TASK_REGISTRY = { ...@@ -88,6 +91,11 @@ TASK_REGISTRY = {
"arithmetic_2dm": arithmetic.Arithmetic2DMultiplication, "arithmetic_2dm": arithmetic.Arithmetic2DMultiplication,
"arithmetic_1dc": arithmetic.Arithmetic1DComposite, "arithmetic_1dc": arithmetic.Arithmetic1DComposite,
# TODO Perhaps make these groups of tasks
# e.g. anli, arithmetic, openai_translations, harness_translations
# e.g. wmt14-fr-en
**translation.create_tasks_from_benchmarks(translation.selected_benchmarks)
} }
...@@ -95,7 +103,12 @@ ALL_TASKS = sorted(list(TASK_REGISTRY)) ...@@ -95,7 +103,12 @@ ALL_TASKS = sorted(list(TASK_REGISTRY))
def get_task(task_name): def get_task(task_name):
return TASK_REGISTRY[task_name] try:
return TASK_REGISTRY[task_name]
except KeyError as e:
print("Available tasks:")
pprint(TASK_REGISTRY)
raise KeyError(f"Missing task {task_name}")
def get_task_dict(task_name_list): def get_task_dict(task_name_list):
......
import numpy as np import numpy as np
from lm_eval.base import rf, mean from lm_eval.base import rf
from ..metrics import mean
from . common import HFTask from . common import HFTask
class ANLIBase(HFTask): class ANLIBase(HFTask):
......
import numpy as np
from lm_eval.base import MultipleChoiceTask from lm_eval.base import MultipleChoiceTask
from .common import HFTask from ..metrics import mean
from . common import HFTask
class ARCEasy(HFTask, MultipleChoiceTask): class ARCEasy(HFTask, MultipleChoiceTask):
......
...@@ -2,7 +2,8 @@ import abc ...@@ -2,7 +2,8 @@ import abc
import json import json
import os import os
from collections import namedtuple from collections import namedtuple
from lm_eval.base import Task, mean, rf from lm_eval.base import Task, rf
from lm_eval.metrics import mean
from best_download import download_file from best_download import download_file
ArithmeticDoc = namedtuple('ArithmeticDoc', ['context', 'completion']) ArithmeticDoc = namedtuple('ArithmeticDoc', ['context', 'completion'])
......
import datasets import datasets
import numpy as np import numpy as np
import lm_eval.metrics
from ..base import Task from ..base import Task
...@@ -44,7 +46,7 @@ class HFTask(Task): ...@@ -44,7 +46,7 @@ class HFTask(Task):
def simple_accuracy_metric(preds, golds): def simple_accuracy_metric(preds, golds):
acc = float((np.array(preds) == np.array(golds)).mean()) acc = float(lm_eval.metrics.mean())
return { return {
"major": acc, "major": acc,
"minor": {"acc": acc}, "minor": {"acc": acc},
......
import numpy as np import numpy as np
from lm_eval.base import rf, mean, f1_score, matthews_corrcoef from lm_eval.base import rf
from ..metrics import mean, matthews_corrcoef, f1_score
from scipy.stats import pearsonr, spearmanr from scipy.stats import pearsonr, spearmanr
from tqdm import auto as tqdm_lib from tqdm import auto as tqdm_lib
from . common import HFTask, yesno from . common import HFTask, yesno
......
from lm_eval.base import Task, rf, mean, perplexity from lm_eval.base import Task, rf
from lm_eval.metrics import mean, perplexity
from lm_eval.utils import sh from lm_eval.utils import sh
import json import json
import math import math
......
import numpy as np import numpy as np
from lm_eval.base import rf, mean from lm_eval.base import rf
from ..metrics import mean
from . common import HFTask from . common import HFTask
......
...@@ -2,7 +2,8 @@ import numpy as np ...@@ -2,7 +2,8 @@ import numpy as np
import json import json
import random import random
from .common import HFTask from .common import HFTask
from lm_eval.base import rf, mean from lm_eval.base import rf
from ..metrics import mean
class Pubmed_QA(HFTask): class Pubmed_QA(HFTask):
......
import os import os
import numpy as np import numpy as np
from best_download import download_file from best_download import download_file
from lm_eval.base import MultipleChoiceTask, rf, mean from lm_eval.base import MultipleChoiceTask, rf
from lm_eval.metrics import mean
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
import random import random
......
import collections import collections
import datasets import datasets
import numpy as np import numpy as np
from lm_eval.base import rf, mean from lm_eval.base import rf
from ..metrics import mean
from . common import HFTask from . common import HFTask
import os import os
......
import json import json
import random import random
import os import os
from lm_eval.base import MultipleChoiceTask, rf, mean from lm_eval.base import MultipleChoiceTask, rf
from ..metrics import mean
from tqdm import auto as tqdm_lib from tqdm import auto as tqdm_lib
from . common import simple_accuracy_metric from . common import simple_accuracy_metric
import numpy as np import numpy as np
......
import os import os
import json import json
from ..utils import sh from ..utils import sh
from lm_eval.base import MultipleChoiceTask, rf, mean from lm_eval.base import MultipleChoiceTask, rf
from ..metrics import mean
import zipfile import zipfile
from best_download import download_file from best_download import download_file
......
...@@ -5,7 +5,8 @@ To-do: ...@@ -5,7 +5,8 @@ To-do:
""" """
import numpy as np import numpy as np
from . common import HFTask, yesno from . common import HFTask, yesno
from lm_eval.base import rf, mean, acc_all, metric_max_over_ground_truths from lm_eval.base import rf
from ..metrics import mean, acc_all, metric_max_over_ground_truths
import sklearn import sklearn
import transformers.data.metrics.squad_metrics as squad_metrics import transformers.data.metrics.squad_metrics as squad_metrics
from ..utils import general_detokenize from ..utils import general_detokenize
......
import abc
import json
import random
import os
from pprint import pprint
import pycountry
from sacrebleu import sacrebleu
import logging
from lm_eval import metrics
from lm_eval.base import Task, rf
"""
This file implements translation tasks using datasets from WMT conferences, provided by sacrebleu.
Traditionally they are evaluated with BLEU scores. TER and CHRF are other options.
See sacrebleu.DATASETS for all available datasets. There are a lot!
"""
sacrebleu_datasets = sacrebleu.DATASETS
########################################
# Benchmarks one might want to run
########################################
# 6 total
gpt3_benchmarks = {
"wmt14": ['en-fr', 'fr-en'], # French
"wmt16": ['en-ro', 'ro-en', 'de-en', 'en-de'], # German, Romanian
}
# 14 total
selected_benchmarks = {
**gpt3_benchmarks,
"wmt20": ['fr-de', 'de-fr', 'en-ru', 'ru-en', 'en-iu', 'iu-en'], # French, German, Russian, Inuit
"iwslt17": ['en-ar', 'ar-en'] # Arabic
}
# 319 total
all_benchmarks = {
ts: sacrebleu.get_langpairs_for_testset(ts)
for ts in sacrebleu.get_available_testsets()
}
available_tests = {
"gpt3_tests": gpt3_benchmarks,
"selected_tests": selected_benchmarks,
"all_tests": all_benchmarks
}
def create_tasks_from_benchmarks(benchmark_dict):
"""Creates a dictionary of tasks from a dict
:param benchmark_dict: { dataset: [lang_pair, ...] }
:return: {task_name: task}
e.g. {wmt14-fr-en: Task, wmt16-de-en: Task}
"""
return {
f"{dataset}-{language_pair}": create_translation_task(dataset, language_pair)
for dataset, language_pairs in benchmark_dict.items()
for language_pair in language_pairs
}
########################################
# Tasks
########################################
def create_translation_task(dataset, language_pair):
class TranslationTask(GeneralTranslationTask):
def __init__(self):
super().__init__(dataset, language_pair)
return TranslationTask
class GeneralTranslationTask(Task):
# e.g. ("wmt14", "fr-en")
def __init__(self, sacrebleu_dataset, sacrebleu_language_pair=None):
self.sacrebleu_dataset = sacrebleu_dataset
self.sacrebleu_language_pair = sacrebleu_language_pair
self.src_file = self.ref_file = self.src_data = self.ref_data = None
super().__init__()
def download(self):
# This caches in the users home dir automatically
self.src_file, self.ref_file = \
sacrebleu.download_test_set(self.sacrebleu_dataset, self.sacrebleu_language_pair)
self.src_data, self.ref_data = [
[line.rstrip() for line in sacrebleu.smart_open(file)]
for file in (self.src_file, self.ref_file)
]
def has_training_docs(self):
"""Whether the task has a training set"""
# TODO In the future we could be more discerning. Some more recent tests have train and dev sets
return False
def has_validation_docs(self):
"""Whether the task has a validation set"""
return False
def has_test_docs(self):
"""Whether the task has a test set"""
return True
def test_docs(self):
"""
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
"""
return [{
"src": src,
"ref": ref
} for src, ref in zip(self.src_data, self.ref_data)]
def doc_to_text(self, doc):
return doc["src"]
def doc_to_target(self, doc):
# TODO Note that some exotic tests have multiple ref lines.
# How does sacrebleu handle opening these files?
return doc["ref"]
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
return rf.greedy_until(ctx, ["\n"])
def process_results(self, doc, results):
# These metrics are corpus-level not sentence level, so we'll hide the
# results in this dict and compute the corpus score in the aggregate method
ref_pred = (doc["ref"], results)
return {
"bleu": ref_pred,
"chrf": ref_pred,
"ter": ref_pred,
}
def aggregation(self):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
return {
"bleu": metrics.bleu,
"chrf": metrics.chrf,
"ter": metrics.ter,
}
def higher_is_better(self):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return {
"bleu": True,
"chrf": True,
"ter": False,
}
def fewshot_description(self):
language_codes = self.sacrebleu_language_pair.split("-")
src_lang = code_to_language(language_codes[0])
tar_lang = code_to_language(language_codes[1])
return f"Translate these {src_lang} phrases to {tar_lang}."
# TODO This should be something like
# French: {src_line}
# English: {ref_line}
def fewshot_context(self, doc, num_fewshot, provide_description):
return ""
def __str__(self):
language_codes = self.sacrebleu_language_pair.split("-")
src_lang = code_to_language(language_codes[0])
tar_lang = code_to_language(language_codes[1])
return f"{self.sacrebleu_dataset.upper()} {src_lang} to {tar_lang} Task"
########################################
# Util
########################################
def code_to_language(code):
# key is alpha_2 or alpha_3 depending on the code length
language_tuple = pycountry.languages.get(**{f"alpha_{len(code)}": code})
return language_tuple.name
def print_available_tests():
pprint({ts: sacrebleu.get_langpairs_for_testset(ts) for ts in sacrebleu.get_available_testsets()})
def main():
# print(sacrebleu.download_test_set("wmt14", "en-fr"))
# print_available_tests()
# sacrebleu.print_test_set("wmt14", "fr-en", "src")
# # Print number of benchmarks
# print(sum([
# len(sacrebleu.get_langpairs_for_testset(ts))
# for ts in sacrebleu.get_available_testsets()
# ]))
# Test task dictionary
# for task, task_class in create_tasks_from_benchmarks(selected_benchmarks).items():
# print(task, task_class())
pass
if __name__ == "__main__":
main()
########################################
# Don't mind me...!
########################################
# Available tests as of 2020/02/11
"""
{'iwslt17': ['en-fr',
'fr-en',
'en-de',
'de-en',
'en-zh',
'zh-en',
'en-ar',
'ar-en',
'en-ja',
'ja-en',
'en-ko',
'ko-en'],
'iwslt17/dev2010': ['en-fr', 'fr-en', 'en-de', 'de-en', 'en-zh', 'zh-en'],
'iwslt17/tst2010': ['en-fr', 'fr-en', 'en-de', 'de-en', 'en-zh', 'zh-en'],
'iwslt17/tst2011': ['en-fr', 'fr-en', 'en-de', 'de-en', 'en-zh', 'zh-en'],
'iwslt17/tst2012': ['en-fr', 'fr-en', 'en-de', 'de-en', 'en-zh', 'zh-en'],
'iwslt17/tst2013': ['en-fr', 'fr-en', 'en-de', 'de-en', 'en-zh', 'zh-en'],
'iwslt17/tst2014': ['en-fr', 'fr-en', 'en-de', 'de-en', 'en-zh', 'zh-en'],
'iwslt17/tst2015': ['en-fr', 'fr-en', 'en-de', 'de-en', 'en-zh', 'zh-en'],
'iwslt17/tst2016': ['en-fr', 'fr-en', 'en-de', 'de-en', 'en-zh', 'zh-en'],
'mtnt1.1/test': ['en-fr', 'fr-en', 'en-ja', 'ja-en'],
'mtnt1.1/train': ['en-fr', 'fr-en', 'en-ja', 'ja-en'],
'mtnt1.1/valid': ['en-fr', 'fr-en', 'en-ja', 'ja-en'],
'mtnt2019': ['en-fr', 'fr-en', 'en-ja', 'ja-en'],
'multi30k/2016': ['en-fr', 'en-de', 'en-cs'],
'multi30k/2017': ['en-fr', 'en-de'],
'multi30k/2018': ['en-fr', 'en-de'],
'wmt08': ['cs-en',
'en-cs',
'de-en',
'en-de',
'es-en',
'en-es',
'fr-en',
'en-fr',
'hu-en',
'en-hu'],
'wmt08/europarl': ['de-en', 'en-de', 'es-en', 'en-es', 'fr-en', 'en-fr'],
'wmt08/nc': ['cs-en', 'en-cs'],
'wmt09': ['cs-en',
'en-cs',
'de-en',
'en-de',
'es-en',
'en-es',
'fr-en',
'en-fr',
'hu-en',
'en-hu',
'it-en',
'en-it'],
'wmt10': ['cs-en',
'en-cs',
'de-en',
'en-de',
'es-en',
'en-es',
'fr-en',
'en-fr'],
'wmt11': ['cs-en',
'en-cs',
'de-en',
'en-de',
'fr-en',
'en-fr',
'es-en',
'en-es'],
'wmt12': ['cs-en',
'en-cs',
'de-en',
'en-de',
'es-en',
'en-es',
'fr-en',
'en-fr'],
'wmt13': ['cs-en',
'en-cs',
'de-en',
'en-de',
'es-en',
'en-es',
'fr-en',
'en-fr',
'ru-en',
'en-ru'],
'wmt14': ['cs-en',
'en-cs',
'de-en',
'en-de',
'en-fr',
'fr-en',
'en-hi',
'hi-en',
'en-ru',
'ru-en'],
'wmt14/full': ['cs-en',
'en-cs',
'de-en',
'en-de',
'en-fr',
'fr-en',
'en-hi',
'hi-en',
'en-ru',
'ru-en'],
'wmt15': ['en-fr',
'fr-en',
'cs-en',
'de-en',
'en-cs',
'en-de',
'en-fi',
'en-ru',
'fi-en',
'ru-en'],
'wmt16': ['cs-en',
'de-en',
'en-cs',
'en-de',
'en-fi',
'en-ro',
'en-ru',
'en-tr',
'fi-en',
'ro-en',
'ru-en',
'tr-en'],
'wmt16/B': ['en-fi'],
'wmt16/dev': ['en-ro', 'en-tr', 'ro-en', 'tr-en'],
'wmt16/tworefs': ['en-fi'],
'wmt17': ['cs-en',
'de-en',
'en-cs',
'en-de',
'en-fi',
'en-lv',
'en-ru',
'en-tr',
'en-zh',
'fi-en',
'lv-en',
'ru-en',
'tr-en',
'zh-en'],
'wmt17/B': ['en-fi'],
'wmt17/dev': ['en-lv', 'en-zh', 'lv-en', 'zh-en'],
'wmt17/improved': ['en-zh', 'zh-en'],
'wmt17/ms': ['zh-en'],
'wmt17/tworefs': ['en-fi'],
'wmt18': ['cs-en',
'de-en',
'en-cs',
'en-de',
'en-et',
'en-fi',
'en-ru',
'et-en',
'fi-en',
'ru-en',
'en-tr',
'tr-en',
'en-zh',
'zh-en'],
'wmt18/dev': ['et-en', 'en-et'],
'wmt18/test-ts': ['cs-en',
'de-en',
'en-cs',
'en-de',
'en-et',
'en-fi',
'en-ru',
'et-en',
'fi-en',
'ru-en',
'en-tr',
'tr-en',
'en-zh',
'zh-en'],
'wmt19': ['cs-de',
'de-cs',
'de-en',
'de-fr',
'en-cs',
'en-de',
'en-fi',
'en-gu',
'en-kk',
'en-lt',
'en-ru',
'en-zh',
'fi-en',
'fr-de',
'gu-en',
'kk-en',
'lt-en',
'ru-en',
'zh-en'],
'wmt19/dev': ['lt-en', 'en-lt', 'gu-en', 'en-gu', 'kk-en', 'en-kk'],
'wmt19/google/ar': ['en-de'],
'wmt19/google/arp': ['en-de'],
'wmt19/google/hqall': ['en-de'],
'wmt19/google/hqp': ['en-de'],
'wmt19/google/hqr': ['en-de'],
'wmt19/google/wmtp': ['en-de'],
'wmt20': ['cs-en',
'de-en',
'de-fr',
'en-cs',
'en-de',
'en-iu',
'en-ja',
'en-km',
'en-pl',
'en-ps',
'en-ru',
'en-ta',
'en-zh',
'fr-de',
'iu-en',
'ja-en',
'km-en',
'pl-en',
'ps-en',
'ru-en',
'ta-en',
'zh-en'],
'wmt20/dev': ['iu-en',
'en-iu',
'ja-en',
'en-ja',
'pl-en',
'en-pl',
'ta-en',
'en-ta'],
'wmt20/robust/set1': ['en-ja', 'en-de'],
'wmt20/robust/set2': ['en-ja', 'ja-en'],
'wmt20/robust/set3': ['de-en'],
'wmt20/tworefs': ['de-en', 'en-de', 'en-zh', 'ru-en', 'zh-en']}
"""
\ No newline at end of file
import os import os
import json import json
import random import random
from lm_eval.base import Task, mean, rf from lm_eval.base import Task, rf
from ..metrics import mean
from ..utils import sh from ..utils import sh
class TriviaQA(Task): class TriviaQA(Task):
......
from . common import HFTask from . common import HFTask
from lm_eval.base import mean, rf from lm_eval.base import rf
from ..metrics import mean
class WebQs(HFTask): class WebQs(HFTask):
DATASET_PATH = "web_questions" DATASET_PATH = "web_questions"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment