Commit baa8b0d3 authored by bzantium's avatar bzantium
Browse files

fix for merge from master

parent a956bc63
...@@ -3,16 +3,21 @@ from itertools import islice ...@@ -3,16 +3,21 @@ from itertools import islice
ct = 3 ct = 3
for tname, Task in tasks.TASK_REGISTRY.items():#[('record', tasks.superglue.ReCoRD)]:# for (
tname,
Task,
) in tasks.TASK_REGISTRY.items(): # [('record', tasks.superglue.ReCoRD)]:#
task = Task() task = Task()
print('#', tname) print("#", tname)
docs = islice(task.validation_docs() if task.has_validation_docs() else task.test_docs(), ct) docs = islice(
task.validation_docs() if task.has_validation_docs() else task.test_docs(), ct
)
print() print()
for i in range(ct): for i in range(ct):
print() print()
doc = next(docs) doc = next(docs)
print("**Context**:", "\n```\n" + task.doc_to_text(doc) + "\n```\n") print("**Context**:", "\n```\n" + task.doc_to_text(doc) + "\n```\n")
print() print()
print('**Target**:', "\n```\n" + task.doc_to_target(doc) + "\n```\n") print("**Target**:", "\n```\n" + task.doc_to_target(doc) + "\n```\n")
print() print()
...@@ -10,7 +10,7 @@ random.seed(42) ...@@ -10,7 +10,7 @@ random.seed(42)
data = [ data = [
"A multilayer perceptron (MLP) is a class of feedforward artificial neural network (ANN)", "A multilayer perceptron (MLP) is a class of feedforward artificial neural network (ANN)",
"The term MLP is used ambiguously, sometimes loosely to any feedforward ANN, sometimes strictly to refer to networks composed of multiple layers of perceptrons (with threshold activation); see § Terminology", "The term MLP is used ambiguously, sometimes loosely to any feedforward ANN, sometimes strictly to refer to networks composed of multiple layers of perceptrons (with threshold activation); see § Terminology",
"Multilayer perceptrons are sometimes colloquially referred to as \"vanilla\" neural networks, especially when they have a single hidden layer.[1]", 'Multilayer perceptrons are sometimes colloquially referred to as "vanilla" neural networks, especially when they have a single hidden layer.[1]',
"An MLP consists of at least three layers of nodes: an input layer, a hidden layer and an output layer. Except for the input nodes, each node is a neuron that uses a nonlinear activation function.", "An MLP consists of at least three layers of nodes: an input layer, a hidden layer and an output layer. Except for the input nodes, each node is a neuron that uses a nonlinear activation function.",
"MLP utilizes a supervised learning technique called backpropagation for training.[2][3] Its multiple layers and non-linear activation distinguish MLP from a linear perceptron. It can distinguish data that is not linearly separable.[4]", "MLP utilizes a supervised learning technique called backpropagation for training.[2][3] Its multiple layers and non-linear activation distinguish MLP from a linear perceptron. It can distinguish data that is not linearly separable.[4]",
"Recent work has demonstrated substantial gains on many NLP tasks and benchmarks by pre-training on a large corpus of text followed by fine-tuning on a specific task. While typically task-agnostic in architecture, this method still requires task-specific fine-tuning datasets of thousands or tens of thousands of examples. By contrast, humans can generally perform a new language task from only a few examples or from simple instructions - something which current NLP systems still largely struggle to do. Here we show that scaling up language models greatly improves task-agnostic, few-shot performance, sometimes even reaching competitiveness with prior state-of-the-art fine-tuning approaches. ", "Recent work has demonstrated substantial gains on many NLP tasks and benchmarks by pre-training on a large corpus of text followed by fine-tuning on a specific task. While typically task-agnostic in architecture, this method still requires task-specific fine-tuning datasets of thousands or tens of thousands of examples. By contrast, humans can generally perform a new language task from only a few examples or from simple instructions - something which current NLP systems still largely struggle to do. Here we show that scaling up language models greatly improves task-agnostic, few-shot performance, sometimes even reaching competitiveness with prior state-of-the-art fine-tuning approaches. ",
...@@ -20,22 +20,28 @@ data = [ ...@@ -20,22 +20,28 @@ data = [
] ]
model = transformers.GPT2LMHeadModel.from_pretrained('gpt2') model = transformers.GPT2LMHeadModel.from_pretrained("gpt2")
tok = transformers.GPT2Tokenizer.from_pretrained('gpt2') tok = transformers.GPT2Tokenizer.from_pretrained("gpt2")
tgs = [] tgs = []
for dat in data: for dat in data:
random.seed(dat) random.seed(dat)
#print(model(tok.encode(dat, return_tensors="pt"))[0][0]) # print(model(tok.encode(dat, return_tensors="pt"))[0][0])
toks = tok.encode(dat, return_tensors="pt") toks = tok.encode(dat, return_tensors="pt")
ind = random.randrange(len(toks[0])-1) ind = random.randrange(len(toks[0]) - 1)
logits = F.log_softmax(model(toks)[0], dim=-1)[:, :-1] # [batch, seq, vocab] logits = F.log_softmax(model(toks)[0], dim=-1)[:, :-1] # [batch, seq, vocab]
res = torch.gather(logits, 2, toks[:, 1:].unsqueeze(-1)).squeeze(-1)[0] res = torch.gather(logits, 2, toks[:, 1:].unsqueeze(-1)).squeeze(-1)[0]
tgs.append( float(res[ind:].sum())) tgs.append(float(res[ind:].sum()))
print(r'("""' + tok.decode(toks[0, :ind+1]) + r'""", """' + tok.decode(toks[0, ind+1:]) + r'"""), ') print(
r'("""'
+ tok.decode(toks[0, : ind + 1])
+ r'""", """'
+ tok.decode(toks[0, ind + 1 :])
+ r'"""), '
)
print(tgs) print(tgs)
\ No newline at end of file
"""
Usage:
python make_table_tasks.py --output <markdown_filename>
"""
import argparse
import logging
from lm_eval import tasks from lm_eval import tasks
from pytablewriter import MarkdownTableWriter from pytablewriter import MarkdownTableWriter
writer = MarkdownTableWriter()
writer.headers = ["Task Name", "Train", "Val", "Test","Val/Test Docs", "Metrics"]
values = [] logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def chk(tf):
def check(tf):
if tf: if tf:
return '✓' return "✓"
else: else:
return ' ' return " "
for tname, Task in tasks.TASK_REGISTRY.items():
task = Task()
v = [tname,chk(task.has_training_docs()),chk(task.has_validation_docs()),chk(task.has_test_docs()), len(list(task.test_docs() if task.has_test_docs() else task.validation_docs())),', '.join(task.aggregation().keys())] if __name__ == "__main__":
print(v) parser = argparse.ArgumentParser()
values.append(v) parser.add_argument("--output", type=str, default="task_table.md")
args = parser.parse_args()
writer.value_matrix = values writer = MarkdownTableWriter()
writer.headers = ["Task Name", "Train", "Val", "Test", "Val/Test Docs", "Metrics"]
values = []
print(writer.dumps()) tasks = tasks.TASK_REGISTRY.items()
\ No newline at end of file tasks = sorted(tasks, key=lambda x: x[0])
for tname, Task in tasks:
task = Task()
v = [
tname,
check(task.has_training_docs()),
check(task.has_validation_docs()),
check(task.has_test_docs()),
len(
list(
task.test_docs() if task.has_test_docs() else task.validation_docs()
)
),
", ".join(task.aggregation().keys()),
]
logger.info(v)
values.append(v)
writer.value_matrix = values
table = writer.dumps()
with open(args.output, "w") as f:
f.write(table)
...@@ -11,14 +11,14 @@ EXAMPLE_DIVIDER = "!!@@##@@!! -- Example {i}\n" ...@@ -11,14 +11,14 @@ EXAMPLE_DIVIDER = "!!@@##@@!! -- Example {i}\n"
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--output_base_path', required=True) parser.add_argument("--output_base_path", required=True)
parser.add_argument('--tasks', default="all_tasks") parser.add_argument("--tasks", default="all_tasks")
parser.add_argument('--provide_description', action="store_true") parser.add_argument("--provide_description", action="store_true")
parser.add_argument('--sets', type=str, default="val") # example: val,test parser.add_argument("--sets", type=str, default="val") # example: val,test
parser.add_argument('--num_fewshot', type=int, default=1) parser.add_argument("--num_fewshot", type=int, default=1)
parser.add_argument('--seed', type=int, default=42) parser.add_argument("--seed", type=int, default=42)
parser.add_argument('--num_examples', type=int, default=1) parser.add_argument("--num_examples", type=int, default=1)
parser.add_argument('--description_dict_path', default=None) parser.add_argument("--description_dict_path", default=None)
return parser.parse_args() return parser.parse_args()
...@@ -34,7 +34,7 @@ def main(): ...@@ -34,7 +34,7 @@ def main():
description_dict = {} description_dict = {}
if args.description_dict_path: if args.description_dict_path:
with open(args.description_dict_path, 'r') as f: with open(args.description_dict_path, "r") as f:
description_dict = json.load(f) description_dict = json.load(f)
os.makedirs(args.output_base_path, exist_ok=True) os.makedirs(args.output_base_path, exist_ok=True)
...@@ -45,26 +45,34 @@ def main(): ...@@ -45,26 +45,34 @@ def main():
iters = [] iters = []
for set in args.sets.split(","): for set in args.sets.split(","):
if set == 'train' and task.has_training_docs(): if set == "train" and task.has_training_docs():
docs = task.training_docs() docs = task.training_docs()
if set == 'val' and task.has_validation_docs(): if set == "val" and task.has_validation_docs():
docs = task.validation_docs() docs = task.validation_docs()
if set == 'test' and task.has_test_docs(): if set == "test" and task.has_test_docs():
docs = task.test_docs() docs = task.test_docs()
iters.append(docs) iters.append(docs)
docs = join_iters(iters) docs = join_iters(iters)
description = description_dict[task_name] if description_dict and task_name in description_dict else "" description = (
description_dict[task_name]
if description_dict and task_name in description_dict
else ""
)
with open(os.path.join(args.output_base_path, task_name), "w") as f: with open(os.path.join(args.output_base_path, task_name), "w") as f:
for i, doc in zip(range(args.num_examples), docs) if args.num_examples > 0 else enumerate(docs): for i, doc in (
zip(range(args.num_examples), docs)
if args.num_examples > 0
else enumerate(docs)
):
f.write(EXAMPLE_DIVIDER.format(i=i)) f.write(EXAMPLE_DIVIDER.format(i=i))
ctx = task.fewshot_context( ctx = task.fewshot_context(
doc=doc, doc=doc,
num_fewshot=args.num_fewshot, num_fewshot=args.num_fewshot,
rnd=rnd, rnd=rnd,
description=description description=description,
) )
f.write(ctx + "\n") f.write(ctx + "\n")
......
...@@ -5,7 +5,7 @@ with open("README.md", "r", encoding="utf-8") as fh: ...@@ -5,7 +5,7 @@ with open("README.md", "r", encoding="utf-8") as fh:
setuptools.setup( setuptools.setup(
name="lm_eval", name="lm_eval",
version="0.2.0", version="0.3.0",
author="Leo Gao", author="Leo Gao",
author_email="lg@eleuther.ai", author_email="lg@eleuther.ai",
description="A framework for evaluating autoregressive language models", description="A framework for evaluating autoregressive language models",
...@@ -14,37 +14,35 @@ setuptools.setup( ...@@ -14,37 +14,35 @@ setuptools.setup(
url="https://github.com/EleutherAI/lm-evaluation-harness", url="https://github.com/EleutherAI/lm-evaluation-harness",
packages=setuptools.find_packages(), packages=setuptools.find_packages(),
classifiers=[ classifiers=[
"Development Status :: 3 - Alpha",
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License", "License :: OSI Approved :: MIT License",
"Operating System :: OS Independent", "Operating System :: OS Independent",
], ],
python_requires='>=3.6', python_requires=">=3.9",
install_requires=[ install_requires=[
"black", "datasets>=2.0.0",
"datasets==2.0.0", "jsonlines",
"click>=7.1", "numexpr",
"openai>=0.6.4",
"omegaconf>=2.2",
"peft>=0.2.0",
"pybind11>=2.6.2",
"pycountry",
"pytablewriter",
"rouge-score>=0.0.4",
"sacrebleu==1.5.0",
"scikit-learn>=0.24.1", "scikit-learn>=0.24.1",
"sqlitedict",
"torch>=1.7", "torch>=1.7",
"tqdm-multiprocess",
"transformers>=4.1", "transformers>=4.1",
"sqlitedict==1.6.0", "zstandard",
"pytablewriter==0.58.0", "accelerate>=0.17.1",
"sacrebleu==1.5.0",
"rouge-score==0.0.4",
"pycountry==20.7.3",
"numexpr==2.7.2",
"lm_dataformat==0.0.20",
"pytest==6.2.3",
"pybind11==2.6.2",
"tqdm-multiprocess==0.0.11",
"zstandard==0.15.2",
"jsonlines==2.0.0",
"mock==4.0.3",
"openai==0.6.4",
"jieba==0.42.1",
"nagisa==0.2.7",
"bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt"
], ],
dependency_links=[ extras_require={
"https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt", "dev": ["black", "flake8", "pre-commit", "pytest", "pytest-cov"],
] "multilingual": ["nagisa>=0.2.7", "jieba>=0.42.1"],
"sentencepiece": ["sentencepiece>=0.1.98", "protobuf>=4.22.1"],
},
) )
# TODO: Remove all TODO comments once the implementation is complete.
"""
TODO: Add the Paper Title on this line.
TODO: Add the paper's PDF URL (preferably from arXiv) on this line.
TODO: Write a Short Description of the task.
Homepage: TODO: Add the URL to the task's Homepage here.
"""
from lm_eval.base import MultipleChoiceTask
# TODO: Add the BibTeX citation for the task.
_CITATION = """
"""
# TODO: Replace `NewTask` with the name of your Task.
class NewTask(MultipleChoiceTask):
VERSION = 0
# TODO: Add the `DATASET_PATH` string. This will be the name of the `Task`
# dataset as denoted in HuggingFace `datasets`.
DATASET_PATH = ""
# TODO: Add the `DATASET_NAME` string. This is the name of a subset within
# `DATASET_PATH`. If there aren't specific subsets you need, leave this as `None`.
DATASET_NAME = None
def has_training_docs(self):
# TODO: Fill in the return with `True` if the Task has training data; else `False`.
return False
def has_validation_docs(self):
# TODO: Fill in the return with `True` if the Task has validation data; else `False`.
return False
def has_test_docs(self):
# TODO: Fill in the return with `True` if the Task has test data; else `False`.
return False
def training_docs(self):
if self.has_training_docs():
# We cache training documents in `self._training_docs` for faster
# few-shot processing. If the data is too large to fit in memory,
# return the training data as a generator instead of a list.
if self._training_docs is None:
# TODO: Return the training document generator from `self.dataset`.
# In most case you can leave this as is unless the dataset split is
# named differently than the default `"train"`.
self._training_docs = list(
map(self._process_doc, self.dataset["train"])
)
return self._training_docs
def validation_docs(self):
if self.has_validation_docs():
# TODO: Return the validation document generator from `self.dataset`.
# In most case you can leave this as is unless the dataset split is
# named differently than the default `"validation"`.
return map(self._process_doc, self.dataset["validation"])
def test_docs(self):
if self.has_test_docs():
# TODO: Return the test document generator from `self.dataset`.
# In most case you can leave this as is unless the dataset split is
# named differently than the default `"test"`.
return map(self._process_doc, self.dataset["test"])
def _process_doc(self, doc):
# TODO: Process the documents into a dictionary with the following keys:
return {
"query": "", # The query prompt.
"choices": [], # The list of choices.
"gold": 0, # The integer used to index into the correct element of `"choices"`.
}
def doc_to_text(self, doc):
# TODO: Format the query prompt portion of the document example.
return doc["query"]
# TODO: Remove all TODO comments once the implementation is complete.
"""
TODO: Add the Paper Title on this line.
TODO: Add the paper's PDF URL (preferably from arXiv) on this line.
TODO: Write a Short Description of the task.
Homepage: TODO: Add the URL to the task's Homepage here.
"""
from lm_eval.base import Task
# TODO: Add the BibTeX citation for the task.
_CITATION = """
"""
# TODO: Replace `NewTask` with the name of your Task.
class NewTask(Task):
VERSION = 0
# TODO: Add the `DATASET_PATH` string. This will be the name of the `Task`
# dataset as denoted in HuggingFace `datasets`.
DATASET_PATH = ""
# TODO: Add the `DATASET_NAME` string. This is the name of a subset within
# `DATASET_PATH`. If there aren't specific subsets you need, leave this as `None`.
DATASET_NAME = None
def has_training_docs(self):
# TODO: Fill in the return with `True` if the Task has training data; else `False`.
return False
def has_validation_docs(self):
# TODO: Fill in the return with `True` if the Task has validation data; else `False`.
return False
def has_test_docs(self):
# TODO: Fill in the return with `True` if the Task has test data; else `False`.
return False
def training_docs(self):
if self.has_training_docs():
# We cache training documents in `self._training_docs` for faster
# few-shot processing. If the data is too large to fit in memory,
# return the training data as a generator instead of a list.
if self._training_docs is None:
# TODO: Return the training document generator from `self.dataset`.
# If you need to process the data, `map` over the documents with
# the custom processing function, `self._process_doc`. E.g.
# `map(self._process_doc, self.dataset["validation"])`
# In most case you can leave this as is unless the dataset split is
# named differently than the default `"train"`.
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
if self.has_validation_docs():
# TODO: Return the validation document generator from `self.dataset`.
# If you need to process the data, `map` over the documents with the
# custom processing function, `self._process_doc`. E.g.
# `map(self._process_doc, self.dataset["validation"])`
# In most case you can leave this as is unless the dataset split is
# named differently than the default `"validation"`.
return self.dataset["validation"]
def test_docs(self):
if self.has_test_docs():
# TODO: Return the test document generator from `self.dataset`.
# If you need to process the data, `map` over the documents with the
# custom processing function, `self._process_doc`. E.g.
# `map(self._process_doc, self.dataset["test"])`
# In most case you can leave this as is unless the dataset split is
# named differently than the default `"test"`.
return self.dataset["test"]
def _process_doc(self, doc):
# TODO: Process (detokenize, strip, replace etc.) each individual `doc`
# with this function. You can map this across the docs in each available
# dataset split. See the TODOs in `train_docs`, `validation_docs`, and
# `test_docs` for snippets.
# NOTE: DELETE THIS FUNCTION IF UNUSED.
return doc
def doc_to_text(self, doc):
# TODO: Format the query prompt portion of the document example.
return ""
def doc_to_target(self, doc):
# TODO: Fill in the `target` ("gold answer") variable.
# The prepended `" "` is required to space out the `doc_to_text` and
# `doc_to_target` strings.
target = ""
return " " + target
def construct_requests(self, doc, ctx):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or
test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
# TODO: Construct your language model requests with the request factory, `rf`,
# and return them as an iterable.
return []
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
The results of the requests created in construct_requests.
"""
# TODO: For each (sub)metric in the task evaluation, add a key-value pair
# with the metric name as key and the corresponding metric result as value
# for the current `doc`.
return {}
def aggregation(self):
"""
:returns: {str: [metric_score] -> float}
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metric scores
"""
# TODO: For each (sub)metric in the task evaluation, add a key-value pair
# with the metric name as key and an aggregation function as value which
# determines how to combine results from each document in the dataset.
# Check `lm_eval.metrics` to find built-in aggregation functions.
return {}
def higher_is_better(self):
# TODO: For each (sub)metric in the task evaluation, add a key-value pair
# with the metric name as key and a `bool` value determining whether or
# not higher values of that metric are deemed better.
return {}
...@@ -10,23 +10,24 @@ import pytest ...@@ -10,23 +10,24 @@ import pytest
# TODO: more fine grained unit tests rather than this big honking integration # TODO: more fine grained unit tests rather than this big honking integration
# test once we break evaluator into smaller, more manageable pieces # test once we break evaluator into smaller, more manageable pieces
@pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items()) @pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items())
def test_evaluator(taskname, task_class): def test_evaluator(taskname, task_class):
task_dict = tasks.get_task_dict([taskname]) task_dict = tasks.get_task_dict([taskname])
os.system("rm test_cache.db") os.system("rm test_cache.db")
lm = base.CachingLM(models.get_model('dummy')(), "test_cache.db") lm = base.CachingLM(models.get_model("dummy")(), "test_cache.db")
def ll_fn(reqs): def ll_fn(reqs):
for ctx, cont in reqs: for ctx, cont in reqs:
if len(ctx) == 0: if len(ctx) == 0:
continue continue
# space convention # space convention
assert ctx[-1] != ' ' assert ctx[-1] != " "
assert cont[0] == ' ' or ctx[-1] == '\n' assert cont[0] == " " or ctx[-1] == "\n"
res = [] res = []
random.seed(42) random.seed(42)
for _ in reqs: for _ in reqs:
res.append((-random.random(), False)) res.append((-random.random(), False))
...@@ -34,7 +35,7 @@ def test_evaluator(taskname, task_class): ...@@ -34,7 +35,7 @@ def test_evaluator(taskname, task_class):
return res return res
def ll_perp_fn(reqs): def ll_perp_fn(reqs):
for string, in reqs: for (string,) in reqs:
assert isinstance(string, str) assert isinstance(string, str)
res = [] res = []
...@@ -49,20 +50,20 @@ def test_evaluator(taskname, task_class): ...@@ -49,20 +50,20 @@ def test_evaluator(taskname, task_class):
limit = 10 limit = 10
e1 = evaluator.evaluate( e1 = evaluator.evaluate(
lm=lm, lm=lm,
task_dict=task_dict, task_dict=task_dict,
num_fewshot=0, num_fewshot=0,
limit=limit, limit=limit,
bootstrap_iters=10, bootstrap_iters=10,
description_dict=None description_dict=None,
) )
e2 = evaluator.evaluate( e2 = evaluator.evaluate(
lm=lm, lm=lm,
task_dict=task_dict, task_dict=task_dict,
num_fewshot=0, num_fewshot=0,
limit=limit, limit=limit,
bootstrap_iters=10, bootstrap_iters=10,
description_dict=None description_dict=None,
) )
# check that caching is working # check that caching is working
......
...@@ -3,27 +3,32 @@ from collections import Counter ...@@ -3,27 +3,32 @@ from collections import Counter
import shutil import shutil
import glob import glob
from scripts.clean_training_data.janitor import * from lm_eval.decontamination.janitor import Janitor, word_ngrams
from scripts.clean_training_data.generate_13_grams import do_ngrams_in_buckets from scripts.clean_training_data.generate_13_grams import do_ngrams_in_buckets
from scripts.clean_training_data.archiver import Archive, TextReader from lm_eval.decontamination.archiver import Archive, TextReader
import logging
def test_generate_13_grams_1(): logger = logging.getLogger(__name__)
data = """A goose (plural geese) is a bird of any of several waterfowl species in the family Anatidae.
This group comprises the genera Anser (the grey geese and white geese) and Branta (the black geese).
Some other birds, mostly related to the shelducks, have "goose" as part of their names. def test_generate_13_grams_1(caplog):
More distantly related members of the family Anatidae are swans, most of which are larger data = """A goose (plural geese) is a bird of any of several waterfowl species in the family Anatidae.
than true geese, and ducks, which are smaller. The term "goose" may refer to either a male This group comprises the genera Anser (the grey geese and white geese) and Branta (the black geese).
or female bird, but when paired with "gander", refers specifically to a female one (the latter referring Some other birds, mostly related to the shelducks, have "goose" as part of their names.
to a male). Young birds before fledging are called goslings. The collective noun for a group of More distantly related members of the family Anatidae are swans, most of which are larger
geese on the ground is a gaggle; when in flight, they are called a skein, a team, or a wedge; when than true geese, and ducks, which are smaller. The term "goose" may refer to either a male
or female bird, but when paired with "gander", refers specifically to a female one (the latter referring
to a male). Young birds before fledging are called goslings. The collective noun for a group of
geese on the ground is a gaggle; when in flight, they are called a skein, a team, or a wedge; when
flying close together, they are called a plump.""" flying close together, they are called a plump."""
data = data + data data = data + data
# Simple Generation # Simple Generation
print("simple generation")
n = 13 n = 13
janitor = Janitor() janitor = Janitor()
ngrams = word_ngrams(janitor.normalize_string(data), n) ngrams = word_ngrams(janitor.normalize_string(data), n)
comparison = list(ngrams) comparison = list(ngrams)
comparison_counter = Counter(comparison) comparison_counter = Counter(comparison)
...@@ -31,35 +36,42 @@ def test_generate_13_grams_1(): ...@@ -31,35 +36,42 @@ def test_generate_13_grams_1():
# print(comparison) # print(comparison)
# Generating into buckets # Generating into buckets
print("bucket generation")
test_working_directory = "test_generate_13_grams" test_working_directory = "test_generate_13_grams"
output_directory = os.path.join(test_working_directory, "output")
try: try:
shutil.rmtree(output_directory) shutil.rmtree(test_working_directory)
except FileNotFoundError: except FileNotFoundError:
pass pass
os.makedirs(test_working_directory, exist_ok=True) os.makedirs(test_working_directory)
archive = Archive(os.path.join(test_working_directory, "test.jsonl.zst"))
assert not os.path.exists("pile")
os.makedirs("pile")
archive = Archive(os.path.join("pile", "test.jsonl.zst"))
archive.add_data(data) archive.add_data(data)
archive.commit() archive.commit()
bucket_count = 4 bucket_count = 4
do_ngrams_in_buckets(n, test_working_directory, bucket_count) do_ngrams_in_buckets(n, test_working_directory, bucket_count)
# Rebuild from buckets # Rebuild from buckets
print("rebuild")
rebuilt_ngrams = [] rebuilt_ngrams = []
bucket_file_paths = glob.glob(
bucket_file_paths = glob.glob(os.path.join(test_working_directory, "output", f"*.bkt.txt")) os.path.join(test_working_directory, "output", f"*.bkt.txt")
)
for bucket_file_path in bucket_file_paths: for bucket_file_path in bucket_file_paths:
reader = TextReader(bucket_file_path) reader = TextReader(bucket_file_path)
for line in reader.read(): for line in reader.read():
[ngram, document_id] = line.rsplit(" ", 1) [ngram, document_id] = line.rsplit(" ", 1)
rebuilt_ngrams.append(ngram) rebuilt_ngrams.append(ngram)
# Compare # Compare
print("compare")
result_counter = Counter(rebuilt_ngrams) result_counter = Counter(rebuilt_ngrams)
# print(len(result_counter)) # print(len(result_counter))
# print(len(comparison_counter)) # print(len(comparison_counter))
assert(len(result_counter) == len(comparison_counter)) assert len(result_counter) == len(comparison_counter)
# print(result_counter) # print(result_counter)
# print(comparison_counter) # print(comparison_counter)
assert(comparison_counter == result_counter) assert comparison_counter == result_counter
\ No newline at end of file
import lm_eval.models as models
import pytest
import os
import json
import openai
import mock
import pickle
import hashlib
def mock_completion(**kwargs):
# Mock completion function
# Loads from a cached+pickled response if it exists, otherwise it will actually try to ping
os.makedirs("tests/testdata", exist_ok=True)
hash = hashlib.sha256(json.dumps(kwargs, sort_keys=True).encode('utf-8')).hexdigest()
fname = f"tests/testdata/gpt3_test_{hash}.pkl"
if os.path.exists(fname):
with open(fname, 'rb') as fh:
return pickle.load(fh)
ret = openai.Completion.create(**kwargs)
ret.api_key = ""
with open(fname, 'wb') as fh:
pickle.dump(ret, fh)
return ret
@mock.patch("lm_eval.models.gpt3.oa_completion", new=mock_completion)
def test_gpt3():
if "OPENAI_API_SECRET_KEY" not in os.environ: os.environ["OPENAI_API_SECRET_KEY"] = ""
gpt3 = models.get_model('gpt3').create_from_arg_string("engine=ada")
(ll_dog, ig_dog), (ll_cat, ig_cat), (_, ll_max_0), (_, ll_max_1), (_, ll_max_2), *vals = gpt3.loglikelihood([
('The quick brown fox jumps over the lazy', ' dog'),
('The quick brown fox jumps over the lazy', ' cat'),
('The quick brown fox jumps over the lazy', ', lazy dog'),
('The quick brown fox jumps over the lazy', ', lazy fox'),
('The quick brown fox jumps over the lazy', ', lazy fox and they both fall to the ground'),
("""A mult""", """ilayer perceptron (MLP) is a class of feedforward artificial neural network (ANN)"""),
("""The term MLP is used ambiguously, sometimes loosely to any feedforward ANN, sometimes strictly to refer to networks composed of multiple layers of perceptrons""", """ (with threshold activation); see § Terminology"""),
("""Multilayer perceptrons are sometimes coll""", """oquially referred to as "vanilla" neural networks, especially when they have a single hidden layer.[1]"""),
("""An MLP consists of at least three layers of nodes: an input layer, a hidden layer and an output layer. Except for the input nodes, each node is a neuron that uses a nonlinear""", """ activation function."""),
("""MLP utilizes a supervised""", """ learning technique called backpropagation for training.[2][3] Its multiple layers and non-linear activation distinguish MLP from a linear perceptron. It can distinguish data that is not linearly separable.[4]"""),
("""Recent work has demonstrated substantial gains on many NLP tasks and benchmarks by pre-training on a large corpus of text followed by fine-tuning on a specific task. While typically task-agnostic""", """ in architecture, this method still requires task-specific fine-tuning datasets of thousands or tens of thousands of examples. By contrast, humans can generally perform a new language task from only a few examples or from simple instructions - something which current NLP systems still largely struggle to do. Here we show that scaling up language models greatly improves task-agnostic, few-shot performance, sometimes even reaching competitiveness with prior state-of-the-art fine-tuning approaches. """),
("""Specifically, we train GPT-3, an autoregressive language model with 175""", """ billion parameters, 10x more than any previous non-sparse language model, and test its performance in the few-shot setting. For all tasks, GPT-3 is applied without any gradient updates or fine-tuning, with tasks and few-shot demonstrations specified purely via text interaction with the model. GPT-3 achieves strong performance on many NLP datasets, including translation, question-answering, and cloze tasks, as well as several tasks that require on-the-fly reasoning or domain adaptation, such as unscrambling words, using a novel word in a sentence, or performing 3-digit arithmetic. At the same time, we also identify some datasets where GPT-3's few-shot learning still struggles, as well as some datasets where GPT-3 faces methodological issues related to training on large web corpora. Finally, we find that GPT-3 can generate samples of news articles which human evaluators have difficulty distinguishing from articles written by humans. We discuss broader societal impacts of this finding and of GPT-3 in general."""),
("""A mult""", """ilayer perceptron (MLP) is a class of feedforward artificial neural network (ANN)"""),
("""Hello""", """ World"""),
])
assert ll_dog > ll_cat
assert not ig_cat
assert ig_dog
assert not ll_max_0
assert not ll_max_1
assert not ll_max_2
# test empty context
gpt3.loglikelihood([('', 'test')])
gen, = gpt3.greedy_until([
('The quick brown fox jumps over the lazy', ['.', '\n'])
])
assert gen == ' dog'
print([x[0] for x in vals])
targets = [
-34.848301606999996, -47.148329679999996, -45.44380149599999, -5.285246016, -133.97821690686004,
-321.2616693239001, -658.0299524401041, -34.848301606999996, -7.525115,
]
for (pred, _), tgt in zip(vals, targets):
assert pred == pytest.approx(tgt, rel=1e-3)
@mock.patch("lm_eval.models.gpt3.oa_completion", new=mock_completion)
def test_gpt3_perplexity():
if "OPENAI_API_SECRET_KEY" not in os.environ: os.environ["OPENAI_API_SECRET_KEY"] = ""
gpt3 = models.get_model('gpt3').create_from_arg_string("engine=ada")
test_string = "We study empirical scaling laws for language model performance on the cross-entropy loss."
perplexity = gpt3.loglikelihood_rolling([(test_string,)])[0]
tgt = -84.38819608
assert perplexity == pytest.approx(tgt, rel=1e-3)
# Hack: modify gpt3 to have shorter context length to induce rolling windows
with mock.patch.object(models.gpt3.GPT3LM, 'max_length', new_callable=mock.PropertyMock) as mock_max_length:
mock_max_length.return_value = 5
gpt3 = models.get_model('gpt3').create_from_arg_string("engine=ada")
perplexity = gpt3.loglikelihood_rolling([(test_string,)])[0]
tgt = -101.81967209999999
assert perplexity == pytest.approx(tgt, rel=1e-3)
import re import re
from collections import defaultdict from collections import defaultdict
from scripts.clean_training_data.janitor import * from lm_eval.decontamination.janitor import (
Janitor,
form_ngrams,
word_ngrams,
split_indices,
word_ngrams_indices,
)
def simple_ngram(sequence, n): def simple_ngram(sequence, n):
ngrams = list() ngrams = list()
...@@ -16,8 +23,10 @@ def simple_ngram(sequence, n): ...@@ -16,8 +23,10 @@ def simple_ngram(sequence, n):
def test_form_ngrams(): def test_form_ngrams():
sequence = "Hello my name is Bob, I like eating pizza, chicken, chips and ice cream. Maybe I should eat some" \ sequence = (
" more salad but it's so booooring. I just... like eating pizza, chicken, chips and ice cream so much." "Hello my name is Bob, I like eating pizza, chicken, chips and ice cream. Maybe I should eat some"
" more salad but it's so booooring. I just... like eating pizza, chicken, chips and ice cream so much."
)
n_values = [1, 2, 3, 5, 13] n_values = [1, 2, 3, 5, 13]
for n in n_values: for n in n_values:
...@@ -26,9 +35,12 @@ def test_form_ngrams(): ...@@ -26,9 +35,12 @@ def test_form_ngrams():
assert len(comparison) == len(result_to_test) assert len(comparison) == len(result_to_test)
assert comparison == result_to_test assert comparison == result_to_test
def test_word_ngrams(): def test_word_ngrams():
sequence = "Hello my name is Bob, I like eating pizza, chicken, chips and ice cream. Maybe I should eat some" \ sequence = (
" more salad but it's so booooring. I just... like eating pizza, chicken, chips and ice cream so much." "Hello my name is Bob, I like eating pizza, chicken, chips and ice cream. Maybe I should eat some"
" more salad but it's so booooring. I just... like eating pizza, chicken, chips and ice cream so much."
)
words = sequence.split() words = sequence.split()
...@@ -40,9 +52,12 @@ def test_word_ngrams(): ...@@ -40,9 +52,12 @@ def test_word_ngrams():
assert len(comparison) == len(result_to_test) assert len(comparison) == len(result_to_test)
assert result_to_test == comparison assert result_to_test == comparison
def test_split_indices(): def test_split_indices():
sequence = "Hello my name is Bob, I like eating pizza, chicken, chips and ice cream. Maybe I should eat some" \ sequence = (
" more salad but it's so booooring. I just... like eating pizza, chicken, chips and ice cream so much." "Hello my name is Bob, I like eating pizza, chicken, chips and ice cream. Maybe I should eat some"
" more salad but it's so booooring. I just... like eating pizza, chicken, chips and ice cream so much."
)
comparison = [] comparison = []
current_word = "" current_word = ""
...@@ -55,17 +70,22 @@ def test_split_indices(): ...@@ -55,17 +70,22 @@ def test_split_indices():
current_word = "" current_word = ""
if current_word: if current_word:
comparison.append((current_word, (len(sequence) - len(current_word), len(sequence) - 1))) comparison.append(
current_word = "" (current_word, (len(sequence) - len(current_word), len(sequence) - 1))
)
current_word = ""
result_to_test = list(split_indices(sequence)) result_to_test = list(split_indices(sequence))
assert len(comparison) == len(result_to_test) assert len(comparison) == len(result_to_test)
assert(comparison == result_to_test) assert comparison == result_to_test
def test_word_ngrams_indices(): def test_word_ngrams_indices():
sequence = "Hello my name is Bob, I like eating pizza, chicken, chips and ice cream. Maybe I should eat some" \ sequence = (
" more salad but it's so booooring. I just... like eating pizza, chicken, chips and ice cream so much." "Hello my name is Bob, I like eating pizza, chicken, chips and ice cream. Maybe I should eat some"
" more salad but it's so booooring. I just... like eating pizza, chicken, chips and ice cream so much."
)
n_values = [1, 2, 3, 5, 13] n_values = [1, 2, 3, 5, 13]
...@@ -76,55 +96,62 @@ def test_word_ngrams_indices(): ...@@ -76,55 +96,62 @@ def test_word_ngrams_indices():
for ngram in ngrams: for ngram in ngrams:
while True: while True:
start = sequence.find(ngram, tracker[ngram]) start = sequence.find(ngram, tracker[ngram])
assert start != -1 # testing the test assert start != -1 # testing the test
end = start + len(ngram) - 1 end = start + len(ngram) - 1
tracker[ngram] = end + 1 tracker[ngram] = end + 1
# ignore partial word matches # ignore partial word matches
if (start != 0 and sequence[start - 1] != " ") or \ if (start != 0 and sequence[start - 1] != " ") or (
(end != len(sequence) - 1 and sequence[end + 1] != " "): end != len(sequence) - 1 and sequence[end + 1] != " "
):
pass pass
else: else:
break break
comparison.append((ngram, (start, end))) comparison.append((ngram, (start, end)))
result_to_test = list(word_ngrams_indices(sequence, n)) result_to_test = list(word_ngrams_indices(sequence, n))
assert len(result_to_test) == len(comparison) assert len(result_to_test) == len(comparison)
assert result_to_test == comparison assert result_to_test == comparison
# Assumptions from GPT3 Paper: # Assumptions from GPT3 Paper:
# the 200 characters to remove include punctuation and is actually a half-window # the 200 characters to remove include punctuation and is actually a half-window
# All tests below initially test without any registered contaminants, expecting the same sequence back. # All tests below initially test without any registered contaminants, expecting the same sequence back.
def test_janitor1(): def test_janitor1():
# First test using a 1gram and expected the first block before the filth to have some remaining # First test using a 1gram and expected the first block before the filth to have some remaining
# characters, but the second block should be completely removed. # characters, but the second block should be completely removed.
sequence = "This is a @line #containing a certain number of characters, 76 to be exact. " \ sequence = (
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"FILTH. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "FILTH. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
)
filth = "filth" filth = "filth"
expected_result = "This is a @line #containing a certain number of characters, 76 to be exact. " \ expected_result = (
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing " "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing "
)
janitor = Janitor(ngram_n=1, window_to_remove=200, too_dirty_cutoff=10, minimum_slice_length=200) janitor = Janitor(
ngram_n=1, window_to_remove=200, too_dirty_cutoff=10, minimum_slice_length=200
)
result = janitor.clean_python(sequence) result = janitor.clean_python(sequence)
result = "".join(result) result = "".join(result)
assert result == sequence assert result == sequence
...@@ -133,42 +160,47 @@ def test_janitor1(): ...@@ -133,42 +160,47 @@ def test_janitor1():
assert janitor.dirt_ngrams == {filth} assert janitor.dirt_ngrams == {filth}
result = janitor.clean_python(sequence) result = janitor.clean_python(sequence)
result = "".join(result) result = "".join(result)
assert result == expected_result assert result == expected_result
def test_janitor2(): def test_janitor2():
# Second test using a 1gram and expected the first block before the filth to have some remaining # Second test using a 1gram and expected the first block before the filth to have some remaining
# characters, and the second block is longer then 200 characters so should also have some remaining. # characters, and the second block is longer then 200 characters so should also have some remaining.
sequence = "This is a @line #containing a certain number of characters, 76 to be exact. " \ sequence = (
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"FILTH. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "FILTH. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
)
filth = "filth" filth = "filth"
expected_result = "This is a @line #containing a certain number of characters, 76 to be exact. " \ expected_result = (
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
" characters, 76 to be exact. " \ "This is a @line #containing "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ " characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
janitor = Janitor(ngram_n=1, window_to_remove=200, too_dirty_cutoff=10, minimum_slice_length=200) )
janitor = Janitor(
ngram_n=1, window_to_remove=200, too_dirty_cutoff=10, minimum_slice_length=200
)
result = janitor.clean_python(sequence) result = janitor.clean_python(sequence)
result = "".join(result) result = "".join(result)
assert result == sequence assert result == sequence
...@@ -180,37 +212,43 @@ def test_janitor2(): ...@@ -180,37 +212,43 @@ def test_janitor2():
result = "".join(result) result = "".join(result)
assert result == expected_result assert result == expected_result
def test_janitor3(): def test_janitor3():
# Same test as above but with a 6gram. # Same test as above but with a 6gram.
sequence = "This is a @line #containing a certain number of characters, 76 to be exact. " \ sequence = (
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"FILTH. lots of dirty filtHy FIlTh " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "FILTH. lots of dirty filtHy FIlTh "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
)
filth = "filth lots of dirty filthy filth" filth = "filth lots of dirty filthy filth"
expected_result = "This is a @line #containing a certain number of characters, 76 to be exact. " \ expected_result = (
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
" characters, 76 to be exact. " \ "This is a @line #containing "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ " characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
janitor = Janitor(ngram_n=6, window_to_remove=200, too_dirty_cutoff=10, minimum_slice_length=200) )
janitor = Janitor(
ngram_n=6, window_to_remove=200, too_dirty_cutoff=10, minimum_slice_length=200
)
result = janitor.clean_python(sequence) result = janitor.clean_python(sequence)
result = "".join(result) result = "".join(result)
assert result == sequence assert result == sequence
...@@ -222,45 +260,51 @@ def test_janitor3(): ...@@ -222,45 +260,51 @@ def test_janitor3():
result = "".join(result) result = "".join(result)
assert result == expected_result assert result == expected_result
def test_janitor4(): def test_janitor4():
# This test adds another block to that from the previous. The middle block should be entirely # This test adds another block to that from the previous. The middle block should be entirely
# removed as the 200 characters are removed from each side. # removed as the 200 characters are removed from each side.
sequence = "This is a @line #containing a certain number of characters, 76 to be exact. " \ sequence = (
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"FILTH. lots of dirty filtHy FIlTh " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "FILTH. lots of dirty filtHy FIlTh "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"FILTH. lots of dirty filtHy FIlTh " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "FILTH. lots of dirty filtHy FIlTh "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
)
filth = "filth lots of dirty filthy filth" filth = "filth lots of dirty filthy filth"
expected_result = "This is a @line #containing a certain number of characters, 76 to be exact. " \ expected_result = (
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
" characters, 76 to be exact. " \ "This is a @line #containing "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ " characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
janitor = Janitor(ngram_n=6, window_to_remove=200, too_dirty_cutoff=10, minimum_slice_length=200) )
janitor = Janitor(
ngram_n=6, window_to_remove=200, too_dirty_cutoff=10, minimum_slice_length=200
)
result = janitor.clean_python(sequence) result = janitor.clean_python(sequence)
result = "".join(result) result = "".join(result)
assert result == sequence assert result == sequence
...@@ -272,49 +316,55 @@ def test_janitor4(): ...@@ -272,49 +316,55 @@ def test_janitor4():
result = "".join(result) result = "".join(result)
assert result == expected_result assert result == expected_result
def test_janitor5(): def test_janitor5():
# Same as above but using multiple different filth 6grams. # Same as above but using multiple different filth 6grams.
sequence = "This is a @line #containing a certain number of characters, 76 to be exact. " \ sequence = (
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"FILTH. lots of dirty filtHy FIlTh " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "FILTH. lots of dirty filtHy FIlTh "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"FILTH. lots of filtHy dirty FIlTh " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "FILTH. lots of filtHy dirty FIlTh "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
)
filths = ["filth lots of dirty filthy filth", "filth lots of filthy dirty filth"]
filths = ["filth lots of dirty filthy filth", "filth lots of filthy dirty filth"]
expected_result = "This is a @line #containing a certain number of characters, 76 to be exact. " \
"This is a @line #containing a certain number of characters, 76 to be exact. " \ expected_result = (
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
" characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ " characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
janitor = Janitor(ngram_n=6, window_to_remove=200, too_dirty_cutoff=10, minimum_slice_length=200) "This is a @line #containing a certain number of characters, 76 to be exact. "
)
janitor = Janitor(
ngram_n=6, window_to_remove=200, too_dirty_cutoff=10, minimum_slice_length=200
)
result = janitor.clean_python(sequence) result = janitor.clean_python(sequence)
result = "".join(result) result = "".join(result)
assert result == sequence assert result == sequence
for filth in filths: for filth in filths:
janitor.register_contaminant(filth) janitor.register_contaminant(filth)
assert janitor.dirt_ngrams == set(filths) assert janitor.dirt_ngrams == set(filths)
...@@ -322,57 +372,63 @@ def test_janitor5(): ...@@ -322,57 +372,63 @@ def test_janitor5():
result = "".join(result) result = "".join(result)
assert result == expected_result assert result == expected_result
def test_janitor6(): def test_janitor6():
# Same as above but now we add 10 filths and expect the same result, the following test does 11. # Same as above but now we add 10 filths and expect the same result, the following test does 11.
sequence = "This is a @line #containing a certain number of characters, 76 to be exact. " \ sequence = (
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"FILTH. lots of dirty filtHy FIlTh " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"FILTH. lots of dirty filtHy FIlTh " \ "FILTH. lots of dirty filtHy FIlTh "
"FILTH. lots of dirty filtHy FIlTh " \ "FILTH. lots of dirty filtHy FIlTh "
"FILTH. lots of dirty filtHy FIlTh " \ "FILTH. lots of dirty filtHy FIlTh "
"FILTH. lots of dirty filtHy FIlTh " \ "FILTH. lots of dirty filtHy FIlTh "
"FILTH. lots of dirty filtHy FIlTh " \ "FILTH. lots of dirty filtHy FIlTh "
"FILTH. lots of dirty filtHy FIlTh " \ "FILTH. lots of dirty filtHy FIlTh "
"FILTH. lots of dirty filtHy FIlTh " \ "FILTH. lots of dirty filtHy FIlTh "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "FILTH. lots of dirty filtHy FIlTh "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"FILTH. lots of filtHy dirty FIlTh " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"FILTH. lots of filtHy dirty FIlTh " \ "FILTH. lots of filtHy dirty FIlTh "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "FILTH. lots of filtHy dirty FIlTh "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
)
filths = ["filth lots of dirty filthy filth", "filth lots of filthy dirty filth"]
filths = ["filth lots of dirty filthy filth", "filth lots of filthy dirty filth"]
expected_result = "This is a @line #containing a certain number of characters, 76 to be exact. " \
"This is a @line #containing a certain number of characters, 76 to be exact. " \ expected_result = (
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
" characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ " characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
janitor = Janitor(ngram_n=6, window_to_remove=200, too_dirty_cutoff=10, minimum_slice_length=200) "This is a @line #containing a certain number of characters, 76 to be exact. "
)
janitor = Janitor(
ngram_n=6, window_to_remove=200, too_dirty_cutoff=10, minimum_slice_length=200
)
result = janitor.clean_python(sequence) result = janitor.clean_python(sequence)
result = "".join(result) result = "".join(result)
assert result == sequence assert result == sequence
for filth in filths: for filth in filths:
janitor.register_contaminant(filth) janitor.register_contaminant(filth)
assert janitor.dirt_ngrams == set(filths) assert janitor.dirt_ngrams == set(filths)
...@@ -380,51 +436,55 @@ def test_janitor6(): ...@@ -380,51 +436,55 @@ def test_janitor6():
result = "".join(result) result = "".join(result)
assert result == expected_result assert result == expected_result
def test_janitor7(): def test_janitor7():
# Same as above but now we add 9 filths and expect the same result, the following test does 10. # Same as above but now we add 9 filths and expect the same result, the following test does 10.
sequence = "This is a @line #containing a certain number of characters, 76 to be exact. " \ sequence = (
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"FILTH. lots of dirty filtHy FIlTh " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"FILTH. lots of dirty filtHy FIlTh " \ "FILTH. lots of dirty filtHy FIlTh "
"FILTH. lots of dirty filtHy FIlTh " \ "FILTH. lots of dirty filtHy FIlTh "
"FILTH. lots of dirty filtHy FIlTh " \ "FILTH. lots of dirty filtHy FIlTh "
"FILTH. lots of dirty filtHy FIlTh " \ "FILTH. lots of dirty filtHy FIlTh "
"FILTH. lots of dirty filtHy FIlTh " \ "FILTH. lots of dirty filtHy FIlTh "
"FILTH. lots of dirty filtHy FIlTh " \ "FILTH. lots of dirty filtHy FIlTh "
"FILTH. lots of dirty filtHy FIlTh " \ "FILTH. lots of dirty filtHy FIlTh "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "FILTH. lots of dirty filtHy FIlTh "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"FILTH. lots of filtHy dirty FIlTh " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"FILTH. lots of filtHy dirty FIlTh " \ "FILTH. lots of filtHy dirty FIlTh "
"FILTH. lots of filtHy dirty FIlTh " \ "FILTH. lots of filtHy dirty FIlTh "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "FILTH. lots of filtHy dirty FIlTh "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
)
filths = ["filth lots of dirty filthy filth", "filth lots of filthy dirty filth"]
filths = ["filth lots of dirty filthy filth", "filth lots of filthy dirty filth"]
expected_result = "" expected_result = ""
janitor = Janitor(ngram_n=6, window_to_remove=200, too_dirty_cutoff=10, minimum_slice_length=200) janitor = Janitor(
ngram_n=6, window_to_remove=200, too_dirty_cutoff=10, minimum_slice_length=200
)
result = janitor.clean_python(sequence) result = janitor.clean_python(sequence)
result = "".join(result) result = "".join(result)
assert result == sequence assert result == sequence
for filth in filths: for filth in filths:
janitor.register_contaminant(filth) janitor.register_contaminant(filth)
assert janitor.dirt_ngrams == set(filths) assert janitor.dirt_ngrams == set(filths)
...@@ -453,23 +513,3 @@ def test_janitor8(): ...@@ -453,23 +513,3 @@ def test_janitor8():
# cleaned = " ".join(jan.clean(source)) # cleaned = " ".join(jan.clean(source))
# for contam in jan.dirt_ngrams: # for contam in jan.dirt_ngrams:
# assert contam not in cleaned, contam # assert contam not in cleaned, contam
import hashlib
import json
import openai
import os
import pickle
import pytest import pytest
import unittest.mock as mock import unittest.mock as mock
import lm_eval.models as models import lm_eval.models as models
LOGLIKELIHOOD_TEST_CASES = [
("The quick brown fox jumps over the lazy", " dog"),
("The quick brown fox jumps over the lazy", " cat"),
("The quick brown fox jumps over the lazy", ", lazy dog"),
("The quick brown fox jumps over the lazy", ", lazy fox"),
(
"The quick brown fox jumps over the lazy",
", lazy fox and they both fall to the ground",
),
(
"""A mult""",
"""ilayer perceptron (MLP) is a class of feedforward artificial neural network (ANN)""",
),
(
"""The term MLP is used ambiguously, sometimes loosely to any feedforward ANN, sometimes strictly to refer to networks composed of multiple layers of perceptrons""",
""" (with threshold activation); see § Terminology""",
),
(
"""Multilayer perceptrons are sometimes coll""",
"""oquially referred to as "vanilla" neural networks, especially when they have a single hidden layer.[1]""",
),
(
"""An MLP consists of at least three layers of nodes: an input layer, a hidden layer and an output layer. Except for the input nodes, each node is a neuron that uses a nonlinear""",
""" activation function.""",
),
(
"""MLP utilizes a supervised""",
""" learning technique called backpropagation for training.[2][3] Its multiple layers and non-linear activation distinguish MLP from a linear perceptron. It can distinguish data that is not linearly separable.[4]""",
),
(
"""Recent work has demonstrated substantial gains on many NLP tasks and benchmarks by pre-training on a large corpus of text followed by fine-tuning on a specific task. While typically task-agnostic""",
""" in architecture, this method still requires task-specific fine-tuning datasets of thousands or tens of thousands of examples. By contrast, humans can generally perform a new language task from only a few examples or from simple instructions - something which current NLP systems still largely struggle to do. Here we show that scaling up language models greatly improves task-agnostic, few-shot performance, sometimes even reaching competitiveness with prior state-of-the-art fine-tuning approaches. """,
),
(
"""Specifically, we train GPT-3, an autoregressive language model with 175""",
""" billion parameters, 10x more than any previous non-sparse language model, and test its performance in the few-shot setting. For all tasks, GPT-3 is applied without any gradient updates or fine-tuning, with tasks and few-shot demonstrations specified purely via text interaction with the model. GPT-3 achieves strong performance on many NLP datasets, including translation, question-answering, and cloze tasks, as well as several tasks that require on-the-fly reasoning or domain adaptation, such as unscrambling words, using a novel word in a sentence, or performing 3-digit arithmetic. At the same time, we also identify some datasets where GPT-3's few-shot learning still struggles, as well as some datasets where GPT-3 faces methodological issues related to training on large web corpora. Finally, we find that GPT-3 can generate samples of news articles which human evaluators have difficulty distinguishing from articles written by humans. We discuss broader societal impacts of this finding and of GPT-3 in general.""",
),
(
"""A mult""",
"""ilayer perceptron (MLP) is a class of feedforward artificial neural network (ANN)""",
),
("""Hello""", """ World"""),
]
# Test HuggingFace Models (GPT-2)
def test_gpt2(): def test_gpt2():
gpt2 = models.get_model('gpt2').create_from_arg_string("device=cpu") gpt2 = models.get_model("gpt2").create_from_arg_string("device=cpu")
(ll_dog, ig_dog), (ll_cat, ig_cat), (_, ll_max_0), (_, ll_max_1), (_, ll_max_2), *vals = gpt2.loglikelihood([ (
('The quick brown fox jumps over the lazy', ' dog'), (ll_dog, ig_dog),
('The quick brown fox jumps over the lazy', ' cat'), (ll_cat, ig_cat),
('The quick brown fox jumps over the lazy', ', lazy dog'), (_, ll_max_0),
('The quick brown fox jumps over the lazy', ', lazy fox'), (_, ll_max_1),
('The quick brown fox jumps over the lazy', ', lazy fox and they both fall to the ground'), (_, ll_max_2),
*vals,
("""A mult""", """ilayer perceptron (MLP) is a class of feedforward artificial neural network (ANN)"""), ) = gpt2.loglikelihood(LOGLIKELIHOOD_TEST_CASES)
("""The term MLP is used ambiguously, sometimes loosely to any feedforward ANN, sometimes strictly to refer to networks composed of multiple layers of perceptrons""", """ (with threshold activation); see § Terminology"""),
("""Multilayer perceptrons are sometimes coll""", """oquially referred to as "vanilla" neural networks, especially when they have a single hidden layer.[1]"""),
("""An MLP consists of at least three layers of nodes: an input layer, a hidden layer and an output layer. Except for the input nodes, each node is a neuron that uses a nonlinear""", """ activation function."""),
("""MLP utilizes a supervised""", """ learning technique called backpropagation for training.[2][3] Its multiple layers and non-linear activation distinguish MLP from a linear perceptron. It can distinguish data that is not linearly separable.[4]"""),
("""Recent work has demonstrated substantial gains on many NLP tasks and benchmarks by pre-training on a large corpus of text followed by fine-tuning on a specific task. While typically task-agnostic""", """ in architecture, this method still requires task-specific fine-tuning datasets of thousands or tens of thousands of examples. By contrast, humans can generally perform a new language task from only a few examples or from simple instructions - something which current NLP systems still largely struggle to do. Here we show that scaling up language models greatly improves task-agnostic, few-shot performance, sometimes even reaching competitiveness with prior state-of-the-art fine-tuning approaches. """),
("""Specifically, we train GPT-3, an autoregressive language model with 175""", """ billion parameters, 10x more than any previous non-sparse language model, and test its performance in the few-shot setting. For all tasks, GPT-3 is applied without any gradient updates or fine-tuning, with tasks and few-shot demonstrations specified purely via text interaction with the model. GPT-3 achieves strong performance on many NLP datasets, including translation, question-answering, and cloze tasks, as well as several tasks that require on-the-fly reasoning or domain adaptation, such as unscrambling words, using a novel word in a sentence, or performing 3-digit arithmetic. At the same time, we also identify some datasets where GPT-3's few-shot learning still struggles, as well as some datasets where GPT-3 faces methodological issues related to training on large web corpora. Finally, we find that GPT-3 can generate samples of news articles which human evaluators have difficulty distinguishing from articles written by humans. We discuss broader societal impacts of this finding and of GPT-3 in general."""),
("""A mult""", """ilayer perceptron (MLP) is a class of feedforward artificial neural network (ANN)"""),
("""Hello""", """ World"""),
])
assert ll_dog > ll_cat assert ll_dog > ll_cat
assert not ig_cat assert not ig_cat
...@@ -31,17 +76,24 @@ def test_gpt2(): ...@@ -31,17 +76,24 @@ def test_gpt2():
assert ll_max_2 assert ll_max_2
# test empty context # test empty context
gpt2.loglikelihood([('', 'test')]) gpt2.loglikelihood([("", "test")])
gen, = gpt2.greedy_until([ (gen,) = gpt2.greedy_until(
('The quick brown fox jumps over the lazy', ['.', '\n']) [("The quick brown fox jumps over the lazy", [".", "\n"])]
]) )
assert gen == ', lazy fox and they both fall to the ground' assert gen == ", lazy fox and they both fall to the ground"
targets = [ targets = [
-61.60536193847656, -56.57843780517578, -62.131004333496094, -9.799489974975586, -153.96334838867188, -61.60536193847656,
-341.222900390625, -731.1475830078125, -61.60536193847656, -8.682319641113281 -56.57843780517578,
-62.131004333496094,
-9.799489974975586,
-153.96334838867188,
-341.222900390625,
-731.1475830078125,
-61.60536193847656,
-8.682319641113281,
] ]
for (pred, _), tgt in zip(vals, targets): for (pred, _), tgt in zip(vals, targets):
...@@ -49,21 +101,224 @@ def test_gpt2(): ...@@ -49,21 +101,224 @@ def test_gpt2():
def test_gpt2_perplexity(): def test_gpt2_perplexity():
gpt2 = models.get_model('gpt2').create_from_arg_string("device=cpu") gpt2 = models.get_model("gpt2").create_from_arg_string("device=cpu")
test_string = "We study empirical scaling laws for language model performance on the cross-entropy loss." test_string = "We study empirical scaling laws for language model performance on the cross-entropy loss."
perplexity = gpt2.loglikelihood_rolling([(test_string,)])[0] perplexity = gpt2.loglikelihood_rolling([(test_string,)])[0]
tgt = sum([ tgt = sum(
-4.9599953, -8.069298, -8.308624, -10.178513, -8.906924, -1.9318912, -7.745445, -7.146077, -5.2072, [
-3.5882986, -1.9957212, -8.044922, -0.20841774, -5.1096807, -0.099879116, -8.888423, -4.6180487, -4.9599953,
]) -8.069298,
-8.308624,
-10.178513,
-8.906924,
-1.9318912,
-7.745445,
-7.146077,
-5.2072,
-3.5882986,
-1.9957212,
-8.044922,
-0.20841774,
-5.1096807,
-0.099879116,
-8.888423,
-4.6180487,
]
)
assert perplexity == pytest.approx(tgt, rel=1e-3) assert perplexity == pytest.approx(tgt, rel=1e-3)
with mock.patch.object(models.gpt2.HFLM, 'max_length', new_callable=mock.PropertyMock) as mock_max_length: with mock.patch.object(
models.gpt2.HFLM, "max_length", new_callable=mock.PropertyMock
) as mock_max_length:
mock_max_length.return_value = 5 mock_max_length.return_value = 5
gpt2 = models.get_model('gpt2').create_from_arg_string("device=cpu") gpt2 = models.get_model("gpt2").create_from_arg_string("device=cpu")
perplexity = gpt2.loglikelihood_rolling([(test_string,)])[0] perplexity = gpt2.loglikelihood_rolling([(test_string,)])[0]
tgt = sum([ tgt = sum(
-4.96001, -8.069275, -8.308612, -10.178482, -8.90691, -4.037338, -8.09261, -11.662385, -10.206891, [
-4.425003, -2.2563353, -7.909143, -1.9304147, -7.3610134, -2.3120654, -7.3229, -2.1643813, -4.96001,
]) -8.069275,
-8.308612,
-10.178482,
-8.90691,
-4.037338,
-8.09261,
-11.662385,
-10.206891,
-4.425003,
-2.2563353,
-7.909143,
-1.9304147,
-7.3610134,
-2.3120654,
-7.3229,
-2.1643813,
]
)
assert perplexity == pytest.approx(tgt, rel=1e-3) assert perplexity == pytest.approx(tgt, rel=1e-3)
# Test OpenAI Models (GPT-3)
def openai_mock_completion(**kwargs):
# Mock completion function
# Loads from a cached+pickled response if it exists, otherwise it will actually try to ping
os.makedirs("tests/testdata", exist_ok=True)
hash = hashlib.sha256(
json.dumps(kwargs, sort_keys=True).encode("utf-8")
).hexdigest()
fname = f"tests/testdata/gpt3_test_{hash}.pkl"
if os.path.exists(fname):
with open(fname, "rb") as fh:
return pickle.load(fh)
ret = openai.Completion.create(**kwargs)
ret.api_key = ""
with open(fname, "wb") as fh:
pickle.dump(ret, fh)
return ret
@mock.patch("lm_eval.models.gpt3.oa_completion", new=openai_mock_completion)
def test_gpt3():
if "OPENAI_API_SECRET_KEY" not in os.environ:
os.environ["OPENAI_API_SECRET_KEY"] = ""
gpt3 = models.get_model("gpt3").create_from_arg_string("engine=ada")
(
(ll_dog, ig_dog),
(ll_cat, ig_cat),
(_, ll_max_0),
(_, ll_max_1),
(_, ll_max_2),
*vals,
) = gpt3.loglikelihood(LOGLIKELIHOOD_TEST_CASES)
assert ll_dog > ll_cat
assert not ig_cat
assert ig_dog
assert not ll_max_0
assert not ll_max_1
assert not ll_max_2
# test empty context
gpt3.loglikelihood([("", "test")])
(gen,) = gpt3.greedy_until(
[("The quick brown fox jumps over the lazy", [".", "\n"])]
)
assert gen == " dog"
print([x[0] for x in vals])
targets = [
-34.848301606999996,
-47.148329679999996,
-45.44380149599999,
-5.285246016,
-133.97821690686004,
-321.2616693239001,
-658.0299524401041,
-34.848301606999996,
-7.525115,
]
for (pred, _), tgt in zip(vals, targets):
assert pred == pytest.approx(tgt, rel=1e-3)
@mock.patch("lm_eval.models.gpt3.oa_completion", new=openai_mock_completion)
def test_gpt3_perplexity():
if "OPENAI_API_SECRET_KEY" not in os.environ:
os.environ["OPENAI_API_SECRET_KEY"] = ""
gpt3 = models.get_model("gpt3").create_from_arg_string("engine=ada")
test_string = "We study empirical scaling laws for language model performance on the cross-entropy loss."
perplexity = gpt3.loglikelihood_rolling([(test_string,)])[0]
tgt = -84.38819608
assert perplexity == pytest.approx(tgt, rel=1e-3)
# Hack: modify gpt3 to have shorter context length to induce rolling windows
with mock.patch.object(
models.gpt3.GPT3LM, "max_length", new_callable=mock.PropertyMock
) as mock_max_length:
mock_max_length.return_value = 5
gpt3 = models.get_model("gpt3").create_from_arg_string("engine=ada")
perplexity = gpt3.loglikelihood_rolling([(test_string,)])[0]
tgt = -101.81967209999999
assert perplexity == pytest.approx(tgt, rel=1e-3)
# Test TextSynth Models (GPT-J)
def textsynth_mock_completion(**kwargs):
# Mock completion function
# Loads from a cached+pickled response if it exists, otherwise it will actually try to ping
import requests
os.makedirs("tests/testdata", exist_ok=True)
hash_kwargs = {k: v for k, v in kwargs.items() if k != "headers"}
hash = hashlib.sha256(
json.dumps(hash_kwargs, sort_keys=True).encode("utf-8")
).hexdigest()
fname = f"tests/testdata/textsynth_test_{hash}.pkl"
if os.path.exists(fname):
with open(fname, "rb") as fh:
return pickle.load(fh)
ret = requests.post(**kwargs)
with open(fname, "wb") as fh:
pickle.dump(ret, fh)
return ret
@mock.patch(
"lm_eval.models.textsynth.textsynth_completion", new=textsynth_mock_completion
)
def test_textsynth():
if "TEXTSYNTH_API_SECRET_KEY" not in os.environ:
os.environ["TEXTSYNTH_API_SECRET_KEY"] = ""
textsynth = models.get_model("textsynth").create_from_arg_string("engine=gptj_6B")
(
(ll_dog, ig_dog),
(ll_cat, ig_cat),
(_, ll_max_0),
(_, ll_max_1),
(_, ll_max_2),
*vals,
) = textsynth.loglikelihood(LOGLIKELIHOOD_TEST_CASES)
assert ll_dog > ll_cat
assert not ig_cat
assert ig_dog
assert not ll_max_0
assert not ll_max_1
assert not ll_max_2
# test empty context
textsynth.loglikelihood([("", "test")])
(gen,) = textsynth.greedy_until(
[("The quick brown fox jumps over the lazy", [".", "\n"])]
)
assert gen == " dog"
print([x[0] for x in vals])
targets = [
-17.90513712817,
-41.83518912287,
-33.82445643841,
-2.377361565302,
-99.53018069754,
-243.5642283598,
-528.6862613790,
-17.90513712817,
-5.041000672142,
]
for (pred, _), tgt in zip(vals, targets):
assert pred == pytest.approx(tgt, rel=1e-3)
...@@ -6,11 +6,8 @@ from itertools import islice ...@@ -6,11 +6,8 @@ from itertools import islice
@pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items()) @pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items())
def test_basic_interface(taskname, task_class): def test_basic_interface(taskname, task_class):
print('Evaluating task', taskname) print("Evaluating task", taskname)
# dl = task_class.download
# task_class.download = MagicMock()
task = task_class() task = task_class()
# task_class.download = dl
assert task.has_training_docs() in [True, False] assert task.has_training_docs() in [True, False]
assert task.has_validation_docs() in [True, False] assert task.has_validation_docs() in [True, False]
...@@ -42,7 +39,7 @@ def test_basic_interface(taskname, task_class): ...@@ -42,7 +39,7 @@ def test_basic_interface(taskname, task_class):
reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr] reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2] reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2]
assert reqs == reqs2 assert reqs == reqs2
if task.has_test_docs(): if task.has_test_docs():
...@@ -53,7 +50,7 @@ def test_basic_interface(taskname, task_class): ...@@ -53,7 +50,7 @@ def test_basic_interface(taskname, task_class):
reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr] reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2] reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2]
assert reqs == reqs2 assert reqs == reqs2
if task.has_training_docs(): if task.has_training_docs():
...@@ -64,13 +61,13 @@ def test_basic_interface(taskname, task_class): ...@@ -64,13 +61,13 @@ def test_basic_interface(taskname, task_class):
reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr] reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2] reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2]
assert reqs == reqs2 assert reqs == reqs2
@pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items()) @pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items())
def test_documents_and_requests(taskname, task_class): def test_documents_and_requests(taskname, task_class):
print('Evaluating task', taskname) print("Evaluating task", taskname)
task = task_class() task = task_class()
fns = [] fns = []
if task.has_training_docs(): if task.has_training_docs():
...@@ -83,21 +80,21 @@ def test_documents_and_requests(taskname, task_class): ...@@ -83,21 +80,21 @@ def test_documents_and_requests(taskname, task_class):
for fn in fns: for fn in fns:
# print(list(islice(fn(), 10))) # print(list(islice(fn(), 10)))
for doc in islice(fn(), 10): for doc in islice(fn(), 10):
txt = task.doc_to_text(doc) txt = task.doc_to_text(doc)
tgt = task.doc_to_target(doc) tgt = task.doc_to_target(doc)
assert isinstance(txt, str) assert isinstance(txt, str)
assert isinstance(tgt, str) assert isinstance(tgt, str)
# space convention # space convention
# allow txt to have length 0 for perplexity-like tasks since the model tacks an <|endoftext|> on # allow txt to have length 0 for perplexity-like tasks since the model tacks an <|endoftext|> on
if len(txt) != 0: if len(txt) != 0:
assert txt[-1] != ' ' assert txt[-1] != " "
assert tgt[0] == ' ' or txt[-1] == '\n' assert tgt[0] == " " or txt[-1] == "\n"
reqs = task.construct_requests(doc, txt) reqs = task.construct_requests(doc, txt)
# construct_requests can return just one request # construct_requests can return just one request
if not isinstance(reqs, (list, tuple)): if not isinstance(reqs, (list, tuple)):
reqs = [reqs] reqs = [reqs]
......
...@@ -5,8 +5,14 @@ from lm_eval.utils import get_rolling_token_windows, make_disjoint_window ...@@ -5,8 +5,14 @@ from lm_eval.utils import get_rolling_token_windows, make_disjoint_window
def test_get_rolling_token_windows_v1(): def test_get_rolling_token_windows_v1():
gold = [ gold = [
([-100, 0, 1, 2, 3, 4, 5, 6, 7, 8], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), ([-100, 0, 1, 2, 3, 4, 5, 6, 7, 8], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
([9, 10, 11, 12, 13, 14, 15, 16, 17, 18], [10, 11, 12, 13, 14, 15, 16, 17, 18, 19]), (
([19, 20, 21, 22, 23, 24, 25, 26, 27, 28], [20, 21, 22, 23, 24, 25, 26, 27, 28, 29]), [9, 10, 11, 12, 13, 14, 15, 16, 17, 18],
[10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
),
(
[19, 20, 21, 22, 23, 24, 25, 26, 27, 28],
[20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
),
([23, 24, 25, 26, 27, 28, 29, 30, 31, 32], [30, 31, 32, 33]), ([23, 24, 25, 26, 27, 28, 29, 30, 31, 32], [30, 31, 32, 33]),
] ]
x = list(range(34)) x = list(range(34))
...@@ -123,7 +129,6 @@ def test_get_rolling_token_windows_v4(): ...@@ -123,7 +129,6 @@ def test_get_rolling_token_windows_v4():
([17, 18, 19, 20, 21, 22, 23, 24, 25, 26], [27]), ([17, 18, 19, 20, 21, 22, 23, 24, 25, 26], [27]),
([18, 19, 20, 21, 22, 23, 24, 25, 26, 27], [28]), ([18, 19, 20, 21, 22, 23, 24, 25, 26, 27], [28]),
([19, 20, 21, 22, 23, 24, 25, 26, 27, 28], [29]), ([19, 20, 21, 22, 23, 24, 25, 26, 27, 28], [29]),
] ]
x = list(range(30)) x = list(range(30))
generator = get_rolling_token_windows( generator = get_rolling_token_windows(
...@@ -145,8 +150,14 @@ def test_get_rolling_token_windows_v4(): ...@@ -145,8 +150,14 @@ def test_get_rolling_token_windows_v4():
def test_get_rolling_token_windows_v5(): def test_get_rolling_token_windows_v5():
gold = [ gold = [
([-100, 0, 1, 2, 3, 4, 5, 6, 7, 8], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), ([-100, 0, 1, 2, 3, 4, 5, 6, 7, 8], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
([9, 10, 11, 12, 13, 14, 15, 16, 17, 18], [10, 11, 12, 13, 14, 15, 16, 17, 18, 19]), (
([19, 20, 21, 22, 23, 24, 25, 26, 27, 28], [20, 21, 22, 23, 24, 25, 26, 27, 28, 29]), [9, 10, 11, 12, 13, 14, 15, 16, 17, 18],
[10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
),
(
[19, 20, 21, 22, 23, 24, 25, 26, 27, 28],
[20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
),
] ]
x = list(range(30)) x = list(range(30))
generator = get_rolling_token_windows( generator = get_rolling_token_windows(
...@@ -203,5 +214,9 @@ def test_get_rolling_token_windows_empty(): ...@@ -203,5 +214,9 @@ def test_get_rolling_token_windows_empty():
def test_make_disjoint_window(): def test_make_disjoint_window():
assert make_disjoint_window(([1,2,3,4,5], [2,3,4,5,6])) == ([1], [2,3,4,5,6]) assert make_disjoint_window(([1, 2, 3, 4, 5], [2, 3, 4, 5, 6])) == (
assert make_disjoint_window(([1,2,3,4,5], [4,5,6])) == ([1,2,3], [4,5,6]) [1],
\ No newline at end of file [2, 3, 4, 5, 6],
)
assert make_disjoint_window(([1, 2, 3, 4, 5], [4, 5, 6])) == ([1, 2, 3], [4, 5, 6])
assert make_disjoint_window(([1, 2, 3, 4, 5], [6])) == ([1, 2, 3, 4, 5], [6])
...@@ -16,13 +16,14 @@ def assert_target(name, ob): ...@@ -16,13 +16,14 @@ def assert_target(name, ob):
fname = f"tests/testdata/{name}.json" fname = f"tests/testdata/{name}.json"
if os.path.exists(fname): if os.path.exists(fname):
with open(fname) as fh: with open(fname) as fh:
# Use relative tolerance of 1e-5 and absolute tolerance of 1e-8 # Use relative tolerance of 1e-5 and absolute tolerance of 1e-8
# assuming most metrics work on `float32` values, which is the common # assuming most metrics work on `float32` values, which is the common
# default floating type across popular libraries (PyTorch, Tensorflow, and JAX). # default floating type across popular libraries (PyTorch, Tensorflow, and JAX).
assert flatten(json.load(fh)) == pytest.approx( assert flatten(json.load(fh)) == pytest.approx(
flatten(json.loads(json.dumps(ob, sort_keys=True))), rel=1e-5, abs=1e-8) flatten(json.loads(json.dumps(ob, sort_keys=True))), rel=1e-5, abs=1e-8
)
else: else:
with open(fname, 'w') as fh: with open(fname, "w") as fh:
json.dump(ob, fh, sort_keys=True) json.dump(ob, fh, sort_keys=True)
...@@ -30,41 +31,52 @@ def assert_target_hashed(name, ob): ...@@ -30,41 +31,52 @@ def assert_target_hashed(name, ob):
fname = f"tests/testdata/{name}" fname = f"tests/testdata/{name}"
if os.path.exists(fname): if os.path.exists(fname):
with open(fname) as fh: with open(fname) as fh:
assert fh.read() == hashlib.sha256(json.dumps(ob, sort_keys=True).encode('utf-8')).hexdigest() assert (
fh.read()
== hashlib.sha256(
json.dumps(ob, sort_keys=True).encode("utf-8")
).hexdigest()
)
else: else:
with open(fname, 'w') as fh: with open(fname, "w") as fh:
fh.write(hashlib.sha256(json.dumps(ob, sort_keys=True).encode('utf-8')).hexdigest()) fh.write(
hashlib.sha256(
json.dumps(ob, sort_keys=True).encode("utf-8")
).hexdigest()
)
# from https://stackoverflow.com/a/6027615 # from https://stackoverflow.com/a/6027615
def flatten(d, parent_key='', sep='.'): def flatten(d, parent_key="", sep="."):
items = [] items = []
for k, v in d.items(): for k, v in d.items():
new_key = parent_key + sep + k if parent_key else k new_key = parent_key + sep + k if parent_key else k
if isinstance(v, collections.MutableMapping): if isinstance(v, collections.abc.MutableMapping):
items.extend(flatten(v, new_key, sep=sep).items()) items.extend(flatten(v, new_key, sep=sep).items())
else: else:
items.append((new_key, v)) items.append((new_key, v))
return dict(items) return dict(items)
# make sure eval results for a task version are stable # make sure eval results for a task version are stable
@pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items()) @pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items())
def test_versions_stable(taskname, task_class): def test_versions_stable(taskname, task_class):
task_dict = tasks.get_task_dict([taskname]) task_dict = tasks.get_task_dict([taskname])
lm = models.get_model('dummy')() lm = models.get_model("dummy")()
def ll_fn(reqs): def ll_fn(reqs):
for ctx, cont in reqs: for ctx, cont in reqs:
if len(ctx) == 0: if len(ctx) == 0:
continue continue
# space convention # space convention
assert ctx[-1] != ' ' assert ctx[-1] != " "
assert cont[0] == ' ' or ctx[-1] == '\n' assert cont[0] == " " or ctx[-1] == "\n"
assert_target_hashed(f"{taskname}-v{task_class.VERSION}-loglikelihood", reqs) assert_target_hashed(f"{taskname}-v{task_class.VERSION}-loglikelihood", reqs)
res = [] res = []
random.seed(42) random.seed(42)
for _ in reqs: for _ in reqs:
res.append((-random.random(), False)) res.append((-random.random(), False))
...@@ -72,10 +84,12 @@ def test_versions_stable(taskname, task_class): ...@@ -72,10 +84,12 @@ def test_versions_stable(taskname, task_class):
return res return res
def ll_perp_fn(reqs): def ll_perp_fn(reqs):
for string, in reqs: for (string,) in reqs:
assert isinstance(string, str) assert isinstance(string, str)
assert_target_hashed(f"{taskname}-v{task_class.VERSION}-loglikelihood_rolling", reqs) assert_target_hashed(
f"{taskname}-v{task_class.VERSION}-loglikelihood_rolling", reqs
)
res = [] res = []
random.seed(42) random.seed(42)
...@@ -83,14 +97,14 @@ def test_versions_stable(taskname, task_class): ...@@ -83,14 +97,14 @@ def test_versions_stable(taskname, task_class):
res.append(-random.random()) res.append(-random.random())
return res return res
def greedy_until(reqs): def greedy_until(reqs):
res = [] res = []
assert_target_hashed(f"{taskname}-v{task_class.VERSION}-greedy_until", reqs) assert_target_hashed(f"{taskname}-v{task_class.VERSION}-greedy_until", reqs)
for ctx, _ in reqs: for ctx, _ in reqs:
res.append("lol") res.append("lol")
assert ctx.strip() != '' assert ctx.strip() != ""
return res return res
...@@ -100,12 +114,12 @@ def test_versions_stable(taskname, task_class): ...@@ -100,12 +114,12 @@ def test_versions_stable(taskname, task_class):
limit = None limit = None
result = evaluator.evaluate( result = evaluator.evaluate(
lm=lm, lm=lm,
task_dict=task_dict, task_dict=task_dict,
num_fewshot=0, num_fewshot=0,
limit=limit, limit=limit,
bootstrap_iters=10, bootstrap_iters=10,
description_dict=None description_dict=None,
) )
assert_target(f"{taskname}-v{task_class.VERSION}-res", result) assert_target(f"{taskname}-v{task_class.VERSION}-res", result)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment