Unverified Commit a2cada5d authored by Jonathan Tow's avatar Jonathan Tow Committed by GitHub
Browse files

Merge pull request #317 from EleutherAI/Mistobaan/add-pre-commit

Add pre-commit
parents 7a038118 83507c4b
...@@ -65,13 +65,14 @@ class TruthfulqaConfig(datasets.BuilderConfig): ...@@ -65,13 +65,14 @@ class TruthfulqaConfig(datasets.BuilderConfig):
class Truthfulqa(datasets.GeneratorBasedBuilder): class Truthfulqa(datasets.GeneratorBasedBuilder):
"""TruthfulQA is a benchmark to measure whether a language model is truthful in """TruthfulQA is a benchmark to measure whether a language model is truthful in
generating answers to questions.""" generating answers to questions."""
BUILDER_CONFIGS = [ BUILDER_CONFIGS = [
TruthfulqaConfig( TruthfulqaConfig(
name="multiple_choice", name="multiple_choice",
url="https://raw.githubusercontent.com/sylinrl/TruthfulQA/013686a06be7a7bde5bf8223943e106c7250123c/data/mc_task.json", url="https://raw.githubusercontent.com/sylinrl/TruthfulQA/013686a06be7a7bde5bf8223943e106c7250123c/data/mc_task.json",
features=datasets.Features({ features=datasets.Features(
{
"question": datasets.Value("string"), "question": datasets.Value("string"),
"mc1_targets": { "mc1_targets": {
"choices": datasets.features.Sequence(datasets.Value("string")), "choices": datasets.features.Sequence(datasets.Value("string")),
...@@ -80,23 +81,30 @@ generating answers to questions.""" ...@@ -80,23 +81,30 @@ generating answers to questions."""
"mc2_targets": { "mc2_targets": {
"choices": datasets.features.Sequence(datasets.Value("string")), "choices": datasets.features.Sequence(datasets.Value("string")),
"labels": datasets.features.Sequence(datasets.Value("int32")), "labels": datasets.features.Sequence(datasets.Value("int32")),
},
} }
}), ),
description="The multiple choice TruthfulQA task" description="The multiple choice TruthfulQA task",
), ),
TruthfulqaConfig( TruthfulqaConfig(
name="generation", name="generation",
url="https://raw.githubusercontent.com/sylinrl/TruthfulQA/013686a06be7a7bde5bf8223943e106c7250123c/TruthfulQA.csv", url="https://raw.githubusercontent.com/sylinrl/TruthfulQA/013686a06be7a7bde5bf8223943e106c7250123c/TruthfulQA.csv",
features=datasets.Features({ features=datasets.Features(
{
"category": datasets.Value("string"), "category": datasets.Value("string"),
"question": datasets.Value("string"), "question": datasets.Value("string"),
"best_answer": datasets.Value("string"), "best_answer": datasets.Value("string"),
"correct_answers": datasets.features.Sequence(datasets.Value("string")), "correct_answers": datasets.features.Sequence(
"incorrect_answers": datasets.features.Sequence(datasets.Value("string")), datasets.Value("string")
),
"incorrect_answers": datasets.features.Sequence(
datasets.Value("string")
),
"source": datasets.Value("string"), "source": datasets.Value("string"),
}), }
description="The generative TruthfulQA task" ),
) description="The generative TruthfulQA task",
),
] ]
def _info(self): def _info(self):
...@@ -138,15 +146,15 @@ generating answers to questions.""" ...@@ -138,15 +146,15 @@ generating answers to questions."""
"mc2_targets": { "mc2_targets": {
"choices": row["mc2_targets"].keys(), "choices": row["mc2_targets"].keys(),
"labels": row["mc2_targets"].values(), "labels": row["mc2_targets"].values(),
} },
} }
else: else:
# Generation data is in a `CSV` file. # Generation data is in a `CSV` file.
with open(filepath, newline='') as f: with open(filepath, newline="") as f:
contents = csv.DictReader(f) contents = csv.DictReader(f)
for key, row in enumerate(contents): for key, row in enumerate(contents):
# Ensure that references exist. # Ensure that references exist.
if not row['Correct Answers'] or not row['Incorrect Answers']: if not row["Correct Answers"] or not row["Incorrect Answers"]:
continue continue
yield key, { yield key, {
"category": row["Category"], "category": row["Category"],
...@@ -154,6 +162,8 @@ generating answers to questions.""" ...@@ -154,6 +162,8 @@ generating answers to questions."""
"best_answer": row["Best Answer"], "best_answer": row["Best Answer"],
# split on ";" # split on ";"
"correct_answers": row["Correct Answers"].strip().split(";"), "correct_answers": row["Correct Answers"].strip().split(";"),
"incorrect_answers": row["Incorrect Answers"].strip().split(";"), "incorrect_answers": row["Incorrect Answers"]
.strip()
.split(";"),
"source": row["Source"], "source": row["Source"],
} }
...@@ -64,8 +64,9 @@ class Unscramble(datasets.GeneratorBasedBuilder): ...@@ -64,8 +64,9 @@ class Unscramble(datasets.GeneratorBasedBuilder):
VERSION = datasets.Version("0.0.1") VERSION = datasets.Version("0.0.1")
BUILDER_CONFIGS = [ BUILDER_CONFIGS = [
datasets.BuilderConfig(name=name, version=version, datasets.BuilderConfig(
description=_DESCRIPTIONS[name]) name=name, version=version, description=_DESCRIPTIONS[name]
)
for name, version in zip(_NAMES, [VERSION] * len(_NAMES)) for name, version in zip(_NAMES, [VERSION] * len(_NAMES))
] ]
......
...@@ -123,86 +123,111 @@ class Wikitext(datasets.GeneratorBasedBuilder): ...@@ -123,86 +123,111 @@ class Wikitext(datasets.GeneratorBasedBuilder):
return [ return [
datasets.SplitGenerator( datasets.SplitGenerator(
name=datasets.Split.TEST, name=datasets.Split.TEST,
gen_kwargs={"data_file": os.path.join( gen_kwargs={
data_dir, "wiki.test.tokens"), "split": "test"}, "data_file": os.path.join(data_dir, "wiki.test.tokens"),
"split": "test",
},
), ),
datasets.SplitGenerator( datasets.SplitGenerator(
name=datasets.Split.TRAIN, name=datasets.Split.TRAIN,
gen_kwargs={"data_file": os.path.join( gen_kwargs={
data_dir, "wiki.train.tokens"), "split": "train"}, "data_file": os.path.join(data_dir, "wiki.train.tokens"),
"split": "train",
},
), ),
datasets.SplitGenerator( datasets.SplitGenerator(
name=datasets.Split.VALIDATION, name=datasets.Split.VALIDATION,
gen_kwargs={"data_file": os.path.join( gen_kwargs={
data_dir, "wiki.valid.tokens"), "split": "valid"}, "data_file": os.path.join(data_dir, "wiki.valid.tokens"),
"split": "valid",
},
), ),
] ]
else: else:
if self.config.name == "wikitext-103-raw-v1": if self.config.name == "wikitext-103-raw-v1":
data_file = dl_manager.download_and_extract( data_file = dl_manager.download_and_extract(self.config.data_url)
self.config.data_url)
data_dir = os.path.join(data_file, "wikitext-103-raw") data_dir = os.path.join(data_file, "wikitext-103-raw")
return [ return [
datasets.SplitGenerator( datasets.SplitGenerator(
name=datasets.Split.TEST, name=datasets.Split.TEST,
gen_kwargs={"data_file": os.path.join( gen_kwargs={
data_dir, "wiki.test.raw"), "split": "test"}, "data_file": os.path.join(data_dir, "wiki.test.raw"),
"split": "test",
},
), ),
datasets.SplitGenerator( datasets.SplitGenerator(
name=datasets.Split.TRAIN, name=datasets.Split.TRAIN,
gen_kwargs={"data_file": os.path.join( gen_kwargs={
data_dir, "wiki.train.raw"), "split": "train"}, "data_file": os.path.join(data_dir, "wiki.train.raw"),
"split": "train",
},
), ),
datasets.SplitGenerator( datasets.SplitGenerator(
name=datasets.Split.VALIDATION, name=datasets.Split.VALIDATION,
gen_kwargs={"data_file": os.path.join( gen_kwargs={
data_dir, "wiki.valid.raw"), "split": "valid"}, "data_file": os.path.join(data_dir, "wiki.valid.raw"),
"split": "valid",
},
), ),
] ]
else: else:
if self.config.name == "wikitext-2-raw-v1": if self.config.name == "wikitext-2-raw-v1":
data_file = dl_manager.download_and_extract( data_file = dl_manager.download_and_extract(self.config.data_url)
self.config.data_url)
data_dir = os.path.join(data_file, "wikitext-2-raw") data_dir = os.path.join(data_file, "wikitext-2-raw")
return [ return [
datasets.SplitGenerator( datasets.SplitGenerator(
name=datasets.Split.TEST, name=datasets.Split.TEST,
gen_kwargs={"data_file": os.path.join( gen_kwargs={
data_dir, "wiki.test.raw"), "split": "test"}, "data_file": os.path.join(data_dir, "wiki.test.raw"),
"split": "test",
},
), ),
datasets.SplitGenerator( datasets.SplitGenerator(
name=datasets.Split.TRAIN, name=datasets.Split.TRAIN,
gen_kwargs={"data_file": os.path.join( gen_kwargs={
data_dir, "wiki.train.raw"), "split": "train"}, "data_file": os.path.join(data_dir, "wiki.train.raw"),
"split": "train",
},
), ),
datasets.SplitGenerator( datasets.SplitGenerator(
name=datasets.Split.VALIDATION, name=datasets.Split.VALIDATION,
gen_kwargs={"data_file": os.path.join( gen_kwargs={
data_dir, "wiki.valid.raw"), "split": "valid"}, "data_file": os.path.join(data_dir, "wiki.valid.raw"),
"split": "valid",
},
), ),
] ]
else: else:
if self.config.name == "wikitext-2-v1": if self.config.name == "wikitext-2-v1":
data_file = dl_manager.download_and_extract( data_file = dl_manager.download_and_extract(
self.config.data_url) self.config.data_url
)
data_dir = os.path.join(data_file, "wikitext-2") data_dir = os.path.join(data_file, "wikitext-2")
return [ return [
datasets.SplitGenerator( datasets.SplitGenerator(
name=datasets.Split.TEST, name=datasets.Split.TEST,
gen_kwargs={"data_file": os.path.join( gen_kwargs={
data_dir, "wiki.test.tokens"), "split": "test"}, "data_file": os.path.join(
data_dir, "wiki.test.tokens"
),
"split": "test",
},
), ),
datasets.SplitGenerator( datasets.SplitGenerator(
name=datasets.Split.TRAIN, name=datasets.Split.TRAIN,
gen_kwargs={ gen_kwargs={
"data_file": os.path.join(data_dir, "wiki.train.tokens"), "data_file": os.path.join(
data_dir, "wiki.train.tokens"
),
"split": "train", "split": "train",
}, },
), ),
datasets.SplitGenerator( datasets.SplitGenerator(
name=datasets.Split.VALIDATION, name=datasets.Split.VALIDATION,
gen_kwargs={ gen_kwargs={
"data_file": os.path.join(data_dir, "wiki.valid.tokens"), "data_file": os.path.join(
data_dir, "wiki.valid.tokens"
),
"split": "valid", "split": "valid",
}, },
), ),
...@@ -216,12 +241,12 @@ class Wikitext(datasets.GeneratorBasedBuilder): ...@@ -216,12 +241,12 @@ class Wikitext(datasets.GeneratorBasedBuilder):
data = f.read().split("\n") data = f.read().split("\n")
for line in data: for line in data:
rline = line.replace("= = =", "===").replace("= =", "==").strip() rline = line.replace("= = =", "===").replace("= =", "==").strip()
if rline.startswith('= ') and rline.strip().endswith(' ='): if rline.startswith("= ") and rline.strip().endswith(" ="):
page = '\n'.join(ret) page = "\n".join(ret)
if page.strip(): if page.strip():
yield key, {"page": page} yield key, {"page": page}
key += 1 key += 1
ret = [] ret = []
ret.append(line) ret.append(line)
page = '\n'.join(ret) page = "\n".join(ret)
yield key, {"page": page} yield key, {"page": page}
...@@ -8,12 +8,14 @@ import mmap ...@@ -8,12 +8,14 @@ import mmap
import tqdm import tqdm
from pathlib import Path from pathlib import Path
def json_serial(obj): def json_serial(obj):
"""JSON serializer for objects not serializable by default json code""" """JSON serializer for objects not serializable by default json code"""
if isinstance(obj, (datetime.datetime,)): if isinstance(obj, (datetime.datetime,)):
return obj.isoformat() return obj.isoformat()
raise TypeError ("Type %s not serializable" % type(obj)) raise TypeError("Type %s not serializable" % type(obj))
# Modified version of lm_dataformat Archive for single file. # Modified version of lm_dataformat Archive for single file.
class Archive: class Archive:
...@@ -22,25 +24,31 @@ class Archive: ...@@ -22,25 +24,31 @@ class Archive:
dir_name = os.path.dirname(file_path) dir_name = os.path.dirname(file_path)
if dir_name: if dir_name:
os.makedirs(dir_name, exist_ok=True) os.makedirs(dir_name, exist_ok=True)
self.fh = open(self.file_path, 'wb') self.fh = open(self.file_path, "wb")
self.cctx = zstandard.ZstdCompressor(level=compression_level) self.cctx = zstandard.ZstdCompressor(level=compression_level)
self.compressor = self.cctx.stream_writer(self.fh) self.compressor = self.cctx.stream_writer(self.fh)
def add_data(self, data, meta={}): def add_data(self, data, meta={}):
self.compressor.write(json.dumps({'text': data, 'meta': meta}, default=json_serial).encode('UTF-8') + b'\n') self.compressor.write(
json.dumps({"text": data, "meta": meta}, default=json_serial).encode(
"UTF-8"
)
+ b"\n"
)
def commit(self): def commit(self):
self.compressor.flush(zstandard.FLUSH_FRAME) self.compressor.flush(zstandard.FLUSH_FRAME)
self.fh.flush() self.fh.flush()
self.fh.close() self.fh.close()
# Modified version of lm_dataformat Reader with self.fh set, allowing peeking for tqdm. # Modified version of lm_dataformat Reader with self.fh set, allowing peeking for tqdm.
class Reader: class Reader:
def __init__(self): def __init__(self):
pass pass
def read(self, file, get_meta=False, autojoin_paragraphs=True, para_joiner='\n\n'): def read(self, file, get_meta=False, autojoin_paragraphs=True, para_joiner="\n\n"):
with open(file, 'rb') as fh: with open(file, "rb") as fh:
self.fh = fh self.fh = fh
cctx = zstandard.ZstdDecompressor() cctx = zstandard.ZstdDecompressor()
reader = io.BufferedReader(cctx.stream_reader(fh)) reader = io.BufferedReader(cctx.stream_reader(fh))
...@@ -52,16 +60,17 @@ class Reader: ...@@ -52,16 +60,17 @@ class Reader:
yield ob yield ob
continue continue
text = ob['text'] text = ob["text"]
if autojoin_paragraphs and isinstance(text, list): if autojoin_paragraphs and isinstance(text, list):
text = para_joiner.join(text) text = para_joiner.join(text)
if get_meta: if get_meta:
yield text, (ob['meta'] if 'meta' in ob else {}) yield text, (ob["meta"] if "meta" in ob else {})
else: else:
yield text yield text
class TextArchive: class TextArchive:
def __init__(self, file_path, mode="rb+"): def __init__(self, file_path, mode="rb+"):
self.file_path = file_path self.file_path = file_path
...@@ -75,12 +84,13 @@ class TextArchive: ...@@ -75,12 +84,13 @@ class TextArchive:
self.fh = open(self.file_path, mode) self.fh = open(self.file_path, mode)
def add_data(self, data): def add_data(self, data):
self.fh.write(data.encode('UTF-8') + b'\n') self.fh.write(data.encode("UTF-8") + b"\n")
def commit(self): def commit(self):
self.fh.flush() self.fh.flush()
self.fh.close() self.fh.close()
class TextReader: class TextReader:
def __init__(self, file_path): def __init__(self, file_path):
self.file_path = file_path self.file_path = file_path
...@@ -90,9 +100,12 @@ class TextReader: ...@@ -90,9 +100,12 @@ class TextReader:
def read_tqdm(self, update_frequency=10000): def read_tqdm(self, update_frequency=10000):
current_file_position = 0 current_file_position = 0
line_counter = 0 line_counter = 0
with open(self.file_path, 'r') as fh, \ with open(self.file_path, "r") as fh, tqdm.tqdm(
tqdm.tqdm(total=os.path.getsize(self.file_path), dynamic_ncols=True, total=os.path.getsize(self.file_path),
unit="byte", unit_scale=1) as progress: dynamic_ncols=True,
unit="byte",
unit_scale=1,
) as progress:
with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj: with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
for line in iter(mmap_obj.readline, b""): for line in iter(mmap_obj.readline, b""):
line = line.decode("utf-8") line = line.decode("utf-8")
...@@ -107,7 +120,7 @@ class TextReader: ...@@ -107,7 +120,7 @@ class TextReader:
def read_and_tell(self): def read_and_tell(self):
current_file_position = 0 current_file_position = 0
with open(self.file_path, 'r', encoding="utf8") as fh: with open(self.file_path, "r", encoding="utf8") as fh:
with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj: with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
for line in iter(mmap_obj.readline, b""): for line in iter(mmap_obj.readline, b""):
line = line.decode("utf-8") line = line.decode("utf-8")
...@@ -117,14 +130,14 @@ class TextReader: ...@@ -117,14 +130,14 @@ class TextReader:
yield line[:-1], raw_bytes_read yield line[:-1], raw_bytes_read
def read(self): def read(self):
with open(self.file_path, 'r', encoding="utf8") as fh: with open(self.file_path, "r", encoding="utf8") as fh:
with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj: with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
for line in iter(mmap_obj.readline, b""): for line in iter(mmap_obj.readline, b""):
line = line.decode("utf-8") line = line.decode("utf-8")
yield line[:-1] yield line[:-1]
def read_slow(self): def read_slow(self):
with open(self.file_path, 'r', encoding="utf8") as fh: with open(self.file_path, "r", encoding="utf8") as fh:
while True: while True:
line = fh.readline() line = fh.readline()
if line == -1 or line == "": if line == -1 or line == "":
...@@ -132,6 +145,7 @@ class TextReader: ...@@ -132,6 +145,7 @@ class TextReader:
else: else:
yield line[:-1] yield line[:-1]
# Optimized for speed. Decompresses the archive in shell before # Optimized for speed. Decompresses the archive in shell before
# using the mmap'd TextReader. # using the mmap'd TextReader.
class ZStdTextReader: class ZStdTextReader:
......
...@@ -9,6 +9,7 @@ import collections ...@@ -9,6 +9,7 @@ import collections
from .janitor import Janitor, word_ngrams from .janitor import Janitor, word_ngrams
from .archiver import ZStdTextReader from .archiver import ZStdTextReader
# Was used for testing the evaluator decoupled from the full logic below # Was used for testing the evaluator decoupled from the full logic below
def get_train_overlap_stub(docs, ngrams_path, ngrams_n_size): def get_train_overlap_stub(docs, ngrams_path, ngrams_n_size):
simulated_overlap = 0.1 simulated_overlap = 0.1
...@@ -57,19 +58,27 @@ def get_train_overlap(docs_by_task_set, ngrams_path, limit): ...@@ -57,19 +58,27 @@ def get_train_overlap(docs_by_task_set, ngrams_path, limit):
os.mkdir(f"data/{task_name}") os.mkdir(f"data/{task_name}")
# Check if we've decontaminated this combination before # Check if we've decontaminated this combination before
overlaps_dump_path = get_overlaps_dump_path(task_name, task_set, ngrams_n_size, limit) overlaps_dump_path = get_overlaps_dump_path(
task_name, task_set, ngrams_n_size, limit
)
if os.path.exists(overlaps_dump_path): if os.path.exists(overlaps_dump_path):
duplicates[(task_name, task_set)] = pickle.load(open(overlaps_dump_path, "rb")) duplicates[(task_name, task_set)] = pickle.load(
open(overlaps_dump_path, "rb")
)
sets_to_decontaminate -= 1 sets_to_decontaminate -= 1
continue continue
else: else:
duplicates[(task_name, task_set)] = set() duplicates[(task_name, task_set)] = set()
# Build/load the task lookup {ngram: set(documents)}. # Build/load the task lookup {ngram: set(documents)}.
task_set_lookup_path = f"data/{task_name}/{task_set}_{ngrams_n_size}grams_limit{limit}.lookup" task_set_lookup_path = (
f"data/{task_name}/{task_set}_{ngrams_n_size}grams_limit{limit}.lookup"
)
if os.path.exists(task_set_lookup_path): if os.path.exists(task_set_lookup_path):
print(f"{task_set_lookup_path} available, loading...") print(f"{task_set_lookup_path} available, loading...")
lookups[(task_name, task_set)] = pickle.load(open(task_set_lookup_path, "rb")) lookups[(task_name, task_set)] = pickle.load(
open(task_set_lookup_path, "rb")
)
else: else:
print(f"{task_set_lookup_path} not available, building...") print(f"{task_set_lookup_path} not available, building...")
lookup = collections.defaultdict(set) lookup = collections.defaultdict(set)
...@@ -79,7 +88,7 @@ def get_train_overlap(docs_by_task_set, ngrams_path, limit): ...@@ -79,7 +88,7 @@ def get_train_overlap(docs_by_task_set, ngrams_path, limit):
for ngram in ngrams: for ngram in ngrams:
lookup[ngram].add(doc_id) lookup[ngram].add(doc_id)
pickle.dump(lookup, open(task_set_lookup_path,"wb")) pickle.dump(lookup, open(task_set_lookup_path, "wb"))
lookups[(task_name, task_set)] = lookup lookups[(task_name, task_set)] = lookup
elapsed = time.perf_counter() - start elapsed = time.perf_counter() - start
...@@ -115,7 +124,9 @@ def get_train_overlap(docs_by_task_set, ngrams_path, limit): ...@@ -115,7 +124,9 @@ def get_train_overlap(docs_by_task_set, ngrams_path, limit):
for line in reader.read_tqdm(): # Scan training set ngrams file for line in reader.read_tqdm(): # Scan training set ngrams file
total_ngrams += 1 total_ngrams += 1
[ngram, document_id] = line.rsplit(" ", 1) [ngram, document_id] = line.rsplit(" ", 1)
if ngram != current_ngram: # Only need to match the ngram once in training set if (
ngram != current_ngram
): # Only need to match the ngram once in training set
unique_ngrams += 1 unique_ngrams += 1
current_ngram = ngram current_ngram = ngram
if ngram in merged_lookup: if ngram in merged_lookup:
...@@ -123,7 +134,11 @@ def get_train_overlap(docs_by_task_set, ngrams_path, limit): ...@@ -123,7 +134,11 @@ def get_train_overlap(docs_by_task_set, ngrams_path, limit):
matching_unique += 1 matching_unique += 1
for task_name, task_set, doc_ids in merged_lookup[ngram]: for task_name, task_set, doc_ids in merged_lookup[ngram]:
task_doc_set = duplicates[(task_name, task_set)] task_doc_set = duplicates[(task_name, task_set)]
for doc_id in doc_ids: # Record contamination across all relevant task/set combos for (
doc_id
) in (
doc_ids
): # Record contamination across all relevant task/set combos
task_doc_set.add(doc_id) task_doc_set.add(doc_id)
del merged_lookup[ngram] # No point matching again del merged_lookup[ngram] # No point matching again
else: else:
...@@ -145,9 +160,10 @@ def get_train_overlap(docs_by_task_set, ngrams_path, limit): ...@@ -145,9 +160,10 @@ def get_train_overlap(docs_by_task_set, ngrams_path, limit):
# Dump overlaps separately # Dump overlaps separately
for (task_name, task_set), doc_ids in duplicates.items(): for (task_name, task_set), doc_ids in duplicates.items():
overlaps_dump_path = get_overlaps_dump_path(task_name, task_set, ngrams_n_size, limit) overlaps_dump_path = get_overlaps_dump_path(
pickle.dump(doc_ids, open(overlaps_dump_path,"wb")) task_name, task_set, ngrams_n_size, limit
)
pickle.dump(doc_ids, open(overlaps_dump_path, "wb"))
# Strip task set and return # Strip task set and return
return {task_name: doc_ids for (task_name, task_set), doc_ids in duplicates.items()} return {task_name: doc_ids for (task_name, task_set), doc_ids in duplicates.items()}
...@@ -9,8 +9,9 @@ from pprint import pprint ...@@ -9,8 +9,9 @@ from pprint import pprint
# c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup # c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup
try: try:
import janitor_util import janitor_util
JANITOR_CPP = True JANITOR_CPP = True
except Exception as e: except Exception:
print("WARNING: C++ module could not be loaded. Janitor running in python mode") print("WARNING: C++ module could not be loaded. Janitor running in python mode")
traceback.print_exc() traceback.print_exc()
JANITOR_CPP = False JANITOR_CPP = False
...@@ -41,6 +42,7 @@ def word_ngrams(s, n): ...@@ -41,6 +42,7 @@ def word_ngrams(s, n):
ngram_seqs = form_ngrams(iter(tokens), n) ngram_seqs = form_ngrams(iter(tokens), n)
return (" ".join(ngram) for ngram in ngram_seqs) return (" ".join(ngram) for ngram in ngram_seqs)
# Does character sequences only - combined faster function to play around with later # Does character sequences only - combined faster function to play around with later
# def word_ngrams_indices_combined(sequence, n): # def word_ngrams_indices_combined(sequence, n):
# current_word = "" # current_word = ""
...@@ -70,7 +72,7 @@ def split_indices(s): ...@@ -70,7 +72,7 @@ def split_indices(s):
"""Splits a string on whitespaces and records the indices of each in the original string. """Splits a string on whitespaces and records the indices of each in the original string.
@:return generator((word, (start_idx, end_idx)), ...) @:return generator((word, (start_idx, end_idx)), ...)
""" """
return ((m.group(0), (m.start(), m.end() - 1)) for m in re.finditer(r'\S+', s)) return ((m.group(0), (m.start(), m.end() - 1)) for m in re.finditer(r"\S+", s))
def word_ngrams_indices(s, n): def word_ngrams_indices(s, n):
...@@ -90,10 +92,15 @@ def word_ngrams_indices(s, n): ...@@ -90,10 +92,15 @@ def word_ngrams_indices(s, n):
# ([word, word, ...], [(start,end), (start,end), ...]), # ([word, word, ...], [(start,end), (start,end), ...]),
# ... # ...
# ) # )
ngram_indices_pairs = (zip(*ngram_with_indices) for ngram_with_indices in ngram_seqs_with_indices) ngram_indices_pairs = (
zip(*ngram_with_indices) for ngram_with_indices in ngram_seqs_with_indices
)
# Generator of ( (word_ngram, (start, end)), (word_ngram, start, end)), ...) # Generator of ( (word_ngram, (start, end)), (word_ngram, start, end)), ...)
return ((" ".join(ngram_seq), (indices[0][0], indices[-1][1])) for ngram_seq, indices in ngram_indices_pairs) return (
(" ".join(ngram_seq), (indices[0][0], indices[-1][1]))
for ngram_seq, indices in ngram_indices_pairs
)
class Janitor: class Janitor:
...@@ -105,7 +112,7 @@ class Janitor: ...@@ -105,7 +112,7 @@ class Janitor:
window_to_remove=200, window_to_remove=200,
too_dirty_cutoff=10, too_dirty_cutoff=10,
minimum_slice_length=200, minimum_slice_length=200,
delete_chars=string.punctuation delete_chars=string.punctuation,
): ):
self.ngram_n = ngram_n self.ngram_n = ngram_n
self.window_to_remove = window_to_remove self.window_to_remove = window_to_remove
...@@ -121,7 +128,7 @@ class Janitor: ...@@ -121,7 +128,7 @@ class Janitor:
self.translation_table = str.maketrans( self.translation_table = str.maketrans(
string.ascii_lowercase + string.ascii_uppercase, # These characters string.ascii_lowercase + string.ascii_uppercase, # These characters
string.ascii_lowercase * 2, # Become these characters string.ascii_lowercase * 2, # Become these characters
self.delete_chars # These are deleted self.delete_chars, # These are deleted
) )
############## ##############
...@@ -129,14 +136,13 @@ class Janitor: ...@@ -129,14 +136,13 @@ class Janitor:
############## ##############
def save_contamination_ngrams(self, filename): def save_contamination_ngrams(self, filename):
with open(filename, 'wb') as fp: with open(filename, "wb") as fp:
pickle.dump(filename, fp) pickle.dump(filename, fp)
def load_contamination_ngrams(self, filename): def load_contamination_ngrams(self, filename):
with open(filename, 'rb') as fp: with open(filename, "rb") as fp:
self.dirt_ngrams = pickle.load(fp) self.dirt_ngrams = pickle.load(fp)
############## ##############
# Call these :) # Call these :)
############## ##############
...@@ -152,7 +158,7 @@ class Janitor: ...@@ -152,7 +158,7 @@ class Janitor:
def clean(self, dirty_string): def clean(self, dirty_string):
"""Clean a string (e.g. a training set) by removing all ngrams previously """Clean a string (e.g. a training set) by removing all ngrams previously
reigstered as contaminants. Returns a list of clean chunks, or empty if registered as contaminants. Returns a list of clean chunks, or empty if
the string was too dirty""" the string was too dirty"""
if JANITOR_CPP: if JANITOR_CPP:
return self.clean_cpp(dirty_string) return self.clean_cpp(dirty_string)
...@@ -171,11 +177,11 @@ class Janitor: ...@@ -171,11 +177,11 @@ class Janitor:
end = min(len(dirty_string), end + self.window_to_remove) end = min(len(dirty_string), end + self.window_to_remove)
if start - splice_idx > self.minimum_slice_length: if start - splice_idx > self.minimum_slice_length:
clean_chunks.append(dirty_string[splice_idx: start]) clean_chunks.append(dirty_string[splice_idx:start])
splice_idx = end splice_idx = end
if end < len(dirty_string) - self.minimum_slice_length: if end < len(dirty_string) - self.minimum_slice_length:
clean_chunks.append(dirty_string[end+1:]) clean_chunks.append(dirty_string[end + 1 :])
return clean_chunks return clean_chunks
...@@ -184,10 +190,14 @@ class Janitor: ...@@ -184,10 +190,14 @@ class Janitor:
############## ##############
def register_contaminant_cpp(self, dirt_string): def register_contaminant_cpp(self, dirt_string):
self.dirt_ngrams.update(janitor_util.clean_ngram(dirt_string, self.delete_chars, self.ngram_n)) self.dirt_ngrams.update(
janitor_util.clean_ngram(dirt_string, self.delete_chars, self.ngram_n)
)
def clean_cpp(self, dirty_string): def clean_cpp(self, dirty_string):
contamination_indices = janitor_util.clean_ngram_with_indices(dirty_string, self.delete_chars, self.ngram_n) contamination_indices = janitor_util.clean_ngram_with_indices(
dirty_string, self.delete_chars, self.ngram_n
)
return self._split_chunks(dirty_string, contamination_indices) return self._split_chunks(dirty_string, contamination_indices)
############## ##############
...@@ -198,7 +208,9 @@ class Janitor: ...@@ -198,7 +208,9 @@ class Janitor:
return s.translate(self.translation_table) return s.translate(self.translation_table)
def register_contaminant_python(self, dirt_string): def register_contaminant_python(self, dirt_string):
self.dirt_ngrams.update(word_ngrams(self.normalize_string(dirt_string), self.ngram_n)) self.dirt_ngrams.update(
word_ngrams(self.normalize_string(dirt_string), self.ngram_n)
)
def clean_python(self, dirty_string): def clean_python(self, dirty_string):
contamination_indices = ( contamination_indices = (
...@@ -263,7 +275,7 @@ class Janitor: ...@@ -263,7 +275,7 @@ class Janitor:
# ultan Al Nahyan]] granted [[Petroleum]] concessions, and oil was first found in 1958. At first, # ultan Al Nahyan]] granted [[Petroleum]] concessions, and oil was first found in 1958. At first,
# oil money had a marginal impact. A few lowrise concete buildings were erected, and the first # oil money had a marginal impact. A few lowrise concete buildings were erected, and the first
# paved road was completed in 1961, but Sheikh Shakbut, uncertain whether the new oil royalties # paved road was completed in 1961, but Sheikh Shakbut, uncertain whether the new oil royalties
# would last, took a cautious approach, prefering to save the revenue rather than investing it in # would last, took a cautious approach, preferring to save the revenue rather than investing it in
# development. His brother, [[Zayed bin Sultan Al Nahayan]], saw that oil wealth had the potential # development. His brother, [[Zayed bin Sultan Al Nahayan]], saw that oil wealth had the potential
# to transform Abu Dhabi. The ruling Al Nahayan family decided that Sheikh Zayed should replace his # to transform Abu Dhabi. The ruling Al Nahayan family decided that Sheikh Zayed should replace his
# brother as Ruler and carry out his vision of developing the country. On [[August 6]], [[1966]], # brother as Ruler and carry out his vision of developing the country. On [[August 6]], [[1966]],
......
...@@ -11,12 +11,22 @@ import numpy as np ...@@ -11,12 +11,22 @@ import numpy as np
from lm_eval.utils import positional_deprecated, run_task_tests from lm_eval.utils import positional_deprecated, run_task_tests
from lm_eval.decontamination.decontaminate import get_train_overlap from lm_eval.decontamination.decontaminate import get_train_overlap
@positional_deprecated @positional_deprecated
def simple_evaluate(model, model_args=None, tasks=[], def simple_evaluate(
num_fewshot=0, batch_size=None, device=None, model,
no_cache=False, limit=None, bootstrap_iters=100000, model_args=None,
description_dict=None, check_integrity=False, tasks=[],
decontamination_ngrams_path=None): num_fewshot=0,
batch_size=None,
device=None,
no_cache=False,
limit=None,
bootstrap_iters=100000,
description_dict=None,
check_integrity=False,
decontamination_ngrams_path=None,
):
"""Instantiate and evaluate a model on a list of tasks. """Instantiate and evaluate a model on a list of tasks.
...@@ -52,17 +62,23 @@ def simple_evaluate(model, model_args=None, tasks=[], ...@@ -52,17 +62,23 @@ def simple_evaluate(model, model_args=None, tasks=[],
assert tasks != [], "No tasks specified" assert tasks != [], "No tasks specified"
if isinstance(model, str): if isinstance(model, str):
if model_args is None: model_args = "" if model_args is None:
lm = lm_eval.models.get_model(model).create_from_arg_string(model_args, { model_args = ""
'batch_size': batch_size, 'device': device lm = lm_eval.models.get_model(model).create_from_arg_string(
}) model_args, {"batch_size": batch_size, "device": device}
)
else: else:
assert isinstance(model, lm_eval.base.LM) assert isinstance(model, lm_eval.base.LM)
lm = model lm = model
if not no_cache: if not no_cache:
lm = lm_eval.base.CachingLM( lm = lm_eval.base.CachingLM(
lm, 'lm_cache/' + model + '_' + model_args.replace('=', '-').replace(',', '_').replace('/', '-') + '.db' lm,
"lm_cache/"
+ model
+ "_"
+ model_args.replace("=", "-").replace(",", "_").replace("/", "-")
+ ".db",
) )
task_dict = lm_eval.tasks.get_task_dict(tasks) task_dict = lm_eval.tasks.get_task_dict(tasks)
...@@ -89,16 +105,26 @@ def simple_evaluate(model, model_args=None, tasks=[], ...@@ -89,16 +105,26 @@ def simple_evaluate(model, model_args=None, tasks=[],
"no_cache": no_cache, "no_cache": no_cache,
"limit": limit, "limit": limit,
"bootstrap_iters": bootstrap_iters, "bootstrap_iters": bootstrap_iters,
"description_dict": description_dict "description_dict": description_dict,
} }
return results return results
decontaminate_suffix = "_decontaminate" decontaminate_suffix = "_decontaminate"
@positional_deprecated @positional_deprecated
def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, bootstrap_iters=100000, description_dict=None, def evaluate(
decontamination_ngrams_path=None): lm,
task_dict,
provide_description=None,
num_fewshot=0,
limit=None,
bootstrap_iters=100000,
description_dict=None,
decontamination_ngrams_path=None,
):
"""Instantiate and evaluate a model on a list of tasks. """Instantiate and evaluate a model on a list of tasks.
:param lm: obj :param lm: obj
...@@ -124,14 +150,16 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, ...@@ -124,14 +150,16 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
assert not provide_description # not implemented. assert not provide_description # not implemented.
if provide_description is not None: if provide_description is not None:
# nudge people to not specify it at all # nudge people to not specify it at all
print("WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict") print(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
decontaminate = decontamination_ngrams_path is not None decontaminate = decontamination_ngrams_path is not None
task_dict_items = [ task_dict_items = [
(name, task) (name, task)
for name, task in task_dict.items() for name, task in task_dict.items()
if(task.has_validation_docs() or task.has_test_docs()) if (task.has_validation_docs() or task.has_test_docs())
] ]
results = collections.defaultdict(dict) results = collections.defaultdict(dict)
...@@ -172,19 +200,22 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, ...@@ -172,19 +200,22 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
rnd.seed(42) rnd.seed(42)
rnd.shuffle(task_docs) rnd.shuffle(task_docs)
description = description_dict[task_name] if description_dict and task_name in description_dict else "" description = (
description_dict[task_name]
if description_dict and task_name in description_dict
else ""
)
for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)): for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)):
if decontaminate and task.should_decontaminate(): if decontaminate and task.should_decontaminate():
docs_for_decontamination[(task_name, task_set)].append(task.doc_to_decontamination_query(doc)) docs_for_decontamination[(task_name, task_set)].append(
task.doc_to_decontamination_query(doc)
)
docs[(task_name, doc_id)] = doc docs[(task_name, doc_id)] = doc
ctx = task.fewshot_context( ctx = task.fewshot_context(
doc=doc, doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description
num_fewshot=num_fewshot,
rnd=rnd,
description=description
) )
reqs = task.construct_requests(doc, ctx) reqs = task.construct_requests(doc, ctx)
if not isinstance(reqs, (list, tuple)): if not isinstance(reqs, (list, tuple)):
...@@ -198,7 +229,9 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, ...@@ -198,7 +229,9 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
# Compare all tasks/sets at once to ensure a single training set scan # Compare all tasks/sets at once to ensure a single training set scan
if decontaminate: if decontaminate:
print("Finding train/test overlap, please wait...") print("Finding train/test overlap, please wait...")
overlaps = get_train_overlap(docs_for_decontamination, decontamination_ngrams_path, limit) overlaps = get_train_overlap(
docs_for_decontamination, decontamination_ngrams_path, limit
)
# all responses for each (task, doc) # all responses for each (task, doc)
process_res_queue = collections.defaultdict(list) process_res_queue = collections.defaultdict(list)
...@@ -212,7 +245,9 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, ...@@ -212,7 +245,9 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
print("Running", reqtype, "requests") print("Running", reqtype, "requests")
resps = getattr(lm, reqtype)([req.args for req in reqs]) resps = getattr(lm, reqtype)([req.args for req in reqs])
resps = [x if req.index is None else x[req.index] for x, req in zip(resps, reqs)] resps = [
x if req.index is None else x[req.index] for x, req in zip(resps, reqs)
]
for resp, (i, task_name, doc, doc_id) in zip(resps, requests_origin[reqtype]): for resp, (i, task_name, doc, doc_id) in zip(resps, requests_origin[reqtype]):
process_res_queue[(task_name, doc_id)].append((i, resp)) process_res_queue[(task_name, doc_id)].append((i, resp))
...@@ -241,7 +276,9 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, ...@@ -241,7 +276,9 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
task = task_dict[task_name] task = task_dict[task_name]
real_metric = metric # key when looking up the metric with task.aggregation real_metric = metric # key when looking up the metric with task.aggregation
if metric.endswith(decontaminate_suffix): if metric.endswith(decontaminate_suffix):
real_metric = metric.replace(decontaminate_suffix, "") # decontaminated still uses the same metric real_metric = metric.replace(
decontaminate_suffix, ""
) # decontaminated still uses the same metric
results[task_name][metric] = task.aggregation()[real_metric](items) results[task_name][metric] = task.aggregation()[real_metric](items)
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
...@@ -249,16 +286,15 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, ...@@ -249,16 +286,15 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
stderr = lm_eval.metrics.stderr_for_metric( stderr = lm_eval.metrics.stderr_for_metric(
metric=task.aggregation()[real_metric], metric=task.aggregation()[real_metric],
bootstrap_iters=min(bootstrap_iters, 1000) if metric in ["bleu", "chrf", "ter"] else bootstrap_iters, bootstrap_iters=min(bootstrap_iters, 1000)
if metric in ["bleu", "chrf", "ter"]
else bootstrap_iters,
) )
if stderr is not None: if stderr is not None:
results[task_name][metric + "_stderr"] = stderr(items) results[task_name][metric + "_stderr"] = stderr(items)
return { return {"results": dict(results), "versions": dict(versions)}
"results": dict(results),
"versions": dict(versions)
}
def make_table(result_dict): def make_table(result_dict):
...@@ -280,9 +316,9 @@ def make_table(result_dict): ...@@ -280,9 +316,9 @@ def make_table(result_dict):
if m + "_stderr" in dic: if m + "_stderr" in dic:
se = dic[m + "_stderr"] se = dic[m + "_stderr"]
values.append([k, version, m, '%.4f' % v, '±', '%.4f' % se]) values.append([k, version, m, "%.4f" % v, "±", "%.4f" % se])
else: else:
values.append([k, version, m, '%.4f' % v, '', '']) values.append([k, version, m, "%.4f" % v, "", ""])
k = "" k = ""
version = "" version = ""
md_writer.value_matrix = values md_writer.value_matrix = values
......
...@@ -103,6 +103,7 @@ def weighted_mean(items): ...@@ -103,6 +103,7 @@ def weighted_mean(items):
def weighted_perplexity(items): def weighted_perplexity(items):
return math.exp(-weighted_mean(items)) return math.exp(-weighted_mean(items))
def bits_per_byte(items): def bits_per_byte(items):
return -weighted_mean(items) / math.log(2) return -weighted_mean(items) / math.log(2)
...@@ -184,8 +185,10 @@ def _sacreformat(refs, preds): ...@@ -184,8 +185,10 @@ def _sacreformat(refs, preds):
return refs, preds return refs, preds
# stderr stuff # stderr stuff
class _bootstrap_internal: class _bootstrap_internal:
def __init__(self, f, n): def __init__(self, f, n):
self.f = f self.f = f
...@@ -203,6 +206,7 @@ class _bootstrap_internal: ...@@ -203,6 +206,7 @@ class _bootstrap_internal:
def bootstrap_stderr(f, xs, iters): def bootstrap_stderr(f, xs, iters):
import multiprocessing as mp import multiprocessing as mp
pool = mp.Pool(mp.cpu_count()) pool = mp.Pool(mp.cpu_count())
# this gives a biased estimate of the stderr (i.e w/ the mean, it gives something # this gives a biased estimate of the stderr (i.e w/ the mean, it gives something
# equivalent to stderr calculated without Bessel's correction in the stddev. # equivalent to stderr calculated without Bessel's correction in the stddev.
...@@ -213,10 +217,15 @@ def bootstrap_stderr(f, xs, iters): ...@@ -213,10 +217,15 @@ def bootstrap_stderr(f, xs, iters):
res = [] res = []
chunk_size = min(1000, iters) chunk_size = min(1000, iters)
from tqdm import tqdm from tqdm import tqdm
print("bootstrapping for stddev:", f.__name__) print("bootstrapping for stddev:", f.__name__)
for bootstrap in tqdm(pool.imap( for bootstrap in tqdm(
pool.imap(
_bootstrap_internal(f, chunk_size), _bootstrap_internal(f, chunk_size),
[(i, xs) for i in range(iters // chunk_size)]), total=iters // chunk_size): [(i, xs) for i in range(iters // chunk_size)],
),
total=iters // chunk_size,
):
# sample w replacement # sample w replacement
res.extend(bootstrap) res.extend(bootstrap)
...@@ -238,17 +247,13 @@ def stderr_for_metric(metric, bootstrap_iters): ...@@ -238,17 +247,13 @@ def stderr_for_metric(metric, bootstrap_iters):
if metric in bootstrappable: if metric in bootstrappable:
return lambda x: bootstrap_stderr(metric, x, iters=bootstrap_iters) return lambda x: bootstrap_stderr(metric, x, iters=bootstrap_iters)
stderr = { stderr = {mean: mean_stderr, acc_all: acc_all_stderr}
mean: mean_stderr,
acc_all: acc_all_stderr
}
return stderr.get(metric, None) return stderr.get(metric, None)
def yesno(x): def yesno(x):
if x: if x:
return 'yes' return "yes"
else: else:
return 'no' return "no"
...@@ -4,8 +4,15 @@ from lm_eval.base import BaseLM ...@@ -4,8 +4,15 @@ from lm_eval.base import BaseLM
class HFLM(BaseLM): class HFLM(BaseLM):
def __init__(
def __init__(self, device='cuda', pretrained='gpt2', revision='main', subfolder=None, tokenizer=None, batch_size=1): self,
device="cuda",
pretrained="gpt2",
revision="main",
subfolder=None,
tokenizer=None,
batch_size=1,
):
super().__init__() super().__init__()
assert isinstance(device, str) assert isinstance(device, str)
...@@ -18,30 +25,49 @@ class HFLM(BaseLM): ...@@ -18,30 +25,49 @@ class HFLM(BaseLM):
self._device = torch.device(device) self._device = torch.device(device)
print(f"Using device '{device}'") print(f"Using device '{device}'")
else: else:
print("Device not specificed") print("Device not specified")
print(f"Cuda Available? {torch.cuda.is_available()}") print(f"Cuda Available? {torch.cuda.is_available()}")
self._device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') self._device = (
torch.device("cuda")
if torch.cuda.is_available()
else torch.device("cpu")
)
# TODO: update this to be less of a hack once subfolder is fixed in HF # TODO: update this to be less of a hack once subfolder is fixed in HF
self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained( self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained(
pretrained, revision=revision + ("/" + subfolder if subfolder is not None else "") pretrained,
revision=revision + ("/" + subfolder if subfolder is not None else ""),
).to(self.device) ).to(self.device)
self.gpt2.eval() self.gpt2.eval()
# pretrained tokenizer for neo is broken for now so just hard-coding this to gpt2 # pretrained tokenizer for neo is broken for now so just hard-coding this to gpt2
self.tokenizer = transformers.AutoTokenizer.from_pretrained( self.tokenizer = transformers.AutoTokenizer.from_pretrained(
pretrained if tokenizer is None else tokenizer, revision=revision, subfolder=subfolder) pretrained if tokenizer is None else tokenizer,
revision=revision,
subfolder=subfolder,
)
assert isinstance(self.tokenizer, ( assert isinstance(
transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast, self.tokenizer,
transformers.T5Tokenizer, transformers.T5TokenizerFast, (
)), "this tokenizer has not been checked for compatibility yet!" transformers.GPT2Tokenizer,
transformers.GPT2TokenizerFast,
transformers.T5Tokenizer,
transformers.T5TokenizerFast,
),
), "this tokenizer has not been checked for compatibility yet!"
self.vocab_size = self.tokenizer.vocab_size self.vocab_size = self.tokenizer.vocab_size
if isinstance(self.tokenizer, (transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast)): if isinstance(
assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373], \ self.tokenizer, (transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast)
self.tokenizer.encode('hello\n\nhello') ):
assert self.tokenizer.encode("hello\n\nhello") == [
31373,
198,
198,
31373,
], self.tokenizer.encode("hello\n\nhello")
# multithreading and batching # multithreading and batching
self.batch_size_per_gpu = batch_size # todo: adaptive batch size self.batch_size_per_gpu = batch_size # todo: adaptive batch size
...@@ -97,10 +123,7 @@ class HFLM(BaseLM): ...@@ -97,10 +123,7 @@ class HFLM(BaseLM):
def _model_generate(self, context, max_length, eos_token_id): def _model_generate(self, context, max_length, eos_token_id):
return self.gpt2.generate( return self.gpt2.generate(
context, context, max_length=max_length, eos_token_id=eos_token_id, do_sample=False
max_length=max_length,
eos_token_id=eos_token_id,
do_sample=False
) )
......
...@@ -36,17 +36,19 @@ def get_result(response, ctxlen): ...@@ -36,17 +36,19 @@ def get_result(response, ctxlen):
def oa_completion(**kwargs): def oa_completion(**kwargs):
""" Query OpenAI API for completion. """Query OpenAI API for completion.
Retry with back-off until they respond Retry with back-off until they respond
""" """
import openai import openai
backoff_time = 3 backoff_time = 3
while True: while True:
try: try:
return openai.Completion.create(**kwargs) return openai.Completion.create(**kwargs)
except openai.error.OpenAIError: except openai.error.OpenAIError:
import traceback import traceback
traceback.print_exc() traceback.print_exc()
time.sleep(backoff_time) time.sleep(backoff_time)
backoff_time *= 1.5 backoff_time *= 1.5
...@@ -66,16 +68,19 @@ class GPT3LM(BaseLM): ...@@ -66,16 +68,19 @@ class GPT3LM(BaseLM):
super().__init__() super().__init__()
import openai import openai
self.engine = engine self.engine = engine
self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2') self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained("gpt2")
self.vocab_size = self.tokenizer.vocab_size self.vocab_size = self.tokenizer.vocab_size
# to make the annoying "Using pad_token, but it is not set yet." error go away # to make the annoying "Using pad_token, but it is not set yet." error go away
self.tokenizer.pad_token = "<|endoftext|>" self.tokenizer.pad_token = "<|endoftext|>"
assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373] assert self.tokenizer.encode("hello\n\nhello") == [31373, 198, 198, 31373]
self.truncate = truncate self.truncate = truncate
self.end_of_text_token_id = self.tokenizer.convert_tokens_to_ids(["<|endoftext|>"])[0] self.end_of_text_token_id = self.tokenizer.convert_tokens_to_ids(
["<|endoftext|>"]
)[0]
# Read from environment variable OPENAI_API_SECRET_KEY # Read from environment variable OPENAI_API_SECRET_KEY
openai.api_key = os.environ["OPENAI_API_SECRET_KEY"] openai.api_key = os.environ["OPENAI_API_SECRET_KEY"]
...@@ -119,16 +124,21 @@ class GPT3LM(BaseLM): ...@@ -119,16 +124,21 @@ class GPT3LM(BaseLM):
toks = x[1] + x[2] toks = x[1] + x[2]
return -len(toks), tuple(toks) return -len(toks), tuple(toks)
reord = utils.Reorderer(requests, _collate) re_ord = utils.Reorderer(requests, _collate)
for chunk in tqdm(list(utils.chunks(reord.get_reordered(), self.REQ_CHUNK_SIZE)), disable=disable_tqdm): for chunk in tqdm(
list(utils.chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE)),
disable=disable_tqdm,
):
inps = [] inps = []
ctxlens = [] ctxlens = []
for cache_key, context_enc, continuation_enc in chunk: for cache_key, context_enc, continuation_enc in chunk:
# max_length+1 because the API takes up to 2049 tokens, including the first context token # max_length+1 because the API takes up to 2049 tokens, including the first context token
inp = (context_enc + continuation_enc)[-(self.max_length+1):] inp = (context_enc + continuation_enc)[-(self.max_length + 1) :]
# TODO: the logic is much simpler if we just look at the length of continuation tokens # TODO: the logic is much simpler if we just look at the length of continuation tokens
ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - (self.max_length+1)) ctxlen = len(context_enc) - max(
0, len(context_enc) + len(continuation_enc) - (self.max_length + 1)
)
inps.append(inp) inps.append(inp)
ctxlens.append(ctxlen) ctxlens.append(ctxlen)
...@@ -137,11 +147,14 @@ class GPT3LM(BaseLM): ...@@ -137,11 +147,14 @@ class GPT3LM(BaseLM):
engine=self.engine, engine=self.engine,
prompt=inps, prompt=inps,
echo=True, echo=True,
max_tokens=0, temperature=0., max_tokens=0,
temperature=0.0,
logprobs=10, logprobs=10,
) )
for resp, ctxlen, (cache_key, context_enc, continuation_enc) in zip(response.choices, ctxlens, chunk): for resp, ctxlen, (cache_key, context_enc, continuation_enc) in zip(
response.choices, ctxlens, chunk
):
answer = get_result(resp, ctxlen) answer = get_result(resp, ctxlen)
res.append(answer) res.append(answer)
...@@ -150,7 +163,7 @@ class GPT3LM(BaseLM): ...@@ -150,7 +163,7 @@ class GPT3LM(BaseLM):
if cache_key is not None: if cache_key is not None:
self.cache_hook.add_partial("loglikelihood", cache_key, answer) self.cache_hook.add_partial("loglikelihood", cache_key, answer)
return reord.get_original(res) return re_ord.get_original(res)
def greedy_until(self, requests): def greedy_until(self, requests):
if not requests: if not requests:
...@@ -161,7 +174,7 @@ class GPT3LM(BaseLM): ...@@ -161,7 +174,7 @@ class GPT3LM(BaseLM):
toks = self.tok_encode(x[0]) toks = self.tok_encode(x[0])
return len(toks), x[0] return len(toks), x[0]
reord = utils.Reorderer(requests, _collate) re_ord = utils.Reorderer(requests, _collate)
def sameuntil_chunks(xs, size): def sameuntil_chunks(xs, size):
ret = [] ret = []
...@@ -177,24 +190,26 @@ class GPT3LM(BaseLM): ...@@ -177,24 +190,26 @@ class GPT3LM(BaseLM):
yield ret, lastuntil yield ret, lastuntil
# todo: more intelligent batching for heterogeneous `until` # todo: more intelligent batching for heterogeneous `until`
for chunk, until in tqdm(list(sameuntil_chunks(reord.get_reordered(), self.REQ_CHUNK_SIZE))): for chunk, until in tqdm(
list(sameuntil_chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE))
):
inps = [] inps = []
for context, _ in chunk: for context, _ in chunk:
context_enc = self.tok_encode(context) context_enc = self.tok_encode(context)
inp = context_enc[-(self.max_length - self.max_gen_toks):] inp = context_enc[-(self.max_length - self.max_gen_toks) :]
inps.append(inp) inps.append(inp)
response = oa_completion( response = oa_completion(
engine=self.engine, engine=self.engine,
prompt=inps, prompt=inps,
max_tokens=self.max_gen_toks, max_tokens=self.max_gen_toks,
temperature=0., temperature=0.0,
logprobs=10, logprobs=10,
stop=until, stop=until,
) )
for resp, (context, until_) in zip(response.choices, chunk): for resp, (context, until_) in zip(response.choices, chunk):
s = resp['text'] s = resp["text"]
for term in until_: for term in until_:
s = s.split(term)[0] s = s.split(term)[0]
...@@ -204,7 +219,7 @@ class GPT3LM(BaseLM): ...@@ -204,7 +219,7 @@ class GPT3LM(BaseLM):
res.append(s) res.append(s)
return reord.get_original(res) return re_ord.get_original(res)
def _model_call(self, inps): def _model_call(self, inps):
# Isn't used because we override _loglikelihood_tokens # Isn't used because we override _loglikelihood_tokens
......
...@@ -22,14 +22,12 @@ from . import naturalqs ...@@ -22,14 +22,12 @@ from . import naturalqs
from . import sat from . import sat
from . import arithmetic from . import arithmetic
from . import lambada from . import lambada
from . import race
from . import piqa from . import piqa
from . import prost from . import prost
from . import mc_taco from . import mc_taco
from . import triviaqa from . import triviaqa
from . import pubmedqa from . import pubmedqa
from . import sciq from . import sciq
from . import webqs
from . import qasper from . import qasper
from . import qa4mre from . import qa4mre
from . import translation from . import translation
...@@ -59,8 +57,8 @@ from . import storycloze ...@@ -59,8 +57,8 @@ from . import storycloze
# 6 total # 6 total
gpt3_translation_benchmarks = { gpt3_translation_benchmarks = {
"wmt14": ['en-fr', 'fr-en'], # French "wmt14": ["en-fr", "fr-en"], # French
"wmt16": ['en-ro', 'ro-en', 'de-en', 'en-de'], # German, Romanian "wmt16": ["en-ro", "ro-en", "de-en", "en-de"], # German, Romanian
} }
...@@ -68,7 +66,7 @@ gpt3_translation_benchmarks = { ...@@ -68,7 +66,7 @@ gpt3_translation_benchmarks = {
selected_translation_benchmarks = { selected_translation_benchmarks = {
**gpt3_translation_benchmarks, **gpt3_translation_benchmarks,
"wmt20": sacrebleu.get_langpairs_for_testset("wmt20"), "wmt20": sacrebleu.get_langpairs_for_testset("wmt20"),
"iwslt17": ['en-ar', 'ar-en'] # Arabic "iwslt17": ["en-ar", "ar-en"], # Arabic
} }
# 319 total # 319 total
...@@ -92,7 +90,7 @@ TASK_REGISTRY = { ...@@ -92,7 +90,7 @@ TASK_REGISTRY = {
"rte": glue.RTE, "rte": glue.RTE,
"qnli": glue.QNLI, "qnli": glue.QNLI,
"qqp": glue.QQP, "qqp": glue.QQP,
#"stsb": glue.STSB, # not implemented yet # "stsb": glue.STSB, # not implemented yet
"sst": glue.SST, "sst": glue.SST,
"wnli": glue.WNLI, "wnli": glue.WNLI,
# SuperGLUE # SuperGLUE
...@@ -103,34 +101,26 @@ TASK_REGISTRY = { ...@@ -103,34 +101,26 @@ TASK_REGISTRY = {
"record": superglue.ReCoRD, "record": superglue.ReCoRD,
"wic": superglue.WordsInContext, "wic": superglue.WordsInContext,
"wsc": superglue.SGWinogradSchemaChallenge, "wsc": superglue.SGWinogradSchemaChallenge,
# Order by benchmark/genre? # Order by benchmark/genre?
"coqa": coqa.CoQA, "coqa": coqa.CoQA,
"drop": drop.DROP, "drop": drop.DROP,
"lambada": lambada.LAMBADA, "lambada": lambada.LAMBADA,
"lambada_cloze": lambada_cloze.LAMBADA_cloze, "lambada_cloze": lambada_cloze.LAMBADA_cloze,
# multilingual lambada # multilingual lambada
**lambada_multilingual.construct_tasks(), **lambada_multilingual.construct_tasks(),
"wikitext": wikitext.WikiText, "wikitext": wikitext.WikiText,
# "cbt-cn": cbt.CBTCN, # disabled pending context length fix # "cbt-cn": cbt.CBTCN, # disabled pending context length fix
# "cbt-ne": cbt.CBTNE, # disabled pending context length fix # "cbt-ne": cbt.CBTNE, # disabled pending context length fix
"piqa": piqa.PiQA, "piqa": piqa.PiQA,
"prost": prost.PROST, "prost": prost.PROST,
"mc_taco": mc_taco.MCTACO, "mc_taco": mc_taco.MCTACO,
# Science related # Science related
"pubmedqa" : pubmedqa.Pubmed_QA, "pubmedqa": pubmedqa.Pubmed_QA,
"sciq" : sciq.SciQ, "sciq": sciq.SciQ,
"qasper": qasper.QASPER, "qasper": qasper.QASPER,
"qa4mre_2011": qa4mre.QA4MRE_2011,
"qa4mre_2011" : qa4mre.QA4MRE_2011, "qa4mre_2012": qa4mre.QA4MRE_2012,
"qa4mre_2012" : qa4mre.QA4MRE_2012, "qa4mre_2013": qa4mre.QA4MRE_2013,
"qa4mre_2013" : qa4mre.QA4MRE_2013,
"triviaqa": triviaqa.TriviaQA, "triviaqa": triviaqa.TriviaQA,
"arc_easy": arc.ARCEasy, "arc_easy": arc.ARCEasy,
"arc_challenge": arc.ARCChallenge, "arc_challenge": arc.ARCChallenge,
...@@ -152,21 +142,17 @@ TASK_REGISTRY = { ...@@ -152,21 +142,17 @@ TASK_REGISTRY = {
"anli_r1": anli.ANLIRound1, "anli_r1": anli.ANLIRound1,
"anli_r2": anli.ANLIRound2, "anli_r2": anli.ANLIRound2,
"anli_r3": anli.ANLIRound3, "anli_r3": anli.ANLIRound3,
"ethics_cm": hendrycks_ethics.EthicsCM, "ethics_cm": hendrycks_ethics.EthicsCM,
"ethics_deontology": hendrycks_ethics.EthicsDeontology, "ethics_deontology": hendrycks_ethics.EthicsDeontology,
"ethics_justice": hendrycks_ethics.EthicsJustice, "ethics_justice": hendrycks_ethics.EthicsJustice,
"ethics_utilitarianism_original": hendrycks_ethics.EthicsUtilitarianismOriginal, "ethics_utilitarianism_original": hendrycks_ethics.EthicsUtilitarianismOriginal,
"ethics_utilitarianism": hendrycks_ethics.EthicsUtilitarianism, "ethics_utilitarianism": hendrycks_ethics.EthicsUtilitarianism,
"ethics_virtue": hendrycks_ethics.EthicsVirtue, "ethics_virtue": hendrycks_ethics.EthicsVirtue,
"truthfulqa_mc": truthfulqa.TruthfulQAMultipleChoice, "truthfulqa_mc": truthfulqa.TruthfulQAMultipleChoice,
"truthfulqa_gen": truthfulqa.TruthfulQAGeneration, "truthfulqa_gen": truthfulqa.TruthfulQAGeneration,
# dialogue # dialogue
"mutual": mutual.MuTual, "mutual": mutual.MuTual,
"mutual_plus": mutual.MuTualPlus, "mutual_plus": mutual.MuTualPlus,
# math # math
"math_algebra": hendrycks_math.MathAlgebra, "math_algebra": hendrycks_math.MathAlgebra,
"math_counting_and_prob": hendrycks_math.MathCountingAndProbability, "math_counting_and_prob": hendrycks_math.MathCountingAndProbability,
...@@ -177,7 +163,6 @@ TASK_REGISTRY = { ...@@ -177,7 +163,6 @@ TASK_REGISTRY = {
"math_precalc": hendrycks_math.MathPrecalculus, "math_precalc": hendrycks_math.MathPrecalculus,
"math_asdiv": asdiv.Asdiv, "math_asdiv": asdiv.Asdiv,
"gsm8k": gsm8k.GradeSchoolMath8K, "gsm8k": gsm8k.GradeSchoolMath8K,
# arithmetic # arithmetic
"arithmetic_2da": arithmetic.Arithmetic2DPlus, "arithmetic_2da": arithmetic.Arithmetic2DPlus,
"arithmetic_2ds": arithmetic.Arithmetic2DMinus, "arithmetic_2ds": arithmetic.Arithmetic2DMinus,
...@@ -191,22 +176,18 @@ TASK_REGISTRY = { ...@@ -191,22 +176,18 @@ TASK_REGISTRY = {
"arithmetic_1dc": arithmetic.Arithmetic1DComposite, "arithmetic_1dc": arithmetic.Arithmetic1DComposite,
# TODO Perhaps make these groups of tasks # TODO Perhaps make these groups of tasks
# e.g. anli, arithmetic, openai_translations, harness_translations # e.g. anli, arithmetic, openai_translations, harness_translations
# hendrycksTest (57 tasks) # hendrycksTest (57 tasks)
**hendrycks_test.create_all_tasks(), **hendrycks_test.create_all_tasks(),
# e.g. wmt14-fr-en # e.g. wmt14-fr-en
**translation.create_tasks_from_benchmarks(gpt3_translation_benchmarks), **translation.create_tasks_from_benchmarks(gpt3_translation_benchmarks),
# chef's selection, mostly wmt20 # chef's selection, mostly wmt20
**translation.create_tasks_from_benchmarks(selected_translation_benchmarks), **translation.create_tasks_from_benchmarks(selected_translation_benchmarks),
# Word Scrambling and Manipulation Tasks # Word Scrambling and Manipulation Tasks
"anagrams1": unscramble.Anagrams1, "anagrams1": unscramble.Anagrams1,
"anagrams2": unscramble.Anagrams2, "anagrams2": unscramble.Anagrams2,
"cycle_letters": unscramble.CycleLetters, "cycle_letters": unscramble.CycleLetters,
"random_insertion": unscramble.RandomInsertion, "random_insertion": unscramble.RandomInsertion,
"reversed_words": unscramble.ReversedWords, "reversed_words": unscramble.ReversedWords,
# Pile # Pile
"pile_arxiv": pile.PileArxiv, "pile_arxiv": pile.PileArxiv,
"pile_books3": pile.PileBooks3, "pile_books3": pile.PileBooks3,
...@@ -230,7 +211,6 @@ TASK_REGISTRY = { ...@@ -230,7 +211,6 @@ TASK_REGISTRY = {
"pile_ubuntu-irc": pile.PileUbuntuIrc, "pile_ubuntu-irc": pile.PileUbuntuIrc,
"pile_wikipedia": pile.PileWikipedia, "pile_wikipedia": pile.PileWikipedia,
"pile_youtubesubtitles": pile.PileYoutubeSubtitles, "pile_youtubesubtitles": pile.PileYoutubeSubtitles,
# BLiMP # BLiMP
"blimp_adjunct_island": blimp.BlimpAdjunctIsland, "blimp_adjunct_island": blimp.BlimpAdjunctIsland,
"blimp_anaphor_gender_agreement": blimp.BlimpAnaphorGenderAgreement, "blimp_anaphor_gender_agreement": blimp.BlimpAnaphorGenderAgreement,
...@@ -299,7 +279,6 @@ TASK_REGISTRY = { ...@@ -299,7 +279,6 @@ TASK_REGISTRY = {
"blimp_wh_vs_that_no_gap_long_distance": blimp.BlimpWhVsThatNoGapLongDistance, "blimp_wh_vs_that_no_gap_long_distance": blimp.BlimpWhVsThatNoGapLongDistance,
"blimp_wh_vs_that_with_gap": blimp.BlimpWhVsThatWithGap, "blimp_wh_vs_that_with_gap": blimp.BlimpWhVsThatWithGap,
"blimp_wh_vs_that_with_gap_long_distance": blimp.BlimpWhVsThatWithGapLongDistance, "blimp_wh_vs_that_with_gap_long_distance": blimp.BlimpWhVsThatWithGapLongDistance,
# Requires manual download of data. # Requires manual download of data.
# "storycloze_2016": storycloze.StoryCloze2016, # "storycloze_2016": storycloze.StoryCloze2016,
# "storycloze_2018": storycloze.StoryCloze2018, # "storycloze_2018": storycloze.StoryCloze2018,
...@@ -313,7 +292,7 @@ ALL_TASKS = sorted(list(TASK_REGISTRY)) ...@@ -313,7 +292,7 @@ ALL_TASKS = sorted(list(TASK_REGISTRY))
def get_task(task_name): def get_task(task_name):
try: try:
return TASK_REGISTRY[task_name] return TASK_REGISTRY[task_name]
except KeyError as e: except KeyError:
print("Available tasks:") print("Available tasks:")
pprint(TASK_REGISTRY) pprint(TASK_REGISTRY)
raise KeyError(f"Missing task {task_name}") raise KeyError(f"Missing task {task_name}")
...@@ -325,17 +304,23 @@ def get_task_name_from_object(task_object): ...@@ -325,17 +304,23 @@ def get_task_name_from_object(task_object):
return name return name
# this gives a mechanism for non-registered tasks to have a custom name anyways when reporting # this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
return task_object.EVAL_HARNESS_NAME if hasattr(task_object, "EVAL_HARNESS_NAME") else type(task_object).__name__ return (
task_object.EVAL_HARNESS_NAME
if hasattr(task_object, "EVAL_HARNESS_NAME")
else type(task_object).__name__
)
def get_task_dict(task_name_list: List[Union[str, lm_eval.base.Task]]): def get_task_dict(task_name_list: List[Union[str, lm_eval.base.Task]]):
task_name_dict = { task_name_dict = {
task_name: get_task(task_name)() task_name: get_task(task_name)()
for task_name in task_name_list if isinstance(task_name, str) for task_name in task_name_list
if isinstance(task_name, str)
} }
task_name_from_object_dict = { task_name_from_object_dict = {
get_task_name_from_object(task_object): task_object get_task_name_from_object(task_object): task_object
for task_object in task_name_list if not isinstance(task_object, str) for task_object in task_name_list
if not isinstance(task_object, str)
} }
assert set(task_name_dict.keys()).isdisjoint(set(task_name_from_object_dict.keys())) assert set(task_name_dict.keys()).isdisjoint(set(task_name_from_object_dict.keys()))
return {**task_name_dict, **task_name_from_object_dict} return {**task_name_dict, **task_name_from_object_dict}
...@@ -64,7 +64,12 @@ class ANLIBase(Task): ...@@ -64,7 +64,12 @@ class ANLIBase(Task):
# of the prompt (yes, repeating it!). also, " True, False, or Neither?" is directly # of the prompt (yes, repeating it!). also, " True, False, or Neither?" is directly
# appended onto the question, with no "Answer:" or even a newline. Do we *really* # appended onto the question, with no "Answer:" or even a newline. Do we *really*
# want to do it exactly as OA did? # want to do it exactly as OA did?
return doc['premise'] + '\nQuestion: ' + doc['hypothesis'] + ' True, False, or Neither?\nAnswer:' return (
doc["premise"]
+ "\nQuestion: "
+ doc["hypothesis"]
+ " True, False, or Neither?\nAnswer:"
)
def should_decontaminate(self): def should_decontaminate(self):
return True return True
...@@ -76,10 +81,10 @@ class ANLIBase(Task): ...@@ -76,10 +81,10 @@ class ANLIBase(Task):
# True = entailment # True = entailment
# False = contradiction # False = contradiction
# Neither = neutral # Neither = neutral
return " " + ["True", "Neither", "False"][doc['label']] return " " + ["True", "Neither", "False"][doc["label"]]
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of """Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM. Requests which will be sent to the LM.
:param doc: :param doc:
...@@ -106,9 +111,7 @@ class ANLIBase(Task): ...@@ -106,9 +111,7 @@ class ANLIBase(Task):
""" """
gold = doc["label"] gold = doc["label"]
pred = np.argmax(results) pred = np.argmax(results)
return { return {"acc": pred == gold}
"acc": pred == gold
}
def aggregation(self): def aggregation(self):
""" """
...@@ -116,9 +119,7 @@ class ANLIBase(Task): ...@@ -116,9 +119,7 @@ class ANLIBase(Task):
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics functions that aggregate a list of metrics
""" """
return { return {"acc": mean}
"acc": mean
}
def higher_is_better(self): def higher_is_better(self):
""" """
...@@ -126,9 +127,7 @@ class ANLIBase(Task): ...@@ -126,9 +127,7 @@ class ANLIBase(Task):
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better whether a higher value of the submetric is better
""" """
return { return {"acc": True}
"acc": True
}
class ANLIRound1(ANLIBase): class ANLIRound1(ANLIBase):
......
...@@ -67,10 +67,8 @@ class Arithmetic(Task): ...@@ -67,10 +67,8 @@ class Arithmetic(Task):
return is_prediction return is_prediction
def process_results(self, doc, results): def process_results(self, doc, results):
is_prediction, = results (is_prediction,) = results
return { return {"acc": is_prediction}
"acc": is_prediction
}
def aggregation(self): def aggregation(self):
return { return {
...@@ -78,9 +76,7 @@ class Arithmetic(Task): ...@@ -78,9 +76,7 @@ class Arithmetic(Task):
} }
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True}
"acc": True
}
class Arithmetic2DPlus(Arithmetic): class Arithmetic2DPlus(Arithmetic):
......
...@@ -54,29 +54,28 @@ class Asdiv(Task): ...@@ -54,29 +54,28 @@ class Asdiv(Task):
def test_docs(self): def test_docs(self):
raise NotImplementedError("This dataset has no test docs") raise NotImplementedError("This dataset has no test docs")
def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None): def fewshot_context(
self, doc, num_fewshot, provide_description=None, rnd=None, description=None
):
assert num_fewshot == 0, "ASDiv is intended only for the zero-shot setting." assert num_fewshot == 0, "ASDiv is intended only for the zero-shot setting."
return super().fewshot_context( return super().fewshot_context(
doc=doc, doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description
num_fewshot=num_fewshot,
rnd=rnd,
description=description
) )
def doc_to_text(self, doc): def doc_to_text(self, doc):
# TODO: add solution-type # TODO: add solution-type
return doc['body'] + '\n' + 'Question:' + doc['question'] + '\n' + 'Answer:' return doc["body"] + "\n" + "Question:" + doc["question"] + "\n" + "Answer:"
def should_decontaminate(self): def should_decontaminate(self):
return True return True
def doc_to_decontamination_query(self, doc): def doc_to_decontamination_query(self, doc):
return doc['body'] + " " + doc['question'] return doc["body"] + " " + doc["question"]
def doc_to_target(self, doc): def doc_to_target(self, doc):
# TODO: add formula # TODO: add formula
answer = doc['answer'].split(' (')[0] answer = doc["answer"].split(" (")[0]
return " " + answer return " " + answer
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
...@@ -86,16 +85,10 @@ class Asdiv(Task): ...@@ -86,16 +85,10 @@ class Asdiv(Task):
def process_results(self, doc, results): def process_results(self, doc, results):
ll, is_greedy = results ll, is_greedy = results
return { return {"acc": int(is_greedy)}
'acc': int(is_greedy)
}
def aggregation(self): def aggregation(self):
return { return {"acc": mean}
'acc': mean
}
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True}
'acc': True
}
...@@ -28,7 +28,7 @@ _CITATION = """ ...@@ -28,7 +28,7 @@ _CITATION = """
eprint = {https://doi.org/10.1162/tacl_a_00321}, eprint = {https://doi.org/10.1162/tacl_a_00321},
abstract = { We introduce The Benchmark of Linguistic Minimal Pairs (BLiMP),1 a challenge set for evaluating the linguistic knowledge of language models (LMs) on major grammatical phenomena in English. BLiMP consists of 67 individual datasets, each containing 1,000 minimal pairs—that is, pairs of minimally different sentences that contrast in grammatical acceptability and isolate specific phenomenon in syntax, morphology, or semantics. We generate the data according to linguist-crafted grammar templates, and human aggregate agreement with the labels is 96.4\%. We evaluate n-gram, LSTM, and Transformer (GPT-2 and Transformer-XL) LMs by observing whether they assign a higher probability to the acceptable sentence in each minimal pair. We find that state-of-the-art models identify morphological contrasts related to agreement reliably, but they struggle with some subtle semantic and syntactic phenomena, such as negative polarity items and extraction islands. } abstract = { We introduce The Benchmark of Linguistic Minimal Pairs (BLiMP),1 a challenge set for evaluating the linguistic knowledge of language models (LMs) on major grammatical phenomena in English. BLiMP consists of 67 individual datasets, each containing 1,000 minimal pairs—that is, pairs of minimally different sentences that contrast in grammatical acceptability and isolate specific phenomenon in syntax, morphology, or semantics. We generate the data according to linguist-crafted grammar templates, and human aggregate agreement with the labels is 96.4\%. We evaluate n-gram, LSTM, and Transformer (GPT-2 and Transformer-XL) LMs by observing whether they assign a higher probability to the acceptable sentence in each minimal pair. We find that state-of-the-art models identify morphological contrasts related to agreement reliably, but they struggle with some subtle semantic and syntactic phenomena, such as negative polarity items and extraction islands. }
} }
""" """ # noqa: W605
class BlimpTask(Task): class BlimpTask(Task):
...@@ -50,9 +50,13 @@ class BlimpTask(Task): ...@@ -50,9 +50,13 @@ class BlimpTask(Task):
# trained on this data. # trained on this data.
return self.dataset["train"] return self.dataset["train"]
def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None): def fewshot_context(
self, doc, num_fewshot, provide_description=None, rnd=None, description=None
):
assert num_fewshot == 0 assert num_fewshot == 0
assert rnd is not None, "A `random.Random` generator argument must be provided to `rnd`" assert (
rnd is not None
), "A `random.Random` generator argument must be provided to `rnd`"
assert not provide_description, ( assert not provide_description, (
"The `provide_description` arg will be removed in future versions. To prepend " "The `provide_description` arg will be removed in future versions. To prepend "
"a custom description to the context, supply the corresponding string via the " "a custom description to the context, supply the corresponding string via the "
...@@ -60,7 +64,9 @@ class BlimpTask(Task): ...@@ -60,7 +64,9 @@ class BlimpTask(Task):
) )
if provide_description is not None: if provide_description is not None:
# nudge people to not specify it at all # nudge people to not specify it at all
print("WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict") print(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
return "" return ""
......
...@@ -86,7 +86,9 @@ class CBTBase(Task): ...@@ -86,7 +86,9 @@ class CBTBase(Task):
return "" return ""
def fewshot_examples(self, k, rnd): def fewshot_examples(self, k, rnd):
assert k == 0, f"CBT is only implemented for the zero-shot setting. Given k={k}." assert (
k == 0
), f"CBT is only implemented for the zero-shot setting. Given k={k}."
return super().fewshot_examples(k, rnd) return super().fewshot_examples(k, rnd)
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
...@@ -120,9 +122,7 @@ class CBTBase(Task): ...@@ -120,9 +122,7 @@ class CBTBase(Task):
""" """
gold = doc["options"].index(doc["answer"]) gold = doc["options"].index(doc["answer"])
pred = np.argmax(results) pred = np.argmax(results)
return { return {"acc": pred == gold}
"acc": pred == gold
}
def aggregation(self): def aggregation(self):
""" """
...@@ -130,9 +130,7 @@ class CBTBase(Task): ...@@ -130,9 +130,7 @@ class CBTBase(Task):
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics functions that aggregate a list of metrics
""" """
return { return {"acc": mean}
"acc": mean
}
def higher_is_better(self): def higher_is_better(self):
""" """
...@@ -140,9 +138,7 @@ class CBTBase(Task): ...@@ -140,9 +138,7 @@ class CBTBase(Task):
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better whether a higher value of the submetric is better
""" """
return { return {"acc": True}
"acc": True
}
class CBTCN(CBTBase): class CBTCN(CBTBase):
......
...@@ -54,8 +54,10 @@ class CoQA(Task): ...@@ -54,8 +54,10 @@ class CoQA(Task):
def doc_to_text(self, doc): def doc_to_text(self, doc):
# Given a passage p, the conversation history {q1, a1, . . . qi−1, ai−1} # Given a passage p, the conversation history {q1, a1, . . . qi−1, ai−1}
# and a question qi, the task is to predict the answer ai # and a question qi, the task is to predict the answer ai
doc_text = doc["story"] + '\n\n' doc_text = doc["story"] + "\n\n"
for (q, a) in zip_longest(doc["questions"]["input_text"], doc["answers"]["input_text"][:-1]): # omit target answer ai for (q, a) in zip_longest(
doc["questions"]["input_text"], doc["answers"]["input_text"][:-1]
): # omit target answer ai
question = f"Q: {q}\n\n" question = f"Q: {q}\n\n"
answer = f"A: {a}\n\n" if a is not None else "A:" answer = f"A: {a}\n\n" if a is not None else "A:"
doc_text += question + answer doc_text += question + answer
...@@ -77,7 +79,9 @@ class CoQA(Task): ...@@ -77,7 +79,9 @@ class CoQA(Task):
additional_answers = doc.get("additional_answers") additional_answers = doc.get("additional_answers")
if additional_answers: if additional_answers:
for key in additional_answers: for key in additional_answers:
additional_answer_for_turn = additional_answers[key]["input_text"][turn_id - 1] additional_answer_for_turn = additional_answers[key]["input_text"][
turn_id - 1
]
if additional_answer_for_turn.lower() not in map(str.lower, answers): if additional_answer_for_turn.lower() not in map(str.lower, answers):
answers.append(additional_answer_for_turn) answers.append(additional_answer_for_turn)
return answers return answers
...@@ -89,12 +93,12 @@ class CoQA(Task): ...@@ -89,12 +93,12 @@ class CoQA(Task):
# ~ 2/3 of the CoQA answers are span-based # ~ 2/3 of the CoQA answers are span-based
# (answers overlap with the passage ignoring punctuation and case mismatch) # (answers overlap with the passage ignoring punctuation and case mismatch)
if raw_text == "unknown": if raw_text == "unknown":
return '0' return "0"
if squad_metrics.normalize_answer(raw_text) == "yes": if squad_metrics.normalize_answer(raw_text) == "yes":
return '1' return "1"
if squad_metrics.normalize_answer(raw_text) == "no": if squad_metrics.normalize_answer(raw_text) == "no":
return '2' return "2"
return '3' # Not a yes/no question return "3" # Not a yes/no question
@staticmethod @staticmethod
def compute_scores(gold_list, pred): def compute_scores(gold_list, pred):
...@@ -104,25 +108,30 @@ class CoQA(Task): ...@@ -104,25 +108,30 @@ class CoQA(Task):
em_sum = 0.0 em_sum = 0.0
if len(gold_list) > 1: if len(gold_list) > 1:
for i in range(len(gold_list)): for i in range(len(gold_list)):
gold_answers = gold_list[0:i] + gold_list[i + 1:] gold_answers = gold_list[0:i] + gold_list[i + 1 :]
# predictions compared against (n) golds and take maximum # predictions compared against (n) golds and take maximum
em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_answers) em_sum += max(
squad_metrics.compute_exact(a, pred) for a in gold_answers
)
f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_answers) f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_answers)
else: else:
em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_list) em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_list)
f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_list) f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_list)
return {'em': em_sum / max(1, len(gold_list)), 'f1': f1_sum / max(1, len(gold_list))} return {
"em": em_sum / max(1, len(gold_list)),
"f1": f1_sum / max(1, len(gold_list)),
}
def doc_to_target(self, doc, turnid=None): def doc_to_target(self, doc, turnid=None):
# Default to prediction of last turn. # Default to prediction of last turn.
if turnid is None: if turnid is None:
turnid = len(doc["questions"]["input_text"]) turnid = len(doc["questions"]["input_text"])
raw_text = doc['answers']["input_text"][turnid - 1] raw_text = doc["answers"]["input_text"][turnid - 1]
return " " + raw_text return " " + raw_text
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of """Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM. Requests which will be sent to the LM.
:param doc: :param doc:
...@@ -132,7 +141,7 @@ class CoQA(Task): ...@@ -132,7 +141,7 @@ class CoQA(Task):
language description, as well as the few shot examples, and the question language description, as well as the few shot examples, and the question
part of the document for `doc`. part of the document for `doc`.
""" """
cont_request = rf.greedy_until(ctx, ['\nQ:']) cont_request = rf.greedy_until(ctx, ["\nQ:"])
return cont_request return cont_request
def process_results(self, doc, results): def process_results(self, doc, results):
...@@ -147,13 +156,13 @@ class CoQA(Task): ...@@ -147,13 +156,13 @@ class CoQA(Task):
""" """
turn_id = len(doc["questions"]["input_text"]) turn_id = len(doc["questions"]["input_text"])
gold_list = self.get_answers(doc, turn_id) gold_list = self.get_answers(doc, turn_id)
pred = results[0].strip().split('\n')[0] pred = results[0].strip().split("\n")[0]
scores = self.compute_scores(gold_list, pred) scores = self.compute_scores(gold_list, pred)
return { return {
"f1": scores['f1'], "f1": scores["f1"],
"em": scores['em'], "em": scores["em"],
} }
def higher_is_better(self): def higher_is_better(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment