Unverified Commit a2cada5d authored by Jonathan Tow's avatar Jonathan Tow Committed by GitHub
Browse files

Merge pull request #317 from EleutherAI/Mistobaan/add-pre-commit

Add pre-commit
parents 7a038118 83507c4b
...@@ -23,7 +23,7 @@ _CITATION = """ ...@@ -23,7 +23,7 @@ _CITATION = """
booktitle={CLEF}, booktitle={CLEF},
year={2013} year={2013}
} }
""" """ # noqa: W605
class QA4MRE(MultipleChoiceTask): class QA4MRE(MultipleChoiceTask):
...@@ -47,7 +47,7 @@ class QA4MRE(MultipleChoiceTask): ...@@ -47,7 +47,7 @@ class QA4MRE(MultipleChoiceTask):
def _process_doc(self, doc): def _process_doc(self, doc):
choices = doc["answer_options"]["answer_str"] choices = doc["answer_options"]["answer_str"]
out_doc = { out_doc = {
"source": doc["document_str"].strip().replace("\'", "'"), "source": doc["document_str"].strip().replace("'", "'"),
"query": doc["question_str"], "query": doc["question_str"],
"choices": choices, "choices": choices,
"gold": int(doc["correct_answer_id"]) - 1, "gold": int(doc["correct_answer_id"]) - 1,
......
...@@ -51,23 +51,34 @@ class QuAC(Task): ...@@ -51,23 +51,34 @@ class QuAC(Task):
raise NotImplementedError("QuAC has no test docs.") raise NotImplementedError("QuAC has no test docs.")
def _process_doc(self, doc): def _process_doc(self, doc):
doc["title"] = doc['title'] + ' - ' + doc['section_title'] doc["title"] = doc["title"] + " - " + doc["section_title"]
return doc return doc
def doc_to_text(self, doc): def doc_to_text(self, doc):
return 'TITLE: ' + doc['title'] + '\n' + 'PARAGRAPH: ' + doc['paragraph'] + '\n\n' + 'Q: ' + doc['question'] + '\n\n' + 'A: ' return (
"TITLE: "
+ doc["title"]
+ "\n"
+ "PARAGRAPH: "
+ doc["paragraph"]
+ "\n\n"
+ "Q: "
+ doc["question"]
+ "\n\n"
+ "A: "
)
def should_decontaminate(self): def should_decontaminate(self):
return True return True
def doc_to_decontamination_query(self, doc): def doc_to_decontamination_query(self, doc):
return doc['paragraph'] return doc["paragraph"]
def doc_to_target(self, doc): def doc_to_target(self, doc):
return doc['answer'] return doc["answer"]
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of """Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM. Requests which will be sent to the LM.
:param doc: :param doc:
...@@ -78,7 +89,7 @@ class QuAC(Task): ...@@ -78,7 +89,7 @@ class QuAC(Task):
part of the document for `doc`. part of the document for `doc`.
""" """
# TODO: implement evaluation. # TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented') raise NotImplementedError("Evaluation not implemented")
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a """Take a single document and the LM results and evaluates, returning a
...@@ -91,7 +102,7 @@ class QuAC(Task): ...@@ -91,7 +102,7 @@ class QuAC(Task):
The results of the requests created in construct_requests. The results of the requests created in construct_requests.
""" """
# TODO: implement evaluation. # TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented') raise NotImplementedError("Evaluation not implemented")
def aggregation(self): def aggregation(self):
""" """
...@@ -100,7 +111,7 @@ class QuAC(Task): ...@@ -100,7 +111,7 @@ class QuAC(Task):
functions that aggregate a list of metrics functions that aggregate a list of metrics
""" """
# TODO: implement evaluation. # TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented') raise NotImplementedError("Evaluation not implemented")
def higher_is_better(self): def higher_is_better(self):
""" """
...@@ -109,4 +120,4 @@ class QuAC(Task): ...@@ -109,4 +120,4 @@ class QuAC(Task):
whether a higher value of the submetric is better whether a higher value of the submetric is better
""" """
# TODO: implement evaluation. # TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented') raise NotImplementedError("Evaluation not implemented")
...@@ -40,7 +40,7 @@ class RACE(Task): ...@@ -40,7 +40,7 @@ class RACE(Task):
DATASET_NAME = "high" DATASET_NAME = "high"
cache = {} cache = {}
letter_to_num = {'A': 0, 'B': 1, 'C': 2, 'D': 3} letter_to_num = {"A": 0, "B": 1, "C": 2, "D": 3}
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -59,17 +59,27 @@ class RACE(Task): ...@@ -59,17 +59,27 @@ class RACE(Task):
# is shown that one document is made per passage. # is shown that one document is made per passage.
r = collections.defaultdict(list) r = collections.defaultdict(list)
for item in datasets.load_dataset(path=self.DATASET_PATH, name=self.DATASET_NAME)[set]: for item in datasets.load_dataset(
r[item['article']].append(item) path=self.DATASET_PATH, name=self.DATASET_NAME
)[set]:
res = list(r.values() >> each(lambda x: { r[item["article"]].append(item)
'article': x[0]['article'],
'problems': x >> each(lambda y: { res = list(
'question': y['question'], r.values()
'answer': y['answer'], >> each(
'options': y['options'], lambda x: {
}) "article": x[0]["article"],
})) "problems": x
>> each(
lambda y: {
"question": y["question"],
"answer": y["answer"],
"options": y["options"],
}
),
}
)
)
self.cache[set] = res self.cache[set] = res
return res return res
...@@ -85,36 +95,38 @@ class RACE(Task): ...@@ -85,36 +95,38 @@ class RACE(Task):
@classmethod @classmethod
def get_answer_option(cls, problem): def get_answer_option(cls, problem):
answer = cls.letter_to_num[problem['answer']] answer = cls.letter_to_num[problem["answer"]]
return problem['options'][answer] return problem["options"][answer]
@classmethod @classmethod
def last_problem(cls, doc): def last_problem(cls, doc):
return doc['problems'][-1] return doc["problems"][-1]
def doc_to_text(self, doc): def doc_to_text(self, doc):
text = 'Article: ' + doc['article'] + '\n\n' text = "Article: " + doc["article"] + "\n\n"
for problem in doc['problems'][:-1]: for problem in doc["problems"][:-1]:
if problem['question'][-6:] == ' _ .': if problem["question"][-6:] == " _ .":
text += problem['question'][-5:] + self.get_answer_option(problem) + '\n' text += (
problem["question"][-5:] + self.get_answer_option(problem) + "\n"
)
else: else:
question = 'Question: ' + problem['question'] + '\n' question = "Question: " + problem["question"] + "\n"
answer = 'Answer: ' + self.get_answer_option(problem) + '\n' answer = "Answer: " + self.get_answer_option(problem) + "\n"
text += question + answer text += question + answer
text += self.last_problem(doc)['question'] text += self.last_problem(doc)["question"]
return text return text
def should_decontaminate(self): def should_decontaminate(self):
return True return True
def doc_to_decontamination_query(self, doc): def doc_to_decontamination_query(self, doc):
return doc['article'] return doc["article"]
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " " + self.get_answer_option(self.last_problem(doc)) return " " + self.get_answer_option(self.last_problem(doc))
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of """Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM. Requests which will be sent to the LM.
:param doc: :param doc:
...@@ -126,8 +138,7 @@ class RACE(Task): ...@@ -126,8 +138,7 @@ class RACE(Task):
""" """
problem = self.last_problem(doc) problem = self.last_problem(doc)
ll_choices = [ ll_choices = [
rf.loglikelihood(ctx, " " + problem['options'][i])[0] rf.loglikelihood(ctx, " " + problem["options"][i])[0] for i in range(4)
for i in range(4)
] ]
return ll_choices return ll_choices
...@@ -141,11 +152,9 @@ class RACE(Task): ...@@ -141,11 +152,9 @@ class RACE(Task):
:param results: :param results:
The results of the requests created in construct_requests. The results of the requests created in construct_requests.
""" """
gold = self.letter_to_num[self.last_problem(doc)['answer']] gold = self.letter_to_num[self.last_problem(doc)["answer"]]
pred = np.argmax(results) pred = np.argmax(results)
return { return {"acc": int(pred == gold)}
"acc": int(pred == gold)
}
def aggregation(self): def aggregation(self):
""" """
...@@ -153,9 +162,7 @@ class RACE(Task): ...@@ -153,9 +162,7 @@ class RACE(Task):
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics functions that aggregate a list of metrics
""" """
return { return {"acc": mean}
"acc": mean
}
def higher_is_better(self): def higher_is_better(self):
""" """
...@@ -163,6 +170,4 @@ class RACE(Task): ...@@ -163,6 +170,4 @@ class RACE(Task):
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better whether a higher value of the submetric is better
""" """
return { return {"acc": True}
"acc": True
}
...@@ -59,14 +59,16 @@ class SATAnalogies(MultipleChoiceTask): ...@@ -59,14 +59,16 @@ class SATAnalogies(MultipleChoiceTask):
def _process_doc(self, doc): def _process_doc(self, doc):
return { return {
'source': doc['source'], "source": doc["source"],
'query': doc['stem'].split(' ')[:2], "query": doc["stem"].split(" ")[:2],
'choices': ["{} is to {}".format(*c.split(' ')[:2]) for c in doc["choices"]], "choices": [
'gold': ['a', 'b', 'c', 'd', 'e'].index(doc['solution'].strip()), "{} is to {}".format(*c.split(" ")[:2]) for c in doc["choices"]
],
"gold": ["a", "b", "c", "d", "e"].index(doc["solution"].strip()),
} }
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "{} is to {} as".format(*doc['query']) return "{} is to {} as".format(*doc["query"])
def should_decontaminate(self): def should_decontaminate(self):
return True return True
......
...@@ -54,10 +54,10 @@ class SciQ(MultipleChoiceTask): ...@@ -54,10 +54,10 @@ class SciQ(MultipleChoiceTask):
doc["distractor3"], doc["distractor3"],
doc["correct_answer"], doc["correct_answer"],
] ]
src = doc['support'] src = doc["support"]
out_doc = { out_doc = {
"source": src, "source": src,
"query": doc['question'], "query": doc["question"],
"choices": choices, "choices": choices,
"gold": 3, "gold": 3,
} }
......
...@@ -49,7 +49,9 @@ class SQuAD2(Task): ...@@ -49,7 +49,9 @@ class SQuAD2(Task):
DATASET_NAME = None DATASET_NAME = None
# HF changed squad on us so we have to make sure we aren't running the old one # HF changed squad on us so we have to make sure we aren't running the old one
assert version.parse(datasets.__version__) >= version.parse("1.11.0"), "datasets v1.11.0 or later required for SQuAD" assert version.parse(datasets.__version__) >= version.parse(
"1.11.0"
), "datasets v1.11.0 or later required for SQuAD"
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -67,24 +69,35 @@ class SQuAD2(Task): ...@@ -67,24 +69,35 @@ class SQuAD2(Task):
return self.dataset["validation"] return self.dataset["validation"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return 'Title: ' + doc['title'] + '\n\n' + 'Background: ' + doc['context'] + '\n\n' + 'Question: ' + doc['question'] + '\n\n' + 'Answer:' return (
"Title: "
+ doc["title"]
+ "\n\n"
+ "Background: "
+ doc["context"]
+ "\n\n"
+ "Question: "
+ doc["question"]
+ "\n\n"
+ "Answer:"
)
def should_decontaminate(self): def should_decontaminate(self):
return True return True
def doc_to_decontamination_query(self, doc): def doc_to_decontamination_query(self, doc):
return doc['context'] return doc["context"]
def doc_to_target(self, doc): def doc_to_target(self, doc):
answer_list = doc['answers']['text'] answer_list = doc["answers"]["text"]
if len(answer_list) > 0: if len(answer_list) > 0:
answer = answer_list[0] answer = answer_list[0]
else: else:
answer = 'unanswerable' answer = "unanswerable"
return " " + answer return " " + answer
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of """Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM. Requests which will be sent to the LM.
:param doc: :param doc:
...@@ -94,7 +107,7 @@ class SQuAD2(Task): ...@@ -94,7 +107,7 @@ class SQuAD2(Task):
language description, as well as the few shot examples, and the question language description, as well as the few shot examples, and the question
part of the document for `doc`. part of the document for `doc`.
""" """
continuation = rf.greedy_until(ctx, ['\n']) continuation = rf.greedy_until(ctx, ["\n"])
is_unanswerable = rf.loglikelihood(ctx, " " + "unanswerable") is_unanswerable = rf.loglikelihood(ctx, " " + "unanswerable")
return continuation, is_unanswerable return continuation, is_unanswerable
...@@ -113,25 +126,46 @@ class SQuAD2(Task): ...@@ -113,25 +126,46 @@ class SQuAD2(Task):
no_answer_probability = exp(logprob_unanswerable) no_answer_probability = exp(logprob_unanswerable)
predictions = { predictions = {
'id': doc['id'], "id": doc["id"],
'prediction_text': continuation, "prediction_text": continuation,
'no_answer_probability': no_answer_probability, "no_answer_probability": no_answer_probability,
} }
references = { references = {
'id': doc['id'], "id": doc["id"],
'answers': doc['answers'], "answers": doc["answers"],
} }
return { return {
'exact': (predictions, references), # Exact match (the normalized answer exactly match the gold answer) "exact": (
'f1': (predictions, references), # The F-score of predicted tokens versus the gold answer predictions,
'HasAns_exact': (predictions, references), # Exact match (the normalized answer exactly match the gold answer) references,
'HasAns_f1': (predictions, references), # The F-score of predicted tokens versus the gold answer ), # Exact match (the normalized answer exactly match the gold answer)
'NoAns_exact': (predictions, references), # Exact match (the normalized answer exactly match the gold answer) "f1": (
'NoAns_f1': (predictions, references), # The F-score of predicted tokens versus the gold answer predictions,
'best_exact': (predictions, references), # Best exact match (with varying threshold) references,
'best_f1': (predictions, references), # Best F1 (with varying threshold) ), # The F-score of predicted tokens versus the gold answer
"HasAns_exact": (
predictions,
references,
), # Exact match (the normalized answer exactly match the gold answer)
"HasAns_f1": (
predictions,
references,
), # The F-score of predicted tokens versus the gold answer
"NoAns_exact": (
predictions,
references,
), # Exact match (the normalized answer exactly match the gold answer)
"NoAns_f1": (
predictions,
references,
), # The F-score of predicted tokens versus the gold answer
"best_exact": (
predictions,
references,
), # Best exact match (with varying threshold)
"best_f1": (predictions, references), # Best F1 (with varying threshold)
} }
def aggregation(self): def aggregation(self):
...@@ -141,14 +175,30 @@ class SQuAD2(Task): ...@@ -141,14 +175,30 @@ class SQuAD2(Task):
functions that aggregate a list of metrics functions that aggregate a list of metrics
""" """
return { return {
'exact': partial(_squad_agg, 'exact'), # Exact match (the normalized answer exactly match the gold answer) "exact": partial(
'f1': partial(_squad_agg, 'f1'), # The F-score of predicted tokens versus the gold answer _squad_agg, "exact"
'HasAns_exact': partial(_squad_agg, 'HasAns_exact'), # Exact match (the normalized answer exactly match the gold answer) ), # Exact match (the normalized answer exactly match the gold answer)
'HasAns_f1': partial(_squad_agg, 'HasAns_f1'), # The F-score of predicted tokens versus the gold answer "f1": partial(
'NoAns_exact': partial(_squad_agg, 'NoAns_exact'), # Exact match (the normalized answer exactly match the gold answer) _squad_agg, "f1"
'NoAns_f1': partial(_squad_agg, 'NoAns_f1'), # The F-score of predicted tokens versus the gold answer ), # The F-score of predicted tokens versus the gold answer
'best_exact': partial(_squad_agg, 'best_exact'), # Best exact match (with varying threshold) "HasAns_exact": partial(
'best_f1': partial(_squad_agg, 'best_f1'), # Best F1 (with varying threshold) _squad_agg, "HasAns_exact"
), # Exact match (the normalized answer exactly match the gold answer)
"HasAns_f1": partial(
_squad_agg, "HasAns_f1"
), # The F-score of predicted tokens versus the gold answer
"NoAns_exact": partial(
_squad_agg, "NoAns_exact"
), # Exact match (the normalized answer exactly match the gold answer)
"NoAns_f1": partial(
_squad_agg, "NoAns_f1"
), # The F-score of predicted tokens versus the gold answer
"best_exact": partial(
_squad_agg, "best_exact"
), # Best exact match (with varying threshold)
"best_f1": partial(
_squad_agg, "best_f1"
), # Best F1 (with varying threshold)
} }
def higher_is_better(self): def higher_is_better(self):
...@@ -158,12 +208,12 @@ class SQuAD2(Task): ...@@ -158,12 +208,12 @@ class SQuAD2(Task):
whether a higher value of the submetric is better whether a higher value of the submetric is better
""" """
return { return {
'exact': True, # Exact match (the normalized answer exactly match the gold answer) "exact": True, # Exact match (the normalized answer exactly match the gold answer)
'f1': True, # The F-score of predicted tokens versus the gold answer "f1": True, # The F-score of predicted tokens versus the gold answer
'HasAns_exact': True, # Exact match (the normalized answer exactly match the gold answer) "HasAns_exact": True, # Exact match (the normalized answer exactly match the gold answer)
'HasAns_f1': True, # The F-score of predicted tokens versus the gold answer "HasAns_f1": True, # The F-score of predicted tokens versus the gold answer
'NoAns_exact': True, # Exact match (the normalized answer exactly match the gold answer) "NoAns_exact": True, # Exact match (the normalized answer exactly match the gold answer)
'NoAns_f1': True, # The F-score of predicted tokens versus the gold answer "NoAns_f1": True, # The F-score of predicted tokens versus the gold answer
'best_exact': True, # Best exact match (with varying threshold) "best_exact": True, # Best exact match (with varying threshold)
'best_f1': True, # Best F1 (with varying threshold) "best_f1": True, # Best F1 (with varying threshold)
} }
...@@ -65,23 +65,27 @@ class StoryCloze(Task): ...@@ -65,23 +65,27 @@ class StoryCloze(Task):
return self.dataset["test"] return self.dataset["test"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return ' '.join([ return " ".join(
[
doc["input_sentence_1"], doc["input_sentence_1"],
doc["input_sentence_2"], doc["input_sentence_2"],
doc["input_sentence_3"], doc["input_sentence_3"],
doc["input_sentence_4"], doc["input_sentence_4"],
]) ]
)
def should_decontaminate(self): def should_decontaminate(self):
return True return True
def doc_to_decontamination_query(self, doc): def doc_to_decontamination_query(self, doc):
return ' '.join([ return " ".join(
[
doc["input_sentence_1"], doc["input_sentence_1"],
doc["input_sentence_2"], doc["input_sentence_2"],
doc["input_sentence_3"], doc["input_sentence_3"],
doc["input_sentence_4"], doc["input_sentence_4"],
]) ]
)
def doc_to_target(self, doc): def doc_to_target(self, doc):
clozes = [doc["sentence_quiz1"], doc["sentence_quiz2"]] clozes = [doc["sentence_quiz1"], doc["sentence_quiz2"]]
...@@ -89,7 +93,7 @@ class StoryCloze(Task): ...@@ -89,7 +93,7 @@ class StoryCloze(Task):
return " " + clozes[doc["answer_right_ending"] - 1] return " " + clozes[doc["answer_right_ending"] - 1]
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of """Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM. Requests which will be sent to the LM.
:param doc: :param doc:
...@@ -100,10 +104,7 @@ class StoryCloze(Task): ...@@ -100,10 +104,7 @@ class StoryCloze(Task):
part of the document for `doc`. part of the document for `doc`.
""" """
clozes = [doc["sentence_quiz1"], doc["sentence_quiz2"]] clozes = [doc["sentence_quiz1"], doc["sentence_quiz2"]]
lls = [ lls = [rf.loglikelihood(ctx, " {}".format(choice))[0] for choice in clozes]
rf.loglikelihood(ctx, " {}".format(choice))[0]
for choice in clozes
]
return lls return lls
def process_results(self, doc, results): def process_results(self, doc, results):
...@@ -117,10 +118,8 @@ class StoryCloze(Task): ...@@ -117,10 +118,8 @@ class StoryCloze(Task):
The results of the requests created in construct_requests. The results of the requests created in construct_requests.
""" """
gold = doc["answer_right_ending"] - 1 gold = doc["answer_right_ending"] - 1
acc = 1. if np.argmax(results) == gold else 0. acc = 1.0 if np.argmax(results) == gold else 0.0
return { return {"acc": acc}
"acc": acc
}
def aggregation(self): def aggregation(self):
""" """
...@@ -128,9 +127,7 @@ class StoryCloze(Task): ...@@ -128,9 +127,7 @@ class StoryCloze(Task):
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics functions that aggregate a list of metrics
""" """
return { return {"acc": mean}
"acc": mean
}
def higher_is_better(self): def higher_is_better(self):
""" """
...@@ -138,9 +135,7 @@ class StoryCloze(Task): ...@@ -138,9 +135,7 @@ class StoryCloze(Task):
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better whether a higher value of the submetric is better
""" """
return { return {"acc": True}
"acc": True
}
class StoryCloze2016(StoryCloze): class StoryCloze2016(StoryCloze):
......
This diff is collapsed.
...@@ -41,44 +41,57 @@ def create_tasks_from_benchmarks(benchmark_dict): ...@@ -41,44 +41,57 @@ def create_tasks_from_benchmarks(benchmark_dict):
:return: {task_name: task} :return: {task_name: task}
e.g. {wmt14-fr-en: Task, wmt16-de-en: Task} e.g. {wmt14-fr-en: Task, wmt16-de-en: Task}
""" """
def version_of(dataset, language_pair): def version_of(dataset, language_pair):
if language_pair[-2:] in ["zh", "ja"]: if language_pair[-2:] in ["zh", "ja"]:
return 1 # changed to use jieba/nagisa return 1 # changed to use jieba/nagisa
return 0 return 0
return { return {
f"{dataset}-{language_pair}": create_translation_task(dataset, language_pair, version_of(dataset, language_pair)) f"{dataset}-{language_pair}": create_translation_task(
dataset, language_pair, version_of(dataset, language_pair)
)
for dataset, language_pairs in benchmark_dict.items() for dataset, language_pairs in benchmark_dict.items()
for language_pair in language_pairs for language_pair in language_pairs
} }
######################################## ########################################
# Language Specifics # Language Specifics
######################################## ########################################
def zh_split(zh_text: List[str]) -> List[str]: def zh_split(zh_text: List[str]) -> List[str]:
"""Chinese splitting""" """Chinese splitting"""
import jieba import jieba
return [" ".join(jieba.cut(txt.strip())) for txt in zh_text] return [" ".join(jieba.cut(txt.strip())) for txt in zh_text]
def ja_split(ja_text: List[str]) -> List[str]: def ja_split(ja_text: List[str]) -> List[str]:
"""Japanese splitting""" """Japanese splitting"""
import nagisa import nagisa
return [" ".join(nagisa.tagging(txt.strip()).words) for txt in ja_text] return [" ".join(nagisa.tagging(txt.strip()).words) for txt in ja_text]
NO_SPACE_LANG = {"zh": zh_split, "ja": ja_split} NO_SPACE_LANG = {"zh": zh_split, "ja": ja_split}
######################################## ########################################
# Tasks # Tasks
######################################## ########################################
def create_translation_task(dataset, language_pair, version=0): def create_translation_task(dataset, language_pair, version=0):
class TranslationTask(GeneralTranslationTask): class TranslationTask(GeneralTranslationTask):
VERSION = version VERSION = version
def __init__(self): def __init__(self):
super().__init__(dataset, language_pair) super().__init__(dataset, language_pair)
return TranslationTask return TranslationTask
class GeneralTranslationTask(Task): class GeneralTranslationTask(Task):
VERSION = 0 VERSION = 0
...@@ -92,8 +105,9 @@ class GeneralTranslationTask(Task): ...@@ -92,8 +105,9 @@ class GeneralTranslationTask(Task):
def download(self, data_dir=None, cache_dir=None, download_mode=None): def download(self, data_dir=None, cache_dir=None, download_mode=None):
# This caches in the users home dir automatically # This caches in the users home dir automatically
self.src_file, self.ref_file = \ self.src_file, self.ref_file = sacrebleu.download_test_set(
sacrebleu.download_test_set(self.sacrebleu_dataset, self.sacrebleu_language_pair) self.sacrebleu_dataset, self.sacrebleu_language_pair
)
self.src_data, self.ref_data = [ self.src_data, self.ref_data = [
[line.rstrip() for line in sacrebleu.smart_open(file)] [line.rstrip() for line in sacrebleu.smart_open(file)]
for file in (self.src_file, self.ref_file) for file in (self.src_file, self.ref_file)
...@@ -117,10 +131,9 @@ class GeneralTranslationTask(Task): ...@@ -117,10 +131,9 @@ class GeneralTranslationTask(Task):
:return: Iterable[obj] :return: Iterable[obj]
A iterable of any object, that doc_to_text can handle A iterable of any object, that doc_to_text can handle
""" """
return [{ return [
"src": src, {"src": src, "ref": ref} for src, ref in zip(self.src_data, self.ref_data)
"ref": ref ]
} for src, ref in zip(self.src_data, self.ref_data)]
def doc_to_text(self, doc): def doc_to_text(self, doc):
language_codes = self.sacrebleu_language_pair.split("-") language_codes = self.sacrebleu_language_pair.split("-")
...@@ -139,7 +152,7 @@ class GeneralTranslationTask(Task): ...@@ -139,7 +152,7 @@ class GeneralTranslationTask(Task):
return " " + doc["ref"] if isinstance(doc["ref"], str) else doc["ref"][0] return " " + doc["ref"] if isinstance(doc["ref"], str) else doc["ref"][0]
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of """Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM. Requests which will be sent to the LM.
:param doc: :param doc:
......
...@@ -43,10 +43,10 @@ class TriviaQA(Task): ...@@ -43,10 +43,10 @@ class TriviaQA(Task):
return False return False
def training_docs(self): def training_docs(self):
return self.dataset['train'] return self.dataset["train"]
def validation_docs(self): def validation_docs(self):
return self.dataset['validation'] return self.dataset["validation"]
def test_docs(self): def test_docs(self):
raise NotImplementedError() raise NotImplementedError()
...@@ -58,10 +58,10 @@ class TriviaQA(Task): ...@@ -58,10 +58,10 @@ class TriviaQA(Task):
return True return True
def doc_to_decontamination_query(self, doc): def doc_to_decontamination_query(self, doc):
return doc['question'] return doc["question"]
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " " + doc['answer']['value'] return " " + doc["answer"]["value"]
def _remove_prefixes(self, aliases): def _remove_prefixes(self, aliases):
# Optimization: Remove any alias that has a strict prefix elsewhere in the list # Optimization: Remove any alias that has a strict prefix elsewhere in the list
...@@ -75,15 +75,13 @@ class TriviaQA(Task): ...@@ -75,15 +75,13 @@ class TriviaQA(Task):
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
ret = [] ret = []
for alias in self._remove_prefixes(doc['answer']['aliases']): for alias in self._remove_prefixes(doc["answer"]["aliases"]):
_, is_prediction = rf.loglikelihood(ctx, " " + alias) _, is_prediction = rf.loglikelihood(ctx, " " + alias)
ret.append(is_prediction) ret.append(is_prediction)
return ret return ret
def process_results(self, doc, results): def process_results(self, doc, results):
return { return {"acc": float(any(results))}
"acc": float(any(results))
}
def aggregation(self): def aggregation(self):
return { return {
...@@ -91,6 +89,4 @@ class TriviaQA(Task): ...@@ -91,6 +89,4 @@ class TriviaQA(Task):
} }
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True}
"acc": True
}
This diff is collapsed.
...@@ -65,19 +65,13 @@ class WordUnscrambleTask(Task): ...@@ -65,19 +65,13 @@ class WordUnscrambleTask(Task):
def process_results(self, doc, results): def process_results(self, doc, results):
pred = results[0] pred = results[0]
gold = doc["completion"] gold = doc["completion"]
return { return {"acc": int(pred == gold)}
"acc": int(pred == gold)
}
def aggregation(self): def aggregation(self):
return { return {"acc": mean}
"acc": mean
}
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True}
"acc": True
}
class Anagrams1(WordUnscrambleTask): class Anagrams1(WordUnscrambleTask):
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment