Commit 18c0fa29 authored by cardy20's avatar cardy20
Browse files

conflict solved

parents 09915adf 0542d35d
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
Pointer Sentinel Mixture Models Pointer Sentinel Mixture Models
https://arxiv.org/pdf/1609.07843.pdf https://arxiv.org/pdf/1609.07843.pdf
The WikiText language modeling dataset is a collection of over 100 million tokens The WikiText language modeling dataset is a collection of over 100 million tokens
extracted from the set of verified Good and Featured articles on Wikipedia. extracted from the set of verified Good and Featured articles on Wikipedia.
NOTE: This `Task` is based on WikiText-2. NOTE: This `Task` is based on WikiText-2.
...@@ -10,14 +10,12 @@ NOTE: This `Task` is based on WikiText-2. ...@@ -10,14 +10,12 @@ NOTE: This `Task` is based on WikiText-2.
Homepage: https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/ Homepage: https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/
""" """
import re import re
import inspect
import lm_eval.datasets.wikitext.wikitext
from lm_eval.base import PerplexityTask from lm_eval.base import PerplexityTask
_CITATION = """ _CITATION = """
@misc{merity2016pointer, @misc{merity2016pointer,
title={Pointer Sentinel Mixture Models}, title={Pointer Sentinel Mixture Models},
author={Stephen Merity and Caiming Xiong and James Bradbury and Richard Socher}, author={Stephen Merity and Caiming Xiong and James Bradbury and Richard Socher},
year={2016}, year={2016},
eprint={1609.07843}, eprint={1609.07843},
...@@ -63,7 +61,7 @@ def wikitext_detokenizer(string): ...@@ -63,7 +61,7 @@ def wikitext_detokenizer(string):
class WikiText(PerplexityTask): class WikiText(PerplexityTask):
VERSION = 1 VERSION = 1
DATASET_PATH = inspect.getfile(lm_eval.datasets.wikitext.wikitext) DATASET_PATH = "EleutherAI/wikitext_document_level"
DATASET_NAME = "wikitext-2-raw-v1" DATASET_NAME = "wikitext-2-raw-v1"
def has_training_docs(self): def has_training_docs(self):
...@@ -76,20 +74,23 @@ class WikiText(PerplexityTask): ...@@ -76,20 +74,23 @@ class WikiText(PerplexityTask):
return True return True
def training_docs(self): def training_docs(self):
return map(self._load_doc, self.dataset["train"]) return map(self._process_doc, self.dataset["train"])
def validation_docs(self): def validation_docs(self):
return map(self._load_doc, self.dataset["validation"]) return map(self._process_doc, self.dataset["validation"])
def test_docs(self): def test_docs(self):
return map(self._load_doc, self.dataset["test"]) return map(self._process_doc, self.dataset["test"])
def _load_doc(self, doc): def _process_doc(self, doc):
return doc["page"] return doc["page"]
def doc_to_target(self, doc): def doc_to_target(self, doc):
return wikitext_detokenizer(doc) return wikitext_detokenizer(doc)
def should_decontaminate(self):
return True
def count_words(self, doc): def count_words(self, doc):
# count number of words in *original doc before detokenization* # count number of words in *original doc before detokenization*
return len(re.split(r"\s+", doc)) return len(re.split(r"\s+", doc))
""" """
WinoGrande: An Adversarial Winograd Schema Challenge at Scale WinoGrande: An Adversarial Winograd Schema Challenge at Scale
https://arxiv.org/pdf/1907.10641.pdf https://arxiv.org/pdf/1907.10641.pdf
WinoGrande is a collection of 44k problems, inspired by Winograd Schema Challenge WinoGrande is a collection of 44k problems, inspired by Winograd Schema Challenge
(Levesque, Davis, and Morgenstern 2011), but adjusted to improve the scale and (Levesque, Davis, and Morgenstern 2011), but adjusted to improve the scale and
robustness against the dataset-specific bias. Formulated as a fill-in-a-blank robustness against the dataset-specific bias. Formulated as a fill-in-a-blank
task with binary options, the goal is to choose the right option for a given task with binary options, the goal is to choose the right option for a given
sentence which requires commonsense reasoning. sentence which requires commonsense reasoning.
NOTE: This evaluation of Winogrande uses partial evaluation as described by NOTE: This evaluation of Winogrande uses partial evaluation as described by
Trinh & Le in Simple Method for Commonsense Reasoning (2018). Trinh & Le in Simple Method for Commonsense Reasoning (2018).
See: https://arxiv.org/abs/1806.02847 See: https://arxiv.org/abs/1806.02847
Homepage: https://leaderboard.allenai.org/winogrande/submissions/public Homepage: https://leaderboard.allenai.org/winogrande/submissions/public
""" """
import numpy as np import numpy as np
from lm_eval.base import rf, Task from lm_eval.base import rf, Task
from lm_eval.metrics import mean from lm_eval.metrics import mean
_CITATION = """ _CITATION = """
@article{sakaguchi2019winogrande, @article{sakaguchi2019winogrande,
title={WinoGrande: An Adversarial Winograd Schema Challenge at Scale}, title={WinoGrande: An Adversarial Winograd Schema Challenge at Scale},
author={Sakaguchi, Keisuke and Bras, Ronan Le and Bhagavatula, Chandra and Choi, Yejin}, author={Sakaguchi, Keisuke and Bras, Ronan Le and Bhagavatula, Chandra and Choi, Yejin},
journal={arXiv preprint arXiv:1907.10641}, journal={arXiv preprint arXiv:1907.10641},
year={2019} year={2019}
} }
""" """
class Winogrande(Task): class Winogrande(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = "winogrande" DATASET_PATH = "winogrande"
DATASET_NAME = "winogrande_xl" DATASET_NAME = "winogrande_xl"
answer_to_num = {'1': 0, '2': 1} answer_to_num = {"1": 0, "2": 1}
def has_training_docs(self): def has_training_docs(self):
return True return True
def has_validation_docs(self): def has_validation_docs(self):
return True return True
def has_test_docs(self): def has_test_docs(self):
return False return False
def training_docs(self): def training_docs(self):
if self._training_docs is None: if self._training_docs is None:
self._training_docs = list(self.dataset["train"]) self._training_docs = list(self.dataset["train"])
return self._training_docs return self._training_docs
def validation_docs(self): def validation_docs(self):
return self.dataset["validation"] return self.dataset["validation"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return self.partial_context(doc, doc["option" + doc["answer"]]) return self.partial_context(doc, doc["option" + doc["answer"]])
@classmethod def should_decontaminate(self):
def partial_context(cls, doc, option): return True
# Substitute the pronoun in the sentence with the specified option
# and ignore everything after. def doc_to_decontamination_query(self, doc):
pronoun_loc = doc["sentence"].index("_") return doc["sentence"]
return doc["sentence"][:pronoun_loc] + option
@classmethod
def doc_to_target(self, doc): def partial_context(cls, doc, option):
return self.partial_target(doc) # Substitute the pronoun in the sentence with the specified option
# and ignore everything after.
@classmethod pronoun_loc = doc["sentence"].index("_")
def partial_target(cls, doc): return doc["sentence"][:pronoun_loc] + option
# The target is everything after the document specified pronoun.
pronoun_loc = doc["sentence"].index("_") + 1 def doc_to_target(self, doc):
return " " + doc["sentence"][pronoun_loc:].strip() return self.partial_target(doc)
def construct_requests(self, doc, ctx): @classmethod
"""Uses RequestFactory to construct Requests and returns an iterable of def partial_target(cls, doc):
Requests which will be sent to the LM. # The target is everything after the document specified pronoun.
pronoun_loc = doc["sentence"].index("_") + 1
:param doc: return " " + doc["sentence"][pronoun_loc:].strip()
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str def construct_requests(self, doc, ctx):
The context string, generated by fewshot_context. This includes the natural """Uses RequestFactory to construct Requests and returns an iterable of
language description, as well as the few shot examples, and the question Requests which will be sent to the LM.
part of the document for `doc`.
""" :param doc:
target = self.partial_target(doc) The document as returned from training_docs, validation_docs, or test_docs.
lls = [] :param ctx: str
for option in [doc["option1"], doc["option2"]]: The context string, generated by fewshot_context. This includes the natural
partial_ctx = self.partial_context(doc, option) language description, as well as the few shot examples, and the question
full_ctx = self.append_context(ctx, partial_ctx) part of the document for `doc`.
lls.append(rf.loglikelihood(full_ctx, target)[0]) """
return lls target = self.partial_target(doc)
lls = []
@classmethod for option in [doc["option1"], doc["option2"]]:
def append_context(cls, ctx, partial_ctx): partial_ctx = self.partial_context(doc, option)
ctx = ctx.split("\n\n") # Each fewshot context is on its own new line. full_ctx = self.append_context(ctx, partial_ctx)
ctx.pop() # Remove the correct context put in by `doc_to_text`. lls.append(rf.loglikelihood(full_ctx, target)[0])
return "\n\n".join([*ctx, partial_ctx]) if ctx else partial_ctx return lls
def process_results(self, doc, results): @classmethod
"""Take a single document and the LM results and evaluates, returning a def append_context(cls, ctx, partial_ctx):
dict where keys are the names of submetrics and values are the values of ctx = ctx.split("\n\n") # Each fewshot context is on its own new line.
the metric for that one document ctx.pop() # Remove the correct context put in by `doc_to_text`.
return "\n\n".join([*ctx, partial_ctx]) if ctx else partial_ctx
:param doc:
The document as returned from training_docs, validation_docs, or test_docs. def process_results(self, doc, results):
:param results: """Take a single document and the LM results and evaluates, returning a
The results of the requests created in construct_requests. dict where keys are the names of submetrics and values are the values of
""" the metric for that one document
return {
"acc": np.argmax(results) == self.answer_to_num[doc["answer"]] :param doc:
} The document as returned from training_docs, validation_docs, or test_docs.
:param results:
def aggregation(self): The results of the requests created in construct_requests.
""" """
:returns: {str: [float] -> float} return {"acc": np.argmax(results) == self.answer_to_num[doc["answer"]]}
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics def aggregation(self):
""" """
return { :returns: {str: [float] -> float}
"acc": mean A dictionary where keys are the names of submetrics and values are
} functions that aggregate a list of metrics
"""
def higher_is_better(self): return {"acc": mean}
"""
:returns: {str: bool} def higher_is_better(self):
A dictionary where keys are the names of submetrics and values are """
whether a higher value of the submetric is better :returns: {str: bool}
""" A dictionary where keys are the names of submetrics and values are
return { whether a higher value of the submetric is better
"acc": True """
} return {"acc": True}
...@@ -40,8 +40,19 @@ class WinogradSchemaChallenge273(Task): ...@@ -40,8 +40,19 @@ class WinogradSchemaChallenge273(Task):
DATASET_PATH = "winograd_wsc" DATASET_PATH = "winograd_wsc"
DATASET_NAME = "wsc273" DATASET_NAME = "wsc273"
upper_pronouns = ["A", "An", "The", "She", "He", upper_pronouns = [
"It", "They", "My", "His", "Her", "Their"] "A",
"An",
"The",
"She",
"He",
"It",
"They",
"My",
"His",
"Her",
"Their",
]
def has_training_docs(self): def has_training_docs(self):
return False return False
...@@ -53,9 +64,9 @@ class WinogradSchemaChallenge273(Task): ...@@ -53,9 +64,9 @@ class WinogradSchemaChallenge273(Task):
return True return True
def test_docs(self): def test_docs(self):
return map(self._load_doc, self.dataset["test"]) return map(self._process_doc, self.dataset["test"])
def _load_doc(self, doc): def _process_doc(self, doc):
# The HF implementation of `wsc273` is not `partial evaluation` friendly. # The HF implementation of `wsc273` is not `partial evaluation` friendly.
doc["text"] = doc["text"].replace(" ", " ") doc["text"] = doc["text"].replace(" ", " ")
doc["options"][0] = self.__normalize_option(doc, doc["options"][0]) doc["options"][0] = self.__normalize_option(doc, doc["options"][0])
...@@ -68,7 +79,7 @@ class WinogradSchemaChallenge273(Task): ...@@ -68,7 +79,7 @@ class WinogradSchemaChallenge273(Task):
option += "'s" option += "'s"
# Appropriately lowercase the pronoun in the option. # Appropriately lowercase the pronoun in the option.
pronoun = option.split()[0] pronoun = option.split()[0]
start_of_sentence = doc["text"][doc['pronoun_loc'] - 2] == '.' start_of_sentence = doc["text"][doc["pronoun_loc"] - 2] == "."
if not start_of_sentence and pronoun in self.upper_pronouns: if not start_of_sentence and pronoun in self.upper_pronouns:
return option.replace(pronoun, pronoun.lower()) return option.replace(pronoun, pronoun.lower())
return option return option
...@@ -85,11 +96,17 @@ class WinogradSchemaChallenge273(Task): ...@@ -85,11 +96,17 @@ class WinogradSchemaChallenge273(Task):
def doc_to_text(self, doc): def doc_to_text(self, doc):
return self.partial_context(doc, doc["options"][doc["label"]]) return self.partial_context(doc, doc["options"][doc["label"]])
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["text"]
@classmethod @classmethod
def partial_context(cls, doc, option): def partial_context(cls, doc, option):
# Substitute the pronoun in the original text with the specified # Substitute the pronoun in the original text with the specified
# option and ignore everything after. # option and ignore everything after.
return doc["text"][:doc["pronoun_loc"]] + option return doc["text"][: doc["pronoun_loc"]] + option
def doc_to_target(self, doc): def doc_to_target(self, doc):
return self.partial_target(doc) return self.partial_target(doc)
...@@ -135,9 +152,7 @@ class WinogradSchemaChallenge273(Task): ...@@ -135,9 +152,7 @@ class WinogradSchemaChallenge273(Task):
:param results: :param results:
The results of the requests created in construct_requests. The results of the requests created in construct_requests.
""" """
return { return {"acc": np.argmax(results) == doc["label"]}
"acc": np.argmax(results) == doc["label"]
}
def aggregation(self): def aggregation(self):
""" """
...@@ -145,9 +160,7 @@ class WinogradSchemaChallenge273(Task): ...@@ -145,9 +160,7 @@ class WinogradSchemaChallenge273(Task):
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics functions that aggregate a list of metrics
""" """
return { return {"acc": mean}
"acc": mean
}
def higher_is_better(self): def higher_is_better(self):
""" """
...@@ -155,6 +168,4 @@ class WinogradSchemaChallenge273(Task): ...@@ -155,6 +168,4 @@ class WinogradSchemaChallenge273(Task):
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better whether a higher value of the submetric is better
""" """
return { return {"acc": True}
"acc": True
}
...@@ -5,8 +5,11 @@ import collections ...@@ -5,8 +5,11 @@ import collections
import functools import functools
import inspect import inspect
import sys import sys
import pytest from typing import List, Union
from typing import List
import torch
from omegaconf import OmegaConf
class ExitCodeError(Exception): class ExitCodeError(Exception):
...@@ -28,12 +31,10 @@ def simple_parse_args_string(args_string): ...@@ -28,12 +31,10 @@ def simple_parse_args_string(args_string):
if not args_string: if not args_string:
return {} return {}
arg_list = args_string.split(",") arg_list = args_string.split(",")
args_dict = {} args_dict = OmegaConf.to_object(OmegaConf.from_dotlist(arg_list))
for arg in arg_list:
k, v = arg.split("=")
args_dict[k] = v
return args_dict return args_dict
def join_iters(iters): def join_iters(iters):
for iter in iters: for iter in iters:
yield from iter yield from iter
...@@ -46,23 +47,26 @@ def chunks(iter, n): ...@@ -46,23 +47,26 @@ def chunks(iter, n):
if len(arr) == n: if len(arr) == n:
yield arr yield arr
arr = [] arr = []
if arr: yield arr if arr:
yield arr
def group(arr, fn): def group(arr, fn):
res = collections.defaultdict(list) res = collections.defaultdict(list)
for ob in arr: for ob in arr:
res[fn(ob)].append(ob) res[fn(ob)].append(ob)
return list(res.values()) return list(res.values())
def general_detokenize(string): def general_detokenize(string):
string = string.replace(" n't", "n't") string = string.replace(" n't", "n't")
string = string.replace(" )", ")") string = string.replace(" )", ")")
string = string.replace("( ", "(") string = string.replace("( ", "(")
string = string.replace("\" ", "\"") string = string.replace('" ', '"')
string = string.replace(" \"", "\"") string = string.replace(' "', '"')
string = re.sub(r" (['.,])", r"\1", string) string = re.sub(r" (['.,])", r"\1", string)
return string return string
...@@ -94,10 +98,7 @@ def get_rolling_token_windows(token_list, prefix_token, max_seq_len, context_len ...@@ -94,10 +98,7 @@ def get_rolling_token_windows(token_list, prefix_token, max_seq_len, context_len
# Special handling for first window: predict all tokens # Special handling for first window: predict all tokens
first_seq_len = min(max_seq_len, len(token_list)) first_seq_len = min(max_seq_len, len(token_list))
yield ( yield ([prefix_token] + token_list[: first_seq_len - 1], token_list[:first_seq_len])
[prefix_token] + token_list[:first_seq_len - 1],
token_list[:first_seq_len]
)
predicted += first_seq_len predicted += first_seq_len
while predicted < len(token_list): while predicted < len(token_list):
...@@ -105,61 +106,84 @@ def get_rolling_token_windows(token_list, prefix_token, max_seq_len, context_len ...@@ -105,61 +106,84 @@ def get_rolling_token_windows(token_list, prefix_token, max_seq_len, context_len
window_end = predicted + window_pred_len window_end = predicted + window_pred_len
yield ( yield (
token_list[window_end - max_seq_len - 1:window_end - 1], token_list[window_end - max_seq_len - 1 : window_end - 1],
token_list[window_end - window_pred_len:window_end], token_list[window_end - window_pred_len : window_end],
) )
predicted += window_pred_len predicted += window_pred_len
def make_disjoint_window(pair):
""" Takes output from get_rolling_token_windows and makes the context not overlap with the continuation """
def make_disjoint_window(pair):
"""Takes output from get_rolling_token_windows and makes the context not overlap with the continuation"""
a, b = pair a, b = pair
return a[: len(a) - (len(b) - 1)], b
def select_continuation_from_batch_left_padding(
generations: Union[List[List[int]], torch.Tensor], max_context_size: int
):
"""Select the continuation from the batch, removing prompts of different lengths.
Args:
generations (Union[List[List[int]], torch.Tensor]):
A tensor or list-of-lists of shape [batch_size, sequence length].
max_context_size (int):
The size of the biggest context; generations will proceed from that
index.
Example:
PAD PAD Continue : The dog chased the cat [every day of the week]
Riddle me this : The dog chased the cat [yesterday] PAD PAD PAD PAD
Output:
[every day of the week]
[yesterday] PAD PAD PAD PAD
"""
return generations[:, max_context_size:]
return a[:-(len(b) - 1)], b
class Reorderer: class Reorderer:
def __init__(self, arr, fn): def __init__(self, arr, fn):
self.size = len(arr) self.size = len(arr)
arr = list(enumerate(arr)) arr = list(enumerate(arr))
arr = group(arr, lambda x: fn(x[1])) arr = group(arr, lambda x: fn(x[1]))
arr = [ arr = [([y[0] for y in x], x[0][1]) for x in arr]
([y[0] for y in x], x[0][1]) for x in arr
]
arr.sort(key=lambda x: fn(x[1])) arr.sort(key=lambda x: fn(x[1]))
self.arr = arr self.arr = arr
def get_reordered(self): def get_reordered(self):
return [x[1] for x in self.arr] return [x[1] for x in self.arr]
def get_original(self, newarr): def get_original(self, newarr):
res = [None] * self.size res = [None] * self.size
cov = [False] * self.size cov = [False] * self.size
for (inds, _), v in zip(self.arr, newarr): for (inds, _), v in zip(self.arr, newarr):
for ind in inds: for ind in inds:
res[ind] = v res[ind] = v
cov[ind] = True cov[ind] = True
assert all(cov) assert all(cov)
return res return res
def positional_deprecated(fn): def positional_deprecated(fn):
""" """
A decorator to nudge users into passing only keyword args (`kwargs`) to the A decorator to nudge users into passing only keyword args (`kwargs`) to the
wrapped function, `fn`. wrapped function, `fn`.
""" """
@functools.wraps(fn) @functools.wraps(fn)
def _wrapper(*args, **kwargs): def _wrapper(*args, **kwargs):
if len(args) != 1 if inspect.ismethod(fn) else 0: if len(args) != 1 if inspect.ismethod(fn) else 0:
print(f"WARNING: using {fn.__name__} with positional arguments is " print(
f"WARNING: using {fn.__name__} with positional arguments is "
"deprecated and will be disallowed in a future version of " "deprecated and will be disallowed in a future version of "
"lm-evaluation-harness!") "lm-evaluation-harness!"
)
return fn(*args, **kwargs) return fn(*args, **kwargs)
return _wrapper return _wrapper
@positional_deprecated @positional_deprecated
def find_test_root(start_path: pathlib.Path) -> pathlib.Path: def find_test_root(start_path: pathlib.Path) -> pathlib.Path:
""" """
...@@ -169,22 +193,34 @@ def find_test_root(start_path: pathlib.Path) -> pathlib.Path: ...@@ -169,22 +193,34 @@ def find_test_root(start_path: pathlib.Path) -> pathlib.Path:
cur_path = start_path.resolve() cur_path = start_path.resolve()
max_layers = 3 max_layers = 3
for _ in range(max_layers): for _ in range(max_layers):
if (cur_path / 'tests' / 'test_version_stable.py').exists(): if (cur_path / "tests" / "test_version_stable.py").exists():
return cur_path return cur_path
else: else:
cur_path = cur_path.parent.resolve() cur_path = cur_path.parent.resolve()
raise FileNotFoundError(f"Unable to find package root within {max_layers} upwards" +\ raise FileNotFoundError(
f"of {start_path}") f"Unable to find package root within {max_layers} upwards" + f"of {start_path}"
)
@positional_deprecated @positional_deprecated
def run_task_tests(task_list: List[str]): def run_task_tests(task_list: List[str]):
""" """
Find the package root and run the tests for the given tasks Find the package root and run the tests for the given tasks
""" """
import pytest
package_root = find_test_root(start_path=pathlib.Path(__file__)) package_root = find_test_root(start_path=pathlib.Path(__file__))
task_string = ' or '.join(task_list) task_string = " or ".join(task_list)
args = [f'{package_root}/tests/test_version_stable.py', f'--rootdir={package_root}', '-k', f'{task_string}'] args = [
f"{package_root}/tests/test_version_stable.py",
f"--rootdir={package_root}",
"-k",
f"{task_string}",
]
sys.path.append(str(package_root)) sys.path.append(str(package_root))
pytest_return_val = pytest.main(args) pytest_return_val = pytest.main(args)
if pytest_return_val: if pytest_return_val:
raise ValueError(f"Not all tests for the specified tasks ({task_list}) ran successfully! Error code: {pytest_return_val}") raise ValueError(
\ No newline at end of file f"Not all tests for the specified tasks ({task_list}) ran successfully! Error code: {pytest_return_val}"
)
import argparse import argparse
import json import json
import logging import logging
import fnmatch
from lm_eval import tasks, evaluator from lm_eval import tasks, evaluator
logging.getLogger("openai").setLevel(logging.WARNING) logging.getLogger("openai").setLevel(logging.WARNING)
class MultiChoice:
def __init__(self, choices):
self.choices = choices
# Simple wildcard support (linux filename patterns)
def __contains__(self, values):
for value in values.split(","):
if len(fnmatch.filter(self.choices, value)) == 0:
return False
return True
def __iter__(self):
for choice in self.choices:
yield choice
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--model', required=True) parser.add_argument("--model", required=True)
parser.add_argument('--model_args', default="") parser.add_argument("--model_args", default="")
parser.add_argument('--tasks', default="all_tasks") parser.add_argument("--tasks", default=None, choices=MultiChoice(tasks.ALL_TASKS))
parser.add_argument('--provide_description', action="store_true") parser.add_argument("--provide_description", action="store_true")
parser.add_argument('--num_fewshot', type=int, default=0) parser.add_argument("--num_fewshot", type=int, default=0)
parser.add_argument('--batch_size', type=int, default=None) parser.add_argument("--batch_size", type=str, default=None)
parser.add_argument('--device', type=str, default=None) parser.add_argument("--device", type=str, default=None)
parser.add_argument('--output_path', default=None) parser.add_argument("--output_path", default=None)
parser.add_argument('--limit', type=int, default=None) parser.add_argument("--limit", type=int, default=None)
parser.add_argument('--no_cache', action="store_true") parser.add_argument("--no_cache", action="store_true")
parser.add_argument('--description_dict_path', default=None) parser.add_argument("--decontamination_ngrams_path", default=None)
parser.add_argument('--check_integrity', action="store_true") parser.add_argument("--description_dict_path", default=None)
parser.add_argument("--check_integrity", action="store_true")
return parser.parse_args() return parser.parse_args()
# Returns a list containing all values of the source_list that
# match at least one of the patterns
def pattern_match(patterns, source_list):
task_names = set()
for pattern in patterns:
for matching in fnmatch.filter(source_list, pattern):
task_names.add(matching)
return sorted(list(task_names))
def main(): def main():
args = parse_args() args = parse_args()
assert not args.provide_description # not implemented assert not args.provide_description # not implemented
if args.limit: if args.limit:
print("WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.") print(
"WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
)
if args.tasks == "all_tasks": if args.tasks is None:
task_names = tasks.ALL_TASKS task_names = tasks.ALL_TASKS
else: else:
task_names = args.tasks.split(",") task_names = pattern_match(args.tasks.split(","), tasks.ALL_TASKS)
print(f"Selected Tasks: {task_names}")
description_dict = {} description_dict = {}
if args.description_dict_path: if args.description_dict_path:
with open(args.description_dict_path, 'r') as f: with open(args.description_dict_path, "r") as f:
description_dict = json.load(f) description_dict = json.load(f)
results = evaluator.simple_evaluate( results = evaluator.simple_evaluate(
...@@ -51,11 +86,11 @@ def main(): ...@@ -51,11 +86,11 @@ def main():
no_cache=args.no_cache, no_cache=args.no_cache,
limit=args.limit, limit=args.limit,
description_dict=description_dict, description_dict=description_dict,
check_integrity=args.check_integrity decontamination_ngrams_path=args.decontamination_ngrams_path,
check_integrity=args.check_integrity,
) )
dumped = json.dumps(results, indent=2) dumped = json.dumps(results, indent=2)
print(dumped) print(dumped)
if args.output_path: if args.output_path:
......
{
"Data": "Pile statistics",
"Document Count": 210607728,
"Total Pile Characters": 421215456,
"File Start Offsets": [
0,
7021438,
14042822,
21066113,
28086515,
35106072,
42123306,
49145091,
56165817,
63185587,
70211208,
77234322,
84249267,
91267634,
98285983,
105305110,
112322489,
119342491,
126367373,
133389153,
140412039,
147432373,
154452516,
161470190,
168492733,
175512521,
182526939,
189547478,
196565318,
203583306
]
}
janitor.py contains a script to remove benchmark data contamination from training data sets. janitor.py contains a script to remove benchmark data contamination from training data sets.
It uses the approach described in the [GPT-3 paper](https://arxiv.org/abs/2005.14165). It uses the approach described in the [GPT-3 paper](https://arxiv.org/abs/2005.14165).
## Algorithm ## Algorithm
1) Collects all contamination text files that are to be removed from training data 1) Collects all contamination text files that are to be removed from training data
2) Filters training data by finding `N`gram matches between the training data 2) Filters training data by finding `N`gram matches between the training data
and any contamination and any contamination
1) `N`grams ignore case and punctation and are split on whitespace. 1) `N`grams ignore case and punctuation and are split on whitespace.
2) Matching `N`gram substrings are removed, as is a `window_to_remove` character window around 2) Matching `N`gram substrings are removed, as is a `window_to_remove` character window around
the match, splitting the training data into chunks the match, splitting the training data into chunks
3) Any chunks less than `minimum_slice_length` are removed 3) Any chunks less than `minimum_slice_length` are removed
4) Training data sets split into more than `too_dirty_cutoff` are considered 4) Training data sets split into more than `too_dirty_cutoff` are considered
completey contaminated and removed completey contaminated and removed
OpenAI used: OpenAI used:
``` ```
ngram_n = 13 ngram_n = 13
...@@ -20,7 +20,7 @@ minimum_slice_length = 200 ...@@ -20,7 +20,7 @@ minimum_slice_length = 200
too_dirty_cutoff = 10 too_dirty_cutoff = 10
``` ```
## Compling ## Compiling
Janitor can be used as a pure python program, but it is much faster if the ngram Janitor can be used as a pure python program, but it is much faster if the ngram
code is run in C++. To compile the C++ code, run code is run in C++. To compile the C++ code, run
...@@ -31,4 +31,3 @@ c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor ...@@ -31,4 +31,3 @@ c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor
``` ```
If your your compiler isn't linked to python, you may need to add to the above `-undefined dynamic_lookup` If your your compiler isn't linked to python, you may need to add to the above `-undefined dynamic_lookup`
""" """
Outputs all 13-grams found in The Pile. Outputs all 13-grams found in The Pile.
Loops through all documents and uses the logic found in janitor.py to extract 13-grams. Loops through all documents and uses the logic found in janitor.py to extract 13-grams.
We bucket each 13-gram by hash into separate file buckets to allow easy parallel processing in the We bucket each 13-gram by hash into separate file buckets to allow easy parallel processing in the
next stage. We also include the current pile document_id with each ngram instance to allow the next stage. We also include the current pile document_id with each ngram instance to allow the
filtering to exclude 13-grams that match more then 10 unique documents (done further down the pipeline). filtering to exclude 13-grams that match more then 10 unique documents (done further down the pipeline).
We didn't use lm_dataformat to output as it increases time 4x (slow jsonify) and makes We didn't use lm_dataformat to output as it increases time 4x (slow jsonify) and makes
...@@ -21,8 +21,10 @@ Arguments ...@@ -21,8 +21,10 @@ Arguments
""" """
import argparse import argparse
import json
import pickle import pickle
import os import os
import sys
from pathlib import Path from pathlib import Path
import glob import glob
import signal import signal
...@@ -30,36 +32,32 @@ from signal import SIGINT ...@@ -30,36 +32,32 @@ from signal import SIGINT
from tqdm import tqdm from tqdm import tqdm
from scripts.clean_training_data.janitor import Janitor, word_ngrams from lm_eval.decontamination.janitor import Janitor, word_ngrams
from scripts.clean_training_data.archiver import TextArchive, Reader from lm_eval.decontamination.archiver import TextArchive, Reader
import logging import logging
from tqdm_multiprocess.logger import setup_logger_tqdm from tqdm_multiprocess.logger import setup_logger_tqdm
logger = logging.getLogger(__name__)
pile_document_count = 210607728 logger = logging.getLogger(__name__)
terminate = False terminate = False
def handler(signal_received, frame): def handler(signal_received, frame):
global terminate global terminate
terminate = True terminate = True
def get_pile(directory): def get_pile(directory):
reader = Reader() reader = Reader()
# for file in glob.glob(os.path.join(directory, f"*.jsonl.zst*")): for file in glob.glob(os.path.join(directory, f"*.jsonl.zst*")):
for dir in os.listdir(directory): for document in reader.read(file):
for file in glob.glob(os.path.join(directory + dir, "*.jsonl")): yield document
for document in reader.read(file):
yield document
def close_buckets(buckets): def close_buckets(self):
for bucket in buckets: for bucket in self.buckets:
bucket.commit() bucket.commit()
def do_ngrams_in_buckets(n_value, working_directory, sdir, bucket_count): def do_ngrams_in_buckets(n_value, working_directory, bucket_count):
output_directory = os.path.join(sdir, "output") output_directory = os.path.join(working_directory, "output")
os.makedirs(output_directory, exist_ok=True) os.makedirs(output_directory, exist_ok=True)
logger.info(f"Generating {n_value}-grams and bucketing.") logger.info(f"Generating {n_value}-grams and bucketing.")
...@@ -71,59 +69,68 @@ def do_ngrams_in_buckets(n_value, working_directory, sdir, bucket_count): ...@@ -71,59 +69,68 @@ def do_ngrams_in_buckets(n_value, working_directory, sdir, bucket_count):
return return
# Checkpoint # Checkpoint
checkpoint_file = os.path.join(output_directory, f"ngram_buckets.ckpt") checkpoint_file = os.path.join(working_directory, f"pile_offset.ckpt")
if os.path.exists(checkpoint_file): if os.path.exists(checkpoint_file):
start_id = pickle.load(open(checkpoint_file,"rb")) checkpoint_offset = pickle.load(open(checkpoint_file, "rb"))
iterate = True
else: else:
start_id = 0 checkpoint_offset = 0
iterate = False
logger.info(f"Starting at pile document index {start_id}") logger.info(f"Starting at pile document index {checkpoint_offset}")
bucket_files = [os.path.join(output_directory, f"ngrams_{i}.bkt.txt") for i in range(bucket_count)] buckets = Buckets(output_directory, bucket_count)
buckets = list(map(TextArchive, bucket_files))
janitor = Janitor() janitor = Janitor()
current_id = 0
batch_size = 1000 batch_size = 1000
batch_counter = 0 batch_counter = 0
with tqdm(total=pile_document_count, dynamic_ncols=True, unit="docs") as progress:
for document in get_pile(working_directory):
if current_id < start_id:
if terminate:
close_buckets(buckets)
return
current_id += 1 with tqdm(total=checkpoint_offset, dynamic_ncols=True, unit="docs") as progress:
for offset, document in yield_pile(start_offsets, checkpoint_offset):
if iterate:
logger.info(f"Iterating to offset {checkpoint_offset} from {offset}")
progress.update(offset)
iterate = False
if offset < checkpoint_offset:
progress.update() progress.update()
if terminate:
return
continue continue
# Save checkpoint every "batch_size", only allow terminate after checkpoint # Save checkpoint every "batch_size", only allow terminate after checkpoint
if batch_counter == batch_size: if batch_counter == batch_size:
progress.update(batch_size) progress.update(batch_size)
batch_counter = 0 batch_counter = 0
pickle.dump(current_id, open(checkpoint_file,"wb")) buckets.save_checkpoint()
pickle.dump(offset, open(checkpoint_file, "wb"))
if terminate: if terminate:
close_buckets(buckets) buckets.close_buckets()
return return
ngrams = word_ngrams(janitor.normalize_string(document), n_value) ngrams = word_ngrams(janitor.normalize_string(document), n_value)
for ngram in ngrams: for ngram in ngrams:
bucket = hash(ngram) % len(buckets) buckets.add_data(ngram, f"{ngram} {offset}")
buckets[bucket].add_data(f"{ngram} {current_id}")
batch_counter += 1 batch_counter += 1
current_id += 1
buckets.close_buckets()
close_buckets(buckets)
Path(done_file).touch() Path(done_file).touch()
parser = argparse.ArgumentParser(description='Generate 13 grams from Pile.') parser = argparse.ArgumentParser(description="Generate 13 grams from Pile.")
parser.add_argument("-dir", "--working_directory", default="") parser.add_argument("-dir", "--working_directory", default="")
parser.add_argument("-sdir", "--save_directory", default="") parser.add_argument("-sdir", "--save_directory", default="")
parser.add_argument("-n", "--n_value", type=int, default=13) parser.add_argument("-n", "--n_value", type=int, default=13)
parser.add_argument("-buckets", "--bucket_count", type=int, default=500) parser.add_argument("-buckets", "--bucket_count", type=int, default=500)
if __name__ == '__main__': if __name__ == "__main__":
version = 1.00
print(f"Running version {version}")
if "PYTHONHASHSEED" not in os.environ or os.environ["PYTHONHASHSEED"] != "0":
print("Please run 'export PYTHONHASHSEED=0' before running generate.")
sys.exit()
# Handle sigint (ctrl-c) cleanly # Handle sigint (ctrl-c) cleanly
previous_signal_int = signal.signal(SIGINT, handler) previous_signal_int = signal.signal(SIGINT, handler)
...@@ -132,4 +139,4 @@ if __name__ == '__main__': ...@@ -132,4 +139,4 @@ if __name__ == '__main__':
setup_logger_tqdm(logfile_path) setup_logger_tqdm(logfile_path)
args = parser.parse_args() args = parser.parse_args()
do_ngrams_in_buckets(args.n_value, args.working_directory, args.save_directory , args.bucket_count) do_ngrams_in_buckets(args.n_value, args.working_directory, args.bucket_count)
\ No newline at end of file \ No newline at end of file
from lm_eval.decontamination.archiver import Reader
import os
import json
from functools import reduce
import glob
import tqdm
from tqdm_multiprocess import TqdmMultiProcessPool
def get_file_stats(file_path, tqdm_func, global_tqdm):
reader = Reader()
total_documents = 0
total_size = 0
update_frequency = 10000
current_file_position = 0
with tqdm_func(
total=os.path.getsize(file_path), dynamic_ncols=True, unit="byte", unit_scale=1
) as progress:
for document in reader.read(file_path, get_meta=True):
total_size += len(document)
total_documents += 1
if total_documents % update_frequency == 0:
new_file_pos = reader.fh.tell()
bytes_read = new_file_pos - current_file_position
current_file_position = new_file_pos
progress.update(bytes_read)
global_tqdm.update(bytes_read)
return (total_documents, total_size)
def get_files_zst():
directory = "pile"
files = list(sorted(glob.glob(os.path.join(directory, "*.jsonl.zst*"))))
print(files)
return files
def get_files():
""" jsonl files in directory """
directory = "pile"
files = list(sorted(glob.glob(os.path.join(directory, "*.jsonl"))))
print(files)
return files
def get_stats():
files = get_files()
total_size_bytes = sum(map(lambda x: os.path.getsize(x), files))
pool = TqdmMultiProcessPool(4)
global_tqdm = tqdm.tqdm(
total=total_size_bytes, dynamic_ncols=True, unit="byte", unit_scale=1
)
# Generate minhashes with pool
tasks = [(get_file_stats, (file,)) for file in files]
def on_done(_):
return None
def on_error(_):
return None
results = pool.map(global_tqdm, tasks, on_error, on_done)
total_documents, total_size = reduce(
lambda x, y: (x[0] + y[0], x[1] + y[1]), results
)
start_offsets = []
current_offset = 0
for file_document_count, _ in results:
start_offsets.append(current_offset)
current_offset += file_document_count
return (total_documents, total_size, start_offsets)
if __name__ == "__main__":
version = 1.01
print(f"Running version {version}")
stats_file_path = "pile_statistics.json"
if os.path.exists(stats_file_path):
stats = json.load(open(stats_file_path, "r"))
else:
document_count, total_document_size_chars, start_offsets = get_stats()
stats = {
"Data": "Pile statistics",
"Document Count": document_count,
"Total Pile Characters": total_document_size_chars,
"File Start Offsets": start_offsets,
}
json.dump(stats, open(stats_file_path, "w"), indent=4)
print(f"document_count: {stats['Document Count']}") print(f"document_count: {stats['Document Count']}")
print(f"total_chars: {stats['Total Pile Characters']}") print(f"total_chars: {stats['Total Pile Characters']}")
......
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <pybind11/stl.h> #include <pybind11/stl.h>
#include <utility> #include <queue>
#include <string> #include <string>
#include <vector>
#include <tuple> #include <tuple>
#include <queue> #include <utility>
#include <vector>
bool is_whitespace(char ch) noexcept { bool is_whitespace(char ch) noexcept {
// " \t\n\r\x0b\x0c" (python string.whitespace) // " \t\n\r\x0b\x0c" (python string.whitespace)
return ch == 32 or (9 <= ch and ch <= 13); return ch == 32 or (9 <= ch and ch <= 13);
// return ch <= 32; // arguably too general, but slightly faster // return ch <= 32; // arguably too general, but slightly faster
} }
bool is_punctuation(char c) noexcept { bool is_punctuation(char c) noexcept {
// '!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~' ascii values: 33-47, 58-64, 91-96, 123-126 // '!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~' ascii values: 33-47, 58-64,
return (33 <= c and c <= 47) or (58 <= c and c <= 64) or (91 <= c and c <= 96) or (123 <= c and c <= 126); // 91-96, 123-126
return (33 <= c and c <= 47) or (58 <= c and c <= 64) or
(91 <= c and c <= 96) or (123 <= c and c <= 126);
} }
// Takes a string and makes ngrams of length N, splitting grams on whitespace and ignoring ignored characters // Takes a string and makes ngrams of length N, splitting grams on whitespace
// Returns a LARGE array of ngrams // and ignoring ignored characters Returns a LARGE array of ngrams
std::vector<std::string> clean_ngram( std::vector<std::string> clean_ngram(std::string const &input,
std::string const & input, std::string const & ignore, size_t ngram_n std::string const &ignore,
) noexcept { size_t ngram_n) noexcept {
size_t num_grams = 0; size_t num_grams = 0;
std::vector<std::string> ngram_list; std::vector<std::string> ngram_list;
std::vector<uint8_t> gram_lengths; std::vector<uint8_t> gram_lengths;
std::string current_ngram; std::string current_ngram;
// Max gram length is set to 10 below. // Max gram length is set to 10 below.
current_ngram.reserve(11*ngram_n); current_ngram.reserve(11 * ngram_n);
gram_lengths.reserve(ngram_n); gram_lengths.reserve(ngram_n);
bool started_gram = false; bool started_gram = false;
gram_lengths.push_back(0); gram_lengths.push_back(0);
//for (size_t i=0; i<input.length(); i++) { // for (size_t i=0; i<input.length(); i++) {
// this is slightly faster, and we don't need the index in this one // this is slightly faster, and we don't need the index in this one
for (auto iter = input.begin(); iter != input.end(); iter++) { for (auto iter = input.begin(); iter != input.end(); iter++) {
// If whitespace, end the current ngram and start the next // If whitespace, end the current ngram and start the next
// alternatively, (perhaps marginally) faster: if (is_whitespace(ch)) { ... } // alternatively, (perhaps marginally) faster: if (is_whitespace(ch)) { ...
if (is_whitespace(*iter) || gram_lengths.back() > 10) { // }
if (is_whitespace(*iter) || gram_lengths.back() > 10) {
// Skip all whitespace
while (++iter != input.end() && is_whitespace(*iter)); // Skip all whitespace
iter--; while (++iter != input.end() && is_whitespace(*iter))
;
if (started_gram){ iter--;
num_grams += 1;
if (started_gram) {
// Building 1grams is a special case num_grams += 1;
if (ngram_n == 1){
ngram_list.push_back(current_ngram); // Building 1grams is a special case
current_ngram = current_ngram.substr(gram_lengths.front()); if (ngram_n == 1) {
gram_lengths.back() = 0; ngram_list.push_back(current_ngram);
current_ngram = current_ngram.substr(gram_lengths.front());
// If there are enough grams to form an ngram, save gram_lengths.back() = 0;
} else if (num_grams >= ngram_n){
// Save the current ngram // If there are enough grams to form an ngram, save
ngram_list.push_back(current_ngram); } else if (num_grams >= ngram_n) {
// Save the current ngram
// Start the next ngram by dropping the first gram and its space from the ngram ngram_list.push_back(current_ngram);
current_ngram = current_ngram.substr(gram_lengths.front() + 1);
current_ngram += ' '; // Start the next ngram by dropping the first gram and its space from
// the ngram
// Drop the length of the first gram and prepare to record the length of the new gram current_ngram = current_ngram.substr(gram_lengths.front() + 1);
gram_lengths.erase(gram_lengths.begin()); current_ngram += ' ';
gram_lengths.push_back(0);
// Drop the length of the first gram and prepare to record the length
// Otherwise, continute building // of the new gram
} else { gram_lengths.erase(gram_lengths.begin());
current_ngram += ' '; gram_lengths.push_back(0);
gram_lengths.push_back(0);
} // Otherwise, continute building
} else {
started_gram = false; current_ngram += ' ';
} gram_lengths.push_back(0);
}
started_gram = false;
}
// Skip ignored characters // Skip ignored characters
// alternatively, (perhaps marginally) faster: if (is_punctuation(ch)) continue; // alternatively, (perhaps marginally) faster: if (is_punctuation(ch))
} else if (ignore.find(*iter) != std::string::npos) { // continue;
continue; } else if (ignore.find(*iter) != std::string::npos) {
} continue;
}
// If it is a non-ignored character, add it to the ngram and update the last gram's length // If it is a non-ignored character, add it to the ngram and update the last
else { // gram's length
current_ngram += tolower(*iter); else {
gram_lengths.back() += 1; current_ngram += tolower(*iter);
started_gram = true; gram_lengths.back() += 1;
} started_gram = true;
} }
}
return ngram_list; return ngram_list;
} }
// Takes a string and makes ngrams of length N, splitting grams on whitespace
// and ignoring ignored characters Returns a LARGE array of tuples of (ngram,
// start_idx, end_idx)
std::vector<std::tuple<std::string, size_t, size_t>>
clean_ngram_with_indices(std::string const &input, std::string const &ignore,
size_t ngram_n) noexcept {
size_t num_grams = 0;
std::vector<std::tuple<std::string, size_t, size_t>> ngram_list;
std::vector<uint8_t> gram_lengths;
std::vector<size_t> gram_start_indices;
std::string current_ngram;
// Max gram length is set to 10 below.
current_ngram.reserve(11 * ngram_n);
bool started_gram = false;
gram_lengths.push_back(0);
gram_start_indices.push_back(0);
for (size_t i = 0; i < input.length(); i++) {
char ch = input[i];
// If whitespace, end the current ngram and start the next
if (is_whitespace(ch) || gram_lengths.back() > 10) {
// Skip all whitespace
while (++i < input.length() && is_whitespace(input[i]))
;
i--;
if (started_gram) {
num_grams += 1;
// Building 1grams is a special case
if (ngram_n == 1) {
ngram_list.push_back(
std::make_tuple(current_ngram, gram_start_indices.front(), i));
current_ngram = current_ngram.substr(gram_lengths.front());
gram_lengths.back() = 0;
gram_start_indices.back() = i + 1;
// If there are enough grams to form an ngram, save
} else if (num_grams >= ngram_n) {
// Save the current ngram
ngram_list.push_back(
std::make_tuple(current_ngram, gram_start_indices.front(), i));
// Start the next ngram by dropping the first gram and its space from
// the ngram
current_ngram = current_ngram.substr(gram_lengths.front() + 1);
current_ngram += ' ';
// Drop the length of the first gram and prepare to record the length
// of the new gram
gram_lengths.erase(gram_lengths.begin());
gram_lengths.push_back(0);
gram_start_indices.erase(gram_start_indices.begin());
gram_start_indices.push_back(i + 1);
// Otherwise, continute building
} else {
current_ngram += ' ';
gram_lengths.push_back(0);
gram_start_indices.push_back(i + 1);
}
// Takes a string and makes ngrams of length N, splitting grams on whitespace and ignoring ignored characters started_gram = false;
// Returns a LARGE array of tuples of (ngram, start_idx, end_idx) }
std::vector<std::tuple<std::string, size_t, size_t> > clean_ngram_with_indices(
std::string const & input, std::string const & ignore, size_t ngram_n
) noexcept {
size_t num_grams = 0;
std::vector<std::tuple<std::string, size_t, size_t> > ngram_list;
std::vector<uint8_t> gram_lengths;
std::vector<size_t> gram_start_indices;
std::string current_ngram;
// Max gram length is set to 10 below.
current_ngram.reserve(11*ngram_n);
bool started_gram = false;
gram_lengths.push_back(0);
gram_start_indices.push_back(0);
for (size_t i=0; i<input.length(); i++) {
char ch = input[i];
// If whitespace, end the current ngram and start the next
if (is_whitespace(ch) || gram_lengths.back() > 10) {
// Skip all whitespace
while (++i < input.length() && is_whitespace(input[i]));
i--;
if (started_gram){
num_grams += 1;
// Building 1grams is a special case
if (ngram_n == 1){
ngram_list.push_back(std::make_tuple(current_ngram, gram_start_indices.front(), i));
current_ngram = current_ngram.substr(gram_lengths.front());
gram_lengths.back() = 0;
gram_start_indices.back() = i+1;
// If there are enough grams to form an ngram, save
} else if (num_grams >= ngram_n){
// Save the current ngram
ngram_list.push_back(
std::make_tuple(current_ngram, gram_start_indices.front(), i)
);
// Start the next ngram by dropping the first gram and its space from the ngram
current_ngram = current_ngram.substr(gram_lengths.front() + 1);
current_ngram += ' ';
// Drop the length of the first gram and prepare to record the length of the new gram
gram_lengths.erase(gram_lengths.begin());
gram_lengths.push_back(0);
gram_start_indices.erase(gram_start_indices.begin());
gram_start_indices.push_back(i+1);
// Otherwise, continute building
} else {
current_ngram += ' ';
gram_lengths.push_back(0);
gram_start_indices.push_back(i+1);
}
started_gram = false;
}
// Skip ignored characters // Skip ignored characters
} else if (ignore.find(*iter) != std::string::npos) { } else if (ignore.find(ch) != std::string::npos) {
continue; continue;
// If it is a non-ignored character, add it to the ngram and update the last gram's length // If it is a non-ignored character, add it to the ngram and update the
} else { // last gram's length
current_ngram += tolower(ch); } else {
gram_lengths.back() += 1; current_ngram += tolower(ch);
started_gram = true; gram_lengths.back() += 1;
} started_gram = true;
} }
}
return ngram_list; return ngram_list;
} }
PYBIND11_MODULE(janitor_util, m) { PYBIND11_MODULE(janitor_util, m) {
m.doc() = "pybind11 example plugin"; // optional module docstring m.doc() = "pybind11 example plugin"; // optional module docstring
// m.def("add", &add, "A function which adds two numbers"); // example function // m.def("add", &add, "A function which adds two numbers"); // example
m.def("clean_ngram", &clean_ngram, "Create ngrams of words, ignoring some characters"); // function
m.def("clean_ngram_with_indices", &clean_ngram_with_indices, "Create ngrams of words with indices, ignoring some characters"); m.def("clean_ngram", &clean_ngram,
"Create ngrams of words, ignoring some characters");
m.def("clean_ngram_with_indices", &clean_ngram_with_indices,
"Create ngrams of words with indices, ignoring some characters");
} }
// Example compile // Example compile
// c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) // c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes)
// If python and gcc aren't linked, append to the above: -undefined dynamic_lookup // janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) If
\ No newline at end of file // python and gcc aren't linked, append to the above: -undefined
// dynamic_lookup
...@@ -27,25 +27,33 @@ from scripts.clean_training_data.archiver import TextReader, TextArchive ...@@ -27,25 +27,33 @@ from scripts.clean_training_data.archiver import TextReader, TextArchive
import logging import logging
from tqdm_multiprocess.logger import setup_logger_tqdm from tqdm_multiprocess.logger import setup_logger_tqdm
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Multiprocessed
def process_bucket(bucket_file_path, processed_directory, move_dir, tqdm_func, global_tqdm):
bucket_id = re.sub("\D", "", os.path.basename(bucket_file_path)) # Multiprocessed
done_file = os.path.join(processed_directory, f"ngram_bucket_processing_{bucket_id}.done") def process_bucket(
bucket_file_path, processed_directory, move_dir, tqdm_func, global_tqdm
):
bucket_id = re.sub("\D", "", os.path.basename(bucket_file_path)) # noqa: W605
done_file = os.path.join(
processed_directory, f"ngram_bucket_processing_{bucket_id}.done"
)
if os.path.exists(done_file): if os.path.exists(done_file):
logger.info(f"bucket {bucket_id} already processed, skipping") logger.info(f"bucket {bucket_id} already processed, skipping")
return return
# For managing tqdm # For managing tqdm
file_size = os.path.getsize(bucket_file_path) file_size = os.path.getsize(bucket_file_path)
bucket_progress = tqdm_func(total=file_size, dynamic_ncols=True, unit="byte", unit_scale=1) bucket_progress = tqdm_func(
total=file_size, dynamic_ncols=True, unit="byte", unit_scale=1
)
current_file_position = 0 current_file_position = 0
update_frequency = 100 * 1000000 # 100mb update_frequency = 100 * 1000000 # 100mb
update_counter = 0 update_counter = 0
# Iterate through and output ngrams which occur in more then 10 documents # Iterate through and output ngrams which occur in more then 10 documents
bucket = TextReader(bucket_file_path) bucket = TextReader(bucket_file_path)
output_file_path = bucket_file_path + ".processed" output_file_path = bucket_file_path + ".processed"
...@@ -56,10 +64,12 @@ def process_bucket(bucket_file_path, processed_directory, move_dir, tqdm_func, g ...@@ -56,10 +64,12 @@ def process_bucket(bucket_file_path, processed_directory, move_dir, tqdm_func, g
for line in bucket.read(): for line in bucket.read():
[ngram, document_id] = line.rsplit(" ", 1) [ngram, document_id] = line.rsplit(" ", 1)
# Write ngram if more then 10 unique document occurences # Write ngram if more then 10 unique document occurrences
if ngram != current_ngram: if ngram != current_ngram:
if len(current_ngram_document_ids) > 10: if len(current_ngram_document_ids) > 10:
output_archive.add_data(f"{current_ngram} {len(current_ngram_document_ids)}") output_archive.add_data(
f"{current_ngram} {len(current_ngram_document_ids)}"
)
current_ngram = ngram current_ngram = ngram
current_ngram_document_ids = set() current_ngram_document_ids = set()
...@@ -84,28 +94,38 @@ def process_bucket(bucket_file_path, processed_directory, move_dir, tqdm_func, g ...@@ -84,28 +94,38 @@ def process_bucket(bucket_file_path, processed_directory, move_dir, tqdm_func, g
global_tqdm.update() global_tqdm.update()
def process_sorted_buckets(working_directory, move_dir, process_count): def process_sorted_buckets(working_directory, move_dir, process_count):
bucket_file_paths = glob.glob(os.path.join(working_directory, f"*.bkt.txt.sorted")) bucket_file_paths = glob.glob(os.path.join(working_directory, f"*.bkt.txt.sorted"))
processed_directory = os.path.join(working_directory, "processed") processed_directory = os.path.join(working_directory, "processed")
os.makedirs(processed_directory, exist_ok=True) os.makedirs(processed_directory, exist_ok=True)
pool = TqdmMultiProcessPool(process_count) pool = TqdmMultiProcessPool(process_count)
tasks = [(process_bucket, (bucket_file, processed_directory, move_dir)) for bucket_file in bucket_file_paths] tasks = [
(process_bucket, (bucket_file, processed_directory, move_dir))
for bucket_file in bucket_file_paths
]
global_tqdm = tqdm(total=len(bucket_file_paths), dynamic_ncols=True, unit="bucket") global_tqdm = tqdm(total=len(bucket_file_paths), dynamic_ncols=True, unit="bucket")
on_done = lambda _ : None
on_error = lambda _ : None def on_done(_):
return None
def on_error(_):
return None
_ = pool.map(global_tqdm, tasks, on_error, on_done) _ = pool.map(global_tqdm, tasks, on_error, on_done)
parser = argparse.ArgumentParser(description='Process 13 grams from sorted buckets.')
parser = argparse.ArgumentParser(description="Process 13 grams from sorted buckets.")
parser.add_argument("-dir", "--working_directory", default="") parser.add_argument("-dir", "--working_directory", default="")
parser.add_argument("-move", "--move_dir", default="") parser.add_argument("-move", "--move_dir", default="")
parser.add_argument("-procs", "--process_count", type=int, default=4) parser.add_argument("-procs", "--process_count", type=int, default=4)
if __name__ == '__main__': if __name__ == "__main__":
logfile_path = "process13grams.log" logfile_path = "process13grams.log"
setup_logger_tqdm(logfile_path) setup_logger_tqdm(logfile_path)
args = parser.parse_args() args = parser.parse_args()
process_sorted_buckets(args.working_directory, args.move_dir, args.process_count) process_sorted_buckets(args.working_directory, args.move_dir, args.process_count)
\ No newline at end of file
""" """
Iteratively runs gnu sort on each bucket, gnu handles the multiprocessing. Iteratively runs gnu sort on each bucket, uses up to 8 cores.
Arguments Arguments
--------- ---------
...@@ -11,25 +11,27 @@ Arguments ...@@ -11,25 +11,27 @@ Arguments
import glob import glob
import argparse import argparse
import os import os
from pathlib import Path
import signal import signal
from signal import SIGINT from signal import SIGINT
import re
import subprocess import subprocess
from tqdm import tqdm from tqdm import tqdm
import logging import logging
from tqdm_multiprocess.logger import setup_logger_tqdm from tqdm_multiprocess.logger import setup_logger_tqdm
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
terminate = False terminate = False
def handler(signal_received, frame): def handler(signal_received, frame):
global terminate global terminate
terminate = True terminate = True
def sort_13_gram_buckets(working_directory): def sort_13_gram_buckets(working_directory):
bucket_file_paths = glob.glob(os.path.join(working_directory, f"*.bkt.txt")) bucket_file_paths = glob.glob(os.path.join(working_directory, f"*.bkt.txt"))
for bucket_file_path in tqdm(bucket_file_paths, dynamic_ncols=True): for bucket_file_path in tqdm(bucket_file_paths, dynamic_ncols=True):
bucket_id = re.sub("\D", "", os.path.basename(bucket_file_path)) bucket_id = re.sub("\D", "", os.path.basename(bucket_file_path))
...@@ -41,19 +43,22 @@ def sort_13_gram_buckets(working_directory): ...@@ -41,19 +43,22 @@ def sort_13_gram_buckets(working_directory):
sorted_file_path = bucket_file_path + ".sorted" sorted_file_path = bucket_file_path + ".sorted"
command = f"sort {bucket_file_path} > {sorted_file_path}" command = f"sort {bucket_file_path} > {sorted_file_path}"
logger.info(command) logger.info(command)
subprocess.call(command, shell=True) subprocess.call(command, shell=True)
if terminate: if terminate:
return return
Path(done_file).touch()
os.remove(bucket_file_path) os.remove(bucket_file_path)
parser = argparse.ArgumentParser(description='sort 13gram buckets')
parser = argparse.ArgumentParser(description="sort 13gram buckets")
parser.add_argument("-dir", "--working_directory", default="") parser.add_argument("-dir", "--working_directory", default="")
if __name__ == '__main__': if __name__ == "__main__":
version = 1.00
print(f"Running version {version}")
# Handle sigint (ctrl-c) cleanly # Handle sigint (ctrl-c) cleanly
previous_signal_int = signal.signal(SIGINT, handler) previous_signal_int = signal.signal(SIGINT, handler)
...@@ -62,4 +67,4 @@ if __name__ == '__main__': ...@@ -62,4 +67,4 @@ if __name__ == '__main__':
setup_logger_tqdm(logfile_path) setup_logger_tqdm(logfile_path)
args = parser.parse_args() args = parser.parse_args()
sort_13_gram_buckets(args.working_directory) sort_13_gram_buckets(args.working_directory)
\ No newline at end of file
...@@ -7,7 +7,7 @@ from lm_eval.base import LM ...@@ -7,7 +7,7 @@ from lm_eval.base import LM
class DryrunLM(LM): class DryrunLM(LM):
def __init__(self): def __init__(self):
self.tokencost = 0 self.tokencost = 0
self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2') self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained("gpt2")
self.tokenizer.pad_token = "<|endoftext|>" self.tokenizer.pad_token = "<|endoftext|>"
@classmethod @classmethod
...@@ -16,28 +16,28 @@ class DryrunLM(LM): ...@@ -16,28 +16,28 @@ class DryrunLM(LM):
def loglikelihood(self, requests): def loglikelihood(self, requests):
res = [] res = []
for ctx, cont in requests: for ctx, cont in requests:
res.append((-random.random(), False)) res.append((-random.random(), False))
self.tokencost += len(self.tokenizer.tokenize(ctx + cont)) self.tokencost += len(self.tokenizer.tokenize(ctx + cont))
return res return res
def greedy_until(self, requests): def greedy_until(self, requests):
res = [] res = []
for ctx, until in requests: for ctx, _ in requests:
res.append("lol") res.append("lol")
# assume worst case - generates until 256 # assume worst case - generates until 256
self.tokencost += len(self.tokenizer.tokenize(ctx)) + 256 self.tokencost += len(self.tokenizer.tokenize(ctx)) + 256
return res return res
def loglikelihood_rolling(self, requests): def loglikelihood_rolling(self, requests):
res = [] res = []
for s, in requests: for (s,) in requests:
# assume worst case: extra full context # assume worst case: extra full context
self.tokencost += len(self.tokenizer.tokenize(s)) + 2048 self.tokencost += len(self.tokenizer.tokenize(s)) + 2048
...@@ -46,7 +46,7 @@ class DryrunLM(LM): ...@@ -46,7 +46,7 @@ class DryrunLM(LM):
def main(): def main():
lm = DryrunLM() lm = DryrunLM()
task_list = "arc_challenge,arc_easy,boolq,cola,copa,headqa,hellaswag,lambada,logiqa,mathqa,mc_taco,mrpc,multirc,openbookqa,piqa,prost,pubmedqa,qnli,qqp,race,record,rte,sciq,sst,triviaqa,webqs,wic,wikitext,winogrande,wnli,wsc" task_list = "arc_challenge,arc_easy,boolq,cola,copa,headqa,hellaswag,lambada,logiqa,mathqa,mc_taco,mrpc,multirc,openbookqa,piqa,prost,pubmedqa,qnli,qqp,race,record,rte,sciq,sst,triviaqa,webqs,wic,wikitext,winogrande,wnli,wsc"
values = [] values = []
for taskname in task_list.split(","): for taskname in task_list.split(","):
...@@ -57,11 +57,20 @@ def main(): ...@@ -57,11 +57,20 @@ def main():
num_fewshot=0, num_fewshot=0,
limit=None, limit=None,
bootstrap_iters=10, bootstrap_iters=10,
description_dict=None description_dict=None,
) )
print(taskname, lm.tokencost) print(taskname, lm.tokencost)
values.append([taskname, lm.tokencost, lm.tokencost / 1000 * 0.0008, lm.tokencost / 1000 * 0.0012, lm.tokencost / 1000 * 0.006, lm.tokencost / 1000 * 0.06]) values.append(
[
taskname,
lm.tokencost,
lm.tokencost / 1000 * 0.0008,
lm.tokencost / 1000 * 0.0012,
lm.tokencost / 1000 * 0.006,
lm.tokencost / 1000 * 0.06,
]
)
from pytablewriter import MarkdownTableWriter from pytablewriter import MarkdownTableWriter
writer = MarkdownTableWriter() writer = MarkdownTableWriter()
...@@ -69,10 +78,21 @@ def main(): ...@@ -69,10 +78,21 @@ def main():
values.sort(key=lambda x: -x[1]) values.sort(key=lambda x: -x[1])
totcost = sum([x[1] for x in values]) totcost = sum([x[1] for x in values])
values.append(["**Total**", totcost, totcost / 1000 * 0.0008, totcost / 1000 * 0.0012, totcost / 1000 * 0.006, totcost / 1000 * 0.06]) values.append(
[
"**Total**",
totcost,
totcost / 1000 * 0.0008,
totcost / 1000 * 0.0012,
totcost / 1000 * 0.006,
totcost / 1000 * 0.06,
]
)
writer.value_matrix = values writer.value_matrix = values
print(writer.dumps()) print(writer.dumps())
if __name__ == "__main__": if __name__ == "__main__":
main() main()
...@@ -3,16 +3,21 @@ from itertools import islice ...@@ -3,16 +3,21 @@ from itertools import islice
ct = 3 ct = 3
for tname, Task in tasks.TASK_REGISTRY.items():#[('record', tasks.superglue.ReCoRD)]:# for (
tname,
Task,
) in tasks.TASK_REGISTRY.items(): # [('record', tasks.superglue.ReCoRD)]:#
task = Task() task = Task()
print('#', tname) print("#", tname)
docs = islice(task.validation_docs() if task.has_validation_docs() else task.test_docs(), ct) docs = islice(
task.validation_docs() if task.has_validation_docs() else task.test_docs(), ct
)
print() print()
for i in range(ct): for i in range(ct):
print() print()
doc = next(docs) doc = next(docs)
print("**Context**:", "\n```\n" + task.doc_to_text(doc) + "\n```\n") print("**Context**:", "\n```\n" + task.doc_to_text(doc) + "\n```\n")
print() print()
print('**Target**:', "\n```\n" + task.doc_to_target(doc) + "\n```\n") print("**Target**:", "\n```\n" + task.doc_to_target(doc) + "\n```\n")
print() print()
...@@ -10,7 +10,7 @@ random.seed(42) ...@@ -10,7 +10,7 @@ random.seed(42)
data = [ data = [
"A multilayer perceptron (MLP) is a class of feedforward artificial neural network (ANN)", "A multilayer perceptron (MLP) is a class of feedforward artificial neural network (ANN)",
"The term MLP is used ambiguously, sometimes loosely to any feedforward ANN, sometimes strictly to refer to networks composed of multiple layers of perceptrons (with threshold activation); see § Terminology", "The term MLP is used ambiguously, sometimes loosely to any feedforward ANN, sometimes strictly to refer to networks composed of multiple layers of perceptrons (with threshold activation); see § Terminology",
"Multilayer perceptrons are sometimes colloquially referred to as \"vanilla\" neural networks, especially when they have a single hidden layer.[1]", 'Multilayer perceptrons are sometimes colloquially referred to as "vanilla" neural networks, especially when they have a single hidden layer.[1]',
"An MLP consists of at least three layers of nodes: an input layer, a hidden layer and an output layer. Except for the input nodes, each node is a neuron that uses a nonlinear activation function.", "An MLP consists of at least three layers of nodes: an input layer, a hidden layer and an output layer. Except for the input nodes, each node is a neuron that uses a nonlinear activation function.",
"MLP utilizes a supervised learning technique called backpropagation for training.[2][3] Its multiple layers and non-linear activation distinguish MLP from a linear perceptron. It can distinguish data that is not linearly separable.[4]", "MLP utilizes a supervised learning technique called backpropagation for training.[2][3] Its multiple layers and non-linear activation distinguish MLP from a linear perceptron. It can distinguish data that is not linearly separable.[4]",
"Recent work has demonstrated substantial gains on many NLP tasks and benchmarks by pre-training on a large corpus of text followed by fine-tuning on a specific task. While typically task-agnostic in architecture, this method still requires task-specific fine-tuning datasets of thousands or tens of thousands of examples. By contrast, humans can generally perform a new language task from only a few examples or from simple instructions - something which current NLP systems still largely struggle to do. Here we show that scaling up language models greatly improves task-agnostic, few-shot performance, sometimes even reaching competitiveness with prior state-of-the-art fine-tuning approaches. ", "Recent work has demonstrated substantial gains on many NLP tasks and benchmarks by pre-training on a large corpus of text followed by fine-tuning on a specific task. While typically task-agnostic in architecture, this method still requires task-specific fine-tuning datasets of thousands or tens of thousands of examples. By contrast, humans can generally perform a new language task from only a few examples or from simple instructions - something which current NLP systems still largely struggle to do. Here we show that scaling up language models greatly improves task-agnostic, few-shot performance, sometimes even reaching competitiveness with prior state-of-the-art fine-tuning approaches. ",
...@@ -20,22 +20,28 @@ data = [ ...@@ -20,22 +20,28 @@ data = [
] ]
model = transformers.GPT2LMHeadModel.from_pretrained('gpt2') model = transformers.GPT2LMHeadModel.from_pretrained("gpt2")
tok = transformers.GPT2Tokenizer.from_pretrained('gpt2') tok = transformers.GPT2Tokenizer.from_pretrained("gpt2")
tgs = [] tgs = []
for dat in data: for dat in data:
random.seed(dat) random.seed(dat)
#print(model(tok.encode(dat, return_tensors="pt"))[0][0]) # print(model(tok.encode(dat, return_tensors="pt"))[0][0])
toks = tok.encode(dat, return_tensors="pt") toks = tok.encode(dat, return_tensors="pt")
ind = random.randrange(len(toks[0])-1) ind = random.randrange(len(toks[0]) - 1)
logits = F.log_softmax(model(toks)[0], dim=-1)[:, :-1] # [batch, seq, vocab] logits = F.log_softmax(model(toks)[0], dim=-1)[:, :-1] # [batch, seq, vocab]
res = torch.gather(logits, 2, toks[:, 1:].unsqueeze(-1)).squeeze(-1)[0] res = torch.gather(logits, 2, toks[:, 1:].unsqueeze(-1)).squeeze(-1)[0]
tgs.append( float(res[ind:].sum())) tgs.append(float(res[ind:].sum()))
print(r'("""' + tok.decode(toks[0, :ind+1]) + r'""", """' + tok.decode(toks[0, ind+1:]) + r'"""), ') print(
r'("""'
+ tok.decode(toks[0, : ind + 1])
+ r'""", """'
+ tok.decode(toks[0, ind + 1 :])
+ r'"""), '
)
print(tgs) print(tgs)
\ No newline at end of file
"""
Usage:
python make_table_tasks.py --output <markdown_filename>
"""
import argparse
import logging
from lm_eval import tasks from lm_eval import tasks
from pytablewriter import MarkdownTableWriter from pytablewriter import MarkdownTableWriter
writer = MarkdownTableWriter()
writer.headers = ["Task Name", "Train", "Val", "Test","Val/Test Docs", "Metrics"]
values = [] logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def chk(tf):
def check(tf):
if tf: if tf:
return '✓' return "✓"
else: else:
return ' ' return " "
for tname, Task in tasks.TASK_REGISTRY.items():
task = Task()
v = [tname,chk(task.has_training_docs()),chk(task.has_validation_docs()),chk(task.has_test_docs()), len(list(task.test_docs() if task.has_test_docs() else task.validation_docs())),', '.join(task.aggregation().keys())] if __name__ == "__main__":
print(v) parser = argparse.ArgumentParser()
values.append(v) parser.add_argument("--output", type=str, default="task_table.md")
args = parser.parse_args()
writer.value_matrix = values writer = MarkdownTableWriter()
writer.headers = ["Task Name", "Train", "Val", "Test", "Val/Test Docs", "Metrics"]
values = []
print(writer.dumps()) tasks = tasks.TASK_REGISTRY.items()
\ No newline at end of file tasks = sorted(tasks, key=lambda x: x[0])
for tname, Task in tasks:
task = Task()
v = [
tname,
check(task.has_training_docs()),
check(task.has_validation_docs()),
check(task.has_test_docs()),
len(
list(
task.test_docs() if task.has_test_docs() else task.validation_docs()
)
),
", ".join(task.aggregation().keys()),
]
logger.info(v)
values.append(v)
writer.value_matrix = values
table = writer.dumps()
with open(args.output, "w") as f:
f.write(table)
...@@ -11,14 +11,14 @@ EXAMPLE_DIVIDER = "!!@@##@@!! -- Example {i}\n" ...@@ -11,14 +11,14 @@ EXAMPLE_DIVIDER = "!!@@##@@!! -- Example {i}\n"
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--output_base_path', required=True) parser.add_argument("--output_base_path", required=True)
parser.add_argument('--tasks', default="all_tasks") parser.add_argument("--tasks", default="all_tasks")
parser.add_argument('--provide_description', action="store_true") parser.add_argument("--provide_description", action="store_true")
parser.add_argument('--sets', type=str, default="val") # example: val,test parser.add_argument("--sets", type=str, default="val") # example: val,test
parser.add_argument('--num_fewshot', type=int, default=1) parser.add_argument("--num_fewshot", type=int, default=1)
parser.add_argument('--seed', type=int, default=42) parser.add_argument("--seed", type=int, default=42)
parser.add_argument('--num_examples', type=int, default=1) parser.add_argument("--num_examples", type=int, default=1)
parser.add_argument('--description_dict_path', default=None) parser.add_argument("--description_dict_path", default=None)
return parser.parse_args() return parser.parse_args()
...@@ -34,7 +34,7 @@ def main(): ...@@ -34,7 +34,7 @@ def main():
description_dict = {} description_dict = {}
if args.description_dict_path: if args.description_dict_path:
with open(args.description_dict_path, 'r') as f: with open(args.description_dict_path, "r") as f:
description_dict = json.load(f) description_dict = json.load(f)
os.makedirs(args.output_base_path, exist_ok=True) os.makedirs(args.output_base_path, exist_ok=True)
...@@ -45,26 +45,34 @@ def main(): ...@@ -45,26 +45,34 @@ def main():
iters = [] iters = []
for set in args.sets.split(","): for set in args.sets.split(","):
if set == 'train' and task.has_training_docs(): if set == "train" and task.has_training_docs():
docs = task.training_docs() docs = task.training_docs()
if set == 'val' and task.has_validation_docs(): if set == "val" and task.has_validation_docs():
docs = task.validation_docs() docs = task.validation_docs()
if set == 'test' and task.has_test_docs(): if set == "test" and task.has_test_docs():
docs = task.test_docs() docs = task.test_docs()
iters.append(docs) iters.append(docs)
docs = join_iters(iters) docs = join_iters(iters)
description = description_dict[task_name] if description_dict and task_name in description_dict else "" description = (
description_dict[task_name]
if description_dict and task_name in description_dict
else ""
)
with open(os.path.join(args.output_base_path, task_name), "w") as f: with open(os.path.join(args.output_base_path, task_name), "w") as f:
for i, doc in zip(range(args.num_examples), docs) if args.num_examples > 0 else enumerate(docs): for i, doc in (
zip(range(args.num_examples), docs)
if args.num_examples > 0
else enumerate(docs)
):
f.write(EXAMPLE_DIVIDER.format(i=i)) f.write(EXAMPLE_DIVIDER.format(i=i))
ctx = task.fewshot_context( ctx = task.fewshot_context(
doc=doc, doc=doc,
num_fewshot=args.num_fewshot, num_fewshot=args.num_fewshot,
rnd=rnd, rnd=rnd,
description=description description=description,
) )
f.write(ctx + "\n") f.write(ctx + "\n")
......
...@@ -5,7 +5,7 @@ with open("README.md", "r", encoding="utf-8") as fh: ...@@ -5,7 +5,7 @@ with open("README.md", "r", encoding="utf-8") as fh:
setuptools.setup( setuptools.setup(
name="lm_eval", name="lm_eval",
version="0.2.0", version="0.3.0",
author="Leo Gao", author="Leo Gao",
author_email="lg@eleuther.ai", author_email="lg@eleuther.ai",
description="A framework for evaluating autoregressive language models", description="A framework for evaluating autoregressive language models",
...@@ -14,37 +14,35 @@ setuptools.setup( ...@@ -14,37 +14,35 @@ setuptools.setup(
url="https://github.com/EleutherAI/lm-evaluation-harness", url="https://github.com/EleutherAI/lm-evaluation-harness",
packages=setuptools.find_packages(), packages=setuptools.find_packages(),
classifiers=[ classifiers=[
"Development Status :: 3 - Alpha",
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License", "License :: OSI Approved :: MIT License",
"Operating System :: OS Independent", "Operating System :: OS Independent",
], ],
python_requires='>=3.6', python_requires=">=3.9",
install_requires=[ install_requires=[
"black", "datasets>=2.0.0",
"datasets==2.0.0", "jsonlines",
"click>=7.1", "numexpr",
"openai>=0.6.4",
"omegaconf>=2.2",
"peft>=0.2.0",
"pybind11>=2.6.2",
"pycountry",
"pytablewriter",
"rouge-score>=0.0.4",
"sacrebleu==1.5.0",
"scikit-learn>=0.24.1", "scikit-learn>=0.24.1",
"sqlitedict",
"torch>=1.7", "torch>=1.7",
"tqdm-multiprocess",
"transformers>=4.1", "transformers>=4.1",
"sqlitedict==1.6.0", "zstandard",
"pytablewriter==0.58.0", "accelerate>=0.17.1"
"sacrebleu==1.5.0",
"rouge-score==0.0.4",
"pycountry==20.7.3",
"numexpr==2.7.2",
"lm_dataformat==0.0.20",
"pytest==6.2.3",
"pybind11==2.6.2",
"tqdm-multiprocess==0.0.11",
"zstandard==0.15.2",
"jsonlines==2.0.0",
"mock==4.0.3",
"openai==0.6.4",
"jieba==0.42.1",
"nagisa==0.2.7",
"bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt"
], ],
dependency_links=[ extras_require={
"https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt", "dev": ["black", "flake8", "pre-commit", "pytest", "pytest-cov"],
] "multilingual": ["nagisa>=0.2.7", "jieba>=0.42.1"],
"sentencepiece": ["sentencepiece>=0.1.98", "protobuf>=4.22.1"],
},
) )
# TODO: Remove all TODO comments once the implementation is complete.
"""
TODO: Add the Paper Title on this line.
TODO: Add the paper's PDF URL (preferably from arXiv) on this line.
TODO: Write a Short Description of the task.
Homepage: TODO: Add the URL to the task's Homepage here.
"""
from lm_eval.base import MultipleChoiceTask
# TODO: Add the BibTeX citation for the task.
_CITATION = """
"""
# TODO: Replace `NewTask` with the name of your Task.
class NewTask(MultipleChoiceTask):
VERSION = 0
# TODO: Add the `DATASET_PATH` string. This will be the name of the `Task`
# dataset as denoted in HuggingFace `datasets`.
DATASET_PATH = ""
# TODO: Add the `DATASET_NAME` string. This is the name of a subset within
# `DATASET_PATH`. If there aren't specific subsets you need, leave this as `None`.
DATASET_NAME = None
def has_training_docs(self):
# TODO: Fill in the return with `True` if the Task has training data; else `False`.
return False
def has_validation_docs(self):
# TODO: Fill in the return with `True` if the Task has validation data; else `False`.
return False
def has_test_docs(self):
# TODO: Fill in the return with `True` if the Task has test data; else `False`.
return False
def training_docs(self):
if self.has_training_docs():
# We cache training documents in `self._training_docs` for faster
# few-shot processing. If the data is too large to fit in memory,
# return the training data as a generator instead of a list.
if self._training_docs is None:
# TODO: Return the training document generator from `self.dataset`.
# In most case you can leave this as is unless the dataset split is
# named differently than the default `"train"`.
self._training_docs = list(
map(self._process_doc, self.dataset["train"])
)
return self._training_docs
def validation_docs(self):
if self.has_validation_docs():
# TODO: Return the validation document generator from `self.dataset`.
# In most case you can leave this as is unless the dataset split is
# named differently than the default `"validation"`.
return map(self._process_doc, self.dataset["validation"])
def test_docs(self):
if self.has_test_docs():
# TODO: Return the test document generator from `self.dataset`.
# In most case you can leave this as is unless the dataset split is
# named differently than the default `"test"`.
return map(self._process_doc, self.dataset["test"])
def _process_doc(self, doc):
# TODO: Process the documents into a dictionary with the following keys:
return {
"query": "", # The query prompt.
"choices": [], # The list of choices.
"gold": 0, # The integer used to index into the correct element of `"choices"`.
}
def doc_to_text(self, doc):
# TODO: Format the query prompt portion of the document example.
return doc["query"]
# TODO: Remove all TODO comments once the implementation is complete.
"""
TODO: Add the Paper Title on this line.
TODO: Add the paper's PDF URL (preferably from arXiv) on this line.
TODO: Write a Short Description of the task.
Homepage: TODO: Add the URL to the task's Homepage here.
"""
from lm_eval.base import Task
# TODO: Add the BibTeX citation for the task.
_CITATION = """
"""
# TODO: Replace `NewTask` with the name of your Task.
class NewTask(Task):
VERSION = 0
# TODO: Add the `DATASET_PATH` string. This will be the name of the `Task`
# dataset as denoted in HuggingFace `datasets`.
DATASET_PATH = ""
# TODO: Add the `DATASET_NAME` string. This is the name of a subset within
# `DATASET_PATH`. If there aren't specific subsets you need, leave this as `None`.
DATASET_NAME = None
def has_training_docs(self):
# TODO: Fill in the return with `True` if the Task has training data; else `False`.
return False
def has_validation_docs(self):
# TODO: Fill in the return with `True` if the Task has validation data; else `False`.
return False
def has_test_docs(self):
# TODO: Fill in the return with `True` if the Task has test data; else `False`.
return False
def training_docs(self):
if self.has_training_docs():
# We cache training documents in `self._training_docs` for faster
# few-shot processing. If the data is too large to fit in memory,
# return the training data as a generator instead of a list.
if self._training_docs is None:
# TODO: Return the training document generator from `self.dataset`.
# If you need to process the data, `map` over the documents with
# the custom processing function, `self._process_doc`. E.g.
# `map(self._process_doc, self.dataset["validation"])`
# In most case you can leave this as is unless the dataset split is
# named differently than the default `"train"`.
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
if self.has_validation_docs():
# TODO: Return the validation document generator from `self.dataset`.
# If you need to process the data, `map` over the documents with the
# custom processing function, `self._process_doc`. E.g.
# `map(self._process_doc, self.dataset["validation"])`
# In most case you can leave this as is unless the dataset split is
# named differently than the default `"validation"`.
return self.dataset["validation"]
def test_docs(self):
if self.has_test_docs():
# TODO: Return the test document generator from `self.dataset`.
# If you need to process the data, `map` over the documents with the
# custom processing function, `self._process_doc`. E.g.
# `map(self._process_doc, self.dataset["test"])`
# In most case you can leave this as is unless the dataset split is
# named differently than the default `"test"`.
return self.dataset["test"]
def _process_doc(self, doc):
# TODO: Process (detokenize, strip, replace etc.) each individual `doc`
# with this function. You can map this across the docs in each available
# dataset split. See the TODOs in `train_docs`, `validation_docs`, and
# `test_docs` for snippets.
# NOTE: DELETE THIS FUNCTION IF UNUSED.
return doc
def doc_to_text(self, doc):
# TODO: Format the query prompt portion of the document example.
return ""
def doc_to_target(self, doc):
# TODO: Fill in the `target` ("gold answer") variable.
# The prepended `" "` is required to space out the `doc_to_text` and
# `doc_to_target` strings.
target = ""
return " " + target
def construct_requests(self, doc, ctx):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or
test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
# TODO: Construct your language model requests with the request factory, `rf`,
# and return them as an iterable.
return []
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
The results of the requests created in construct_requests.
"""
# TODO: For each (sub)metric in the task evaluation, add a key-value pair
# with the metric name as key and the corresponding metric result as value
# for the current `doc`.
return {}
def aggregation(self):
"""
:returns: {str: [metric_score] -> float}
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metric scores
"""
# TODO: For each (sub)metric in the task evaluation, add a key-value pair
# with the metric name as key and an aggregation function as value which
# determines how to combine results from each document in the dataset.
# Check `lm_eval.metrics` to find built-in aggregation functions.
return {}
def higher_is_better(self):
# TODO: For each (sub)metric in the task evaluation, add a key-value pair
# with the metric name as key and a `bool` value determining whether or
# not higher values of that metric are deemed better.
return {}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment