Unverified Commit a2cada5d authored by Jonathan Tow's avatar Jonathan Tow Committed by GitHub
Browse files

Merge pull request #317 from EleutherAI/Mistobaan/add-pre-commit

Add pre-commit
parents 7a038118 83507c4b
......@@ -65,13 +65,14 @@ class TruthfulqaConfig(datasets.BuilderConfig):
class Truthfulqa(datasets.GeneratorBasedBuilder):
"""TruthfulQA is a benchmark to measure whether a language model is truthful in
generating answers to questions."""
generating answers to questions."""
BUILDER_CONFIGS = [
TruthfulqaConfig(
name="multiple_choice",
url="https://raw.githubusercontent.com/sylinrl/TruthfulQA/013686a06be7a7bde5bf8223943e106c7250123c/data/mc_task.json",
features=datasets.Features({
features=datasets.Features(
{
"question": datasets.Value("string"),
"mc1_targets": {
"choices": datasets.features.Sequence(datasets.Value("string")),
......@@ -80,23 +81,30 @@ generating answers to questions."""
"mc2_targets": {
"choices": datasets.features.Sequence(datasets.Value("string")),
"labels": datasets.features.Sequence(datasets.Value("int32")),
},
}
}),
description="The multiple choice TruthfulQA task"
),
description="The multiple choice TruthfulQA task",
),
TruthfulqaConfig(
name="generation",
url="https://raw.githubusercontent.com/sylinrl/TruthfulQA/013686a06be7a7bde5bf8223943e106c7250123c/TruthfulQA.csv",
features=datasets.Features({
features=datasets.Features(
{
"category": datasets.Value("string"),
"question": datasets.Value("string"),
"best_answer": datasets.Value("string"),
"correct_answers": datasets.features.Sequence(datasets.Value("string")),
"incorrect_answers": datasets.features.Sequence(datasets.Value("string")),
"correct_answers": datasets.features.Sequence(
datasets.Value("string")
),
"incorrect_answers": datasets.features.Sequence(
datasets.Value("string")
),
"source": datasets.Value("string"),
}),
description="The generative TruthfulQA task"
)
}
),
description="The generative TruthfulQA task",
),
]
def _info(self):
......@@ -138,15 +146,15 @@ generating answers to questions."""
"mc2_targets": {
"choices": row["mc2_targets"].keys(),
"labels": row["mc2_targets"].values(),
}
},
}
else:
# Generation data is in a `CSV` file.
with open(filepath, newline='') as f:
with open(filepath, newline="") as f:
contents = csv.DictReader(f)
for key, row in enumerate(contents):
# Ensure that references exist.
if not row['Correct Answers'] or not row['Incorrect Answers']:
if not row["Correct Answers"] or not row["Incorrect Answers"]:
continue
yield key, {
"category": row["Category"],
......@@ -154,6 +162,8 @@ generating answers to questions."""
"best_answer": row["Best Answer"],
# split on ";"
"correct_answers": row["Correct Answers"].strip().split(";"),
"incorrect_answers": row["Incorrect Answers"].strip().split(";"),
"incorrect_answers": row["Incorrect Answers"]
.strip()
.split(";"),
"source": row["Source"],
}
......@@ -64,8 +64,9 @@ class Unscramble(datasets.GeneratorBasedBuilder):
VERSION = datasets.Version("0.0.1")
BUILDER_CONFIGS = [
datasets.BuilderConfig(name=name, version=version,
description=_DESCRIPTIONS[name])
datasets.BuilderConfig(
name=name, version=version, description=_DESCRIPTIONS[name]
)
for name, version in zip(_NAMES, [VERSION] * len(_NAMES))
]
......
......@@ -123,86 +123,111 @@ class Wikitext(datasets.GeneratorBasedBuilder):
return [
datasets.SplitGenerator(
name=datasets.Split.TEST,
gen_kwargs={"data_file": os.path.join(
data_dir, "wiki.test.tokens"), "split": "test"},
gen_kwargs={
"data_file": os.path.join(data_dir, "wiki.test.tokens"),
"split": "test",
},
),
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={"data_file": os.path.join(
data_dir, "wiki.train.tokens"), "split": "train"},
gen_kwargs={
"data_file": os.path.join(data_dir, "wiki.train.tokens"),
"split": "train",
},
),
datasets.SplitGenerator(
name=datasets.Split.VALIDATION,
gen_kwargs={"data_file": os.path.join(
data_dir, "wiki.valid.tokens"), "split": "valid"},
gen_kwargs={
"data_file": os.path.join(data_dir, "wiki.valid.tokens"),
"split": "valid",
},
),
]
else:
if self.config.name == "wikitext-103-raw-v1":
data_file = dl_manager.download_and_extract(
self.config.data_url)
data_file = dl_manager.download_and_extract(self.config.data_url)
data_dir = os.path.join(data_file, "wikitext-103-raw")
return [
datasets.SplitGenerator(
name=datasets.Split.TEST,
gen_kwargs={"data_file": os.path.join(
data_dir, "wiki.test.raw"), "split": "test"},
gen_kwargs={
"data_file": os.path.join(data_dir, "wiki.test.raw"),
"split": "test",
},
),
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={"data_file": os.path.join(
data_dir, "wiki.train.raw"), "split": "train"},
gen_kwargs={
"data_file": os.path.join(data_dir, "wiki.train.raw"),
"split": "train",
},
),
datasets.SplitGenerator(
name=datasets.Split.VALIDATION,
gen_kwargs={"data_file": os.path.join(
data_dir, "wiki.valid.raw"), "split": "valid"},
gen_kwargs={
"data_file": os.path.join(data_dir, "wiki.valid.raw"),
"split": "valid",
},
),
]
else:
if self.config.name == "wikitext-2-raw-v1":
data_file = dl_manager.download_and_extract(
self.config.data_url)
data_file = dl_manager.download_and_extract(self.config.data_url)
data_dir = os.path.join(data_file, "wikitext-2-raw")
return [
datasets.SplitGenerator(
name=datasets.Split.TEST,
gen_kwargs={"data_file": os.path.join(
data_dir, "wiki.test.raw"), "split": "test"},
gen_kwargs={
"data_file": os.path.join(data_dir, "wiki.test.raw"),
"split": "test",
},
),
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={"data_file": os.path.join(
data_dir, "wiki.train.raw"), "split": "train"},
gen_kwargs={
"data_file": os.path.join(data_dir, "wiki.train.raw"),
"split": "train",
},
),
datasets.SplitGenerator(
name=datasets.Split.VALIDATION,
gen_kwargs={"data_file": os.path.join(
data_dir, "wiki.valid.raw"), "split": "valid"},
gen_kwargs={
"data_file": os.path.join(data_dir, "wiki.valid.raw"),
"split": "valid",
},
),
]
else:
if self.config.name == "wikitext-2-v1":
data_file = dl_manager.download_and_extract(
self.config.data_url)
self.config.data_url
)
data_dir = os.path.join(data_file, "wikitext-2")
return [
datasets.SplitGenerator(
name=datasets.Split.TEST,
gen_kwargs={"data_file": os.path.join(
data_dir, "wiki.test.tokens"), "split": "test"},
gen_kwargs={
"data_file": os.path.join(
data_dir, "wiki.test.tokens"
),
"split": "test",
},
),
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={
"data_file": os.path.join(data_dir, "wiki.train.tokens"),
"data_file": os.path.join(
data_dir, "wiki.train.tokens"
),
"split": "train",
},
),
datasets.SplitGenerator(
name=datasets.Split.VALIDATION,
gen_kwargs={
"data_file": os.path.join(data_dir, "wiki.valid.tokens"),
"data_file": os.path.join(
data_dir, "wiki.valid.tokens"
),
"split": "valid",
},
),
......@@ -216,12 +241,12 @@ class Wikitext(datasets.GeneratorBasedBuilder):
data = f.read().split("\n")
for line in data:
rline = line.replace("= = =", "===").replace("= =", "==").strip()
if rline.startswith('= ') and rline.strip().endswith(' ='):
page = '\n'.join(ret)
if rline.startswith("= ") and rline.strip().endswith(" ="):
page = "\n".join(ret)
if page.strip():
yield key, {"page": page}
key += 1
ret = []
ret.append(line)
page = '\n'.join(ret)
page = "\n".join(ret)
yield key, {"page": page}
......@@ -8,12 +8,14 @@ import mmap
import tqdm
from pathlib import Path
def json_serial(obj):
"""JSON serializer for objects not serializable by default json code"""
if isinstance(obj, (datetime.datetime,)):
return obj.isoformat()
raise TypeError ("Type %s not serializable" % type(obj))
raise TypeError("Type %s not serializable" % type(obj))
# Modified version of lm_dataformat Archive for single file.
class Archive:
......@@ -22,25 +24,31 @@ class Archive:
dir_name = os.path.dirname(file_path)
if dir_name:
os.makedirs(dir_name, exist_ok=True)
self.fh = open(self.file_path, 'wb')
self.fh = open(self.file_path, "wb")
self.cctx = zstandard.ZstdCompressor(level=compression_level)
self.compressor = self.cctx.stream_writer(self.fh)
def add_data(self, data, meta={}):
self.compressor.write(json.dumps({'text': data, 'meta': meta}, default=json_serial).encode('UTF-8') + b'\n')
self.compressor.write(
json.dumps({"text": data, "meta": meta}, default=json_serial).encode(
"UTF-8"
)
+ b"\n"
)
def commit(self):
self.compressor.flush(zstandard.FLUSH_FRAME)
self.fh.flush()
self.fh.close()
# Modified version of lm_dataformat Reader with self.fh set, allowing peeking for tqdm.
class Reader:
def __init__(self):
pass
def read(self, file, get_meta=False, autojoin_paragraphs=True, para_joiner='\n\n'):
with open(file, 'rb') as fh:
def read(self, file, get_meta=False, autojoin_paragraphs=True, para_joiner="\n\n"):
with open(file, "rb") as fh:
self.fh = fh
cctx = zstandard.ZstdDecompressor()
reader = io.BufferedReader(cctx.stream_reader(fh))
......@@ -52,16 +60,17 @@ class Reader:
yield ob
continue
text = ob['text']
text = ob["text"]
if autojoin_paragraphs and isinstance(text, list):
text = para_joiner.join(text)
if get_meta:
yield text, (ob['meta'] if 'meta' in ob else {})
yield text, (ob["meta"] if "meta" in ob else {})
else:
yield text
class TextArchive:
def __init__(self, file_path, mode="rb+"):
self.file_path = file_path
......@@ -75,12 +84,13 @@ class TextArchive:
self.fh = open(self.file_path, mode)
def add_data(self, data):
self.fh.write(data.encode('UTF-8') + b'\n')
self.fh.write(data.encode("UTF-8") + b"\n")
def commit(self):
self.fh.flush()
self.fh.close()
class TextReader:
def __init__(self, file_path):
self.file_path = file_path
......@@ -90,9 +100,12 @@ class TextReader:
def read_tqdm(self, update_frequency=10000):
current_file_position = 0
line_counter = 0
with open(self.file_path, 'r') as fh, \
tqdm.tqdm(total=os.path.getsize(self.file_path), dynamic_ncols=True,
unit="byte", unit_scale=1) as progress:
with open(self.file_path, "r") as fh, tqdm.tqdm(
total=os.path.getsize(self.file_path),
dynamic_ncols=True,
unit="byte",
unit_scale=1,
) as progress:
with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
for line in iter(mmap_obj.readline, b""):
line = line.decode("utf-8")
......@@ -107,7 +120,7 @@ class TextReader:
def read_and_tell(self):
current_file_position = 0
with open(self.file_path, 'r', encoding="utf8") as fh:
with open(self.file_path, "r", encoding="utf8") as fh:
with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
for line in iter(mmap_obj.readline, b""):
line = line.decode("utf-8")
......@@ -117,14 +130,14 @@ class TextReader:
yield line[:-1], raw_bytes_read
def read(self):
with open(self.file_path, 'r', encoding="utf8") as fh:
with open(self.file_path, "r", encoding="utf8") as fh:
with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
for line in iter(mmap_obj.readline, b""):
line = line.decode("utf-8")
yield line[:-1]
def read_slow(self):
with open(self.file_path, 'r', encoding="utf8") as fh:
with open(self.file_path, "r", encoding="utf8") as fh:
while True:
line = fh.readline()
if line == -1 or line == "":
......@@ -132,6 +145,7 @@ class TextReader:
else:
yield line[:-1]
# Optimized for speed. Decompresses the archive in shell before
# using the mmap'd TextReader.
class ZStdTextReader:
......
......@@ -9,6 +9,7 @@ import collections
from .janitor import Janitor, word_ngrams
from .archiver import ZStdTextReader
# Was used for testing the evaluator decoupled from the full logic below
def get_train_overlap_stub(docs, ngrams_path, ngrams_n_size):
simulated_overlap = 0.1
......@@ -57,19 +58,27 @@ def get_train_overlap(docs_by_task_set, ngrams_path, limit):
os.mkdir(f"data/{task_name}")
# Check if we've decontaminated this combination before
overlaps_dump_path = get_overlaps_dump_path(task_name, task_set, ngrams_n_size, limit)
overlaps_dump_path = get_overlaps_dump_path(
task_name, task_set, ngrams_n_size, limit
)
if os.path.exists(overlaps_dump_path):
duplicates[(task_name, task_set)] = pickle.load(open(overlaps_dump_path, "rb"))
duplicates[(task_name, task_set)] = pickle.load(
open(overlaps_dump_path, "rb")
)
sets_to_decontaminate -= 1
continue
else:
duplicates[(task_name, task_set)] = set()
# Build/load the task lookup {ngram: set(documents)}.
task_set_lookup_path = f"data/{task_name}/{task_set}_{ngrams_n_size}grams_limit{limit}.lookup"
task_set_lookup_path = (
f"data/{task_name}/{task_set}_{ngrams_n_size}grams_limit{limit}.lookup"
)
if os.path.exists(task_set_lookup_path):
print(f"{task_set_lookup_path} available, loading...")
lookups[(task_name, task_set)] = pickle.load(open(task_set_lookup_path, "rb"))
lookups[(task_name, task_set)] = pickle.load(
open(task_set_lookup_path, "rb")
)
else:
print(f"{task_set_lookup_path} not available, building...")
lookup = collections.defaultdict(set)
......@@ -79,7 +88,7 @@ def get_train_overlap(docs_by_task_set, ngrams_path, limit):
for ngram in ngrams:
lookup[ngram].add(doc_id)
pickle.dump(lookup, open(task_set_lookup_path,"wb"))
pickle.dump(lookup, open(task_set_lookup_path, "wb"))
lookups[(task_name, task_set)] = lookup
elapsed = time.perf_counter() - start
......@@ -115,7 +124,9 @@ def get_train_overlap(docs_by_task_set, ngrams_path, limit):
for line in reader.read_tqdm(): # Scan training set ngrams file
total_ngrams += 1
[ngram, document_id] = line.rsplit(" ", 1)
if ngram != current_ngram: # Only need to match the ngram once in training set
if (
ngram != current_ngram
): # Only need to match the ngram once in training set
unique_ngrams += 1
current_ngram = ngram
if ngram in merged_lookup:
......@@ -123,7 +134,11 @@ def get_train_overlap(docs_by_task_set, ngrams_path, limit):
matching_unique += 1
for task_name, task_set, doc_ids in merged_lookup[ngram]:
task_doc_set = duplicates[(task_name, task_set)]
for doc_id in doc_ids: # Record contamination across all relevant task/set combos
for (
doc_id
) in (
doc_ids
): # Record contamination across all relevant task/set combos
task_doc_set.add(doc_id)
del merged_lookup[ngram] # No point matching again
else:
......@@ -145,9 +160,10 @@ def get_train_overlap(docs_by_task_set, ngrams_path, limit):
# Dump overlaps separately
for (task_name, task_set), doc_ids in duplicates.items():
overlaps_dump_path = get_overlaps_dump_path(task_name, task_set, ngrams_n_size, limit)
pickle.dump(doc_ids, open(overlaps_dump_path,"wb"))
overlaps_dump_path = get_overlaps_dump_path(
task_name, task_set, ngrams_n_size, limit
)
pickle.dump(doc_ids, open(overlaps_dump_path, "wb"))
# Strip task set and return
return {task_name: doc_ids for (task_name, task_set), doc_ids in duplicates.items()}
......@@ -9,8 +9,9 @@ from pprint import pprint
# c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup
try:
import janitor_util
JANITOR_CPP = True
except Exception as e:
except Exception:
print("WARNING: C++ module could not be loaded. Janitor running in python mode")
traceback.print_exc()
JANITOR_CPP = False
......@@ -41,6 +42,7 @@ def word_ngrams(s, n):
ngram_seqs = form_ngrams(iter(tokens), n)
return (" ".join(ngram) for ngram in ngram_seqs)
# Does character sequences only - combined faster function to play around with later
# def word_ngrams_indices_combined(sequence, n):
# current_word = ""
......@@ -70,7 +72,7 @@ def split_indices(s):
"""Splits a string on whitespaces and records the indices of each in the original string.
@:return generator((word, (start_idx, end_idx)), ...)
"""
return ((m.group(0), (m.start(), m.end() - 1)) for m in re.finditer(r'\S+', s))
return ((m.group(0), (m.start(), m.end() - 1)) for m in re.finditer(r"\S+", s))
def word_ngrams_indices(s, n):
......@@ -90,10 +92,15 @@ def word_ngrams_indices(s, n):
# ([word, word, ...], [(start,end), (start,end), ...]),
# ...
# )
ngram_indices_pairs = (zip(*ngram_with_indices) for ngram_with_indices in ngram_seqs_with_indices)
ngram_indices_pairs = (
zip(*ngram_with_indices) for ngram_with_indices in ngram_seqs_with_indices
)
# Generator of ( (word_ngram, (start, end)), (word_ngram, start, end)), ...)
return ((" ".join(ngram_seq), (indices[0][0], indices[-1][1])) for ngram_seq, indices in ngram_indices_pairs)
return (
(" ".join(ngram_seq), (indices[0][0], indices[-1][1]))
for ngram_seq, indices in ngram_indices_pairs
)
class Janitor:
......@@ -105,7 +112,7 @@ class Janitor:
window_to_remove=200,
too_dirty_cutoff=10,
minimum_slice_length=200,
delete_chars=string.punctuation
delete_chars=string.punctuation,
):
self.ngram_n = ngram_n
self.window_to_remove = window_to_remove
......@@ -121,7 +128,7 @@ class Janitor:
self.translation_table = str.maketrans(
string.ascii_lowercase + string.ascii_uppercase, # These characters
string.ascii_lowercase * 2, # Become these characters
self.delete_chars # These are deleted
self.delete_chars, # These are deleted
)
##############
......@@ -129,14 +136,13 @@ class Janitor:
##############
def save_contamination_ngrams(self, filename):
with open(filename, 'wb') as fp:
with open(filename, "wb") as fp:
pickle.dump(filename, fp)
def load_contamination_ngrams(self, filename):
with open(filename, 'rb') as fp:
with open(filename, "rb") as fp:
self.dirt_ngrams = pickle.load(fp)
##############
# Call these :)
##############
......@@ -152,7 +158,7 @@ class Janitor:
def clean(self, dirty_string):
"""Clean a string (e.g. a training set) by removing all ngrams previously
reigstered as contaminants. Returns a list of clean chunks, or empty if
registered as contaminants. Returns a list of clean chunks, or empty if
the string was too dirty"""
if JANITOR_CPP:
return self.clean_cpp(dirty_string)
......@@ -171,11 +177,11 @@ class Janitor:
end = min(len(dirty_string), end + self.window_to_remove)
if start - splice_idx > self.minimum_slice_length:
clean_chunks.append(dirty_string[splice_idx: start])
clean_chunks.append(dirty_string[splice_idx:start])
splice_idx = end
if end < len(dirty_string) - self.minimum_slice_length:
clean_chunks.append(dirty_string[end+1:])
clean_chunks.append(dirty_string[end + 1 :])
return clean_chunks
......@@ -184,10 +190,14 @@ class Janitor:
##############
def register_contaminant_cpp(self, dirt_string):
self.dirt_ngrams.update(janitor_util.clean_ngram(dirt_string, self.delete_chars, self.ngram_n))
self.dirt_ngrams.update(
janitor_util.clean_ngram(dirt_string, self.delete_chars, self.ngram_n)
)
def clean_cpp(self, dirty_string):
contamination_indices = janitor_util.clean_ngram_with_indices(dirty_string, self.delete_chars, self.ngram_n)
contamination_indices = janitor_util.clean_ngram_with_indices(
dirty_string, self.delete_chars, self.ngram_n
)
return self._split_chunks(dirty_string, contamination_indices)
##############
......@@ -198,7 +208,9 @@ class Janitor:
return s.translate(self.translation_table)
def register_contaminant_python(self, dirt_string):
self.dirt_ngrams.update(word_ngrams(self.normalize_string(dirt_string), self.ngram_n))
self.dirt_ngrams.update(
word_ngrams(self.normalize_string(dirt_string), self.ngram_n)
)
def clean_python(self, dirty_string):
contamination_indices = (
......@@ -263,7 +275,7 @@ class Janitor:
# ultan Al Nahyan]] granted [[Petroleum]] concessions, and oil was first found in 1958. At first,
# oil money had a marginal impact. A few lowrise concete buildings were erected, and the first
# paved road was completed in 1961, but Sheikh Shakbut, uncertain whether the new oil royalties
# would last, took a cautious approach, prefering to save the revenue rather than investing it in
# would last, took a cautious approach, preferring to save the revenue rather than investing it in
# development. His brother, [[Zayed bin Sultan Al Nahayan]], saw that oil wealth had the potential
# to transform Abu Dhabi. The ruling Al Nahayan family decided that Sheikh Zayed should replace his
# brother as Ruler and carry out his vision of developing the country. On [[August 6]], [[1966]],
......
......@@ -11,12 +11,22 @@ import numpy as np
from lm_eval.utils import positional_deprecated, run_task_tests
from lm_eval.decontamination.decontaminate import get_train_overlap
@positional_deprecated
def simple_evaluate(model, model_args=None, tasks=[],
num_fewshot=0, batch_size=None, device=None,
no_cache=False, limit=None, bootstrap_iters=100000,
description_dict=None, check_integrity=False,
decontamination_ngrams_path=None):
def simple_evaluate(
model,
model_args=None,
tasks=[],
num_fewshot=0,
batch_size=None,
device=None,
no_cache=False,
limit=None,
bootstrap_iters=100000,
description_dict=None,
check_integrity=False,
decontamination_ngrams_path=None,
):
"""Instantiate and evaluate a model on a list of tasks.
......@@ -52,17 +62,23 @@ def simple_evaluate(model, model_args=None, tasks=[],
assert tasks != [], "No tasks specified"
if isinstance(model, str):
if model_args is None: model_args = ""
lm = lm_eval.models.get_model(model).create_from_arg_string(model_args, {
'batch_size': batch_size, 'device': device
})
if model_args is None:
model_args = ""
lm = lm_eval.models.get_model(model).create_from_arg_string(
model_args, {"batch_size": batch_size, "device": device}
)
else:
assert isinstance(model, lm_eval.base.LM)
lm = model
if not no_cache:
lm = lm_eval.base.CachingLM(
lm, 'lm_cache/' + model + '_' + model_args.replace('=', '-').replace(',', '_').replace('/', '-') + '.db'
lm,
"lm_cache/"
+ model
+ "_"
+ model_args.replace("=", "-").replace(",", "_").replace("/", "-")
+ ".db",
)
task_dict = lm_eval.tasks.get_task_dict(tasks)
......@@ -89,16 +105,26 @@ def simple_evaluate(model, model_args=None, tasks=[],
"no_cache": no_cache,
"limit": limit,
"bootstrap_iters": bootstrap_iters,
"description_dict": description_dict
"description_dict": description_dict,
}
return results
decontaminate_suffix = "_decontaminate"
@positional_deprecated
def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, bootstrap_iters=100000, description_dict=None,
decontamination_ngrams_path=None):
def evaluate(
lm,
task_dict,
provide_description=None,
num_fewshot=0,
limit=None,
bootstrap_iters=100000,
description_dict=None,
decontamination_ngrams_path=None,
):
"""Instantiate and evaluate a model on a list of tasks.
:param lm: obj
......@@ -124,14 +150,16 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
assert not provide_description # not implemented.
if provide_description is not None:
# nudge people to not specify it at all
print("WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict")
print(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
decontaminate = decontamination_ngrams_path is not None
task_dict_items = [
(name, task)
for name, task in task_dict.items()
if(task.has_validation_docs() or task.has_test_docs())
if (task.has_validation_docs() or task.has_test_docs())
]
results = collections.defaultdict(dict)
......@@ -172,19 +200,22 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
rnd.seed(42)
rnd.shuffle(task_docs)
description = description_dict[task_name] if description_dict and task_name in description_dict else ""
description = (
description_dict[task_name]
if description_dict and task_name in description_dict
else ""
)
for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)):
if decontaminate and task.should_decontaminate():
docs_for_decontamination[(task_name, task_set)].append(task.doc_to_decontamination_query(doc))
docs_for_decontamination[(task_name, task_set)].append(
task.doc_to_decontamination_query(doc)
)
docs[(task_name, doc_id)] = doc
ctx = task.fewshot_context(
doc=doc,
num_fewshot=num_fewshot,
rnd=rnd,
description=description
doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description
)
reqs = task.construct_requests(doc, ctx)
if not isinstance(reqs, (list, tuple)):
......@@ -198,7 +229,9 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
# Compare all tasks/sets at once to ensure a single training set scan
if decontaminate:
print("Finding train/test overlap, please wait...")
overlaps = get_train_overlap(docs_for_decontamination, decontamination_ngrams_path, limit)
overlaps = get_train_overlap(
docs_for_decontamination, decontamination_ngrams_path, limit
)
# all responses for each (task, doc)
process_res_queue = collections.defaultdict(list)
......@@ -212,7 +245,9 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
print("Running", reqtype, "requests")
resps = getattr(lm, reqtype)([req.args for req in reqs])
resps = [x if req.index is None else x[req.index] for x, req in zip(resps, reqs)]
resps = [
x if req.index is None else x[req.index] for x, req in zip(resps, reqs)
]
for resp, (i, task_name, doc, doc_id) in zip(resps, requests_origin[reqtype]):
process_res_queue[(task_name, doc_id)].append((i, resp))
......@@ -241,7 +276,9 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
task = task_dict[task_name]
real_metric = metric # key when looking up the metric with task.aggregation
if metric.endswith(decontaminate_suffix):
real_metric = metric.replace(decontaminate_suffix, "") # decontaminated still uses the same metric
real_metric = metric.replace(
decontaminate_suffix, ""
) # decontaminated still uses the same metric
results[task_name][metric] = task.aggregation()[real_metric](items)
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
......@@ -249,16 +286,15 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
stderr = lm_eval.metrics.stderr_for_metric(
metric=task.aggregation()[real_metric],
bootstrap_iters=min(bootstrap_iters, 1000) if metric in ["bleu", "chrf", "ter"] else bootstrap_iters,
bootstrap_iters=min(bootstrap_iters, 1000)
if metric in ["bleu", "chrf", "ter"]
else bootstrap_iters,
)
if stderr is not None:
results[task_name][metric + "_stderr"] = stderr(items)
return {
"results": dict(results),
"versions": dict(versions)
}
return {"results": dict(results), "versions": dict(versions)}
def make_table(result_dict):
......@@ -280,9 +316,9 @@ def make_table(result_dict):
if m + "_stderr" in dic:
se = dic[m + "_stderr"]
values.append([k, version, m, '%.4f' % v, '±', '%.4f' % se])
values.append([k, version, m, "%.4f" % v, "±", "%.4f" % se])
else:
values.append([k, version, m, '%.4f' % v, '', ''])
values.append([k, version, m, "%.4f" % v, "", ""])
k = ""
version = ""
md_writer.value_matrix = values
......
......@@ -103,6 +103,7 @@ def weighted_mean(items):
def weighted_perplexity(items):
return math.exp(-weighted_mean(items))
def bits_per_byte(items):
return -weighted_mean(items) / math.log(2)
......@@ -184,8 +185,10 @@ def _sacreformat(refs, preds):
return refs, preds
# stderr stuff
class _bootstrap_internal:
def __init__(self, f, n):
self.f = f
......@@ -203,6 +206,7 @@ class _bootstrap_internal:
def bootstrap_stderr(f, xs, iters):
import multiprocessing as mp
pool = mp.Pool(mp.cpu_count())
# this gives a biased estimate of the stderr (i.e w/ the mean, it gives something
# equivalent to stderr calculated without Bessel's correction in the stddev.
......@@ -213,10 +217,15 @@ def bootstrap_stderr(f, xs, iters):
res = []
chunk_size = min(1000, iters)
from tqdm import tqdm
print("bootstrapping for stddev:", f.__name__)
for bootstrap in tqdm(pool.imap(
for bootstrap in tqdm(
pool.imap(
_bootstrap_internal(f, chunk_size),
[(i, xs) for i in range(iters // chunk_size)]), total=iters // chunk_size):
[(i, xs) for i in range(iters // chunk_size)],
),
total=iters // chunk_size,
):
# sample w replacement
res.extend(bootstrap)
......@@ -238,17 +247,13 @@ def stderr_for_metric(metric, bootstrap_iters):
if metric in bootstrappable:
return lambda x: bootstrap_stderr(metric, x, iters=bootstrap_iters)
stderr = {
mean: mean_stderr,
acc_all: acc_all_stderr
}
stderr = {mean: mean_stderr, acc_all: acc_all_stderr}
return stderr.get(metric, None)
def yesno(x):
if x:
return 'yes'
return "yes"
else:
return 'no'
return "no"
......@@ -4,8 +4,15 @@ from lm_eval.base import BaseLM
class HFLM(BaseLM):
def __init__(self, device='cuda', pretrained='gpt2', revision='main', subfolder=None, tokenizer=None, batch_size=1):
def __init__(
self,
device="cuda",
pretrained="gpt2",
revision="main",
subfolder=None,
tokenizer=None,
batch_size=1,
):
super().__init__()
assert isinstance(device, str)
......@@ -18,30 +25,49 @@ class HFLM(BaseLM):
self._device = torch.device(device)
print(f"Using device '{device}'")
else:
print("Device not specificed")
print("Device not specified")
print(f"Cuda Available? {torch.cuda.is_available()}")
self._device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
self._device = (
torch.device("cuda")
if torch.cuda.is_available()
else torch.device("cpu")
)
# TODO: update this to be less of a hack once subfolder is fixed in HF
self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained(
pretrained, revision=revision + ("/" + subfolder if subfolder is not None else "")
pretrained,
revision=revision + ("/" + subfolder if subfolder is not None else ""),
).to(self.device)
self.gpt2.eval()
# pretrained tokenizer for neo is broken for now so just hard-coding this to gpt2
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
pretrained if tokenizer is None else tokenizer, revision=revision, subfolder=subfolder)
pretrained if tokenizer is None else tokenizer,
revision=revision,
subfolder=subfolder,
)
assert isinstance(self.tokenizer, (
transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast,
transformers.T5Tokenizer, transformers.T5TokenizerFast,
)), "this tokenizer has not been checked for compatibility yet!"
assert isinstance(
self.tokenizer,
(
transformers.GPT2Tokenizer,
transformers.GPT2TokenizerFast,
transformers.T5Tokenizer,
transformers.T5TokenizerFast,
),
), "this tokenizer has not been checked for compatibility yet!"
self.vocab_size = self.tokenizer.vocab_size
if isinstance(self.tokenizer, (transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast)):
assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373], \
self.tokenizer.encode('hello\n\nhello')
if isinstance(
self.tokenizer, (transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast)
):
assert self.tokenizer.encode("hello\n\nhello") == [
31373,
198,
198,
31373,
], self.tokenizer.encode("hello\n\nhello")
# multithreading and batching
self.batch_size_per_gpu = batch_size # todo: adaptive batch size
......@@ -97,10 +123,7 @@ class HFLM(BaseLM):
def _model_generate(self, context, max_length, eos_token_id):
return self.gpt2.generate(
context,
max_length=max_length,
eos_token_id=eos_token_id,
do_sample=False
context, max_length=max_length, eos_token_id=eos_token_id, do_sample=False
)
......
......@@ -36,17 +36,19 @@ def get_result(response, ctxlen):
def oa_completion(**kwargs):
""" Query OpenAI API for completion.
"""Query OpenAI API for completion.
Retry with back-off until they respond
"""
import openai
backoff_time = 3
while True:
try:
return openai.Completion.create(**kwargs)
except openai.error.OpenAIError:
import traceback
traceback.print_exc()
time.sleep(backoff_time)
backoff_time *= 1.5
......@@ -66,16 +68,19 @@ class GPT3LM(BaseLM):
super().__init__()
import openai
self.engine = engine
self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained("gpt2")
self.vocab_size = self.tokenizer.vocab_size
# to make the annoying "Using pad_token, but it is not set yet." error go away
self.tokenizer.pad_token = "<|endoftext|>"
assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373]
assert self.tokenizer.encode("hello\n\nhello") == [31373, 198, 198, 31373]
self.truncate = truncate
self.end_of_text_token_id = self.tokenizer.convert_tokens_to_ids(["<|endoftext|>"])[0]
self.end_of_text_token_id = self.tokenizer.convert_tokens_to_ids(
["<|endoftext|>"]
)[0]
# Read from environment variable OPENAI_API_SECRET_KEY
openai.api_key = os.environ["OPENAI_API_SECRET_KEY"]
......@@ -119,16 +124,21 @@ class GPT3LM(BaseLM):
toks = x[1] + x[2]
return -len(toks), tuple(toks)
reord = utils.Reorderer(requests, _collate)
re_ord = utils.Reorderer(requests, _collate)
for chunk in tqdm(list(utils.chunks(reord.get_reordered(), self.REQ_CHUNK_SIZE)), disable=disable_tqdm):
for chunk in tqdm(
list(utils.chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE)),
disable=disable_tqdm,
):
inps = []
ctxlens = []
for cache_key, context_enc, continuation_enc in chunk:
# max_length+1 because the API takes up to 2049 tokens, including the first context token
inp = (context_enc + continuation_enc)[-(self.max_length+1):]
inp = (context_enc + continuation_enc)[-(self.max_length + 1) :]
# TODO: the logic is much simpler if we just look at the length of continuation tokens
ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - (self.max_length+1))
ctxlen = len(context_enc) - max(
0, len(context_enc) + len(continuation_enc) - (self.max_length + 1)
)
inps.append(inp)
ctxlens.append(ctxlen)
......@@ -137,11 +147,14 @@ class GPT3LM(BaseLM):
engine=self.engine,
prompt=inps,
echo=True,
max_tokens=0, temperature=0.,
max_tokens=0,
temperature=0.0,
logprobs=10,
)
for resp, ctxlen, (cache_key, context_enc, continuation_enc) in zip(response.choices, ctxlens, chunk):
for resp, ctxlen, (cache_key, context_enc, continuation_enc) in zip(
response.choices, ctxlens, chunk
):
answer = get_result(resp, ctxlen)
res.append(answer)
......@@ -150,7 +163,7 @@ class GPT3LM(BaseLM):
if cache_key is not None:
self.cache_hook.add_partial("loglikelihood", cache_key, answer)
return reord.get_original(res)
return re_ord.get_original(res)
def greedy_until(self, requests):
if not requests:
......@@ -161,7 +174,7 @@ class GPT3LM(BaseLM):
toks = self.tok_encode(x[0])
return len(toks), x[0]
reord = utils.Reorderer(requests, _collate)
re_ord = utils.Reorderer(requests, _collate)
def sameuntil_chunks(xs, size):
ret = []
......@@ -177,24 +190,26 @@ class GPT3LM(BaseLM):
yield ret, lastuntil
# todo: more intelligent batching for heterogeneous `until`
for chunk, until in tqdm(list(sameuntil_chunks(reord.get_reordered(), self.REQ_CHUNK_SIZE))):
for chunk, until in tqdm(
list(sameuntil_chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE))
):
inps = []
for context, _ in chunk:
context_enc = self.tok_encode(context)
inp = context_enc[-(self.max_length - self.max_gen_toks):]
inp = context_enc[-(self.max_length - self.max_gen_toks) :]
inps.append(inp)
response = oa_completion(
engine=self.engine,
prompt=inps,
max_tokens=self.max_gen_toks,
temperature=0.,
temperature=0.0,
logprobs=10,
stop=until,
)
for resp, (context, until_) in zip(response.choices, chunk):
s = resp['text']
s = resp["text"]
for term in until_:
s = s.split(term)[0]
......@@ -204,7 +219,7 @@ class GPT3LM(BaseLM):
res.append(s)
return reord.get_original(res)
return re_ord.get_original(res)
def _model_call(self, inps):
# Isn't used because we override _loglikelihood_tokens
......
......@@ -22,14 +22,12 @@ from . import naturalqs
from . import sat
from . import arithmetic
from . import lambada
from . import race
from . import piqa
from . import prost
from . import mc_taco
from . import triviaqa
from . import pubmedqa
from . import sciq
from . import webqs
from . import qasper
from . import qa4mre
from . import translation
......@@ -59,8 +57,8 @@ from . import storycloze
# 6 total
gpt3_translation_benchmarks = {
"wmt14": ['en-fr', 'fr-en'], # French
"wmt16": ['en-ro', 'ro-en', 'de-en', 'en-de'], # German, Romanian
"wmt14": ["en-fr", "fr-en"], # French
"wmt16": ["en-ro", "ro-en", "de-en", "en-de"], # German, Romanian
}
......@@ -68,7 +66,7 @@ gpt3_translation_benchmarks = {
selected_translation_benchmarks = {
**gpt3_translation_benchmarks,
"wmt20": sacrebleu.get_langpairs_for_testset("wmt20"),
"iwslt17": ['en-ar', 'ar-en'] # Arabic
"iwslt17": ["en-ar", "ar-en"], # Arabic
}
# 319 total
......@@ -92,7 +90,7 @@ TASK_REGISTRY = {
"rte": glue.RTE,
"qnli": glue.QNLI,
"qqp": glue.QQP,
#"stsb": glue.STSB, # not implemented yet
# "stsb": glue.STSB, # not implemented yet
"sst": glue.SST,
"wnli": glue.WNLI,
# SuperGLUE
......@@ -103,34 +101,26 @@ TASK_REGISTRY = {
"record": superglue.ReCoRD,
"wic": superglue.WordsInContext,
"wsc": superglue.SGWinogradSchemaChallenge,
# Order by benchmark/genre?
"coqa": coqa.CoQA,
"drop": drop.DROP,
"lambada": lambada.LAMBADA,
"lambada_cloze": lambada_cloze.LAMBADA_cloze,
# multilingual lambada
**lambada_multilingual.construct_tasks(),
"wikitext": wikitext.WikiText,
# "cbt-cn": cbt.CBTCN, # disabled pending context length fix
# "cbt-ne": cbt.CBTNE, # disabled pending context length fix
"piqa": piqa.PiQA,
"prost": prost.PROST,
"mc_taco": mc_taco.MCTACO,
# Science related
"pubmedqa" : pubmedqa.Pubmed_QA,
"sciq" : sciq.SciQ,
"pubmedqa": pubmedqa.Pubmed_QA,
"sciq": sciq.SciQ,
"qasper": qasper.QASPER,
"qa4mre_2011" : qa4mre.QA4MRE_2011,
"qa4mre_2012" : qa4mre.QA4MRE_2012,
"qa4mre_2013" : qa4mre.QA4MRE_2013,
"qa4mre_2011": qa4mre.QA4MRE_2011,
"qa4mre_2012": qa4mre.QA4MRE_2012,
"qa4mre_2013": qa4mre.QA4MRE_2013,
"triviaqa": triviaqa.TriviaQA,
"arc_easy": arc.ARCEasy,
"arc_challenge": arc.ARCChallenge,
......@@ -152,21 +142,17 @@ TASK_REGISTRY = {
"anli_r1": anli.ANLIRound1,
"anli_r2": anli.ANLIRound2,
"anli_r3": anli.ANLIRound3,
"ethics_cm": hendrycks_ethics.EthicsCM,
"ethics_deontology": hendrycks_ethics.EthicsDeontology,
"ethics_justice": hendrycks_ethics.EthicsJustice,
"ethics_utilitarianism_original": hendrycks_ethics.EthicsUtilitarianismOriginal,
"ethics_utilitarianism": hendrycks_ethics.EthicsUtilitarianism,
"ethics_virtue": hendrycks_ethics.EthicsVirtue,
"truthfulqa_mc": truthfulqa.TruthfulQAMultipleChoice,
"truthfulqa_gen": truthfulqa.TruthfulQAGeneration,
# dialogue
"mutual": mutual.MuTual,
"mutual_plus": mutual.MuTualPlus,
# math
"math_algebra": hendrycks_math.MathAlgebra,
"math_counting_and_prob": hendrycks_math.MathCountingAndProbability,
......@@ -177,7 +163,6 @@ TASK_REGISTRY = {
"math_precalc": hendrycks_math.MathPrecalculus,
"math_asdiv": asdiv.Asdiv,
"gsm8k": gsm8k.GradeSchoolMath8K,
# arithmetic
"arithmetic_2da": arithmetic.Arithmetic2DPlus,
"arithmetic_2ds": arithmetic.Arithmetic2DMinus,
......@@ -191,22 +176,18 @@ TASK_REGISTRY = {
"arithmetic_1dc": arithmetic.Arithmetic1DComposite,
# TODO Perhaps make these groups of tasks
# e.g. anli, arithmetic, openai_translations, harness_translations
# hendrycksTest (57 tasks)
**hendrycks_test.create_all_tasks(),
# e.g. wmt14-fr-en
**translation.create_tasks_from_benchmarks(gpt3_translation_benchmarks),
# chef's selection, mostly wmt20
**translation.create_tasks_from_benchmarks(selected_translation_benchmarks),
# Word Scrambling and Manipulation Tasks
"anagrams1": unscramble.Anagrams1,
"anagrams2": unscramble.Anagrams2,
"cycle_letters": unscramble.CycleLetters,
"random_insertion": unscramble.RandomInsertion,
"reversed_words": unscramble.ReversedWords,
# Pile
"pile_arxiv": pile.PileArxiv,
"pile_books3": pile.PileBooks3,
......@@ -230,7 +211,6 @@ TASK_REGISTRY = {
"pile_ubuntu-irc": pile.PileUbuntuIrc,
"pile_wikipedia": pile.PileWikipedia,
"pile_youtubesubtitles": pile.PileYoutubeSubtitles,
# BLiMP
"blimp_adjunct_island": blimp.BlimpAdjunctIsland,
"blimp_anaphor_gender_agreement": blimp.BlimpAnaphorGenderAgreement,
......@@ -299,7 +279,6 @@ TASK_REGISTRY = {
"blimp_wh_vs_that_no_gap_long_distance": blimp.BlimpWhVsThatNoGapLongDistance,
"blimp_wh_vs_that_with_gap": blimp.BlimpWhVsThatWithGap,
"blimp_wh_vs_that_with_gap_long_distance": blimp.BlimpWhVsThatWithGapLongDistance,
# Requires manual download of data.
# "storycloze_2016": storycloze.StoryCloze2016,
# "storycloze_2018": storycloze.StoryCloze2018,
......@@ -313,7 +292,7 @@ ALL_TASKS = sorted(list(TASK_REGISTRY))
def get_task(task_name):
try:
return TASK_REGISTRY[task_name]
except KeyError as e:
except KeyError:
print("Available tasks:")
pprint(TASK_REGISTRY)
raise KeyError(f"Missing task {task_name}")
......@@ -325,17 +304,23 @@ def get_task_name_from_object(task_object):
return name
# this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
return task_object.EVAL_HARNESS_NAME if hasattr(task_object, "EVAL_HARNESS_NAME") else type(task_object).__name__
return (
task_object.EVAL_HARNESS_NAME
if hasattr(task_object, "EVAL_HARNESS_NAME")
else type(task_object).__name__
)
def get_task_dict(task_name_list: List[Union[str, lm_eval.base.Task]]):
task_name_dict = {
task_name: get_task(task_name)()
for task_name in task_name_list if isinstance(task_name, str)
for task_name in task_name_list
if isinstance(task_name, str)
}
task_name_from_object_dict = {
get_task_name_from_object(task_object): task_object
for task_object in task_name_list if not isinstance(task_object, str)
for task_object in task_name_list
if not isinstance(task_object, str)
}
assert set(task_name_dict.keys()).isdisjoint(set(task_name_from_object_dict.keys()))
return {**task_name_dict, **task_name_from_object_dict}
......@@ -64,7 +64,12 @@ class ANLIBase(Task):
# of the prompt (yes, repeating it!). also, " True, False, or Neither?" is directly
# appended onto the question, with no "Answer:" or even a newline. Do we *really*
# want to do it exactly as OA did?
return doc['premise'] + '\nQuestion: ' + doc['hypothesis'] + ' True, False, or Neither?\nAnswer:'
return (
doc["premise"]
+ "\nQuestion: "
+ doc["hypothesis"]
+ " True, False, or Neither?\nAnswer:"
)
def should_decontaminate(self):
return True
......@@ -76,10 +81,10 @@ class ANLIBase(Task):
# True = entailment
# False = contradiction
# Neither = neutral
return " " + ["True", "Neither", "False"][doc['label']]
return " " + ["True", "Neither", "False"][doc["label"]]
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
......@@ -106,9 +111,7 @@ class ANLIBase(Task):
"""
gold = doc["label"]
pred = np.argmax(results)
return {
"acc": pred == gold
}
return {"acc": pred == gold}
def aggregation(self):
"""
......@@ -116,9 +119,7 @@ class ANLIBase(Task):
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
return {
"acc": mean
}
return {"acc": mean}
def higher_is_better(self):
"""
......@@ -126,9 +127,7 @@ class ANLIBase(Task):
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return {
"acc": True
}
return {"acc": True}
class ANLIRound1(ANLIBase):
......
......@@ -67,10 +67,8 @@ class Arithmetic(Task):
return is_prediction
def process_results(self, doc, results):
is_prediction, = results
return {
"acc": is_prediction
}
(is_prediction,) = results
return {"acc": is_prediction}
def aggregation(self):
return {
......@@ -78,9 +76,7 @@ class Arithmetic(Task):
}
def higher_is_better(self):
return {
"acc": True
}
return {"acc": True}
class Arithmetic2DPlus(Arithmetic):
......
......@@ -54,29 +54,28 @@ class Asdiv(Task):
def test_docs(self):
raise NotImplementedError("This dataset has no test docs")
def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None):
def fewshot_context(
self, doc, num_fewshot, provide_description=None, rnd=None, description=None
):
assert num_fewshot == 0, "ASDiv is intended only for the zero-shot setting."
return super().fewshot_context(
doc=doc,
num_fewshot=num_fewshot,
rnd=rnd,
description=description
doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description
)
def doc_to_text(self, doc):
# TODO: add solution-type
return doc['body'] + '\n' + 'Question:' + doc['question'] + '\n' + 'Answer:'
return doc["body"] + "\n" + "Question:" + doc["question"] + "\n" + "Answer:"
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc['body'] + " " + doc['question']
return doc["body"] + " " + doc["question"]
def doc_to_target(self, doc):
# TODO: add formula
answer = doc['answer'].split(' (')[0]
answer = doc["answer"].split(" (")[0]
return " " + answer
def construct_requests(self, doc, ctx):
......@@ -86,16 +85,10 @@ class Asdiv(Task):
def process_results(self, doc, results):
ll, is_greedy = results
return {
'acc': int(is_greedy)
}
return {"acc": int(is_greedy)}
def aggregation(self):
return {
'acc': mean
}
return {"acc": mean}
def higher_is_better(self):
return {
'acc': True
}
return {"acc": True}
......@@ -28,7 +28,7 @@ _CITATION = """
eprint = {https://doi.org/10.1162/tacl_a_00321},
abstract = { We introduce The Benchmark of Linguistic Minimal Pairs (BLiMP),1 a challenge set for evaluating the linguistic knowledge of language models (LMs) on major grammatical phenomena in English. BLiMP consists of 67 individual datasets, each containing 1,000 minimal pairs—that is, pairs of minimally different sentences that contrast in grammatical acceptability and isolate specific phenomenon in syntax, morphology, or semantics. We generate the data according to linguist-crafted grammar templates, and human aggregate agreement with the labels is 96.4\%. We evaluate n-gram, LSTM, and Transformer (GPT-2 and Transformer-XL) LMs by observing whether they assign a higher probability to the acceptable sentence in each minimal pair. We find that state-of-the-art models identify morphological contrasts related to agreement reliably, but they struggle with some subtle semantic and syntactic phenomena, such as negative polarity items and extraction islands. }
}
"""
""" # noqa: W605
class BlimpTask(Task):
......@@ -50,9 +50,13 @@ class BlimpTask(Task):
# trained on this data.
return self.dataset["train"]
def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None):
def fewshot_context(
self, doc, num_fewshot, provide_description=None, rnd=None, description=None
):
assert num_fewshot == 0
assert rnd is not None, "A `random.Random` generator argument must be provided to `rnd`"
assert (
rnd is not None
), "A `random.Random` generator argument must be provided to `rnd`"
assert not provide_description, (
"The `provide_description` arg will be removed in future versions. To prepend "
"a custom description to the context, supply the corresponding string via the "
......@@ -60,7 +64,9 @@ class BlimpTask(Task):
)
if provide_description is not None:
# nudge people to not specify it at all
print("WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict")
print(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
return ""
......
......@@ -86,7 +86,9 @@ class CBTBase(Task):
return ""
def fewshot_examples(self, k, rnd):
assert k == 0, f"CBT is only implemented for the zero-shot setting. Given k={k}."
assert (
k == 0
), f"CBT is only implemented for the zero-shot setting. Given k={k}."
return super().fewshot_examples(k, rnd)
def construct_requests(self, doc, ctx):
......@@ -120,9 +122,7 @@ class CBTBase(Task):
"""
gold = doc["options"].index(doc["answer"])
pred = np.argmax(results)
return {
"acc": pred == gold
}
return {"acc": pred == gold}
def aggregation(self):
"""
......@@ -130,9 +130,7 @@ class CBTBase(Task):
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
return {
"acc": mean
}
return {"acc": mean}
def higher_is_better(self):
"""
......@@ -140,9 +138,7 @@ class CBTBase(Task):
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return {
"acc": True
}
return {"acc": True}
class CBTCN(CBTBase):
......
......@@ -54,8 +54,10 @@ class CoQA(Task):
def doc_to_text(self, doc):
# Given a passage p, the conversation history {q1, a1, . . . qi−1, ai−1}
# and a question qi, the task is to predict the answer ai
doc_text = doc["story"] + '\n\n'
for (q, a) in zip_longest(doc["questions"]["input_text"], doc["answers"]["input_text"][:-1]): # omit target answer ai
doc_text = doc["story"] + "\n\n"
for (q, a) in zip_longest(
doc["questions"]["input_text"], doc["answers"]["input_text"][:-1]
): # omit target answer ai
question = f"Q: {q}\n\n"
answer = f"A: {a}\n\n" if a is not None else "A:"
doc_text += question + answer
......@@ -77,7 +79,9 @@ class CoQA(Task):
additional_answers = doc.get("additional_answers")
if additional_answers:
for key in additional_answers:
additional_answer_for_turn = additional_answers[key]["input_text"][turn_id - 1]
additional_answer_for_turn = additional_answers[key]["input_text"][
turn_id - 1
]
if additional_answer_for_turn.lower() not in map(str.lower, answers):
answers.append(additional_answer_for_turn)
return answers
......@@ -89,12 +93,12 @@ class CoQA(Task):
# ~ 2/3 of the CoQA answers are span-based
# (answers overlap with the passage ignoring punctuation and case mismatch)
if raw_text == "unknown":
return '0'
return "0"
if squad_metrics.normalize_answer(raw_text) == "yes":
return '1'
return "1"
if squad_metrics.normalize_answer(raw_text) == "no":
return '2'
return '3' # Not a yes/no question
return "2"
return "3" # Not a yes/no question
@staticmethod
def compute_scores(gold_list, pred):
......@@ -104,25 +108,30 @@ class CoQA(Task):
em_sum = 0.0
if len(gold_list) > 1:
for i in range(len(gold_list)):
gold_answers = gold_list[0:i] + gold_list[i + 1:]
gold_answers = gold_list[0:i] + gold_list[i + 1 :]
# predictions compared against (n) golds and take maximum
em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_answers)
em_sum += max(
squad_metrics.compute_exact(a, pred) for a in gold_answers
)
f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_answers)
else:
em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_list)
f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_list)
return {'em': em_sum / max(1, len(gold_list)), 'f1': f1_sum / max(1, len(gold_list))}
return {
"em": em_sum / max(1, len(gold_list)),
"f1": f1_sum / max(1, len(gold_list)),
}
def doc_to_target(self, doc, turnid=None):
# Default to prediction of last turn.
if turnid is None:
turnid = len(doc["questions"]["input_text"])
raw_text = doc['answers']["input_text"][turnid - 1]
raw_text = doc["answers"]["input_text"][turnid - 1]
return " " + raw_text
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
......@@ -132,7 +141,7 @@ class CoQA(Task):
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
cont_request = rf.greedy_until(ctx, ['\nQ:'])
cont_request = rf.greedy_until(ctx, ["\nQ:"])
return cont_request
def process_results(self, doc, results):
......@@ -147,13 +156,13 @@ class CoQA(Task):
"""
turn_id = len(doc["questions"]["input_text"])
gold_list = self.get_answers(doc, turn_id)
pred = results[0].strip().split('\n')[0]
pred = results[0].strip().split("\n")[0]
scores = self.compute_scores(gold_list, pred)
return {
"f1": scores['f1'],
"em": scores['em'],
"f1": scores["f1"],
"em": scores["em"],
}
def higher_is_better(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment