Commit 121b7096 authored by Fabrizio Milo's avatar Fabrizio Milo
Browse files

add pre-commit

parent 7a038118
......@@ -65,13 +65,14 @@ class TruthfulqaConfig(datasets.BuilderConfig):
class Truthfulqa(datasets.GeneratorBasedBuilder):
"""TruthfulQA is a benchmark to measure whether a language model is truthful in
generating answers to questions."""
generating answers to questions."""
BUILDER_CONFIGS = [
TruthfulqaConfig(
name="multiple_choice",
url="https://raw.githubusercontent.com/sylinrl/TruthfulQA/013686a06be7a7bde5bf8223943e106c7250123c/data/mc_task.json",
features=datasets.Features({
features=datasets.Features(
{
"question": datasets.Value("string"),
"mc1_targets": {
"choices": datasets.features.Sequence(datasets.Value("string")),
......@@ -80,23 +81,30 @@ generating answers to questions."""
"mc2_targets": {
"choices": datasets.features.Sequence(datasets.Value("string")),
"labels": datasets.features.Sequence(datasets.Value("int32")),
},
}
}),
description="The multiple choice TruthfulQA task"
),
description="The multiple choice TruthfulQA task",
),
TruthfulqaConfig(
name="generation",
url="https://raw.githubusercontent.com/sylinrl/TruthfulQA/013686a06be7a7bde5bf8223943e106c7250123c/TruthfulQA.csv",
features=datasets.Features({
features=datasets.Features(
{
"category": datasets.Value("string"),
"question": datasets.Value("string"),
"best_answer": datasets.Value("string"),
"correct_answers": datasets.features.Sequence(datasets.Value("string")),
"incorrect_answers": datasets.features.Sequence(datasets.Value("string")),
"correct_answers": datasets.features.Sequence(
datasets.Value("string")
),
"incorrect_answers": datasets.features.Sequence(
datasets.Value("string")
),
"source": datasets.Value("string"),
}),
description="The generative TruthfulQA task"
)
}
),
description="The generative TruthfulQA task",
),
]
def _info(self):
......@@ -138,15 +146,15 @@ generating answers to questions."""
"mc2_targets": {
"choices": row["mc2_targets"].keys(),
"labels": row["mc2_targets"].values(),
}
},
}
else:
# Generation data is in a `CSV` file.
with open(filepath, newline='') as f:
with open(filepath, newline="") as f:
contents = csv.DictReader(f)
for key, row in enumerate(contents):
# Ensure that references exist.
if not row['Correct Answers'] or not row['Incorrect Answers']:
if not row["Correct Answers"] or not row["Incorrect Answers"]:
continue
yield key, {
"category": row["Category"],
......@@ -154,6 +162,8 @@ generating answers to questions."""
"best_answer": row["Best Answer"],
# split on ";"
"correct_answers": row["Correct Answers"].strip().split(";"),
"incorrect_answers": row["Incorrect Answers"].strip().split(";"),
"incorrect_answers": row["Incorrect Answers"]
.strip()
.split(";"),
"source": row["Source"],
}
......@@ -64,8 +64,9 @@ class Unscramble(datasets.GeneratorBasedBuilder):
VERSION = datasets.Version("0.0.1")
BUILDER_CONFIGS = [
datasets.BuilderConfig(name=name, version=version,
description=_DESCRIPTIONS[name])
datasets.BuilderConfig(
name=name, version=version, description=_DESCRIPTIONS[name]
)
for name, version in zip(_NAMES, [VERSION] * len(_NAMES))
]
......
......@@ -123,86 +123,111 @@ class Wikitext(datasets.GeneratorBasedBuilder):
return [
datasets.SplitGenerator(
name=datasets.Split.TEST,
gen_kwargs={"data_file": os.path.join(
data_dir, "wiki.test.tokens"), "split": "test"},
gen_kwargs={
"data_file": os.path.join(data_dir, "wiki.test.tokens"),
"split": "test",
},
),
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={"data_file": os.path.join(
data_dir, "wiki.train.tokens"), "split": "train"},
gen_kwargs={
"data_file": os.path.join(data_dir, "wiki.train.tokens"),
"split": "train",
},
),
datasets.SplitGenerator(
name=datasets.Split.VALIDATION,
gen_kwargs={"data_file": os.path.join(
data_dir, "wiki.valid.tokens"), "split": "valid"},
gen_kwargs={
"data_file": os.path.join(data_dir, "wiki.valid.tokens"),
"split": "valid",
},
),
]
else:
if self.config.name == "wikitext-103-raw-v1":
data_file = dl_manager.download_and_extract(
self.config.data_url)
data_file = dl_manager.download_and_extract(self.config.data_url)
data_dir = os.path.join(data_file, "wikitext-103-raw")
return [
datasets.SplitGenerator(
name=datasets.Split.TEST,
gen_kwargs={"data_file": os.path.join(
data_dir, "wiki.test.raw"), "split": "test"},
gen_kwargs={
"data_file": os.path.join(data_dir, "wiki.test.raw"),
"split": "test",
},
),
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={"data_file": os.path.join(
data_dir, "wiki.train.raw"), "split": "train"},
gen_kwargs={
"data_file": os.path.join(data_dir, "wiki.train.raw"),
"split": "train",
},
),
datasets.SplitGenerator(
name=datasets.Split.VALIDATION,
gen_kwargs={"data_file": os.path.join(
data_dir, "wiki.valid.raw"), "split": "valid"},
gen_kwargs={
"data_file": os.path.join(data_dir, "wiki.valid.raw"),
"split": "valid",
},
),
]
else:
if self.config.name == "wikitext-2-raw-v1":
data_file = dl_manager.download_and_extract(
self.config.data_url)
data_file = dl_manager.download_and_extract(self.config.data_url)
data_dir = os.path.join(data_file, "wikitext-2-raw")
return [
datasets.SplitGenerator(
name=datasets.Split.TEST,
gen_kwargs={"data_file": os.path.join(
data_dir, "wiki.test.raw"), "split": "test"},
gen_kwargs={
"data_file": os.path.join(data_dir, "wiki.test.raw"),
"split": "test",
},
),
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={"data_file": os.path.join(
data_dir, "wiki.train.raw"), "split": "train"},
gen_kwargs={
"data_file": os.path.join(data_dir, "wiki.train.raw"),
"split": "train",
},
),
datasets.SplitGenerator(
name=datasets.Split.VALIDATION,
gen_kwargs={"data_file": os.path.join(
data_dir, "wiki.valid.raw"), "split": "valid"},
gen_kwargs={
"data_file": os.path.join(data_dir, "wiki.valid.raw"),
"split": "valid",
},
),
]
else:
if self.config.name == "wikitext-2-v1":
data_file = dl_manager.download_and_extract(
self.config.data_url)
self.config.data_url
)
data_dir = os.path.join(data_file, "wikitext-2")
return [
datasets.SplitGenerator(
name=datasets.Split.TEST,
gen_kwargs={"data_file": os.path.join(
data_dir, "wiki.test.tokens"), "split": "test"},
gen_kwargs={
"data_file": os.path.join(
data_dir, "wiki.test.tokens"
),
"split": "test",
},
),
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={
"data_file": os.path.join(data_dir, "wiki.train.tokens"),
"data_file": os.path.join(
data_dir, "wiki.train.tokens"
),
"split": "train",
},
),
datasets.SplitGenerator(
name=datasets.Split.VALIDATION,
gen_kwargs={
"data_file": os.path.join(data_dir, "wiki.valid.tokens"),
"data_file": os.path.join(
data_dir, "wiki.valid.tokens"
),
"split": "valid",
},
),
......@@ -216,12 +241,12 @@ class Wikitext(datasets.GeneratorBasedBuilder):
data = f.read().split("\n")
for line in data:
rline = line.replace("= = =", "===").replace("= =", "==").strip()
if rline.startswith('= ') and rline.strip().endswith(' ='):
page = '\n'.join(ret)
if rline.startswith("= ") and rline.strip().endswith(" ="):
page = "\n".join(ret)
if page.strip():
yield key, {"page": page}
key += 1
ret = []
ret.append(line)
page = '\n'.join(ret)
page = "\n".join(ret)
yield key, {"page": page}
......@@ -8,12 +8,14 @@ import mmap
import tqdm
from pathlib import Path
def json_serial(obj):
"""JSON serializer for objects not serializable by default json code"""
if isinstance(obj, (datetime.datetime,)):
return obj.isoformat()
raise TypeError ("Type %s not serializable" % type(obj))
raise TypeError("Type %s not serializable" % type(obj))
# Modified version of lm_dataformat Archive for single file.
class Archive:
......@@ -22,25 +24,31 @@ class Archive:
dir_name = os.path.dirname(file_path)
if dir_name:
os.makedirs(dir_name, exist_ok=True)
self.fh = open(self.file_path, 'wb')
self.fh = open(self.file_path, "wb")
self.cctx = zstandard.ZstdCompressor(level=compression_level)
self.compressor = self.cctx.stream_writer(self.fh)
def add_data(self, data, meta={}):
self.compressor.write(json.dumps({'text': data, 'meta': meta}, default=json_serial).encode('UTF-8') + b'\n')
self.compressor.write(
json.dumps({"text": data, "meta": meta}, default=json_serial).encode(
"UTF-8"
)
+ b"\n"
)
def commit(self):
self.compressor.flush(zstandard.FLUSH_FRAME)
self.fh.flush()
self.fh.close()
# Modified version of lm_dataformat Reader with self.fh set, allowing peeking for tqdm.
class Reader:
def __init__(self):
pass
def read(self, file, get_meta=False, autojoin_paragraphs=True, para_joiner='\n\n'):
with open(file, 'rb') as fh:
def read(self, file, get_meta=False, autojoin_paragraphs=True, para_joiner="\n\n"):
with open(file, "rb") as fh:
self.fh = fh
cctx = zstandard.ZstdDecompressor()
reader = io.BufferedReader(cctx.stream_reader(fh))
......@@ -52,16 +60,17 @@ class Reader:
yield ob
continue
text = ob['text']
text = ob["text"]
if autojoin_paragraphs and isinstance(text, list):
text = para_joiner.join(text)
if get_meta:
yield text, (ob['meta'] if 'meta' in ob else {})
yield text, (ob["meta"] if "meta" in ob else {})
else:
yield text
class TextArchive:
def __init__(self, file_path, mode="rb+"):
self.file_path = file_path
......@@ -75,12 +84,13 @@ class TextArchive:
self.fh = open(self.file_path, mode)
def add_data(self, data):
self.fh.write(data.encode('UTF-8') + b'\n')
self.fh.write(data.encode("UTF-8") + b"\n")
def commit(self):
self.fh.flush()
self.fh.close()
class TextReader:
def __init__(self, file_path):
self.file_path = file_path
......@@ -90,9 +100,12 @@ class TextReader:
def read_tqdm(self, update_frequency=10000):
current_file_position = 0
line_counter = 0
with open(self.file_path, 'r') as fh, \
tqdm.tqdm(total=os.path.getsize(self.file_path), dynamic_ncols=True,
unit="byte", unit_scale=1) as progress:
with open(self.file_path, "r") as fh, tqdm.tqdm(
total=os.path.getsize(self.file_path),
dynamic_ncols=True,
unit="byte",
unit_scale=1,
) as progress:
with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
for line in iter(mmap_obj.readline, b""):
line = line.decode("utf-8")
......@@ -107,7 +120,7 @@ class TextReader:
def read_and_tell(self):
current_file_position = 0
with open(self.file_path, 'r', encoding="utf8") as fh:
with open(self.file_path, "r", encoding="utf8") as fh:
with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
for line in iter(mmap_obj.readline, b""):
line = line.decode("utf-8")
......@@ -117,14 +130,14 @@ class TextReader:
yield line[:-1], raw_bytes_read
def read(self):
with open(self.file_path, 'r', encoding="utf8") as fh:
with open(self.file_path, "r", encoding="utf8") as fh:
with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
for line in iter(mmap_obj.readline, b""):
line = line.decode("utf-8")
yield line[:-1]
def read_slow(self):
with open(self.file_path, 'r', encoding="utf8") as fh:
with open(self.file_path, "r", encoding="utf8") as fh:
while True:
line = fh.readline()
if line == -1 or line == "":
......@@ -132,6 +145,7 @@ class TextReader:
else:
yield line[:-1]
# Optimized for speed. Decompresses the archive in shell before
# using the mmap'd TextReader.
class ZStdTextReader:
......
......@@ -57,19 +57,27 @@ def get_train_overlap(docs_by_task_set, ngrams_path, limit):
os.mkdir(f"data/{task_name}")
# Check if we've decontaminated this combination before
overlaps_dump_path = get_overlaps_dump_path(task_name, task_set, ngrams_n_size, limit)
overlaps_dump_path = get_overlaps_dump_path(
task_name, task_set, ngrams_n_size, limit
)
if os.path.exists(overlaps_dump_path):
duplicates[(task_name, task_set)] = pickle.load(open(overlaps_dump_path, "rb"))
duplicates[(task_name, task_set)] = pickle.load(
open(overlaps_dump_path, "rb")
)
sets_to_decontaminate -= 1
continue
else:
duplicates[(task_name, task_set)] = set()
# Build/load the task lookup {ngram: set(documents)}.
task_set_lookup_path = f"data/{task_name}/{task_set}_{ngrams_n_size}grams_limit{limit}.lookup"
task_set_lookup_path = (
f"data/{task_name}/{task_set}_{ngrams_n_size}grams_limit{limit}.lookup"
)
if os.path.exists(task_set_lookup_path):
print(f"{task_set_lookup_path} available, loading...")
lookups[(task_name, task_set)] = pickle.load(open(task_set_lookup_path, "rb"))
lookups[(task_name, task_set)] = pickle.load(
open(task_set_lookup_path, "rb")
)
else:
print(f"{task_set_lookup_path} not available, building...")
lookup = collections.defaultdict(set)
......@@ -79,7 +87,7 @@ def get_train_overlap(docs_by_task_set, ngrams_path, limit):
for ngram in ngrams:
lookup[ngram].add(doc_id)
pickle.dump(lookup, open(task_set_lookup_path,"wb"))
pickle.dump(lookup, open(task_set_lookup_path, "wb"))
lookups[(task_name, task_set)] = lookup
elapsed = time.perf_counter() - start
......@@ -115,7 +123,9 @@ def get_train_overlap(docs_by_task_set, ngrams_path, limit):
for line in reader.read_tqdm(): # Scan training set ngrams file
total_ngrams += 1
[ngram, document_id] = line.rsplit(" ", 1)
if ngram != current_ngram: # Only need to match the ngram once in training set
if (
ngram != current_ngram
): # Only need to match the ngram once in training set
unique_ngrams += 1
current_ngram = ngram
if ngram in merged_lookup:
......@@ -123,7 +133,11 @@ def get_train_overlap(docs_by_task_set, ngrams_path, limit):
matching_unique += 1
for task_name, task_set, doc_ids in merged_lookup[ngram]:
task_doc_set = duplicates[(task_name, task_set)]
for doc_id in doc_ids: # Record contamination across all relevant task/set combos
for (
doc_id
) in (
doc_ids
): # Record contamination across all relevant task/set combos
task_doc_set.add(doc_id)
del merged_lookup[ngram] # No point matching again
else:
......@@ -145,9 +159,10 @@ def get_train_overlap(docs_by_task_set, ngrams_path, limit):
# Dump overlaps separately
for (task_name, task_set), doc_ids in duplicates.items():
overlaps_dump_path = get_overlaps_dump_path(task_name, task_set, ngrams_n_size, limit)
pickle.dump(doc_ids, open(overlaps_dump_path,"wb"))
overlaps_dump_path = get_overlaps_dump_path(
task_name, task_set, ngrams_n_size, limit
)
pickle.dump(doc_ids, open(overlaps_dump_path, "wb"))
# Strip task set and return
return {task_name: doc_ids for (task_name, task_set), doc_ids in duplicates.items()}
......@@ -9,6 +9,7 @@ from pprint import pprint
# c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup
try:
import janitor_util
JANITOR_CPP = True
except Exception as e:
print("WARNING: C++ module could not be loaded. Janitor running in python mode")
......@@ -41,6 +42,7 @@ def word_ngrams(s, n):
ngram_seqs = form_ngrams(iter(tokens), n)
return (" ".join(ngram) for ngram in ngram_seqs)
# Does character sequences only - combined faster function to play around with later
# def word_ngrams_indices_combined(sequence, n):
# current_word = ""
......@@ -70,7 +72,7 @@ def split_indices(s):
"""Splits a string on whitespaces and records the indices of each in the original string.
@:return generator((word, (start_idx, end_idx)), ...)
"""
return ((m.group(0), (m.start(), m.end() - 1)) for m in re.finditer(r'\S+', s))
return ((m.group(0), (m.start(), m.end() - 1)) for m in re.finditer(r"\S+", s))
def word_ngrams_indices(s, n):
......@@ -90,10 +92,15 @@ def word_ngrams_indices(s, n):
# ([word, word, ...], [(start,end), (start,end), ...]),
# ...
# )
ngram_indices_pairs = (zip(*ngram_with_indices) for ngram_with_indices in ngram_seqs_with_indices)
ngram_indices_pairs = (
zip(*ngram_with_indices) for ngram_with_indices in ngram_seqs_with_indices
)
# Generator of ( (word_ngram, (start, end)), (word_ngram, start, end)), ...)
return ((" ".join(ngram_seq), (indices[0][0], indices[-1][1])) for ngram_seq, indices in ngram_indices_pairs)
return (
(" ".join(ngram_seq), (indices[0][0], indices[-1][1]))
for ngram_seq, indices in ngram_indices_pairs
)
class Janitor:
......@@ -105,7 +112,7 @@ class Janitor:
window_to_remove=200,
too_dirty_cutoff=10,
minimum_slice_length=200,
delete_chars=string.punctuation
delete_chars=string.punctuation,
):
self.ngram_n = ngram_n
self.window_to_remove = window_to_remove
......@@ -121,7 +128,7 @@ class Janitor:
self.translation_table = str.maketrans(
string.ascii_lowercase + string.ascii_uppercase, # These characters
string.ascii_lowercase * 2, # Become these characters
self.delete_chars # These are deleted
self.delete_chars, # These are deleted
)
##############
......@@ -129,14 +136,13 @@ class Janitor:
##############
def save_contamination_ngrams(self, filename):
with open(filename, 'wb') as fp:
with open(filename, "wb") as fp:
pickle.dump(filename, fp)
def load_contamination_ngrams(self, filename):
with open(filename, 'rb') as fp:
with open(filename, "rb") as fp:
self.dirt_ngrams = pickle.load(fp)
##############
# Call these :)
##############
......@@ -171,11 +177,11 @@ class Janitor:
end = min(len(dirty_string), end + self.window_to_remove)
if start - splice_idx > self.minimum_slice_length:
clean_chunks.append(dirty_string[splice_idx: start])
clean_chunks.append(dirty_string[splice_idx:start])
splice_idx = end
if end < len(dirty_string) - self.minimum_slice_length:
clean_chunks.append(dirty_string[end+1:])
clean_chunks.append(dirty_string[end + 1 :])
return clean_chunks
......@@ -184,10 +190,14 @@ class Janitor:
##############
def register_contaminant_cpp(self, dirt_string):
self.dirt_ngrams.update(janitor_util.clean_ngram(dirt_string, self.delete_chars, self.ngram_n))
self.dirt_ngrams.update(
janitor_util.clean_ngram(dirt_string, self.delete_chars, self.ngram_n)
)
def clean_cpp(self, dirty_string):
contamination_indices = janitor_util.clean_ngram_with_indices(dirty_string, self.delete_chars, self.ngram_n)
contamination_indices = janitor_util.clean_ngram_with_indices(
dirty_string, self.delete_chars, self.ngram_n
)
return self._split_chunks(dirty_string, contamination_indices)
##############
......@@ -198,7 +208,9 @@ class Janitor:
return s.translate(self.translation_table)
def register_contaminant_python(self, dirt_string):
self.dirt_ngrams.update(word_ngrams(self.normalize_string(dirt_string), self.ngram_n))
self.dirt_ngrams.update(
word_ngrams(self.normalize_string(dirt_string), self.ngram_n)
)
def clean_python(self, dirty_string):
contamination_indices = (
......
......@@ -11,12 +11,22 @@ import numpy as np
from lm_eval.utils import positional_deprecated, run_task_tests
from lm_eval.decontamination.decontaminate import get_train_overlap
@positional_deprecated
def simple_evaluate(model, model_args=None, tasks=[],
num_fewshot=0, batch_size=None, device=None,
no_cache=False, limit=None, bootstrap_iters=100000,
description_dict=None, check_integrity=False,
decontamination_ngrams_path=None):
def simple_evaluate(
model,
model_args=None,
tasks=[],
num_fewshot=0,
batch_size=None,
device=None,
no_cache=False,
limit=None,
bootstrap_iters=100000,
description_dict=None,
check_integrity=False,
decontamination_ngrams_path=None,
):
"""Instantiate and evaluate a model on a list of tasks.
......@@ -52,17 +62,23 @@ def simple_evaluate(model, model_args=None, tasks=[],
assert tasks != [], "No tasks specified"
if isinstance(model, str):
if model_args is None: model_args = ""
lm = lm_eval.models.get_model(model).create_from_arg_string(model_args, {
'batch_size': batch_size, 'device': device
})
if model_args is None:
model_args = ""
lm = lm_eval.models.get_model(model).create_from_arg_string(
model_args, {"batch_size": batch_size, "device": device}
)
else:
assert isinstance(model, lm_eval.base.LM)
lm = model
if not no_cache:
lm = lm_eval.base.CachingLM(
lm, 'lm_cache/' + model + '_' + model_args.replace('=', '-').replace(',', '_').replace('/', '-') + '.db'
lm,
"lm_cache/"
+ model
+ "_"
+ model_args.replace("=", "-").replace(",", "_").replace("/", "-")
+ ".db",
)
task_dict = lm_eval.tasks.get_task_dict(tasks)
......@@ -89,16 +105,26 @@ def simple_evaluate(model, model_args=None, tasks=[],
"no_cache": no_cache,
"limit": limit,
"bootstrap_iters": bootstrap_iters,
"description_dict": description_dict
"description_dict": description_dict,
}
return results
decontaminate_suffix = "_decontaminate"
@positional_deprecated
def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, bootstrap_iters=100000, description_dict=None,
decontamination_ngrams_path=None):
def evaluate(
lm,
task_dict,
provide_description=None,
num_fewshot=0,
limit=None,
bootstrap_iters=100000,
description_dict=None,
decontamination_ngrams_path=None,
):
"""Instantiate and evaluate a model on a list of tasks.
:param lm: obj
......@@ -124,14 +150,16 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
assert not provide_description # not implemented.
if provide_description is not None:
# nudge people to not specify it at all
print("WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict")
print(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
decontaminate = decontamination_ngrams_path is not None
task_dict_items = [
(name, task)
for name, task in task_dict.items()
if(task.has_validation_docs() or task.has_test_docs())
if (task.has_validation_docs() or task.has_test_docs())
]
results = collections.defaultdict(dict)
......@@ -172,19 +200,22 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
rnd.seed(42)
rnd.shuffle(task_docs)
description = description_dict[task_name] if description_dict and task_name in description_dict else ""
description = (
description_dict[task_name]
if description_dict and task_name in description_dict
else ""
)
for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)):
if decontaminate and task.should_decontaminate():
docs_for_decontamination[(task_name, task_set)].append(task.doc_to_decontamination_query(doc))
docs_for_decontamination[(task_name, task_set)].append(
task.doc_to_decontamination_query(doc)
)
docs[(task_name, doc_id)] = doc
ctx = task.fewshot_context(
doc=doc,
num_fewshot=num_fewshot,
rnd=rnd,
description=description
doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description
)
reqs = task.construct_requests(doc, ctx)
if not isinstance(reqs, (list, tuple)):
......@@ -198,7 +229,9 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
# Compare all tasks/sets at once to ensure a single training set scan
if decontaminate:
print("Finding train/test overlap, please wait...")
overlaps = get_train_overlap(docs_for_decontamination, decontamination_ngrams_path, limit)
overlaps = get_train_overlap(
docs_for_decontamination, decontamination_ngrams_path, limit
)
# all responses for each (task, doc)
process_res_queue = collections.defaultdict(list)
......@@ -212,7 +245,9 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
print("Running", reqtype, "requests")
resps = getattr(lm, reqtype)([req.args for req in reqs])
resps = [x if req.index is None else x[req.index] for x, req in zip(resps, reqs)]
resps = [
x if req.index is None else x[req.index] for x, req in zip(resps, reqs)
]
for resp, (i, task_name, doc, doc_id) in zip(resps, requests_origin[reqtype]):
process_res_queue[(task_name, doc_id)].append((i, resp))
......@@ -241,7 +276,9 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
task = task_dict[task_name]
real_metric = metric # key when looking up the metric with task.aggregation
if metric.endswith(decontaminate_suffix):
real_metric = metric.replace(decontaminate_suffix, "") # decontaminated still uses the same metric
real_metric = metric.replace(
decontaminate_suffix, ""
) # decontaminated still uses the same metric
results[task_name][metric] = task.aggregation()[real_metric](items)
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
......@@ -249,16 +286,15 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
stderr = lm_eval.metrics.stderr_for_metric(
metric=task.aggregation()[real_metric],
bootstrap_iters=min(bootstrap_iters, 1000) if metric in ["bleu", "chrf", "ter"] else bootstrap_iters,
bootstrap_iters=min(bootstrap_iters, 1000)
if metric in ["bleu", "chrf", "ter"]
else bootstrap_iters,
)
if stderr is not None:
results[task_name][metric + "_stderr"] = stderr(items)
return {
"results": dict(results),
"versions": dict(versions)
}
return {"results": dict(results), "versions": dict(versions)}
def make_table(result_dict):
......@@ -280,9 +316,9 @@ def make_table(result_dict):
if m + "_stderr" in dic:
se = dic[m + "_stderr"]
values.append([k, version, m, '%.4f' % v, '±', '%.4f' % se])
values.append([k, version, m, "%.4f" % v, "±", "%.4f" % se])
else:
values.append([k, version, m, '%.4f' % v, '', ''])
values.append([k, version, m, "%.4f" % v, "", ""])
k = ""
version = ""
md_writer.value_matrix = values
......
......@@ -103,6 +103,7 @@ def weighted_mean(items):
def weighted_perplexity(items):
return math.exp(-weighted_mean(items))
def bits_per_byte(items):
return -weighted_mean(items) / math.log(2)
......@@ -184,8 +185,10 @@ def _sacreformat(refs, preds):
return refs, preds
# stderr stuff
class _bootstrap_internal:
def __init__(self, f, n):
self.f = f
......@@ -203,6 +206,7 @@ class _bootstrap_internal:
def bootstrap_stderr(f, xs, iters):
import multiprocessing as mp
pool = mp.Pool(mp.cpu_count())
# this gives a biased estimate of the stderr (i.e w/ the mean, it gives something
# equivalent to stderr calculated without Bessel's correction in the stddev.
......@@ -213,10 +217,15 @@ def bootstrap_stderr(f, xs, iters):
res = []
chunk_size = min(1000, iters)
from tqdm import tqdm
print("bootstrapping for stddev:", f.__name__)
for bootstrap in tqdm(pool.imap(
for bootstrap in tqdm(
pool.imap(
_bootstrap_internal(f, chunk_size),
[(i, xs) for i in range(iters // chunk_size)]), total=iters // chunk_size):
[(i, xs) for i in range(iters // chunk_size)],
),
total=iters // chunk_size,
):
# sample w replacement
res.extend(bootstrap)
......@@ -238,17 +247,13 @@ def stderr_for_metric(metric, bootstrap_iters):
if metric in bootstrappable:
return lambda x: bootstrap_stderr(metric, x, iters=bootstrap_iters)
stderr = {
mean: mean_stderr,
acc_all: acc_all_stderr
}
stderr = {mean: mean_stderr, acc_all: acc_all_stderr}
return stderr.get(metric, None)
def yesno(x):
if x:
return 'yes'
return "yes"
else:
return 'no'
return "no"
......@@ -4,8 +4,15 @@ from lm_eval.base import BaseLM
class HFLM(BaseLM):
def __init__(self, device='cuda', pretrained='gpt2', revision='main', subfolder=None, tokenizer=None, batch_size=1):
def __init__(
self,
device="cuda",
pretrained="gpt2",
revision="main",
subfolder=None,
tokenizer=None,
batch_size=1,
):
super().__init__()
assert isinstance(device, str)
......@@ -20,28 +27,47 @@ class HFLM(BaseLM):
else:
print("Device not specificed")
print(f"Cuda Available? {torch.cuda.is_available()}")
self._device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
self._device = (
torch.device("cuda")
if torch.cuda.is_available()
else torch.device("cpu")
)
# TODO: update this to be less of a hack once subfolder is fixed in HF
self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained(
pretrained, revision=revision + ("/" + subfolder if subfolder is not None else "")
pretrained,
revision=revision + ("/" + subfolder if subfolder is not None else ""),
).to(self.device)
self.gpt2.eval()
# pretrained tokenizer for neo is broken for now so just hard-coding this to gpt2
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
pretrained if tokenizer is None else tokenizer, revision=revision, subfolder=subfolder)
pretrained if tokenizer is None else tokenizer,
revision=revision,
subfolder=subfolder,
)
assert isinstance(self.tokenizer, (
transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast,
transformers.T5Tokenizer, transformers.T5TokenizerFast,
)), "this tokenizer has not been checked for compatibility yet!"
assert isinstance(
self.tokenizer,
(
transformers.GPT2Tokenizer,
transformers.GPT2TokenizerFast,
transformers.T5Tokenizer,
transformers.T5TokenizerFast,
),
), "this tokenizer has not been checked for compatibility yet!"
self.vocab_size = self.tokenizer.vocab_size
if isinstance(self.tokenizer, (transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast)):
assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373], \
self.tokenizer.encode('hello\n\nhello')
if isinstance(
self.tokenizer, (transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast)
):
assert self.tokenizer.encode("hello\n\nhello") == [
31373,
198,
198,
31373,
], self.tokenizer.encode("hello\n\nhello")
# multithreading and batching
self.batch_size_per_gpu = batch_size # todo: adaptive batch size
......@@ -97,10 +123,7 @@ class HFLM(BaseLM):
def _model_generate(self, context, max_length, eos_token_id):
return self.gpt2.generate(
context,
max_length=max_length,
eos_token_id=eos_token_id,
do_sample=False
context, max_length=max_length, eos_token_id=eos_token_id, do_sample=False
)
......
......@@ -36,17 +36,19 @@ def get_result(response, ctxlen):
def oa_completion(**kwargs):
""" Query OpenAI API for completion.
"""Query OpenAI API for completion.
Retry with back-off until they respond
"""
import openai
backoff_time = 3
while True:
try:
return openai.Completion.create(**kwargs)
except openai.error.OpenAIError:
import traceback
traceback.print_exc()
time.sleep(backoff_time)
backoff_time *= 1.5
......@@ -66,16 +68,19 @@ class GPT3LM(BaseLM):
super().__init__()
import openai
self.engine = engine
self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained("gpt2")
self.vocab_size = self.tokenizer.vocab_size
# to make the annoying "Using pad_token, but it is not set yet." error go away
self.tokenizer.pad_token = "<|endoftext|>"
assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373]
assert self.tokenizer.encode("hello\n\nhello") == [31373, 198, 198, 31373]
self.truncate = truncate
self.end_of_text_token_id = self.tokenizer.convert_tokens_to_ids(["<|endoftext|>"])[0]
self.end_of_text_token_id = self.tokenizer.convert_tokens_to_ids(
["<|endoftext|>"]
)[0]
# Read from environment variable OPENAI_API_SECRET_KEY
openai.api_key = os.environ["OPENAI_API_SECRET_KEY"]
......@@ -121,14 +126,19 @@ class GPT3LM(BaseLM):
reord = utils.Reorderer(requests, _collate)
for chunk in tqdm(list(utils.chunks(reord.get_reordered(), self.REQ_CHUNK_SIZE)), disable=disable_tqdm):
for chunk in tqdm(
list(utils.chunks(reord.get_reordered(), self.REQ_CHUNK_SIZE)),
disable=disable_tqdm,
):
inps = []
ctxlens = []
for cache_key, context_enc, continuation_enc in chunk:
# max_length+1 because the API takes up to 2049 tokens, including the first context token
inp = (context_enc + continuation_enc)[-(self.max_length+1):]
inp = (context_enc + continuation_enc)[-(self.max_length + 1) :]
# TODO: the logic is much simpler if we just look at the length of continuation tokens
ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - (self.max_length+1))
ctxlen = len(context_enc) - max(
0, len(context_enc) + len(continuation_enc) - (self.max_length + 1)
)
inps.append(inp)
ctxlens.append(ctxlen)
......@@ -137,11 +147,14 @@ class GPT3LM(BaseLM):
engine=self.engine,
prompt=inps,
echo=True,
max_tokens=0, temperature=0.,
max_tokens=0,
temperature=0.0,
logprobs=10,
)
for resp, ctxlen, (cache_key, context_enc, continuation_enc) in zip(response.choices, ctxlens, chunk):
for resp, ctxlen, (cache_key, context_enc, continuation_enc) in zip(
response.choices, ctxlens, chunk
):
answer = get_result(resp, ctxlen)
res.append(answer)
......@@ -177,24 +190,26 @@ class GPT3LM(BaseLM):
yield ret, lastuntil
# todo: more intelligent batching for heterogeneous `until`
for chunk, until in tqdm(list(sameuntil_chunks(reord.get_reordered(), self.REQ_CHUNK_SIZE))):
for chunk, until in tqdm(
list(sameuntil_chunks(reord.get_reordered(), self.REQ_CHUNK_SIZE))
):
inps = []
for context, _ in chunk:
context_enc = self.tok_encode(context)
inp = context_enc[-(self.max_length - self.max_gen_toks):]
inp = context_enc[-(self.max_length - self.max_gen_toks) :]
inps.append(inp)
response = oa_completion(
engine=self.engine,
prompt=inps,
max_tokens=self.max_gen_toks,
temperature=0.,
temperature=0.0,
logprobs=10,
stop=until,
)
for resp, (context, until_) in zip(response.choices, chunk):
s = resp['text']
s = resp["text"]
for term in until_:
s = s.split(term)[0]
......
......@@ -59,8 +59,8 @@ from . import storycloze
# 6 total
gpt3_translation_benchmarks = {
"wmt14": ['en-fr', 'fr-en'], # French
"wmt16": ['en-ro', 'ro-en', 'de-en', 'en-de'], # German, Romanian
"wmt14": ["en-fr", "fr-en"], # French
"wmt16": ["en-ro", "ro-en", "de-en", "en-de"], # German, Romanian
}
......@@ -68,7 +68,7 @@ gpt3_translation_benchmarks = {
selected_translation_benchmarks = {
**gpt3_translation_benchmarks,
"wmt20": sacrebleu.get_langpairs_for_testset("wmt20"),
"iwslt17": ['en-ar', 'ar-en'] # Arabic
"iwslt17": ["en-ar", "ar-en"], # Arabic
}
# 319 total
......@@ -92,7 +92,7 @@ TASK_REGISTRY = {
"rte": glue.RTE,
"qnli": glue.QNLI,
"qqp": glue.QQP,
#"stsb": glue.STSB, # not implemented yet
# "stsb": glue.STSB, # not implemented yet
"sst": glue.SST,
"wnli": glue.WNLI,
# SuperGLUE
......@@ -103,34 +103,26 @@ TASK_REGISTRY = {
"record": superglue.ReCoRD,
"wic": superglue.WordsInContext,
"wsc": superglue.SGWinogradSchemaChallenge,
# Order by benchmark/genre?
"coqa": coqa.CoQA,
"drop": drop.DROP,
"lambada": lambada.LAMBADA,
"lambada_cloze": lambada_cloze.LAMBADA_cloze,
# multilingual lambada
**lambada_multilingual.construct_tasks(),
"wikitext": wikitext.WikiText,
# "cbt-cn": cbt.CBTCN, # disabled pending context length fix
# "cbt-ne": cbt.CBTNE, # disabled pending context length fix
"piqa": piqa.PiQA,
"prost": prost.PROST,
"mc_taco": mc_taco.MCTACO,
# Science related
"pubmedqa" : pubmedqa.Pubmed_QA,
"sciq" : sciq.SciQ,
"pubmedqa": pubmedqa.Pubmed_QA,
"sciq": sciq.SciQ,
"qasper": qasper.QASPER,
"qa4mre_2011" : qa4mre.QA4MRE_2011,
"qa4mre_2012" : qa4mre.QA4MRE_2012,
"qa4mre_2013" : qa4mre.QA4MRE_2013,
"qa4mre_2011": qa4mre.QA4MRE_2011,
"qa4mre_2012": qa4mre.QA4MRE_2012,
"qa4mre_2013": qa4mre.QA4MRE_2013,
"triviaqa": triviaqa.TriviaQA,
"arc_easy": arc.ARCEasy,
"arc_challenge": arc.ARCChallenge,
......@@ -152,21 +144,17 @@ TASK_REGISTRY = {
"anli_r1": anli.ANLIRound1,
"anli_r2": anli.ANLIRound2,
"anli_r3": anli.ANLIRound3,
"ethics_cm": hendrycks_ethics.EthicsCM,
"ethics_deontology": hendrycks_ethics.EthicsDeontology,
"ethics_justice": hendrycks_ethics.EthicsJustice,
"ethics_utilitarianism_original": hendrycks_ethics.EthicsUtilitarianismOriginal,
"ethics_utilitarianism": hendrycks_ethics.EthicsUtilitarianism,
"ethics_virtue": hendrycks_ethics.EthicsVirtue,
"truthfulqa_mc": truthfulqa.TruthfulQAMultipleChoice,
"truthfulqa_gen": truthfulqa.TruthfulQAGeneration,
# dialogue
"mutual": mutual.MuTual,
"mutual_plus": mutual.MuTualPlus,
# math
"math_algebra": hendrycks_math.MathAlgebra,
"math_counting_and_prob": hendrycks_math.MathCountingAndProbability,
......@@ -177,7 +165,6 @@ TASK_REGISTRY = {
"math_precalc": hendrycks_math.MathPrecalculus,
"math_asdiv": asdiv.Asdiv,
"gsm8k": gsm8k.GradeSchoolMath8K,
# arithmetic
"arithmetic_2da": arithmetic.Arithmetic2DPlus,
"arithmetic_2ds": arithmetic.Arithmetic2DMinus,
......@@ -191,22 +178,18 @@ TASK_REGISTRY = {
"arithmetic_1dc": arithmetic.Arithmetic1DComposite,
# TODO Perhaps make these groups of tasks
# e.g. anli, arithmetic, openai_translations, harness_translations
# hendrycksTest (57 tasks)
**hendrycks_test.create_all_tasks(),
# e.g. wmt14-fr-en
**translation.create_tasks_from_benchmarks(gpt3_translation_benchmarks),
# chef's selection, mostly wmt20
**translation.create_tasks_from_benchmarks(selected_translation_benchmarks),
# Word Scrambling and Manipulation Tasks
"anagrams1": unscramble.Anagrams1,
"anagrams2": unscramble.Anagrams2,
"cycle_letters": unscramble.CycleLetters,
"random_insertion": unscramble.RandomInsertion,
"reversed_words": unscramble.ReversedWords,
# Pile
"pile_arxiv": pile.PileArxiv,
"pile_books3": pile.PileBooks3,
......@@ -230,7 +213,6 @@ TASK_REGISTRY = {
"pile_ubuntu-irc": pile.PileUbuntuIrc,
"pile_wikipedia": pile.PileWikipedia,
"pile_youtubesubtitles": pile.PileYoutubeSubtitles,
# BLiMP
"blimp_adjunct_island": blimp.BlimpAdjunctIsland,
"blimp_anaphor_gender_agreement": blimp.BlimpAnaphorGenderAgreement,
......@@ -299,7 +281,6 @@ TASK_REGISTRY = {
"blimp_wh_vs_that_no_gap_long_distance": blimp.BlimpWhVsThatNoGapLongDistance,
"blimp_wh_vs_that_with_gap": blimp.BlimpWhVsThatWithGap,
"blimp_wh_vs_that_with_gap_long_distance": blimp.BlimpWhVsThatWithGapLongDistance,
# Requires manual download of data.
# "storycloze_2016": storycloze.StoryCloze2016,
# "storycloze_2018": storycloze.StoryCloze2018,
......@@ -325,17 +306,23 @@ def get_task_name_from_object(task_object):
return name
# this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
return task_object.EVAL_HARNESS_NAME if hasattr(task_object, "EVAL_HARNESS_NAME") else type(task_object).__name__
return (
task_object.EVAL_HARNESS_NAME
if hasattr(task_object, "EVAL_HARNESS_NAME")
else type(task_object).__name__
)
def get_task_dict(task_name_list: List[Union[str, lm_eval.base.Task]]):
task_name_dict = {
task_name: get_task(task_name)()
for task_name in task_name_list if isinstance(task_name, str)
for task_name in task_name_list
if isinstance(task_name, str)
}
task_name_from_object_dict = {
get_task_name_from_object(task_object): task_object
for task_object in task_name_list if not isinstance(task_object, str)
for task_object in task_name_list
if not isinstance(task_object, str)
}
assert set(task_name_dict.keys()).isdisjoint(set(task_name_from_object_dict.keys()))
return {**task_name_dict, **task_name_from_object_dict}
......@@ -64,7 +64,12 @@ class ANLIBase(Task):
# of the prompt (yes, repeating it!). also, " True, False, or Neither?" is directly
# appended onto the question, with no "Answer:" or even a newline. Do we *really*
# want to do it exactly as OA did?
return doc['premise'] + '\nQuestion: ' + doc['hypothesis'] + ' True, False, or Neither?\nAnswer:'
return (
doc["premise"]
+ "\nQuestion: "
+ doc["hypothesis"]
+ " True, False, or Neither?\nAnswer:"
)
def should_decontaminate(self):
return True
......@@ -76,10 +81,10 @@ class ANLIBase(Task):
# True = entailment
# False = contradiction
# Neither = neutral
return " " + ["True", "Neither", "False"][doc['label']]
return " " + ["True", "Neither", "False"][doc["label"]]
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
......@@ -106,9 +111,7 @@ class ANLIBase(Task):
"""
gold = doc["label"]
pred = np.argmax(results)
return {
"acc": pred == gold
}
return {"acc": pred == gold}
def aggregation(self):
"""
......@@ -116,9 +119,7 @@ class ANLIBase(Task):
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
return {
"acc": mean
}
return {"acc": mean}
def higher_is_better(self):
"""
......@@ -126,9 +127,7 @@ class ANLIBase(Task):
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return {
"acc": True
}
return {"acc": True}
class ANLIRound1(ANLIBase):
......
......@@ -67,10 +67,8 @@ class Arithmetic(Task):
return is_prediction
def process_results(self, doc, results):
is_prediction, = results
return {
"acc": is_prediction
}
(is_prediction,) = results
return {"acc": is_prediction}
def aggregation(self):
return {
......@@ -78,9 +76,7 @@ class Arithmetic(Task):
}
def higher_is_better(self):
return {
"acc": True
}
return {"acc": True}
class Arithmetic2DPlus(Arithmetic):
......
......@@ -54,29 +54,28 @@ class Asdiv(Task):
def test_docs(self):
raise NotImplementedError("This dataset has no test docs")
def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None):
def fewshot_context(
self, doc, num_fewshot, provide_description=None, rnd=None, description=None
):
assert num_fewshot == 0, "ASDiv is intended only for the zero-shot setting."
return super().fewshot_context(
doc=doc,
num_fewshot=num_fewshot,
rnd=rnd,
description=description
doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description
)
def doc_to_text(self, doc):
# TODO: add solution-type
return doc['body'] + '\n' + 'Question:' + doc['question'] + '\n' + 'Answer:'
return doc["body"] + "\n" + "Question:" + doc["question"] + "\n" + "Answer:"
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc['body'] + " " + doc['question']
return doc["body"] + " " + doc["question"]
def doc_to_target(self, doc):
# TODO: add formula
answer = doc['answer'].split(' (')[0]
answer = doc["answer"].split(" (")[0]
return " " + answer
def construct_requests(self, doc, ctx):
......@@ -86,16 +85,10 @@ class Asdiv(Task):
def process_results(self, doc, results):
ll, is_greedy = results
return {
'acc': int(is_greedy)
}
return {"acc": int(is_greedy)}
def aggregation(self):
return {
'acc': mean
}
return {"acc": mean}
def higher_is_better(self):
return {
'acc': True
}
return {"acc": True}
......@@ -50,9 +50,13 @@ class BlimpTask(Task):
# trained on this data.
return self.dataset["train"]
def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None):
def fewshot_context(
self, doc, num_fewshot, provide_description=None, rnd=None, description=None
):
assert num_fewshot == 0
assert rnd is not None, "A `random.Random` generator argument must be provided to `rnd`"
assert (
rnd is not None
), "A `random.Random` generator argument must be provided to `rnd`"
assert not provide_description, (
"The `provide_description` arg will be removed in future versions. To prepend "
"a custom description to the context, supply the corresponding string via the "
......@@ -60,7 +64,9 @@ class BlimpTask(Task):
)
if provide_description is not None:
# nudge people to not specify it at all
print("WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict")
print(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
return ""
......
......@@ -86,7 +86,9 @@ class CBTBase(Task):
return ""
def fewshot_examples(self, k, rnd):
assert k == 0, f"CBT is only implemented for the zero-shot setting. Given k={k}."
assert (
k == 0
), f"CBT is only implemented for the zero-shot setting. Given k={k}."
return super().fewshot_examples(k, rnd)
def construct_requests(self, doc, ctx):
......@@ -120,9 +122,7 @@ class CBTBase(Task):
"""
gold = doc["options"].index(doc["answer"])
pred = np.argmax(results)
return {
"acc": pred == gold
}
return {"acc": pred == gold}
def aggregation(self):
"""
......@@ -130,9 +130,7 @@ class CBTBase(Task):
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
return {
"acc": mean
}
return {"acc": mean}
def higher_is_better(self):
"""
......@@ -140,9 +138,7 @@ class CBTBase(Task):
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return {
"acc": True
}
return {"acc": True}
class CBTCN(CBTBase):
......
......@@ -54,8 +54,10 @@ class CoQA(Task):
def doc_to_text(self, doc):
# Given a passage p, the conversation history {q1, a1, . . . qi−1, ai−1}
# and a question qi, the task is to predict the answer ai
doc_text = doc["story"] + '\n\n'
for (q, a) in zip_longest(doc["questions"]["input_text"], doc["answers"]["input_text"][:-1]): # omit target answer ai
doc_text = doc["story"] + "\n\n"
for (q, a) in zip_longest(
doc["questions"]["input_text"], doc["answers"]["input_text"][:-1]
): # omit target answer ai
question = f"Q: {q}\n\n"
answer = f"A: {a}\n\n" if a is not None else "A:"
doc_text += question + answer
......@@ -77,7 +79,9 @@ class CoQA(Task):
additional_answers = doc.get("additional_answers")
if additional_answers:
for key in additional_answers:
additional_answer_for_turn = additional_answers[key]["input_text"][turn_id - 1]
additional_answer_for_turn = additional_answers[key]["input_text"][
turn_id - 1
]
if additional_answer_for_turn.lower() not in map(str.lower, answers):
answers.append(additional_answer_for_turn)
return answers
......@@ -89,12 +93,12 @@ class CoQA(Task):
# ~ 2/3 of the CoQA answers are span-based
# (answers overlap with the passage ignoring punctuation and case mismatch)
if raw_text == "unknown":
return '0'
return "0"
if squad_metrics.normalize_answer(raw_text) == "yes":
return '1'
return "1"
if squad_metrics.normalize_answer(raw_text) == "no":
return '2'
return '3' # Not a yes/no question
return "2"
return "3" # Not a yes/no question
@staticmethod
def compute_scores(gold_list, pred):
......@@ -104,25 +108,30 @@ class CoQA(Task):
em_sum = 0.0
if len(gold_list) > 1:
for i in range(len(gold_list)):
gold_answers = gold_list[0:i] + gold_list[i + 1:]
gold_answers = gold_list[0:i] + gold_list[i + 1 :]
# predictions compared against (n) golds and take maximum
em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_answers)
em_sum += max(
squad_metrics.compute_exact(a, pred) for a in gold_answers
)
f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_answers)
else:
em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_list)
f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_list)
return {'em': em_sum / max(1, len(gold_list)), 'f1': f1_sum / max(1, len(gold_list))}
return {
"em": em_sum / max(1, len(gold_list)),
"f1": f1_sum / max(1, len(gold_list)),
}
def doc_to_target(self, doc, turnid=None):
# Default to prediction of last turn.
if turnid is None:
turnid = len(doc["questions"]["input_text"])
raw_text = doc['answers']["input_text"][turnid - 1]
raw_text = doc["answers"]["input_text"][turnid - 1]
return " " + raw_text
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
......@@ -132,7 +141,7 @@ class CoQA(Task):
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
cont_request = rf.greedy_until(ctx, ['\nQ:'])
cont_request = rf.greedy_until(ctx, ["\nQ:"])
return cont_request
def process_results(self, doc, results):
......@@ -147,13 +156,13 @@ class CoQA(Task):
"""
turn_id = len(doc["questions"]["input_text"])
gold_list = self.get_answers(doc, turn_id)
pred = results[0].strip().split('\n')[0]
pred = results[0].strip().split("\n")[0]
scores = self.compute_scores(gold_list, pred)
return {
"f1": scores['f1'],
"em": scores['em'],
"f1": scores["f1"],
"em": scores["em"],
}
def higher_is_better(self):
......
......@@ -70,21 +70,26 @@ class DROP(Task):
@classmethod
def get_answers(cls, qa):
def _flatten_validated_answers(validated_answers):
""" Flattens a dict of lists of validated answers.
"""Flattens a dict of lists of validated answers.
{"number": ['1', '8'], ...}
-> [{"number": ['1'], ...}, {"number": ['8'], ...}]
"""
vas = []
for i in range(len(validated_answers["number"])):
vas.append({
vas.append(
{
"number": validated_answers["number"][i],
"date": validated_answers["date"][i],
"spans": validated_answers["spans"][i],
})
}
)
return vas
answers = []
answers_set = set()
candidates = [qa["answer"]] + _flatten_validated_answers(qa["validated_answers"])
candidates = [qa["answer"]] + _flatten_validated_answers(
qa["validated_answers"]
)
for candidate in candidates:
answer = cls.parse_answer(candidate)
if answer in answers_set:
......@@ -100,9 +105,11 @@ class DROP(Task):
return (str(answer["number"]),)
if answer["spans"] != []:
return tuple(answer["spans"])
return (" ".join([answer["date"]["day"],
answer["date"]["month"],
answer["date"]["year"]]).strip(),)
return (
" ".join(
[answer["date"]["day"], answer["date"]["month"], answer["date"]["year"]]
).strip(),
)
def doc_to_text(self, doc):
return f"Passage: {doc['passage']}\nQuestion: {doc['question']}\nAnswer:"
......@@ -111,7 +118,7 @@ class DROP(Task):
return True
def doc_to_decontamination_query(self, doc):
return doc['passage'] + " " + doc['question']
return doc["passage"] + " " + doc["question"]
def doc_to_target(self, doc):
return " " + ", ".join(doc["answers"][0])
......@@ -148,10 +155,7 @@ class DROP(Task):
if gold_answer[0].strip():
max_em = max(max_em, exact_match)
max_f1 = max(max_f1, f1_score)
return {
"em": max_em,
"f1": max_f1
}
return {"em": max_em, "f1": max_f1}
def get_metrics(self, predicted, gold):
"""
......@@ -164,7 +168,9 @@ class DROP(Task):
predicted_bags = self._answer_to_bags(predicted)
gold_bags = self._answer_to_bags(gold)
if set(predicted_bags[0]) == set(gold_bags[0]) and len(predicted_bags[0]) == len(gold_bags[0]):
if set(predicted_bags[0]) == set(gold_bags[0]) and len(
predicted_bags[0]
) == len(gold_bags[0]):
exact_match = 1.0
else:
exact_match = 0.0
......@@ -196,7 +202,9 @@ class DROP(Task):
for gold_index, gold_item in enumerate(gold):
for pred_index, pred_item in enumerate(predicted):
if self._match_numbers_if_present(gold_item, pred_item):
scores[gold_index, pred_index] = self._compute_f1(pred_item, gold_item)
scores[gold_index, pred_index] = self._compute_f1(
pred_item, gold_item
)
row_ind, col_ind = linear_sum_assignment(-scores)
max_scores = np.zeros([max(len(gold), len(predicted))])
......@@ -262,7 +270,11 @@ class DROP(Task):
def _normalize(self, answer):
tokens = [
self._white_space_fix(self._remove_articles(self._fix_number(self._remove_punc(token.lower()))))
self._white_space_fix(
self._remove_articles(
self._fix_number(self._remove_punc(token.lower()))
)
)
for token in self._tokenize(answer)
]
tokens = [token for token in tokens if token.strip()]
......@@ -275,10 +287,7 @@ class DROP(Task):
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
return {
"em": mean,
"f1": mean
}
return {"em": mean, "f1": mean}
def higher_is_better(self):
"""
......@@ -286,7 +295,4 @@ class DROP(Task):
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return {
"em": True,
"f1": True
}
return {"em": True, "f1": True}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment