"vscode:/vscode.git/clone" did not exist on "0975ba99bcf38c34c28e738c14e9df0abb9cb10a"
Commit baa8b0d3 authored by bzantium's avatar bzantium
Browse files

fix for merge from master

parent a956bc63
...@@ -49,29 +49,29 @@ class WordUnscrambleTask(Task): ...@@ -49,29 +49,29 @@ class WordUnscrambleTask(Task):
def doc_to_text(self, doc): def doc_to_text(self, doc):
return doc["context"] return doc["context"]
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["context"]
def doc_to_target(self, doc): def doc_to_target(self, doc):
return doc["completion"] return doc["completion"]
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
completion = rf.greedy_until(ctx, ["\n"]) completion = rf.greedy_until(ctx, {"until": ["\n"]})
return completion return completion
def process_results(self, doc, results): def process_results(self, doc, results):
pred = results[0] pred = results[0]
gold = doc["completion"] gold = doc["completion"]
return { return {"acc": int(pred == gold)}
"acc": int(pred == gold)
}
def aggregation(self): def aggregation(self):
return { return {"acc": mean}
"acc": mean
}
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True}
"acc": True
}
class Anagrams1(WordUnscrambleTask): class Anagrams1(WordUnscrambleTask):
......
...@@ -54,14 +54,20 @@ class WebQs(Task): ...@@ -54,14 +54,20 @@ class WebQs(Task):
return self.dataset["test"] return self.dataset["test"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "Question: " + doc['question'] + '\nAnswer:' return "Question: " + doc["question"] + "\nAnswer:"
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["question"]
def doc_to_target(self, doc): def doc_to_target(self, doc):
# this picks one answer to be the "correct" one, despite sometimes # this picks one answer to be the "correct" one, despite sometimes
# multiple correct answers being possible. # multiple correct answers being possible.
# TODO: make sure we're actually handling multi-answer correctly # TODO: make sure we're actually handling multi-answer correctly
return " " + doc['answers'][0] return " " + doc["answers"][0]
def _remove_prefixes(self, aliases): def _remove_prefixes(self, aliases):
# Optimization: Remove any alias that has a strict prefix elsewhere in the list # Optimization: Remove any alias that has a strict prefix elsewhere in the list
# we can do this because if the prefix is acceptable by isgreedy, we can stop looking # we can do this because if the prefix is acceptable by isgreedy, we can stop looking
...@@ -75,15 +81,13 @@ class WebQs(Task): ...@@ -75,15 +81,13 @@ class WebQs(Task):
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
ret = [] ret = []
for alias in self._remove_prefixes(doc['answers']): for alias in self._remove_prefixes(doc["answers"]):
_, is_prediction = rf.loglikelihood(ctx, " " + alias) _, is_prediction = rf.loglikelihood(ctx, " " + alias)
ret.append(is_prediction) ret.append(is_prediction)
return ret return ret
def process_results(self, doc, results): def process_results(self, doc, results):
return { return {"acc": float(any(results))}
"acc": float(any(results))
}
def aggregation(self): def aggregation(self):
return { return {
...@@ -91,6 +95,4 @@ class WebQs(Task): ...@@ -91,6 +95,4 @@ class WebQs(Task):
} }
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True}
"acc": True
}
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
Pointer Sentinel Mixture Models Pointer Sentinel Mixture Models
https://arxiv.org/pdf/1609.07843.pdf https://arxiv.org/pdf/1609.07843.pdf
The WikiText language modeling dataset is a collection of over 100 million tokens The WikiText language modeling dataset is a collection of over 100 million tokens
extracted from the set of verified Good and Featured articles on Wikipedia. extracted from the set of verified Good and Featured articles on Wikipedia.
NOTE: This `Task` is based on WikiText-2. NOTE: This `Task` is based on WikiText-2.
...@@ -10,14 +10,12 @@ NOTE: This `Task` is based on WikiText-2. ...@@ -10,14 +10,12 @@ NOTE: This `Task` is based on WikiText-2.
Homepage: https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/ Homepage: https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/
""" """
import re import re
import inspect
import lm_eval.datasets.wikitext.wikitext
from lm_eval.base import PerplexityTask from lm_eval.base import PerplexityTask
_CITATION = """ _CITATION = """
@misc{merity2016pointer, @misc{merity2016pointer,
title={Pointer Sentinel Mixture Models}, title={Pointer Sentinel Mixture Models},
author={Stephen Merity and Caiming Xiong and James Bradbury and Richard Socher}, author={Stephen Merity and Caiming Xiong and James Bradbury and Richard Socher},
year={2016}, year={2016},
eprint={1609.07843}, eprint={1609.07843},
...@@ -63,7 +61,7 @@ def wikitext_detokenizer(string): ...@@ -63,7 +61,7 @@ def wikitext_detokenizer(string):
class WikiText(PerplexityTask): class WikiText(PerplexityTask):
VERSION = 1 VERSION = 1
DATASET_PATH = inspect.getfile(lm_eval.datasets.wikitext.wikitext) DATASET_PATH = "EleutherAI/wikitext_document_level"
DATASET_NAME = "wikitext-2-raw-v1" DATASET_NAME = "wikitext-2-raw-v1"
def has_training_docs(self): def has_training_docs(self):
...@@ -76,20 +74,23 @@ class WikiText(PerplexityTask): ...@@ -76,20 +74,23 @@ class WikiText(PerplexityTask):
return True return True
def training_docs(self): def training_docs(self):
return map(self._load_doc, self.dataset["train"]) return map(self._process_doc, self.dataset["train"])
def validation_docs(self): def validation_docs(self):
return map(self._load_doc, self.dataset["validation"]) return map(self._process_doc, self.dataset["validation"])
def test_docs(self): def test_docs(self):
return map(self._load_doc, self.dataset["test"]) return map(self._process_doc, self.dataset["test"])
def _load_doc(self, doc): def _process_doc(self, doc):
return doc["page"] return doc["page"]
def doc_to_target(self, doc): def doc_to_target(self, doc):
return wikitext_detokenizer(doc) return wikitext_detokenizer(doc)
def should_decontaminate(self):
return True
def count_words(self, doc): def count_words(self, doc):
# count number of words in *original doc before detokenization* # count number of words in *original doc before detokenization*
return len(re.split(r"\s+", doc)) return len(re.split(r"\s+", doc))
""" """
WinoGrande: An Adversarial Winograd Schema Challenge at Scale WinoGrande: An Adversarial Winograd Schema Challenge at Scale
https://arxiv.org/pdf/1907.10641.pdf https://arxiv.org/pdf/1907.10641.pdf
WinoGrande is a collection of 44k problems, inspired by Winograd Schema Challenge WinoGrande is a collection of 44k problems, inspired by Winograd Schema Challenge
(Levesque, Davis, and Morgenstern 2011), but adjusted to improve the scale and (Levesque, Davis, and Morgenstern 2011), but adjusted to improve the scale and
robustness against the dataset-specific bias. Formulated as a fill-in-a-blank robustness against the dataset-specific bias. Formulated as a fill-in-a-blank
task with binary options, the goal is to choose the right option for a given task with binary options, the goal is to choose the right option for a given
sentence which requires commonsense reasoning. sentence which requires commonsense reasoning.
NOTE: This evaluation of Winogrande uses partial evaluation as described by NOTE: This evaluation of Winogrande uses partial evaluation as described by
Trinh & Le in Simple Method for Commonsense Reasoning (2018). Trinh & Le in Simple Method for Commonsense Reasoning (2018).
See: https://arxiv.org/abs/1806.02847 See: https://arxiv.org/abs/1806.02847
Homepage: https://leaderboard.allenai.org/winogrande/submissions/public Homepage: https://leaderboard.allenai.org/winogrande/submissions/public
""" """
import numpy as np import numpy as np
from lm_eval.base import rf, Task from lm_eval.base import rf, Task
from lm_eval.metrics import mean from lm_eval.metrics import mean
_CITATION = """ _CITATION = """
@article{sakaguchi2019winogrande, @article{sakaguchi2019winogrande,
title={WinoGrande: An Adversarial Winograd Schema Challenge at Scale}, title={WinoGrande: An Adversarial Winograd Schema Challenge at Scale},
author={Sakaguchi, Keisuke and Bras, Ronan Le and Bhagavatula, Chandra and Choi, Yejin}, author={Sakaguchi, Keisuke and Bras, Ronan Le and Bhagavatula, Chandra and Choi, Yejin},
journal={arXiv preprint arXiv:1907.10641}, journal={arXiv preprint arXiv:1907.10641},
year={2019} year={2019}
} }
""" """
class Winogrande(Task): class Winogrande(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = "winogrande" DATASET_PATH = "winogrande"
DATASET_NAME = "winogrande_xl" DATASET_NAME = "winogrande_xl"
answer_to_num = {'1': 0, '2': 1} answer_to_num = {"1": 0, "2": 1}
def has_training_docs(self): def has_training_docs(self):
return True return True
def has_validation_docs(self): def has_validation_docs(self):
return True return True
def has_test_docs(self): def has_test_docs(self):
return False return False
def training_docs(self): def training_docs(self):
if self._training_docs is None: if self._training_docs is None:
self._training_docs = list(self.dataset["train"]) self._training_docs = list(self.dataset["train"])
return self._training_docs return self._training_docs
def validation_docs(self): def validation_docs(self):
return self.dataset["validation"] return self.dataset["validation"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return self.partial_context(doc, doc["option" + doc["answer"]]) return self.partial_context(doc, doc["option" + doc["answer"]])
@classmethod def should_decontaminate(self):
def partial_context(cls, doc, option): return True
# Substitute the pronoun in the sentence with the specified option
# and ignore everything after. def doc_to_decontamination_query(self, doc):
pronoun_loc = doc["sentence"].index("_") return doc["sentence"]
return doc["sentence"][:pronoun_loc] + option
@classmethod
def doc_to_target(self, doc): def partial_context(cls, doc, option):
return self.partial_target(doc) # Substitute the pronoun in the sentence with the specified option
# and ignore everything after.
@classmethod pronoun_loc = doc["sentence"].index("_")
def partial_target(cls, doc): return doc["sentence"][:pronoun_loc] + option
# The target is everything after the document specified pronoun.
pronoun_loc = doc["sentence"].index("_") + 1 def doc_to_target(self, doc):
return " " + doc["sentence"][pronoun_loc:].strip() return self.partial_target(doc)
def construct_requests(self, doc, ctx): @classmethod
"""Uses RequestFactory to construct Requests and returns an iterable of def partial_target(cls, doc):
Requests which will be sent to the LM. # The target is everything after the document specified pronoun.
pronoun_loc = doc["sentence"].index("_") + 1
:param doc: return " " + doc["sentence"][pronoun_loc:].strip()
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str def construct_requests(self, doc, ctx):
The context string, generated by fewshot_context. This includes the natural """Uses RequestFactory to construct Requests and returns an iterable of
language description, as well as the few shot examples, and the question Requests which will be sent to the LM.
part of the document for `doc`.
""" :param doc:
target = self.partial_target(doc) The document as returned from training_docs, validation_docs, or test_docs.
lls = [] :param ctx: str
for option in [doc["option1"], doc["option2"]]: The context string, generated by fewshot_context. This includes the natural
partial_ctx = self.partial_context(doc, option) language description, as well as the few shot examples, and the question
full_ctx = self.append_context(ctx, partial_ctx) part of the document for `doc`.
lls.append(rf.loglikelihood(full_ctx, target)[0]) """
return lls target = self.partial_target(doc)
lls = []
@classmethod for option in [doc["option1"], doc["option2"]]:
def append_context(cls, ctx, partial_ctx): partial_ctx = self.partial_context(doc, option)
ctx = ctx.split("\n\n") # Each fewshot context is on its own new line. full_ctx = self.append_context(ctx, partial_ctx)
ctx.pop() # Remove the correct context put in by `doc_to_text`. lls.append(rf.loglikelihood(full_ctx, target)[0])
return "\n\n".join([*ctx, partial_ctx]) if ctx else partial_ctx return lls
def process_results(self, doc, results): @classmethod
"""Take a single document and the LM results and evaluates, returning a def append_context(cls, ctx, partial_ctx):
dict where keys are the names of submetrics and values are the values of ctx = ctx.split("\n\n") # Each fewshot context is on its own new line.
the metric for that one document ctx.pop() # Remove the correct context put in by `doc_to_text`.
return "\n\n".join([*ctx, partial_ctx]) if ctx else partial_ctx
:param doc:
The document as returned from training_docs, validation_docs, or test_docs. def process_results(self, doc, results):
:param results: """Take a single document and the LM results and evaluates, returning a
The results of the requests created in construct_requests. dict where keys are the names of submetrics and values are the values of
""" the metric for that one document
return {
"acc": np.argmax(results) == self.answer_to_num[doc["answer"]] :param doc:
} The document as returned from training_docs, validation_docs, or test_docs.
:param results:
def aggregation(self): The results of the requests created in construct_requests.
""" """
:returns: {str: [float] -> float} return {"acc": np.argmax(results) == self.answer_to_num[doc["answer"]]}
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics def aggregation(self):
""" """
return { :returns: {str: [float] -> float}
"acc": mean A dictionary where keys are the names of submetrics and values are
} functions that aggregate a list of metrics
"""
def higher_is_better(self): return {"acc": mean}
"""
:returns: {str: bool} def higher_is_better(self):
A dictionary where keys are the names of submetrics and values are """
whether a higher value of the submetric is better :returns: {str: bool}
""" A dictionary where keys are the names of submetrics and values are
return { whether a higher value of the submetric is better
"acc": True """
} return {"acc": True}
...@@ -40,8 +40,19 @@ class WinogradSchemaChallenge273(Task): ...@@ -40,8 +40,19 @@ class WinogradSchemaChallenge273(Task):
DATASET_PATH = "winograd_wsc" DATASET_PATH = "winograd_wsc"
DATASET_NAME = "wsc273" DATASET_NAME = "wsc273"
upper_pronouns = ["A", "An", "The", "She", "He", upper_pronouns = [
"It", "They", "My", "His", "Her", "Their"] "A",
"An",
"The",
"She",
"He",
"It",
"They",
"My",
"His",
"Her",
"Their",
]
def has_training_docs(self): def has_training_docs(self):
return False return False
...@@ -53,9 +64,9 @@ class WinogradSchemaChallenge273(Task): ...@@ -53,9 +64,9 @@ class WinogradSchemaChallenge273(Task):
return True return True
def test_docs(self): def test_docs(self):
return map(self._load_doc, self.dataset["test"]) return map(self._process_doc, self.dataset["test"])
def _load_doc(self, doc): def _process_doc(self, doc):
# The HF implementation of `wsc273` is not `partial evaluation` friendly. # The HF implementation of `wsc273` is not `partial evaluation` friendly.
doc["text"] = doc["text"].replace(" ", " ") doc["text"] = doc["text"].replace(" ", " ")
doc["options"][0] = self.__normalize_option(doc, doc["options"][0]) doc["options"][0] = self.__normalize_option(doc, doc["options"][0])
...@@ -68,7 +79,7 @@ class WinogradSchemaChallenge273(Task): ...@@ -68,7 +79,7 @@ class WinogradSchemaChallenge273(Task):
option += "'s" option += "'s"
# Appropriately lowercase the pronoun in the option. # Appropriately lowercase the pronoun in the option.
pronoun = option.split()[0] pronoun = option.split()[0]
start_of_sentence = doc["text"][doc['pronoun_loc'] - 2] == '.' start_of_sentence = doc["text"][doc["pronoun_loc"] - 2] == "."
if not start_of_sentence and pronoun in self.upper_pronouns: if not start_of_sentence and pronoun in self.upper_pronouns:
return option.replace(pronoun, pronoun.lower()) return option.replace(pronoun, pronoun.lower())
return option return option
...@@ -85,11 +96,17 @@ class WinogradSchemaChallenge273(Task): ...@@ -85,11 +96,17 @@ class WinogradSchemaChallenge273(Task):
def doc_to_text(self, doc): def doc_to_text(self, doc):
return self.partial_context(doc, doc["options"][doc["label"]]) return self.partial_context(doc, doc["options"][doc["label"]])
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["text"]
@classmethod @classmethod
def partial_context(cls, doc, option): def partial_context(cls, doc, option):
# Substitute the pronoun in the original text with the specified # Substitute the pronoun in the original text with the specified
# option and ignore everything after. # option and ignore everything after.
return doc["text"][:doc["pronoun_loc"]] + option return doc["text"][: doc["pronoun_loc"]] + option
def doc_to_target(self, doc): def doc_to_target(self, doc):
return self.partial_target(doc) return self.partial_target(doc)
...@@ -135,9 +152,7 @@ class WinogradSchemaChallenge273(Task): ...@@ -135,9 +152,7 @@ class WinogradSchemaChallenge273(Task):
:param results: :param results:
The results of the requests created in construct_requests. The results of the requests created in construct_requests.
""" """
return { return {"acc": np.argmax(results) == doc["label"]}
"acc": np.argmax(results) == doc["label"]
}
def aggregation(self): def aggregation(self):
""" """
...@@ -145,9 +160,7 @@ class WinogradSchemaChallenge273(Task): ...@@ -145,9 +160,7 @@ class WinogradSchemaChallenge273(Task):
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics functions that aggregate a list of metrics
""" """
return { return {"acc": mean}
"acc": mean
}
def higher_is_better(self): def higher_is_better(self):
""" """
...@@ -155,6 +168,4 @@ class WinogradSchemaChallenge273(Task): ...@@ -155,6 +168,4 @@ class WinogradSchemaChallenge273(Task):
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better whether a higher value of the submetric is better
""" """
return { return {"acc": True}
"acc": True
}
"""
XCOPA: A Multilingual Dataset for Causal Commonsense Reasoning
https://ducdauge.github.io/files/xcopa.pdf
The Cross-lingual Choice of Plausible Alternatives dataset is a benchmark to evaluate the ability of machine learning models to transfer commonsense reasoning across languages.
The dataset is the translation and reannotation of the English COPA (Roemmele et al. 2011) and covers 11 languages from 11 families and several areas around the globe.
The dataset is challenging as it requires both the command of world knowledge and the ability to generalise to new languages.
All the details about the creation of XCOPA and the implementation of the baselines are available in the paper.
Homepage: https://github.com/cambridgeltl/xcopa
"""
from .superglue import Copa
_CITATION = """
@inproceedings{ponti2020xcopa,
title={{XCOPA: A} Multilingual Dataset for Causal Commonsense Reasoning},
author={Edoardo M. Ponti, Goran Glava\v{s}, Olga Majewska, Qianchu Liu, Ivan Vuli\'{c} and Anna Korhonen},
booktitle={Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing (EMNLP)},
year={2020},
url={https://ducdauge.github.io/files/xcopa.pdf}
}
"""
class XCopa(Copa):
VERSION = 0
DATASET_PATH = "xcopa"
DATASET_NAME = None
CAUSE = "because"
EFFECT = "therefore"
def has_training_docs(self):
return False
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
def validation_docs(self):
return self.dataset["validation"]
def test_docs(self):
return self.dataset["test"]
def doc_to_text(self, doc):
# Drop the period
connector = {
"cause": self.CAUSE,
"effect": self.EFFECT,
}[doc["question"]]
return doc["premise"].strip()[:-1] + f" {connector}"
class XCopaEt(XCopa):
DATASET_NAME = "et"
CAUSE = "sest"
EFFECT = "seetõttu"
class XCopaHt(XCopa):
DATASET_NAME = "ht"
CAUSE = "poukisa"
EFFECT = "donk sa"
class XCopaIt(XCopa):
DATASET_NAME = "it"
CAUSE = "perché"
EFFECT = "quindi"
class XCopaId(XCopa):
DATASET_NAME = "id"
CAUSE = "karena"
EFFECT = "maka"
class XCopaQu(XCopa):
DATASET_NAME = "qu"
CAUSE = "imataq"
EFFECT = "chaymi"
class XCopaSw(XCopa):
DATASET_NAME = "sw"
CAUSE = "kwa sababu"
EFFECT = "kwa hiyo"
class XCopaZh(XCopa):
DATASET_NAME = "zh"
CAUSE = "因为"
EFFECT = "所以"
class XCopaTa(XCopa):
DATASET_NAME = "ta"
CAUSE = "காரணமாக"
EFFECT = "எனவே"
class XCopaTh(XCopa):
DATASET_NAME = "th"
CAUSE = "เพราะ"
EFFECT = "ดังนั้น"
class XCopaTr(XCopa):
DATASET_NAME = "tr"
CAUSE = "çünkü"
EFFECT = "bu yüzden"
class XCopaVi(XCopa):
DATASET_NAME = "vi"
CAUSE = "bởi vì"
EFFECT = "vì vậy"
LANGS = ["et", "ht", "it", "id", "qu", "sw", "zh", "ta", "th", "tr", "vi"]
LANG_CLASSES = [
XCopaEt,
XCopaHt,
XCopaIt,
XCopaId,
XCopaQu,
XCopaSw,
XCopaZh,
XCopaTa,
XCopaTh,
XCopaTr,
XCopaVi,
]
def construct_tasks():
tasks = {}
for lang, lang_class in zip(LANGS, LANG_CLASSES):
tasks[f"xcopa_{lang}"] = lang_class
return tasks
"""
XNLI: Evaluating Cross-lingual Sentence Representations
https://arxiv.org/abs/1809.05053
Based on the implementation of @yongzx (see https://github.com/EleutherAI/lm-evaluation-harness/pull/258)
Prompt format (same as XGLM and mGPT):
sentence1 + ", right? " + mask = (Yes|Also|No) + ", " + sentence2
Predicition is the full sequence with the highest likelihood.
Language specific prompts are translated word-by-word with Google Translate
and may differ from the ones used by mGPT and XGLM (they do not provide their prompts).
Homepage: https://github.com/facebookresearch/XNLI
"""
import numpy as np
from lm_eval.base import rf, Task
from lm_eval.metrics import mean
_CITATIONS = """
@InProceedings{conneau2018xnli,
author = "Conneau, Alexis
and Rinott, Ruty
and Lample, Guillaume
and Williams, Adina
and Bowman, Samuel R.
and Schwenk, Holger
and Stoyanov, Veselin",
title = "XNLI: Evaluating Cross-lingual Sentence Representations",
booktitle = "Proceedings of the 2018 Conference on Empirical Methods
in Natural Language Processing",
year = "2018",
publisher = "Association for Computational Linguistics",
location = "Brussels, Belgium",
}
"""
class XNLIBase(Task):
VERSION = 0
DATASET_PATH = "xnli"
DATASET_NAME = None
QUESTION_WORD = None # 'right'
ENTAILMENT_LABEL = None # 'Yes'
NEUTRAL_LABEL = None # 'Also'
CONTRADICTION_LABEL = None # 'No'
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
def training_docs(self):
return self.dataset["train"]
def validation_docs(self):
return self.dataset["validation"]
def test_docs(self):
return self.dataset["test"]
def doc_to_text(self, doc):
# Example:
# The girl that can help me is all the way across town, right? Yes, The girl I need help from lives a ways away.
# [MASK] is replaced with ENTAILMENT_LABEL, NEUTRAL_LABEL, or CONTRADICTION_LABEL
return (
doc["premise"]
+ ", "
+ self.QUESTION_WORD
+ "? [MASK], "
+ doc["hypothesis"]
)
def doc_to_target(self, doc):
# True = entailment
# False = contradiction
# Neither = neutral
return (
" "
+ [self.ENTAILMENT_LABEL, self.NEUTRAL_LABEL, self.CONTRADICTION_LABEL][
doc["label"]
]
)
def construct_requests(self, doc, ctx):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
ll_true = rf.loglikelihood_rolling(ctx.replace("[MASK]", self.ENTAILMENT_LABEL))
ll_neither = rf.loglikelihood_rolling(ctx.replace("[MASK]", self.NEUTRAL_LABEL))
ll_false = rf.loglikelihood_rolling(
ctx.replace("[MASK]", self.CONTRADICTION_LABEL)
)
return ll_true, ll_neither, ll_false
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
The results of the requests created in construct_requests.
"""
gold = doc["label"]
pred = np.argmax(results)
return {"acc": pred == gold}
def aggregation(self):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
return {"acc": mean}
def higher_is_better(self):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return {"acc": True}
class XNLI_en(XNLIBase): # English
DATASET_NAME = "en"
QUESTION_WORD = "right"
ENTAILMENT_LABEL = "Yes"
NEUTRAL_LABEL = "Also"
CONTRADICTION_LABEL = "No"
class XNLI_de(XNLIBase): # German
DATASET_NAME = "de"
QUESTION_WORD = "richtig"
ENTAILMENT_LABEL = "Ja"
NEUTRAL_LABEL = "Auch"
CONTRADICTION_LABEL = "Nein"
class XNLI_ar(XNLIBase): # Arabic
DATASET_NAME = "ar"
QUESTION_WORD = "صحيح"
ENTAILMENT_LABEL = "نعم"
NEUTRAL_LABEL = "لذا"
CONTRADICTION_LABEL = "رقم"
class XNLI_bg(XNLIBase): # Bulgarian
DATASET_NAME = "bg"
QUESTION_WORD = "правилно"
ENTAILMENT_LABEL = "да"
NEUTRAL_LABEL = "така"
CONTRADICTION_LABEL = "не"
class XNLI_el(XNLIBase): # Greek
DATASET_NAME = "el"
QUESTION_WORD = "σωστός"
ENTAILMENT_LABEL = "Ναί"
NEUTRAL_LABEL = "Έτσι"
CONTRADICTION_LABEL = "όχι"
class XNLI_es(XNLIBase): # Spanish
DATASET_NAME = "es"
QUESTION_WORD = "correcto"
ENTAILMENT_LABEL = "Sí"
NEUTRAL_LABEL = "Asi que"
CONTRADICTION_LABEL = "No"
class XNLI_fr(XNLIBase): # French
DATASET_NAME = "fr"
QUESTION_WORD = "correct"
ENTAILMENT_LABEL = "Oui"
NEUTRAL_LABEL = "Aussi"
CONTRADICTION_LABEL = "Non"
class XNLI_hi(XNLIBase): # Hindi
DATASET_NAME = "hi"
QUESTION_WORD = "सही"
ENTAILMENT_LABEL = "हाँ"
NEUTRAL_LABEL = "इसलिए"
CONTRADICTION_LABEL = "नहीं"
class XNLI_ru(XNLIBase): # Russian
DATASET_NAME = "ru"
QUESTION_WORD = "правильно"
ENTAILMENT_LABEL = "Да"
NEUTRAL_LABEL = "Так"
CONTRADICTION_LABEL = "Нет"
class XNLI_sw(XNLIBase): # Swahili
DATASET_NAME = "sw"
QUESTION_WORD = "sahihi"
ENTAILMENT_LABEL = "Ndiyo"
NEUTRAL_LABEL = "Hivyo"
CONTRADICTION_LABEL = "Hapana"
class XNLI_th(XNLIBase): # Thai
DATASET_NAME = "th"
QUESTION_WORD = "ถูกต้อง"
ENTAILMENT_LABEL = "ใช่"
NEUTRAL_LABEL = "ดังนั้น"
CONTRADICTION_LABEL = "ไม่"
class XNLI_tr(XNLIBase): # Turkish
DATASET_NAME = "tr"
QUESTION_WORD = "doğru"
ENTAILMENT_LABEL = "Evet"
NEUTRAL_LABEL = "Böylece"
CONTRADICTION_LABEL = "Hayır"
class XNLI_ur(XNLIBase): # Urdu
DATASET_NAME = "ur"
QUESTION_WORD = "صحیح"
ENTAILMENT_LABEL = "جی ہاں"
NEUTRAL_LABEL = "اس لئے"
CONTRADICTION_LABEL = "نہیں"
class XNLI_vi(XNLIBase): # Vietnamese
DATASET_NAME = "vi"
QUESTION_WORD = "đúng"
ENTAILMENT_LABEL = "Vâng"
NEUTRAL_LABEL = "Vì vậy"
CONTRADICTION_LABEL = "Không"
class XNLI_zh(XNLIBase): # Chinese
DATASET_NAME = "zh"
QUESTION_WORD = "正确"
ENTAILMENT_LABEL = "是的"
NEUTRAL_LABEL = "所以"
CONTRADICTION_LABEL = "不是的"
LANGS = [
"ar",
"bg",
"de",
"el",
"en",
"es",
"fr",
"hi",
"ru",
"sw",
"th",
"tr",
"ur",
"vi",
"zh",
]
LANG_CLASSES = [
XNLI_ar,
XNLI_bg,
XNLI_de,
XNLI_el,
XNLI_en,
XNLI_es,
XNLI_fr,
XNLI_hi,
XNLI_ru,
XNLI_sw,
XNLI_th,
XNLI_tr,
XNLI_ur,
XNLI_vi,
XNLI_zh,
]
def construct_tasks():
tasks = {}
for lang, lang_class in zip(LANGS, LANG_CLASSES):
tasks[f"xnli_{lang}"] = lang_class
return tasks
"""
Few-shot Learning with Multilingual Language Models
https://arxiv.org/abs/2112.10668
XStoryCloze consists of the professionally translated version of the [English StoryCloze dataset](https://cs.rochester.edu/nlp/rocstories/) (Spring 2016 version) to 10 non-English languages. This dataset is released by Meta AI.
Homepage: https://github.com/facebookresearch/fairseq/pull/4820
"""
from .storycloze import StoryCloze
_CITATION = """
@article{DBLP:journals/corr/abs-2112-10668,
author = {Xi Victoria Lin and
Todor Mihaylov and
Mikel Artetxe and
Tianlu Wang and
Shuohui Chen and
Daniel Simig and
Myle Ott and
Naman Goyal and
Shruti Bhosale and
Jingfei Du and
Ramakanth Pasunuru and
Sam Shleifer and
Punit Singh Koura and
Vishrav Chaudhary and
Brian O'Horo and
Jeff Wang and
Luke Zettlemoyer and
Zornitsa Kozareva and
Mona T. Diab and
Veselin Stoyanov and
Xian Li},
title = {Few-shot Learning with Multilingual Language Models},
journal = {CoRR},
volume = {abs/2112.10668},
year = {2021},
url = {https://arxiv.org/abs/2112.10668},
eprinttype = {arXiv},
eprint = {2112.10668},
timestamp = {Tue, 04 Jan 2022 15:59:27 +0100},
biburl = {https://dblp.org/rec/journals/corr/abs-2112-10668.bib},
bibsource = {dblp computer science bibliography, https://dblp.org}
}
"""
_LANG = ["en", "ru", "zh", "es", "ar", "hi", "id", "te", "sw", "eu", "my"]
def create_all_tasks():
"""Creates a dictionary of tasks from a list of subjects
:return: {task_name: task}
"""
return {f"xstory_cloze_{lang}": create_task(lang) for lang in _LANG}
def create_task(lang):
class XStoryCloze(StoryCloze):
DATASET_PATH = "juletxara/xstory_cloze"
DATASET_NAME = lang
def __init__(self):
super().__init__(data_dir="")
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return False
def training_docs(self):
return self.dataset["train"]
def validation_docs(self):
return self.dataset["eval"]
def test_docs(self):
pass
return XStoryCloze
"""
It's All in the Heads: Using Attention Heads as a Baseline for Cross-Lingual Transfer in Commonsense Reasoning
https://arxiv.org/abs/2106.12066
Multilingual winograd schema challenge that includes English, French, Japanese, Portuguese, Russian and Chinese. Winograd schema challenges come from the XWinograd dataset introduced in Tikhonov et al. As it only contains 16 Chinese schemas, we add 488 Chinese schemas from clue/cluewsc2020.
Homepage: https://huggingface.co/datasets/Muennighoff/xwinograd
"""
from .winogrande import Winogrande
_CITATION = """
@misc{muennighoff2022crosslingual,
title={Crosslingual Generalization through Multitask Finetuning},
author={Niklas Muennighoff and Thomas Wang and Lintang Sutawika and Adam Roberts and Stella Biderman and Teven Le Scao and M Saiful Bari and Sheng Shen and Zheng-Xin Yong and Hailey Schoelkopf and Xiangru Tang and Dragomir Radev and Alham Fikri Aji and Khalid Almubarak and Samuel Albanie and Zaid Alyafeai and Albert Webson and Edward Raff and Colin Raffel},
year={2022},
eprint={2211.01786},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
@misc{tikhonov2021heads,
title={It's All in the Heads: Using Attention Heads as a Baseline for Cross-Lingual Transfer in Commonsense Reasoning},
author={Alexey Tikhonov and Max Ryabinin},
year={2021},
eprint={2106.12066},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
"""
_LANG = ["en", "fr", "jp", "pt", "ru", "zh"]
def create_all_tasks():
"""Creates a dictionary of tasks from a list of subjects
:return: {task_name: task}
"""
return {f"xwinograd_{lang}": create_task(lang) for lang in _LANG}
def create_task(lang):
class XWinograd(Winogrande):
DATASET_PATH = "Muennighoff/xwinograd"
DATASET_NAME = lang
def __init__(self):
super().__init__()
def has_training_docs(self):
return False
def has_validation_docs(self):
return False
def has_test_docs(self):
return True
def training_docs(self):
pass
def validation_docs(self):
pass
def test_docs(self):
return self.dataset["test"]
return XWinograd
...@@ -5,8 +5,11 @@ import collections ...@@ -5,8 +5,11 @@ import collections
import functools import functools
import inspect import inspect
import sys import sys
import pytest from typing import List, Union
from typing import List
import torch
from omegaconf import OmegaConf
class ExitCodeError(Exception): class ExitCodeError(Exception):
...@@ -28,12 +31,10 @@ def simple_parse_args_string(args_string): ...@@ -28,12 +31,10 @@ def simple_parse_args_string(args_string):
if not args_string: if not args_string:
return {} return {}
arg_list = args_string.split(",") arg_list = args_string.split(",")
args_dict = {} args_dict = OmegaConf.to_object(OmegaConf.from_dotlist(arg_list))
for arg in arg_list:
k, v = arg.split("=")
args_dict[k] = v
return args_dict return args_dict
def join_iters(iters): def join_iters(iters):
for iter in iters: for iter in iters:
yield from iter yield from iter
...@@ -46,23 +47,26 @@ def chunks(iter, n): ...@@ -46,23 +47,26 @@ def chunks(iter, n):
if len(arr) == n: if len(arr) == n:
yield arr yield arr
arr = [] arr = []
if arr: yield arr if arr:
yield arr
def group(arr, fn): def group(arr, fn):
res = collections.defaultdict(list) res = collections.defaultdict(list)
for ob in arr: for ob in arr:
res[fn(ob)].append(ob) res[fn(ob)].append(ob)
return list(res.values()) return list(res.values())
def general_detokenize(string): def general_detokenize(string):
string = string.replace(" n't", "n't") string = string.replace(" n't", "n't")
string = string.replace(" )", ")") string = string.replace(" )", ")")
string = string.replace("( ", "(") string = string.replace("( ", "(")
string = string.replace("\" ", "\"") string = string.replace('" ', '"')
string = string.replace(" \"", "\"") string = string.replace(' "', '"')
string = re.sub(r" (['.,])", r"\1", string) string = re.sub(r" (['.,])", r"\1", string)
return string return string
...@@ -94,10 +98,7 @@ def get_rolling_token_windows(token_list, prefix_token, max_seq_len, context_len ...@@ -94,10 +98,7 @@ def get_rolling_token_windows(token_list, prefix_token, max_seq_len, context_len
# Special handling for first window: predict all tokens # Special handling for first window: predict all tokens
first_seq_len = min(max_seq_len, len(token_list)) first_seq_len = min(max_seq_len, len(token_list))
yield ( yield ([prefix_token] + token_list[: first_seq_len - 1], token_list[:first_seq_len])
[prefix_token] + token_list[:first_seq_len - 1],
token_list[:first_seq_len]
)
predicted += first_seq_len predicted += first_seq_len
while predicted < len(token_list): while predicted < len(token_list):
...@@ -105,61 +106,84 @@ def get_rolling_token_windows(token_list, prefix_token, max_seq_len, context_len ...@@ -105,61 +106,84 @@ def get_rolling_token_windows(token_list, prefix_token, max_seq_len, context_len
window_end = predicted + window_pred_len window_end = predicted + window_pred_len
yield ( yield (
token_list[window_end - max_seq_len - 1:window_end - 1], token_list[window_end - max_seq_len - 1 : window_end - 1],
token_list[window_end - window_pred_len:window_end], token_list[window_end - window_pred_len : window_end],
) )
predicted += window_pred_len predicted += window_pred_len
def make_disjoint_window(pair):
""" Takes output from get_rolling_token_windows and makes the context not overlap with the continuation """
def make_disjoint_window(pair):
"""Takes output from get_rolling_token_windows and makes the context not overlap with the continuation"""
a, b = pair a, b = pair
return a[: len(a) - (len(b) - 1)], b
def select_continuation_from_batch_left_padding(
generations: Union[List[List[int]], torch.Tensor], max_context_size: int
):
"""Select the continuation from the batch, removing prompts of different lengths.
Args:
generations (Union[List[List[int]], torch.Tensor]):
A tensor or list-of-lists of shape [batch_size, sequence length].
max_context_size (int):
The size of the biggest context; generations will proceed from that
index.
Example:
PAD PAD Continue : The dog chased the cat [every day of the week]
Riddle me this : The dog chased the cat [yesterday] PAD PAD PAD PAD
Output:
[every day of the week]
[yesterday] PAD PAD PAD PAD
"""
return generations[:, max_context_size:]
return a[:-(len(b) - 1)], b
class Reorderer: class Reorderer:
def __init__(self, arr, fn): def __init__(self, arr, fn):
self.size = len(arr) self.size = len(arr)
arr = list(enumerate(arr)) arr = list(enumerate(arr))
arr = group(arr, lambda x: fn(x[1])) arr = group(arr, lambda x: fn(x[1]))
arr = [ arr = [([y[0] for y in x], x[0][1]) for x in arr]
([y[0] for y in x], x[0][1]) for x in arr
]
arr.sort(key=lambda x: fn(x[1])) arr.sort(key=lambda x: fn(x[1]))
self.arr = arr self.arr = arr
def get_reordered(self): def get_reordered(self):
return [x[1] for x in self.arr] return [x[1] for x in self.arr]
def get_original(self, newarr): def get_original(self, newarr):
res = [None] * self.size res = [None] * self.size
cov = [False] * self.size cov = [False] * self.size
for (inds, _), v in zip(self.arr, newarr): for (inds, _), v in zip(self.arr, newarr):
for ind in inds: for ind in inds:
res[ind] = v res[ind] = v
cov[ind] = True cov[ind] = True
assert all(cov) assert all(cov)
return res return res
def positional_deprecated(fn): def positional_deprecated(fn):
""" """
A decorator to nudge users into passing only keyword args (`kwargs`) to the A decorator to nudge users into passing only keyword args (`kwargs`) to the
wrapped function, `fn`. wrapped function, `fn`.
""" """
@functools.wraps(fn) @functools.wraps(fn)
def _wrapper(*args, **kwargs): def _wrapper(*args, **kwargs):
if len(args) != 1 if inspect.ismethod(fn) else 0: if len(args) != 1 if inspect.ismethod(fn) else 0:
print(f"WARNING: using {fn.__name__} with positional arguments is " print(
f"WARNING: using {fn.__name__} with positional arguments is "
"deprecated and will be disallowed in a future version of " "deprecated and will be disallowed in a future version of "
"lm-evaluation-harness!") "lm-evaluation-harness!"
)
return fn(*args, **kwargs) return fn(*args, **kwargs)
return _wrapper return _wrapper
@positional_deprecated @positional_deprecated
def find_test_root(start_path: pathlib.Path) -> pathlib.Path: def find_test_root(start_path: pathlib.Path) -> pathlib.Path:
""" """
...@@ -169,22 +193,33 @@ def find_test_root(start_path: pathlib.Path) -> pathlib.Path: ...@@ -169,22 +193,33 @@ def find_test_root(start_path: pathlib.Path) -> pathlib.Path:
cur_path = start_path.resolve() cur_path = start_path.resolve()
max_layers = 3 max_layers = 3
for _ in range(max_layers): for _ in range(max_layers):
if (cur_path / 'tests' / 'test_version_stable.py').exists(): if (cur_path / "tests" / "test_version_stable.py").exists():
return cur_path return cur_path
else: else:
cur_path = cur_path.parent.resolve() cur_path = cur_path.parent.resolve()
raise FileNotFoundError(f"Unable to find package root within {max_layers} upwards" +\ raise FileNotFoundError(
f"of {start_path}") f"Unable to find package root within {max_layers} upwards" + f"of {start_path}"
)
@positional_deprecated @positional_deprecated
def run_task_tests(task_list: List[str]): def run_task_tests(task_list: List[str]):
""" """
Find the package root and run the tests for the given tasks Find the package root and run the tests for the given tasks
""" """
import pytest
package_root = find_test_root(start_path=pathlib.Path(__file__)) package_root = find_test_root(start_path=pathlib.Path(__file__))
task_string = ' or '.join(task_list) task_string = " or ".join(task_list)
args = [f'{package_root}/tests/test_version_stable.py', f'--rootdir={package_root}', '-k', f'{task_string}'] args = [
f"{package_root}/tests/test_version_stable.py",
f"--rootdir={package_root}",
"-k",
f"{task_string}",
]
sys.path.append(str(package_root)) sys.path.append(str(package_root))
pytest_return_val = pytest.main(args) pytest_return_val = pytest.main(args)
if pytest_return_val: if pytest_return_val:
raise ValueError(f"Not all tests for the specified tasks ({task_list}) ran successfully! Error code: {pytest_return_val}") raise ValueError(
\ No newline at end of file f"Not all tests for the specified tasks ({task_list}) ran successfully! Error code: {pytest_return_val}"
)
import argparse import argparse
import json import json
import logging import logging
import fnmatch
from lm_eval import tasks, evaluator from lm_eval import tasks, evaluator
logging.getLogger("openai").setLevel(logging.WARNING) logging.getLogger("openai").setLevel(logging.WARNING)
class MultiChoice:
def __init__(self, choices):
self.choices = choices
# Simple wildcard support (linux filename patterns)
def __contains__(self, values):
for value in values.split(","):
if len(fnmatch.filter(self.choices, value)) == 0:
return False
return True
def __iter__(self):
for choice in self.choices:
yield choice
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--model', required=True) parser.add_argument("--model", required=True)
parser.add_argument('--model_args', default="") parser.add_argument("--model_args", default="")
parser.add_argument('--tasks', default="all_tasks") parser.add_argument("--tasks", default=None, choices=MultiChoice(tasks.ALL_TASKS))
parser.add_argument('--provide_description', action="store_true") parser.add_argument("--provide_description", action="store_true")
parser.add_argument('--num_fewshot', type=int, default=0) parser.add_argument("--num_fewshot", type=int, default=0)
parser.add_argument('--batch_size', type=int, default=None) parser.add_argument("--batch_size", type=str, default=None)
parser.add_argument('--device', type=str, default=None) parser.add_argument("--device", type=str, default=None)
parser.add_argument('--output_path', default=None) parser.add_argument("--output_path", default=None)
parser.add_argument('--limit', type=int, default=None) parser.add_argument("--limit", type=int, default=None)
parser.add_argument('--no_cache', action="store_true") parser.add_argument("--no_cache", action="store_true")
parser.add_argument('--description_dict_path', default=None) parser.add_argument("--decontamination_ngrams_path", default=None)
parser.add_argument('--check_integrity', action="store_true") parser.add_argument("--description_dict_path", default=None)
parser.add_argument("--check_integrity", action="store_true")
return parser.parse_args() return parser.parse_args()
# Returns a list containing all values of the source_list that
# match at least one of the patterns
def pattern_match(patterns, source_list):
task_names = set()
for pattern in patterns:
for matching in fnmatch.filter(source_list, pattern):
task_names.add(matching)
return sorted(list(task_names))
def main(): def main():
args = parse_args() args = parse_args()
assert not args.provide_description # not implemented assert not args.provide_description # not implemented
if args.limit: if args.limit:
print("WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.") print(
"WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
)
if args.tasks == "all_tasks": if args.tasks is None:
task_names = tasks.ALL_TASKS task_names = tasks.ALL_TASKS
else: else:
task_names = args.tasks.split(",") task_names = pattern_match(args.tasks.split(","), tasks.ALL_TASKS)
print(f"Selected Tasks: {task_names}")
description_dict = {} description_dict = {}
if args.description_dict_path: if args.description_dict_path:
with open(args.description_dict_path, 'r') as f: with open(args.description_dict_path, "r") as f:
description_dict = json.load(f) description_dict = json.load(f)
results = evaluator.simple_evaluate( results = evaluator.simple_evaluate(
...@@ -51,11 +86,11 @@ def main(): ...@@ -51,11 +86,11 @@ def main():
no_cache=args.no_cache, no_cache=args.no_cache,
limit=args.limit, limit=args.limit,
description_dict=description_dict, description_dict=description_dict,
check_integrity=args.check_integrity decontamination_ngrams_path=args.decontamination_ngrams_path,
check_integrity=args.check_integrity,
) )
dumped = json.dumps(results, indent=2) dumped = json.dumps(results, indent=2)
print(dumped) print(dumped)
if args.output_path: if args.output_path:
......
{
"Data": "Pile statistics",
"Document Count": 210607728,
"Total Pile Characters": 421215456,
"File Start Offsets": [
0,
7021438,
14042822,
21066113,
28086515,
35106072,
42123306,
49145091,
56165817,
63185587,
70211208,
77234322,
84249267,
91267634,
98285983,
105305110,
112322489,
119342491,
126367373,
133389153,
140412039,
147432373,
154452516,
161470190,
168492733,
175512521,
182526939,
189547478,
196565318,
203583306
]
}
janitor.py contains a script to remove benchmark data contamination from training data sets. janitor.py contains a script to remove benchmark data contamination from training data sets.
It uses the approach described in the [GPT-3 paper](https://arxiv.org/abs/2005.14165). It uses the approach described in the [GPT-3 paper](https://arxiv.org/abs/2005.14165).
## Algorithm ## Algorithm
1) Collects all contamination text files that are to be removed from training data 1) Collects all contamination text files that are to be removed from training data
2) Filters training data by finding `N`gram matches between the training data 2) Filters training data by finding `N`gram matches between the training data
and any contamination and any contamination
1) `N`grams ignore case and punctation and are split on whitespace. 1) `N`grams ignore case and punctuation and are split on whitespace.
2) Matching `N`gram substrings are removed, as is a `window_to_remove` character window around 2) Matching `N`gram substrings are removed, as is a `window_to_remove` character window around
the match, splitting the training data into chunks the match, splitting the training data into chunks
3) Any chunks less than `minimum_slice_length` are removed 3) Any chunks less than `minimum_slice_length` are removed
4) Training data sets split into more than `too_dirty_cutoff` are considered 4) Training data sets split into more than `too_dirty_cutoff` are considered
completey contaminated and removed completey contaminated and removed
OpenAI used: OpenAI used:
``` ```
ngram_n = 13 ngram_n = 13
...@@ -20,7 +20,7 @@ minimum_slice_length = 200 ...@@ -20,7 +20,7 @@ minimum_slice_length = 200
too_dirty_cutoff = 10 too_dirty_cutoff = 10
``` ```
## Compling ## Compiling
Janitor can be used as a pure python program, but it is much faster if the ngram Janitor can be used as a pure python program, but it is much faster if the ngram
code is run in C++. To compile the C++ code, run code is run in C++. To compile the C++ code, run
...@@ -31,4 +31,3 @@ c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor ...@@ -31,4 +31,3 @@ c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor
``` ```
If your your compiler isn't linked to python, you may need to add to the above `-undefined dynamic_lookup` If your your compiler isn't linked to python, you may need to add to the above `-undefined dynamic_lookup`
import glob
import argparse
import os
import subprocess
import shutil
from tqdm import tqdm
from tqdm_multiprocess import TqdmMultiProcessPool
import logging
from tqdm_multiprocess.logger import setup_logger_tqdm
logger = logging.getLogger(__name__)
def process_task(
working_directory, output_directory, bucket_file_path, tqdm_func, global_tqdm
):
command = f"zstd {bucket_file_path}"
logger.info(command)
subprocess.call(command, shell=True)
compressed_file = bucket_file_path + ".zst"
if output_directory:
shutil.move(compressed_file, output_directory)
os.remove(bucket_file_path)
global_tqdm.update()
def compress_and_move(working_directory, output_directory, process_count):
os.makedirs(output_directory, exist_ok=True)
original_info_file_path = os.path.join(working_directory, "info.json")
assert os.path.exists(original_info_file_path)
tasks = []
bucket_file_paths = glob.glob(
os.path.join(working_directory, "output", f"*.bkt.txt.sorted")
)
for bucket_file_path in bucket_file_paths:
task = (process_task, (working_directory, output_directory, bucket_file_path))
tasks.append(task)
pool = TqdmMultiProcessPool(process_count)
def on_done(_):
return None
def on_error(_):
return None
global_progress = tqdm(
total=len(bucket_file_paths), dynamic_ncols=True, unit="file"
)
_ = pool.map(global_progress, tasks, on_error, on_done)
shutil.copy(original_info_file_path, os.path.join(output_directory, "info.json"))
parser = argparse.ArgumentParser(description="sort 13gram buckets")
parser.add_argument("-dir", "--working_directory", required=True)
parser.add_argument("-output", "--output_directory", required=True)
parser.add_argument("-procs", "--process_count", type=int, default=8)
if __name__ == "__main__":
version = 1.00
print(f"Running version {version}")
logfile_path = "compress_and_package.log"
setup_logger_tqdm(logfile_path)
args = parser.parse_args()
compress_and_move(args.working_directory, args.output_directory, args.process_count)
""" """
Outputs all 13-grams found in The Pile. Outputs all 13-grams found in The Pile.
Loops through all documents and uses the logic found in janitor.py to extract 13-grams. Loops through all documents and uses the logic found in janitor.py to extract 13-grams.
We bucket each 13-gram by hash into separate file buckets to allow easy parallel processing in the We bucket each 13-gram by hash into separate file buckets to allow easy parallel processing in the
next stage. We also include the current pile document_id with each ngram instance to allow the next stage. We also include the current pile document_id with each ngram instance to allow the
filtering to exclude 13-grams that match more then 10 unique documents (done further down the pipeline). filtering to exclude 13-grams that match more then 10 unique documents (done further down the pipeline).
We didn't use lm_dataformat to output as it increases time 4x (slow jsonify) and makes We didn't use lm_dataformat to output as it increases time 4x (slow jsonify) and makes
...@@ -21,8 +21,10 @@ Arguments ...@@ -21,8 +21,10 @@ Arguments
""" """
import argparse import argparse
import json
import pickle import pickle
import os import os
import sys
from pathlib import Path from pathlib import Path
import glob import glob
import signal import signal
...@@ -30,32 +32,98 @@ from signal import SIGINT ...@@ -30,32 +32,98 @@ from signal import SIGINT
from tqdm import tqdm from tqdm import tqdm
from scripts.clean_training_data.janitor import Janitor, word_ngrams from lm_eval.decontamination.janitor import Janitor, word_ngrams
from scripts.clean_training_data.archiver import TextArchive, Reader from lm_eval.decontamination.archiver import TextArchive, Reader
import logging import logging
from tqdm_multiprocess.logger import setup_logger_tqdm from tqdm_multiprocess.logger import setup_logger_tqdm
logger = logging.getLogger(__name__)
pile_document_count = 210607728 logger = logging.getLogger(__name__)
terminate = False terminate = False
def handler(signal_received, frame): def handler(signal_received, frame):
global terminate global terminate
terminate = True terminate = True
def get_pile(directory):
reader = Reader() def yield_pile(start_offsets=None, checkpoint_offset=None):
for file in glob.glob(os.path.join(directory, f"*.jsonl.zst*")): directory = "pile"
if not os.path.exists(directory):
print(
"We expect the pile archives to be in the 'pile' directory, but this was not found."
)
raise Exception("Pile directory not found.")
files = list(sorted(glob.glob(os.path.join(directory, "*.jsonl.zst*"))))
pile_global_offset = 0
start_file = 0
if checkpoint_offset:
for file_i, start_offset in enumerate(start_offsets):
if start_offset > checkpoint_offset:
break
start_file = file_i
pile_global_offset = start_offset
for file_i, file in enumerate(files):
if file_i < start_file:
logger.info(f"Skipping file {file}")
continue
logger.info(f"Reading from pile file: {file}")
reader = Reader()
for document in reader.read(file): for document in reader.read(file):
yield document yield (pile_global_offset, document)
pile_global_offset += 1
# Hash buckets > disk backed files. Supports file position checkpointing and resuming
# Allows you to write continuously and checkpoint intermittently. If a failure occurs
# the buckets are simply truncated at your last checkpoint.
class Buckets:
def __init__(self, directory, num_buckets):
self.bucket_files = [
os.path.join(directory, f"ngrams_{i}.bkt.txt") for i in range(num_buckets)
]
self.buckets = list(map(TextArchive, self.bucket_files))
self.checkpoint_file = os.path.join(directory, f"bucket_offsets.ckpt")
if os.path.exists(self.checkpoint_file):
self.bucket_offsets = pickle.load(open(self.checkpoint_file, "rb"))
else:
self.bucket_offsets = [0 for i in range(len(self.buckets))]
for i, offset in enumerate(self.bucket_offsets):
bucket = self.buckets[i]
bucket.fh.seek(offset)
bucket.fh.truncate()
def add_data(self, key, value):
i = hash(key) % len(self.buckets)
bucket = self.buckets[i]
bucket.add_data(value)
def save_checkpoint(self):
for bucket in self.buckets:
bucket.fh.flush()
bucket_offsets = [bucket.fh.tell() for bucket in self.buckets]
pickle.dump(bucket_offsets, open(self.checkpoint_file, "wb"))
def close_buckets(self):
for bucket in self.buckets:
bucket.commit()
def close_buckets(buckets):
for bucket in buckets:
bucket.commit()
def do_ngrams_in_buckets(n_value, working_directory, bucket_count): def do_ngrams_in_buckets(n_value, working_directory, bucket_count):
pile_statistics = json.load(open("pile_statistics.json", "r"))
pile_document_count = pile_statistics["Document Count"]
start_offsets = pile_statistics["File Start Offsets"]
output_directory = os.path.join(working_directory, "output") output_directory = os.path.join(working_directory, "output")
os.makedirs(output_directory, exist_ok=True) os.makedirs(output_directory, exist_ok=True)
...@@ -68,58 +136,71 @@ def do_ngrams_in_buckets(n_value, working_directory, bucket_count): ...@@ -68,58 +136,71 @@ def do_ngrams_in_buckets(n_value, working_directory, bucket_count):
return return
# Checkpoint # Checkpoint
checkpoint_file = os.path.join(output_directory, f"ngram_buckets.ckpt") checkpoint_file = os.path.join(working_directory, f"pile_offset.ckpt")
if os.path.exists(checkpoint_file): if os.path.exists(checkpoint_file):
start_id = pickle.load(open(checkpoint_file,"rb")) checkpoint_offset = pickle.load(open(checkpoint_file, "rb"))
iterate = True
else: else:
start_id = 0 checkpoint_offset = 0
iterate = False
logger.info(f"Starting at pile document index {start_id}") logger.info(f"Starting at pile document index {checkpoint_offset}")
bucket_files = [os.path.join(output_directory, f"ngrams_{i}.bkt.txt") for i in range(bucket_count)] buckets = Buckets(output_directory, bucket_count)
buckets = list(map(TextArchive, bucket_files))
janitor = Janitor() janitor = Janitor()
current_id = 0
batch_size = 1000 batch_size = 1000
batch_counter = 0 batch_counter = 0
with tqdm(total=pile_document_count, dynamic_ncols=True, unit="docs") as progress:
for document in get_pile(working_directory):
if current_id < start_id:
if terminate:
close_buckets(buckets)
return
current_id += 1 with tqdm(total=checkpoint_offset, dynamic_ncols=True, unit="docs") as progress:
for offset, document in yield_pile(start_offsets, checkpoint_offset):
if iterate:
logger.info(f"Iterating to offset {checkpoint_offset} from {offset}")
progress.update(offset)
iterate = False
if offset < checkpoint_offset:
progress.update() progress.update()
if terminate:
return
continue continue
if offset == checkpoint_offset:
progress.reset(total=pile_document_count)
progress.update(checkpoint_offset)
# Save checkpoint every "batch_size", only allow terminate after checkpoint # Save checkpoint every "batch_size", only allow terminate after checkpoint
if batch_counter == batch_size: if batch_counter == batch_size:
progress.update(batch_size) progress.update(batch_size)
batch_counter = 0 batch_counter = 0
pickle.dump(current_id, open(checkpoint_file,"wb")) buckets.save_checkpoint()
pickle.dump(offset, open(checkpoint_file, "wb"))
if terminate: if terminate:
close_buckets(buckets) buckets.close_buckets()
return return
ngrams = word_ngrams(janitor.normalize_string(document), n_value) ngrams = word_ngrams(janitor.normalize_string(document), n_value)
for ngram in ngrams: for ngram in ngrams:
bucket = hash(ngram) % len(buckets) buckets.add_data(ngram, f"{ngram} {offset}")
buckets[bucket].add_data(f"{ngram} {current_id}")
batch_counter += 1 batch_counter += 1
current_id += 1
buckets.close_buckets()
close_buckets(buckets)
Path(done_file).touch() Path(done_file).touch()
parser = argparse.ArgumentParser(description='Generate 13 grams from Pile.') parser = argparse.ArgumentParser(description="Generate 13 grams from Pile.")
parser.add_argument("-dir", "--working_directory", default="") parser.add_argument("-dir", "--working_directory", default="")
parser.add_argument("-n", "--n_value", type=int, default=13) parser.add_argument("-n", "--n_value", type=int, default=13)
parser.add_argument("-buckets", "--bucket_count", type=int, default=500) parser.add_argument("-buckets", "--bucket_count", type=int, default=500)
if __name__ == '__main__': if __name__ == "__main__":
version = 1.00
print(f"Running version {version}")
if "PYTHONHASHSEED" not in os.environ or os.environ["PYTHONHASHSEED"] != "0":
print("Please run 'export PYTHONHASHSEED=0' before running generate.")
sys.exit()
# Handle sigint (ctrl-c) cleanly # Handle sigint (ctrl-c) cleanly
previous_signal_int = signal.signal(SIGINT, handler) previous_signal_int = signal.signal(SIGINT, handler)
...@@ -128,4 +209,8 @@ if __name__ == '__main__': ...@@ -128,4 +209,8 @@ if __name__ == '__main__':
setup_logger_tqdm(logfile_path) setup_logger_tqdm(logfile_path)
args = parser.parse_args() args = parser.parse_args()
do_ngrams_in_buckets(args.n_value, args.working_directory, args.bucket_count) do_ngrams_in_buckets(args.n_value, args.working_directory, args.bucket_count)
\ No newline at end of file
info_dict = {"title": "dataset ngrams", "ngram_size": 13}
info_dict_path = os.path.join(args.working_directory, "info.json")
json.dump(info_dict, open(info_dict_path, "w"))
from lm_eval.decontamination.archiver import Reader
import os
import json
from functools import reduce
import glob
import tqdm
from tqdm_multiprocess import TqdmMultiProcessPool
def get_file_stats(file_path, tqdm_func, global_tqdm):
reader = Reader()
total_documents = 0
total_size = 0
update_frequency = 10000
current_file_position = 0
with tqdm_func(
total=os.path.getsize(file_path), dynamic_ncols=True, unit="byte", unit_scale=1
) as progress:
for document in reader.read(file_path, get_meta=True):
total_size += len(document)
total_documents += 1
if total_documents % update_frequency == 0:
new_file_pos = reader.fh.tell()
bytes_read = new_file_pos - current_file_position
current_file_position = new_file_pos
progress.update(bytes_read)
global_tqdm.update(bytes_read)
return (total_documents, total_size)
def get_files():
directory = "pile"
files = list(sorted(glob.glob(os.path.join(directory, "*.jsonl.zst*"))))
print(files)
return files
def get_stats():
files = get_files()
total_size_bytes = sum(map(lambda x: os.path.getsize(x), files))
pool = TqdmMultiProcessPool(4)
global_tqdm = tqdm.tqdm(
total=total_size_bytes, dynamic_ncols=True, unit="byte", unit_scale=1
)
# Generate minhashes with pool
tasks = [(get_file_stats, (file,)) for file in files]
def on_done(_):
return None
def on_error(_):
return None
results = pool.map(global_tqdm, tasks, on_error, on_done)
total_documents, total_size = reduce(
lambda x, y: (x[0] + y[0], x[1] + y[1]), results
)
start_offsets = []
current_offset = 0
for file_document_count, _ in results:
start_offsets.append(current_offset)
current_offset += file_document_count
return (total_documents, total_size, start_offsets)
if __name__ == "__main__":
version = 1.01
print(f"Running version {version}")
stats_file_path = "pile_statistics.json"
if os.path.exists(stats_file_path):
stats = json.load(open(stats_file_path, "r"))
else:
document_count, total_document_size_chars, start_offsets = get_stats()
stats = {
"Data": "Pile statistics",
"Document Count": document_count,
"Total Pile Characters": total_document_size_chars,
"File Start Offsets": start_offsets,
}
json.dump(stats, open(stats_file_path, "w"), indent=4)
print(f"document_count: {stats['Document Count']}")
print(f"total_chars: {stats['Total Pile Characters']}")
print(f"start_offsets: {stats['File Start Offsets']}")
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <pybind11/stl.h> #include <pybind11/stl.h>
#include <utility> #include <queue>
#include <string> #include <string>
#include <vector>
#include <tuple> #include <tuple>
#include <queue> #include <utility>
#include <vector>
bool is_whitespace(char ch) noexcept { bool is_whitespace(char ch) noexcept {
// " \t\n\r\x0b\x0c" (python string.whitespace) // " \t\n\r\x0b\x0c" (python string.whitespace)
return ch == 32 or (9 <= ch and ch <= 13); return ch == 32 or (9 <= ch and ch <= 13);
// return ch <= 32; // arguably too general, but slightly faster // return ch <= 32; // arguably too general, but slightly faster
} }
bool is_punctuation(char c) noexcept { bool is_punctuation(char c) noexcept {
// '!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~' ascii values: 33-47, 58-64, 91-96, 123-126 // '!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~' ascii values: 33-47, 58-64,
return (33 <= c and c <= 47) or (58 <= c and c <= 64) or (91 <= c and c <= 96) or (123 <= c and c <= 126); // 91-96, 123-126
return (33 <= c and c <= 47) or (58 <= c and c <= 64) or
(91 <= c and c <= 96) or (123 <= c and c <= 126);
} }
// Takes a string and makes ngrams of length N, splitting grams on whitespace and ignoring ignored characters // Takes a string and makes ngrams of length N, splitting grams on whitespace
// Returns a LARGE array of ngrams // and ignoring ignored characters Returns a LARGE array of ngrams
std::vector<std::string> clean_ngram( std::vector<std::string> clean_ngram(std::string const &input,
std::string const & input, std::string const & ignore, size_t ngram_n std::string const &ignore,
) noexcept { size_t ngram_n) noexcept {
size_t num_grams = 0; size_t num_grams = 0;
std::vector<std::string> ngram_list; std::vector<std::string> ngram_list;
std::vector<uint8_t> gram_lengths; std::vector<uint8_t> gram_lengths;
std::string current_ngram; std::string current_ngram;
// Max gram length is set to 10 below. // Max gram length is set to 10 below.
current_ngram.reserve(11*ngram_n); current_ngram.reserve(11 * ngram_n);
gram_lengths.reserve(ngram_n); gram_lengths.reserve(ngram_n);
bool started_gram = false; bool started_gram = false;
gram_lengths.push_back(0); gram_lengths.push_back(0);
//for (size_t i=0; i<input.length(); i++) { // for (size_t i=0; i<input.length(); i++) {
// this is slightly faster, and we don't need the index in this one // this is slightly faster, and we don't need the index in this one
for (auto iter = input.begin(); iter != input.end(); iter++) { for (auto iter = input.begin(); iter != input.end(); iter++) {
// If whitespace, end the current ngram and start the next // If whitespace, end the current ngram and start the next
// alternatively, (perhaps marginally) faster: if (is_whitespace(ch)) { ... } // alternatively, (perhaps marginally) faster: if (is_whitespace(ch)) { ...
if (is_whitespace(*iter) || gram_lengths.back() > 10) { // }
if (is_whitespace(*iter) || gram_lengths.back() > 10) {
// Skip all whitespace
while (++iter != input.end() && is_whitespace(*iter)); // Skip all whitespace
iter--; while (++iter != input.end() && is_whitespace(*iter))
;
if (started_gram){ iter--;
num_grams += 1;
if (started_gram) {
// Building 1grams is a special case num_grams += 1;
if (ngram_n == 1){
ngram_list.push_back(current_ngram); // Building 1grams is a special case
current_ngram = current_ngram.substr(gram_lengths.front()); if (ngram_n == 1) {
gram_lengths.back() = 0; ngram_list.push_back(current_ngram);
current_ngram = current_ngram.substr(gram_lengths.front());
// If there are enough grams to form an ngram, save gram_lengths.back() = 0;
} else if (num_grams >= ngram_n){
// Save the current ngram // If there are enough grams to form an ngram, save
ngram_list.push_back(current_ngram); } else if (num_grams >= ngram_n) {
// Save the current ngram
// Start the next ngram by dropping the first gram and its space from the ngram ngram_list.push_back(current_ngram);
current_ngram = current_ngram.substr(gram_lengths.front() + 1);
current_ngram += ' '; // Start the next ngram by dropping the first gram and its space from
// the ngram
// Drop the length of the first gram and prepare to record the length of the new gram current_ngram = current_ngram.substr(gram_lengths.front() + 1);
gram_lengths.erase(gram_lengths.begin()); current_ngram += ' ';
gram_lengths.push_back(0);
// Drop the length of the first gram and prepare to record the length
// Otherwise, continute building // of the new gram
} else { gram_lengths.erase(gram_lengths.begin());
current_ngram += ' '; gram_lengths.push_back(0);
gram_lengths.push_back(0);
} // Otherwise, continute building
} else {
started_gram = false; current_ngram += ' ';
} gram_lengths.push_back(0);
}
started_gram = false;
}
// Skip ignored characters // Skip ignored characters
// alternatively, (perhaps marginally) faster: if (is_punctuation(ch)) continue; // alternatively, (perhaps marginally) faster: if (is_punctuation(ch))
} else if (ignore.find(*iter) != std::string::npos) { // continue;
continue; } else if (ignore.find(*iter) != std::string::npos) {
} continue;
}
// If it is a non-ignored character, add it to the ngram and update the last gram's length // If it is a non-ignored character, add it to the ngram and update the last
else { // gram's length
current_ngram += tolower(*iter); else {
gram_lengths.back() += 1; current_ngram += tolower(*iter);
started_gram = true; gram_lengths.back() += 1;
} started_gram = true;
} }
}
return ngram_list; return ngram_list;
} }
// Takes a string and makes ngrams of length N, splitting grams on whitespace
// and ignoring ignored characters Returns a LARGE array of tuples of (ngram,
// start_idx, end_idx)
std::vector<std::tuple<std::string, size_t, size_t>>
clean_ngram_with_indices(std::string const &input, std::string const &ignore,
size_t ngram_n) noexcept {
size_t num_grams = 0;
std::vector<std::tuple<std::string, size_t, size_t>> ngram_list;
std::vector<uint8_t> gram_lengths;
std::vector<size_t> gram_start_indices;
std::string current_ngram;
// Max gram length is set to 10 below.
current_ngram.reserve(11 * ngram_n);
bool started_gram = false;
gram_lengths.push_back(0);
gram_start_indices.push_back(0);
for (size_t i = 0; i < input.length(); i++) {
char ch = input[i];
// If whitespace, end the current ngram and start the next
if (is_whitespace(ch) || gram_lengths.back() > 10) {
// Skip all whitespace
while (++i < input.length() && is_whitespace(input[i]))
;
i--;
if (started_gram) {
num_grams += 1;
// Building 1grams is a special case
if (ngram_n == 1) {
ngram_list.push_back(
std::make_tuple(current_ngram, gram_start_indices.front(), i));
current_ngram = current_ngram.substr(gram_lengths.front());
gram_lengths.back() = 0;
gram_start_indices.back() = i + 1;
// If there are enough grams to form an ngram, save
} else if (num_grams >= ngram_n) {
// Save the current ngram
ngram_list.push_back(
std::make_tuple(current_ngram, gram_start_indices.front(), i));
// Start the next ngram by dropping the first gram and its space from
// the ngram
current_ngram = current_ngram.substr(gram_lengths.front() + 1);
current_ngram += ' ';
// Drop the length of the first gram and prepare to record the length
// of the new gram
gram_lengths.erase(gram_lengths.begin());
gram_lengths.push_back(0);
gram_start_indices.erase(gram_start_indices.begin());
gram_start_indices.push_back(i + 1);
// Otherwise, continute building
} else {
current_ngram += ' ';
gram_lengths.push_back(0);
gram_start_indices.push_back(i + 1);
}
// Takes a string and makes ngrams of length N, splitting grams on whitespace and ignoring ignored characters started_gram = false;
// Returns a LARGE array of tuples of (ngram, start_idx, end_idx) }
std::vector<std::tuple<std::string, size_t, size_t> > clean_ngram_with_indices(
std::string const & input, std::string const & ignore, size_t ngram_n
) noexcept {
size_t num_grams = 0;
std::vector<std::tuple<std::string, size_t, size_t> > ngram_list;
std::vector<uint8_t> gram_lengths;
std::vector<size_t> gram_start_indices;
std::string current_ngram;
// Max gram length is set to 10 below.
current_ngram.reserve(11*ngram_n);
bool started_gram = false;
gram_lengths.push_back(0);
gram_start_indices.push_back(0);
for (size_t i=0; i<input.length(); i++) {
char ch = input[i];
// If whitespace, end the current ngram and start the next
if (is_whitespace(ch) || gram_lengths.back() > 10) {
// Skip all whitespace
while (++i < input.length() && is_whitespace(input[i]));
i--;
if (started_gram){
num_grams += 1;
// Building 1grams is a special case
if (ngram_n == 1){
ngram_list.push_back(std::make_tuple(current_ngram, gram_start_indices.front(), i));
current_ngram = current_ngram.substr(gram_lengths.front());
gram_lengths.back() = 0;
gram_start_indices.back() = i+1;
// If there are enough grams to form an ngram, save
} else if (num_grams >= ngram_n){
// Save the current ngram
ngram_list.push_back(
std::make_tuple(current_ngram, gram_start_indices.front(), i)
);
// Start the next ngram by dropping the first gram and its space from the ngram
current_ngram = current_ngram.substr(gram_lengths.front() + 1);
current_ngram += ' ';
// Drop the length of the first gram and prepare to record the length of the new gram
gram_lengths.erase(gram_lengths.begin());
gram_lengths.push_back(0);
gram_start_indices.erase(gram_start_indices.begin());
gram_start_indices.push_back(i+1);
// Otherwise, continute building
} else {
current_ngram += ' ';
gram_lengths.push_back(0);
gram_start_indices.push_back(i+1);
}
started_gram = false;
}
// Skip ignored characters // Skip ignored characters
} else if (ignore.find(*iter) != std::string::npos) { } else if (ignore.find(ch) != std::string::npos) {
continue; continue;
// If it is a non-ignored character, add it to the ngram and update the last gram's length // If it is a non-ignored character, add it to the ngram and update the
} else { // last gram's length
current_ngram += tolower(ch); } else {
gram_lengths.back() += 1; current_ngram += tolower(ch);
started_gram = true; gram_lengths.back() += 1;
} started_gram = true;
} }
}
return ngram_list; return ngram_list;
} }
PYBIND11_MODULE(janitor_util, m) { PYBIND11_MODULE(janitor_util, m) {
m.doc() = "pybind11 example plugin"; // optional module docstring m.doc() = "pybind11 example plugin"; // optional module docstring
// m.def("add", &add, "A function which adds two numbers"); // example function // m.def("add", &add, "A function which adds two numbers"); // example
m.def("clean_ngram", &clean_ngram, "Create ngrams of words, ignoring some characters"); // function
m.def("clean_ngram_with_indices", &clean_ngram_with_indices, "Create ngrams of words with indices, ignoring some characters"); m.def("clean_ngram", &clean_ngram,
"Create ngrams of words, ignoring some characters");
m.def("clean_ngram_with_indices", &clean_ngram_with_indices,
"Create ngrams of words with indices, ignoring some characters");
} }
// Example compile // Example compile
// c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) // c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes)
// If python and gcc aren't linked, append to the above: -undefined dynamic_lookup // janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) If
\ No newline at end of file // python and gcc aren't linked, append to the above: -undefined
// dynamic_lookup
...@@ -27,25 +27,33 @@ from scripts.clean_training_data.archiver import TextReader, TextArchive ...@@ -27,25 +27,33 @@ from scripts.clean_training_data.archiver import TextReader, TextArchive
import logging import logging
from tqdm_multiprocess.logger import setup_logger_tqdm from tqdm_multiprocess.logger import setup_logger_tqdm
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Multiprocessed
def process_bucket(bucket_file_path, processed_directory, move_dir, tqdm_func, global_tqdm):
bucket_id = re.sub("\D", "", os.path.basename(bucket_file_path)) # Multiprocessed
done_file = os.path.join(processed_directory, f"ngram_bucket_processing_{bucket_id}.done") def process_bucket(
bucket_file_path, processed_directory, move_dir, tqdm_func, global_tqdm
):
bucket_id = re.sub("\D", "", os.path.basename(bucket_file_path)) # noqa: W605
done_file = os.path.join(
processed_directory, f"ngram_bucket_processing_{bucket_id}.done"
)
if os.path.exists(done_file): if os.path.exists(done_file):
logger.info(f"bucket {bucket_id} already processed, skipping") logger.info(f"bucket {bucket_id} already processed, skipping")
return return
# For managing tqdm # For managing tqdm
file_size = os.path.getsize(bucket_file_path) file_size = os.path.getsize(bucket_file_path)
bucket_progress = tqdm_func(total=file_size, dynamic_ncols=True, unit="byte", unit_scale=1) bucket_progress = tqdm_func(
total=file_size, dynamic_ncols=True, unit="byte", unit_scale=1
)
current_file_position = 0 current_file_position = 0
update_frequency = 100 * 1000000 # 100mb update_frequency = 100 * 1000000 # 100mb
update_counter = 0 update_counter = 0
# Iterate through and output ngrams which occur in more then 10 documents # Iterate through and output ngrams which occur in more then 10 documents
bucket = TextReader(bucket_file_path) bucket = TextReader(bucket_file_path)
output_file_path = bucket_file_path + ".processed" output_file_path = bucket_file_path + ".processed"
...@@ -56,10 +64,12 @@ def process_bucket(bucket_file_path, processed_directory, move_dir, tqdm_func, g ...@@ -56,10 +64,12 @@ def process_bucket(bucket_file_path, processed_directory, move_dir, tqdm_func, g
for line in bucket.read(): for line in bucket.read():
[ngram, document_id] = line.rsplit(" ", 1) [ngram, document_id] = line.rsplit(" ", 1)
# Write ngram if more then 10 unique document occurences # Write ngram if more then 10 unique document occurrences
if ngram != current_ngram: if ngram != current_ngram:
if len(current_ngram_document_ids) > 10: if len(current_ngram_document_ids) > 10:
output_archive.add_data(f"{current_ngram} {len(current_ngram_document_ids)}") output_archive.add_data(
f"{current_ngram} {len(current_ngram_document_ids)}"
)
current_ngram = ngram current_ngram = ngram
current_ngram_document_ids = set() current_ngram_document_ids = set()
...@@ -84,28 +94,38 @@ def process_bucket(bucket_file_path, processed_directory, move_dir, tqdm_func, g ...@@ -84,28 +94,38 @@ def process_bucket(bucket_file_path, processed_directory, move_dir, tqdm_func, g
global_tqdm.update() global_tqdm.update()
def process_sorted_buckets(working_directory, move_dir, process_count): def process_sorted_buckets(working_directory, move_dir, process_count):
bucket_file_paths = glob.glob(os.path.join(working_directory, f"*.bkt.txt.sorted")) bucket_file_paths = glob.glob(os.path.join(working_directory, f"*.bkt.txt.sorted"))
processed_directory = os.path.join(working_directory, "processed") processed_directory = os.path.join(working_directory, "processed")
os.makedirs(processed_directory, exist_ok=True) os.makedirs(processed_directory, exist_ok=True)
pool = TqdmMultiProcessPool(process_count) pool = TqdmMultiProcessPool(process_count)
tasks = [(process_bucket, (bucket_file, processed_directory, move_dir)) for bucket_file in bucket_file_paths] tasks = [
(process_bucket, (bucket_file, processed_directory, move_dir))
for bucket_file in bucket_file_paths
]
global_tqdm = tqdm(total=len(bucket_file_paths), dynamic_ncols=True, unit="bucket") global_tqdm = tqdm(total=len(bucket_file_paths), dynamic_ncols=True, unit="bucket")
on_done = lambda _ : None
on_error = lambda _ : None def on_done(_):
return None
def on_error(_):
return None
_ = pool.map(global_tqdm, tasks, on_error, on_done) _ = pool.map(global_tqdm, tasks, on_error, on_done)
parser = argparse.ArgumentParser(description='Process 13 grams from sorted buckets.')
parser = argparse.ArgumentParser(description="Process 13 grams from sorted buckets.")
parser.add_argument("-dir", "--working_directory", default="") parser.add_argument("-dir", "--working_directory", default="")
parser.add_argument("-move", "--move_dir", default="") parser.add_argument("-move", "--move_dir", default="")
parser.add_argument("-procs", "--process_count", type=int, default=4) parser.add_argument("-procs", "--process_count", type=int, default=4)
if __name__ == '__main__': if __name__ == "__main__":
logfile_path = "process13grams.log" logfile_path = "process13grams.log"
setup_logger_tqdm(logfile_path) setup_logger_tqdm(logfile_path)
args = parser.parse_args() args = parser.parse_args()
process_sorted_buckets(args.working_directory, args.move_dir, args.process_count) process_sorted_buckets(args.working_directory, args.move_dir, args.process_count)
\ No newline at end of file
""" """
Iteratively runs gnu sort on each bucket, gnu handles the multiprocessing. Iteratively runs gnu sort on each bucket, uses up to 8 cores.
Arguments Arguments
--------- ---------
...@@ -11,48 +11,47 @@ Arguments ...@@ -11,48 +11,47 @@ Arguments
import glob import glob
import argparse import argparse
import os import os
from pathlib import Path
import signal import signal
from signal import SIGINT from signal import SIGINT
import re
import subprocess import subprocess
from tqdm import tqdm from tqdm import tqdm
import logging import logging
from tqdm_multiprocess.logger import setup_logger_tqdm from tqdm_multiprocess.logger import setup_logger_tqdm
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
terminate = False terminate = False
def handler(signal_received, frame): def handler(signal_received, frame):
global terminate global terminate
terminate = True terminate = True
def sort_13_gram_buckets(working_directory): def sort_13_gram_buckets(working_directory):
bucket_file_paths = glob.glob(os.path.join(working_directory, f"*.bkt.txt")) bucket_file_paths = glob.glob(os.path.join(working_directory, f"*.bkt.txt"))
for bucket_file_path in tqdm(bucket_file_paths, dynamic_ncols=True): for bucket_file_path in tqdm(bucket_file_paths, dynamic_ncols=True):
bucket_id = re.sub("\D", "", os.path.basename(bucket_file_path))
done_file = os.path.join(working_directory, f"ngram_bucket_sorting_{bucket_id}.done")
if os.path.exists(done_file):
logger.info(f"bucket {bucket_id} already processed, skipping")
return
sorted_file_path = bucket_file_path + ".sorted" sorted_file_path = bucket_file_path + ".sorted"
command = f"sort {bucket_file_path} > {sorted_file_path}" command = f"sort {bucket_file_path} > {sorted_file_path}"
logger.info(command) logger.info(command)
subprocess.call(command, shell=True) subprocess.call(command, shell=True)
if terminate: if terminate:
return return
Path(done_file).touch()
os.remove(bucket_file_path) os.remove(bucket_file_path)
parser = argparse.ArgumentParser(description='sort 13gram buckets')
parser = argparse.ArgumentParser(description="sort 13gram buckets")
parser.add_argument("-dir", "--working_directory", default="") parser.add_argument("-dir", "--working_directory", default="")
if __name__ == '__main__': if __name__ == "__main__":
version = 1.00
print(f"Running version {version}")
# Handle sigint (ctrl-c) cleanly # Handle sigint (ctrl-c) cleanly
previous_signal_int = signal.signal(SIGINT, handler) previous_signal_int = signal.signal(SIGINT, handler)
...@@ -61,4 +60,4 @@ if __name__ == '__main__': ...@@ -61,4 +60,4 @@ if __name__ == '__main__':
setup_logger_tqdm(logfile_path) setup_logger_tqdm(logfile_path)
args = parser.parse_args() args = parser.parse_args()
sort_13_gram_buckets(args.working_directory) sort_13_gram_buckets(args.working_directory)
\ No newline at end of file
...@@ -7,7 +7,7 @@ from lm_eval.base import LM ...@@ -7,7 +7,7 @@ from lm_eval.base import LM
class DryrunLM(LM): class DryrunLM(LM):
def __init__(self): def __init__(self):
self.tokencost = 0 self.tokencost = 0
self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2') self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained("gpt2")
self.tokenizer.pad_token = "<|endoftext|>" self.tokenizer.pad_token = "<|endoftext|>"
@classmethod @classmethod
...@@ -16,28 +16,28 @@ class DryrunLM(LM): ...@@ -16,28 +16,28 @@ class DryrunLM(LM):
def loglikelihood(self, requests): def loglikelihood(self, requests):
res = [] res = []
for ctx, cont in requests: for ctx, cont in requests:
res.append((-random.random(), False)) res.append((-random.random(), False))
self.tokencost += len(self.tokenizer.tokenize(ctx + cont)) self.tokencost += len(self.tokenizer.tokenize(ctx + cont))
return res return res
def greedy_until(self, requests): def greedy_until(self, requests):
res = [] res = []
for ctx, until in requests: for ctx, _ in requests:
res.append("lol") res.append("lol")
# assume worst case - generates until 256 # assume worst case - generates until 256
self.tokencost += len(self.tokenizer.tokenize(ctx)) + 256 self.tokencost += len(self.tokenizer.tokenize(ctx)) + 256
return res return res
def loglikelihood_rolling(self, requests): def loglikelihood_rolling(self, requests):
res = [] res = []
for s, in requests: for (s,) in requests:
# assume worst case: extra full context # assume worst case: extra full context
self.tokencost += len(self.tokenizer.tokenize(s)) + 2048 self.tokencost += len(self.tokenizer.tokenize(s)) + 2048
...@@ -46,7 +46,7 @@ class DryrunLM(LM): ...@@ -46,7 +46,7 @@ class DryrunLM(LM):
def main(): def main():
lm = DryrunLM() lm = DryrunLM()
task_list = "arc_challenge,arc_easy,boolq,cola,copa,headqa,hellaswag,lambada,logiqa,mathqa,mc_taco,mrpc,multirc,openbookqa,piqa,prost,pubmedqa,qnli,qqp,race,record,rte,sciq,sst,triviaqa,webqs,wic,wikitext,winogrande,wnli,wsc" task_list = "arc_challenge,arc_easy,boolq,cola,copa,headqa,hellaswag,lambada,logiqa,mathqa,mc_taco,mrpc,multirc,openbookqa,piqa,prost,pubmedqa,qnli,qqp,race,record,rte,sciq,sst,triviaqa,webqs,wic,wikitext,winogrande,wnli,wsc"
values = [] values = []
for taskname in task_list.split(","): for taskname in task_list.split(","):
...@@ -57,11 +57,20 @@ def main(): ...@@ -57,11 +57,20 @@ def main():
num_fewshot=0, num_fewshot=0,
limit=None, limit=None,
bootstrap_iters=10, bootstrap_iters=10,
description_dict=None description_dict=None,
) )
print(taskname, lm.tokencost) print(taskname, lm.tokencost)
values.append([taskname, lm.tokencost, lm.tokencost / 1000 * 0.0008, lm.tokencost / 1000 * 0.0012, lm.tokencost / 1000 * 0.006, lm.tokencost / 1000 * 0.06]) values.append(
[
taskname,
lm.tokencost,
lm.tokencost / 1000 * 0.0008,
lm.tokencost / 1000 * 0.0012,
lm.tokencost / 1000 * 0.006,
lm.tokencost / 1000 * 0.06,
]
)
from pytablewriter import MarkdownTableWriter from pytablewriter import MarkdownTableWriter
writer = MarkdownTableWriter() writer = MarkdownTableWriter()
...@@ -69,10 +78,21 @@ def main(): ...@@ -69,10 +78,21 @@ def main():
values.sort(key=lambda x: -x[1]) values.sort(key=lambda x: -x[1])
totcost = sum([x[1] for x in values]) totcost = sum([x[1] for x in values])
values.append(["**Total**", totcost, totcost / 1000 * 0.0008, totcost / 1000 * 0.0012, totcost / 1000 * 0.006, totcost / 1000 * 0.06]) values.append(
[
"**Total**",
totcost,
totcost / 1000 * 0.0008,
totcost / 1000 * 0.0012,
totcost / 1000 * 0.006,
totcost / 1000 * 0.06,
]
)
writer.value_matrix = values writer.value_matrix = values
print(writer.dumps()) print(writer.dumps())
if __name__ == "__main__": if __name__ == "__main__":
main() main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment