Commit 1f8a8c1d authored by jon-tow's avatar jon-tow
Browse files

Merge branch 'master' of https://github.com/EleutherAI/lm-evaluation-harness into remove-dataset

parents b4c0275d b0acb337
...@@ -123,86 +123,111 @@ class Wikitext(datasets.GeneratorBasedBuilder): ...@@ -123,86 +123,111 @@ class Wikitext(datasets.GeneratorBasedBuilder):
return [ return [
datasets.SplitGenerator( datasets.SplitGenerator(
name=datasets.Split.TEST, name=datasets.Split.TEST,
gen_kwargs={"data_file": os.path.join( gen_kwargs={
data_dir, "wiki.test.tokens"), "split": "test"}, "data_file": os.path.join(data_dir, "wiki.test.tokens"),
"split": "test",
},
), ),
datasets.SplitGenerator( datasets.SplitGenerator(
name=datasets.Split.TRAIN, name=datasets.Split.TRAIN,
gen_kwargs={"data_file": os.path.join( gen_kwargs={
data_dir, "wiki.train.tokens"), "split": "train"}, "data_file": os.path.join(data_dir, "wiki.train.tokens"),
"split": "train",
},
), ),
datasets.SplitGenerator( datasets.SplitGenerator(
name=datasets.Split.VALIDATION, name=datasets.Split.VALIDATION,
gen_kwargs={"data_file": os.path.join( gen_kwargs={
data_dir, "wiki.valid.tokens"), "split": "valid"}, "data_file": os.path.join(data_dir, "wiki.valid.tokens"),
"split": "valid",
},
), ),
] ]
else: else:
if self.config.name == "wikitext-103-raw-v1": if self.config.name == "wikitext-103-raw-v1":
data_file = dl_manager.download_and_extract( data_file = dl_manager.download_and_extract(self.config.data_url)
self.config.data_url)
data_dir = os.path.join(data_file, "wikitext-103-raw") data_dir = os.path.join(data_file, "wikitext-103-raw")
return [ return [
datasets.SplitGenerator( datasets.SplitGenerator(
name=datasets.Split.TEST, name=datasets.Split.TEST,
gen_kwargs={"data_file": os.path.join( gen_kwargs={
data_dir, "wiki.test.raw"), "split": "test"}, "data_file": os.path.join(data_dir, "wiki.test.raw"),
"split": "test",
},
), ),
datasets.SplitGenerator( datasets.SplitGenerator(
name=datasets.Split.TRAIN, name=datasets.Split.TRAIN,
gen_kwargs={"data_file": os.path.join( gen_kwargs={
data_dir, "wiki.train.raw"), "split": "train"}, "data_file": os.path.join(data_dir, "wiki.train.raw"),
"split": "train",
},
), ),
datasets.SplitGenerator( datasets.SplitGenerator(
name=datasets.Split.VALIDATION, name=datasets.Split.VALIDATION,
gen_kwargs={"data_file": os.path.join( gen_kwargs={
data_dir, "wiki.valid.raw"), "split": "valid"}, "data_file": os.path.join(data_dir, "wiki.valid.raw"),
"split": "valid",
},
), ),
] ]
else: else:
if self.config.name == "wikitext-2-raw-v1": if self.config.name == "wikitext-2-raw-v1":
data_file = dl_manager.download_and_extract( data_file = dl_manager.download_and_extract(self.config.data_url)
self.config.data_url)
data_dir = os.path.join(data_file, "wikitext-2-raw") data_dir = os.path.join(data_file, "wikitext-2-raw")
return [ return [
datasets.SplitGenerator( datasets.SplitGenerator(
name=datasets.Split.TEST, name=datasets.Split.TEST,
gen_kwargs={"data_file": os.path.join( gen_kwargs={
data_dir, "wiki.test.raw"), "split": "test"}, "data_file": os.path.join(data_dir, "wiki.test.raw"),
"split": "test",
},
), ),
datasets.SplitGenerator( datasets.SplitGenerator(
name=datasets.Split.TRAIN, name=datasets.Split.TRAIN,
gen_kwargs={"data_file": os.path.join( gen_kwargs={
data_dir, "wiki.train.raw"), "split": "train"}, "data_file": os.path.join(data_dir, "wiki.train.raw"),
"split": "train",
},
), ),
datasets.SplitGenerator( datasets.SplitGenerator(
name=datasets.Split.VALIDATION, name=datasets.Split.VALIDATION,
gen_kwargs={"data_file": os.path.join( gen_kwargs={
data_dir, "wiki.valid.raw"), "split": "valid"}, "data_file": os.path.join(data_dir, "wiki.valid.raw"),
"split": "valid",
},
), ),
] ]
else: else:
if self.config.name == "wikitext-2-v1": if self.config.name == "wikitext-2-v1":
data_file = dl_manager.download_and_extract( data_file = dl_manager.download_and_extract(
self.config.data_url) self.config.data_url
)
data_dir = os.path.join(data_file, "wikitext-2") data_dir = os.path.join(data_file, "wikitext-2")
return [ return [
datasets.SplitGenerator( datasets.SplitGenerator(
name=datasets.Split.TEST, name=datasets.Split.TEST,
gen_kwargs={"data_file": os.path.join( gen_kwargs={
data_dir, "wiki.test.tokens"), "split": "test"}, "data_file": os.path.join(
data_dir, "wiki.test.tokens"
),
"split": "test",
},
), ),
datasets.SplitGenerator( datasets.SplitGenerator(
name=datasets.Split.TRAIN, name=datasets.Split.TRAIN,
gen_kwargs={ gen_kwargs={
"data_file": os.path.join(data_dir, "wiki.train.tokens"), "data_file": os.path.join(
data_dir, "wiki.train.tokens"
),
"split": "train", "split": "train",
}, },
), ),
datasets.SplitGenerator( datasets.SplitGenerator(
name=datasets.Split.VALIDATION, name=datasets.Split.VALIDATION,
gen_kwargs={ gen_kwargs={
"data_file": os.path.join(data_dir, "wiki.valid.tokens"), "data_file": os.path.join(
data_dir, "wiki.valid.tokens"
),
"split": "valid", "split": "valid",
}, },
), ),
...@@ -216,12 +241,12 @@ class Wikitext(datasets.GeneratorBasedBuilder): ...@@ -216,12 +241,12 @@ class Wikitext(datasets.GeneratorBasedBuilder):
data = f.read().split("\n") data = f.read().split("\n")
for line in data: for line in data:
rline = line.replace("= = =", "===").replace("= =", "==").strip() rline = line.replace("= = =", "===").replace("= =", "==").strip()
if rline.startswith('= ') and rline.strip().endswith(' ='): if rline.startswith("= ") and rline.strip().endswith(" ="):
page = '\n'.join(ret) page = "\n".join(ret)
if page.strip(): if page.strip():
yield key, {"page": page} yield key, {"page": page}
key += 1 key += 1
ret = [] ret = []
ret.append(line) ret.append(line)
page = '\n'.join(ret) page = "\n".join(ret)
yield key, {"page": page} yield key, {"page": page}
...@@ -4,13 +4,18 @@ import json ...@@ -4,13 +4,18 @@ import json
import jsonlines import jsonlines
import io import io
import datetime import datetime
import mmap
import tqdm
from pathlib import Path
def json_serial(obj): def json_serial(obj):
"""JSON serializer for objects not serializable by default json code""" """JSON serializer for objects not serializable by default json code"""
if isinstance(obj, (datetime.datetime,)): if isinstance(obj, (datetime.datetime,)):
return obj.isoformat() return obj.isoformat()
raise TypeError ("Type %s not serializable" % type(obj)) raise TypeError("Type %s not serializable" % type(obj))
# Modified version of lm_dataformat Archive for single file. # Modified version of lm_dataformat Archive for single file.
class Archive: class Archive:
...@@ -19,25 +24,31 @@ class Archive: ...@@ -19,25 +24,31 @@ class Archive:
dir_name = os.path.dirname(file_path) dir_name = os.path.dirname(file_path)
if dir_name: if dir_name:
os.makedirs(dir_name, exist_ok=True) os.makedirs(dir_name, exist_ok=True)
self.fh = open(self.file_path, 'wb') self.fh = open(self.file_path, "wb")
self.cctx = zstandard.ZstdCompressor(level=compression_level) self.cctx = zstandard.ZstdCompressor(level=compression_level)
self.compressor = self.cctx.stream_writer(self.fh) self.compressor = self.cctx.stream_writer(self.fh)
def add_data(self, data, meta={}): def add_data(self, data, meta={}):
self.compressor.write(json.dumps({'text': data, 'meta': meta}, default=json_serial).encode('UTF-8') + b'\n') self.compressor.write(
json.dumps({"text": data, "meta": meta}, default=json_serial).encode(
"UTF-8"
)
+ b"\n"
)
def commit(self): def commit(self):
self.compressor.flush(zstandard.FLUSH_FRAME) self.compressor.flush(zstandard.FLUSH_FRAME)
self.fh.flush() self.fh.flush()
self.fh.close() self.fh.close()
# Modified version of lm_dataformat Reader with self.fh set, allowing peeking for tqdm. # Modified version of lm_dataformat Reader with self.fh set, allowing peeking for tqdm.
class Reader: class Reader:
def __init__(self): def __init__(self):
pass pass
def read(self, file, get_meta=False, autojoin_paragraphs=True, para_joiner='\n\n'): def read(self, file, get_meta=False, autojoin_paragraphs=True, para_joiner="\n\n"):
with open(file, 'rb') as fh: with open(file, "rb") as fh:
self.fh = fh self.fh = fh
cctx = zstandard.ZstdDecompressor() cctx = zstandard.ZstdDecompressor()
reader = io.BufferedReader(cctx.stream_reader(fh)) reader = io.BufferedReader(cctx.stream_reader(fh))
...@@ -49,42 +60,102 @@ class Reader: ...@@ -49,42 +60,102 @@ class Reader:
yield ob yield ob
continue continue
text = ob['text'] text = ob["text"]
if autojoin_paragraphs and isinstance(text, list): if autojoin_paragraphs and isinstance(text, list):
text = para_joiner.join(text) text = para_joiner.join(text)
if get_meta: if get_meta:
yield text, (ob['meta'] if 'meta' in ob else {}) yield text, (ob["meta"] if "meta" in ob else {})
else: else:
yield text yield text
# Simple text reader and writer with same interface as above
class TextArchive: class TextArchive:
def __init__(self, file_path, mode="ab"): def __init__(self, file_path, mode="rb+"):
self.file_path = file_path self.file_path = file_path
dir_name = os.path.dirname(file_path) dir_name = os.path.dirname(file_path)
if dir_name: if dir_name:
os.makedirs(dir_name, exist_ok=True) os.makedirs(dir_name, exist_ok=True)
if not os.path.exists(file_path):
Path(file_path).touch()
self.fh = open(self.file_path, mode) self.fh = open(self.file_path, mode)
def add_data(self, data, meta={}): def add_data(self, data):
self.fh.write(data.encode('UTF-8') + b'\n') self.fh.write(data.encode("UTF-8") + b"\n")
def commit(self): def commit(self):
self.fh.flush() self.fh.flush()
self.fh.close() self.fh.close()
class TextReader: class TextReader:
def __init__(self, file_path): def __init__(self, file_path):
self.file_path = file_path self.file_path = file_path
# Optimized mmap read with infrequent tqdm updates to maintain speed
# Tested up to 250MB/s.
def read_tqdm(self, update_frequency=10000):
current_file_position = 0
line_counter = 0
with open(self.file_path, "r") as fh, tqdm.tqdm(
total=os.path.getsize(self.file_path),
dynamic_ncols=True,
unit="byte",
unit_scale=1,
) as progress:
with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
for line in iter(mmap_obj.readline, b""):
line = line.decode("utf-8")
line_counter += 1
if line_counter == update_frequency:
new_file_pos = mmap_obj.tell()
bytes_read = new_file_pos - current_file_position
current_file_position = new_file_pos
progress.update(bytes_read)
line_counter = 0
yield line[:-1]
def read_and_tell(self):
current_file_position = 0
with open(self.file_path, "r", encoding="utf8") as fh:
with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
for line in iter(mmap_obj.readline, b""):
line = line.decode("utf-8")
new_file_pos = mmap_obj.tell()
raw_bytes_read = new_file_pos - current_file_position
current_file_position = new_file_pos
yield line[:-1], raw_bytes_read
def read(self): def read(self):
with open(self.file_path, 'r', encoding="utf8") as fh: with open(self.file_path, "r", encoding="utf8") as fh:
self.fh = fh with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
for line in iter(mmap_obj.readline, b""):
line = line.decode("utf-8")
yield line[:-1]
def read_slow(self):
with open(self.file_path, "r", encoding="utf8") as fh:
while True: while True:
line = self.fh.readline() line = fh.readline()
if line == -1 or line == "": if line == -1 or line == "":
break break
else : else:
yield line[:-1] yield line[:-1]
# Optimized for speed. Decompresses the archive in shell before
# using the mmap'd TextReader.
class ZStdTextReader:
def __init__(self, file):
self.file = file
def read_tqdm(self):
decompressed_file = self.file[:-4]
print("Decompressing file, please wait...")
os.system(f"zstd -d {self.file}") # linux decompress is faster
reader = TextReader(decompressed_file)
yield from reader.read_tqdm()
os.remove(decompressed_file)
import time
import random
import pickle
import json
import glob
import os
import collections
from .janitor import Janitor, word_ngrams
from .archiver import ZStdTextReader
# Was used for testing the evaluator decoupled from the full logic below
def get_train_overlap_stub(docs, ngrams_path, ngrams_n_size):
simulated_overlap = 0.1
contaminated = int(len(docs) * simulated_overlap)
return random.sample(range(len(docs)), contaminated)
# Returns a dictionary containing all overlapping documents in each
# task. In the standard use case, an overlap occurs when any of the 13-grams
# found in the task document exist in the training set documents.
#
# To generate 13-grams for the pile see scripts/clean_training_data. The final output of these
# scripts are an info.json file containing the n_gram_size (13) and a bunch of "ngrams_{x}.bkt.txt.sorted.zst"
# files. These should exist in the "ngrams_path" provided to this function.
# Algorithm:
# 1. Build lookups for each dataset {ngram: list(document_ids)}
# 2. Merge into an overall lookup {ngram: [(task_name, task_set, doc_ids),]}
# 3. Full scan the 13-grams from the training set against the merged lookup,
# saving matches in the "duplicates" dictionary {(task_name, task_set): set(doc_ids)}
# 4. Strip the task_set from the dictionary keys and return
#
# We cache the task+set lookups as well as the overlaps.
def get_train_overlap(docs_by_task_set, ngrams_path, limit):
# return get_train_overlap_stub(docs, ngrams_path, ngrams_n_size)
info_dict_path = os.path.join(ngrams_path, "info.json")
info_dict = json.load(open(info_dict_path, "r"))
ngrams_n_size = info_dict["ngram_size"]
janitor = Janitor()
# Build lookup for each dataset first in case we use different task combinations later
print("Building Lookups...")
start = time.perf_counter()
def get_overlaps_dump_path(task_name, task_set, ngrams_n_size, limit):
return f"data/{task_name}/{task_set}_{ngrams_n_size}grams_limit{limit}.overlaps"
lookups = {}
duplicates = {} # (task_name, task_set): set(doc_ids)}
sets_to_decontaminate = len(docs_by_task_set.keys())
for (task_name, task_set), docs in docs_by_task_set.items():
if not os.path.exists(f"data/{task_name}"):
os.mkdir(f"data/{task_name}")
# Check if we've decontaminated this combination before
overlaps_dump_path = get_overlaps_dump_path(
task_name, task_set, ngrams_n_size, limit
)
if os.path.exists(overlaps_dump_path):
duplicates[(task_name, task_set)] = pickle.load(
open(overlaps_dump_path, "rb")
)
sets_to_decontaminate -= 1
continue
else:
duplicates[(task_name, task_set)] = set()
# Build/load the task lookup {ngram: set(documents)}.
task_set_lookup_path = (
f"data/{task_name}/{task_set}_{ngrams_n_size}grams_limit{limit}.lookup"
)
if os.path.exists(task_set_lookup_path):
print(f"{task_set_lookup_path} available, loading...")
lookups[(task_name, task_set)] = pickle.load(
open(task_set_lookup_path, "rb")
)
else:
print(f"{task_set_lookup_path} not available, building...")
lookup = collections.defaultdict(set)
for doc_id, document in enumerate(docs):
ngrams = word_ngrams(janitor.normalize_string(document), ngrams_n_size)
for ngram in ngrams:
lookup[ngram].add(doc_id)
pickle.dump(lookup, open(task_set_lookup_path, "wb"))
lookups[(task_name, task_set)] = lookup
elapsed = time.perf_counter() - start
print(f"Building lookups took {elapsed:0.5f} seconds.")
matched_ngrams = []
if sets_to_decontaminate > 0:
print("Merging lookups...")
start = time.perf_counter()
merged_lookup = collections.defaultdict(list)
for (task_name, task_set), lookup in lookups.items():
for ngram, doc_ids in lookup.items():
merged_lookup[ngram].append((task_name, task_set, doc_ids))
elapsed = time.perf_counter() - start
print(f"Merging lookups took {elapsed:0.5f} seconds.")
print(f"{ngrams_n_size} grams files found in {ngrams_path}:")
files = glob.glob(os.path.join(ngrams_path, f"*.sorted.zst"))
print(files)
for file in files:
start = time.perf_counter()
print(f"Scanning {file}")
reader = ZStdTextReader(file)
total_ngrams = 0
unique_ngrams = 0
matching_unique = 0
non_matching_unique = 0
current_ngram = ""
for line in reader.read_tqdm(): # Scan training set ngrams file
total_ngrams += 1
[ngram, document_id] = line.rsplit(" ", 1)
if (
ngram != current_ngram
): # Only need to match the ngram once in training set
unique_ngrams += 1
current_ngram = ngram
if ngram in merged_lookup:
matched_ngrams.append(ngram) # For logging
matching_unique += 1
for task_name, task_set, doc_ids in merged_lookup[ngram]:
task_doc_set = duplicates[(task_name, task_set)]
for (
doc_id
) in (
doc_ids
): # Record contamination across all relevant task/set combos
task_doc_set.add(doc_id)
del merged_lookup[ngram] # No point matching again
else:
non_matching_unique += 1
print(f"Total Ngrams: {total_ngrams}")
print(f"Unique Ngrams: {unique_ngrams}")
print(f"Unique Matching: {matching_unique}")
print(f"Unique Non Matching: {non_matching_unique}")
print("Matched ngrams:")
for ngram in matched_ngrams:
print(ngram)
elapsed = time.perf_counter() - start
print(f"Read took {elapsed:0.5f} seconds.")
print(f"Speed: {(os.path.getsize(file)/1000000.0)/elapsed}MB/second")
print(duplicates)
# Dump overlaps separately
for (task_name, task_set), doc_ids in duplicates.items():
overlaps_dump_path = get_overlaps_dump_path(
task_name, task_set, ngrams_n_size, limit
)
pickle.dump(doc_ids, open(overlaps_dump_path, "wb"))
# Strip task set and return
return {task_name: doc_ids for (task_name, task_set), doc_ids in duplicates.items()}
...@@ -9,8 +9,9 @@ from pprint import pprint ...@@ -9,8 +9,9 @@ from pprint import pprint
# c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup # c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup
try: try:
import janitor_util import janitor_util
JANITOR_CPP = True JANITOR_CPP = True
except Exception as e: except Exception:
print("WARNING: C++ module could not be loaded. Janitor running in python mode") print("WARNING: C++ module could not be loaded. Janitor running in python mode")
traceback.print_exc() traceback.print_exc()
JANITOR_CPP = False JANITOR_CPP = False
...@@ -41,6 +42,7 @@ def word_ngrams(s, n): ...@@ -41,6 +42,7 @@ def word_ngrams(s, n):
ngram_seqs = form_ngrams(iter(tokens), n) ngram_seqs = form_ngrams(iter(tokens), n)
return (" ".join(ngram) for ngram in ngram_seqs) return (" ".join(ngram) for ngram in ngram_seqs)
# Does character sequences only - combined faster function to play around with later # Does character sequences only - combined faster function to play around with later
# def word_ngrams_indices_combined(sequence, n): # def word_ngrams_indices_combined(sequence, n):
# current_word = "" # current_word = ""
...@@ -70,7 +72,7 @@ def split_indices(s): ...@@ -70,7 +72,7 @@ def split_indices(s):
"""Splits a string on whitespaces and records the indices of each in the original string. """Splits a string on whitespaces and records the indices of each in the original string.
@:return generator((word, (start_idx, end_idx)), ...) @:return generator((word, (start_idx, end_idx)), ...)
""" """
return ((m.group(0), (m.start(), m.end() - 1)) for m in re.finditer(r'\S+', s)) return ((m.group(0), (m.start(), m.end() - 1)) for m in re.finditer(r"\S+", s))
def word_ngrams_indices(s, n): def word_ngrams_indices(s, n):
...@@ -90,10 +92,15 @@ def word_ngrams_indices(s, n): ...@@ -90,10 +92,15 @@ def word_ngrams_indices(s, n):
# ([word, word, ...], [(start,end), (start,end), ...]), # ([word, word, ...], [(start,end), (start,end), ...]),
# ... # ...
# ) # )
ngram_indices_pairs = (zip(*ngram_with_indices) for ngram_with_indices in ngram_seqs_with_indices) ngram_indices_pairs = (
zip(*ngram_with_indices) for ngram_with_indices in ngram_seqs_with_indices
)
# Generator of ( (word_ngram, (start, end)), (word_ngram, start, end)), ...) # Generator of ( (word_ngram, (start, end)), (word_ngram, start, end)), ...)
return ((" ".join(ngram_seq), (indices[0][0], indices[-1][1])) for ngram_seq, indices in ngram_indices_pairs) return (
(" ".join(ngram_seq), (indices[0][0], indices[-1][1]))
for ngram_seq, indices in ngram_indices_pairs
)
class Janitor: class Janitor:
...@@ -105,7 +112,7 @@ class Janitor: ...@@ -105,7 +112,7 @@ class Janitor:
window_to_remove=200, window_to_remove=200,
too_dirty_cutoff=10, too_dirty_cutoff=10,
minimum_slice_length=200, minimum_slice_length=200,
delete_chars=string.punctuation delete_chars=string.punctuation,
): ):
self.ngram_n = ngram_n self.ngram_n = ngram_n
self.window_to_remove = window_to_remove self.window_to_remove = window_to_remove
...@@ -121,7 +128,7 @@ class Janitor: ...@@ -121,7 +128,7 @@ class Janitor:
self.translation_table = str.maketrans( self.translation_table = str.maketrans(
string.ascii_lowercase + string.ascii_uppercase, # These characters string.ascii_lowercase + string.ascii_uppercase, # These characters
string.ascii_lowercase * 2, # Become these characters string.ascii_lowercase * 2, # Become these characters
self.delete_chars # These are deleted self.delete_chars, # These are deleted
) )
############## ##############
...@@ -129,14 +136,13 @@ class Janitor: ...@@ -129,14 +136,13 @@ class Janitor:
############## ##############
def save_contamination_ngrams(self, filename): def save_contamination_ngrams(self, filename):
with open(filename, 'wb') as fp: with open(filename, "wb") as fp:
pickle.dump(filename, fp) pickle.dump(filename, fp)
def load_contamination_ngrams(self, filename): def load_contamination_ngrams(self, filename):
with open(filename, 'rb') as fp: with open(filename, "rb") as fp:
self.dirt_ngrams = pickle.load(fp) self.dirt_ngrams = pickle.load(fp)
############## ##############
# Call these :) # Call these :)
############## ##############
...@@ -152,7 +158,7 @@ class Janitor: ...@@ -152,7 +158,7 @@ class Janitor:
def clean(self, dirty_string): def clean(self, dirty_string):
"""Clean a string (e.g. a training set) by removing all ngrams previously """Clean a string (e.g. a training set) by removing all ngrams previously
reigstered as contaminants. Returns a list of clean chunks, or empty if registered as contaminants. Returns a list of clean chunks, or empty if
the string was too dirty""" the string was too dirty"""
if JANITOR_CPP: if JANITOR_CPP:
return self.clean_cpp(dirty_string) return self.clean_cpp(dirty_string)
...@@ -171,11 +177,11 @@ class Janitor: ...@@ -171,11 +177,11 @@ class Janitor:
end = min(len(dirty_string), end + self.window_to_remove) end = min(len(dirty_string), end + self.window_to_remove)
if start - splice_idx > self.minimum_slice_length: if start - splice_idx > self.minimum_slice_length:
clean_chunks.append(dirty_string[splice_idx: start]) clean_chunks.append(dirty_string[splice_idx:start])
splice_idx = end splice_idx = end
if end < len(dirty_string) - self.minimum_slice_length: if end < len(dirty_string) - self.minimum_slice_length:
clean_chunks.append(dirty_string[end+1:]) clean_chunks.append(dirty_string[end + 1 :])
return clean_chunks return clean_chunks
...@@ -184,10 +190,14 @@ class Janitor: ...@@ -184,10 +190,14 @@ class Janitor:
############## ##############
def register_contaminant_cpp(self, dirt_string): def register_contaminant_cpp(self, dirt_string):
self.dirt_ngrams.update(janitor_util.clean_ngram(dirt_string, self.delete_chars, self.ngram_n)) self.dirt_ngrams.update(
janitor_util.clean_ngram(dirt_string, self.delete_chars, self.ngram_n)
)
def clean_cpp(self, dirty_string): def clean_cpp(self, dirty_string):
contamination_indices = janitor_util.clean_ngram_with_indices(dirty_string, self.delete_chars, self.ngram_n) contamination_indices = janitor_util.clean_ngram_with_indices(
dirty_string, self.delete_chars, self.ngram_n
)
return self._split_chunks(dirty_string, contamination_indices) return self._split_chunks(dirty_string, contamination_indices)
############## ##############
...@@ -198,7 +208,9 @@ class Janitor: ...@@ -198,7 +208,9 @@ class Janitor:
return s.translate(self.translation_table) return s.translate(self.translation_table)
def register_contaminant_python(self, dirt_string): def register_contaminant_python(self, dirt_string):
self.dirt_ngrams.update(word_ngrams(self.normalize_string(dirt_string), self.ngram_n)) self.dirt_ngrams.update(
word_ngrams(self.normalize_string(dirt_string), self.ngram_n)
)
def clean_python(self, dirty_string): def clean_python(self, dirty_string):
contamination_indices = ( contamination_indices = (
...@@ -263,7 +275,7 @@ class Janitor: ...@@ -263,7 +275,7 @@ class Janitor:
# ultan Al Nahyan]] granted [[Petroleum]] concessions, and oil was first found in 1958. At first, # ultan Al Nahyan]] granted [[Petroleum]] concessions, and oil was first found in 1958. At first,
# oil money had a marginal impact. A few lowrise concete buildings were erected, and the first # oil money had a marginal impact. A few lowrise concete buildings were erected, and the first
# paved road was completed in 1961, but Sheikh Shakbut, uncertain whether the new oil royalties # paved road was completed in 1961, but Sheikh Shakbut, uncertain whether the new oil royalties
# would last, took a cautious approach, prefering to save the revenue rather than investing it in # would last, took a cautious approach, preferring to save the revenue rather than investing it in
# development. His brother, [[Zayed bin Sultan Al Nahayan]], saw that oil wealth had the potential # development. His brother, [[Zayed bin Sultan Al Nahayan]], saw that oil wealth had the potential
# to transform Abu Dhabi. The ruling Al Nahayan family decided that Sheikh Zayed should replace his # to transform Abu Dhabi. The ruling Al Nahayan family decided that Sheikh Zayed should replace his
# brother as Ruler and carry out his vision of developing the country. On [[August 6]], [[1966]], # brother as Ruler and carry out his vision of developing the country. On [[August 6]], [[1966]],
......
import collections import collections
import itertools import itertools
import pathlib import numpy as np
import random import random
import lm_eval.metrics import lm_eval.metrics
import lm_eval.models import lm_eval.models
import lm_eval.tasks import lm_eval.tasks
import lm_eval.base import lm_eval.base
import numpy as np
from lm_eval.utils import positional_deprecated, run_task_tests from lm_eval.utils import positional_deprecated, run_task_tests
@positional_deprecated @positional_deprecated
def simple_evaluate(model, model_args=None, tasks=[], def simple_evaluate(
num_fewshot=0, batch_size=None, device=None, model,
no_cache=False, limit=None, bootstrap_iters=100000, model_args=None,
description_dict=None, check_integrity=False): tasks=[],
num_fewshot=0,
batch_size=None,
device=None,
no_cache=False,
limit=None,
bootstrap_iters=100000,
description_dict=None,
check_integrity=False,
decontamination_ngrams_path=None,
):
"""Instantiate and evaluate a model on a list of tasks. """Instantiate and evaluate a model on a list of tasks.
:param model: Union[str, LM] :param model: Union[str, LM]
...@@ -49,17 +59,23 @@ def simple_evaluate(model, model_args=None, tasks=[], ...@@ -49,17 +59,23 @@ def simple_evaluate(model, model_args=None, tasks=[],
assert tasks != [], "No tasks specified" assert tasks != [], "No tasks specified"
if isinstance(model, str): if isinstance(model, str):
if model_args is None: model_args = "" if model_args is None:
lm = lm_eval.models.get_model(model).create_from_arg_string(model_args, { model_args = ""
'batch_size': batch_size, 'device': device lm = lm_eval.models.get_model(model).create_from_arg_string(
}) model_args, {"batch_size": batch_size, "device": device}
)
else: else:
assert isinstance(model, lm_eval.base.LM) assert isinstance(model, lm_eval.base.LM)
lm = model lm = model
if not no_cache: if not no_cache:
lm = lm_eval.base.CachingLM( lm = lm_eval.base.CachingLM(
lm, 'lm_cache/' + model + '_' + model_args.replace('=', '-').replace(',', '_').replace('/', '-') + '.db' lm,
"lm_cache/"
+ model
+ "_"
+ model_args.replace("=", "-").replace(",", "_").replace("/", "-")
+ ".db",
) )
task_dict = lm_eval.tasks.get_task_dict(tasks) task_dict = lm_eval.tasks.get_task_dict(tasks)
...@@ -72,7 +88,9 @@ def simple_evaluate(model, model_args=None, tasks=[], ...@@ -72,7 +88,9 @@ def simple_evaluate(model, model_args=None, tasks=[],
task_dict=task_dict, task_dict=task_dict,
num_fewshot=num_fewshot, num_fewshot=num_fewshot,
limit=limit, limit=limit,
description_dict=description_dict bootstrap_iters=bootstrap_iters,
description_dict=description_dict,
decontamination_ngrams_path=decontamination_ngrams_path,
) )
# add info about the model and few shot config # add info about the model and few shot config
...@@ -85,14 +103,26 @@ def simple_evaluate(model, model_args=None, tasks=[], ...@@ -85,14 +103,26 @@ def simple_evaluate(model, model_args=None, tasks=[],
"no_cache": no_cache, "no_cache": no_cache,
"limit": limit, "limit": limit,
"bootstrap_iters": bootstrap_iters, "bootstrap_iters": bootstrap_iters,
"description_dict": description_dict "description_dict": description_dict,
} }
return results return results
decontaminate_suffix = "_decontaminate"
@positional_deprecated @positional_deprecated
def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, bootstrap_iters=100000, description_dict=None): def evaluate(
lm,
task_dict,
provide_description=None,
num_fewshot=0,
limit=None,
bootstrap_iters=100000,
description_dict=None,
decontamination_ngrams_path=None,
):
"""Instantiate and evaluate a model on a list of tasks. """Instantiate and evaluate a model on a list of tasks.
:param lm: obj :param lm: obj
...@@ -118,12 +148,16 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, ...@@ -118,12 +148,16 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
assert not provide_description # not implemented. assert not provide_description # not implemented.
if provide_description is not None: if provide_description is not None:
# nudge people to not specify it at all # nudge people to not specify it at all
print("WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict") print(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
decontaminate = decontamination_ngrams_path is not None
task_dict_items = [ task_dict_items = [
(name, task) (name, task)
for name, task in task_dict.items() for name, task in task_dict.items()
if(task.has_validation_docs() or task.has_test_docs()) if (task.has_validation_docs() or task.has_test_docs())
] ]
results = collections.defaultdict(dict) results = collections.defaultdict(dict)
...@@ -132,6 +166,8 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, ...@@ -132,6 +166,8 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
requests = collections.defaultdict(list) requests = collections.defaultdict(list)
requests_origin = collections.defaultdict(list) requests_origin = collections.defaultdict(list)
overlaps = collections.defaultdict(list) # {task_name: contaminated_docs}
# If we ever run into issues where the eval tasks don't fit in memory and we can't afford a machine with bigger # If we ever run into issues where the eval tasks don't fit in memory and we can't afford a machine with bigger
# memory, we can always modify this plumbing to support that, but I didn't want to include it just yet because # memory, we can always modify this plumbing to support that, but I didn't want to include it just yet because
# over-engineering is bad (or we could make it write the requests to disk and then read them back out again # over-engineering is bad (or we could make it write the requests to disk and then read them back out again
...@@ -140,6 +176,8 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, ...@@ -140,6 +176,8 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
# TODO: we need unit tests & sanity checks or something to ensure that the return of `validation_docs` is stable # TODO: we need unit tests & sanity checks or something to ensure that the return of `validation_docs` is stable
docs = {} docs = {}
docs_for_decontamination = collections.defaultdict(list)
# get lists of each type of request # get lists of each type of request
for task_name, task in task_dict_items: for task_name, task in task_dict_items:
versions[task_name] = task.VERSION versions[task_name] = task.VERSION
...@@ -147,7 +185,9 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, ...@@ -147,7 +185,9 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
# TODO: the test-fallback-to-val system isn't final, we should revisit it at some point # TODO: the test-fallback-to-val system isn't final, we should revisit it at some point
if task.has_test_docs(): if task.has_test_docs():
task_doc_func = task.test_docs task_doc_func = task.test_docs
task_set = "test" # Required for caching in the decontamination
elif task.has_validation_docs(): elif task.has_validation_docs():
task_set = "val" # Required for caching in the decontamination
task_doc_func = task.validation_docs task_doc_func = task.validation_docs
else: else:
raise RuntimeError("Task has neither test_docs nor validation_docs") raise RuntimeError("Task has neither test_docs nor validation_docs")
...@@ -158,15 +198,22 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, ...@@ -158,15 +198,22 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
rnd.seed(42) rnd.seed(42)
rnd.shuffle(task_docs) rnd.shuffle(task_docs)
description = description_dict[task_name] if description_dict and task_name in description_dict else "" description = (
description_dict[task_name]
if description_dict and task_name in description_dict
else ""
)
for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)): for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)):
if decontaminate and task.should_decontaminate():
docs_for_decontamination[(task_name, task_set)].append(
task.doc_to_decontamination_query(doc)
)
docs[(task_name, doc_id)] = doc docs[(task_name, doc_id)] = doc
ctx = task.fewshot_context( ctx = task.fewshot_context(
doc=doc, doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description
num_fewshot=num_fewshot,
rnd=rnd,
description=description
) )
reqs = task.construct_requests(doc, ctx) reqs = task.construct_requests(doc, ctx)
if not isinstance(reqs, (list, tuple)): if not isinstance(reqs, (list, tuple)):
...@@ -177,6 +224,15 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, ...@@ -177,6 +224,15 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
# doc_id: unique id that we can get back to a doc using `docs` # doc_id: unique id that we can get back to a doc using `docs`
requests_origin[req.request_type].append((i, task_name, doc, doc_id)) requests_origin[req.request_type].append((i, task_name, doc, doc_id))
# Compare all tasks/sets at once to ensure a single training set scan
if decontaminate:
from lm_eval.decontamination.decontaminate import get_train_overlap
print("Finding train/test overlap, please wait...")
overlaps = get_train_overlap(
docs_for_decontamination, decontamination_ngrams_path, limit
)
# all responses for each (task, doc) # all responses for each (task, doc)
process_res_queue = collections.defaultdict(list) process_res_queue = collections.defaultdict(list)
...@@ -189,7 +245,9 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, ...@@ -189,7 +245,9 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
print("Running", reqtype, "requests") print("Running", reqtype, "requests")
resps = getattr(lm, reqtype)([req.args for req in reqs]) resps = getattr(lm, reqtype)([req.args for req in reqs])
resps = [x if req.index is None else x[req.index] for x, req in zip(resps, reqs)] resps = [
x if req.index is None else x[req.index] for x, req in zip(resps, reqs)
]
for resp, (i, task_name, doc, doc_id) in zip(resps, requests_origin[reqtype]): for resp, (i, task_name, doc, doc_id) in zip(resps, requests_origin[reqtype]):
process_res_queue[(task_name, doc_id)].append((i, resp)) process_res_queue[(task_name, doc_id)].append((i, resp))
...@@ -208,24 +266,35 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, ...@@ -208,24 +266,35 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
for metric, value in metrics.items(): for metric, value in metrics.items():
vals[(task_name, metric)].append(value) vals[(task_name, metric)].append(value)
# Re-use the evaluation for the decontaminated set by just ignoring the overlaps
if decontaminate and task_name in overlaps:
if doc_id not in overlaps[task_name]:
vals[(task_name, metric + decontaminate_suffix)].append(value)
# aggregate results # aggregate results
for (task_name, metric), items in vals.items(): for (task_name, metric), items in vals.items():
task = task_dict[task_name] task = task_dict[task_name]
results[task_name][metric] = task.aggregation()[metric](items) real_metric = metric # key when looking up the metric with task.aggregation
if metric.endswith(decontaminate_suffix):
real_metric = metric.replace(
decontaminate_suffix, ""
) # decontaminated still uses the same metric
results[task_name][metric] = task.aggregation()[real_metric](items)
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this # so we run them less iterations. still looking for a cleaner way to do this
stderr = lm_eval.metrics.stderr_for_metric( stderr = lm_eval.metrics.stderr_for_metric(
metric=task.aggregation()[metric], metric=task.aggregation()[real_metric],
bootstrap_iters=min(bootstrap_iters, 1000) if metric in ["bleu", "chrf", "ter"] else bootstrap_iters, bootstrap_iters=min(bootstrap_iters, 1000)
if metric in ["bleu", "chrf", "ter"]
else bootstrap_iters,
) )
if stderr is not None: if stderr is not None:
results[task_name][metric + "_stderr"] = stderr(items) results[task_name][metric + "_stderr"] = stderr(items)
return { return {"results": dict(results), "versions": dict(versions)}
"results": dict(results),
"versions": dict(versions)
}
def make_table(result_dict): def make_table(result_dict):
...@@ -247,9 +316,9 @@ def make_table(result_dict): ...@@ -247,9 +316,9 @@ def make_table(result_dict):
if m + "_stderr" in dic: if m + "_stderr" in dic:
se = dic[m + "_stderr"] se = dic[m + "_stderr"]
values.append([k, version, m, '%.4f' % v, '±', '%.4f' % se]) values.append([k, version, m, "%.4f" % v, "±", "%.4f" % se])
else: else:
values.append([k, version, m, '%.4f' % v, '', '']) values.append([k, version, m, "%.4f" % v, "", ""])
k = "" k = ""
version = "" version = ""
md_writer.value_matrix = values md_writer.value_matrix = values
......
...@@ -103,6 +103,7 @@ def weighted_mean(items): ...@@ -103,6 +103,7 @@ def weighted_mean(items):
def weighted_perplexity(items): def weighted_perplexity(items):
return math.exp(-weighted_mean(items)) return math.exp(-weighted_mean(items))
def bits_per_byte(items): def bits_per_byte(items):
return -weighted_mean(items) / math.log(2) return -weighted_mean(items) / math.log(2)
...@@ -184,8 +185,10 @@ def _sacreformat(refs, preds): ...@@ -184,8 +185,10 @@ def _sacreformat(refs, preds):
return refs, preds return refs, preds
# stderr stuff # stderr stuff
class _bootstrap_internal: class _bootstrap_internal:
def __init__(self, f, n): def __init__(self, f, n):
self.f = f self.f = f
...@@ -203,6 +206,7 @@ class _bootstrap_internal: ...@@ -203,6 +206,7 @@ class _bootstrap_internal:
def bootstrap_stderr(f, xs, iters): def bootstrap_stderr(f, xs, iters):
import multiprocessing as mp import multiprocessing as mp
pool = mp.Pool(mp.cpu_count()) pool = mp.Pool(mp.cpu_count())
# this gives a biased estimate of the stderr (i.e w/ the mean, it gives something # this gives a biased estimate of the stderr (i.e w/ the mean, it gives something
# equivalent to stderr calculated without Bessel's correction in the stddev. # equivalent to stderr calculated without Bessel's correction in the stddev.
...@@ -213,10 +217,15 @@ def bootstrap_stderr(f, xs, iters): ...@@ -213,10 +217,15 @@ def bootstrap_stderr(f, xs, iters):
res = [] res = []
chunk_size = min(1000, iters) chunk_size = min(1000, iters)
from tqdm import tqdm from tqdm import tqdm
print("bootstrapping for stddev:", f.__name__) print("bootstrapping for stddev:", f.__name__)
for bootstrap in tqdm(pool.imap( for bootstrap in tqdm(
pool.imap(
_bootstrap_internal(f, chunk_size), _bootstrap_internal(f, chunk_size),
[(i, xs) for i in range(iters // chunk_size)]), total=iters // chunk_size): [(i, xs) for i in range(iters // chunk_size)],
),
total=iters // chunk_size,
):
# sample w replacement # sample w replacement
res.extend(bootstrap) res.extend(bootstrap)
...@@ -238,17 +247,13 @@ def stderr_for_metric(metric, bootstrap_iters): ...@@ -238,17 +247,13 @@ def stderr_for_metric(metric, bootstrap_iters):
if metric in bootstrappable: if metric in bootstrappable:
return lambda x: bootstrap_stderr(metric, x, iters=bootstrap_iters) return lambda x: bootstrap_stderr(metric, x, iters=bootstrap_iters)
stderr = { stderr = {mean: mean_stderr, acc_all: acc_all_stderr}
mean: mean_stderr,
acc_all: acc_all_stderr
}
return stderr.get(metric, None) return stderr.get(metric, None)
def yesno(x): def yesno(x):
if x: if x:
return 'yes' return "yes"
else: else:
return 'no' return "no"
...@@ -4,8 +4,15 @@ from lm_eval.base import BaseLM ...@@ -4,8 +4,15 @@ from lm_eval.base import BaseLM
class HFLM(BaseLM): class HFLM(BaseLM):
def __init__(
def __init__(self, device='cuda', pretrained='gpt2', revision='main', subfolder=None, tokenizer=None, batch_size=1): self,
device="cuda",
pretrained="gpt2",
revision="main",
subfolder=None,
tokenizer=None,
batch_size=1,
):
super().__init__() super().__init__()
assert isinstance(device, str) assert isinstance(device, str)
...@@ -13,30 +20,54 @@ class HFLM(BaseLM): ...@@ -13,30 +20,54 @@ class HFLM(BaseLM):
assert isinstance(batch_size, int) assert isinstance(batch_size, int)
if device: if device:
if device not in ["cuda", "cpu"]:
device = int(device)
self._device = torch.device(device) self._device = torch.device(device)
print(f"Using device '{device}'")
else: else:
self._device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') print("Device not specified")
print(f"Cuda Available? {torch.cuda.is_available()}")
self._device = (
torch.device("cuda")
if torch.cuda.is_available()
else torch.device("cpu")
)
# TODO: update this to be less of a hack once subfolder is fixed in HF # TODO: update this to be less of a hack once subfolder is fixed in HF
self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained( self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained(
pretrained, revision=revision + ("/" + subfolder if subfolder is not None else "") pretrained,
revision=revision + ("/" + subfolder if subfolder is not None else ""),
).to(self.device) ).to(self.device)
self.gpt2.eval() self.gpt2.eval()
# pretrained tokenizer for neo is broken for now so just hard-coding this to gpt2 # pretrained tokenizer for neo is broken for now so just hard-coding this to gpt2
self.tokenizer = transformers.AutoTokenizer.from_pretrained( self.tokenizer = transformers.AutoTokenizer.from_pretrained(
pretrained if tokenizer is None else tokenizer, revision=revision, subfolder=subfolder) pretrained if tokenizer is None else tokenizer,
revision=revision,
subfolder=subfolder,
)
assert isinstance(self.tokenizer, ( assert isinstance(
transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast, self.tokenizer,
transformers.T5Tokenizer, transformers.T5TokenizerFast, (
)), "this tokenizer has not been checked for compatibility yet!" transformers.GPT2Tokenizer,
transformers.GPT2TokenizerFast,
transformers.T5Tokenizer,
transformers.T5TokenizerFast,
),
), "this tokenizer has not been checked for compatibility yet!"
self.vocab_size = self.tokenizer.vocab_size self.vocab_size = self.tokenizer.vocab_size
if isinstance(self.tokenizer, (transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast)): if isinstance(
assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373], \ self.tokenizer, (transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast)
self.tokenizer.encode('hello\n\nhello') ):
assert self.tokenizer.encode("hello\n\nhello") == [
31373,
198,
198,
31373,
], self.tokenizer.encode("hello\n\nhello")
# multithreading and batching # multithreading and batching
self.batch_size_per_gpu = batch_size # todo: adaptive batch size self.batch_size_per_gpu = batch_size # todo: adaptive batch size
...@@ -92,10 +123,7 @@ class HFLM(BaseLM): ...@@ -92,10 +123,7 @@ class HFLM(BaseLM):
def _model_generate(self, context, max_length, eos_token_id): def _model_generate(self, context, max_length, eos_token_id):
return self.gpt2.generate( return self.gpt2.generate(
context, context, max_length=max_length, eos_token_id=eos_token_id, do_sample=False
max_length=max_length,
eos_token_id=eos_token_id,
do_sample=False
) )
......
...@@ -36,17 +36,19 @@ def get_result(response, ctxlen): ...@@ -36,17 +36,19 @@ def get_result(response, ctxlen):
def oa_completion(**kwargs): def oa_completion(**kwargs):
""" Query OpenAI API for completion. """Query OpenAI API for completion.
Retry with back-off until they respond Retry with back-off until they respond
""" """
import openai import openai
backoff_time = 3 backoff_time = 3
while True: while True:
try: try:
return openai.Completion.create(**kwargs) return openai.Completion.create(**kwargs)
except openai.error.OpenAIError: except openai.error.OpenAIError:
import traceback import traceback
traceback.print_exc() traceback.print_exc()
time.sleep(backoff_time) time.sleep(backoff_time)
backoff_time *= 1.5 backoff_time *= 1.5
...@@ -66,16 +68,19 @@ class GPT3LM(BaseLM): ...@@ -66,16 +68,19 @@ class GPT3LM(BaseLM):
super().__init__() super().__init__()
import openai import openai
self.engine = engine self.engine = engine
self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2') self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained("gpt2")
self.vocab_size = self.tokenizer.vocab_size self.vocab_size = self.tokenizer.vocab_size
# to make the annoying "Using pad_token, but it is not set yet." error go away # to make the annoying "Using pad_token, but it is not set yet." error go away
self.tokenizer.pad_token = "<|endoftext|>" self.tokenizer.pad_token = "<|endoftext|>"
assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373] assert self.tokenizer.encode("hello\n\nhello") == [31373, 198, 198, 31373]
self.truncate = truncate self.truncate = truncate
self.end_of_text_token_id = self.tokenizer.convert_tokens_to_ids(["<|endoftext|>"])[0] self.end_of_text_token_id = self.tokenizer.convert_tokens_to_ids(
["<|endoftext|>"]
)[0]
# Read from environment variable OPENAI_API_SECRET_KEY # Read from environment variable OPENAI_API_SECRET_KEY
openai.api_key = os.environ["OPENAI_API_SECRET_KEY"] openai.api_key = os.environ["OPENAI_API_SECRET_KEY"]
...@@ -119,16 +124,21 @@ class GPT3LM(BaseLM): ...@@ -119,16 +124,21 @@ class GPT3LM(BaseLM):
toks = x[1] + x[2] toks = x[1] + x[2]
return -len(toks), tuple(toks) return -len(toks), tuple(toks)
reord = utils.Reorderer(requests, _collate) re_ord = utils.Reorderer(requests, _collate)
for chunk in tqdm(list(utils.chunks(reord.get_reordered(), self.REQ_CHUNK_SIZE)), disable=disable_tqdm): for chunk in tqdm(
list(utils.chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE)),
disable=disable_tqdm,
):
inps = [] inps = []
ctxlens = [] ctxlens = []
for cache_key, context_enc, continuation_enc in chunk: for cache_key, context_enc, continuation_enc in chunk:
# max_length+1 because the API takes up to 2049 tokens, including the first context token # max_length+1 because the API takes up to 2049 tokens, including the first context token
inp = (context_enc + continuation_enc)[-(self.max_length+1):] inp = (context_enc + continuation_enc)[-(self.max_length + 1) :]
# TODO: the logic is much simpler if we just look at the length of continuation tokens # TODO: the logic is much simpler if we just look at the length of continuation tokens
ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - (self.max_length+1)) ctxlen = len(context_enc) - max(
0, len(context_enc) + len(continuation_enc) - (self.max_length + 1)
)
inps.append(inp) inps.append(inp)
ctxlens.append(ctxlen) ctxlens.append(ctxlen)
...@@ -137,11 +147,14 @@ class GPT3LM(BaseLM): ...@@ -137,11 +147,14 @@ class GPT3LM(BaseLM):
engine=self.engine, engine=self.engine,
prompt=inps, prompt=inps,
echo=True, echo=True,
max_tokens=0, temperature=0., max_tokens=0,
temperature=0.0,
logprobs=10, logprobs=10,
) )
for resp, ctxlen, (cache_key, context_enc, continuation_enc) in zip(response.choices, ctxlens, chunk): for resp, ctxlen, (cache_key, context_enc, continuation_enc) in zip(
response.choices, ctxlens, chunk
):
answer = get_result(resp, ctxlen) answer = get_result(resp, ctxlen)
res.append(answer) res.append(answer)
...@@ -150,7 +163,7 @@ class GPT3LM(BaseLM): ...@@ -150,7 +163,7 @@ class GPT3LM(BaseLM):
if cache_key is not None: if cache_key is not None:
self.cache_hook.add_partial("loglikelihood", cache_key, answer) self.cache_hook.add_partial("loglikelihood", cache_key, answer)
return reord.get_original(res) return re_ord.get_original(res)
def greedy_until(self, requests): def greedy_until(self, requests):
if not requests: if not requests:
...@@ -161,7 +174,7 @@ class GPT3LM(BaseLM): ...@@ -161,7 +174,7 @@ class GPT3LM(BaseLM):
toks = self.tok_encode(x[0]) toks = self.tok_encode(x[0])
return len(toks), x[0] return len(toks), x[0]
reord = utils.Reorderer(requests, _collate) re_ord = utils.Reorderer(requests, _collate)
def sameuntil_chunks(xs, size): def sameuntil_chunks(xs, size):
ret = [] ret = []
...@@ -177,24 +190,26 @@ class GPT3LM(BaseLM): ...@@ -177,24 +190,26 @@ class GPT3LM(BaseLM):
yield ret, lastuntil yield ret, lastuntil
# todo: more intelligent batching for heterogeneous `until` # todo: more intelligent batching for heterogeneous `until`
for chunk, until in tqdm(list(sameuntil_chunks(reord.get_reordered(), self.REQ_CHUNK_SIZE))): for chunk, until in tqdm(
list(sameuntil_chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE))
):
inps = [] inps = []
for context, _ in chunk: for context, _ in chunk:
context_enc = self.tok_encode(context) context_enc = self.tok_encode(context)
inp = context_enc[-(self.max_length - self.max_gen_toks):] inp = context_enc[-(self.max_length - self.max_gen_toks) :]
inps.append(inp) inps.append(inp)
response = oa_completion( response = oa_completion(
engine=self.engine, engine=self.engine,
prompt=inps, prompt=inps,
max_tokens=self.max_gen_toks, max_tokens=self.max_gen_toks,
temperature=0., temperature=0.0,
logprobs=10, logprobs=10,
stop=until, stop=until,
) )
for resp, (context, until_) in zip(response.choices, chunk): for resp, (context, until_) in zip(response.choices, chunk):
s = resp['text'] s = resp["text"]
for term in until_: for term in until_:
s = s.split(term)[0] s = s.split(term)[0]
...@@ -204,7 +219,7 @@ class GPT3LM(BaseLM): ...@@ -204,7 +219,7 @@ class GPT3LM(BaseLM):
res.append(s) res.append(s)
return reord.get_original(res) return re_ord.get_original(res)
def _model_call(self, inps): def _model_call(self, inps):
# Isn't used because we override _loglikelihood_tokens # Isn't used because we override _loglikelihood_tokens
......
...@@ -22,14 +22,12 @@ from . import naturalqs ...@@ -22,14 +22,12 @@ from . import naturalqs
from . import sat from . import sat
from . import arithmetic from . import arithmetic
from . import lambada from . import lambada
from . import race
from . import piqa from . import piqa
from . import prost from . import prost
from . import mc_taco from . import mc_taco
from . import triviaqa from . import triviaqa
from . import pubmedqa from . import pubmedqa
from . import sciq from . import sciq
from . import webqs
from . import qasper from . import qasper
from . import qa4mre from . import qa4mre
from . import translation from . import translation
...@@ -59,8 +57,8 @@ from . import storycloze ...@@ -59,8 +57,8 @@ from . import storycloze
# 6 total # 6 total
gpt3_translation_benchmarks = { gpt3_translation_benchmarks = {
"wmt14": ['en-fr', 'fr-en'], # French "wmt14": ["en-fr", "fr-en"], # French
"wmt16": ['en-ro', 'ro-en', 'de-en', 'en-de'], # German, Romanian "wmt16": ["en-ro", "ro-en", "de-en", "en-de"], # German, Romanian
} }
...@@ -68,7 +66,7 @@ gpt3_translation_benchmarks = { ...@@ -68,7 +66,7 @@ gpt3_translation_benchmarks = {
selected_translation_benchmarks = { selected_translation_benchmarks = {
**gpt3_translation_benchmarks, **gpt3_translation_benchmarks,
"wmt20": sacrebleu.get_langpairs_for_testset("wmt20"), "wmt20": sacrebleu.get_langpairs_for_testset("wmt20"),
"iwslt17": ['en-ar', 'ar-en'] # Arabic "iwslt17": ["en-ar", "ar-en"], # Arabic
} }
# 319 total # 319 total
...@@ -92,7 +90,7 @@ TASK_REGISTRY = { ...@@ -92,7 +90,7 @@ TASK_REGISTRY = {
"rte": glue.RTE, "rte": glue.RTE,
"qnli": glue.QNLI, "qnli": glue.QNLI,
"qqp": glue.QQP, "qqp": glue.QQP,
#"stsb": glue.STSB, # not implemented yet # "stsb": glue.STSB, # not implemented yet
"sst": glue.SST, "sst": glue.SST,
"wnli": glue.WNLI, "wnli": glue.WNLI,
# SuperGLUE # SuperGLUE
...@@ -103,34 +101,26 @@ TASK_REGISTRY = { ...@@ -103,34 +101,26 @@ TASK_REGISTRY = {
"record": superglue.ReCoRD, "record": superglue.ReCoRD,
"wic": superglue.WordsInContext, "wic": superglue.WordsInContext,
"wsc": superglue.SGWinogradSchemaChallenge, "wsc": superglue.SGWinogradSchemaChallenge,
# Order by benchmark/genre? # Order by benchmark/genre?
"coqa": coqa.CoQA, "coqa": coqa.CoQA,
"drop": drop.DROP, "drop": drop.DROP,
"lambada": lambada.LAMBADA, "lambada": lambada.LAMBADA,
"lambada_cloze": lambada_cloze.LAMBADA_cloze, "lambada_cloze": lambada_cloze.LAMBADA_cloze,
# multilingual lambada # multilingual lambada
**lambada_multilingual.construct_tasks(), **lambada_multilingual.construct_tasks(),
"wikitext": wikitext.WikiText, "wikitext": wikitext.WikiText,
# "cbt-cn": cbt.CBTCN, # disabled pending context length fix # "cbt-cn": cbt.CBTCN, # disabled pending context length fix
# "cbt-ne": cbt.CBTNE, # disabled pending context length fix # "cbt-ne": cbt.CBTNE, # disabled pending context length fix
"piqa": piqa.PiQA, "piqa": piqa.PiQA,
"prost": prost.PROST, "prost": prost.PROST,
"mc_taco": mc_taco.MCTACO, "mc_taco": mc_taco.MCTACO,
# Science related # Science related
"pubmedqa" : pubmedqa.Pubmed_QA, "pubmedqa": pubmedqa.Pubmed_QA,
"sciq" : sciq.SciQ, "sciq": sciq.SciQ,
"qasper": qasper.QASPER, "qasper": qasper.QASPER,
"qa4mre_2011": qa4mre.QA4MRE_2011,
"qa4mre_2011" : qa4mre.QA4MRE_2011, "qa4mre_2012": qa4mre.QA4MRE_2012,
"qa4mre_2012" : qa4mre.QA4MRE_2012, "qa4mre_2013": qa4mre.QA4MRE_2013,
"qa4mre_2013" : qa4mre.QA4MRE_2013,
"triviaqa": triviaqa.TriviaQA, "triviaqa": triviaqa.TriviaQA,
"arc_easy": arc.ARCEasy, "arc_easy": arc.ARCEasy,
"arc_challenge": arc.ARCChallenge, "arc_challenge": arc.ARCChallenge,
...@@ -152,21 +142,17 @@ TASK_REGISTRY = { ...@@ -152,21 +142,17 @@ TASK_REGISTRY = {
"anli_r1": anli.ANLIRound1, "anli_r1": anli.ANLIRound1,
"anli_r2": anli.ANLIRound2, "anli_r2": anli.ANLIRound2,
"anli_r3": anli.ANLIRound3, "anli_r3": anli.ANLIRound3,
"ethics_cm": hendrycks_ethics.EthicsCM, "ethics_cm": hendrycks_ethics.EthicsCM,
"ethics_deontology": hendrycks_ethics.EthicsDeontology, "ethics_deontology": hendrycks_ethics.EthicsDeontology,
"ethics_justice": hendrycks_ethics.EthicsJustice, "ethics_justice": hendrycks_ethics.EthicsJustice,
"ethics_utilitarianism_original": hendrycks_ethics.EthicsUtilitarianismOriginal, "ethics_utilitarianism_original": hendrycks_ethics.EthicsUtilitarianismOriginal,
"ethics_utilitarianism": hendrycks_ethics.EthicsUtilitarianism, "ethics_utilitarianism": hendrycks_ethics.EthicsUtilitarianism,
"ethics_virtue": hendrycks_ethics.EthicsVirtue, "ethics_virtue": hendrycks_ethics.EthicsVirtue,
"truthfulqa_mc": truthfulqa.TruthfulQAMultipleChoice, "truthfulqa_mc": truthfulqa.TruthfulQAMultipleChoice,
"truthfulqa_gen": truthfulqa.TruthfulQAGeneration, "truthfulqa_gen": truthfulqa.TruthfulQAGeneration,
# dialogue # dialogue
"mutual": mutual.MuTual, "mutual": mutual.MuTual,
"mutual_plus": mutual.MuTualPlus, "mutual_plus": mutual.MuTualPlus,
# math # math
"math_algebra": hendrycks_math.MathAlgebra, "math_algebra": hendrycks_math.MathAlgebra,
"math_counting_and_prob": hendrycks_math.MathCountingAndProbability, "math_counting_and_prob": hendrycks_math.MathCountingAndProbability,
...@@ -177,7 +163,6 @@ TASK_REGISTRY = { ...@@ -177,7 +163,6 @@ TASK_REGISTRY = {
"math_precalc": hendrycks_math.MathPrecalculus, "math_precalc": hendrycks_math.MathPrecalculus,
"math_asdiv": asdiv.Asdiv, "math_asdiv": asdiv.Asdiv,
"gsm8k": gsm8k.GradeSchoolMath8K, "gsm8k": gsm8k.GradeSchoolMath8K,
# arithmetic # arithmetic
"arithmetic_2da": arithmetic.Arithmetic2DPlus, "arithmetic_2da": arithmetic.Arithmetic2DPlus,
"arithmetic_2ds": arithmetic.Arithmetic2DMinus, "arithmetic_2ds": arithmetic.Arithmetic2DMinus,
...@@ -191,22 +176,18 @@ TASK_REGISTRY = { ...@@ -191,22 +176,18 @@ TASK_REGISTRY = {
"arithmetic_1dc": arithmetic.Arithmetic1DComposite, "arithmetic_1dc": arithmetic.Arithmetic1DComposite,
# TODO Perhaps make these groups of tasks # TODO Perhaps make these groups of tasks
# e.g. anli, arithmetic, openai_translations, harness_translations # e.g. anli, arithmetic, openai_translations, harness_translations
# hendrycksTest (57 tasks) # hendrycksTest (57 tasks)
**hendrycks_test.create_all_tasks(), **hendrycks_test.create_all_tasks(),
# e.g. wmt14-fr-en # e.g. wmt14-fr-en
**translation.create_tasks_from_benchmarks(gpt3_translation_benchmarks), **translation.create_tasks_from_benchmarks(gpt3_translation_benchmarks),
# chef's selection, mostly wmt20 # chef's selection, mostly wmt20
**translation.create_tasks_from_benchmarks(selected_translation_benchmarks), **translation.create_tasks_from_benchmarks(selected_translation_benchmarks),
# Word Scrambling and Manipulation Tasks # Word Scrambling and Manipulation Tasks
"anagrams1": unscramble.Anagrams1, "anagrams1": unscramble.Anagrams1,
"anagrams2": unscramble.Anagrams2, "anagrams2": unscramble.Anagrams2,
"cycle_letters": unscramble.CycleLetters, "cycle_letters": unscramble.CycleLetters,
"random_insertion": unscramble.RandomInsertion, "random_insertion": unscramble.RandomInsertion,
"reversed_words": unscramble.ReversedWords, "reversed_words": unscramble.ReversedWords,
# Pile # Pile
"pile_arxiv": pile.PileArxiv, "pile_arxiv": pile.PileArxiv,
"pile_books3": pile.PileBooks3, "pile_books3": pile.PileBooks3,
...@@ -230,7 +211,6 @@ TASK_REGISTRY = { ...@@ -230,7 +211,6 @@ TASK_REGISTRY = {
"pile_ubuntu-irc": pile.PileUbuntuIrc, "pile_ubuntu-irc": pile.PileUbuntuIrc,
"pile_wikipedia": pile.PileWikipedia, "pile_wikipedia": pile.PileWikipedia,
"pile_youtubesubtitles": pile.PileYoutubeSubtitles, "pile_youtubesubtitles": pile.PileYoutubeSubtitles,
# BLiMP # BLiMP
"blimp_adjunct_island": blimp.BlimpAdjunctIsland, "blimp_adjunct_island": blimp.BlimpAdjunctIsland,
"blimp_anaphor_gender_agreement": blimp.BlimpAnaphorGenderAgreement, "blimp_anaphor_gender_agreement": blimp.BlimpAnaphorGenderAgreement,
...@@ -299,7 +279,6 @@ TASK_REGISTRY = { ...@@ -299,7 +279,6 @@ TASK_REGISTRY = {
"blimp_wh_vs_that_no_gap_long_distance": blimp.BlimpWhVsThatNoGapLongDistance, "blimp_wh_vs_that_no_gap_long_distance": blimp.BlimpWhVsThatNoGapLongDistance,
"blimp_wh_vs_that_with_gap": blimp.BlimpWhVsThatWithGap, "blimp_wh_vs_that_with_gap": blimp.BlimpWhVsThatWithGap,
"blimp_wh_vs_that_with_gap_long_distance": blimp.BlimpWhVsThatWithGapLongDistance, "blimp_wh_vs_that_with_gap_long_distance": blimp.BlimpWhVsThatWithGapLongDistance,
# Requires manual download of data. # Requires manual download of data.
# "storycloze_2016": storycloze.StoryCloze2016, # "storycloze_2016": storycloze.StoryCloze2016,
# "storycloze_2018": storycloze.StoryCloze2018, # "storycloze_2018": storycloze.StoryCloze2018,
...@@ -313,7 +292,7 @@ ALL_TASKS = sorted(list(TASK_REGISTRY)) ...@@ -313,7 +292,7 @@ ALL_TASKS = sorted(list(TASK_REGISTRY))
def get_task(task_name): def get_task(task_name):
try: try:
return TASK_REGISTRY[task_name] return TASK_REGISTRY[task_name]
except KeyError as e: except KeyError:
print("Available tasks:") print("Available tasks:")
pprint(TASK_REGISTRY) pprint(TASK_REGISTRY)
raise KeyError(f"Missing task {task_name}") raise KeyError(f"Missing task {task_name}")
...@@ -325,17 +304,23 @@ def get_task_name_from_object(task_object): ...@@ -325,17 +304,23 @@ def get_task_name_from_object(task_object):
return name return name
# this gives a mechanism for non-registered tasks to have a custom name anyways when reporting # this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
return task_object.EVAL_HARNESS_NAME if hasattr(task_object, "EVAL_HARNESS_NAME") else type(task_object).__name__ return (
task_object.EVAL_HARNESS_NAME
if hasattr(task_object, "EVAL_HARNESS_NAME")
else type(task_object).__name__
)
def get_task_dict(task_name_list: List[Union[str, lm_eval.base.Task]]): def get_task_dict(task_name_list: List[Union[str, lm_eval.base.Task]]):
task_name_dict = { task_name_dict = {
task_name: get_task(task_name)() task_name: get_task(task_name)()
for task_name in task_name_list if isinstance(task_name, str) for task_name in task_name_list
if isinstance(task_name, str)
} }
task_name_from_object_dict = { task_name_from_object_dict = {
get_task_name_from_object(task_object): task_object get_task_name_from_object(task_object): task_object
for task_object in task_name_list if not isinstance(task_object, str) for task_object in task_name_list
if not isinstance(task_object, str)
} }
assert set(task_name_dict.keys()).isdisjoint(set(task_name_from_object_dict.keys())) assert set(task_name_dict.keys()).isdisjoint(set(task_name_from_object_dict.keys()))
return {**task_name_dict, **task_name_from_object_dict} return {**task_name_dict, **task_name_from_object_dict}
...@@ -64,16 +64,27 @@ class ANLIBase(Task): ...@@ -64,16 +64,27 @@ class ANLIBase(Task):
# of the prompt (yes, repeating it!). also, " True, False, or Neither?" is directly # of the prompt (yes, repeating it!). also, " True, False, or Neither?" is directly
# appended onto the question, with no "Answer:" or even a newline. Do we *really* # appended onto the question, with no "Answer:" or even a newline. Do we *really*
# want to do it exactly as OA did? # want to do it exactly as OA did?
return doc['premise'] + '\nQuestion: ' + doc['hypothesis'] + ' True, False, or Neither?\nAnswer:' return (
doc["premise"]
+ "\nQuestion: "
+ doc["hypothesis"]
+ " True, False, or Neither?\nAnswer:"
)
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["premise"]
def doc_to_target(self, doc): def doc_to_target(self, doc):
# True = entailment # True = entailment
# False = contradiction # False = contradiction
# Neither = neutral # Neither = neutral
return " " + ["True", "Neither", "False"][doc['label']] return " " + ["True", "Neither", "False"][doc["label"]]
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of """Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM. Requests which will be sent to the LM.
:param doc: :param doc:
...@@ -100,9 +111,7 @@ class ANLIBase(Task): ...@@ -100,9 +111,7 @@ class ANLIBase(Task):
""" """
gold = doc["label"] gold = doc["label"]
pred = np.argmax(results) pred = np.argmax(results)
return { return {"acc": pred == gold}
"acc": pred == gold
}
def aggregation(self): def aggregation(self):
""" """
...@@ -110,9 +119,7 @@ class ANLIBase(Task): ...@@ -110,9 +119,7 @@ class ANLIBase(Task):
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics functions that aggregate a list of metrics
""" """
return { return {"acc": mean}
"acc": mean
}
def higher_is_better(self): def higher_is_better(self):
""" """
...@@ -120,9 +127,7 @@ class ANLIBase(Task): ...@@ -120,9 +127,7 @@ class ANLIBase(Task):
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better whether a higher value of the submetric is better
""" """
return { return {"acc": True}
"acc": True
}
class ANLIRound1(ANLIBase): class ANLIRound1(ANLIBase):
......
...@@ -67,6 +67,12 @@ class ARCEasy(MultipleChoiceTask): ...@@ -67,6 +67,12 @@ class ARCEasy(MultipleChoiceTask):
def doc_to_text(self, doc): def doc_to_text(self, doc):
return doc["query"] return doc["query"]
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["query"]
class ARCChallenge(ARCEasy): class ARCChallenge(ARCEasy):
DATASET_PATH = "ai2_arc" DATASET_PATH = "ai2_arc"
......
...@@ -53,6 +53,12 @@ class Arithmetic(Task): ...@@ -53,6 +53,12 @@ class Arithmetic(Task):
def doc_to_text(self, doc): def doc_to_text(self, doc):
return doc["context"] return doc["context"]
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["context"]
def doc_to_target(self, doc): def doc_to_target(self, doc):
return doc["completion"] return doc["completion"]
...@@ -61,10 +67,8 @@ class Arithmetic(Task): ...@@ -61,10 +67,8 @@ class Arithmetic(Task):
return is_prediction return is_prediction
def process_results(self, doc, results): def process_results(self, doc, results):
is_prediction, = results (is_prediction,) = results
return { return {"acc": is_prediction}
"acc": is_prediction
}
def aggregation(self): def aggregation(self):
return { return {
...@@ -72,9 +76,7 @@ class Arithmetic(Task): ...@@ -72,9 +76,7 @@ class Arithmetic(Task):
} }
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True}
"acc": True
}
class Arithmetic2DPlus(Arithmetic): class Arithmetic2DPlus(Arithmetic):
......
...@@ -54,23 +54,28 @@ class Asdiv(Task): ...@@ -54,23 +54,28 @@ class Asdiv(Task):
def test_docs(self): def test_docs(self):
raise NotImplementedError("This dataset has no test docs") raise NotImplementedError("This dataset has no test docs")
def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None): def fewshot_context(
self, doc, num_fewshot, provide_description=None, rnd=None, description=None
):
assert num_fewshot == 0, "ASDiv is intended only for the zero-shot setting." assert num_fewshot == 0, "ASDiv is intended only for the zero-shot setting."
return super().fewshot_context( return super().fewshot_context(
doc=doc, doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description
num_fewshot=num_fewshot,
rnd=rnd,
description=description
) )
def doc_to_text(self, doc): def doc_to_text(self, doc):
# TODO: add solution-type # TODO: add solution-type
return doc['body'] + '\n' + 'Question:' + doc['question'] + '\n' + 'Answer:' return doc["body"] + "\n" + "Question:" + doc["question"] + "\n" + "Answer:"
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["body"] + " " + doc["question"]
def doc_to_target(self, doc): def doc_to_target(self, doc):
# TODO: add formula # TODO: add formula
answer = doc['answer'].split(' (')[0] answer = doc["answer"].split(" (")[0]
return " " + answer return " " + answer
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
...@@ -80,16 +85,10 @@ class Asdiv(Task): ...@@ -80,16 +85,10 @@ class Asdiv(Task):
def process_results(self, doc, results): def process_results(self, doc, results):
ll, is_greedy = results ll, is_greedy = results
return { return {"acc": int(is_greedy)}
'acc': int(is_greedy)
}
def aggregation(self): def aggregation(self):
return { return {"acc": mean}
'acc': mean
}
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True}
'acc': True
}
...@@ -28,7 +28,7 @@ _CITATION = """ ...@@ -28,7 +28,7 @@ _CITATION = """
eprint = {https://doi.org/10.1162/tacl_a_00321}, eprint = {https://doi.org/10.1162/tacl_a_00321},
abstract = { We introduce The Benchmark of Linguistic Minimal Pairs (BLiMP),1 a challenge set for evaluating the linguistic knowledge of language models (LMs) on major grammatical phenomena in English. BLiMP consists of 67 individual datasets, each containing 1,000 minimal pairs—that is, pairs of minimally different sentences that contrast in grammatical acceptability and isolate specific phenomenon in syntax, morphology, or semantics. We generate the data according to linguist-crafted grammar templates, and human aggregate agreement with the labels is 96.4\%. We evaluate n-gram, LSTM, and Transformer (GPT-2 and Transformer-XL) LMs by observing whether they assign a higher probability to the acceptable sentence in each minimal pair. We find that state-of-the-art models identify morphological contrasts related to agreement reliably, but they struggle with some subtle semantic and syntactic phenomena, such as negative polarity items and extraction islands. } abstract = { We introduce The Benchmark of Linguistic Minimal Pairs (BLiMP),1 a challenge set for evaluating the linguistic knowledge of language models (LMs) on major grammatical phenomena in English. BLiMP consists of 67 individual datasets, each containing 1,000 minimal pairs—that is, pairs of minimally different sentences that contrast in grammatical acceptability and isolate specific phenomenon in syntax, morphology, or semantics. We generate the data according to linguist-crafted grammar templates, and human aggregate agreement with the labels is 96.4\%. We evaluate n-gram, LSTM, and Transformer (GPT-2 and Transformer-XL) LMs by observing whether they assign a higher probability to the acceptable sentence in each minimal pair. We find that state-of-the-art models identify morphological contrasts related to agreement reliably, but they struggle with some subtle semantic and syntactic phenomena, such as negative polarity items and extraction islands. }
} }
""" """ # noqa: W605
class BlimpTask(Task): class BlimpTask(Task):
...@@ -50,9 +50,13 @@ class BlimpTask(Task): ...@@ -50,9 +50,13 @@ class BlimpTask(Task):
# trained on this data. # trained on this data.
return self.dataset["train"] return self.dataset["train"]
def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None): def fewshot_context(
self, doc, num_fewshot, provide_description=None, rnd=None, description=None
):
assert num_fewshot == 0 assert num_fewshot == 0
assert rnd is not None, "A `random.Random` generator argument must be provided to `rnd`" assert (
rnd is not None
), "A `random.Random` generator argument must be provided to `rnd`"
assert not provide_description, ( assert not provide_description, (
"The `provide_description` arg will be removed in future versions. To prepend " "The `provide_description` arg will be removed in future versions. To prepend "
"a custom description to the context, supply the corresponding string via the " "a custom description to the context, supply the corresponding string via the "
...@@ -60,7 +64,9 @@ class BlimpTask(Task): ...@@ -60,7 +64,9 @@ class BlimpTask(Task):
) )
if provide_description is not None: if provide_description is not None:
# nudge people to not specify it at all # nudge people to not specify it at all
print("WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict") print(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
return "" return ""
...@@ -68,6 +74,12 @@ class BlimpTask(Task): ...@@ -68,6 +74,12 @@ class BlimpTask(Task):
# this method is invoked by tests only # this method is invoked by tests only
return "" return ""
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["sentence_good"] + " " + doc["sentence_bad"]
def doc_to_target(self, doc): def doc_to_target(self, doc):
# this method is invoked by tests only # this method is invoked by tests only
return "" return ""
......
...@@ -75,11 +75,20 @@ class CBTBase(Task): ...@@ -75,11 +75,20 @@ class CBTBase(Task):
text = "Passage: " + passage + "\nQuestion: " + doc["question"] text = "Passage: " + passage + "\nQuestion: " + doc["question"]
return self.detokenize(text) return self.detokenize(text)
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
passage = " ".join(doc["sentences"])
return passage
def doc_to_target(self, doc): def doc_to_target(self, doc):
return "" return ""
def fewshot_examples(self, k, rnd): def fewshot_examples(self, k, rnd):
assert k == 0, f"CBT is only implemented for the zero-shot setting. Given k={k}." assert (
k == 0
), f"CBT is only implemented for the zero-shot setting. Given k={k}."
return super().fewshot_examples(k, rnd) return super().fewshot_examples(k, rnd)
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
...@@ -113,9 +122,7 @@ class CBTBase(Task): ...@@ -113,9 +122,7 @@ class CBTBase(Task):
""" """
gold = doc["options"].index(doc["answer"]) gold = doc["options"].index(doc["answer"])
pred = np.argmax(results) pred = np.argmax(results)
return { return {"acc": pred == gold}
"acc": pred == gold
}
def aggregation(self): def aggregation(self):
""" """
...@@ -123,9 +130,7 @@ class CBTBase(Task): ...@@ -123,9 +130,7 @@ class CBTBase(Task):
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics functions that aggregate a list of metrics
""" """
return { return {"acc": mean}
"acc": mean
}
def higher_is_better(self): def higher_is_better(self):
""" """
...@@ -133,9 +138,7 @@ class CBTBase(Task): ...@@ -133,9 +138,7 @@ class CBTBase(Task):
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better whether a higher value of the submetric is better
""" """
return { return {"acc": True}
"acc": True
}
class CBTCN(CBTBase): class CBTCN(CBTBase):
......
...@@ -54,13 +54,21 @@ class CoQA(Task): ...@@ -54,13 +54,21 @@ class CoQA(Task):
def doc_to_text(self, doc): def doc_to_text(self, doc):
# Given a passage p, the conversation history {q1, a1, . . . qi−1, ai−1} # Given a passage p, the conversation history {q1, a1, . . . qi−1, ai−1}
# and a question qi, the task is to predict the answer ai # and a question qi, the task is to predict the answer ai
doc_text = doc["story"] + '\n\n' doc_text = doc["story"] + "\n\n"
for (q, a) in zip_longest(doc["questions"]["input_text"], doc["answers"]["input_text"][:-1]): # omit target answer ai for (q, a) in zip_longest(
doc["questions"]["input_text"], doc["answers"]["input_text"][:-1]
): # omit target answer ai
question = f"Q: {q}\n\n" question = f"Q: {q}\n\n"
answer = f"A: {a}\n\n" if a is not None else "A:" answer = f"A: {a}\n\n" if a is not None else "A:"
doc_text += question + answer doc_text += question + answer
return doc_text return doc_text
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["story"] + " " + "\n".join(doc["questions"]["input_text"])
@classmethod @classmethod
def get_answers(cls, doc, turn_id): def get_answers(cls, doc, turn_id):
# Returns unique answers and valid alternatives (Some questions in CoQA have multiple valid answers). # Returns unique answers and valid alternatives (Some questions in CoQA have multiple valid answers).
...@@ -71,7 +79,9 @@ class CoQA(Task): ...@@ -71,7 +79,9 @@ class CoQA(Task):
additional_answers = doc.get("additional_answers") additional_answers = doc.get("additional_answers")
if additional_answers: if additional_answers:
for key in additional_answers: for key in additional_answers:
additional_answer_for_turn = additional_answers[key]["input_text"][turn_id - 1] additional_answer_for_turn = additional_answers[key]["input_text"][
turn_id - 1
]
if additional_answer_for_turn.lower() not in map(str.lower, answers): if additional_answer_for_turn.lower() not in map(str.lower, answers):
answers.append(additional_answer_for_turn) answers.append(additional_answer_for_turn)
return answers return answers
...@@ -83,12 +93,12 @@ class CoQA(Task): ...@@ -83,12 +93,12 @@ class CoQA(Task):
# ~ 2/3 of the CoQA answers are span-based # ~ 2/3 of the CoQA answers are span-based
# (answers overlap with the passage ignoring punctuation and case mismatch) # (answers overlap with the passage ignoring punctuation and case mismatch)
if raw_text == "unknown": if raw_text == "unknown":
return '0' return "0"
if squad_metrics.normalize_answer(raw_text) == "yes": if squad_metrics.normalize_answer(raw_text) == "yes":
return '1' return "1"
if squad_metrics.normalize_answer(raw_text) == "no": if squad_metrics.normalize_answer(raw_text) == "no":
return '2' return "2"
return '3' # Not a yes/no question return "3" # Not a yes/no question
@staticmethod @staticmethod
def compute_scores(gold_list, pred): def compute_scores(gold_list, pred):
...@@ -98,25 +108,30 @@ class CoQA(Task): ...@@ -98,25 +108,30 @@ class CoQA(Task):
em_sum = 0.0 em_sum = 0.0
if len(gold_list) > 1: if len(gold_list) > 1:
for i in range(len(gold_list)): for i in range(len(gold_list)):
gold_answers = gold_list[0:i] + gold_list[i + 1:] gold_answers = gold_list[0:i] + gold_list[i + 1 :]
# predictions compared against (n) golds and take maximum # predictions compared against (n) golds and take maximum
em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_answers) em_sum += max(
squad_metrics.compute_exact(a, pred) for a in gold_answers
)
f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_answers) f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_answers)
else: else:
em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_list) em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_list)
f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_list) f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_list)
return {'em': em_sum / max(1, len(gold_list)), 'f1': f1_sum / max(1, len(gold_list))} return {
"em": em_sum / max(1, len(gold_list)),
"f1": f1_sum / max(1, len(gold_list)),
}
def doc_to_target(self, doc, turnid=None): def doc_to_target(self, doc, turnid=None):
# Default to prediction of last turn. # Default to prediction of last turn.
if turnid is None: if turnid is None:
turnid = len(doc["questions"]["input_text"]) turnid = len(doc["questions"]["input_text"])
raw_text = doc['answers']["input_text"][turnid - 1] raw_text = doc["answers"]["input_text"][turnid - 1]
return " " + raw_text return " " + raw_text
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of """Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM. Requests which will be sent to the LM.
:param doc: :param doc:
...@@ -126,7 +141,7 @@ class CoQA(Task): ...@@ -126,7 +141,7 @@ class CoQA(Task):
language description, as well as the few shot examples, and the question language description, as well as the few shot examples, and the question
part of the document for `doc`. part of the document for `doc`.
""" """
cont_request = rf.greedy_until(ctx, ['\nQ:']) cont_request = rf.greedy_until(ctx, ["\nQ:"])
return cont_request return cont_request
def process_results(self, doc, results): def process_results(self, doc, results):
...@@ -141,13 +156,13 @@ class CoQA(Task): ...@@ -141,13 +156,13 @@ class CoQA(Task):
""" """
turn_id = len(doc["questions"]["input_text"]) turn_id = len(doc["questions"]["input_text"])
gold_list = self.get_answers(doc, turn_id) gold_list = self.get_answers(doc, turn_id)
pred = results[0].strip().split('\n')[0] pred = results[0].strip().split("\n")[0]
scores = self.compute_scores(gold_list, pred) scores = self.compute_scores(gold_list, pred)
return { return {
"f1": scores['f1'], "f1": scores["f1"],
"em": scores['em'], "em": scores["em"],
} }
def higher_is_better(self): def higher_is_better(self):
......
...@@ -70,21 +70,26 @@ class DROP(Task): ...@@ -70,21 +70,26 @@ class DROP(Task):
@classmethod @classmethod
def get_answers(cls, qa): def get_answers(cls, qa):
def _flatten_validated_answers(validated_answers): def _flatten_validated_answers(validated_answers):
""" Flattens a dict of lists of validated answers. """Flattens a dict of lists of validated answers.
{"number": ['1', '8'], ...} {"number": ['1', '8'], ...}
-> [{"number": ['1'], ...}, {"number": ['8'], ...}] -> [{"number": ['1'], ...}, {"number": ['8'], ...}]
""" """
vas = [] valid_answers = []
for i in range(len(validated_answers["number"])): for i in range(len(validated_answers["number"])):
vas.append({ valid_answers.append(
{
"number": validated_answers["number"][i], "number": validated_answers["number"][i],
"date": validated_answers["date"][i], "date": validated_answers["date"][i],
"spans": validated_answers["spans"][i], "spans": validated_answers["spans"][i],
}) }
return vas )
return valid_answers
answers = [] answers = []
answers_set = set() answers_set = set()
candidates = [qa["answer"]] + _flatten_validated_answers(qa["validated_answers"]) candidates = [qa["answer"]] + _flatten_validated_answers(
qa["validated_answers"]
)
for candidate in candidates: for candidate in candidates:
answer = cls.parse_answer(candidate) answer = cls.parse_answer(candidate)
if answer in answers_set: if answer in answers_set:
...@@ -100,13 +105,21 @@ class DROP(Task): ...@@ -100,13 +105,21 @@ class DROP(Task):
return (str(answer["number"]),) return (str(answer["number"]),)
if answer["spans"] != []: if answer["spans"] != []:
return tuple(answer["spans"]) return tuple(answer["spans"])
return (" ".join([answer["date"]["day"], return (
answer["date"]["month"], " ".join(
answer["date"]["year"]]).strip(),) [answer["date"]["day"], answer["date"]["month"], answer["date"]["year"]]
).strip(),
)
def doc_to_text(self, doc): def doc_to_text(self, doc):
return f"Passage: {doc['passage']}\nQuestion: {doc['question']}\nAnswer:" return f"Passage: {doc['passage']}\nQuestion: {doc['question']}\nAnswer:"
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["passage"] + " " + doc["question"]
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " " + ", ".join(doc["answers"][0]) return " " + ", ".join(doc["answers"][0])
...@@ -142,10 +155,7 @@ class DROP(Task): ...@@ -142,10 +155,7 @@ class DROP(Task):
if gold_answer[0].strip(): if gold_answer[0].strip():
max_em = max(max_em, exact_match) max_em = max(max_em, exact_match)
max_f1 = max(max_f1, f1_score) max_f1 = max(max_f1, f1_score)
return { return {"em": max_em, "f1": max_f1}
"em": max_em,
"f1": max_f1
}
def get_metrics(self, predicted, gold): def get_metrics(self, predicted, gold):
""" """
...@@ -158,7 +168,9 @@ class DROP(Task): ...@@ -158,7 +168,9 @@ class DROP(Task):
predicted_bags = self._answer_to_bags(predicted) predicted_bags = self._answer_to_bags(predicted)
gold_bags = self._answer_to_bags(gold) gold_bags = self._answer_to_bags(gold)
if set(predicted_bags[0]) == set(gold_bags[0]) and len(predicted_bags[0]) == len(gold_bags[0]): if set(predicted_bags[0]) == set(gold_bags[0]) and len(
predicted_bags[0]
) == len(gold_bags[0]):
exact_match = 1.0 exact_match = 1.0
else: else:
exact_match = 0.0 exact_match = 0.0
...@@ -190,7 +202,9 @@ class DROP(Task): ...@@ -190,7 +202,9 @@ class DROP(Task):
for gold_index, gold_item in enumerate(gold): for gold_index, gold_item in enumerate(gold):
for pred_index, pred_item in enumerate(predicted): for pred_index, pred_item in enumerate(predicted):
if self._match_numbers_if_present(gold_item, pred_item): if self._match_numbers_if_present(gold_item, pred_item):
scores[gold_index, pred_index] = self._compute_f1(pred_item, gold_item) scores[gold_index, pred_index] = self._compute_f1(
pred_item, gold_item
)
row_ind, col_ind = linear_sum_assignment(-scores) row_ind, col_ind = linear_sum_assignment(-scores)
max_scores = np.zeros([max(len(gold), len(predicted))]) max_scores = np.zeros([max(len(gold), len(predicted))])
...@@ -256,7 +270,11 @@ class DROP(Task): ...@@ -256,7 +270,11 @@ class DROP(Task):
def _normalize(self, answer): def _normalize(self, answer):
tokens = [ tokens = [
self._white_space_fix(self._remove_articles(self._fix_number(self._remove_punc(token.lower())))) self._white_space_fix(
self._remove_articles(
self._fix_number(self._remove_punc(token.lower()))
)
)
for token in self._tokenize(answer) for token in self._tokenize(answer)
] ]
tokens = [token for token in tokens if token.strip()] tokens = [token for token in tokens if token.strip()]
...@@ -269,10 +287,7 @@ class DROP(Task): ...@@ -269,10 +287,7 @@ class DROP(Task):
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics functions that aggregate a list of metrics
""" """
return { return {"em": mean, "f1": mean}
"em": mean,
"f1": mean
}
def higher_is_better(self): def higher_is_better(self):
""" """
...@@ -280,7 +295,4 @@ class DROP(Task): ...@@ -280,7 +295,4 @@ class DROP(Task):
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better whether a higher value of the submetric is better
""" """
return { return {"em": True, "f1": True}
"em": True,
"f1": True
}
...@@ -68,7 +68,15 @@ class CoLA(Task): ...@@ -68,7 +68,15 @@ class CoLA(Task):
return self.dataset["validation"] return self.dataset["validation"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "{}\nQuestion: Does this sentence make sense?\nAnswer:".format(doc["sentence"]) return "{}\nQuestion: Does this sentence make sense?\nAnswer:".format(
doc["sentence"]
)
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["sentence"]
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " {}".format({1: "yes", 0: "no"}[doc["label"]]) return " {}".format({1: "yes", 0: "no"}[doc["label"]])
...@@ -82,19 +90,13 @@ class CoLA(Task): ...@@ -82,19 +90,13 @@ class CoLA(Task):
ll_true, ll_false = results ll_true, ll_false = results
pred = ll_true > ll_false pred = ll_true > ll_false
gold = doc["label"] gold = doc["label"]
return { return {"mcc": (gold, pred)}
"mcc": (gold, pred)
}
def higher_is_better(self): def higher_is_better(self):
return { return {"mcc": True}
"mcc": True
}
def aggregation(self): def aggregation(self):
return { return {"mcc": matthews_corrcoef}
"mcc": matthews_corrcoef
}
class SST(Task): class SST(Task):
...@@ -136,19 +138,13 @@ class SST(Task): ...@@ -136,19 +138,13 @@ class SST(Task):
ll_positive, ll_negative = results ll_positive, ll_negative = results
pred = ll_positive > ll_negative pred = ll_positive > ll_negative
gold = doc["label"] gold = doc["label"]
return { return {"acc": pred == gold}
"acc": pred == gold
}
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True}
"acc": True
}
def aggregation(self): def aggregation(self):
return { return {"acc": mean}
"acc": mean
}
# Inference Tasks # Inference Tasks
...@@ -184,7 +180,8 @@ class MNLI(Task): ...@@ -184,7 +180,8 @@ class MNLI(Task):
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "{}\nQuestion: {} True, False or Neither?\nAnswer:".format( return "{}\nQuestion: {} True, False or Neither?\nAnswer:".format(
doc["premise"], doc["premise"],
doc["hypothesis"].strip() + ('' if doc["hypothesis"].strip().endswith('.') else '.'), doc["hypothesis"].strip()
+ ("" if doc["hypothesis"].strip().endswith(".") else "."),
) )
def doc_to_target(self, doc): def doc_to_target(self, doc):
...@@ -202,19 +199,13 @@ class MNLI(Task): ...@@ -202,19 +199,13 @@ class MNLI(Task):
def process_results(self, doc, results): def process_results(self, doc, results):
gold = doc["label"] gold = doc["label"]
pred = np.argmax(results) pred = np.argmax(results)
return { return {"acc": pred == gold}
"acc": pred == gold
}
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True}
"acc": True
}
def aggregation(self): def aggregation(self):
return { return {"acc": mean}
"acc": mean
}
class MNLIMismatched(MNLI): class MNLIMismatched(MNLI):
...@@ -252,10 +243,12 @@ class QNLI(Task): ...@@ -252,10 +243,12 @@ class QNLI(Task):
return self.dataset["validation"] return self.dataset["validation"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "{}\n{}\nQuestion: Does this response answer the question?\nAnswer:".format( return (
"{}\n{}\nQuestion: Does this response answer the question?\nAnswer:".format(
doc["question"], doc["question"],
doc["sentence"], doc["sentence"],
) )
)
def doc_to_target(self, doc): def doc_to_target(self, doc):
# True = entailment # True = entailment
...@@ -271,19 +264,13 @@ class QNLI(Task): ...@@ -271,19 +264,13 @@ class QNLI(Task):
ll_yes, ll_no = results ll_yes, ll_no = results
pred = ll_no > ll_yes pred = ll_no > ll_yes
gold = doc["label"] gold = doc["label"]
return { return {"acc": pred == gold}
"acc": pred == gold
}
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True}
"acc": True
}
def aggregation(self): def aggregation(self):
return { return {"acc": mean}
"acc": mean
}
class WNLI(Task): class WNLI(Task):
...@@ -328,19 +315,13 @@ class WNLI(Task): ...@@ -328,19 +315,13 @@ class WNLI(Task):
ll_true, ll_false = results ll_true, ll_false = results
pred = ll_true > ll_false pred = ll_true > ll_false
gold = doc["label"] gold = doc["label"]
return { return {"acc": pred == gold}
"acc": pred == gold
}
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True}
"acc": True
}
def aggregation(self): def aggregation(self):
return { return {"acc": mean}
"acc": mean
}
class RTE(Task): class RTE(Task):
...@@ -385,19 +366,13 @@ class RTE(Task): ...@@ -385,19 +366,13 @@ class RTE(Task):
ll_true, ll_false = results ll_true, ll_false = results
pred = ll_false > ll_true pred = ll_false > ll_true
gold = doc["label"] gold = doc["label"]
return { return {"acc": pred == gold}
"acc": pred == gold
}
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True}
"acc": True
}
def aggregation(self): def aggregation(self):
return { return {"acc": mean}
"acc": mean
}
# Similarity and Paraphrase Tasks # Similarity and Paraphrase Tasks
...@@ -449,16 +424,10 @@ class MRPC(Task): ...@@ -449,16 +424,10 @@ class MRPC(Task):
} }
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True, "f1": True}
"acc": True,
"f1": True
}
def aggregation(self): def aggregation(self):
return { return {"acc": mean, "f1": f1_score}
"acc": mean,
"f1": f1_score
}
class QQP(Task): class QQP(Task):
...@@ -507,16 +476,10 @@ class QQP(Task): ...@@ -507,16 +476,10 @@ class QQP(Task):
} }
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True, "f1": True}
"acc": True,
"f1": True
}
def aggregation(self): def aggregation(self):
return { return {"acc": mean, "f1": f1_score}
"acc": mean,
"f1": f1_score
}
class STSB(Task): class STSB(Task):
...@@ -554,7 +517,7 @@ class STSB(Task): ...@@ -554,7 +517,7 @@ class STSB(Task):
return " {}".format(doc["label"]) return " {}".format(doc["label"])
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of """Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM. Requests which will be sent to the LM.
:param doc: :param doc:
...@@ -565,7 +528,7 @@ class STSB(Task): ...@@ -565,7 +528,7 @@ class STSB(Task):
part of the document for `doc`. part of the document for `doc`.
""" """
# TODO: implement evaluation. # TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented') raise NotImplementedError("Evaluation not implemented")
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a """Take a single document and the LM results and evaluates, returning a
...@@ -578,7 +541,7 @@ class STSB(Task): ...@@ -578,7 +541,7 @@ class STSB(Task):
The results of the requests created in construct_requests. The results of the requests created in construct_requests.
""" """
# TODO: implement evaluation. # TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented') raise NotImplementedError("Evaluation not implemented")
def aggregation(self): def aggregation(self):
""" """
...@@ -587,7 +550,7 @@ class STSB(Task): ...@@ -587,7 +550,7 @@ class STSB(Task):
functions that aggregate a list of metrics functions that aggregate a list of metrics
""" """
# TODO: implement evaluation. # TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented') raise NotImplementedError("Evaluation not implemented")
def higher_is_better(self): def higher_is_better(self):
""" """
...@@ -596,4 +559,4 @@ class STSB(Task): ...@@ -596,4 +559,4 @@ class STSB(Task):
whether a higher value of the submetric is better whether a higher value of the submetric is better
""" """
# TODO: implement evaluation. # TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented') raise NotImplementedError("Evaluation not implemented")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment