"...gaoqiong/lm-evaluation-harness.git" did not exist on "37a6dbe2bfbb2c37e53b6f5f9a68b08d946bc437"
Commit 18c0fa29 authored by cardy20's avatar cardy20
Browse files

conflict solved

parents 09915adf 0542d35d
...@@ -10,23 +10,24 @@ import pytest ...@@ -10,23 +10,24 @@ import pytest
# TODO: more fine grained unit tests rather than this big honking integration # TODO: more fine grained unit tests rather than this big honking integration
# test once we break evaluator into smaller, more manageable pieces # test once we break evaluator into smaller, more manageable pieces
@pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items()) @pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items())
def test_evaluator(taskname, task_class): def test_evaluator(taskname, task_class):
task_dict = tasks.get_task_dict([taskname]) task_dict = tasks.get_task_dict([taskname])
os.system("rm test_cache.db") os.system("rm test_cache.db")
lm = base.CachingLM(models.get_model('dummy')(), "test_cache.db") lm = base.CachingLM(models.get_model("dummy")(), "test_cache.db")
def ll_fn(reqs): def ll_fn(reqs):
for ctx, cont in reqs: for ctx, cont in reqs:
if len(ctx) == 0: if len(ctx) == 0:
continue continue
# space convention # space convention
assert ctx[-1] != ' ' assert ctx[-1] != " "
assert cont[0] == ' ' or ctx[-1] == '\n' assert cont[0] == " " or ctx[-1] == "\n"
res = [] res = []
random.seed(42) random.seed(42)
for _ in reqs: for _ in reqs:
res.append((-random.random(), False)) res.append((-random.random(), False))
...@@ -34,7 +35,7 @@ def test_evaluator(taskname, task_class): ...@@ -34,7 +35,7 @@ def test_evaluator(taskname, task_class):
return res return res
def ll_perp_fn(reqs): def ll_perp_fn(reqs):
for string, in reqs: for (string,) in reqs:
assert isinstance(string, str) assert isinstance(string, str)
res = [] res = []
...@@ -49,20 +50,20 @@ def test_evaluator(taskname, task_class): ...@@ -49,20 +50,20 @@ def test_evaluator(taskname, task_class):
limit = 10 limit = 10
e1 = evaluator.evaluate( e1 = evaluator.evaluate(
lm=lm, lm=lm,
task_dict=task_dict, task_dict=task_dict,
num_fewshot=0, num_fewshot=0,
limit=limit, limit=limit,
bootstrap_iters=10, bootstrap_iters=10,
description_dict=None description_dict=None,
) )
e2 = evaluator.evaluate( e2 = evaluator.evaluate(
lm=lm, lm=lm,
task_dict=task_dict, task_dict=task_dict,
num_fewshot=0, num_fewshot=0,
limit=limit, limit=limit,
bootstrap_iters=10, bootstrap_iters=10,
description_dict=None description_dict=None,
) )
# check that caching is working # check that caching is working
......
...@@ -3,27 +3,32 @@ from collections import Counter ...@@ -3,27 +3,32 @@ from collections import Counter
import shutil import shutil
import glob import glob
from scripts.clean_training_data.janitor import * from lm_eval.decontamination.janitor import Janitor, word_ngrams
from scripts.clean_training_data.generate_13_grams import do_ngrams_in_buckets from scripts.clean_training_data.generate_13_grams import do_ngrams_in_buckets
from scripts.clean_training_data.archiver import Archive, TextReader from lm_eval.decontamination.archiver import Archive, TextReader
import logging
def test_generate_13_grams_1(): logger = logging.getLogger(__name__)
data = """A goose (plural geese) is a bird of any of several waterfowl species in the family Anatidae.
This group comprises the genera Anser (the grey geese and white geese) and Branta (the black geese).
Some other birds, mostly related to the shelducks, have "goose" as part of their names. def test_generate_13_grams_1(caplog):
More distantly related members of the family Anatidae are swans, most of which are larger data = """A goose (plural geese) is a bird of any of several waterfowl species in the family Anatidae.
than true geese, and ducks, which are smaller. The term "goose" may refer to either a male This group comprises the genera Anser (the grey geese and white geese) and Branta (the black geese).
or female bird, but when paired with "gander", refers specifically to a female one (the latter referring Some other birds, mostly related to the shelducks, have "goose" as part of their names.
to a male). Young birds before fledging are called goslings. The collective noun for a group of More distantly related members of the family Anatidae are swans, most of which are larger
geese on the ground is a gaggle; when in flight, they are called a skein, a team, or a wedge; when than true geese, and ducks, which are smaller. The term "goose" may refer to either a male
or female bird, but when paired with "gander", refers specifically to a female one (the latter referring
to a male). Young birds before fledging are called goslings. The collective noun for a group of
geese on the ground is a gaggle; when in flight, they are called a skein, a team, or a wedge; when
flying close together, they are called a plump.""" flying close together, they are called a plump."""
data = data + data data = data + data
# Simple Generation # Simple Generation
print("simple generation")
n = 13 n = 13
janitor = Janitor() janitor = Janitor()
ngrams = word_ngrams(janitor.normalize_string(data), n) ngrams = word_ngrams(janitor.normalize_string(data), n)
comparison = list(ngrams) comparison = list(ngrams)
comparison_counter = Counter(comparison) comparison_counter = Counter(comparison)
...@@ -31,35 +36,42 @@ def test_generate_13_grams_1(): ...@@ -31,35 +36,42 @@ def test_generate_13_grams_1():
# print(comparison) # print(comparison)
# Generating into buckets # Generating into buckets
print("bucket generation")
test_working_directory = "test_generate_13_grams" test_working_directory = "test_generate_13_grams"
output_directory = os.path.join(test_working_directory, "output")
try: try:
shutil.rmtree(output_directory) shutil.rmtree(test_working_directory)
except FileNotFoundError: except FileNotFoundError:
pass pass
os.makedirs(test_working_directory, exist_ok=True) os.makedirs(test_working_directory)
archive = Archive(os.path.join(test_working_directory, "test.jsonl.zst"))
assert not os.path.exists("pile")
os.makedirs("pile")
archive = Archive(os.path.join("pile", "test.jsonl.zst"))
archive.add_data(data) archive.add_data(data)
archive.commit() archive.commit()
bucket_count = 4 bucket_count = 4
do_ngrams_in_buckets(n, test_working_directory, bucket_count) do_ngrams_in_buckets(n, test_working_directory, bucket_count)
# Rebuild from buckets # Rebuild from buckets
print("rebuild")
rebuilt_ngrams = [] rebuilt_ngrams = []
bucket_file_paths = glob.glob(
bucket_file_paths = glob.glob(os.path.join(test_working_directory, "output", f"*.bkt.txt")) os.path.join(test_working_directory, "output", f"*.bkt.txt")
)
for bucket_file_path in bucket_file_paths: for bucket_file_path in bucket_file_paths:
reader = TextReader(bucket_file_path) reader = TextReader(bucket_file_path)
for line in reader.read(): for line in reader.read():
[ngram, document_id] = line.rsplit(" ", 1) [ngram, document_id] = line.rsplit(" ", 1)
rebuilt_ngrams.append(ngram) rebuilt_ngrams.append(ngram)
# Compare # Compare
print("compare")
result_counter = Counter(rebuilt_ngrams) result_counter = Counter(rebuilt_ngrams)
# print(len(result_counter)) # print(len(result_counter))
# print(len(comparison_counter)) # print(len(comparison_counter))
assert(len(result_counter) == len(comparison_counter)) assert len(result_counter) == len(comparison_counter)
# print(result_counter) # print(result_counter)
# print(comparison_counter) # print(comparison_counter)
assert(comparison_counter == result_counter) assert comparison_counter == result_counter
\ No newline at end of file
import lm_eval.models as models
import pytest
import os
import json
import openai
import mock
import pickle
import hashlib
def mock_completion(**kwargs):
# Mock completion function
# Loads from a cached+pickled response if it exists, otherwise it will actually try to ping
os.makedirs("tests/testdata", exist_ok=True)
hash = hashlib.sha256(json.dumps(kwargs, sort_keys=True).encode('utf-8')).hexdigest()
fname = f"tests/testdata/gpt3_test_{hash}.pkl"
if os.path.exists(fname):
with open(fname, 'rb') as fh:
return pickle.load(fh)
ret = openai.Completion.create(**kwargs)
ret.api_key = ""
with open(fname, 'wb') as fh:
pickle.dump(ret, fh)
return ret
@mock.patch("lm_eval.models.gpt3.oa_completion", new=mock_completion)
def test_gpt3():
if "OPENAI_API_SECRET_KEY" not in os.environ: os.environ["OPENAI_API_SECRET_KEY"] = ""
gpt3 = models.get_model('gpt3').create_from_arg_string("engine=ada")
(ll_dog, ig_dog), (ll_cat, ig_cat), (_, ll_max_0), (_, ll_max_1), (_, ll_max_2), *vals = gpt3.loglikelihood([
('The quick brown fox jumps over the lazy', ' dog'),
('The quick brown fox jumps over the lazy', ' cat'),
('The quick brown fox jumps over the lazy', ', lazy dog'),
('The quick brown fox jumps over the lazy', ', lazy fox'),
('The quick brown fox jumps over the lazy', ', lazy fox and they both fall to the ground'),
("""A mult""", """ilayer perceptron (MLP) is a class of feedforward artificial neural network (ANN)"""),
("""The term MLP is used ambiguously, sometimes loosely to any feedforward ANN, sometimes strictly to refer to networks composed of multiple layers of perceptrons""", """ (with threshold activation); see § Terminology"""),
("""Multilayer perceptrons are sometimes coll""", """oquially referred to as "vanilla" neural networks, especially when they have a single hidden layer.[1]"""),
("""An MLP consists of at least three layers of nodes: an input layer, a hidden layer and an output layer. Except for the input nodes, each node is a neuron that uses a nonlinear""", """ activation function."""),
("""MLP utilizes a supervised""", """ learning technique called backpropagation for training.[2][3] Its multiple layers and non-linear activation distinguish MLP from a linear perceptron. It can distinguish data that is not linearly separable.[4]"""),
("""Recent work has demonstrated substantial gains on many NLP tasks and benchmarks by pre-training on a large corpus of text followed by fine-tuning on a specific task. While typically task-agnostic""", """ in architecture, this method still requires task-specific fine-tuning datasets of thousands or tens of thousands of examples. By contrast, humans can generally perform a new language task from only a few examples or from simple instructions - something which current NLP systems still largely struggle to do. Here we show that scaling up language models greatly improves task-agnostic, few-shot performance, sometimes even reaching competitiveness with prior state-of-the-art fine-tuning approaches. """),
("""Specifically, we train GPT-3, an autoregressive language model with 175""", """ billion parameters, 10x more than any previous non-sparse language model, and test its performance in the few-shot setting. For all tasks, GPT-3 is applied without any gradient updates or fine-tuning, with tasks and few-shot demonstrations specified purely via text interaction with the model. GPT-3 achieves strong performance on many NLP datasets, including translation, question-answering, and cloze tasks, as well as several tasks that require on-the-fly reasoning or domain adaptation, such as unscrambling words, using a novel word in a sentence, or performing 3-digit arithmetic. At the same time, we also identify some datasets where GPT-3's few-shot learning still struggles, as well as some datasets where GPT-3 faces methodological issues related to training on large web corpora. Finally, we find that GPT-3 can generate samples of news articles which human evaluators have difficulty distinguishing from articles written by humans. We discuss broader societal impacts of this finding and of GPT-3 in general."""),
("""A mult""", """ilayer perceptron (MLP) is a class of feedforward artificial neural network (ANN)"""),
("""Hello""", """ World"""),
])
assert ll_dog > ll_cat
assert not ig_cat
assert ig_dog
assert not ll_max_0
assert not ll_max_1
assert not ll_max_2
# test empty context
gpt3.loglikelihood([('', 'test')])
gen, = gpt3.greedy_until([
('The quick brown fox jumps over the lazy', ['.', '\n'])
])
assert gen == ' dog'
print([x[0] for x in vals])
targets = [
-34.848301606999996, -47.148329679999996, -45.44380149599999, -5.285246016, -133.97821690686004,
-321.2616693239001, -658.0299524401041, -34.848301606999996, -7.525115,
]
for (pred, _), tgt in zip(vals, targets):
assert pred == pytest.approx(tgt, rel=1e-3)
@mock.patch("lm_eval.models.gpt3.oa_completion", new=mock_completion)
def test_gpt3_perplexity():
if "OPENAI_API_SECRET_KEY" not in os.environ: os.environ["OPENAI_API_SECRET_KEY"] = ""
gpt3 = models.get_model('gpt3').create_from_arg_string("engine=ada")
test_string = "We study empirical scaling laws for language model performance on the cross-entropy loss."
perplexity = gpt3.loglikelihood_rolling([(test_string,)])[0]
tgt = -84.38819608
assert perplexity == pytest.approx(tgt, rel=1e-3)
# Hack: modify gpt3 to have shorter context length to induce rolling windows
with mock.patch.object(models.gpt3.GPT3LM, 'max_length', new_callable=mock.PropertyMock) as mock_max_length:
mock_max_length.return_value = 5
gpt3 = models.get_model('gpt3').create_from_arg_string("engine=ada")
perplexity = gpt3.loglikelihood_rolling([(test_string,)])[0]
tgt = -101.81967209999999
assert perplexity == pytest.approx(tgt, rel=1e-3)
import re import re
from collections import defaultdict from collections import defaultdict
from scripts.clean_training_data.janitor import * from lm_eval.decontamination.janitor import (
Janitor,
form_ngrams,
word_ngrams,
split_indices,
word_ngrams_indices,
)
def simple_ngram(sequence, n): def simple_ngram(sequence, n):
ngrams = list() ngrams = list()
...@@ -16,8 +23,10 @@ def simple_ngram(sequence, n): ...@@ -16,8 +23,10 @@ def simple_ngram(sequence, n):
def test_form_ngrams(): def test_form_ngrams():
sequence = "Hello my name is Bob, I like eating pizza, chicken, chips and ice cream. Maybe I should eat some" \ sequence = (
" more salad but it's so booooring. I just... like eating pizza, chicken, chips and ice cream so much." "Hello my name is Bob, I like eating pizza, chicken, chips and ice cream. Maybe I should eat some"
" more salad but it's so booooring. I just... like eating pizza, chicken, chips and ice cream so much."
)
n_values = [1, 2, 3, 5, 13] n_values = [1, 2, 3, 5, 13]
for n in n_values: for n in n_values:
...@@ -26,9 +35,12 @@ def test_form_ngrams(): ...@@ -26,9 +35,12 @@ def test_form_ngrams():
assert len(comparison) == len(result_to_test) assert len(comparison) == len(result_to_test)
assert comparison == result_to_test assert comparison == result_to_test
def test_word_ngrams(): def test_word_ngrams():
sequence = "Hello my name is Bob, I like eating pizza, chicken, chips and ice cream. Maybe I should eat some" \ sequence = (
" more salad but it's so booooring. I just... like eating pizza, chicken, chips and ice cream so much." "Hello my name is Bob, I like eating pizza, chicken, chips and ice cream. Maybe I should eat some"
" more salad but it's so booooring. I just... like eating pizza, chicken, chips and ice cream so much."
)
words = sequence.split() words = sequence.split()
...@@ -40,9 +52,12 @@ def test_word_ngrams(): ...@@ -40,9 +52,12 @@ def test_word_ngrams():
assert len(comparison) == len(result_to_test) assert len(comparison) == len(result_to_test)
assert result_to_test == comparison assert result_to_test == comparison
def test_split_indices(): def test_split_indices():
sequence = "Hello my name is Bob, I like eating pizza, chicken, chips and ice cream. Maybe I should eat some" \ sequence = (
" more salad but it's so booooring. I just... like eating pizza, chicken, chips and ice cream so much." "Hello my name is Bob, I like eating pizza, chicken, chips and ice cream. Maybe I should eat some"
" more salad but it's so booooring. I just... like eating pizza, chicken, chips and ice cream so much."
)
comparison = [] comparison = []
current_word = "" current_word = ""
...@@ -55,17 +70,22 @@ def test_split_indices(): ...@@ -55,17 +70,22 @@ def test_split_indices():
current_word = "" current_word = ""
if current_word: if current_word:
comparison.append((current_word, (len(sequence) - len(current_word), len(sequence) - 1))) comparison.append(
current_word = "" (current_word, (len(sequence) - len(current_word), len(sequence) - 1))
)
current_word = ""
result_to_test = list(split_indices(sequence)) result_to_test = list(split_indices(sequence))
assert len(comparison) == len(result_to_test) assert len(comparison) == len(result_to_test)
assert(comparison == result_to_test) assert comparison == result_to_test
def test_word_ngrams_indices(): def test_word_ngrams_indices():
sequence = "Hello my name is Bob, I like eating pizza, chicken, chips and ice cream. Maybe I should eat some" \ sequence = (
" more salad but it's so booooring. I just... like eating pizza, chicken, chips and ice cream so much." "Hello my name is Bob, I like eating pizza, chicken, chips and ice cream. Maybe I should eat some"
" more salad but it's so booooring. I just... like eating pizza, chicken, chips and ice cream so much."
)
n_values = [1, 2, 3, 5, 13] n_values = [1, 2, 3, 5, 13]
...@@ -76,55 +96,62 @@ def test_word_ngrams_indices(): ...@@ -76,55 +96,62 @@ def test_word_ngrams_indices():
for ngram in ngrams: for ngram in ngrams:
while True: while True:
start = sequence.find(ngram, tracker[ngram]) start = sequence.find(ngram, tracker[ngram])
assert start != -1 # testing the test assert start != -1 # testing the test
end = start + len(ngram) - 1 end = start + len(ngram) - 1
tracker[ngram] = end + 1 tracker[ngram] = end + 1
# ignore partial word matches # ignore partial word matches
if (start != 0 and sequence[start - 1] != " ") or \ if (start != 0 and sequence[start - 1] != " ") or (
(end != len(sequence) - 1 and sequence[end + 1] != " "): end != len(sequence) - 1 and sequence[end + 1] != " "
):
pass pass
else: else:
break break
comparison.append((ngram, (start, end))) comparison.append((ngram, (start, end)))
result_to_test = list(word_ngrams_indices(sequence, n)) result_to_test = list(word_ngrams_indices(sequence, n))
assert len(result_to_test) == len(comparison) assert len(result_to_test) == len(comparison)
assert result_to_test == comparison assert result_to_test == comparison
# Assumptions from GPT3 Paper: # Assumptions from GPT3 Paper:
# the 200 characters to remove include punctuation and is actually a half-window # the 200 characters to remove include punctuation and is actually a half-window
# All tests below initially test without any registered contaminants, expecting the same sequence back. # All tests below initially test without any registered contaminants, expecting the same sequence back.
def test_janitor1(): def test_janitor1():
# First test using a 1gram and expected the first block before the filth to have some remaining # First test using a 1gram and expected the first block before the filth to have some remaining
# characters, but the second block should be completely removed. # characters, but the second block should be completely removed.
sequence = "This is a @line #containing a certain number of characters, 76 to be exact. " \ sequence = (
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"FILTH. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "FILTH. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
)
filth = "filth" filth = "filth"
expected_result = "This is a @line #containing a certain number of characters, 76 to be exact. " \ expected_result = (
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing " "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing "
)
janitor = Janitor(ngram_n=1, window_to_remove=200, too_dirty_cutoff=10, minimum_slice_length=200) janitor = Janitor(
ngram_n=1, window_to_remove=200, too_dirty_cutoff=10, minimum_slice_length=200
)
result = janitor.clean_python(sequence) result = janitor.clean_python(sequence)
result = "".join(result) result = "".join(result)
assert result == sequence assert result == sequence
...@@ -133,42 +160,47 @@ def test_janitor1(): ...@@ -133,42 +160,47 @@ def test_janitor1():
assert janitor.dirt_ngrams == {filth} assert janitor.dirt_ngrams == {filth}
result = janitor.clean_python(sequence) result = janitor.clean_python(sequence)
result = "".join(result) result = "".join(result)
assert result == expected_result assert result == expected_result
def test_janitor2(): def test_janitor2():
# Second test using a 1gram and expected the first block before the filth to have some remaining # Second test using a 1gram and expected the first block before the filth to have some remaining
# characters, and the second block is longer then 200 characters so should also have some remaining. # characters, and the second block is longer then 200 characters so should also have some remaining.
sequence = "This is a @line #containing a certain number of characters, 76 to be exact. " \ sequence = (
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"FILTH. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "FILTH. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
)
filth = "filth" filth = "filth"
expected_result = "This is a @line #containing a certain number of characters, 76 to be exact. " \ expected_result = (
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
" characters, 76 to be exact. " \ "This is a @line #containing "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ " characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
janitor = Janitor(ngram_n=1, window_to_remove=200, too_dirty_cutoff=10, minimum_slice_length=200) )
janitor = Janitor(
ngram_n=1, window_to_remove=200, too_dirty_cutoff=10, minimum_slice_length=200
)
result = janitor.clean_python(sequence) result = janitor.clean_python(sequence)
result = "".join(result) result = "".join(result)
assert result == sequence assert result == sequence
...@@ -180,37 +212,43 @@ def test_janitor2(): ...@@ -180,37 +212,43 @@ def test_janitor2():
result = "".join(result) result = "".join(result)
assert result == expected_result assert result == expected_result
def test_janitor3(): def test_janitor3():
# Same test as above but with a 6gram. # Same test as above but with a 6gram.
sequence = "This is a @line #containing a certain number of characters, 76 to be exact. " \ sequence = (
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"FILTH. lots of dirty filtHy FIlTh " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "FILTH. lots of dirty filtHy FIlTh "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
)
filth = "filth lots of dirty filthy filth" filth = "filth lots of dirty filthy filth"
expected_result = "This is a @line #containing a certain number of characters, 76 to be exact. " \ expected_result = (
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
" characters, 76 to be exact. " \ "This is a @line #containing "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ " characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
janitor = Janitor(ngram_n=6, window_to_remove=200, too_dirty_cutoff=10, minimum_slice_length=200) )
janitor = Janitor(
ngram_n=6, window_to_remove=200, too_dirty_cutoff=10, minimum_slice_length=200
)
result = janitor.clean_python(sequence) result = janitor.clean_python(sequence)
result = "".join(result) result = "".join(result)
assert result == sequence assert result == sequence
...@@ -222,45 +260,51 @@ def test_janitor3(): ...@@ -222,45 +260,51 @@ def test_janitor3():
result = "".join(result) result = "".join(result)
assert result == expected_result assert result == expected_result
def test_janitor4(): def test_janitor4():
# This test adds another block to that from the previous. The middle block should be entirely # This test adds another block to that from the previous. The middle block should be entirely
# removed as the 200 characters are removed from each side. # removed as the 200 characters are removed from each side.
sequence = "This is a @line #containing a certain number of characters, 76 to be exact. " \ sequence = (
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"FILTH. lots of dirty filtHy FIlTh " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "FILTH. lots of dirty filtHy FIlTh "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"FILTH. lots of dirty filtHy FIlTh " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "FILTH. lots of dirty filtHy FIlTh "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
)
filth = "filth lots of dirty filthy filth" filth = "filth lots of dirty filthy filth"
expected_result = "This is a @line #containing a certain number of characters, 76 to be exact. " \ expected_result = (
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
" characters, 76 to be exact. " \ "This is a @line #containing "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ " characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
janitor = Janitor(ngram_n=6, window_to_remove=200, too_dirty_cutoff=10, minimum_slice_length=200) )
janitor = Janitor(
ngram_n=6, window_to_remove=200, too_dirty_cutoff=10, minimum_slice_length=200
)
result = janitor.clean_python(sequence) result = janitor.clean_python(sequence)
result = "".join(result) result = "".join(result)
assert result == sequence assert result == sequence
...@@ -272,49 +316,55 @@ def test_janitor4(): ...@@ -272,49 +316,55 @@ def test_janitor4():
result = "".join(result) result = "".join(result)
assert result == expected_result assert result == expected_result
def test_janitor5(): def test_janitor5():
# Same as above but using multiple different filth 6grams. # Same as above but using multiple different filth 6grams.
sequence = "This is a @line #containing a certain number of characters, 76 to be exact. " \ sequence = (
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"FILTH. lots of dirty filtHy FIlTh " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "FILTH. lots of dirty filtHy FIlTh "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"FILTH. lots of filtHy dirty FIlTh " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "FILTH. lots of filtHy dirty FIlTh "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
)
filths = ["filth lots of dirty filthy filth", "filth lots of filthy dirty filth"]
filths = ["filth lots of dirty filthy filth", "filth lots of filthy dirty filth"]
expected_result = "This is a @line #containing a certain number of characters, 76 to be exact. " \
"This is a @line #containing a certain number of characters, 76 to be exact. " \ expected_result = (
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
" characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ " characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
janitor = Janitor(ngram_n=6, window_to_remove=200, too_dirty_cutoff=10, minimum_slice_length=200) "This is a @line #containing a certain number of characters, 76 to be exact. "
)
janitor = Janitor(
ngram_n=6, window_to_remove=200, too_dirty_cutoff=10, minimum_slice_length=200
)
result = janitor.clean_python(sequence) result = janitor.clean_python(sequence)
result = "".join(result) result = "".join(result)
assert result == sequence assert result == sequence
for filth in filths: for filth in filths:
janitor.register_contaminant(filth) janitor.register_contaminant(filth)
assert janitor.dirt_ngrams == set(filths) assert janitor.dirt_ngrams == set(filths)
...@@ -322,57 +372,63 @@ def test_janitor5(): ...@@ -322,57 +372,63 @@ def test_janitor5():
result = "".join(result) result = "".join(result)
assert result == expected_result assert result == expected_result
def test_janitor6(): def test_janitor6():
# Same as above but now we add 10 filths and expect the same result, the following test does 11. # Same as above but now we add 10 filths and expect the same result, the following test does 11.
sequence = "This is a @line #containing a certain number of characters, 76 to be exact. " \ sequence = (
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"FILTH. lots of dirty filtHy FIlTh " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"FILTH. lots of dirty filtHy FIlTh " \ "FILTH. lots of dirty filtHy FIlTh "
"FILTH. lots of dirty filtHy FIlTh " \ "FILTH. lots of dirty filtHy FIlTh "
"FILTH. lots of dirty filtHy FIlTh " \ "FILTH. lots of dirty filtHy FIlTh "
"FILTH. lots of dirty filtHy FIlTh " \ "FILTH. lots of dirty filtHy FIlTh "
"FILTH. lots of dirty filtHy FIlTh " \ "FILTH. lots of dirty filtHy FIlTh "
"FILTH. lots of dirty filtHy FIlTh " \ "FILTH. lots of dirty filtHy FIlTh "
"FILTH. lots of dirty filtHy FIlTh " \ "FILTH. lots of dirty filtHy FIlTh "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "FILTH. lots of dirty filtHy FIlTh "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"FILTH. lots of filtHy dirty FIlTh " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"FILTH. lots of filtHy dirty FIlTh " \ "FILTH. lots of filtHy dirty FIlTh "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "FILTH. lots of filtHy dirty FIlTh "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
)
filths = ["filth lots of dirty filthy filth", "filth lots of filthy dirty filth"]
filths = ["filth lots of dirty filthy filth", "filth lots of filthy dirty filth"]
expected_result = "This is a @line #containing a certain number of characters, 76 to be exact. " \
"This is a @line #containing a certain number of characters, 76 to be exact. " \ expected_result = (
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
" characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ " characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
janitor = Janitor(ngram_n=6, window_to_remove=200, too_dirty_cutoff=10, minimum_slice_length=200) "This is a @line #containing a certain number of characters, 76 to be exact. "
)
janitor = Janitor(
ngram_n=6, window_to_remove=200, too_dirty_cutoff=10, minimum_slice_length=200
)
result = janitor.clean_python(sequence) result = janitor.clean_python(sequence)
result = "".join(result) result = "".join(result)
assert result == sequence assert result == sequence
for filth in filths: for filth in filths:
janitor.register_contaminant(filth) janitor.register_contaminant(filth)
assert janitor.dirt_ngrams == set(filths) assert janitor.dirt_ngrams == set(filths)
...@@ -380,51 +436,55 @@ def test_janitor6(): ...@@ -380,51 +436,55 @@ def test_janitor6():
result = "".join(result) result = "".join(result)
assert result == expected_result assert result == expected_result
def test_janitor7(): def test_janitor7():
# Same as above but now we add 9 filths and expect the same result, the following test does 10. # Same as above but now we add 9 filths and expect the same result, the following test does 10.
sequence = "This is a @line #containing a certain number of characters, 76 to be exact. " \ sequence = (
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"FILTH. lots of dirty filtHy FIlTh " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"FILTH. lots of dirty filtHy FIlTh " \ "FILTH. lots of dirty filtHy FIlTh "
"FILTH. lots of dirty filtHy FIlTh " \ "FILTH. lots of dirty filtHy FIlTh "
"FILTH. lots of dirty filtHy FIlTh " \ "FILTH. lots of dirty filtHy FIlTh "
"FILTH. lots of dirty filtHy FIlTh " \ "FILTH. lots of dirty filtHy FIlTh "
"FILTH. lots of dirty filtHy FIlTh " \ "FILTH. lots of dirty filtHy FIlTh "
"FILTH. lots of dirty filtHy FIlTh " \ "FILTH. lots of dirty filtHy FIlTh "
"FILTH. lots of dirty filtHy FIlTh " \ "FILTH. lots of dirty filtHy FIlTh "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "FILTH. lots of dirty filtHy FIlTh "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"FILTH. lots of filtHy dirty FIlTh " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"FILTH. lots of filtHy dirty FIlTh " \ "FILTH. lots of filtHy dirty FIlTh "
"FILTH. lots of filtHy dirty FIlTh " \ "FILTH. lots of filtHy dirty FIlTh "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "FILTH. lots of filtHy dirty FIlTh "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
)
filths = ["filth lots of dirty filthy filth", "filth lots of filthy dirty filth"]
filths = ["filth lots of dirty filthy filth", "filth lots of filthy dirty filth"]
expected_result = "" expected_result = ""
janitor = Janitor(ngram_n=6, window_to_remove=200, too_dirty_cutoff=10, minimum_slice_length=200) janitor = Janitor(
ngram_n=6, window_to_remove=200, too_dirty_cutoff=10, minimum_slice_length=200
)
result = janitor.clean_python(sequence) result = janitor.clean_python(sequence)
result = "".join(result) result = "".join(result)
assert result == sequence assert result == sequence
for filth in filths: for filth in filths:
janitor.register_contaminant(filth) janitor.register_contaminant(filth)
assert janitor.dirt_ngrams == set(filths) assert janitor.dirt_ngrams == set(filths)
...@@ -453,23 +513,3 @@ def test_janitor8(): ...@@ -453,23 +513,3 @@ def test_janitor8():
# cleaned = " ".join(jan.clean(source)) # cleaned = " ".join(jan.clean(source))
# for contam in jan.dirt_ngrams: # for contam in jan.dirt_ngrams:
# assert contam not in cleaned, contam # assert contam not in cleaned, contam
import hashlib
import json
import openai
import os
import pickle
import pytest import pytest
import unittest.mock as mock import unittest.mock as mock
import lm_eval.models as models import lm_eval.models as models
LOGLIKELIHOOD_TEST_CASES = [
("The quick brown fox jumps over the lazy", " dog"),
("The quick brown fox jumps over the lazy", " cat"),
("The quick brown fox jumps over the lazy", ", lazy dog"),
("The quick brown fox jumps over the lazy", ", lazy fox"),
(
"The quick brown fox jumps over the lazy",
", lazy fox and they both fall to the ground",
),
(
"""A mult""",
"""ilayer perceptron (MLP) is a class of feedforward artificial neural network (ANN)""",
),
(
"""The term MLP is used ambiguously, sometimes loosely to any feedforward ANN, sometimes strictly to refer to networks composed of multiple layers of perceptrons""",
""" (with threshold activation); see § Terminology""",
),
(
"""Multilayer perceptrons are sometimes coll""",
"""oquially referred to as "vanilla" neural networks, especially when they have a single hidden layer.[1]""",
),
(
"""An MLP consists of at least three layers of nodes: an input layer, a hidden layer and an output layer. Except for the input nodes, each node is a neuron that uses a nonlinear""",
""" activation function.""",
),
(
"""MLP utilizes a supervised""",
""" learning technique called backpropagation for training.[2][3] Its multiple layers and non-linear activation distinguish MLP from a linear perceptron. It can distinguish data that is not linearly separable.[4]""",
),
(
"""Recent work has demonstrated substantial gains on many NLP tasks and benchmarks by pre-training on a large corpus of text followed by fine-tuning on a specific task. While typically task-agnostic""",
""" in architecture, this method still requires task-specific fine-tuning datasets of thousands or tens of thousands of examples. By contrast, humans can generally perform a new language task from only a few examples or from simple instructions - something which current NLP systems still largely struggle to do. Here we show that scaling up language models greatly improves task-agnostic, few-shot performance, sometimes even reaching competitiveness with prior state-of-the-art fine-tuning approaches. """,
),
(
"""Specifically, we train GPT-3, an autoregressive language model with 175""",
""" billion parameters, 10x more than any previous non-sparse language model, and test its performance in the few-shot setting. For all tasks, GPT-3 is applied without any gradient updates or fine-tuning, with tasks and few-shot demonstrations specified purely via text interaction with the model. GPT-3 achieves strong performance on many NLP datasets, including translation, question-answering, and cloze tasks, as well as several tasks that require on-the-fly reasoning or domain adaptation, such as unscrambling words, using a novel word in a sentence, or performing 3-digit arithmetic. At the same time, we also identify some datasets where GPT-3's few-shot learning still struggles, as well as some datasets where GPT-3 faces methodological issues related to training on large web corpora. Finally, we find that GPT-3 can generate samples of news articles which human evaluators have difficulty distinguishing from articles written by humans. We discuss broader societal impacts of this finding and of GPT-3 in general.""",
),
(
"""A mult""",
"""ilayer perceptron (MLP) is a class of feedforward artificial neural network (ANN)""",
),
("""Hello""", """ World"""),
]
# Test HuggingFace Models (GPT-2)
def test_gpt2(): def test_gpt2():
gpt2 = models.get_model('gpt2').create_from_arg_string("device=cpu") gpt2 = models.get_model("gpt2").create_from_arg_string("device=cpu")
(ll_dog, ig_dog), (ll_cat, ig_cat), (_, ll_max_0), (_, ll_max_1), (_, ll_max_2), *vals = gpt2.loglikelihood([ (
('The quick brown fox jumps over the lazy', ' dog'), (ll_dog, ig_dog),
('The quick brown fox jumps over the lazy', ' cat'), (ll_cat, ig_cat),
('The quick brown fox jumps over the lazy', ', lazy dog'), (_, ll_max_0),
('The quick brown fox jumps over the lazy', ', lazy fox'), (_, ll_max_1),
('The quick brown fox jumps over the lazy', ', lazy fox and they both fall to the ground'), (_, ll_max_2),
*vals,
("""A mult""", """ilayer perceptron (MLP) is a class of feedforward artificial neural network (ANN)"""), ) = gpt2.loglikelihood(LOGLIKELIHOOD_TEST_CASES)
("""The term MLP is used ambiguously, sometimes loosely to any feedforward ANN, sometimes strictly to refer to networks composed of multiple layers of perceptrons""", """ (with threshold activation); see § Terminology"""),
("""Multilayer perceptrons are sometimes coll""", """oquially referred to as "vanilla" neural networks, especially when they have a single hidden layer.[1]"""),
("""An MLP consists of at least three layers of nodes: an input layer, a hidden layer and an output layer. Except for the input nodes, each node is a neuron that uses a nonlinear""", """ activation function."""),
("""MLP utilizes a supervised""", """ learning technique called backpropagation for training.[2][3] Its multiple layers and non-linear activation distinguish MLP from a linear perceptron. It can distinguish data that is not linearly separable.[4]"""),
("""Recent work has demonstrated substantial gains on many NLP tasks and benchmarks by pre-training on a large corpus of text followed by fine-tuning on a specific task. While typically task-agnostic""", """ in architecture, this method still requires task-specific fine-tuning datasets of thousands or tens of thousands of examples. By contrast, humans can generally perform a new language task from only a few examples or from simple instructions - something which current NLP systems still largely struggle to do. Here we show that scaling up language models greatly improves task-agnostic, few-shot performance, sometimes even reaching competitiveness with prior state-of-the-art fine-tuning approaches. """),
("""Specifically, we train GPT-3, an autoregressive language model with 175""", """ billion parameters, 10x more than any previous non-sparse language model, and test its performance in the few-shot setting. For all tasks, GPT-3 is applied without any gradient updates or fine-tuning, with tasks and few-shot demonstrations specified purely via text interaction with the model. GPT-3 achieves strong performance on many NLP datasets, including translation, question-answering, and cloze tasks, as well as several tasks that require on-the-fly reasoning or domain adaptation, such as unscrambling words, using a novel word in a sentence, or performing 3-digit arithmetic. At the same time, we also identify some datasets where GPT-3's few-shot learning still struggles, as well as some datasets where GPT-3 faces methodological issues related to training on large web corpora. Finally, we find that GPT-3 can generate samples of news articles which human evaluators have difficulty distinguishing from articles written by humans. We discuss broader societal impacts of this finding and of GPT-3 in general."""),
("""A mult""", """ilayer perceptron (MLP) is a class of feedforward artificial neural network (ANN)"""),
("""Hello""", """ World"""),
])
assert ll_dog > ll_cat assert ll_dog > ll_cat
assert not ig_cat assert not ig_cat
...@@ -31,17 +76,24 @@ def test_gpt2(): ...@@ -31,17 +76,24 @@ def test_gpt2():
assert ll_max_2 assert ll_max_2
# test empty context # test empty context
gpt2.loglikelihood([('', 'test')]) gpt2.loglikelihood([("", "test")])
gen, = gpt2.greedy_until([ (gen,) = gpt2.greedy_until(
('The quick brown fox jumps over the lazy', ['.', '\n']) [("The quick brown fox jumps over the lazy", [".", "\n"])]
]) )
assert gen == ', lazy fox and they both fall to the ground' assert gen == ", lazy fox and they both fall to the ground"
targets = [ targets = [
-61.60536193847656, -56.57843780517578, -62.131004333496094, -9.799489974975586, -153.96334838867188, -61.60536193847656,
-341.222900390625, -731.1475830078125, -61.60536193847656, -8.682319641113281 -56.57843780517578,
-62.131004333496094,
-9.799489974975586,
-153.96334838867188,
-341.222900390625,
-731.1475830078125,
-61.60536193847656,
-8.682319641113281,
] ]
for (pred, _), tgt in zip(vals, targets): for (pred, _), tgt in zip(vals, targets):
...@@ -49,21 +101,224 @@ def test_gpt2(): ...@@ -49,21 +101,224 @@ def test_gpt2():
def test_gpt2_perplexity(): def test_gpt2_perplexity():
gpt2 = models.get_model('gpt2').create_from_arg_string("device=cpu") gpt2 = models.get_model("gpt2").create_from_arg_string("device=cpu")
test_string = "We study empirical scaling laws for language model performance on the cross-entropy loss." test_string = "We study empirical scaling laws for language model performance on the cross-entropy loss."
perplexity = gpt2.loglikelihood_rolling([(test_string,)])[0] perplexity = gpt2.loglikelihood_rolling([(test_string,)])[0]
tgt = sum([ tgt = sum(
-4.9599953, -8.069298, -8.308624, -10.178513, -8.906924, -1.9318912, -7.745445, -7.146077, -5.2072, [
-3.5882986, -1.9957212, -8.044922, -0.20841774, -5.1096807, -0.099879116, -8.888423, -4.6180487, -4.9599953,
]) -8.069298,
-8.308624,
-10.178513,
-8.906924,
-1.9318912,
-7.745445,
-7.146077,
-5.2072,
-3.5882986,
-1.9957212,
-8.044922,
-0.20841774,
-5.1096807,
-0.099879116,
-8.888423,
-4.6180487,
]
)
assert perplexity == pytest.approx(tgt, rel=1e-3) assert perplexity == pytest.approx(tgt, rel=1e-3)
with mock.patch.object(models.gpt2.HFLM, 'max_length', new_callable=mock.PropertyMock) as mock_max_length: with mock.patch.object(
models.gpt2.HFLM, "max_length", new_callable=mock.PropertyMock
) as mock_max_length:
mock_max_length.return_value = 5 mock_max_length.return_value = 5
gpt2 = models.get_model('gpt2').create_from_arg_string("device=cpu") gpt2 = models.get_model("gpt2").create_from_arg_string("device=cpu")
perplexity = gpt2.loglikelihood_rolling([(test_string,)])[0] perplexity = gpt2.loglikelihood_rolling([(test_string,)])[0]
tgt = sum([ tgt = sum(
-4.96001, -8.069275, -8.308612, -10.178482, -8.90691, -4.037338, -8.09261, -11.662385, -10.206891, [
-4.425003, -2.2563353, -7.909143, -1.9304147, -7.3610134, -2.3120654, -7.3229, -2.1643813, -4.96001,
]) -8.069275,
-8.308612,
-10.178482,
-8.90691,
-4.037338,
-8.09261,
-11.662385,
-10.206891,
-4.425003,
-2.2563353,
-7.909143,
-1.9304147,
-7.3610134,
-2.3120654,
-7.3229,
-2.1643813,
]
)
assert perplexity == pytest.approx(tgt, rel=1e-3) assert perplexity == pytest.approx(tgt, rel=1e-3)
# Test OpenAI Models (GPT-3)
def openai_mock_completion(**kwargs):
# Mock completion function
# Loads from a cached+pickled response if it exists, otherwise it will actually try to ping
os.makedirs("tests/testdata", exist_ok=True)
hash = hashlib.sha256(
json.dumps(kwargs, sort_keys=True).encode("utf-8")
).hexdigest()
fname = f"tests/testdata/gpt3_test_{hash}.pkl"
if os.path.exists(fname):
with open(fname, "rb") as fh:
return pickle.load(fh)
ret = openai.Completion.create(**kwargs)
ret.api_key = ""
with open(fname, "wb") as fh:
pickle.dump(ret, fh)
return ret
@mock.patch("lm_eval.models.gpt3.oa_completion", new=openai_mock_completion)
def test_gpt3():
if "OPENAI_API_SECRET_KEY" not in os.environ:
os.environ["OPENAI_API_SECRET_KEY"] = ""
gpt3 = models.get_model("gpt3").create_from_arg_string("engine=ada")
(
(ll_dog, ig_dog),
(ll_cat, ig_cat),
(_, ll_max_0),
(_, ll_max_1),
(_, ll_max_2),
*vals,
) = gpt3.loglikelihood(LOGLIKELIHOOD_TEST_CASES)
assert ll_dog > ll_cat
assert not ig_cat
assert ig_dog
assert not ll_max_0
assert not ll_max_1
assert not ll_max_2
# test empty context
gpt3.loglikelihood([("", "test")])
(gen,) = gpt3.greedy_until(
[("The quick brown fox jumps over the lazy", [".", "\n"])]
)
assert gen == " dog"
print([x[0] for x in vals])
targets = [
-34.848301606999996,
-47.148329679999996,
-45.44380149599999,
-5.285246016,
-133.97821690686004,
-321.2616693239001,
-658.0299524401041,
-34.848301606999996,
-7.525115,
]
for (pred, _), tgt in zip(vals, targets):
assert pred == pytest.approx(tgt, rel=1e-3)
@mock.patch("lm_eval.models.gpt3.oa_completion", new=openai_mock_completion)
def test_gpt3_perplexity():
if "OPENAI_API_SECRET_KEY" not in os.environ:
os.environ["OPENAI_API_SECRET_KEY"] = ""
gpt3 = models.get_model("gpt3").create_from_arg_string("engine=ada")
test_string = "We study empirical scaling laws for language model performance on the cross-entropy loss."
perplexity = gpt3.loglikelihood_rolling([(test_string,)])[0]
tgt = -84.38819608
assert perplexity == pytest.approx(tgt, rel=1e-3)
# Hack: modify gpt3 to have shorter context length to induce rolling windows
with mock.patch.object(
models.gpt3.GPT3LM, "max_length", new_callable=mock.PropertyMock
) as mock_max_length:
mock_max_length.return_value = 5
gpt3 = models.get_model("gpt3").create_from_arg_string("engine=ada")
perplexity = gpt3.loglikelihood_rolling([(test_string,)])[0]
tgt = -101.81967209999999
assert perplexity == pytest.approx(tgt, rel=1e-3)
# Test TextSynth Models (GPT-J)
def textsynth_mock_completion(**kwargs):
# Mock completion function
# Loads from a cached+pickled response if it exists, otherwise it will actually try to ping
import requests
os.makedirs("tests/testdata", exist_ok=True)
hash_kwargs = {k: v for k, v in kwargs.items() if k != "headers"}
hash = hashlib.sha256(
json.dumps(hash_kwargs, sort_keys=True).encode("utf-8")
).hexdigest()
fname = f"tests/testdata/textsynth_test_{hash}.pkl"
if os.path.exists(fname):
with open(fname, "rb") as fh:
return pickle.load(fh)
ret = requests.post(**kwargs)
with open(fname, "wb") as fh:
pickle.dump(ret, fh)
return ret
@mock.patch(
"lm_eval.models.textsynth.textsynth_completion", new=textsynth_mock_completion
)
def test_textsynth():
if "TEXTSYNTH_API_SECRET_KEY" not in os.environ:
os.environ["TEXTSYNTH_API_SECRET_KEY"] = ""
textsynth = models.get_model("textsynth").create_from_arg_string("engine=gptj_6B")
(
(ll_dog, ig_dog),
(ll_cat, ig_cat),
(_, ll_max_0),
(_, ll_max_1),
(_, ll_max_2),
*vals,
) = textsynth.loglikelihood(LOGLIKELIHOOD_TEST_CASES)
assert ll_dog > ll_cat
assert not ig_cat
assert ig_dog
assert not ll_max_0
assert not ll_max_1
assert not ll_max_2
# test empty context
textsynth.loglikelihood([("", "test")])
(gen,) = textsynth.greedy_until(
[("The quick brown fox jumps over the lazy", [".", "\n"])]
)
assert gen == " dog"
print([x[0] for x in vals])
targets = [
-17.90513712817,
-41.83518912287,
-33.82445643841,
-2.377361565302,
-99.53018069754,
-243.5642283598,
-528.6862613790,
-17.90513712817,
-5.041000672142,
]
for (pred, _), tgt in zip(vals, targets):
assert pred == pytest.approx(tgt, rel=1e-3)
...@@ -6,11 +6,8 @@ from itertools import islice ...@@ -6,11 +6,8 @@ from itertools import islice
@pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items()) @pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items())
def test_basic_interface(taskname, task_class): def test_basic_interface(taskname, task_class):
print('Evaluating task', taskname) print("Evaluating task", taskname)
# dl = task_class.download
# task_class.download = MagicMock()
task = task_class() task = task_class()
# task_class.download = dl
assert task.has_training_docs() in [True, False] assert task.has_training_docs() in [True, False]
assert task.has_validation_docs() in [True, False] assert task.has_validation_docs() in [True, False]
...@@ -42,7 +39,7 @@ def test_basic_interface(taskname, task_class): ...@@ -42,7 +39,7 @@ def test_basic_interface(taskname, task_class):
reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr] reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2] reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2]
assert reqs == reqs2 assert reqs == reqs2
if task.has_test_docs(): if task.has_test_docs():
...@@ -53,7 +50,7 @@ def test_basic_interface(taskname, task_class): ...@@ -53,7 +50,7 @@ def test_basic_interface(taskname, task_class):
reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr] reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2] reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2]
assert reqs == reqs2 assert reqs == reqs2
if task.has_training_docs(): if task.has_training_docs():
...@@ -64,13 +61,13 @@ def test_basic_interface(taskname, task_class): ...@@ -64,13 +61,13 @@ def test_basic_interface(taskname, task_class):
reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr] reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2] reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2]
assert reqs == reqs2 assert reqs == reqs2
@pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items()) @pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items())
def test_documents_and_requests(taskname, task_class): def test_documents_and_requests(taskname, task_class):
print('Evaluating task', taskname) print("Evaluating task", taskname)
task = task_class() task = task_class()
fns = [] fns = []
if task.has_training_docs(): if task.has_training_docs():
...@@ -83,21 +80,21 @@ def test_documents_and_requests(taskname, task_class): ...@@ -83,21 +80,21 @@ def test_documents_and_requests(taskname, task_class):
for fn in fns: for fn in fns:
# print(list(islice(fn(), 10))) # print(list(islice(fn(), 10)))
for doc in islice(fn(), 10): for doc in islice(fn(), 10):
txt = task.doc_to_text(doc) txt = task.doc_to_text(doc)
tgt = task.doc_to_target(doc) tgt = task.doc_to_target(doc)
assert isinstance(txt, str) assert isinstance(txt, str)
assert isinstance(tgt, str) assert isinstance(tgt, str)
# space convention # space convention
# allow txt to have length 0 for perplexity-like tasks since the model tacks an <|endoftext|> on # allow txt to have length 0 for perplexity-like tasks since the model tacks an <|endoftext|> on
if len(txt) != 0: if len(txt) != 0:
assert txt[-1] != ' ' assert txt[-1] != " "
assert tgt[0] == ' ' or txt[-1] == '\n' assert tgt[0] == " " or txt[-1] == "\n"
reqs = task.construct_requests(doc, txt) reqs = task.construct_requests(doc, txt)
# construct_requests can return just one request # construct_requests can return just one request
if not isinstance(reqs, (list, tuple)): if not isinstance(reqs, (list, tuple)):
reqs = [reqs] reqs = [reqs]
......
...@@ -5,8 +5,14 @@ from lm_eval.utils import get_rolling_token_windows, make_disjoint_window ...@@ -5,8 +5,14 @@ from lm_eval.utils import get_rolling_token_windows, make_disjoint_window
def test_get_rolling_token_windows_v1(): def test_get_rolling_token_windows_v1():
gold = [ gold = [
([-100, 0, 1, 2, 3, 4, 5, 6, 7, 8], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), ([-100, 0, 1, 2, 3, 4, 5, 6, 7, 8], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
([9, 10, 11, 12, 13, 14, 15, 16, 17, 18], [10, 11, 12, 13, 14, 15, 16, 17, 18, 19]), (
([19, 20, 21, 22, 23, 24, 25, 26, 27, 28], [20, 21, 22, 23, 24, 25, 26, 27, 28, 29]), [9, 10, 11, 12, 13, 14, 15, 16, 17, 18],
[10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
),
(
[19, 20, 21, 22, 23, 24, 25, 26, 27, 28],
[20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
),
([23, 24, 25, 26, 27, 28, 29, 30, 31, 32], [30, 31, 32, 33]), ([23, 24, 25, 26, 27, 28, 29, 30, 31, 32], [30, 31, 32, 33]),
] ]
x = list(range(34)) x = list(range(34))
...@@ -123,7 +129,6 @@ def test_get_rolling_token_windows_v4(): ...@@ -123,7 +129,6 @@ def test_get_rolling_token_windows_v4():
([17, 18, 19, 20, 21, 22, 23, 24, 25, 26], [27]), ([17, 18, 19, 20, 21, 22, 23, 24, 25, 26], [27]),
([18, 19, 20, 21, 22, 23, 24, 25, 26, 27], [28]), ([18, 19, 20, 21, 22, 23, 24, 25, 26, 27], [28]),
([19, 20, 21, 22, 23, 24, 25, 26, 27, 28], [29]), ([19, 20, 21, 22, 23, 24, 25, 26, 27, 28], [29]),
] ]
x = list(range(30)) x = list(range(30))
generator = get_rolling_token_windows( generator = get_rolling_token_windows(
...@@ -145,8 +150,14 @@ def test_get_rolling_token_windows_v4(): ...@@ -145,8 +150,14 @@ def test_get_rolling_token_windows_v4():
def test_get_rolling_token_windows_v5(): def test_get_rolling_token_windows_v5():
gold = [ gold = [
([-100, 0, 1, 2, 3, 4, 5, 6, 7, 8], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), ([-100, 0, 1, 2, 3, 4, 5, 6, 7, 8], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
([9, 10, 11, 12, 13, 14, 15, 16, 17, 18], [10, 11, 12, 13, 14, 15, 16, 17, 18, 19]), (
([19, 20, 21, 22, 23, 24, 25, 26, 27, 28], [20, 21, 22, 23, 24, 25, 26, 27, 28, 29]), [9, 10, 11, 12, 13, 14, 15, 16, 17, 18],
[10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
),
(
[19, 20, 21, 22, 23, 24, 25, 26, 27, 28],
[20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
),
] ]
x = list(range(30)) x = list(range(30))
generator = get_rolling_token_windows( generator = get_rolling_token_windows(
...@@ -203,5 +214,9 @@ def test_get_rolling_token_windows_empty(): ...@@ -203,5 +214,9 @@ def test_get_rolling_token_windows_empty():
def test_make_disjoint_window(): def test_make_disjoint_window():
assert make_disjoint_window(([1,2,3,4,5], [2,3,4,5,6])) == ([1], [2,3,4,5,6]) assert make_disjoint_window(([1, 2, 3, 4, 5], [2, 3, 4, 5, 6])) == (
assert make_disjoint_window(([1,2,3,4,5], [4,5,6])) == ([1,2,3], [4,5,6]) [1],
\ No newline at end of file [2, 3, 4, 5, 6],
)
assert make_disjoint_window(([1, 2, 3, 4, 5], [4, 5, 6])) == ([1, 2, 3], [4, 5, 6])
assert make_disjoint_window(([1, 2, 3, 4, 5], [6])) == ([1, 2, 3, 4, 5], [6])
...@@ -16,13 +16,14 @@ def assert_target(name, ob): ...@@ -16,13 +16,14 @@ def assert_target(name, ob):
fname = f"tests/testdata/{name}.json" fname = f"tests/testdata/{name}.json"
if os.path.exists(fname): if os.path.exists(fname):
with open(fname) as fh: with open(fname) as fh:
# Use relative tolerance of 1e-5 and absolute tolerance of 1e-8 # Use relative tolerance of 1e-5 and absolute tolerance of 1e-8
# assuming most metrics work on `float32` values, which is the common # assuming most metrics work on `float32` values, which is the common
# default floating type across popular libraries (PyTorch, Tensorflow, and JAX). # default floating type across popular libraries (PyTorch, Tensorflow, and JAX).
assert flatten(json.load(fh)) == pytest.approx( assert flatten(json.load(fh)) == pytest.approx(
flatten(json.loads(json.dumps(ob, sort_keys=True))), rel=1e-5, abs=1e-8) flatten(json.loads(json.dumps(ob, sort_keys=True))), rel=1e-5, abs=1e-8
)
else: else:
with open(fname, 'w') as fh: with open(fname, "w") as fh:
json.dump(ob, fh, sort_keys=True) json.dump(ob, fh, sort_keys=True)
...@@ -30,41 +31,52 @@ def assert_target_hashed(name, ob): ...@@ -30,41 +31,52 @@ def assert_target_hashed(name, ob):
fname = f"tests/testdata/{name}" fname = f"tests/testdata/{name}"
if os.path.exists(fname): if os.path.exists(fname):
with open(fname) as fh: with open(fname) as fh:
assert fh.read() == hashlib.sha256(json.dumps(ob, sort_keys=True).encode('utf-8')).hexdigest() assert (
fh.read()
== hashlib.sha256(
json.dumps(ob, sort_keys=True).encode("utf-8")
).hexdigest()
)
else: else:
with open(fname, 'w') as fh: with open(fname, "w") as fh:
fh.write(hashlib.sha256(json.dumps(ob, sort_keys=True).encode('utf-8')).hexdigest()) fh.write(
hashlib.sha256(
json.dumps(ob, sort_keys=True).encode("utf-8")
).hexdigest()
)
# from https://stackoverflow.com/a/6027615 # from https://stackoverflow.com/a/6027615
def flatten(d, parent_key='', sep='.'): def flatten(d, parent_key="", sep="."):
items = [] items = []
for k, v in d.items(): for k, v in d.items():
new_key = parent_key + sep + k if parent_key else k new_key = parent_key + sep + k if parent_key else k
if isinstance(v, collections.MutableMapping): if isinstance(v, collections.abc.MutableMapping):
items.extend(flatten(v, new_key, sep=sep).items()) items.extend(flatten(v, new_key, sep=sep).items())
else: else:
items.append((new_key, v)) items.append((new_key, v))
return dict(items) return dict(items)
# make sure eval results for a task version are stable # make sure eval results for a task version are stable
@pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items()) @pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items())
def test_versions_stable(taskname, task_class): def test_versions_stable(taskname, task_class):
task_dict = tasks.get_task_dict([taskname]) task_dict = tasks.get_task_dict([taskname])
lm = models.get_model('dummy')() lm = models.get_model("dummy")()
def ll_fn(reqs): def ll_fn(reqs):
for ctx, cont in reqs: for ctx, cont in reqs:
if len(ctx) == 0: if len(ctx) == 0:
continue continue
# space convention # space convention
assert ctx[-1] != ' ' assert ctx[-1] != " "
assert cont[0] == ' ' or ctx[-1] == '\n' assert cont[0] == " " or ctx[-1] == "\n"
assert_target_hashed(f"{taskname}-v{task_class.VERSION}-loglikelihood", reqs) assert_target_hashed(f"{taskname}-v{task_class.VERSION}-loglikelihood", reqs)
res = [] res = []
random.seed(42) random.seed(42)
for _ in reqs: for _ in reqs:
res.append((-random.random(), False)) res.append((-random.random(), False))
...@@ -72,10 +84,12 @@ def test_versions_stable(taskname, task_class): ...@@ -72,10 +84,12 @@ def test_versions_stable(taskname, task_class):
return res return res
def ll_perp_fn(reqs): def ll_perp_fn(reqs):
for string, in reqs: for (string,) in reqs:
assert isinstance(string, str) assert isinstance(string, str)
assert_target_hashed(f"{taskname}-v{task_class.VERSION}-loglikelihood_rolling", reqs) assert_target_hashed(
f"{taskname}-v{task_class.VERSION}-loglikelihood_rolling", reqs
)
res = [] res = []
random.seed(42) random.seed(42)
...@@ -83,14 +97,14 @@ def test_versions_stable(taskname, task_class): ...@@ -83,14 +97,14 @@ def test_versions_stable(taskname, task_class):
res.append(-random.random()) res.append(-random.random())
return res return res
def greedy_until(reqs): def greedy_until(reqs):
res = [] res = []
assert_target_hashed(f"{taskname}-v{task_class.VERSION}-greedy_until", reqs) assert_target_hashed(f"{taskname}-v{task_class.VERSION}-greedy_until", reqs)
for ctx, _ in reqs: for ctx, _ in reqs:
res.append("lol") res.append("lol")
assert ctx.strip() != '' assert ctx.strip() != ""
return res return res
...@@ -100,12 +114,12 @@ def test_versions_stable(taskname, task_class): ...@@ -100,12 +114,12 @@ def test_versions_stable(taskname, task_class):
limit = None limit = None
result = evaluator.evaluate( result = evaluator.evaluate(
lm=lm, lm=lm,
task_dict=task_dict, task_dict=task_dict,
num_fewshot=0, num_fewshot=0,
limit=limit, limit=limit,
bootstrap_iters=10, bootstrap_iters=10,
description_dict=None description_dict=None,
) )
assert_target(f"{taskname}-v{task_class.VERSION}-res", result) assert_target(f"{taskname}-v{task_class.VERSION}-res", result)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment