Commit baa8b0d3 authored by bzantium's avatar bzantium
Browse files

fix for merge from master

parent a956bc63
......@@ -49,29 +49,29 @@ class WordUnscrambleTask(Task):
def doc_to_text(self, doc):
return doc["context"]
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["context"]
def doc_to_target(self, doc):
return doc["completion"]
def construct_requests(self, doc, ctx):
completion = rf.greedy_until(ctx, ["\n"])
completion = rf.greedy_until(ctx, {"until": ["\n"]})
return completion
def process_results(self, doc, results):
pred = results[0]
gold = doc["completion"]
return {
"acc": int(pred == gold)
}
return {"acc": int(pred == gold)}
def aggregation(self):
return {
"acc": mean
}
return {"acc": mean}
def higher_is_better(self):
return {
"acc": True
}
return {"acc": True}
class Anagrams1(WordUnscrambleTask):
......
......@@ -54,14 +54,20 @@ class WebQs(Task):
return self.dataset["test"]
def doc_to_text(self, doc):
return "Question: " + doc['question'] + '\nAnswer:'
return "Question: " + doc["question"] + "\nAnswer:"
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["question"]
def doc_to_target(self, doc):
# this picks one answer to be the "correct" one, despite sometimes
# this picks one answer to be the "correct" one, despite sometimes
# multiple correct answers being possible.
# TODO: make sure we're actually handling multi-answer correctly
return " " + doc['answers'][0]
return " " + doc["answers"][0]
def _remove_prefixes(self, aliases):
# Optimization: Remove any alias that has a strict prefix elsewhere in the list
# we can do this because if the prefix is acceptable by isgreedy, we can stop looking
......@@ -75,15 +81,13 @@ class WebQs(Task):
def construct_requests(self, doc, ctx):
ret = []
for alias in self._remove_prefixes(doc['answers']):
for alias in self._remove_prefixes(doc["answers"]):
_, is_prediction = rf.loglikelihood(ctx, " " + alias)
ret.append(is_prediction)
return ret
def process_results(self, doc, results):
return {
"acc": float(any(results))
}
return {"acc": float(any(results))}
def aggregation(self):
return {
......@@ -91,6 +95,4 @@ class WebQs(Task):
}
def higher_is_better(self):
return {
"acc": True
}
return {"acc": True}
......@@ -2,7 +2,7 @@
Pointer Sentinel Mixture Models
https://arxiv.org/pdf/1609.07843.pdf
The WikiText language modeling dataset is a collection of over 100 million tokens
The WikiText language modeling dataset is a collection of over 100 million tokens
extracted from the set of verified Good and Featured articles on Wikipedia.
NOTE: This `Task` is based on WikiText-2.
......@@ -10,14 +10,12 @@ NOTE: This `Task` is based on WikiText-2.
Homepage: https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/
"""
import re
import inspect
import lm_eval.datasets.wikitext.wikitext
from lm_eval.base import PerplexityTask
_CITATION = """
@misc{merity2016pointer,
title={Pointer Sentinel Mixture Models},
title={Pointer Sentinel Mixture Models},
author={Stephen Merity and Caiming Xiong and James Bradbury and Richard Socher},
year={2016},
eprint={1609.07843},
......@@ -63,7 +61,7 @@ def wikitext_detokenizer(string):
class WikiText(PerplexityTask):
VERSION = 1
DATASET_PATH = inspect.getfile(lm_eval.datasets.wikitext.wikitext)
DATASET_PATH = "EleutherAI/wikitext_document_level"
DATASET_NAME = "wikitext-2-raw-v1"
def has_training_docs(self):
......@@ -76,20 +74,23 @@ class WikiText(PerplexityTask):
return True
def training_docs(self):
return map(self._load_doc, self.dataset["train"])
return map(self._process_doc, self.dataset["train"])
def validation_docs(self):
return map(self._load_doc, self.dataset["validation"])
return map(self._process_doc, self.dataset["validation"])
def test_docs(self):
return map(self._load_doc, self.dataset["test"])
return map(self._process_doc, self.dataset["test"])
def _load_doc(self, doc):
def _process_doc(self, doc):
return doc["page"]
def doc_to_target(self, doc):
return wikitext_detokenizer(doc)
def should_decontaminate(self):
return True
def count_words(self, doc):
# count number of words in *original doc before detokenization*
return len(re.split(r"\s+", doc))
"""
WinoGrande: An Adversarial Winograd Schema Challenge at Scale
https://arxiv.org/pdf/1907.10641.pdf
WinoGrande is a collection of 44k problems, inspired by Winograd Schema Challenge
(Levesque, Davis, and Morgenstern 2011), but adjusted to improve the scale and
robustness against the dataset-specific bias. Formulated as a fill-in-a-blank
task with binary options, the goal is to choose the right option for a given
sentence which requires commonsense reasoning.
NOTE: This evaluation of Winogrande uses partial evaluation as described by
Trinh & Le in Simple Method for Commonsense Reasoning (2018).
See: https://arxiv.org/abs/1806.02847
Homepage: https://leaderboard.allenai.org/winogrande/submissions/public
"""
import numpy as np
from lm_eval.base import rf, Task
from lm_eval.metrics import mean
_CITATION = """
@article{sakaguchi2019winogrande,
title={WinoGrande: An Adversarial Winograd Schema Challenge at Scale},
author={Sakaguchi, Keisuke and Bras, Ronan Le and Bhagavatula, Chandra and Choi, Yejin},
journal={arXiv preprint arXiv:1907.10641},
year={2019}
}
"""
class Winogrande(Task):
VERSION = 0
DATASET_PATH = "winogrande"
DATASET_NAME = "winogrande_xl"
answer_to_num = {'1': 0, '2': 1}
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return False
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def doc_to_text(self, doc):
return self.partial_context(doc, doc["option" + doc["answer"]])
@classmethod
def partial_context(cls, doc, option):
# Substitute the pronoun in the sentence with the specified option
# and ignore everything after.
pronoun_loc = doc["sentence"].index("_")
return doc["sentence"][:pronoun_loc] + option
def doc_to_target(self, doc):
return self.partial_target(doc)
@classmethod
def partial_target(cls, doc):
# The target is everything after the document specified pronoun.
pronoun_loc = doc["sentence"].index("_") + 1
return " " + doc["sentence"][pronoun_loc:].strip()
def construct_requests(self, doc, ctx):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
target = self.partial_target(doc)
lls = []
for option in [doc["option1"], doc["option2"]]:
partial_ctx = self.partial_context(doc, option)
full_ctx = self.append_context(ctx, partial_ctx)
lls.append(rf.loglikelihood(full_ctx, target)[0])
return lls
@classmethod
def append_context(cls, ctx, partial_ctx):
ctx = ctx.split("\n\n") # Each fewshot context is on its own new line.
ctx.pop() # Remove the correct context put in by `doc_to_text`.
return "\n\n".join([*ctx, partial_ctx]) if ctx else partial_ctx
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
The results of the requests created in construct_requests.
"""
return {
"acc": np.argmax(results) == self.answer_to_num[doc["answer"]]
}
def aggregation(self):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
return {
"acc": mean
}
def higher_is_better(self):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return {
"acc": True
}
"""
WinoGrande: An Adversarial Winograd Schema Challenge at Scale
https://arxiv.org/pdf/1907.10641.pdf
WinoGrande is a collection of 44k problems, inspired by Winograd Schema Challenge
(Levesque, Davis, and Morgenstern 2011), but adjusted to improve the scale and
robustness against the dataset-specific bias. Formulated as a fill-in-a-blank
task with binary options, the goal is to choose the right option for a given
sentence which requires commonsense reasoning.
NOTE: This evaluation of Winogrande uses partial evaluation as described by
Trinh & Le in Simple Method for Commonsense Reasoning (2018).
See: https://arxiv.org/abs/1806.02847
Homepage: https://leaderboard.allenai.org/winogrande/submissions/public
"""
import numpy as np
from lm_eval.base import rf, Task
from lm_eval.metrics import mean
_CITATION = """
@article{sakaguchi2019winogrande,
title={WinoGrande: An Adversarial Winograd Schema Challenge at Scale},
author={Sakaguchi, Keisuke and Bras, Ronan Le and Bhagavatula, Chandra and Choi, Yejin},
journal={arXiv preprint arXiv:1907.10641},
year={2019}
}
"""
class Winogrande(Task):
VERSION = 0
DATASET_PATH = "winogrande"
DATASET_NAME = "winogrande_xl"
answer_to_num = {"1": 0, "2": 1}
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return False
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def doc_to_text(self, doc):
return self.partial_context(doc, doc["option" + doc["answer"]])
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["sentence"]
@classmethod
def partial_context(cls, doc, option):
# Substitute the pronoun in the sentence with the specified option
# and ignore everything after.
pronoun_loc = doc["sentence"].index("_")
return doc["sentence"][:pronoun_loc] + option
def doc_to_target(self, doc):
return self.partial_target(doc)
@classmethod
def partial_target(cls, doc):
# The target is everything after the document specified pronoun.
pronoun_loc = doc["sentence"].index("_") + 1
return " " + doc["sentence"][pronoun_loc:].strip()
def construct_requests(self, doc, ctx):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
target = self.partial_target(doc)
lls = []
for option in [doc["option1"], doc["option2"]]:
partial_ctx = self.partial_context(doc, option)
full_ctx = self.append_context(ctx, partial_ctx)
lls.append(rf.loglikelihood(full_ctx, target)[0])
return lls
@classmethod
def append_context(cls, ctx, partial_ctx):
ctx = ctx.split("\n\n") # Each fewshot context is on its own new line.
ctx.pop() # Remove the correct context put in by `doc_to_text`.
return "\n\n".join([*ctx, partial_ctx]) if ctx else partial_ctx
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
The results of the requests created in construct_requests.
"""
return {"acc": np.argmax(results) == self.answer_to_num[doc["answer"]]}
def aggregation(self):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
return {"acc": mean}
def higher_is_better(self):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return {"acc": True}
......@@ -40,8 +40,19 @@ class WinogradSchemaChallenge273(Task):
DATASET_PATH = "winograd_wsc"
DATASET_NAME = "wsc273"
upper_pronouns = ["A", "An", "The", "She", "He",
"It", "They", "My", "His", "Her", "Their"]
upper_pronouns = [
"A",
"An",
"The",
"She",
"He",
"It",
"They",
"My",
"His",
"Her",
"Their",
]
def has_training_docs(self):
return False
......@@ -53,9 +64,9 @@ class WinogradSchemaChallenge273(Task):
return True
def test_docs(self):
return map(self._load_doc, self.dataset["test"])
return map(self._process_doc, self.dataset["test"])
def _load_doc(self, doc):
def _process_doc(self, doc):
# The HF implementation of `wsc273` is not `partial evaluation` friendly.
doc["text"] = doc["text"].replace(" ", " ")
doc["options"][0] = self.__normalize_option(doc, doc["options"][0])
......@@ -68,7 +79,7 @@ class WinogradSchemaChallenge273(Task):
option += "'s"
# Appropriately lowercase the pronoun in the option.
pronoun = option.split()[0]
start_of_sentence = doc["text"][doc['pronoun_loc'] - 2] == '.'
start_of_sentence = doc["text"][doc["pronoun_loc"] - 2] == "."
if not start_of_sentence and pronoun in self.upper_pronouns:
return option.replace(pronoun, pronoun.lower())
return option
......@@ -85,11 +96,17 @@ class WinogradSchemaChallenge273(Task):
def doc_to_text(self, doc):
return self.partial_context(doc, doc["options"][doc["label"]])
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["text"]
@classmethod
def partial_context(cls, doc, option):
# Substitute the pronoun in the original text with the specified
# option and ignore everything after.
return doc["text"][:doc["pronoun_loc"]] + option
return doc["text"][: doc["pronoun_loc"]] + option
def doc_to_target(self, doc):
return self.partial_target(doc)
......@@ -135,9 +152,7 @@ class WinogradSchemaChallenge273(Task):
:param results:
The results of the requests created in construct_requests.
"""
return {
"acc": np.argmax(results) == doc["label"]
}
return {"acc": np.argmax(results) == doc["label"]}
def aggregation(self):
"""
......@@ -145,9 +160,7 @@ class WinogradSchemaChallenge273(Task):
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
return {
"acc": mean
}
return {"acc": mean}
def higher_is_better(self):
"""
......@@ -155,6 +168,4 @@ class WinogradSchemaChallenge273(Task):
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return {
"acc": True
}
return {"acc": True}
"""
XCOPA: A Multilingual Dataset for Causal Commonsense Reasoning
https://ducdauge.github.io/files/xcopa.pdf
The Cross-lingual Choice of Plausible Alternatives dataset is a benchmark to evaluate the ability of machine learning models to transfer commonsense reasoning across languages.
The dataset is the translation and reannotation of the English COPA (Roemmele et al. 2011) and covers 11 languages from 11 families and several areas around the globe.
The dataset is challenging as it requires both the command of world knowledge and the ability to generalise to new languages.
All the details about the creation of XCOPA and the implementation of the baselines are available in the paper.
Homepage: https://github.com/cambridgeltl/xcopa
"""
from .superglue import Copa
_CITATION = """
@inproceedings{ponti2020xcopa,
title={{XCOPA: A} Multilingual Dataset for Causal Commonsense Reasoning},
author={Edoardo M. Ponti, Goran Glava\v{s}, Olga Majewska, Qianchu Liu, Ivan Vuli\'{c} and Anna Korhonen},
booktitle={Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing (EMNLP)},
year={2020},
url={https://ducdauge.github.io/files/xcopa.pdf}
}
"""
class XCopa(Copa):
VERSION = 0
DATASET_PATH = "xcopa"
DATASET_NAME = None
CAUSE = "because"
EFFECT = "therefore"
def has_training_docs(self):
return False
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
def validation_docs(self):
return self.dataset["validation"]
def test_docs(self):
return self.dataset["test"]
def doc_to_text(self, doc):
# Drop the period
connector = {
"cause": self.CAUSE,
"effect": self.EFFECT,
}[doc["question"]]
return doc["premise"].strip()[:-1] + f" {connector}"
class XCopaEt(XCopa):
DATASET_NAME = "et"
CAUSE = "sest"
EFFECT = "seetõttu"
class XCopaHt(XCopa):
DATASET_NAME = "ht"
CAUSE = "poukisa"
EFFECT = "donk sa"
class XCopaIt(XCopa):
DATASET_NAME = "it"
CAUSE = "perché"
EFFECT = "quindi"
class XCopaId(XCopa):
DATASET_NAME = "id"
CAUSE = "karena"
EFFECT = "maka"
class XCopaQu(XCopa):
DATASET_NAME = "qu"
CAUSE = "imataq"
EFFECT = "chaymi"
class XCopaSw(XCopa):
DATASET_NAME = "sw"
CAUSE = "kwa sababu"
EFFECT = "kwa hiyo"
class XCopaZh(XCopa):
DATASET_NAME = "zh"
CAUSE = "因为"
EFFECT = "所以"
class XCopaTa(XCopa):
DATASET_NAME = "ta"
CAUSE = "காரணமாக"
EFFECT = "எனவே"
class XCopaTh(XCopa):
DATASET_NAME = "th"
CAUSE = "เพราะ"
EFFECT = "ดังนั้น"
class XCopaTr(XCopa):
DATASET_NAME = "tr"
CAUSE = "çünkü"
EFFECT = "bu yüzden"
class XCopaVi(XCopa):
DATASET_NAME = "vi"
CAUSE = "bởi vì"
EFFECT = "vì vậy"
LANGS = ["et", "ht", "it", "id", "qu", "sw", "zh", "ta", "th", "tr", "vi"]
LANG_CLASSES = [
XCopaEt,
XCopaHt,
XCopaIt,
XCopaId,
XCopaQu,
XCopaSw,
XCopaZh,
XCopaTa,
XCopaTh,
XCopaTr,
XCopaVi,
]
def construct_tasks():
tasks = {}
for lang, lang_class in zip(LANGS, LANG_CLASSES):
tasks[f"xcopa_{lang}"] = lang_class
return tasks
"""
XNLI: Evaluating Cross-lingual Sentence Representations
https://arxiv.org/abs/1809.05053
Based on the implementation of @yongzx (see https://github.com/EleutherAI/lm-evaluation-harness/pull/258)
Prompt format (same as XGLM and mGPT):
sentence1 + ", right? " + mask = (Yes|Also|No) + ", " + sentence2
Predicition is the full sequence with the highest likelihood.
Language specific prompts are translated word-by-word with Google Translate
and may differ from the ones used by mGPT and XGLM (they do not provide their prompts).
Homepage: https://github.com/facebookresearch/XNLI
"""
import numpy as np
from lm_eval.base import rf, Task
from lm_eval.metrics import mean
_CITATIONS = """
@InProceedings{conneau2018xnli,
author = "Conneau, Alexis
and Rinott, Ruty
and Lample, Guillaume
and Williams, Adina
and Bowman, Samuel R.
and Schwenk, Holger
and Stoyanov, Veselin",
title = "XNLI: Evaluating Cross-lingual Sentence Representations",
booktitle = "Proceedings of the 2018 Conference on Empirical Methods
in Natural Language Processing",
year = "2018",
publisher = "Association for Computational Linguistics",
location = "Brussels, Belgium",
}
"""
class XNLIBase(Task):
VERSION = 0
DATASET_PATH = "xnli"
DATASET_NAME = None
QUESTION_WORD = None # 'right'
ENTAILMENT_LABEL = None # 'Yes'
NEUTRAL_LABEL = None # 'Also'
CONTRADICTION_LABEL = None # 'No'
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
def training_docs(self):
return self.dataset["train"]
def validation_docs(self):
return self.dataset["validation"]
def test_docs(self):
return self.dataset["test"]
def doc_to_text(self, doc):
# Example:
# The girl that can help me is all the way across town, right? Yes, The girl I need help from lives a ways away.
# [MASK] is replaced with ENTAILMENT_LABEL, NEUTRAL_LABEL, or CONTRADICTION_LABEL
return (
doc["premise"]
+ ", "
+ self.QUESTION_WORD
+ "? [MASK], "
+ doc["hypothesis"]
)
def doc_to_target(self, doc):
# True = entailment
# False = contradiction
# Neither = neutral
return (
" "
+ [self.ENTAILMENT_LABEL, self.NEUTRAL_LABEL, self.CONTRADICTION_LABEL][
doc["label"]
]
)
def construct_requests(self, doc, ctx):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
ll_true = rf.loglikelihood_rolling(ctx.replace("[MASK]", self.ENTAILMENT_LABEL))
ll_neither = rf.loglikelihood_rolling(ctx.replace("[MASK]", self.NEUTRAL_LABEL))
ll_false = rf.loglikelihood_rolling(
ctx.replace("[MASK]", self.CONTRADICTION_LABEL)
)
return ll_true, ll_neither, ll_false
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
The results of the requests created in construct_requests.
"""
gold = doc["label"]
pred = np.argmax(results)
return {"acc": pred == gold}
def aggregation(self):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
return {"acc": mean}
def higher_is_better(self):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return {"acc": True}
class XNLI_en(XNLIBase): # English
DATASET_NAME = "en"
QUESTION_WORD = "right"
ENTAILMENT_LABEL = "Yes"
NEUTRAL_LABEL = "Also"
CONTRADICTION_LABEL = "No"
class XNLI_de(XNLIBase): # German
DATASET_NAME = "de"
QUESTION_WORD = "richtig"
ENTAILMENT_LABEL = "Ja"
NEUTRAL_LABEL = "Auch"
CONTRADICTION_LABEL = "Nein"
class XNLI_ar(XNLIBase): # Arabic
DATASET_NAME = "ar"
QUESTION_WORD = "صحيح"
ENTAILMENT_LABEL = "نعم"
NEUTRAL_LABEL = "لذا"
CONTRADICTION_LABEL = "رقم"
class XNLI_bg(XNLIBase): # Bulgarian
DATASET_NAME = "bg"
QUESTION_WORD = "правилно"
ENTAILMENT_LABEL = "да"
NEUTRAL_LABEL = "така"
CONTRADICTION_LABEL = "не"
class XNLI_el(XNLIBase): # Greek
DATASET_NAME = "el"
QUESTION_WORD = "σωστός"
ENTAILMENT_LABEL = "Ναί"
NEUTRAL_LABEL = "Έτσι"
CONTRADICTION_LABEL = "όχι"
class XNLI_es(XNLIBase): # Spanish
DATASET_NAME = "es"
QUESTION_WORD = "correcto"
ENTAILMENT_LABEL = "Sí"
NEUTRAL_LABEL = "Asi que"
CONTRADICTION_LABEL = "No"
class XNLI_fr(XNLIBase): # French
DATASET_NAME = "fr"
QUESTION_WORD = "correct"
ENTAILMENT_LABEL = "Oui"
NEUTRAL_LABEL = "Aussi"
CONTRADICTION_LABEL = "Non"
class XNLI_hi(XNLIBase): # Hindi
DATASET_NAME = "hi"
QUESTION_WORD = "सही"
ENTAILMENT_LABEL = "हाँ"
NEUTRAL_LABEL = "इसलिए"
CONTRADICTION_LABEL = "नहीं"
class XNLI_ru(XNLIBase): # Russian
DATASET_NAME = "ru"
QUESTION_WORD = "правильно"
ENTAILMENT_LABEL = "Да"
NEUTRAL_LABEL = "Так"
CONTRADICTION_LABEL = "Нет"
class XNLI_sw(XNLIBase): # Swahili
DATASET_NAME = "sw"
QUESTION_WORD = "sahihi"
ENTAILMENT_LABEL = "Ndiyo"
NEUTRAL_LABEL = "Hivyo"
CONTRADICTION_LABEL = "Hapana"
class XNLI_th(XNLIBase): # Thai
DATASET_NAME = "th"
QUESTION_WORD = "ถูกต้อง"
ENTAILMENT_LABEL = "ใช่"
NEUTRAL_LABEL = "ดังนั้น"
CONTRADICTION_LABEL = "ไม่"
class XNLI_tr(XNLIBase): # Turkish
DATASET_NAME = "tr"
QUESTION_WORD = "doğru"
ENTAILMENT_LABEL = "Evet"
NEUTRAL_LABEL = "Böylece"
CONTRADICTION_LABEL = "Hayır"
class XNLI_ur(XNLIBase): # Urdu
DATASET_NAME = "ur"
QUESTION_WORD = "صحیح"
ENTAILMENT_LABEL = "جی ہاں"
NEUTRAL_LABEL = "اس لئے"
CONTRADICTION_LABEL = "نہیں"
class XNLI_vi(XNLIBase): # Vietnamese
DATASET_NAME = "vi"
QUESTION_WORD = "đúng"
ENTAILMENT_LABEL = "Vâng"
NEUTRAL_LABEL = "Vì vậy"
CONTRADICTION_LABEL = "Không"
class XNLI_zh(XNLIBase): # Chinese
DATASET_NAME = "zh"
QUESTION_WORD = "正确"
ENTAILMENT_LABEL = "是的"
NEUTRAL_LABEL = "所以"
CONTRADICTION_LABEL = "不是的"
LANGS = [
"ar",
"bg",
"de",
"el",
"en",
"es",
"fr",
"hi",
"ru",
"sw",
"th",
"tr",
"ur",
"vi",
"zh",
]
LANG_CLASSES = [
XNLI_ar,
XNLI_bg,
XNLI_de,
XNLI_el,
XNLI_en,
XNLI_es,
XNLI_fr,
XNLI_hi,
XNLI_ru,
XNLI_sw,
XNLI_th,
XNLI_tr,
XNLI_ur,
XNLI_vi,
XNLI_zh,
]
def construct_tasks():
tasks = {}
for lang, lang_class in zip(LANGS, LANG_CLASSES):
tasks[f"xnli_{lang}"] = lang_class
return tasks
"""
Few-shot Learning with Multilingual Language Models
https://arxiv.org/abs/2112.10668
XStoryCloze consists of the professionally translated version of the [English StoryCloze dataset](https://cs.rochester.edu/nlp/rocstories/) (Spring 2016 version) to 10 non-English languages. This dataset is released by Meta AI.
Homepage: https://github.com/facebookresearch/fairseq/pull/4820
"""
from .storycloze import StoryCloze
_CITATION = """
@article{DBLP:journals/corr/abs-2112-10668,
author = {Xi Victoria Lin and
Todor Mihaylov and
Mikel Artetxe and
Tianlu Wang and
Shuohui Chen and
Daniel Simig and
Myle Ott and
Naman Goyal and
Shruti Bhosale and
Jingfei Du and
Ramakanth Pasunuru and
Sam Shleifer and
Punit Singh Koura and
Vishrav Chaudhary and
Brian O'Horo and
Jeff Wang and
Luke Zettlemoyer and
Zornitsa Kozareva and
Mona T. Diab and
Veselin Stoyanov and
Xian Li},
title = {Few-shot Learning with Multilingual Language Models},
journal = {CoRR},
volume = {abs/2112.10668},
year = {2021},
url = {https://arxiv.org/abs/2112.10668},
eprinttype = {arXiv},
eprint = {2112.10668},
timestamp = {Tue, 04 Jan 2022 15:59:27 +0100},
biburl = {https://dblp.org/rec/journals/corr/abs-2112-10668.bib},
bibsource = {dblp computer science bibliography, https://dblp.org}
}
"""
_LANG = ["en", "ru", "zh", "es", "ar", "hi", "id", "te", "sw", "eu", "my"]
def create_all_tasks():
"""Creates a dictionary of tasks from a list of subjects
:return: {task_name: task}
"""
return {f"xstory_cloze_{lang}": create_task(lang) for lang in _LANG}
def create_task(lang):
class XStoryCloze(StoryCloze):
DATASET_PATH = "juletxara/xstory_cloze"
DATASET_NAME = lang
def __init__(self):
super().__init__(data_dir="")
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return False
def training_docs(self):
return self.dataset["train"]
def validation_docs(self):
return self.dataset["eval"]
def test_docs(self):
pass
return XStoryCloze
"""
It's All in the Heads: Using Attention Heads as a Baseline for Cross-Lingual Transfer in Commonsense Reasoning
https://arxiv.org/abs/2106.12066
Multilingual winograd schema challenge that includes English, French, Japanese, Portuguese, Russian and Chinese. Winograd schema challenges come from the XWinograd dataset introduced in Tikhonov et al. As it only contains 16 Chinese schemas, we add 488 Chinese schemas from clue/cluewsc2020.
Homepage: https://huggingface.co/datasets/Muennighoff/xwinograd
"""
from .winogrande import Winogrande
_CITATION = """
@misc{muennighoff2022crosslingual,
title={Crosslingual Generalization through Multitask Finetuning},
author={Niklas Muennighoff and Thomas Wang and Lintang Sutawika and Adam Roberts and Stella Biderman and Teven Le Scao and M Saiful Bari and Sheng Shen and Zheng-Xin Yong and Hailey Schoelkopf and Xiangru Tang and Dragomir Radev and Alham Fikri Aji and Khalid Almubarak and Samuel Albanie and Zaid Alyafeai and Albert Webson and Edward Raff and Colin Raffel},
year={2022},
eprint={2211.01786},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
@misc{tikhonov2021heads,
title={It's All in the Heads: Using Attention Heads as a Baseline for Cross-Lingual Transfer in Commonsense Reasoning},
author={Alexey Tikhonov and Max Ryabinin},
year={2021},
eprint={2106.12066},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
"""
_LANG = ["en", "fr", "jp", "pt", "ru", "zh"]
def create_all_tasks():
"""Creates a dictionary of tasks from a list of subjects
:return: {task_name: task}
"""
return {f"xwinograd_{lang}": create_task(lang) for lang in _LANG}
def create_task(lang):
class XWinograd(Winogrande):
DATASET_PATH = "Muennighoff/xwinograd"
DATASET_NAME = lang
def __init__(self):
super().__init__()
def has_training_docs(self):
return False
def has_validation_docs(self):
return False
def has_test_docs(self):
return True
def training_docs(self):
pass
def validation_docs(self):
pass
def test_docs(self):
return self.dataset["test"]
return XWinograd
......@@ -5,8 +5,11 @@ import collections
import functools
import inspect
import sys
import pytest
from typing import List
from typing import List, Union
import torch
from omegaconf import OmegaConf
class ExitCodeError(Exception):
......@@ -28,12 +31,10 @@ def simple_parse_args_string(args_string):
if not args_string:
return {}
arg_list = args_string.split(",")
args_dict = {}
for arg in arg_list:
k, v = arg.split("=")
args_dict[k] = v
args_dict = OmegaConf.to_object(OmegaConf.from_dotlist(arg_list))
return args_dict
def join_iters(iters):
for iter in iters:
yield from iter
......@@ -46,23 +47,26 @@ def chunks(iter, n):
if len(arr) == n:
yield arr
arr = []
if arr: yield arr
if arr:
yield arr
def group(arr, fn):
res = collections.defaultdict(list)
for ob in arr:
res[fn(ob)].append(ob)
return list(res.values())
def general_detokenize(string):
string = string.replace(" n't", "n't")
string = string.replace(" )", ")")
string = string.replace("( ", "(")
string = string.replace("\" ", "\"")
string = string.replace(" \"", "\"")
string = string.replace('" ', '"')
string = string.replace(' "', '"')
string = re.sub(r" (['.,])", r"\1", string)
return string
......@@ -94,10 +98,7 @@ def get_rolling_token_windows(token_list, prefix_token, max_seq_len, context_len
# Special handling for first window: predict all tokens
first_seq_len = min(max_seq_len, len(token_list))
yield (
[prefix_token] + token_list[:first_seq_len - 1],
token_list[:first_seq_len]
)
yield ([prefix_token] + token_list[: first_seq_len - 1], token_list[:first_seq_len])
predicted += first_seq_len
while predicted < len(token_list):
......@@ -105,61 +106,84 @@ def get_rolling_token_windows(token_list, prefix_token, max_seq_len, context_len
window_end = predicted + window_pred_len
yield (
token_list[window_end - max_seq_len - 1:window_end - 1],
token_list[window_end - window_pred_len:window_end],
token_list[window_end - max_seq_len - 1 : window_end - 1],
token_list[window_end - window_pred_len : window_end],
)
predicted += window_pred_len
def make_disjoint_window(pair):
""" Takes output from get_rolling_token_windows and makes the context not overlap with the continuation """
def make_disjoint_window(pair):
"""Takes output from get_rolling_token_windows and makes the context not overlap with the continuation"""
a, b = pair
return a[: len(a) - (len(b) - 1)], b
def select_continuation_from_batch_left_padding(
generations: Union[List[List[int]], torch.Tensor], max_context_size: int
):
"""Select the continuation from the batch, removing prompts of different lengths.
Args:
generations (Union[List[List[int]], torch.Tensor]):
A tensor or list-of-lists of shape [batch_size, sequence length].
max_context_size (int):
The size of the biggest context; generations will proceed from that
index.
Example:
PAD PAD Continue : The dog chased the cat [every day of the week]
Riddle me this : The dog chased the cat [yesterday] PAD PAD PAD PAD
Output:
[every day of the week]
[yesterday] PAD PAD PAD PAD
"""
return generations[:, max_context_size:]
return a[:-(len(b) - 1)], b
class Reorderer:
def __init__(self, arr, fn):
self.size = len(arr)
arr = list(enumerate(arr))
arr = group(arr, lambda x: fn(x[1]))
arr = [
([y[0] for y in x], x[0][1]) for x in arr
]
arr = [([y[0] for y in x], x[0][1]) for x in arr]
arr.sort(key=lambda x: fn(x[1]))
self.arr = arr
def get_reordered(self):
return [x[1] for x in self.arr]
def get_original(self, newarr):
res = [None] * self.size
cov = [False] * self.size
for (inds, _), v in zip(self.arr, newarr):
for ind in inds:
for ind in inds:
res[ind] = v
cov[ind] = True
assert all(cov)
return res
def positional_deprecated(fn):
"""
A decorator to nudge users into passing only keyword args (`kwargs`) to the
A decorator to nudge users into passing only keyword args (`kwargs`) to the
wrapped function, `fn`.
"""
@functools.wraps(fn)
def _wrapper(*args, **kwargs):
if len(args) != 1 if inspect.ismethod(fn) else 0:
print(f"WARNING: using {fn.__name__} with positional arguments is "
if len(args) != 1 if inspect.ismethod(fn) else 0:
print(
f"WARNING: using {fn.__name__} with positional arguments is "
"deprecated and will be disallowed in a future version of "
"lm-evaluation-harness!")
"lm-evaluation-harness!"
)
return fn(*args, **kwargs)
return _wrapper
@positional_deprecated
def find_test_root(start_path: pathlib.Path) -> pathlib.Path:
"""
......@@ -169,22 +193,33 @@ def find_test_root(start_path: pathlib.Path) -> pathlib.Path:
cur_path = start_path.resolve()
max_layers = 3
for _ in range(max_layers):
if (cur_path / 'tests' / 'test_version_stable.py').exists():
if (cur_path / "tests" / "test_version_stable.py").exists():
return cur_path
else:
cur_path = cur_path.parent.resolve()
raise FileNotFoundError(f"Unable to find package root within {max_layers} upwards" +\
f"of {start_path}")
raise FileNotFoundError(
f"Unable to find package root within {max_layers} upwards" + f"of {start_path}"
)
@positional_deprecated
def run_task_tests(task_list: List[str]):
"""
Find the package root and run the tests for the given tasks
"""
import pytest
package_root = find_test_root(start_path=pathlib.Path(__file__))
task_string = ' or '.join(task_list)
args = [f'{package_root}/tests/test_version_stable.py', f'--rootdir={package_root}', '-k', f'{task_string}']
task_string = " or ".join(task_list)
args = [
f"{package_root}/tests/test_version_stable.py",
f"--rootdir={package_root}",
"-k",
f"{task_string}",
]
sys.path.append(str(package_root))
pytest_return_val = pytest.main(args)
if pytest_return_val:
raise ValueError(f"Not all tests for the specified tasks ({task_list}) ran successfully! Error code: {pytest_return_val}")
\ No newline at end of file
raise ValueError(
f"Not all tests for the specified tasks ({task_list}) ran successfully! Error code: {pytest_return_val}"
)
import argparse
import json
import logging
import fnmatch
from lm_eval import tasks, evaluator
logging.getLogger("openai").setLevel(logging.WARNING)
class MultiChoice:
def __init__(self, choices):
self.choices = choices
# Simple wildcard support (linux filename patterns)
def __contains__(self, values):
for value in values.split(","):
if len(fnmatch.filter(self.choices, value)) == 0:
return False
return True
def __iter__(self):
for choice in self.choices:
yield choice
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--model', required=True)
parser.add_argument('--model_args', default="")
parser.add_argument('--tasks', default="all_tasks")
parser.add_argument('--provide_description', action="store_true")
parser.add_argument('--num_fewshot', type=int, default=0)
parser.add_argument('--batch_size', type=int, default=None)
parser.add_argument('--device', type=str, default=None)
parser.add_argument('--output_path', default=None)
parser.add_argument('--limit', type=int, default=None)
parser.add_argument('--no_cache', action="store_true")
parser.add_argument('--description_dict_path', default=None)
parser.add_argument('--check_integrity', action="store_true")
parser.add_argument("--model", required=True)
parser.add_argument("--model_args", default="")
parser.add_argument("--tasks", default=None, choices=MultiChoice(tasks.ALL_TASKS))
parser.add_argument("--provide_description", action="store_true")
parser.add_argument("--num_fewshot", type=int, default=0)
parser.add_argument("--batch_size", type=str, default=None)
parser.add_argument("--device", type=str, default=None)
parser.add_argument("--output_path", default=None)
parser.add_argument("--limit", type=int, default=None)
parser.add_argument("--no_cache", action="store_true")
parser.add_argument("--decontamination_ngrams_path", default=None)
parser.add_argument("--description_dict_path", default=None)
parser.add_argument("--check_integrity", action="store_true")
return parser.parse_args()
# Returns a list containing all values of the source_list that
# match at least one of the patterns
def pattern_match(patterns, source_list):
task_names = set()
for pattern in patterns:
for matching in fnmatch.filter(source_list, pattern):
task_names.add(matching)
return sorted(list(task_names))
def main():
args = parse_args()
assert not args.provide_description # not implemented
if args.limit:
print("WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.")
print(
"WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
)
if args.tasks == "all_tasks":
if args.tasks is None:
task_names = tasks.ALL_TASKS
else:
task_names = args.tasks.split(",")
task_names = pattern_match(args.tasks.split(","), tasks.ALL_TASKS)
print(f"Selected Tasks: {task_names}")
description_dict = {}
if args.description_dict_path:
with open(args.description_dict_path, 'r') as f:
with open(args.description_dict_path, "r") as f:
description_dict = json.load(f)
results = evaluator.simple_evaluate(
......@@ -51,11 +86,11 @@ def main():
no_cache=args.no_cache,
limit=args.limit,
description_dict=description_dict,
check_integrity=args.check_integrity
decontamination_ngrams_path=args.decontamination_ngrams_path,
check_integrity=args.check_integrity,
)
dumped = json.dumps(results, indent=2)
print(dumped)
if args.output_path:
......
{
"Data": "Pile statistics",
"Document Count": 210607728,
"Total Pile Characters": 421215456,
"File Start Offsets": [
0,
7021438,
14042822,
21066113,
28086515,
35106072,
42123306,
49145091,
56165817,
63185587,
70211208,
77234322,
84249267,
91267634,
98285983,
105305110,
112322489,
119342491,
126367373,
133389153,
140412039,
147432373,
154452516,
161470190,
168492733,
175512521,
182526939,
189547478,
196565318,
203583306
]
}
janitor.py contains a script to remove benchmark data contamination from training data sets.
janitor.py contains a script to remove benchmark data contamination from training data sets.
It uses the approach described in the [GPT-3 paper](https://arxiv.org/abs/2005.14165).
## Algorithm
1) Collects all contamination text files that are to be removed from training data
2) Filters training data by finding `N`gram matches between the training data
2) Filters training data by finding `N`gram matches between the training data
and any contamination
1) `N`grams ignore case and punctation and are split on whitespace.
2) Matching `N`gram substrings are removed, as is a `window_to_remove` character window around
1) `N`grams ignore case and punctuation and are split on whitespace.
2) Matching `N`gram substrings are removed, as is a `window_to_remove` character window around
the match, splitting the training data into chunks
3) Any chunks less than `minimum_slice_length` are removed
4) Training data sets split into more than `too_dirty_cutoff` are considered
completey contaminated and removed
OpenAI used:
```
ngram_n = 13
......@@ -20,7 +20,7 @@ minimum_slice_length = 200
too_dirty_cutoff = 10
```
## Compling
## Compiling
Janitor can be used as a pure python program, but it is much faster if the ngram
code is run in C++. To compile the C++ code, run
......@@ -31,4 +31,3 @@ c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor
```
If your your compiler isn't linked to python, you may need to add to the above `-undefined dynamic_lookup`
import glob
import argparse
import os
import subprocess
import shutil
from tqdm import tqdm
from tqdm_multiprocess import TqdmMultiProcessPool
import logging
from tqdm_multiprocess.logger import setup_logger_tqdm
logger = logging.getLogger(__name__)
def process_task(
working_directory, output_directory, bucket_file_path, tqdm_func, global_tqdm
):
command = f"zstd {bucket_file_path}"
logger.info(command)
subprocess.call(command, shell=True)
compressed_file = bucket_file_path + ".zst"
if output_directory:
shutil.move(compressed_file, output_directory)
os.remove(bucket_file_path)
global_tqdm.update()
def compress_and_move(working_directory, output_directory, process_count):
os.makedirs(output_directory, exist_ok=True)
original_info_file_path = os.path.join(working_directory, "info.json")
assert os.path.exists(original_info_file_path)
tasks = []
bucket_file_paths = glob.glob(
os.path.join(working_directory, "output", f"*.bkt.txt.sorted")
)
for bucket_file_path in bucket_file_paths:
task = (process_task, (working_directory, output_directory, bucket_file_path))
tasks.append(task)
pool = TqdmMultiProcessPool(process_count)
def on_done(_):
return None
def on_error(_):
return None
global_progress = tqdm(
total=len(bucket_file_paths), dynamic_ncols=True, unit="file"
)
_ = pool.map(global_progress, tasks, on_error, on_done)
shutil.copy(original_info_file_path, os.path.join(output_directory, "info.json"))
parser = argparse.ArgumentParser(description="sort 13gram buckets")
parser.add_argument("-dir", "--working_directory", required=True)
parser.add_argument("-output", "--output_directory", required=True)
parser.add_argument("-procs", "--process_count", type=int, default=8)
if __name__ == "__main__":
version = 1.00
print(f"Running version {version}")
logfile_path = "compress_and_package.log"
setup_logger_tqdm(logfile_path)
args = parser.parse_args()
compress_and_move(args.working_directory, args.output_directory, args.process_count)
"""
Outputs all 13-grams found in The Pile.
Loops through all documents and uses the logic found in janitor.py to extract 13-grams.
We bucket each 13-gram by hash into separate file buckets to allow easy parallel processing in the
next stage. We also include the current pile document_id with each ngram instance to allow the
Loops through all documents and uses the logic found in janitor.py to extract 13-grams.
We bucket each 13-gram by hash into separate file buckets to allow easy parallel processing in the
next stage. We also include the current pile document_id with each ngram instance to allow the
filtering to exclude 13-grams that match more then 10 unique documents (done further down the pipeline).
We didn't use lm_dataformat to output as it increases time 4x (slow jsonify) and makes
......@@ -21,8 +21,10 @@ Arguments
"""
import argparse
import json
import pickle
import os
import sys
from pathlib import Path
import glob
import signal
......@@ -30,32 +32,98 @@ from signal import SIGINT
from tqdm import tqdm
from scripts.clean_training_data.janitor import Janitor, word_ngrams
from scripts.clean_training_data.archiver import TextArchive, Reader
from lm_eval.decontamination.janitor import Janitor, word_ngrams
from lm_eval.decontamination.archiver import TextArchive, Reader
import logging
from tqdm_multiprocess.logger import setup_logger_tqdm
logger = logging.getLogger(__name__)
pile_document_count = 210607728
logger = logging.getLogger(__name__)
terminate = False
def handler(signal_received, frame):
global terminate
terminate = True
def get_pile(directory):
reader = Reader()
for file in glob.glob(os.path.join(directory, f"*.jsonl.zst*")):
def yield_pile(start_offsets=None, checkpoint_offset=None):
directory = "pile"
if not os.path.exists(directory):
print(
"We expect the pile archives to be in the 'pile' directory, but this was not found."
)
raise Exception("Pile directory not found.")
files = list(sorted(glob.glob(os.path.join(directory, "*.jsonl.zst*"))))
pile_global_offset = 0
start_file = 0
if checkpoint_offset:
for file_i, start_offset in enumerate(start_offsets):
if start_offset > checkpoint_offset:
break
start_file = file_i
pile_global_offset = start_offset
for file_i, file in enumerate(files):
if file_i < start_file:
logger.info(f"Skipping file {file}")
continue
logger.info(f"Reading from pile file: {file}")
reader = Reader()
for document in reader.read(file):
yield document
yield (pile_global_offset, document)
pile_global_offset += 1
# Hash buckets > disk backed files. Supports file position checkpointing and resuming
# Allows you to write continuously and checkpoint intermittently. If a failure occurs
# the buckets are simply truncated at your last checkpoint.
class Buckets:
def __init__(self, directory, num_buckets):
self.bucket_files = [
os.path.join(directory, f"ngrams_{i}.bkt.txt") for i in range(num_buckets)
]
self.buckets = list(map(TextArchive, self.bucket_files))
self.checkpoint_file = os.path.join(directory, f"bucket_offsets.ckpt")
if os.path.exists(self.checkpoint_file):
self.bucket_offsets = pickle.load(open(self.checkpoint_file, "rb"))
else:
self.bucket_offsets = [0 for i in range(len(self.buckets))]
for i, offset in enumerate(self.bucket_offsets):
bucket = self.buckets[i]
bucket.fh.seek(offset)
bucket.fh.truncate()
def add_data(self, key, value):
i = hash(key) % len(self.buckets)
bucket = self.buckets[i]
bucket.add_data(value)
def save_checkpoint(self):
for bucket in self.buckets:
bucket.fh.flush()
bucket_offsets = [bucket.fh.tell() for bucket in self.buckets]
pickle.dump(bucket_offsets, open(self.checkpoint_file, "wb"))
def close_buckets(self):
for bucket in self.buckets:
bucket.commit()
def close_buckets(buckets):
for bucket in buckets:
bucket.commit()
def do_ngrams_in_buckets(n_value, working_directory, bucket_count):
pile_statistics = json.load(open("pile_statistics.json", "r"))
pile_document_count = pile_statistics["Document Count"]
start_offsets = pile_statistics["File Start Offsets"]
output_directory = os.path.join(working_directory, "output")
os.makedirs(output_directory, exist_ok=True)
......@@ -68,58 +136,71 @@ def do_ngrams_in_buckets(n_value, working_directory, bucket_count):
return
# Checkpoint
checkpoint_file = os.path.join(output_directory, f"ngram_buckets.ckpt")
checkpoint_file = os.path.join(working_directory, f"pile_offset.ckpt")
if os.path.exists(checkpoint_file):
start_id = pickle.load(open(checkpoint_file,"rb"))
checkpoint_offset = pickle.load(open(checkpoint_file, "rb"))
iterate = True
else:
start_id = 0
checkpoint_offset = 0
iterate = False
logger.info(f"Starting at pile document index {start_id}")
bucket_files = [os.path.join(output_directory, f"ngrams_{i}.bkt.txt") for i in range(bucket_count)]
buckets = list(map(TextArchive, bucket_files))
logger.info(f"Starting at pile document index {checkpoint_offset}")
buckets = Buckets(output_directory, bucket_count)
janitor = Janitor()
current_id = 0
batch_size = 1000
batch_counter = 0
with tqdm(total=pile_document_count, dynamic_ncols=True, unit="docs") as progress:
for document in get_pile(working_directory):
if current_id < start_id:
if terminate:
close_buckets(buckets)
return
current_id += 1
with tqdm(total=checkpoint_offset, dynamic_ncols=True, unit="docs") as progress:
for offset, document in yield_pile(start_offsets, checkpoint_offset):
if iterate:
logger.info(f"Iterating to offset {checkpoint_offset} from {offset}")
progress.update(offset)
iterate = False
if offset < checkpoint_offset:
progress.update()
if terminate:
return
continue
if offset == checkpoint_offset:
progress.reset(total=pile_document_count)
progress.update(checkpoint_offset)
# Save checkpoint every "batch_size", only allow terminate after checkpoint
if batch_counter == batch_size:
progress.update(batch_size)
batch_counter = 0
pickle.dump(current_id, open(checkpoint_file,"wb"))
buckets.save_checkpoint()
pickle.dump(offset, open(checkpoint_file, "wb"))
if terminate:
close_buckets(buckets)
buckets.close_buckets()
return
ngrams = word_ngrams(janitor.normalize_string(document), n_value)
for ngram in ngrams:
bucket = hash(ngram) % len(buckets)
buckets[bucket].add_data(f"{ngram} {current_id}")
buckets.add_data(ngram, f"{ngram} {offset}")
batch_counter += 1
current_id += 1
close_buckets(buckets)
buckets.close_buckets()
Path(done_file).touch()
parser = argparse.ArgumentParser(description='Generate 13 grams from Pile.')
parser = argparse.ArgumentParser(description="Generate 13 grams from Pile.")
parser.add_argument("-dir", "--working_directory", default="")
parser.add_argument("-n", "--n_value", type=int, default=13)
parser.add_argument("-buckets", "--bucket_count", type=int, default=500)
if __name__ == '__main__':
if __name__ == "__main__":
version = 1.00
print(f"Running version {version}")
if "PYTHONHASHSEED" not in os.environ or os.environ["PYTHONHASHSEED"] != "0":
print("Please run 'export PYTHONHASHSEED=0' before running generate.")
sys.exit()
# Handle sigint (ctrl-c) cleanly
previous_signal_int = signal.signal(SIGINT, handler)
......@@ -128,4 +209,8 @@ if __name__ == '__main__':
setup_logger_tqdm(logfile_path)
args = parser.parse_args()
do_ngrams_in_buckets(args.n_value, args.working_directory, args.bucket_count)
\ No newline at end of file
do_ngrams_in_buckets(args.n_value, args.working_directory, args.bucket_count)
info_dict = {"title": "dataset ngrams", "ngram_size": 13}
info_dict_path = os.path.join(args.working_directory, "info.json")
json.dump(info_dict, open(info_dict_path, "w"))
from lm_eval.decontamination.archiver import Reader
import os
import json
from functools import reduce
import glob
import tqdm
from tqdm_multiprocess import TqdmMultiProcessPool
def get_file_stats(file_path, tqdm_func, global_tqdm):
reader = Reader()
total_documents = 0
total_size = 0
update_frequency = 10000
current_file_position = 0
with tqdm_func(
total=os.path.getsize(file_path), dynamic_ncols=True, unit="byte", unit_scale=1
) as progress:
for document in reader.read(file_path, get_meta=True):
total_size += len(document)
total_documents += 1
if total_documents % update_frequency == 0:
new_file_pos = reader.fh.tell()
bytes_read = new_file_pos - current_file_position
current_file_position = new_file_pos
progress.update(bytes_read)
global_tqdm.update(bytes_read)
return (total_documents, total_size)
def get_files():
directory = "pile"
files = list(sorted(glob.glob(os.path.join(directory, "*.jsonl.zst*"))))
print(files)
return files
def get_stats():
files = get_files()
total_size_bytes = sum(map(lambda x: os.path.getsize(x), files))
pool = TqdmMultiProcessPool(4)
global_tqdm = tqdm.tqdm(
total=total_size_bytes, dynamic_ncols=True, unit="byte", unit_scale=1
)
# Generate minhashes with pool
tasks = [(get_file_stats, (file,)) for file in files]
def on_done(_):
return None
def on_error(_):
return None
results = pool.map(global_tqdm, tasks, on_error, on_done)
total_documents, total_size = reduce(
lambda x, y: (x[0] + y[0], x[1] + y[1]), results
)
start_offsets = []
current_offset = 0
for file_document_count, _ in results:
start_offsets.append(current_offset)
current_offset += file_document_count
return (total_documents, total_size, start_offsets)
if __name__ == "__main__":
version = 1.01
print(f"Running version {version}")
stats_file_path = "pile_statistics.json"
if os.path.exists(stats_file_path):
stats = json.load(open(stats_file_path, "r"))
else:
document_count, total_document_size_chars, start_offsets = get_stats()
stats = {
"Data": "Pile statistics",
"Document Count": document_count,
"Total Pile Characters": total_document_size_chars,
"File Start Offsets": start_offsets,
}
json.dump(stats, open(stats_file_path, "w"), indent=4)
print(f"document_count: {stats['Document Count']}")
print(f"total_chars: {stats['Total Pile Characters']}")
print(f"start_offsets: {stats['File Start Offsets']}")
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <utility>
#include <queue>
#include <string>
#include <vector>
#include <tuple>
#include <queue>
#include <utility>
#include <vector>
bool is_whitespace(char ch) noexcept {
// " \t\n\r\x0b\x0c" (python string.whitespace)
return ch == 32 or (9 <= ch and ch <= 13);
// return ch <= 32; // arguably too general, but slightly faster
// " \t\n\r\x0b\x0c" (python string.whitespace)
return ch == 32 or (9 <= ch and ch <= 13);
// return ch <= 32; // arguably too general, but slightly faster
}
bool is_punctuation(char c) noexcept {
// '!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~' ascii values: 33-47, 58-64, 91-96, 123-126
return (33 <= c and c <= 47) or (58 <= c and c <= 64) or (91 <= c and c <= 96) or (123 <= c and c <= 126);
// '!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~' ascii values: 33-47, 58-64,
// 91-96, 123-126
return (33 <= c and c <= 47) or (58 <= c and c <= 64) or
(91 <= c and c <= 96) or (123 <= c and c <= 126);
}
// Takes a string and makes ngrams of length N, splitting grams on whitespace and ignoring ignored characters
// Returns a LARGE array of ngrams
std::vector<std::string> clean_ngram(
std::string const & input, std::string const & ignore, size_t ngram_n
) noexcept {
size_t num_grams = 0;
std::vector<std::string> ngram_list;
std::vector<uint8_t> gram_lengths;
std::string current_ngram;
// Max gram length is set to 10 below.
current_ngram.reserve(11*ngram_n);
gram_lengths.reserve(ngram_n);
bool started_gram = false;
gram_lengths.push_back(0);
//for (size_t i=0; i<input.length(); i++) {
// this is slightly faster, and we don't need the index in this one
for (auto iter = input.begin(); iter != input.end(); iter++) {
// If whitespace, end the current ngram and start the next
// alternatively, (perhaps marginally) faster: if (is_whitespace(ch)) { ... }
if (is_whitespace(*iter) || gram_lengths.back() > 10) {
// Skip all whitespace
while (++iter != input.end() && is_whitespace(*iter));
iter--;
if (started_gram){
num_grams += 1;
// Building 1grams is a special case
if (ngram_n == 1){
ngram_list.push_back(current_ngram);
current_ngram = current_ngram.substr(gram_lengths.front());
gram_lengths.back() = 0;
// If there are enough grams to form an ngram, save
} else if (num_grams >= ngram_n){
// Save the current ngram
ngram_list.push_back(current_ngram);
// Start the next ngram by dropping the first gram and its space from the ngram
current_ngram = current_ngram.substr(gram_lengths.front() + 1);
current_ngram += ' ';
// Drop the length of the first gram and prepare to record the length of the new gram
gram_lengths.erase(gram_lengths.begin());
gram_lengths.push_back(0);
// Otherwise, continute building
} else {
current_ngram += ' ';
gram_lengths.push_back(0);
}
started_gram = false;
}
// Takes a string and makes ngrams of length N, splitting grams on whitespace
// and ignoring ignored characters Returns a LARGE array of ngrams
std::vector<std::string> clean_ngram(std::string const &input,
std::string const &ignore,
size_t ngram_n) noexcept {
size_t num_grams = 0;
std::vector<std::string> ngram_list;
std::vector<uint8_t> gram_lengths;
std::string current_ngram;
// Max gram length is set to 10 below.
current_ngram.reserve(11 * ngram_n);
gram_lengths.reserve(ngram_n);
bool started_gram = false;
gram_lengths.push_back(0);
// for (size_t i=0; i<input.length(); i++) {
// this is slightly faster, and we don't need the index in this one
for (auto iter = input.begin(); iter != input.end(); iter++) {
// If whitespace, end the current ngram and start the next
// alternatively, (perhaps marginally) faster: if (is_whitespace(ch)) { ...
// }
if (is_whitespace(*iter) || gram_lengths.back() > 10) {
// Skip all whitespace
while (++iter != input.end() && is_whitespace(*iter))
;
iter--;
if (started_gram) {
num_grams += 1;
// Building 1grams is a special case
if (ngram_n == 1) {
ngram_list.push_back(current_ngram);
current_ngram = current_ngram.substr(gram_lengths.front());
gram_lengths.back() = 0;
// If there are enough grams to form an ngram, save
} else if (num_grams >= ngram_n) {
// Save the current ngram
ngram_list.push_back(current_ngram);
// Start the next ngram by dropping the first gram and its space from
// the ngram
current_ngram = current_ngram.substr(gram_lengths.front() + 1);
current_ngram += ' ';
// Drop the length of the first gram and prepare to record the length
// of the new gram
gram_lengths.erase(gram_lengths.begin());
gram_lengths.push_back(0);
// Otherwise, continute building
} else {
current_ngram += ' ';
gram_lengths.push_back(0);
}
started_gram = false;
}
// Skip ignored characters
// alternatively, (perhaps marginally) faster: if (is_punctuation(ch)) continue;
} else if (ignore.find(*iter) != std::string::npos) {
continue;
}
// Skip ignored characters
// alternatively, (perhaps marginally) faster: if (is_punctuation(ch))
// continue;
} else if (ignore.find(*iter) != std::string::npos) {
continue;
}
// If it is a non-ignored character, add it to the ngram and update the last gram's length
else {
current_ngram += tolower(*iter);
gram_lengths.back() += 1;
started_gram = true;
}
// If it is a non-ignored character, add it to the ngram and update the last
// gram's length
else {
current_ngram += tolower(*iter);
gram_lengths.back() += 1;
started_gram = true;
}
}
return ngram_list;
return ngram_list;
}
// Takes a string and makes ngrams of length N, splitting grams on whitespace
// and ignoring ignored characters Returns a LARGE array of tuples of (ngram,
// start_idx, end_idx)
std::vector<std::tuple<std::string, size_t, size_t>>
clean_ngram_with_indices(std::string const &input, std::string const &ignore,
size_t ngram_n) noexcept {
size_t num_grams = 0;
std::vector<std::tuple<std::string, size_t, size_t>> ngram_list;
std::vector<uint8_t> gram_lengths;
std::vector<size_t> gram_start_indices;
std::string current_ngram;
// Max gram length is set to 10 below.
current_ngram.reserve(11 * ngram_n);
bool started_gram = false;
gram_lengths.push_back(0);
gram_start_indices.push_back(0);
for (size_t i = 0; i < input.length(); i++) {
char ch = input[i];
// If whitespace, end the current ngram and start the next
if (is_whitespace(ch) || gram_lengths.back() > 10) {
// Skip all whitespace
while (++i < input.length() && is_whitespace(input[i]))
;
i--;
if (started_gram) {
num_grams += 1;
// Building 1grams is a special case
if (ngram_n == 1) {
ngram_list.push_back(
std::make_tuple(current_ngram, gram_start_indices.front(), i));
current_ngram = current_ngram.substr(gram_lengths.front());
gram_lengths.back() = 0;
gram_start_indices.back() = i + 1;
// If there are enough grams to form an ngram, save
} else if (num_grams >= ngram_n) {
// Save the current ngram
ngram_list.push_back(
std::make_tuple(current_ngram, gram_start_indices.front(), i));
// Start the next ngram by dropping the first gram and its space from
// the ngram
current_ngram = current_ngram.substr(gram_lengths.front() + 1);
current_ngram += ' ';
// Drop the length of the first gram and prepare to record the length
// of the new gram
gram_lengths.erase(gram_lengths.begin());
gram_lengths.push_back(0);
gram_start_indices.erase(gram_start_indices.begin());
gram_start_indices.push_back(i + 1);
// Otherwise, continute building
} else {
current_ngram += ' ';
gram_lengths.push_back(0);
gram_start_indices.push_back(i + 1);
}
// Takes a string and makes ngrams of length N, splitting grams on whitespace and ignoring ignored characters
// Returns a LARGE array of tuples of (ngram, start_idx, end_idx)
std::vector<std::tuple<std::string, size_t, size_t> > clean_ngram_with_indices(
std::string const & input, std::string const & ignore, size_t ngram_n
) noexcept {
size_t num_grams = 0;
std::vector<std::tuple<std::string, size_t, size_t> > ngram_list;
std::vector<uint8_t> gram_lengths;
std::vector<size_t> gram_start_indices;
std::string current_ngram;
// Max gram length is set to 10 below.
current_ngram.reserve(11*ngram_n);
bool started_gram = false;
gram_lengths.push_back(0);
gram_start_indices.push_back(0);
for (size_t i=0; i<input.length(); i++) {
char ch = input[i];
// If whitespace, end the current ngram and start the next
if (is_whitespace(ch) || gram_lengths.back() > 10) {
// Skip all whitespace
while (++i < input.length() && is_whitespace(input[i]));
i--;
if (started_gram){
num_grams += 1;
// Building 1grams is a special case
if (ngram_n == 1){
ngram_list.push_back(std::make_tuple(current_ngram, gram_start_indices.front(), i));
current_ngram = current_ngram.substr(gram_lengths.front());
gram_lengths.back() = 0;
gram_start_indices.back() = i+1;
// If there are enough grams to form an ngram, save
} else if (num_grams >= ngram_n){
// Save the current ngram
ngram_list.push_back(
std::make_tuple(current_ngram, gram_start_indices.front(), i)
);
// Start the next ngram by dropping the first gram and its space from the ngram
current_ngram = current_ngram.substr(gram_lengths.front() + 1);
current_ngram += ' ';
// Drop the length of the first gram and prepare to record the length of the new gram
gram_lengths.erase(gram_lengths.begin());
gram_lengths.push_back(0);
gram_start_indices.erase(gram_start_indices.begin());
gram_start_indices.push_back(i+1);
// Otherwise, continute building
} else {
current_ngram += ' ';
gram_lengths.push_back(0);
gram_start_indices.push_back(i+1);
}
started_gram = false;
}
started_gram = false;
}
// Skip ignored characters
} else if (ignore.find(*iter) != std::string::npos) {
continue;
// Skip ignored characters
} else if (ignore.find(ch) != std::string::npos) {
continue;
// If it is a non-ignored character, add it to the ngram and update the last gram's length
} else {
current_ngram += tolower(ch);
gram_lengths.back() += 1;
started_gram = true;
}
// If it is a non-ignored character, add it to the ngram and update the
// last gram's length
} else {
current_ngram += tolower(ch);
gram_lengths.back() += 1;
started_gram = true;
}
}
return ngram_list;
return ngram_list;
}
PYBIND11_MODULE(janitor_util, m) {
m.doc() = "pybind11 example plugin"; // optional module docstring
// m.def("add", &add, "A function which adds two numbers"); // example function
m.def("clean_ngram", &clean_ngram, "Create ngrams of words, ignoring some characters");
m.def("clean_ngram_with_indices", &clean_ngram_with_indices, "Create ngrams of words with indices, ignoring some characters");
m.doc() = "pybind11 example plugin"; // optional module docstring
// m.def("add", &add, "A function which adds two numbers"); // example
// function
m.def("clean_ngram", &clean_ngram,
"Create ngrams of words, ignoring some characters");
m.def("clean_ngram_with_indices", &clean_ngram_with_indices,
"Create ngrams of words with indices, ignoring some characters");
}
// Example compile
// c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix)
// If python and gcc aren't linked, append to the above: -undefined dynamic_lookup
\ No newline at end of file
// c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes)
// janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) If
// python and gcc aren't linked, append to the above: -undefined
// dynamic_lookup
......@@ -27,25 +27,33 @@ from scripts.clean_training_data.archiver import TextReader, TextArchive
import logging
from tqdm_multiprocess.logger import setup_logger_tqdm
logger = logging.getLogger(__name__)
# Multiprocessed
def process_bucket(bucket_file_path, processed_directory, move_dir, tqdm_func, global_tqdm):
bucket_id = re.sub("\D", "", os.path.basename(bucket_file_path))
done_file = os.path.join(processed_directory, f"ngram_bucket_processing_{bucket_id}.done")
# Multiprocessed
def process_bucket(
bucket_file_path, processed_directory, move_dir, tqdm_func, global_tqdm
):
bucket_id = re.sub("\D", "", os.path.basename(bucket_file_path)) # noqa: W605
done_file = os.path.join(
processed_directory, f"ngram_bucket_processing_{bucket_id}.done"
)
if os.path.exists(done_file):
logger.info(f"bucket {bucket_id} already processed, skipping")
return
# For managing tqdm
file_size = os.path.getsize(bucket_file_path)
bucket_progress = tqdm_func(total=file_size, dynamic_ncols=True, unit="byte", unit_scale=1)
bucket_progress = tqdm_func(
total=file_size, dynamic_ncols=True, unit="byte", unit_scale=1
)
current_file_position = 0
update_frequency = 100 * 1000000 # 100mb
update_frequency = 100 * 1000000 # 100mb
update_counter = 0
# Iterate through and output ngrams which occur in more then 10 documents
# Iterate through and output ngrams which occur in more then 10 documents
bucket = TextReader(bucket_file_path)
output_file_path = bucket_file_path + ".processed"
......@@ -56,10 +64,12 @@ def process_bucket(bucket_file_path, processed_directory, move_dir, tqdm_func, g
for line in bucket.read():
[ngram, document_id] = line.rsplit(" ", 1)
# Write ngram if more then 10 unique document occurences
# Write ngram if more then 10 unique document occurrences
if ngram != current_ngram:
if len(current_ngram_document_ids) > 10:
output_archive.add_data(f"{current_ngram} {len(current_ngram_document_ids)}")
output_archive.add_data(
f"{current_ngram} {len(current_ngram_document_ids)}"
)
current_ngram = ngram
current_ngram_document_ids = set()
......@@ -84,28 +94,38 @@ def process_bucket(bucket_file_path, processed_directory, move_dir, tqdm_func, g
global_tqdm.update()
def process_sorted_buckets(working_directory, move_dir, process_count):
bucket_file_paths = glob.glob(os.path.join(working_directory, f"*.bkt.txt.sorted"))
processed_directory = os.path.join(working_directory, "processed")
os.makedirs(processed_directory, exist_ok=True)
pool = TqdmMultiProcessPool(process_count)
tasks = [(process_bucket, (bucket_file, processed_directory, move_dir)) for bucket_file in bucket_file_paths]
pool = TqdmMultiProcessPool(process_count)
tasks = [
(process_bucket, (bucket_file, processed_directory, move_dir))
for bucket_file in bucket_file_paths
]
global_tqdm = tqdm(total=len(bucket_file_paths), dynamic_ncols=True, unit="bucket")
on_done = lambda _ : None
on_error = lambda _ : None
def on_done(_):
return None
def on_error(_):
return None
_ = pool.map(global_tqdm, tasks, on_error, on_done)
parser = argparse.ArgumentParser(description='Process 13 grams from sorted buckets.')
parser = argparse.ArgumentParser(description="Process 13 grams from sorted buckets.")
parser.add_argument("-dir", "--working_directory", default="")
parser.add_argument("-move", "--move_dir", default="")
parser.add_argument("-procs", "--process_count", type=int, default=4)
if __name__ == '__main__':
if __name__ == "__main__":
logfile_path = "process13grams.log"
setup_logger_tqdm(logfile_path)
args = parser.parse_args()
process_sorted_buckets(args.working_directory, args.move_dir, args.process_count)
\ No newline at end of file
process_sorted_buckets(args.working_directory, args.move_dir, args.process_count)
"""
Iteratively runs gnu sort on each bucket, gnu handles the multiprocessing.
Iteratively runs gnu sort on each bucket, uses up to 8 cores.
Arguments
---------
......@@ -11,48 +11,47 @@ Arguments
import glob
import argparse
import os
from pathlib import Path
import signal
from signal import SIGINT
import re
import subprocess
from tqdm import tqdm
import logging
from tqdm_multiprocess.logger import setup_logger_tqdm
logger = logging.getLogger(__name__)
terminate = False
def handler(signal_received, frame):
global terminate
terminate = True
def sort_13_gram_buckets(working_directory):
bucket_file_paths = glob.glob(os.path.join(working_directory, f"*.bkt.txt"))
bucket_file_paths = glob.glob(os.path.join(working_directory, f"*.bkt.txt"))
for bucket_file_path in tqdm(bucket_file_paths, dynamic_ncols=True):
bucket_id = re.sub("\D", "", os.path.basename(bucket_file_path))
done_file = os.path.join(working_directory, f"ngram_bucket_sorting_{bucket_id}.done")
if os.path.exists(done_file):
logger.info(f"bucket {bucket_id} already processed, skipping")
return
sorted_file_path = bucket_file_path + ".sorted"
command = f"sort {bucket_file_path} > {sorted_file_path}"
logger.info(command)
logger.info(command)
subprocess.call(command, shell=True)
if terminate:
return
Path(done_file).touch()
os.remove(bucket_file_path)
parser = argparse.ArgumentParser(description='sort 13gram buckets')
parser = argparse.ArgumentParser(description="sort 13gram buckets")
parser.add_argument("-dir", "--working_directory", default="")
if __name__ == '__main__':
if __name__ == "__main__":
version = 1.00
print(f"Running version {version}")
# Handle sigint (ctrl-c) cleanly
previous_signal_int = signal.signal(SIGINT, handler)
......@@ -61,4 +60,4 @@ if __name__ == '__main__':
setup_logger_tqdm(logfile_path)
args = parser.parse_args()
sort_13_gram_buckets(args.working_directory)
\ No newline at end of file
sort_13_gram_buckets(args.working_directory)
......@@ -7,7 +7,7 @@ from lm_eval.base import LM
class DryrunLM(LM):
def __init__(self):
self.tokencost = 0
self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained("gpt2")
self.tokenizer.pad_token = "<|endoftext|>"
@classmethod
......@@ -16,28 +16,28 @@ class DryrunLM(LM):
def loglikelihood(self, requests):
res = []
for ctx, cont in requests:
res.append((-random.random(), False))
self.tokencost += len(self.tokenizer.tokenize(ctx + cont))
return res
def greedy_until(self, requests):
res = []
for ctx, until in requests:
for ctx, _ in requests:
res.append("lol")
# assume worst case - generates until 256
self.tokencost += len(self.tokenizer.tokenize(ctx)) + 256
return res
def loglikelihood_rolling(self, requests):
res = []
for s, in requests:
for (s,) in requests:
# assume worst case: extra full context
self.tokencost += len(self.tokenizer.tokenize(s)) + 2048
......@@ -46,7 +46,7 @@ class DryrunLM(LM):
def main():
lm = DryrunLM()
task_list = "arc_challenge,arc_easy,boolq,cola,copa,headqa,hellaswag,lambada,logiqa,mathqa,mc_taco,mrpc,multirc,openbookqa,piqa,prost,pubmedqa,qnli,qqp,race,record,rte,sciq,sst,triviaqa,webqs,wic,wikitext,winogrande,wnli,wsc"
values = []
for taskname in task_list.split(","):
......@@ -57,11 +57,20 @@ def main():
num_fewshot=0,
limit=None,
bootstrap_iters=10,
description_dict=None
description_dict=None,
)
print(taskname, lm.tokencost)
values.append([taskname, lm.tokencost, lm.tokencost / 1000 * 0.0008, lm.tokencost / 1000 * 0.0012, lm.tokencost / 1000 * 0.006, lm.tokencost / 1000 * 0.06])
values.append(
[
taskname,
lm.tokencost,
lm.tokencost / 1000 * 0.0008,
lm.tokencost / 1000 * 0.0012,
lm.tokencost / 1000 * 0.006,
lm.tokencost / 1000 * 0.06,
]
)
from pytablewriter import MarkdownTableWriter
writer = MarkdownTableWriter()
......@@ -69,10 +78,21 @@ def main():
values.sort(key=lambda x: -x[1])
totcost = sum([x[1] for x in values])
values.append(["**Total**", totcost, totcost / 1000 * 0.0008, totcost / 1000 * 0.0012, totcost / 1000 * 0.006, totcost / 1000 * 0.06])
values.append(
[
"**Total**",
totcost,
totcost / 1000 * 0.0008,
totcost / 1000 * 0.0012,
totcost / 1000 * 0.006,
totcost / 1000 * 0.06,
]
)
writer.value_matrix = values
print(writer.dumps())
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment