Commit ed53d51c authored by Rayyyyy's avatar Rayyyyy
Browse files

first add

parents
Pipeline #864 failed with stages
in 0 seconds
import re
import string
import timeit
import pickle
import traceback
from pprint import pprint
# This is a cpp module. Compile janitor_util.cpp with:
# c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup
try:
import janitor_util
JANITOR_CPP = True
except Exception:
print("WARNING: C++ module could not be loaded. Janitor running in python mode")
traceback.print_exc()
JANITOR_CPP = False
# Implementation from nltk source
# https://www.nltk.org/_modules/nltk/util.html
def form_ngrams(sequence, n):
history = []
while n > 1:
# PEP 479, prevent RuntimeError from being raised when StopIteration bubbles out of generator
try:
next_item = next(sequence)
except StopIteration:
# no more data, terminate the generator
return
history.append(next_item)
n -= 1
for item in sequence:
history.append(item)
yield tuple(history)
del history[0]
def word_ngrams(s, n):
"""Splits a string into ngram words"""
tokens = s.split() # not a generator :(
ngram_seqs = form_ngrams(iter(tokens), n)
return (" ".join(ngram) for ngram in ngram_seqs)
# Does character sequences only - combined faster function to play around with later
# def word_ngrams_indices_combined(sequence, n):
# current_word = ""
# history = []
# gap = False;
# start = 0
# end = 0
# for character in sequence:
# if character == " ":
# if not gap:
# gap = True
# history.append(current_word)
# end += len(current_word) - 1
# current_word = ""
# if len(history) == n:
# yield (tuple(history), start, end)
# del history[0]
# start = end + 1
# end = start
# else:
# gap = False
# current_word += character
# https://stackoverflow.com/questions/13734451/string-split-with-indices-in-python
def split_indices(s):
"""Splits a string on whitespaces and records the indices of each in the original string.
@:return generator((word, (start_idx, end_idx)), ...)
"""
return ((m.group(0), (m.start(), m.end() - 1)) for m in re.finditer(r"\S+", s))
def word_ngrams_indices(s, n):
"""Splits a string into pairs of (ngram words, their start/end indices)"""
tokens_with_indices = split_indices(s)
# Generator of ngrams of (word, idx_pairs)
# (
# [(word, (start,end)), (word, (start, end))...],
# [(word, (start, end)), ...],
# ...
# )
ngram_seqs_with_indices = form_ngrams(tokens_with_indices, n)
# Generator of pairs of word and index ngrams
# (
# ([word, word, ...], [(start,end), (start,end), ...]),
# ...
# )
ngram_indices_pairs = (
zip(*ngram_with_indices) for ngram_with_indices in ngram_seqs_with_indices
)
# Generator of ( (word_ngram, (start, end)), (word_ngram, start, end)), ...)
return (
(" ".join(ngram_seq), (indices[0][0], indices[-1][1]))
for ngram_seq, indices in ngram_indices_pairs
)
class Janitor:
# FIXME delete_chars: Should anything else go here? Special chars?
def __init__(
self,
ngram_n=13,
window_to_remove=200,
too_dirty_cutoff=10,
minimum_slice_length=200,
delete_chars=string.punctuation,
):
self.ngram_n = ngram_n
self.window_to_remove = window_to_remove
self.too_dirty_cutoff = too_dirty_cutoff
self.minimum_slice_length = minimum_slice_length
self.delete_chars = delete_chars
self.dirt_ngrams = set()
# If in python, we'll translate uppercase to lowercase and delete naughty characters.
# This is fast by python standards
# https://stackoverflow.com/questions/638893/what-is-the-most-efficient-way-in-python-to-convert-a-string-to-all-lowercase-st
self.translation_table = str.maketrans(
string.ascii_lowercase + string.ascii_uppercase, # These characters
string.ascii_lowercase * 2, # Become these characters
self.delete_chars, # These are deleted
)
##############
# I/O for saving contamination ngrams
##############
def save_contamination_ngrams(self, filename):
with open(filename, "wb") as fp:
pickle.dump(filename, fp)
def load_contamination_ngrams(self, filename):
with open(filename, "rb") as fp:
self.dirt_ngrams = pickle.load(fp)
##############
# Call these :)
##############
def register_contaminant(self, dirt_string):
"""Register a string as contamination to be removed, e.g. a test set
This breaks the dirt_string into ngrams to store for future cleaning"""
if JANITOR_CPP:
return self.register_contaminant_cpp(dirt_string)
else:
print("WARNING: Janitor running in python mode")
return self.register_contaminant_python(dirt_string)
def clean(self, dirty_string):
"""Clean a string (e.g. a training set) by removing all ngrams previously
registered as contaminants. Returns a list of clean chunks, or empty if
the string was too dirty"""
if JANITOR_CPP:
return self.clean_cpp(dirty_string)
else:
print("WARNING: Janitor running in python mode")
return self.clean_python(dirty_string)
def _split_chunks(self, dirty_string, dirty_parts):
clean_chunks = []
splice_idx = 0
end = -1
for i, (ngram, start, end) in enumerate(dirty_parts):
if i >= self.too_dirty_cutoff:
return []
start = max(0, start - self.window_to_remove)
end = min(len(dirty_string), end + self.window_to_remove)
if start - splice_idx > self.minimum_slice_length:
clean_chunks.append(dirty_string[splice_idx:start])
splice_idx = end
if end < len(dirty_string) - self.minimum_slice_length:
clean_chunks.append(dirty_string[end + 1 :])
return clean_chunks
##############
# Fast C++
##############
def register_contaminant_cpp(self, dirt_string):
self.dirt_ngrams.update(
janitor_util.clean_ngram(dirt_string, self.delete_chars, self.ngram_n)
)
def clean_cpp(self, dirty_string):
contamination_indices = janitor_util.clean_ngram_with_indices(
dirty_string, self.delete_chars, self.ngram_n
)
return self._split_chunks(dirty_string, contamination_indices)
##############
# Slow python
##############
def normalize_string(self, s):
return s.translate(self.translation_table)
def register_contaminant_python(self, dirt_string):
self.dirt_ngrams.update(
word_ngrams(self.normalize_string(dirt_string), self.ngram_n)
)
def clean_python(self, dirty_string):
contamination_indices = (
(None, *idx_pair)
for dirty_ngram, idx_pair in word_ngrams_indices(dirty_string, self.ngram_n)
if self.normalize_string(dirty_ngram) in self.dirt_ngrams
)
return self._split_chunks(dirty_string, contamination_indices)
##################################################################
# Tests
#################################################################
# def print_cpp():
# source = """ ,, I'm a very !dirty,, ,, dirty boy. Clean me daddy. \n\nhe he he hehe heh. lastword """ * 2
# for i in range(1, 10, 2):
# pprint(janitor_util.clean_ngram(source, string.punctuation, i))
# for ngram, start, end in \
# janitor_util.clean_ngram_with_indices(source, string.punctuation, i):
# print(ngram, "\t", start, end, source[start:end].replace("\n", "\\n"))
# def test_cpp():
# source = """ ,, I'm a very !dirty,, ,, dirty boy. Clean me daddy. \n\nhe he he hehe heh. lastword """ * 2
# contaminant = "dirty boy. Clean he he"
# jan_python = Janitor()
# jan_cpp = Janitor()
# jan_python.register_contaminant_python(contaminant)
# jan_cpp.register_contaminant(contaminant)
# assert jan_python.dirt_ngrams == jan_cpp.dirt_ngrams, (jan_python.dirt_ngrams, jan_cpp.dirt_ngrams)
# assert jan_python.clean_python(source) == jan_cpp.clean(source), \
# (jan_python.clean_python(source), jan_cpp.clean(source))
# print("Passed test, python==cpp")
# def benchmark():
# # Download and put in data folder: enwik8 (100 MB) from https://cs.fit.edu/~mmahoney/compression/textdata.html
# setup = \
# """
# with open("data/enwik8", "r") as f:
# data = f.read()
# jan = Janitor(too_dirty_cutoff=1000)
# jan.register_contaminant('''
# theories is that there is a connection between &quot;geekdom&quot; and autism.
# This is hinted, for instance, by a ''Wired Magazine'' article in 2001 entitled &quot;
# The [[Geek]] Syndrome&quot;, which is a point argued by many in the autism rights
# movement{{ref|Wired}}. This article, many professionals assert, is just one example of
# the media's application of mental disease labels to what is actually variant normal behavior
# &amp;mdash;they argue that shyness, lack of athletic ability or social skills, and intellectual
# interests, even when they seem unusual to others, are not in themselves signs of autism or
# Asperger's syndrome. Others assert that it is actually the medical profession which is applying
# mental disease labels to children who in the past would have simply been accepted as a little
# different or even labeled 'gifted'. See [[clinomorphism]] for further discussion of this issue.
# Due to the recent publicity surrounding autism and autis
# ultan Al Nahyan]] granted [[Petroleum]] concessions, and oil was first found in 1958. At first,
# oil money had a marginal impact. A few lowrise concete buildings were erected, and the first
# paved road was completed in 1961, but Sheikh Shakbut, uncertain whether the new oil royalties
# would last, took a cautious approach, preferring to save the revenue rather than investing it in
# development. His brother, [[Zayed bin Sultan Al Nahayan]], saw that oil wealth had the potential
# to transform Abu Dhabi. The ruling Al Nahayan family decided that Sheikh Zayed should replace his
# brother as Ruler and carry out his vision of developing the country. On [[August 6]], [[1966]],
# with the assistance of the British, Sheikh Zayed became the new ruler. See generally, Al-Fahim, M,
# ''From Rags to Riches: A Story of Abu Dhabi'', Chapter Six (London Centre of Arab Studies, 1995),
# ISBN 1 900404 00 1. With the announcement by Britain in 1968 that it would withdraw from the
# Gulf area by 1971, Sheikh Zayed became the main driving force behind the formation of the
# [[United Arab Emirates]]. After the Emirates gained independence in 1971,
# ''')
# """
# n = 1
# print(f"Timing {n} run on 100 MB")
# print("Register contaminant")
# # print("\tPython", timeit.timeit("jan.register_contaminant_python(data)", setup=setup, globals=globals(), number=n))
# print("\tCpp", timeit.timeit("jan.register_contaminant(data)", setup=setup, globals=globals(), number=n))
# print("Clean")
# # print("\tPython", timeit.timeit("jan.clean_python(data)", setup=setup, globals=globals(), number=n))
# print("\tCpp", timeit.timeit("jan.clean(data)", setup=setup, globals=globals(), number=n))
# def test_janitor_general():
# source = """ ,, I'm a very !dirty,, ,, dirty boy. Clean me daddy. \n\nhe he he hehe heh. lastword """ * 2
# contaminant = "dirty boy. Clean he he"
# jan = Janitor(ngram_n=3)
# jan.register_contaminant(contaminant)
# cleaned = " ".join(jan.clean(source))
# for contam in jan.dirt_ngrams:
# assert contam not in cleaned, contam
# filename = "data/saved_contam"
# jan.save_contamination_ngrams(filename)
# jan = Janitor(ngram_n=3)
# jan.load_contamination_ngrams(filename)
# cleaned = " ".join(jan.clean(source))
# for contam in jan.dirt_ngrams:
# assert contam not in cleaned, contam
# if __name__ == "__main__":
# test()
# # print_cpp()
# # test_cpp()
# # benchmark()
import collections
import itertools
import numpy as np
import random
import lm_eval.metrics
import lm_eval.models
import lm_eval.tasks
import lm_eval.base
from lm_eval.utils import positional_deprecated, run_task_tests
@positional_deprecated
def simple_evaluate(
model,
model_args=None,
tasks=[],
num_fewshot=0,
batch_size=None,
max_batch_size=None,
device=None,
no_cache=False,
limit=None,
bootstrap_iters=100000,
description_dict=None,
check_integrity=False,
decontamination_ngrams_path=None,
write_out=False,
output_base_path=None,
log_samples=True,
gen_kwargs=None,
):
"""Instantiate and evaluate a model on a list of tasks.
:param model: Union[str, LM]
Name of model or LM object, see lm_eval.models.get_model
:param model_args: Optional[str]
String arguments for each model class, see LM.create_from_arg_string.
Ignored if `model` argument is a LM object.
:param tasks: list[Union[str, Task]]
List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
:param num_fewshot: int
Number of examples in few-shot context
:param batch_size: int or str, optional
Batch size for model
:param max_batch_size: int, optional
Maximal batch size to try with automatic batch size detection
:param device: str, optional
PyTorch device (e.g. "cpu" or "cuda:0") for running models
:param no_cache: bool
Whether or not to cache
:param limit: int or float, optional
Limit the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples.
:param bootstrap_iters:
Number of iterations for bootstrap statistics
:param description_dict: dict[str, str]
Dictionary of custom task descriptions of the form: `task_name: description`
:param check_integrity: bool
Whether to run the relevant part of the test suite for the tasks
:param write_out: bool
If True, write details about prompts and logits to json for all tasks
:param output_base_path: str, optional
Directory to which detailed eval info will be written. Defaults to present working dir.
:return
Dictionary of results
"""
random.seed(1234)
np.random.seed(1234)
assert tasks != [], "No tasks specified"
if isinstance(model, str):
if model_args is None:
model_args = ""
lm = lm_eval.models.get_model(model).create_from_arg_string(
model_args, {"batch_size": batch_size, "max_batch_size": max_batch_size, "device": device}
)
else:
assert isinstance(model, lm_eval.base.LM)
lm = model
if not no_cache:
lm = lm_eval.base.CachingLM(
lm,
"lm_cache/"
+ (model if isinstance(model, str) else model.model.config._name_or_path)
+ "_"
+ model_args.replace("=", "-").replace(",", "_").replace("/", "-")
+ ".db",
)
task_dict = lm_eval.tasks.get_task_dict(tasks)
if check_integrity:
run_task_tests(task_list=tasks)
results = evaluate(
lm=lm,
task_dict=task_dict,
num_fewshot=num_fewshot,
limit=limit,
bootstrap_iters=bootstrap_iters,
description_dict=description_dict,
decontamination_ngrams_path=decontamination_ngrams_path,
write_out=write_out,
output_base_path=output_base_path,
)
# add info about the model and few shot config
results["config"] = {
"model": (model if isinstance(model, str) else model.model.config._name_or_path),
"model_args": model_args,
"num_fewshot": num_fewshot,
"batch_size": batch_size,
"batch_sizes": list(lm.batch_sizes.values()) if hasattr(lm, "batch_sizes") else [],
"device": device,
"no_cache": no_cache,
"limit": limit,
"bootstrap_iters": bootstrap_iters,
"description_dict": description_dict,
}
return results
decontaminate_suffix = "_decontaminate"
@positional_deprecated
def evaluate(
lm,
task_dict,
provide_description=None,
num_fewshot=0,
limit=None,
bootstrap_iters=100000,
description_dict=None,
decontamination_ngrams_path=None,
write_out=False,
output_base_path=None,
):
"""Instantiate and evaluate a model on a list of tasks.
:param lm: obj
Language Model
:param task_dict: dict[str, Task]
Dictionary of tasks. Tasks will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
:param provide_description: bool
Not implemented, and this option is deprecated and will be removed in a future version in favor of a different description providing method
:param num_fewshot: int
Number of examples in few-shot context
:param limit: int, optional
Limit the number of examples per task (only use this for testing)
:param bootstrap_iters:
Number of iterations for bootstrap statistics
:param description_dict: dict[str, str]
Dictionary of custom task descriptions of the form: `task_name: description`
:param write_out: bool
If True, write all prompts, logits and metrics to json for offline analysis
:param output_base_path: str, optional
Directory to which detailed eval info will be written. Defaults to present working dir
:return
Dictionary of results
"""
# TODO: completely refactor this entire function to not be a huge mess, ideally breaking it down into smaller pieces
# TODO: todo: implement proper description-providing system
assert not provide_description # not implemented.
if provide_description is not None:
# nudge people to not specify it at all
print(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
decontaminate = decontamination_ngrams_path is not None
task_dict_items = [
(name, task)
for name, task in task_dict.items()
if (task.has_validation_docs() or task.has_test_docs())
]
results = collections.defaultdict(dict)
versions = collections.defaultdict(dict)
requests = collections.defaultdict(list)
requests_origin = collections.defaultdict(list)
overlaps = collections.defaultdict(list) # {task_name: contaminated_docs}
# If we ever run into issues where the eval tasks don't fit in memory and we can't afford a machine with bigger
# memory, we can always modify this plumbing to support that, but I didn't want to include it just yet because
# over-engineering is bad (or we could make it write the requests to disk and then read them back out again
# - probably using an sqlite db because of all the moving parts we have
# TODO: we need unit tests & sanity checks or something to ensure that the return of `validation_docs` is stable
docs = {}
write_out_info = {}
docs_for_decontamination = collections.defaultdict(list)
# get lists of each type of request
for task_name, task in task_dict_items:
versions[task_name] = task.VERSION
# default to test doc, fall back to val doc if validation unavailable
# TODO: the test-fallback-to-val system isn't final, we should revisit it at some point
if task.has_test_docs():
task_doc_func = task.test_docs
task_set = "test" # Required for caching in the decontamination
elif task.has_validation_docs():
task_set = "val" # Required for caching in the decontamination
task_doc_func = task.validation_docs
else:
raise RuntimeError("Task has neither test_docs nor validation_docs")
# deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order
task_docs = list(task_doc_func())
rnd = random.Random()
rnd.seed(42)
rnd.shuffle(task_docs)
print(f"Task: {task_name}; number of docs: {len(task_docs)}")
if write_out:
prompt_details = []
description = (
description_dict[task_name]
if description_dict and task_name in description_dict
else ""
)
if limit is not None:
limit = int(len(task_docs) * limit) if limit < 1.0 else int(limit)
for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)):
if decontaminate and task.should_decontaminate():
docs_for_decontamination[(task_name, task_set)].append(
task.doc_to_decontamination_query(doc)
)
docs[(task_name, doc_id)] = doc
ctx = task.fewshot_context(
doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description
)
reqs = task.construct_requests(doc, ctx)
if write_out:
prompt_details.append({"doc_id": doc_id})
# print the prompt for the first few documents
if doc_id < 1:
print(
f"Task: {task_name}; document {doc_id}; context prompt (starting on next line):\n{ctx}\n(end of prompt on previous line)"
)
print("Requests:", reqs)
if not isinstance(reqs, (list, tuple)):
reqs = [reqs]
for i, req in enumerate(reqs):
requests[req.request_type].append(req)
# i: index in requests for a single task instance
# doc_id: unique id that we can get back to a doc using `docs`
requests_origin[req.request_type].append((i, task_name, doc, doc_id))
if write_out:
prompt_details[-1][f"prompt_{i}"] = "".join(
(map(lambda x: "".join(x), req.args))
)
if write_out:
write_out_info[task_name] = prompt_details
# Compare all tasks/sets at once to ensure a single training set scan
if decontaminate:
from lm_eval.decontamination.decontaminate import get_train_overlap
print("Finding train/test overlap, please wait...")
overlaps = get_train_overlap(
docs_for_decontamination, decontamination_ngrams_path, limit
)
# all responses for each (task, doc)
process_res_queue = collections.defaultdict(list)
# execute each type of request
for reqtype, reqs in requests.items():
# TODO: right now, this code runs multiple separate LM requests for multiple Requests differing
# only in index. We could implement some kind of caching, but that would be more of a band-aid
# solution. we could also implement some kind of auto-grouping here;
# they should end up next to each other.
print("Running", reqtype, "requests")
resps = getattr(lm, reqtype)([req.args for req in reqs])
resps = [
x if req.index is None else x[req.index] for x, req in zip(resps, reqs)
]
for resp, (i, task_name, doc, doc_id) in zip(resps, requests_origin[reqtype]):
process_res_queue[(task_name, doc_id)].append((i, resp))
if write_out:
write_out_info[task_name][doc_id][f"logit_{i}"] = resp
task = task_dict[task_name]
if isinstance(task, lm_eval.base.MultipleChoiceTask):
write_out_info[task_name][doc_id]["truth"] = doc["gold"]
elif isinstance(task, lm_eval.tasks.winogrande.Winogrande):
write_out_info[task_name][doc_id]["truth"] = task.answer_to_num[
doc["answer"]
]
else:
write_out_info[task_name][doc_id]["truth"] = task.doc_to_target(doc)
vals = collections.defaultdict(list)
# unpack results and sort back in order and return control to Task
for (task_name, doc_id), requests in process_res_queue.items():
requests.sort(key=lambda x: x[0])
requests = [x[1] for x in requests]
task = task_dict[task_name]
doc = docs[(task_name, doc_id)]
metrics = task.process_results(doc, requests)
for metric, value in metrics.items():
vals[(task_name, metric)].append(value)
if write_out:
write_out_info[task_name][doc_id][metric] = str(value)
# Re-use the evaluation for the decontaminated set by just ignoring the overlaps
if decontaminate and task_name in overlaps:
if doc_id not in overlaps[task_name]:
vals[(task_name, metric + decontaminate_suffix)].append(value)
# aggregate results
for (task_name, metric), items in vals.items():
task = task_dict[task_name]
real_metric = metric # key when looking up the metric with task.aggregation
if metric.endswith(decontaminate_suffix):
real_metric = metric.replace(
decontaminate_suffix, ""
) # decontaminated still uses the same metric
results[task_name][metric] = task.aggregation()[real_metric](items)
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this
stderr = lm_eval.metrics.stderr_for_metric(
metric=task.aggregation()[real_metric],
bootstrap_iters=min(bootstrap_iters, 1000)
if metric in ["bleu", "chrf", "ter"]
else bootstrap_iters,
)
if stderr is not None:
results[task_name][metric + "_stderr"] = stderr(items)
if write_out:
import json
import pathlib
output_base_path = (
pathlib.Path(output_base_path)
if output_base_path is not None
else pathlib.Path(".")
)
try:
output_base_path.mkdir(parents=True, exist_ok=False)
except FileExistsError:
pass
for task_name, _ in task_dict_items:
with open(
output_base_path.joinpath(f"{task_name}_write_out_info.json"),
"w",
encoding="utf8",
) as fp:
json.dump(write_out_info[task_name], fp, indent=4, ensure_ascii=False)
return {"results": dict(results), "versions": dict(versions)}
def make_table(result_dict):
"""Generate table of results."""
from pytablewriter import MarkdownTableWriter, LatexTableWriter
md_writer = MarkdownTableWriter()
latex_writer = LatexTableWriter()
md_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"]
latex_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"]
values = []
for k, dic in result_dict["results"].items():
version = result_dict["versions"][k]
for m, v in dic.items():
if m.endswith("_stderr"):
continue
if m + "_stderr" in dic:
se = dic[m + "_stderr"]
values.append([k, version, m, "%.4f" % v, "±", "%.4f" % se])
else:
values.append([k, version, m, "%.4f" % v, "", ""])
k = ""
version = ""
md_writer.value_matrix = values
latex_writer.value_matrix = values
# todo: make latex table look good
# print(latex_writer.dumps())
return md_writer.dumps()
import math
from collections.abc import Iterable
import numpy as np
import sacrebleu
import sklearn.metrics
import random
def mean(arr):
return sum(arr) / len(arr)
def pop_stddev(arr):
mu = mean(arr)
return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / len(arr))
def sample_stddev(arr):
mu = mean(arr)
return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / (len(arr) - 1))
def mean_stderr(arr):
return sample_stddev(arr) / math.sqrt(len(arr))
def median(arr):
return arr[len(arr) // 2]
def matthews_corrcoef(items):
unzipped_list = list(zip(*items))
golds = unzipped_list[0]
preds = unzipped_list[1]
return sklearn.metrics.matthews_corrcoef(golds, preds)
def f1_score(items):
unzipped_list = list(zip(*items))
golds = unzipped_list[0]
preds = unzipped_list[1]
fscore = sklearn.metrics.f1_score(golds, preds)
return np.max(fscore)
def acc_all(items):
# Only count as correct if all answers are labeled correctly for each question
question_scoring_dict = {}
preds = list(zip(*items))[0]
docs = list(zip(*items))[1]
for doc, pred in zip(docs, preds):
paragraph_id = doc["idx"]["paragraph"]
question_id = doc["idx"]["question"]
if (paragraph_id, question_id) not in question_scoring_dict:
question_scoring_dict[(paragraph_id, question_id)] = []
gold_label = doc["label"] == 1
question_scoring_dict[(paragraph_id, question_id)].append(gold_label == pred)
acc = np.mean([int(all(x)) for x in question_scoring_dict.values()])
return acc
def acc_all_stderr(items):
# Only count as correct if all answers are labeled correctly for each question
question_scoring_dict = {}
preds = list(zip(*items))[0]
docs = list(zip(*items))[1]
for doc, pred in zip(docs, preds):
question_id = doc["idx"]["question"]
if question_id not in question_scoring_dict:
question_scoring_dict[question_id] = []
gold_label = doc["label"] == 1
question_scoring_dict[question_id].append(gold_label == pred)
acc = mean_stderr([int(all(x)) for x in question_scoring_dict.values()])
return acc
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
"""Compute max metric between prediction and each ground truth."""
scores_for_ground_truths = []
for ground_truth in ground_truths:
score = metric_fn(prediction, ground_truth)
scores_for_ground_truths.append(score)
return max(scores_for_ground_truths)
def perplexity(items):
return math.exp(-mean(items))
def weighted_mean(items):
a, b = zip(*items)
return sum(a) / sum(b)
def weighted_perplexity(items):
return math.exp(-weighted_mean(items))
def bits_per_byte(items):
return -weighted_mean(items) / math.log(2)
def bleu(items):
"""The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric
for evaluating a generated sentence to a reference sentence. It counts matching
n-grams in the candidate translation to n-grams in the reference text, where
1-gram or unigram would be each token and a bigram comparison would be each
word pair. The comparison is made regardless of word order
Source: https://machinelearningmastery.com/calculate-bleu-score-for-text-python/
Paper: https://www.aclweb.org/anthology/P02-1040/
Higher is better
"""
refs = list(zip(*items))[0]
preds = list(zip(*items))[1]
refs, preds = _sacreformat(refs, preds)
return sacrebleu.corpus_bleu(preds, refs).score
def chrf(items):
"""chrF++ is a tool for automatic evaluation of machine translation output
based on character n-gram precision and recall enhanced with word n-grams.
Source: https://github.com/m-popovic/chrF
Paper: https://www.aclweb.org/anthology/W15-3049.pdf
Higher is better # TODO I think
"""
refs = list(zip(*items))[0]
preds = list(zip(*items))[1]
refs, preds = _sacreformat(refs, preds)
return sacrebleu.corpus_chrf(preds, refs).score
def ter(items):
"""Translation Error Rate is an error metric for machine translation that
measures the number of edits required to change a system output into one
of the references
Source: http://www.cs.umd.edu/~snover/tercom/
Paper: http://mt-archive.info/AMTA-2006-Snover.pdf
Lower is better
"""
refs = list(zip(*items))[0]
preds = list(zip(*items))[1]
refs, preds = _sacreformat(refs, preds)
return sacrebleu.corpus_ter(preds, refs).score
def is_non_str_iterable(obj):
return isinstance(obj, Iterable) and not isinstance(obj, str)
def _sacreformat(refs, preds):
"""Format refs and preds for sacrebleu corpus calculation. It is very particular"""
# Sacrebleu expects (List[str], List[List[str])
# e.g. sacrebleu.corpus_bleu([pred_t], [[ref1_stream], [ref2_stream], ...])
# Note [ref1_stream] is the first reference for each pred.
# So lists are size N and (M, N) for N preds and M possible refs for each pred
# This is a different order of dimensions that I would expect
# We expect refs to be List[str] or List[List[str]], the outer list corresponding to preds
# Must become List[List[str]] with the inner list corresponding to preds
if not is_non_str_iterable(refs):
refs = list(refs)
if not is_non_str_iterable(refs[0]):
refs = [[ref] for ref in refs]
refs = list(zip(*refs))
# Note the number of refs in each ref list much match the number of preds
# We expect preds to be List[str] or List[List[str]]. Must become List[str]
if not is_non_str_iterable(preds):
preds = list(preds)
if is_non_str_iterable(preds[0]):
assert len(preds[0]) == 1, f"Pred must be a str, was {preds[0]}"
preds = [pred[0] for pred in preds]
return refs, preds
# stderr stuff
class _bootstrap_internal:
def __init__(self, f, n):
self.f = f
self.n = n
def __call__(self, v):
i, xs = v
rnd = random.Random()
rnd.seed(i)
res = []
for _ in range(self.n):
res.append(self.f(rnd.choices(xs, k=len(xs))))
return res
def bootstrap_stderr(f, xs, iters):
import multiprocessing as mp
pool = mp.Pool(mp.cpu_count())
# this gives a biased estimate of the stderr (i.e w/ the mean, it gives something
# equivalent to stderr calculated without Bessel's correction in the stddev.
# Unfortunately, I haven't been able to figure out what the right correction is
# to make the bootstrap unbiased - i considered multiplying by sqrt(n/(n-1)) but
# that would be ad-hoc and I can't prove that that would actually be an unbiased estimator)
# Thankfully, shouldn't matter because our samples are pretty big usually anyways
res = []
chunk_size = min(1000, iters)
from tqdm import tqdm
print("bootstrapping for stddev:", f.__name__)
for bootstrap in tqdm(
pool.imap(
_bootstrap_internal(f, chunk_size),
[(i, xs) for i in range(iters // chunk_size)],
),
total=iters // chunk_size,
):
# sample w replacement
res.extend(bootstrap)
pool.close()
return sample_stddev(res)
def stderr_for_metric(metric, bootstrap_iters):
bootstrappable = [
median,
matthews_corrcoef,
f1_score,
perplexity,
bleu,
chrf,
ter,
]
if metric in bootstrappable:
return lambda x: bootstrap_stderr(metric, x, iters=bootstrap_iters)
stderr = {mean: mean_stderr, acc_all: acc_all_stderr}
return stderr.get(metric, None)
def yesno(x):
if x:
return "yes"
else:
return "no"
from . import gpt2
from . import gpt3
from . import anthropic_llms
from . import huggingface
from . import textsynth
from . import dummy
from . import gguf
MODEL_REGISTRY = {
"hf": gpt2.HFLM,
"hf-causal": gpt2.HFLM,
"hf-causal-experimental": huggingface.AutoCausalLM,
"hf-seq2seq": huggingface.AutoSeq2SeqLM,
"gpt2": gpt2.GPT2LM,
"gpt3": gpt3.GPT3LM,
"anthropic": anthropic_llms.AnthropicLM,
"textsynth": textsynth.TextSynthLM,
"dummy": dummy.DummyLM,
"gguf": gguf.GGUFLM
}
def get_model(model_name):
return MODEL_REGISTRY[model_name]
import os
from lm_eval.base import BaseLM
from tqdm import tqdm
import time
def anthropic_completion(client, model, prompt, max_tokens_to_sample, temperature, stop):
"""Query Anthropic API for completion.
Retry with back-off until they respond
"""
import anthropic
backoff_time = 3
while True:
try:
response = client.completion(
prompt=f"{anthropic.HUMAN_PROMPT} {prompt}{anthropic.AI_PROMPT}",
model=model,
# NOTE: Claude really likes to do CoT, and overly aggressive stop sequences
# (e.g. gsm8k's ":") may truncate a lot of the input.
stop_sequences=[anthropic.HUMAN_PROMPT] + stop,
max_tokens_to_sample=max_tokens_to_sample,
temperature=temperature,
)
print(response)
return response["completion"]
except RuntimeError:
# TODO: I don't actually know what error Anthropic raises when it times out
# So err update this error when we find out.
import traceback
traceback.print_exc()
time.sleep(backoff_time)
backoff_time *= 1.5
class AnthropicLM(BaseLM):
REQ_CHUNK_SIZE = 20
def __init__(self, model):
"""
:param model: str
Anthropic model e.g. claude-instant-v1
"""
super().__init__()
import anthropic
self.model = model
self.client = anthropic.Client(os.environ['ANTHROPIC_API_KEY'])
@property
def eot_token_id(self):
raise NotImplementedError("No idea about anthropic tokenization.")
@property
def max_length(self):
return 2048
@property
def max_gen_toks(self):
return 256
@property
def batch_size(self):
# Isn't used because we override _loglikelihood_tokens
raise NotImplementedError()
@property
def device(self):
# Isn't used because we override _loglikelihood_tokens
raise NotImplementedError()
def tok_encode(self, string: str):
raise NotImplementedError("No idea about anthropic tokenization.")
def tok_decode(self, tokens):
raise NotImplementedError("No idea about anthropic tokenization.")
def _loglikelihood_tokens(self, requests, disable_tqdm=False):
raise NotImplementedError("No support for logits.")
def greedy_until(self, requests):
if not requests:
return []
res = []
for request in tqdm(requests):
inp = request[0]
request_args = request[1]
until = request_args["until"]
response = anthropic_completion(
client=self.client,
model=self.model,
prompt=inp,
max_tokens_to_sample=self.max_gen_toks,
temperature=0.0,
stop=until,
)
res.append(response)
return res
def _model_call(self, inps):
# Isn't used because we override _loglikelihood_tokens
raise NotImplementedError()
def _model_generate(self, context, max_length, eos_token_id):
# Isn't used because we override greedy_until
raise NotImplementedError()
import random
from lm_eval.base import LM
class DummyLM(LM):
def __init__(self):
pass
@classmethod
def create_from_arg_string(cls, arg_string, additional_config=None):
return cls()
def loglikelihood(self, requests):
res = []
for _ in requests:
res.append((-random.random(), False))
return res
def greedy_until(self, requests):
res = []
for ctx, _ in requests:
res.append("lol")
assert ctx.strip() != ""
return res
def loglikelihood_rolling(self, requests):
res = []
for _ in requests:
res.append(-random.random())
return res
import requests
import logging
import time
from tqdm import tqdm
from requests.exceptions import RequestException
import transformers
from lm_eval.utils import Reorderer
from lm_eval.base import BaseLM
logger = logging.getLogger(__name__)
def get_result(logprobs, context_length):
is_greedy = True
offsets = logprobs['text_offset']
tokens = logprobs['tokens']
tokens_logprobs = logprobs['token_logprobs']
idx = 0
while offsets[idx] < context_length:
idx += 1
continuation_logprobs = sum(tokens_logprobs[idx:-1])
for i in range(idx, len(tokens)):
token = tokens[i]
top_tokens = logprobs["top_logprobs"][i]
top_token = max(top_tokens.keys(), key=lambda x: top_tokens[x])
if top_token != token:
is_greedy = False
break
return continuation_logprobs, is_greedy
class GGUFLM(BaseLM):
def __init__(self, base_url, max_length=2048):
super().__init__()
self.base_url = base_url
self.logprobs = 10
self.temperature = 0.0
self.max_length = max_length
def gguf_completion(self, context, continuation=None, stop=None, retries=3, delay=5, **kwargs):
for _ in range(retries):
try:
prompt = context
request = {'prompt': prompt, 'logprobs': self.logprobs,
'temperature': self.temperature}
if continuation:
prompt += continuation
request.update({'prompt': prompt, 'max_tokens': 1, 'echo': True})
if stop is not None:
request['stop'] = stop
response = requests.post(f"{self.base_url}/v1/completions", json=request)
response.raise_for_status()
return response.json()
except RequestException as e:
logger.error(f"RequestException: {e}")
time.sleep(delay) # wait before retrying
else:
raise Exception(f"Failed to get a valid response after {retries} retries.")
def loglikelihood(self, requests):
if not requests:
return []
res = []
for context, continuation in tqdm(requests):
response = self.gguf_completion(context=context, continuation=continuation)
if response and "choices" in response and response["choices"]:
choice = response["choices"][0]
logprobs = choice.get("logprobs")
if logprobs and "token_logprobs" in logprobs and logprobs["token_logprobs"]:
logprob, is_greedy = get_result(logprobs, len(context))
res.append((logprob, is_greedy))
else:
logger.warning("Invalid logprobs data. Expected 'logprobs' to contain 'token_logprobs' list.")
else:
logger.error(f"Invalid response for loglikelihood. Response: {response}")
assert False
return res
def greedy_until(self, requests):
if not requests:
return []
res = []
for request in tqdm(requests):
inp = request[0]
request_args = request[1]
until = request_args["until"]
response = self.gguf_completion(context=inp, stop=until)
if response and "choices" in response and response["choices"]:
choice = response["choices"][0]
if "text" in choice:
generated_text = choice["text"].strip()
res.append(generated_text)
else:
logger.error(f"Invalid response for greedy_until. Response: {response}")
res.append(None) # Add default value in case of error
else:
logger.error(f"Invalid response for greedy_until. Response: {response}")
res.append(None) # Add default value in case of error
return res
def loglikelihood_rolling(self, requests):
raise NotImplementedError("loglikelihood_rolling not yet supported for GGUF models")
def _model_call(self, inps):
# Placeholder implementation
raise NotImplementedError()
def _model_generate(self, context, max_length, eos_token_id):
# Placeholder implementation
raise NotImplementedError()
def tok_encode(self, string: str):
raise NotImplementedError()
def tok_decode(self, tokens):
raise NotImplementedError()
@property
def batch_size(self):
# Placeholder implementation
raise NotImplementedError()
@property
def device(self):
# Placeholder implementation
raise NotImplementedError()
@property
def eot_token_id(self):
# Placeholder implementation
raise NotImplementedError()
def max_length(self):
return self.max_length
@property
def max_gen_toks(self):
# Placeholder implementation
raise NotImplementedError()
import torch
import transformers
from typing import Optional, Union
from lm_eval.base import BaseLM
def _get_dtype(
dtype: Union[str, torch.dtype]
) -> torch.dtype:
"""Converts `dtype` from `str` to torch.dtype when possible. Does not use an instantiated HF AutoConfig"""
if isinstance(dtype, str) and dtype != "auto":
# Convert `str` args torch dtype: `float16` -> `torch.float16`
_torch_dtype = getattr(torch, dtype)
else:
_torch_dtype = dtype
return _torch_dtype
class HFLM(BaseLM):
_DEFAULT_MAX_LENGTH = 2048
def __init__(
self,
device="cuda",
pretrained="gpt2",
revision="main",
low_cpu_mem_usage=None,
subfolder=None,
tokenizer=None,
batch_size=1,
max_length=None,
load_in_8bit: Optional[bool] = False,
trust_remote_code: Optional[bool] = False,
dtype: Optional[Union[str, torch.dtype]]="auto",
):
super().__init__()
assert isinstance(device, str)
assert isinstance(pretrained, str)
assert isinstance(batch_size, (int, str))
device_list = set(
["cuda", "cpu"] + [f"cuda:{i}" for i in range(torch.cuda.device_count())]
)
if device and device in device_list:
self._device = torch.device(device)
print(f"Using device '{device}'")
else:
print("Device not specified")
print(f"Cuda Available? {torch.cuda.is_available()}")
self._device = (
torch.device("cuda")
if torch.cuda.is_available()
else torch.device("cpu")
)
# TODO: update this to be less of a hack once subfolder is fixed in HF
revision = revision + ("/" + subfolder if subfolder is not None else "")
self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained(
pretrained,
load_in_8bit=load_in_8bit,
low_cpu_mem_usage=low_cpu_mem_usage,
revision=revision,
torch_dtype=_get_dtype(dtype),
trust_remote_code=trust_remote_code,
).eval()
if not load_in_8bit:
try:
self.gpt2.to(self.device)
except:
print("Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes`. If the desired GPU is being used, this message is safe to ignore.")
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
pretrained if tokenizer is None else tokenizer,
revision=revision,
trust_remote_code=trust_remote_code,
)
self.vocab_size = self.tokenizer.vocab_size
# setup for automatic batch size detection
if batch_size == "auto":
self.batch_size_per_gpu = batch_size
else:
self.batch_size_per_gpu = int(batch_size)
self._max_length = max_length
@property
def eot_token_id(self):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
return self.tokenizer.eos_token_id
@property
def max_length(self):
if self._max_length: # if max length manually set, return it
return self._max_length
seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx")
for attr in seqlen_config_attrs:
if hasattr(self.gpt2.config, attr):
return getattr(self.gpt2.config, attr)
if hasattr(self.tokenizer, "model_max_length"):
if self.tokenizer.model_max_length == 1000000000000000019884624838656:
return self._DEFAULT_MAX_LENGTH
return self.tokenizer.model_max_length
return self._DEFAULT_MAX_LENGTH
@property
def max_gen_toks(self):
return 256
@property
def batch_size(self):
# TODO: fix multi-gpu
return self.batch_size_per_gpu # * gpus
@property
def device(self):
# TODO: fix multi-gpu
return self._device
def tok_encode(self, string: str):
return self.tokenizer.encode(string, add_special_tokens=False)
def tok_decode(self, tokens):
return self.tokenizer.decode(tokens)
def _model_call(self, inps):
"""
inps: a torch tensor of shape [batch, sequence]
the size of sequence may vary from call to call
returns: a torch tensor of shape [batch, sequence, vocab] with the
logits returned from the model
"""
with torch.no_grad():
return self.gpt2(inps)[0]
def _model_generate(self, context, max_length, eos_token_id):
generation_kwargs = {"do_sample": False, "max_length": max_length}
if eos_token_id is not None:
generation_kwargs['eos_token_id'] = eos_token_id
generation_kwargs['pad_token_id'] = eos_token_id # setting eos_token_id as pad token
return self.gpt2.generate(context, **generation_kwargs)
# for backwards compatibility
GPT2LM = HFLM
import os
import numpy as np
import transformers
from lm_eval.base import BaseLM
from lm_eval import utils
from tqdm import tqdm
import time
def get_result(response, ctxlen):
"""Process results from OpenAI API response.
:param response: dict
OpenAI API Response
:param ctxlen: int
Length of context (so we can slice them away and only keep the predictions)
:return:
continuation_logprobs: np.array
Log probabilities of continuation tokens
is_greedy: bool
whether argmax matches given continuation exactly
"""
is_greedy = True
logprobs = response["logprobs"]["token_logprobs"]
continuation_logprobs = sum(logprobs[ctxlen:])
for i in range(ctxlen, len(response["logprobs"]["tokens"])):
token = response["logprobs"]["tokens"][i]
top_tokens = response["logprobs"]["top_logprobs"][i]
top_token = max(top_tokens.keys(), key=lambda x: top_tokens[x])
if top_token != token:
is_greedy = False
break
return continuation_logprobs, is_greedy
def oa_completion(**kwargs):
"""Query OpenAI API for completion.
Retry with back-off until they respond
"""
import openai
backoff_time = 3
while True:
try:
return openai.Completion.create(**kwargs)
except openai.error.OpenAIError:
import traceback
traceback.print_exc()
time.sleep(backoff_time)
backoff_time *= 1.5
class GPT3LM(BaseLM):
REQ_CHUNK_SIZE = 20
def __init__(self, engine, truncate=False):
"""
:param engine: str
OpenAI API engine (e.g. davinci)
:param truncate: bool
Truncate input if too long (if False and input is too long, throw error)
"""
super().__init__()
import openai
self.engine = engine
self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained("gpt2")
self.vocab_size = self.tokenizer.vocab_size
# to make the annoying "Using pad_token, but it is not set yet." error go away
self.tokenizer.pad_token = "<|endoftext|>"
assert self.tokenizer.encode("hello\n\nhello") == [31373, 198, 198, 31373]
self.truncate = truncate
self.end_of_text_token_id = self.tokenizer.convert_tokens_to_ids(
["<|endoftext|>"]
)[0]
# Read from environment variable OPENAI_API_SECRET_KEY
openai.api_key = os.environ["OPENAI_API_SECRET_KEY"]
@property
def eot_token_id(self):
return self.tokenizer.eos_token_id
@property
def max_length(self):
# Note: the OpenAI API supports up to 2049 tokens, with the first token being the first input token
return 2048
@property
def max_gen_toks(self):
return 256
@property
def batch_size(self):
# Isn't used because we override _loglikelihood_tokens
raise NotImplementedError()
@property
def device(self):
# Isn't used because we override _loglikelihood_tokens
raise NotImplementedError()
def tok_encode(self, string: str):
return self.tokenizer.encode(string, add_special_tokens=False)
def tok_decode(self, tokens):
return self.tokenizer.decode(tokens)
def _loglikelihood_tokens(self, requests, disable_tqdm=False):
res = []
def _collate(x):
# this doesn't efficiently handle last-token differences yet, but those are kinda annoying because
# it's not guaranteed that the 100 or so logprobs we get to see actually contain all the continuations
# we care about and so we need some kind of backup for when it isn't
toks = x[1] + x[2]
return -len(toks), tuple(toks)
re_ord = utils.Reorderer(requests, _collate)
for chunk in tqdm(
list(utils.chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE)),
disable=disable_tqdm,
):
inps = []
ctxlens = []
for cache_key, context_enc, continuation_enc in chunk:
# max_length+1 because the API takes up to 2049 tokens, including the first context token
inp = (context_enc + continuation_enc)[-(self.max_length + 1) :]
# TODO: the logic is much simpler if we just look at the length of continuation tokens
ctxlen = len(context_enc) - max(
0, len(context_enc) + len(continuation_enc) - (self.max_length + 1)
)
inps.append(inp)
ctxlens.append(ctxlen)
response = oa_completion(
engine=self.engine,
prompt=inps,
echo=True,
max_tokens=0,
temperature=0.0,
logprobs=10,
)
for resp, ctxlen, (cache_key, context_enc, continuation_enc) in zip(
response.choices, ctxlens, chunk
):
answer = get_result(resp, ctxlen)
res.append(answer)
# partial caching
if cache_key is not None:
self.cache_hook.add_partial("loglikelihood", cache_key, answer)
return re_ord.get_original(res)
def greedy_until(self, requests):
if not requests:
return []
res = []
def _collate(x):
toks = self.tok_encode(x[0])
return len(toks), x[0]
re_ord = utils.Reorderer(requests, _collate)
def sameuntil_chunks(xs, size):
ret = []
lastuntil = xs[0][1]
for x in xs:
if len(ret) >= size or x[1] != lastuntil:
yield ret, lastuntil
ret = []
lastuntil = x[1]
ret.append(x)
if ret:
yield ret, lastuntil
# todo: more intelligent batching for heterogeneous `until`
for chunk, until in tqdm(
list(sameuntil_chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE))
):
inps = []
for context, _ in chunk:
context_enc = self.tok_encode(context)
inp = context_enc[-(self.max_length - self.max_gen_toks) :]
inps.append(inp)
response = oa_completion(
engine=self.engine,
prompt=inps,
max_tokens=self.max_gen_toks,
temperature=0.0,
logprobs=10,
stop=until,
)
for resp, (context, until_) in zip(response.choices, chunk):
s = resp["text"]
for term in until_:
s = s.split(term)[0]
# partial caching
self.cache_hook.add_partial("greedy_until", (context, until_), s)
res.append(s)
return re_ord.get_original(res)
def _model_call(self, inps):
# Isn't used because we override _loglikelihood_tokens
raise NotImplementedError()
def _model_generate(self, context, max_length, eos_token_id):
# Isn't used because we override greedy_until
raise NotImplementedError()
import math
import torch
import torch.nn.functional as F
import transformers
import peft
from peft import __version__ as PEFT_VERSION
from pathlib import Path
from typing import List, Mapping, NewType, Optional, Tuple, Union
from tqdm import tqdm
from transformers import BatchEncoding
from lm_eval import utils
from lm_eval.base import BaseLM
TokenSequence = Union[List[int], torch.LongTensor, torch.Tensor, BatchEncoding]
_DeviceMapping = NewType("DeviceMapping", Mapping[str, Union[int, str, torch.device]])
def _get_accelerate_args(
device_map_option: Optional[str] = "auto",
max_memory_per_gpu: Optional[Union[int, str]] = None,
max_cpu_memory: Optional[Union[int, str]] = None,
offload_folder: Optional[str] = "./offload",
) -> dict:
"""Returns the kwargs needed to apply `accelerate` in `AutoModel.from_pretrained`."""
max_memory = {}
if max_memory_per_gpu is not None:
max_memory_per_gpu_map = {
device_idx: max_memory_per_gpu
for device_idx in range(torch.cuda.device_count())
}
max_memory.update(max_memory_per_gpu_map)
if max_cpu_memory is not None:
max_memory["cpu"] = max_cpu_memory
args = {}
if max_memory:
args["max_memory"] = max_memory
args["device_map"] = device_map_option
args["offload_folder"] = offload_folder
return args
def _get_dtype(
dtype: Union[str, torch.dtype], config: Optional[transformers.AutoConfig] = None
) -> torch.dtype:
"""Converts `dtype` from `str` to torch.dtype when possible."""
if dtype is None and config is not None:
_torch_dtype = config.torch_dtype
elif isinstance(dtype, str) and dtype != "auto":
# Convert `str` args torch dtype: `float16` -> `torch.float16`
_torch_dtype = getattr(torch, dtype)
else:
_torch_dtype = dtype
return _torch_dtype
class HuggingFaceAutoLM(BaseLM):
AUTO_CONFIG_CLASS: transformers.AutoConfig = transformers.AutoConfig
AUTO_TOKENIZER_CLASS: transformers.AutoTokenizer = transformers.AutoTokenizer
AUTO_MODEL_CLASS: transformers.AutoModel = None
AUTO_PEFT_CLASS: peft.PeftModel = None
# Default max sequence length setting for when no `max_length` is provided
# or no max length config setting is found in the model or tokenizer.
_DEFAULT_MAX_LENGTH: int = 2048
def __init__(
self,
pretrained: str,
quantized: Optional[Union[bool, str]] = False,
tokenizer: Optional[str] = None,
subfolder: Optional[str] = None,
revision: Optional[str] = "main",
batch_size: Optional[Union[int, str]] = 1,
max_batch_size: Optional[int] = 512,
max_gen_toks: Optional[int] = 256,
max_length: Optional[int] = None,
add_special_tokens: Optional[bool] = None,
use_accelerate: Optional[bool] = False,
device_map_option: Optional[str] = "auto",
max_memory_per_gpu: Optional[Union[int, str]] = None,
max_cpu_memory: Optional[Union[int, str]] = None,
offload_folder: Optional[str] = "./offload",
dtype: Optional[Union[str, torch.dtype]] = None,
device: Optional[Union[int, str]] = "cuda",
peft: str = None,
load_in_8bit: Optional[bool] = False,
load_in_4bit: Optional[bool] = False,
trust_remote_code: Optional[bool] = False,
gptq_use_triton: Optional[bool] = False,
bnb_4bit_quant_type: Optional[str] = None,
bnb_4bit_compute_dtype: Optional[Union[str, torch.dtype]] = None,
):
"""Initializes a HuggingFace `AutoModel` and `AutoTokenizer` for evaluation.
Args:
pretrained (str):
The HuggingFace Hub model ID name or the path to a pre-trained
model to load. This is effectively the `pretrained_model_name_or_path`
argument of `from_pretrained` in the HuggingFace `transformers` API.
quantized (str or bool, optional, defaults to False):
File name of a GPTQ quantized model to load. Set to `True` to use the
default name of the quantized model.
add_special_tokens (bool, optional, defaults to True):
Whether to add special tokens to the input sequences. If `None`, the
default value will be set to `True` for seq2seq models (e.g. T5) and
`False` for causal models.
WARNING: Evaluating causal models with `add_special_tokens=True` is
currently __not__ supported.
> Large model loading `accelerate` arguments
use_accelerate (bool, optional, defaults to False):
If True, uses the `accelerate` library to load a large model across
multiple devices.
device_map_option (str, optional, defaults to "auto"):
The device map option to use when loading the model with
`accelerate`.
Options:
"auto", "balanced", "balanced_low_0", "sequential"
See the `accelerate` docs for more details on these options:
https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained.device_map
max_memory_per_gpu (Union[int, str], optional, defaults to None):
The maximum memory available for each GPU in bytes as `int` or in
the format f"{significand}{unit_symbol}" where {unit_symbol} is
any of ["GB", "MB", "GIB", "MIB"]. Refer to the `max_memory` arg in
the "Parameters for big model inference" section of the following
docs:
https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained.max_memory
max_cpu_memory (Union[int, str], optional, defaults to None):
The maximum available CPU RAM in bytes as `int` or in the format
f"{significand}{unit_symbol}" where {unit_symbol} is any of
["GB", "MB", "GIB", "MIB"]. Refer to the `max_memory` arg in the
"Parameters for big model inference" section of the following docs:
https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained.max_memory
offload_folder (str, optional, defaults to "./offload"):
The folder to offload weights into if `device_map` contains any
"disk" value.
dtype (Union[str, torch.dtype], optional, defaults to None):):
Converts the model weights to `dtype`, if specified. Strings get
converted to `torch.dtype` objects (e.g. `float16` -> `torch.float16`).
Use `dtype="auto"` to derive the type from the model’s weights.
peft (str, optional, defaults to None):
Path of the adapter weights to load from Huggingface. This will usually
include a directory that includes the files `adapter_config.json` and
`adapter_model.bin`. Compatible with [PEFT](https://github.com/huggingface/peft)
load_in_8bit (bool, optional, defaults to False):
If True, will convert the loaded model into mixed-8bit quantized model. See:
https://huggingface.co/docs/transformers/main/en/main_classes/quantization#load-a-large-model-in-8bit
load_in_4bit (bool, optional, defaults to False):
If True, will convert the loaded model into mixed-4bit quantized model. See:
https://huggingface.co/docs/transformers/main/en/main_classes/quantization#load-a-large-model-in-4bit
trust_remote_code (bool, optional, defaults to False):
If True, will trust the remote code when loading the model.
gptq_use_triton (bool, optional, defaults to False):
Use Triton for GPTQ inference.
bnb_4bit_quant_type (str, optional, defaults to None):
The quantization type to use for BnB 4bit quantization. See:
https://github.com/huggingface/transformers/blob/main/src/transformers/utils/quantization_config.py#L77
bnb_4bit_compute_dtype (Union[str, torch.dtype], optional, defaults to None):
The compute dtype to use for BnB 4bit quantization. See:
https://github.com/huggingface/transformers/blob/main/src/transformers/utils/quantization_config.py#L74
"""
super().__init__()
assert isinstance(pretrained, str)
assert isinstance(device, str)
assert isinstance(batch_size, (int, str))
if (
add_special_tokens is not None
and self.AUTO_MODEL_CLASS is transformers.AutoModelForCausalLM
):
# TODO: Support evaluating causal models with special tokens. Currently,
# this is not possible because the `_loglikelihood_tokens()` method for
# causal LMs makes a no-special-tokens assumption given that contexts
# and labels/continuations are tokenized separately without special
# tokens, concatenated, and then processed as inputs.
assert (
not add_special_tokens
), "Evaluating causal models with `add_special_tokens=True` is currently not supported."
# setup for automatic batch size detection
if str(batch_size).startswith("auto"):
batch_size = batch_size.split(":")
self._batch_size = batch_size[0]
self.batch_schedule = float(batch_size[1]) if len(batch_size) > 1 else 1
else:
self._batch_size = int(batch_size)
self.max_batch_size = max_batch_size
self._max_gen_toks = max_gen_toks
self._max_length = max_length
self._config = self.AUTO_CONFIG_CLASS.from_pretrained(
pretrained,
trust_remote_code=trust_remote_code,
revision=revision + ("/" + subfolder if subfolder is not None else ""),
)
self._add_special_tokens = add_special_tokens
self.tokenizer = self._create_auto_tokenizer(
pretrained=pretrained,
revision=revision,
subfolder=subfolder,
tokenizer=tokenizer,
)
self.tokenizer.model_max_length = self.max_length
model_kwargs = {}
if use_accelerate:
model_kwargs = _get_accelerate_args(
device_map_option,
max_memory_per_gpu,
max_cpu_memory,
offload_folder,
)
self.model = self._create_auto_model(
pretrained=pretrained,
quantized=quantized,
trust_remote_code=trust_remote_code,
revision=revision,
subfolder=subfolder,
torch_dtype=_get_dtype(dtype, self._config),
gptq_use_triton=gptq_use_triton,
load_in_8bit=load_in_8bit,
load_in_4bit=load_in_4bit,
bnb_4bit_quant_type=bnb_4bit_quant_type,
bnb_4bit_compute_dtype=bnb_4bit_compute_dtype,
**model_kwargs,
)
# note: peft_path can be different than pretrained model path
if peft is not None:
self.model = self._create_auto_model_peft(
model=self.model,
peft=peft,
revision=revision,
subfolder=subfolder,
load_in_4bit=load_in_4bit,
)
self.model.eval()
torch.set_grad_enabled(False)
self._device = device
if use_accelerate and "lm_head" in self.model.hf_device_map:
# `accelerate` can place `lm_head` weights on a different device than
# the user specified one so we force `self._device` to be the same as
# `lm_head`'s.
self._device = self.model.hf_device_map["lm_head"]
if not use_accelerate and not (load_in_4bit or load_in_8bit):
try:
self.model.to(self._device)
except:
print("Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes`. If the desired GPU is being used, this message is safe to ignore.")
def _create_auto_model(
self,
*,
pretrained: str,
quantized: Optional[Union[bool, str]] = False,
revision: str,
subfolder: str,
device_map: Optional[Union[str, _DeviceMapping]] = None,
max_memory: Optional[dict] = None,
offload_folder: Optional[str] = None,
load_in_8bit: Optional[bool] = False,
load_in_4bit: Optional[bool] = False,
trust_remote_code: Optional[bool] = False,
torch_dtype: Optional[Union[str, torch.dtype]] = None,
gptq_use_triton: Optional[bool] = False,
bnb_4bit_quant_type: Optional[str] = None,
bnb_4bit_compute_dtype: Optional[Union[str, torch.dtype]] = None,
) -> transformers.AutoModel:
"""Returns a pre-trained pytorch model from a pre-trained model configuration."""
if not quantized:
if load_in_4bit:
assert transformers.__version__ >= "4.30.0", "load_in_4bit requires transformers >= 4.30.0"
model_kwargs = {}
if transformers.__version__ >= "4.30.0":
model_kwargs["load_in_4bit"] = load_in_4bit
if load_in_4bit:
model_kwargs["bnb_4bit_quant_type"] = bnb_4bit_quant_type
model_kwargs["bnb_4bit_compute_dtype"] = getattr(torch, bnb_4bit_compute_dtype)
model = self.AUTO_MODEL_CLASS.from_pretrained(
pretrained,
revision=revision + ("/" + subfolder if subfolder is not None else ""),
device_map=device_map,
max_memory=max_memory,
offload_folder=offload_folder,
load_in_8bit=load_in_8bit,
trust_remote_code=trust_remote_code,
torch_dtype=torch_dtype,
**model_kwargs,
)
else:
from auto_gptq import AutoGPTQForCausalLM
model = AutoGPTQForCausalLM.from_quantized(
pretrained,
model_basename=None if quantized == True else Path(quantized).stem,
device_map=device_map,
max_memory=max_memory,
trust_remote_code=trust_remote_code,
use_safetensors=True if quantized == True else quantized.endswith('.safetensors'),
use_triton=gptq_use_triton,
warmup_triton=gptq_use_triton,
)
return model
def _create_auto_model_peft(
self,
*,
model: transformers.PreTrainedModel,
peft: str,
revision: str,
subfolder: str,
load_in_4bit: Optional[bool] = False,
):
if load_in_4bit:
assert PEFT_VERSION >= "0.4.0", "load_in_4bit requires peft >= 0.4.0"
model = self.AUTO_PEFT_CLASS.from_pretrained(
model,
peft,
revision=revision + ("/" + subfolder if subfolder is not None else ""),
)
return model
def _create_auto_tokenizer(
self,
*,
pretrained: str,
revision: str,
subfolder: str,
tokenizer: Optional[str] = None,
) -> transformers.PreTrainedTokenizer:
"""Returns a pre-trained tokenizer from a pre-trained tokenizer configuration."""
tokenizer = self.AUTO_TOKENIZER_CLASS.from_pretrained(
pretrained if tokenizer is None else tokenizer,
revision=revision + ("/" + subfolder if subfolder is not None else ""),
)
tokenizer.pad_token = tokenizer.eos_token
return tokenizer
@property
def add_special_tokens(self) -> bool:
"""Whether to include special tokens in encoded text. This should be
determined by whether or not the model was trained with special tokens.
TODO: Remove these conditionals once HuggingFace supports a way to
check whether or not an arbitrary model was trained with special tokens.
"""
if self._add_special_tokens is not None:
return self._add_special_tokens
elif self.AUTO_MODEL_CLASS is transformers.AutoModelForCausalLM:
return False
elif self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM:
return True
else:
raise ValueError(
"Could not determine `add_special_tokens` value from the model "
"class. Set to `True` or `False` depending on whether the model "
"was pre-trained with special tokens."
)
@property
def eot_token(self) -> str:
return self.tokenizer.eos_token
@property
def eot_token_id(self) -> int:
return self.tokenizer.eos_token_id
@property
def max_gen_toks(self) -> int:
return self._max_gen_toks
@property
def max_length(self) -> int:
"""Return the maximum sequence length of the model.
NOTE: Different model configurations have different max sequence length
attribute names.
- n_positions: (CTRLConfig, T5Config)
- max_position_embeddings: (BartConfig, RoFormerConfig)
- n_ctx: (GPT2Config)
NOTE: For relative position encoded models you should specify the max
sequence length of the model in the constructor via `max_length`.
"""
if self._max_length is not None:
return self._max_length
# Try to get the sequence length from the model config.
seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx")
for attr in seqlen_config_attrs:
if hasattr(self._config, attr):
return getattr(self._config, attr)
if hasattr(self.tokenizer, "model_max_length"):
if self.tokenizer.model_max_length == 1000000000000000019884624838656:
return self._DEFAULT_MAX_LENGTH
return self.tokenizer.model_max_length
return self._DEFAULT_MAX_LENGTH
@property
def batch_size(self) -> int:
# TODO: Add adaptive batch size.
return self._batch_size # * gpus
@property
def device(self) -> Union[int, str, torch.device]:
return self._device
def tok_encode(self, string: str) -> TokenSequence:
# TODO: Merge `tok_encode_batch` here.
return self.tokenizer.encode(string, add_special_tokens=self.add_special_tokens)
def tok_encode_batch(self, strings: List[str]) -> TokenSequence:
return self.tokenizer(
strings,
padding=True,
add_special_tokens=self.add_special_tokens,
return_tensors="pt",
)
def tok_decode(self, tokens: torch.LongTensor) -> List[str]:
return self.tokenizer.batch_decode(tokens, skip_special_tokens=True)
def greedy_until(
self, requests: List[Tuple[str, Union[List[str], str]]]
) -> List[str]:
def _collate(x):
tokens = self.tok_encode(x[0])
return len(tokens), x[0]
results = []
reorder = utils.Reorderer(requests, _collate)
adaptive_batch_size = None
if self.batch_size == "auto":
# using rolling window with maximum context
print("Passed argument batch_size = auto. Detecting largest batch size")
batch_size = self._detect_batch_size()
print(f"Determined Largest batch size: {batch_size}")
adaptive_batch_size = batch_size
for chunk in utils.chunks(
tqdm(reorder.get_reordered(), disable=False),
self.batch_size if self.batch_size != "auto" else adaptive_batch_size,
):
context = [c[0] for c in chunk]
request_args = chunk[0][1]
stop = request_args.get("until", None)
stop_sequences = stop if isinstance(stop, list) else [stop]
max_generation_length = request_args.get("max_length", None)
assert (
isinstance(max_generation_length, int) or max_generation_length is None
)
assert isinstance(stop_sequences, list) or stop_sequences is None
# TODO: Find a better way to handle stop sequences for 0-shot.
if stop_sequences is None:
until = [self.eot_token]
else:
until = stop_sequences + [self.eot_token]
if max_generation_length is None:
max_tokens = self.max_gen_toks
else:
max_tokens = max_generation_length
token_context = self.tok_encode_batch(context)
responses = self._model_generate(
inputs=token_context,
max_tokens=max_tokens,
stop=until,
)
responses = self.tok_decode(responses.tolist())
for response in responses:
# Ensure the generated responses do not contain the stop sequences.
for term in until:
response = response.split(term)[0]
# partial caching
self.cache_hook.add_partial("greedy_until", (context, until), response)
results.append(response)
return reorder.get_original(results)
class AutoCausalLM(HuggingFaceAutoLM):
"""Causal language modeling.
You can find a set of supported models in the HF documentation:
https://huggingface.co/docs/transformers/main/model_doc/auto#transformers.AutoModelForCausalLM
"""
AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
AUTO_PEFT_CLASS = peft.PeftModel
def _create_auto_tokenizer(
self,
*,
pretrained: str,
revision: str,
subfolder: str,
tokenizer: Optional[str] = None,
) -> transformers.PreTrainedTokenizer:
tokenizer = super()._create_auto_tokenizer(
pretrained=pretrained,
revision=revision,
subfolder=subfolder,
tokenizer=tokenizer,
)
tokenizer.padding_side = "left"
return tokenizer
def _model_call(
self, inputs: TokenSequence, labels: Optional[TokenSequence] = None
) -> TokenSequence:
return self.model(inputs)["logits"]
def _model_generate(
self,
inputs: transformers.BatchEncoding,
max_tokens: int,
stop: Optional[List[str]] = None,
) -> TokenSequence:
# Ensure that the context does not encroach into the `space`
# for the generation.
input_ids = inputs["input_ids"][:, self.max_gen_toks - self.max_length :]
attention_mask = inputs["attention_mask"][
:, self.max_gen_toks - self.max_length :
]
input_ids = input_ids.to(self.device)
attention_mask = attention_mask.to(self.device)
stopping_criteria = stop_sequences_criteria(
self.tokenizer, stop, input_ids.shape[1], input_ids.shape[0]
)
generations = self.model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
# GPT style models require the `generate` `max_length` arg to include the
# context length, so we instead set `max_new_tokens` which is the number
# of new tokens to generate, excluding the current number of tokens.
max_new_tokens=max_tokens,
stopping_criteria=stopping_criteria,
do_sample=False,
)
return utils.select_continuation_from_batch_left_padding(
generations, max_context_size=inputs["input_ids"].size(1)
)
class AutoSeq2SeqLM(HuggingFaceAutoLM):
"""Seq2Seq language modeling.
You can find a set of supported models in the following documentation:
https://huggingface.co/docs/transformers/main/model_doc/auto#transformers.AutoModelForSeq2SeqLM
"""
AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM
AUTO_PEFT_CLASS = peft.PeftModel
def loglikelihood(
self, requests: List[Tuple[str, str]]
) -> List[Tuple[float, bool]]:
new_requests = []
for chunk in utils.chunks(requests, self.batch_size):
context, continuation = zip(*chunk)
# Fill empty contexts with the EOT token.
context = [
f"{self.eot_token}" if len(text) == 0 else text for text in context
]
context_enc = self.tok_encode_batch(context)
for key in context_enc:
context_enc[key] = context_enc[key][:, -self.max_length :]
# Remove leading whitespace introduced by the default
# `text_target_separator` since the context and continuation
# will not be concatenated as a single (decoder) input.
continuation = [text.lstrip() for text in continuation]
continuation_enc = self.tok_encode_batch(list(continuation))
for key in continuation_enc:
continuation_enc[key] = continuation_enc[key][:, -self.max_length :]
new_requests.append(
((context, continuation), context_enc, continuation_enc)
)
return self._loglikelihood_tokens(new_requests)
def loglikelihood_rolling(self, requests: List[Tuple[str, str]]) -> List[float]:
loglikelihoods = []
for (string,) in tqdm(requests):
rolling_token_windows = list(
map(
utils.make_disjoint_window,
utils.get_rolling_token_windows(
token_list=self.tok_encode(string),
prefix_token=self.eot_token_id,
max_seq_len=self.max_length,
context_len=1,
),
)
)
contexts, conts = utils.split_and_pad_windows(
rolling_token_windows,
pad_token_id=self.eot_token_id,
max_seq_len=self.max_length,
)
# Manually create BatchEncoding tensors with attention masks as
# expected by `self._model_call` in `self._loglikelihood_tokens`.
contexts_enc = torch.Tensor(contexts).long()
contexts_enc = transformers.tokenization_utils_base.BatchEncoding(
{
"input_ids": contexts_enc,
"attention_mask": (contexts_enc != self.eot_token_id).long(),
}
)
conts_enc = torch.Tensor(conts).long()
conts_enc = transformers.tokenization_utils_base.BatchEncoding(
{
"input_ids": conts_enc,
"attention_mask": (conts_enc != self.eot_token_id).long(),
}
)
# TODO: Extract out this call so it only gets called once and also
# somehow figure out partial caching for.
rolling_token_windows_request = [
((contexts, conts), contexts_enc, conts_enc)
]
string_nll = self._loglikelihood_tokens(
rolling_token_windows_request, disable_tqdm=True
)
string_nll = [x[0] for x in string_nll] # discard is_greedy
string_nll = sum(string_nll)
loglikelihoods.append(string_nll)
return loglikelihoods
def _loglikelihood_tokens(
self,
requests: List[Tuple[Tuple[str, str], TokenSequence, TokenSequence]],
disable_tqdm: Optional[bool] = False,
) -> List[Tuple[float, bool]]:
results = []
for chunk in tqdm(
requests, total=math.ceil(len(requests)), disable=disable_tqdm
):
cache_keys, inputs_tokens, targets_tokens = chunk
inputs_tokens = inputs_tokens.to(self.device)
targets_tokens = targets_tokens.to(self.device)
outputs = self._model_call(inputs=inputs_tokens, labels=targets_tokens)
log_softmaxes = F.log_softmax(outputs.logits, dim=-1)
output_iterator = zip(
zip(cache_keys[0], cache_keys[1]),
log_softmaxes,
targets_tokens["input_ids"],
targets_tokens["attention_mask"],
)
for cache_key, log_softmax, target_tokens, target_mask in output_iterator:
length = target_mask.sum()
log_softmax = log_softmax[:length]
target_tokens = target_tokens[:length]
greedy_tokens = log_softmax.argmax(dim=-1)
max_equal = (greedy_tokens == target_tokens).all()
target_logits = torch.gather(
log_softmax, 1, target_tokens.unsqueeze(-1)
).squeeze(-1)
answer = (float(target_logits.sum()), bool(max_equal))
results.append(answer)
if cache_key is not None:
self.cache_hook.add_partial("loglikelihood", cache_key, answer)
return results
def _model_call(
self, inputs: TokenSequence, labels: Optional[TokenSequence] = None
) -> TokenSequence:
return self.model(**inputs, labels=labels["input_ids"])
def _model_generate(
self,
inputs: transformers.BatchEncoding,
max_tokens: int,
stop: Optional[List[str]] = None,
) -> TokenSequence:
input_ids = inputs["input_ids"][:, -self.max_length :].to(self.device)
attention_mask = inputs["attention_mask"][:, -self.max_length :].to(self.device)
# Generate one token to calculate the number of start tokens prepended to decoder_input_ids
# (leaving this here in case the below assumption is violated in the future)
# one_tok_gen = self.model.generate(
# input_ids=torch.zeros((1, 1), dtype=torch.int),
# min_length=2,
# max_new_tokens=1,
# ).squeeze()
# initial_decoder_input_length = len(one_tok_gen) - 1
# Assume that there will always only be one token in the decoder inputs, assumption holds for existing HF models
stopping_criteria = stop_sequences_criteria(
self.tokenizer, stop, 1, input_ids.shape[0]
)
generations = self.model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_tokens,
stopping_criteria=stopping_criteria,
do_sample=False,
)
return generations
class MultiTokenEOSCriteria(transformers.StoppingCriteria):
"""Criteria to stop on the specified multi-token sequence."""
def __init__(
self,
sequence: str,
tokenizer: transformers.PreTrainedTokenizer,
initial_decoder_input_length: int,
batch_size: int,
):
self.initial_decoder_input_length = initial_decoder_input_length
self.done_tracker = [False] * batch_size
self.sequence = sequence
self.sequence_ids = tokenizer.encode(sequence, add_special_tokens=False)
self.sequence_id_len = len(self.sequence_ids)
self.tokenizer = tokenizer
def __call__(self, input_ids, scores, **kwargs) -> bool:
# For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence
lookback_ids_batch = input_ids[:, self.initial_decoder_input_length :][
:, -self.sequence_id_len :
]
lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch)
for i, done in enumerate(self.done_tracker):
if not done:
self.done_tracker[i] = self.sequence in lookback_tokens_batch[i]
return False not in self.done_tracker
def stop_sequences_criteria(
tokenizer: transformers.PreTrainedTokenizer,
stop_sequences: List[str],
initial_decoder_input_length: int,
batch_size: int,
) -> transformers.StoppingCriteriaList:
return transformers.StoppingCriteriaList(
[
*[
MultiTokenEOSCriteria(
sequence, tokenizer, initial_decoder_input_length, batch_size
)
for sequence in stop_sequences
],
]
)
""" TextSynth API
Implementation provided by Fabrice Bellard:
https://github.com/EleutherAI/lm-evaluation-harness/issues/295
In order to use the API, you must have a valid TextSynth account and
enough credits.
Example usage:
python main.py --model textsynth --model_args engine=gptj_6B --no_cache --tasks piqa
Homepage: https://textsynth.com/index.html
"""
import logging
import os
import requests as _requests
import time
from tqdm import tqdm
from lm_eval.base import BaseLM
logger = logging.getLogger(__name__)
def textsynth_completion(**kwargs):
"""Query TextSynth API for completion.
Retry with back-off until they respond.
"""
backoff_time = 3
while True:
try:
return _requests.post(**kwargs)
except _requests.exceptions.RequestException:
import traceback
traceback.print_exc()
time.sleep(backoff_time)
backoff_time *= 1.5
class TextSynthLM(BaseLM):
def __init__(self, engine, truncate=False):
"""
:param engine: str
TextSynth API engine (e.g. `gptj_6B`)
:param truncate: bool
Truncate input if too long (if False and input is too long, throw error)
"""
super().__init__()
self.engine = engine
self.truncate = truncate
self.api_url = "https://api.textsynth.com"
# Read from environment variable TEXTSYNTH_API_SECRET_KEY
self.api_key = os.environ["TEXTSYNTH_API_SECRET_KEY"]
@property
def eot_token_id(self):
# Isn't used because we override loglikelihood, loglikelihood_rolling and greedy_until
raise NotImplementedError()
@property
def max_length(self):
# NOTE: Turn on truncation to avoid errors on long inputs.
return 2048
@property
def max_gen_toks(self):
return 256
@property
def batch_size(self):
# Isn't used because we override loglikelihood, loglikelihood_rolling and greedy_until
raise NotImplementedError()
@property
def device(self):
# Isn't used because we override loglikelihood, loglikelihood_rolling and greedy_until
raise NotImplementedError()
def tok_encode(self, string: str):
# Isn't used because we override loglikelihood, loglikelihood_rolling and greedy_until
raise NotImplementedError()
def tok_decode(self, tokens):
# Isn't used because we override loglikelihood, loglikelihood_rolling and greedy_until
raise NotImplementedError()
def loglikelihood(self, requests):
res = []
for context, continuation in tqdm(requests):
response = textsynth_completion(
url=self.api_url + "/v1/engines/" + self.engine + "/logprob",
headers={"Authorization": "Bearer " + self.api_key},
json={"context": context, "continuation": continuation},
)
resp = response.json()
if "logprob" in resp:
logprob = resp["logprob"]
is_greedy = resp["is_greedy"]
res.append((logprob, is_greedy))
else:
logger.error(
f"The following response does not contain `logprobs`. Got:\n{resp}"
)
assert False
return res
def loglikelihood_rolling(self, requests):
# TODO: The TextSynth API does not support tokenized inputs so we cannot
# manually partition long contexts into smaller rolling windows as
# done for other models derived from `BaseLM`. Override this method
# with a windowing scheme that works for direct string inputs.
raise NotImplementedError(
"`loglikelihood_rolling` is currently not supported due to lack of "
"input tokenization support from TextSynth."
)
def greedy_until(self, requests):
if not requests:
return []
res = []
for request in tqdm(requests):
inp = request[0]
request_args = request[1]
until = request_args["until"]
response = textsynth_completion(
url=self.api_url + "/v1/engines/" + self.engine + "/completions",
headers={"Authorization": "Bearer " + self.api_key},
json={
"prompt": inp,
"max_tokens": self.max_gen_toks,
"top_k": 1,
"stop": until,
},
)
resp = response.json()
if "text" in resp:
s = resp["text"]
res.append(s)
else:
logger.error(
f"The following response does not contain generated `text`. "
"Got:\n{resp}"
)
assert False
return res
def _model_call(self, inps):
# Isn't used because we override _loglikelihood_tokens
raise NotImplementedError()
def _model_generate(self, context, max_length, eos_token_id):
# Isn't used because we override greedy_until
raise NotImplementedError()
import os
import collections
from pprint import pprint
from typing import List, Union, Optional
import sacrebleu
import lm_eval.base
from . import superglue
from . import glue
from . import arc
from . import coqa
from . import race
from . import webqs
from . import anli
from . import wsc273
from . import winogrande
from . import quac
from . import hellaswag
from . import swag
from . import openbookqa
from . import squad
from . import naturalqs
from . import sat
from . import arithmetic
from . import lambada
from . import piqa
from . import prost
from . import mc_taco
from . import triviaqa
from . import pubmedqa
from . import sciq
from . import qasper
from . import qa4mre
from . import translation
from . import headqa
from . import mathqa
from . import hendrycks_ethics
from . import drop
from . import unscramble
from . import logiqa
from . import hendrycks_test
from . import hendrycks_math
from . import cbt
from . import lambada_cloze
from . import pile
from . import wikitext
from . import lambada_multilingual
from . import mutual
from . import truthfulqa
from . import blimp
from . import asdiv
from . import gsm8k
from . import storycloze
from . import toxigen
from . import crowspairs
from . import json
from . import xcopa
from . import bigbench
from . import xstorycloze
from . import xwinograd
from . import pawsx
from . import xnli
from . import mgsm
########################################
# Translation tasks
########################################
# 6 total
gpt3_translation_benchmarks = {
"wmt14": ["en-fr", "fr-en"], # French
"wmt16": ["en-ro", "ro-en", "de-en", "en-de"], # German, Romanian
}
# 28 total
selected_translation_benchmarks = {
**gpt3_translation_benchmarks,
"wmt20": sacrebleu.get_langpairs_for_testset("wmt20"),
"iwslt17": ["en-ar", "ar-en"], # Arabic
}
# 319 total
all_translation_benchmarks = {
ts: sacrebleu.get_langpairs_for_testset(ts)
for ts in sacrebleu.get_available_testsets()
}
########################################
# All tasks
########################################
TASK_REGISTRY = {
# GLUE
"cola": glue.CoLA,
"mnli": glue.MNLI,
"mnli_mismatched": glue.MNLIMismatched,
"mrpc": glue.MRPC,
"rte": glue.RTE,
"qnli": glue.QNLI,
"qqp": glue.QQP,
# "stsb": glue.STSB, # not implemented yet
"sst": glue.SST,
"wnli": glue.WNLI,
# SuperGLUE
"boolq": superglue.BoolQ,
"cb": superglue.CommitmentBank,
"copa": superglue.Copa,
"multirc": superglue.MultiRC,
"record": superglue.ReCoRD,
"wic": superglue.WordsInContext,
"wsc": superglue.SGWinogradSchemaChallenge,
# Order by benchmark/genre?
"coqa": coqa.CoQA,
"drop": drop.DROP,
"lambada_openai": lambada.LambadaOpenAI,
"lambada_standard": lambada.LambadaStandard,
"lambada_openai_cloze": lambada_cloze.LambadaOpenAICloze,
"lambada_standard_cloze": lambada_cloze.LambadaStandardCloze,
# multilingual lambada
**lambada_multilingual.construct_tasks(),
"wikitext": wikitext.WikiText,
# "cbt-cn": cbt.CBTCN, # disabled pending context length fix
# "cbt-ne": cbt.CBTNE, # disabled pending context length fix
"piqa": piqa.PiQA,
"prost": prost.PROST,
"mc_taco": mc_taco.MCTACO,
# Science related
"pubmedqa": pubmedqa.Pubmed_QA,
"sciq": sciq.SciQ,
"qasper": qasper.QASPER,
"qa4mre_2011": qa4mre.QA4MRE_2011,
"qa4mre_2012": qa4mre.QA4MRE_2012,
"qa4mre_2013": qa4mre.QA4MRE_2013,
"triviaqa": triviaqa.TriviaQA,
"arc_easy": arc.ARCEasy,
"arc_challenge": arc.ARCChallenge,
# "quac": quac.QuAC, # not implemented yet
"logiqa": logiqa.LogiQA,
"hellaswag": hellaswag.HellaSwag,
"swag": swag.SWAG,
"openbookqa": openbookqa.OpenBookQA,
"squad2": squad.SQuAD2,
"race": race.RACE,
# "naturalqs": naturalqs.NaturalQs, # not implemented yet
"headqa": headqa.HeadQAEsDeprecated, # for backwards compat - headqa used to default to es
"headqa_es": headqa.HeadQAEs,
"headqa_en": headqa.HeadQAEn,
"mathqa": mathqa.MathQA,
"webqs": webqs.WebQs,
"wsc273": wsc273.WinogradSchemaChallenge273,
"winogrande": winogrande.Winogrande,
"anli_r1": anli.ANLIRound1,
"anli_r2": anli.ANLIRound2,
"anli_r3": anli.ANLIRound3,
"ethics_cm": hendrycks_ethics.EthicsCM,
"ethics_deontology": hendrycks_ethics.EthicsDeontology,
"ethics_justice": hendrycks_ethics.EthicsJustice,
"ethics_utilitarianism_original": hendrycks_ethics.EthicsUtilitarianismOriginal,
"ethics_utilitarianism": hendrycks_ethics.EthicsUtilitarianism,
"ethics_virtue": hendrycks_ethics.EthicsVirtue,
"truthfulqa_mc": truthfulqa.TruthfulQAMultipleChoice,
"truthfulqa_gen": truthfulqa.TruthfulQAGeneration,
# dialogue
"mutual": mutual.MuTual,
"mutual_plus": mutual.MuTualPlus,
# math
"math_algebra": hendrycks_math.MathAlgebra,
"math_counting_and_prob": hendrycks_math.MathCountingAndProbability,
"math_geometry": hendrycks_math.MathGeometry,
"math_intermediate_algebra": hendrycks_math.MathIntermediateAlgebra,
"math_num_theory": hendrycks_math.MathNumberTheory,
"math_prealgebra": hendrycks_math.MathPrealgebra,
"math_precalc": hendrycks_math.MathPrecalculus,
"math_asdiv": asdiv.Asdiv,
"gsm8k": gsm8k.GradeSchoolMath8K,
# arithmetic
"arithmetic_2da": arithmetic.Arithmetic2DPlus,
"arithmetic_2ds": arithmetic.Arithmetic2DMinus,
"arithmetic_3da": arithmetic.Arithmetic3DPlus,
"arithmetic_3ds": arithmetic.Arithmetic3DMinus,
"arithmetic_4da": arithmetic.Arithmetic4DPlus,
"arithmetic_4ds": arithmetic.Arithmetic4DMinus,
"arithmetic_5da": arithmetic.Arithmetic5DPlus,
"arithmetic_5ds": arithmetic.Arithmetic5DMinus,
"arithmetic_2dm": arithmetic.Arithmetic2DMultiplication,
"arithmetic_1dc": arithmetic.Arithmetic1DComposite,
# TODO Perhaps make these groups of tasks
# e.g. anli, arithmetic, openai_translations, harness_translations
# hendrycksTest (57 tasks)
**hendrycks_test.create_all_tasks(),
# e.g. wmt14-fr-en
**translation.create_tasks_from_benchmarks(gpt3_translation_benchmarks),
# chef's selection, mostly wmt20
**translation.create_tasks_from_benchmarks(selected_translation_benchmarks),
# Word Scrambling and Manipulation Tasks
"anagrams1": unscramble.Anagrams1,
"anagrams2": unscramble.Anagrams2,
"cycle_letters": unscramble.CycleLetters,
"random_insertion": unscramble.RandomInsertion,
"reversed_words": unscramble.ReversedWords,
# Pile
"pile_arxiv": pile.PileArxiv,
"pile_books3": pile.PileBooks3,
"pile_bookcorpus2": pile.PileBookCorpus2,
"pile_dm-mathematics": pile.PileDmMathematics,
"pile_enron": pile.PileEnron,
"pile_europarl": pile.PileEuroparl,
"pile_freelaw": pile.PileFreeLaw,
"pile_github": pile.PileGithub,
"pile_gutenberg": pile.PileGutenberg,
"pile_hackernews": pile.PileHackernews,
"pile_nih-exporter": pile.PileNIHExporter,
"pile_opensubtitles": pile.PileOpenSubtitles,
"pile_openwebtext2": pile.PileOpenWebText2,
"pile_philpapers": pile.PilePhilPapers,
"pile_pile-cc": pile.PilePileCc,
"pile_pubmed-abstracts": pile.PilePubmedAbstracts,
"pile_pubmed-central": pile.PilePubmedCentral,
"pile_stackexchange": pile.PileStackExchange,
"pile_uspto": pile.PileUspto,
"pile_ubuntu-irc": pile.PileUbuntuIrc,
"pile_wikipedia": pile.PileWikipedia,
"pile_youtubesubtitles": pile.PileYoutubeSubtitles,
# BLiMP
"blimp_adjunct_island": blimp.BlimpAdjunctIsland,
"blimp_anaphor_gender_agreement": blimp.BlimpAnaphorGenderAgreement,
"blimp_anaphor_number_agreement": blimp.BlimpAnaphorNumberAgreement,
"blimp_animate_subject_passive": blimp.BlimpAnimateSubjectPassive,
"blimp_animate_subject_trans": blimp.BlimpAnimateSubjectTrans,
"blimp_causative": blimp.BlimpCausative,
"blimp_complex_NP_island": blimp.BlimpComplex_NPIsland,
"blimp_coordinate_structure_constraint_complex_left_branch": blimp.BlimpCoordinateStructureConstraintComplexLeftBranch,
"blimp_coordinate_structure_constraint_object_extraction": blimp.BlimpCoordinateStructureConstraintObjectExtraction,
"blimp_determiner_noun_agreement_1": blimp.BlimpDeterminerNounAgreement_1,
"blimp_determiner_noun_agreement_2": blimp.BlimpDeterminerNounAgreement_2,
"blimp_determiner_noun_agreement_irregular_1": blimp.BlimpDeterminerNounAgreementIrregular_1,
"blimp_determiner_noun_agreement_irregular_2": blimp.BlimpDeterminerNounAgreementIrregular_2,
"blimp_determiner_noun_agreement_with_adj_2": blimp.BlimpDeterminerNounAgreementWithAdj_2,
"blimp_determiner_noun_agreement_with_adj_irregular_1": blimp.BlimpDeterminerNounAgreementWithAdjIrregular_1,
"blimp_determiner_noun_agreement_with_adj_irregular_2": blimp.BlimpDeterminerNounAgreementWithAdjIrregular_2,
"blimp_determiner_noun_agreement_with_adjective_1": blimp.BlimpDeterminerNounAgreementWithAdjective_1,
"blimp_distractor_agreement_relational_noun": blimp.BlimpDistractorAgreementRelationalNoun,
"blimp_distractor_agreement_relative_clause": blimp.BlimpDistractorAgreementRelativeClause,
"blimp_drop_argument": blimp.BlimpDropArgument,
"blimp_ellipsis_n_bar_1": blimp.BlimpEllipsisNBar_1,
"blimp_ellipsis_n_bar_2": blimp.BlimpEllipsisNBar_2,
"blimp_existential_there_object_raising": blimp.BlimpExistentialThereObjectRaising,
"blimp_existential_there_quantifiers_1": blimp.BlimpExistentialThereQuantifiers_1,
"blimp_existential_there_quantifiers_2": blimp.BlimpExistentialThereQuantifiers_2,
"blimp_existential_there_subject_raising": blimp.BlimpExistentialThereSubjectRaising,
"blimp_expletive_it_object_raising": blimp.BlimpExpletiveItObjectRaising,
"blimp_inchoative": blimp.BlimpInchoative,
"blimp_intransitive": blimp.BlimpIntransitive,
"blimp_irregular_past_participle_adjectives": blimp.BlimpIrregularPastParticipleAdjectives,
"blimp_irregular_past_participle_verbs": blimp.BlimpIrregularPastParticipleVerbs,
"blimp_irregular_plural_subject_verb_agreement_1": blimp.BlimpIrregularPluralSubjectVerbAgreement_1,
"blimp_irregular_plural_subject_verb_agreement_2": blimp.BlimpIrregularPluralSubjectVerbAgreement_2,
"blimp_left_branch_island_echo_question": blimp.BlimpLeftBranchIslandEchoQuestion,
"blimp_left_branch_island_simple_question": blimp.BlimpLeftBranchIslandSimpleQuestion,
"blimp_matrix_question_npi_licensor_present": blimp.BlimpMatrixQuestionNpiLicensorPresent,
"blimp_npi_present_1": blimp.BlimpNpiPresent_1,
"blimp_npi_present_2": blimp.BlimpNpiPresent_2,
"blimp_only_npi_licensor_present": blimp.BlimpOnlyNpiLicensorPresent,
"blimp_only_npi_scope": blimp.BlimpOnlyNpiScope,
"blimp_passive_1": blimp.BlimpPassive_1,
"blimp_passive_2": blimp.BlimpPassive_2,
"blimp_principle_A_c_command": blimp.BlimpPrinciple_ACCommand,
"blimp_principle_A_case_1": blimp.BlimpPrinciple_ACase_1,
"blimp_principle_A_case_2": blimp.BlimpPrinciple_ACase_2,
"blimp_principle_A_domain_1": blimp.BlimpPrinciple_ADomain_1,
"blimp_principle_A_domain_2": blimp.BlimpPrinciple_ADomain_2,
"blimp_principle_A_domain_3": blimp.BlimpPrinciple_ADomain_3,
"blimp_principle_A_reconstruction": blimp.BlimpPrinciple_AReconstruction,
"blimp_regular_plural_subject_verb_agreement_1": blimp.BlimpRegularPluralSubjectVerbAgreement_1,
"blimp_regular_plural_subject_verb_agreement_2": blimp.BlimpRegularPluralSubjectVerbAgreement_2,
"blimp_sentential_negation_npi_licensor_present": blimp.BlimpSententialNegationNpiLicensorPresent,
"blimp_sentential_negation_npi_scope": blimp.BlimpSententialNegationNpiScope,
"blimp_sentential_subject_island": blimp.BlimpSententialSubjectIsland,
"blimp_superlative_quantifiers_1": blimp.BlimpSuperlativeQuantifiers_1,
"blimp_superlative_quantifiers_2": blimp.BlimpSuperlativeQuantifiers_2,
"blimp_tough_vs_raising_1": blimp.BlimpToughVsRaising_1,
"blimp_tough_vs_raising_2": blimp.BlimpToughVsRaising_2,
"blimp_transitive": blimp.BlimpTransitive,
"blimp_wh_island": blimp.BlimpWhIsland,
"blimp_wh_questions_object_gap": blimp.BlimpWhQuestionsObjectGap,
"blimp_wh_questions_subject_gap": blimp.BlimpWhQuestionsSubjectGap,
"blimp_wh_questions_subject_gap_long_distance": blimp.BlimpWhQuestionsSubjectGapLongDistance,
"blimp_wh_vs_that_no_gap": blimp.BlimpWhVsThatNoGap,
"blimp_wh_vs_that_no_gap_long_distance": blimp.BlimpWhVsThatNoGapLongDistance,
"blimp_wh_vs_that_with_gap": blimp.BlimpWhVsThatWithGap,
"blimp_wh_vs_that_with_gap_long_distance": blimp.BlimpWhVsThatWithGapLongDistance,
"toxigen": toxigen.ToxiGen,
"crows_pairs_english": crowspairs.CrowsPairsEnglish,
"crows_pairs_english_race_color": crowspairs.CrowsPairsEnglishRaceColor,
"crows_pairs_english_socioeconomic": crowspairs.CrowsPairsEnglishSocioeconomic,
"crows_pairs_english_gender": crowspairs.CrowsPairsEnglishGender,
"crows_pairs_english_age": crowspairs.CrowsPairsEnglishAge,
"crows_pairs_english_religion": crowspairs.CrowsPairsEnglishReligion,
"crows_pairs_english_disability": crowspairs.CrowsPairsEnglishDisability,
"crows_pairs_english_sexual_orientation": crowspairs.CrowsPairsEnglishSexualOrientation,
"crows_pairs_english_nationality": crowspairs.CrowsPairsEnglishNationality,
"crows_pairs_english_physical_appearance": crowspairs.CrowsPairsEnglishPhysicalAppearance,
"crows_pairs_english_autre": crowspairs.CrowsPairsEnglishAutre,
"crows_pairs_french": crowspairs.CrowsPairsFrench,
"crows_pairs_french_race_color": crowspairs.CrowsPairsFrenchRaceColor,
"crows_pairs_french_socioeconomic": crowspairs.CrowsPairsFrenchSocioeconomic,
"crows_pairs_french_gender": crowspairs.CrowsPairsFrenchGender,
"crows_pairs_french_age": crowspairs.CrowsPairsFrenchAge,
"crows_pairs_french_religion": crowspairs.CrowsPairsFrenchReligion,
"crows_pairs_french_disability": crowspairs.CrowsPairsFrenchDisability,
"crows_pairs_french_sexual_orientation": crowspairs.CrowsPairsFrenchSexualOrientation,
"crows_pairs_french_nationality": crowspairs.CrowsPairsFrenchNationality,
"crows_pairs_french_physical_appearance": crowspairs.CrowsPairsFrenchPhysicalAppearance,
"crows_pairs_french_autre": crowspairs.CrowsPairsFrenchAutre,
# Requires manual download of data.
# "storycloze_2016": storycloze.StoryCloze2016,
# "storycloze_2018": storycloze.StoryCloze2018,
# "sat": sat.SATAnalogies,
**xcopa.construct_tasks(),
**bigbench.create_all_tasks(),
**xstorycloze.create_all_tasks(),
**xwinograd.create_all_tasks(),
**pawsx.construct_tasks(),
**xnli.construct_tasks(),
**mgsm.construct_tasks(),
}
ALL_TASKS = sorted(list(TASK_REGISTRY))
_EXAMPLE_JSON_PATH = "split:key:/absolute/path/to/data.json"
def add_json_task(task_name):
"""Add a JSON perplexity task if the given task name matches the
JSON task specification.
See `json.JsonPerplexity`.
"""
if not task_name.startswith("json"):
return
def create_json_task():
splits = task_name.split("=", 1)
if len(splits) != 2 or not splits[1]:
raise ValueError(
"json tasks need a path argument pointing to the local "
"dataset, specified like this: json="
+ _EXAMPLE_JSON_PATH
+ ' (if there are no splits, use "train")'
)
json_path = splits[1]
if json_path == _EXAMPLE_JSON_PATH:
raise ValueError(
"please do not copy the example path directly, but substitute "
"it with a path to your local dataset"
)
return lambda: json.JsonPerplexity(json_path)
TASK_REGISTRY[task_name] = create_json_task()
def get_task(task_name):
try:
add_json_task(task_name)
return TASK_REGISTRY[task_name]
except KeyError:
print("Available tasks:")
pprint(TASK_REGISTRY)
raise KeyError(f"Missing task {task_name}")
def get_task_name_from_object(task_object):
for name, class_ in TASK_REGISTRY.items():
if class_ is task_object:
return name
# this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
return (
task_object.EVAL_HARNESS_NAME
if hasattr(task_object, "EVAL_HARNESS_NAME")
else type(task_object).__name__
)
def get_task_dict(task_name_list: List[Union[str, lm_eval.base.Task]]):
task_name_dict = {
task_name: get_task(task_name)()
for task_name in task_name_list
if isinstance(task_name, str)
}
task_name_from_object_dict = {
get_task_name_from_object(task_object): task_object
for task_object in task_name_list
if not isinstance(task_object, str)
}
assert set(task_name_dict.keys()).isdisjoint(set(task_name_from_object_dict.keys()))
return {**task_name_dict, **task_name_from_object_dict}
## add from original
def _config_is_task(config) -> bool:
if ("task" in config) and isinstance(config["task"], str):
return True
return False
def _config_is_group(config) -> bool:
if ("task" in config) and isinstance(config["task"], list):
return True
return False
def _config_is_python_task(config) -> bool:
if "class" in config:
return True
return False
def _get_task_and_group(task_dir: str):
"""Creates a dictionary of tasks index with the following metadata,
- `type`, that can be either `task`, `python_task`, or `group`.
`task` refer to regular task configs, `python_task` are special
yaml files that only consists of `task` and `class` parameters.
`group` are group configs.
- `yaml_path`, path to the yaml file. If the entry is a `group` that
was configured through a task config, the yaml_path will be -1
and all subtasks will be listed in `task` (see below)
- `task`, reserved for entries with `type` as `group`. This will list
all subtasks. When a group config is created (as opposed to task
config having `group` parameter set), this will be set to -1 to
avoid recursive indexing. The whole list of subtasks will be loaded
at evaluation.
:param task_dir: str
A directory to check for tasks
:return
Dictionary of task names as key and task metadata
"""
tasks_and_groups = collections.defaultdict()
for root, _, file_list in os.walk(task_dir):
for f in file_list:
if f.endswith(".yaml"):
yaml_path = os.path.join(root, f)
config = utils.load_yaml_config(yaml_path, mode="simple")
if _config_is_python_task(config):
# This is a python class config
tasks_and_groups[config["task"]] = {
"type": "python_task",
"yaml_path": yaml_path,
}
elif _config_is_group(config):
# This is a group config
tasks_and_groups[config["group"]] = {
"type": "group",
"task": -1, # This signals that
# we don't need to know
# the task list for indexing
# as it can be loaded
# when called.
"yaml_path": yaml_path,
}
elif _config_is_task(config):
# This is a task config
task = config["task"]
tasks_and_groups[task] = {
"type": "task",
"yaml_path": yaml_path,
}
if "group" in config:
groups = config["group"]
if isinstance(config["group"], str):
groups = [groups]
for group in groups:
if group not in tasks_and_groups:
tasks_and_groups[group] = {
"type": "group",
"task": [task],
"yaml_path": -1,
}
else:
tasks_and_groups[group]["task"].append(task)
else:
print(f"File {f} in {root} could not be loaded")
return tasks_and_groups
def initialize_tasks(include_path: Optional[str] = None):
"""Creates a dictionary of tasks index.
:param include_path: str = None
An additional path to be searched for tasks
:return
Dictionary of task names as key and task metadata
"""
all_paths = [os.path.dirname(os.path.abspath(__file__)) + "/"]
if include_path is not None:
if isinstance(include_path, str):
include_path = [include_path]
all_paths.extend(include_path)
task_index = {}
for task_dir in all_paths:
tasks = _get_task_and_group(task_dir)
task_index = {**tasks, **task_index}
return task_index
\ No newline at end of file
"""
Adversarial NLI: A New Benchmark for Natural Language Understanding
https://arxiv.org/pdf/1910.14599.pdf
Adversarial NLI (ANLI) is a dataset collected via an iterative, adversarial
human-and-model-in-the-loop procedure. It consists of three rounds that progressively
increase in difficulty and complexity, and each question-answer includes annotator-
provided explanations.
Homepage: "https://github.com/facebookresearch/anli"
"""
import numpy as np
from lm_eval.base import rf, Task
from lm_eval.metrics import mean
_CITATION = """
@inproceedings{nie-etal-2020-adversarial,
title = "Adversarial {NLI}: A New Benchmark for Natural Language Understanding",
author = "Nie, Yixin and
Williams, Adina and
Dinan, Emily and
Bansal, Mohit and
Weston, Jason and
Kiela, Douwe",
booktitle = "Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics",
year = "2020",
publisher = "Association for Computational Linguistics",
}
"""
class ANLIBase(Task):
VERSION = 0
DATASET_PATH = "anli"
DATASET_NAME = None
SPLIT = None
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
def training_docs(self):
if self.has_training_docs():
if self._training_docs is None:
self._training_docs = list(self.dataset["train_r" + str(self.SPLIT)])
return self._training_docs
def validation_docs(self):
if self.has_validation_docs():
return self.dataset["dev_r" + str(self.SPLIT)]
def test_docs(self):
if self.has_test_docs():
return self.dataset["test_r" + str(self.SPLIT)]
def doc_to_text(self, doc):
# OA does this a bit weirdly: they prepend "anli 1: anli 1: " to the beginning
# of the prompt (yes, repeating it!). also, " True, False, or Neither?" is directly
# appended onto the question, with no "Answer:" or even a newline. Do we *really*
# want to do it exactly as OA did?
return (
doc["premise"]
+ "\nQuestion: "
+ doc["hypothesis"]
+ " True, False, or Neither?\nAnswer:"
)
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["premise"]
def doc_to_target(self, doc):
# True = entailment
# False = contradiction
# Neither = neutral
return " " + ["True", "Neither", "False"][doc["label"]]
def construct_requests(self, doc, ctx):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
ll_true, _ = rf.loglikelihood(ctx, " True")
ll_neither, _ = rf.loglikelihood(ctx, " Neither")
ll_false, _ = rf.loglikelihood(ctx, " False")
return ll_true, ll_neither, ll_false
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
The results of the requests created in construct_requests.
"""
gold = doc["label"]
pred = np.argmax(results)
return {"acc": pred == gold}
def aggregation(self):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
return {"acc": mean}
def higher_is_better(self):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return {"acc": True}
class ANLIRound1(ANLIBase):
SPLIT = 1
class ANLIRound2(ANLIBase):
SPLIT = 2
class ANLIRound3(ANLIBase):
SPLIT = 3
"""
Think you have Solved Question Answering? Try ARC, the AI2 Reasoning Challenge
https://arxiv.org/pdf/1803.05457.pdf
The ARC dataset consists of 7,787 science exam questions drawn from a variety
of sources, including science questions provided under license by a research
partner affiliated with AI2. These are text-only, English language exam questions
that span several grade levels as indicated in the files. Each question has a
multiple choice structure (typically 4 answer options). The questions are sorted
into a Challenge Set of 2,590 “hard” questions (those that both a retrieval and
a co-occurrence method fail to answer correctly) and an Easy Set of 5,197 questions.
Homepage: https://allenai.org/data/arc
"""
from lm_eval.base import MultipleChoiceTask
_CITATION = """
@article{Clark2018ThinkYH,
title={Think you have Solved Question Answering? Try ARC, the AI2 Reasoning Challenge},
author={Peter Clark and Isaac Cowhey and Oren Etzioni and Tushar Khot and Ashish Sabharwal and Carissa Schoenick and Oyvind Tafjord},
journal={ArXiv},
year={2018},
volume={abs/1803.05457}
}
"""
class ARCEasy(MultipleChoiceTask):
VERSION = 0
DATASET_PATH = "ai2_arc"
DATASET_NAME = "ARC-Easy"
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs
def validation_docs(self):
return map(self._process_doc, self.dataset["validation"])
def test_docs(self):
return map(self._process_doc, self.dataset["test"])
def _process_doc(self, doc):
# NOTE: Some `doc["answerKey"]`s are in numeric string format being one
# of {'1', '2', '3', '4', '5'}. We map them back to letters.
num_to_letter = {"1": "A", "2": "B", "3": "C", "4": "D", "5": "E"}
doc["answerKey"] = num_to_letter.get(doc["answerKey"], doc["answerKey"])
out_doc = {
"id": doc["id"],
"query": "Question: " + doc["question"] + "\nAnswer:",
"choices": doc["choices"]["text"],
"gold": ["A", "B", "C", "D", "E"].index(doc["answerKey"]),
}
return out_doc
def doc_to_text(self, doc):
return doc["query"]
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["query"]
class ARCChallenge(ARCEasy):
DATASET_PATH = "ai2_arc"
DATASET_NAME = "ARC-Challenge"
"""
Language Models are Few-Shot Learners
https://arxiv.org/pdf/2005.14165.pdf
A small battery of 10 tests that involve asking language models a simple arithmetic
problem in natural language.
Homepage: https://github.com/openai/gpt-3/tree/master/data
"""
from lm_eval.base import Task, rf
from lm_eval.metrics import mean
_CITATION = """
@inproceedings{NEURIPS2020_1457c0d6,
author = {Brown, Tom and Mann, Benjamin and Ryder, Nick and Subbiah, Melanie and Kaplan, Jared D and Dhariwal, Prafulla and Neelakantan, Arvind and Shyam, Pranav and Sastry, Girish and Askell, Amanda and Agarwal, Sandhini and Herbert-Voss, Ariel and Krueger, Gretchen and Henighan, Tom and Child, Rewon and Ramesh, Aditya and Ziegler, Daniel and Wu, Jeffrey and Winter, Clemens and Hesse, Chris and Chen, Mark and Sigler, Eric and Litwin, Mateusz and Gray, Scott and Chess, Benjamin and Clark, Jack and Berner, Christopher and McCandlish, Sam and Radford, Alec and Sutskever, Ilya and Amodei, Dario},
booktitle = {Advances in Neural Information Processing Systems},
editor = {H. Larochelle and M. Ranzato and R. Hadsell and M. F. Balcan and H. Lin},
pages = {1877--1901},
publisher = {Curran Associates, Inc.},
title = {Language Models are Few-Shot Learners},
url = {https://proceedings.neurips.cc/paper/2020/file/1457c0d6bfcb4967418bfb8ac142f64a-Paper.pdf},
volume = {33},
year = {2020}
}
"""
class Arithmetic(Task):
VERSION = 0
DATASET_PATH = "EleutherAI/arithmetic"
def has_training_docs(self):
return False
def has_validation_docs(self):
return True
def has_test_docs(self):
return False
def training_docs(self):
return NotImplemented
def validation_docs(self):
return self.dataset["validation"]
def test_docs(self):
return NotImplemented
def doc_to_text(self, doc):
return doc["context"]
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["context"]
def doc_to_target(self, doc):
return doc["completion"]
def construct_requests(self, doc, ctx):
ll, is_prediction = rf.loglikelihood(ctx, doc["completion"])
return is_prediction
def process_results(self, doc, results):
(is_prediction,) = results
return {"acc": is_prediction}
def aggregation(self):
return {
"acc": mean,
}
def higher_is_better(self):
return {"acc": True}
class Arithmetic2DPlus(Arithmetic):
DATASET_NAME = "arithmetic_2da"
class Arithmetic2DMinus(Arithmetic):
DATASET_NAME = "arithmetic_2ds"
class Arithmetic3DPlus(Arithmetic):
DATASET_NAME = "arithmetic_3da"
class Arithmetic3DMinus(Arithmetic):
DATASET_NAME = "arithmetic_3ds"
class Arithmetic4DPlus(Arithmetic):
DATASET_NAME = "arithmetic_4da"
class Arithmetic4DMinus(Arithmetic):
DATASET_NAME = "arithmetic_4ds"
class Arithmetic5DPlus(Arithmetic):
DATASET_NAME = "arithmetic_5da"
class Arithmetic5DMinus(Arithmetic):
DATASET_NAME = "arithmetic_5ds"
class Arithmetic2DMultiplication(Arithmetic):
DATASET_NAME = "arithmetic_2dm"
class Arithmetic1DComposite(Arithmetic):
DATASET_NAME = "arithmetic_1dc"
"""
ASDiv: A Diverse Corpus for Evaluating and Developing English Math Word Problem Solvers
https://arxiv.org/abs/2106.15772
ASDiv (Academia Sinica Diverse MWP Dataset) is a diverse (in terms of both language
patterns and problem types) English math word problem (MWP) corpus for evaluating
the capability of various MWP solvers. Existing MWP corpora for studying AI progress
remain limited either in language usage patterns or in problem types. We thus present
a new English MWP corpus with 2,305 MWPs that cover more text patterns and most problem
types taught in elementary school. Each MWP is annotated with its problem type and grade
level (for indicating the level of difficulty).
NOTE: We currently ignore formulas for answer generation.
Homepage: https://github.com/chaochun/nlu-asdiv-dataset
"""
import inspect
import lm_eval.datasets.asdiv.asdiv
from lm_eval.base import rf, Task
from lm_eval.metrics import mean
_CITATION = """
@misc{miao2021diverse,
title={A Diverse Corpus for Evaluating and Developing English Math Word Problem Solvers},
author={Shen-Yun Miao and Chao-Chun Liang and Keh-Yih Su},
year={2021},
eprint={2106.15772},
archivePrefix={arXiv},
primaryClass={cs.AI}
}
"""
class Asdiv(Task):
VERSION = 0
DATASET_PATH = inspect.getfile(lm_eval.datasets.asdiv.asdiv)
def has_training_docs(self):
return False
def has_validation_docs(self):
return True
def has_test_docs(self):
return False
def training_docs(self):
raise NotImplementedError("This dataset has no training docs")
def validation_docs(self):
return self.dataset["validation"]
def test_docs(self):
raise NotImplementedError("This dataset has no test docs")
def fewshot_context(
self, doc, num_fewshot, provide_description=None, rnd=None, description=None
):
assert num_fewshot == 0, "ASDiv is intended only for the zero-shot setting."
return super().fewshot_context(
doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description
)
def doc_to_text(self, doc):
# TODO: add solution-type
return doc["body"] + "\n" + "Question:" + doc["question"] + "\n" + "Answer:"
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["body"] + " " + doc["question"]
def doc_to_target(self, doc):
# TODO: add formula
answer = doc["answer"].split(" (")[0]
return " " + answer
def construct_requests(self, doc, ctx):
ll, is_greedy = rf.loglikelihood(ctx, self.doc_to_target(doc))
return ll, is_greedy
def process_results(self, doc, results):
ll, is_greedy = results
return {"acc": int(is_greedy)}
def aggregation(self):
return {"acc": mean}
def higher_is_better(self):
return {"acc": True}
"""
Tasks missing from BIG-bench-hard:
programmatic - boolean_expressions, web of lies, multistep_arithmetic
"""
import os
import json
import hashlib
import functools
import numpy as np
import re
import importlib.resources
from lm_eval.base import rf, Task
from lm_eval.metrics import mean
_CITATION = """
@misc{srivastava2022imitation,
title={Beyond the Imitation Game: Quantifying and extrapolating the capabilities of language models},
author={Aarohi Srivastava and Abhinav Rastogi and Abhishek Rao and Abu Awal Md Shoeb and Abubakar Abid and Adam Fisch and Adam R. Brown and Adam Santoro and Aditya Gupta and Adrià Garriga-Alonso and Agnieszka Kluska and Aitor Lewkowycz and Akshat Agarwal and Alethea Power and Alex Ray and Alex Warstadt and Alexander W. Kocurek and Ali Safaya and Ali Tazarv and Alice Xiang and Alicia Parrish and Allen Nie and Aman Hussain and Amanda Askell and Amanda Dsouza and Ambrose Slone and Ameet Rahane and Anantharaman S. Iyer and Anders Andreassen and Andrea Madotto and Andrea Santilli and Andreas Stuhlmüller and Andrew Dai and Andrew La and Andrew Lampinen and Andy Zou and Angela Jiang and Angelica Chen and Anh Vuong and Animesh Gupta and Anna Gottardi and Antonio Norelli and Anu Venkatesh and Arash Gholamidavoodi and Arfa Tabassum and Arul Menezes and Arun Kirubarajan and Asher Mullokandov and Ashish Sabharwal and Austin Herrick and Avia Efrat and Aykut Erdem and Ayla Karakaş and B. Ryan Roberts and Bao Sheng Loe and Barret Zoph and Bartłomiej Bojanowski and Batuhan Özyurt and Behnam Hedayatnia and Behnam Neyshabur and Benjamin Inden and Benno Stein and Berk Ekmekci and Bill Yuchen Lin and Blake Howald and Cameron Diao and Cameron Dour and Catherine Stinson and Cedrick Argueta and César Ferri Ramírez and Chandan Singh and Charles Rathkopf and Chenlin Meng and Chitta Baral and Chiyu Wu and Chris Callison-Burch and Chris Waites and Christian Voigt and Christopher D. Manning and Christopher Potts and Cindy Ramirez and Clara E. Rivera and Clemencia Siro and Colin Raffel and Courtney Ashcraft and Cristina Garbacea and Damien Sileo and Dan Garrette and Dan Hendrycks and Dan Kilman and Dan Roth and Daniel Freeman and Daniel Khashabi and Daniel Levy and Daniel Moseguí González and Danielle Perszyk and Danny Hernandez and Danqi Chen and Daphne Ippolito and Dar Gilboa and David Dohan and David Drakard and David Jurgens and Debajyoti Datta and Deep Ganguli and Denis Emelin and Denis Kleyko and Deniz Yuret and Derek Chen and Derek Tam and Dieuwke Hupkes and Diganta Misra and Dilyar Buzan and Dimitri Coelho Mollo and Diyi Yang and Dong-Ho Lee and Ekaterina Shutova and Ekin Dogus Cubuk and Elad Segal and Eleanor Hagerman and Elizabeth Barnes and Elizabeth Donoway and Ellie Pavlick and Emanuele Rodola and Emma Lam and Eric Chu and Eric Tang and Erkut Erdem and Ernie Chang and Ethan A. Chi and Ethan Dyer and Ethan Jerzak and Ethan Kim and Eunice Engefu Manyasi and Evgenii Zheltonozhskii and Fanyue Xia and Fatemeh Siar and Fernando Martínez-Plumed and Francesca Happé and Francois Chollet and Frieda Rong and Gaurav Mishra and Genta Indra Winata and Gerard de Melo and Germán Kruszewski and Giambattista Parascandolo and Giorgio Mariani and Gloria Wang and Gonzalo Jaimovitch-López and Gregor Betz and Guy Gur-Ari and Hana Galijasevic and Hannah Kim and Hannah Rashkin and Hannaneh Hajishirzi and Harsh Mehta and Hayden Bogar and Henry Shevlin and Hinrich Schütze and Hiromu Yakura and Hongming Zhang and Hugh Mee Wong and Ian Ng and Isaac Noble and Jaap Jumelet and Jack Geissinger and Jackson Kernion and Jacob Hilton and Jaehoon Lee and Jaime Fernández Fisac and James B. Simon and James Koppel and James Zheng and James Zou and Jan Kocoń and Jana Thompson and Jared Kaplan and Jarema Radom and Jascha Sohl-Dickstein and Jason Phang and Jason Wei and Jason Yosinski and Jekaterina Novikova and Jelle Bosscher and Jennifer Marsh and Jeremy Kim and Jeroen Taal and Jesse Engel and Jesujoba Alabi and Jiacheng Xu and Jiaming Song and Jillian Tang and Joan Waweru and John Burden and John Miller and John U. Balis and Jonathan Berant and Jörg Frohberg and Jos Rozen and Jose Hernandez-Orallo and Joseph Boudeman and Joseph Jones and Joshua B. Tenenbaum and Joshua S. Rule and Joyce Chua and Kamil Kanclerz and Karen Livescu and Karl Krauth and Karthik Gopalakrishnan and Katerina Ignatyeva and Katja Markert and Kaustubh D. Dhole and Kevin Gimpel and Kevin Omondi and Kory Mathewson and Kristen Chiafullo and Ksenia Shkaruta and Kumar Shridhar and Kyle McDonell and Kyle Richardson and Laria Reynolds and Leo Gao and Li Zhang and Liam Dugan and Lianhui Qin and Lidia Contreras-Ochando and Louis-Philippe Morency and Luca Moschella and Lucas Lam and Lucy Noble and Ludwig Schmidt and Luheng He and Luis Oliveros Colón and Luke Metz and Lütfi Kerem Şenel and Maarten Bosma and Maarten Sap and Maartje ter Hoeve and Maheen Farooqi and Manaal Faruqui and Mantas Mazeika and Marco Baturan and Marco Marelli and Marco Maru and Maria Jose Ramírez Quintana and Marie Tolkiehn and Mario Giulianelli and Martha Lewis and Martin Potthast and Matthew L. Leavitt and Matthias Hagen and Mátyás Schubert and Medina Orduna Baitemirova and Melody Arnaud and Melvin McElrath and Michael A. Yee and Michael Cohen and Michael Gu and Michael Ivanitskiy and Michael Starritt and Michael Strube and Michał Swędrowski and Michele Bevilacqua and Michihiro Yasunaga and Mihir Kale and Mike Cain and Mimee Xu and Mirac Suzgun and Mo Tiwari and Mohit Bansal and Moin Aminnaseri and Mor Geva and Mozhdeh Gheini and Mukund Varma T and Nanyun Peng and Nathan Chi and Nayeon Lee and Neta Gur-Ari Krakover and Nicholas Cameron and Nicholas Roberts and Nick Doiron and Nikita Nangia and Niklas Deckers and Niklas Muennighoff and Nitish Shirish Keskar and Niveditha S. Iyer and Noah Constant and Noah Fiedel and Nuan Wen and Oliver Zhang and Omar Agha and Omar Elbaghdadi and Omer Levy and Owain Evans and Pablo Antonio Moreno Casares and Parth Doshi and Pascale Fung and Paul Pu Liang and Paul Vicol and Pegah Alipoormolabashi and Peiyuan Liao and Percy Liang and Peter Chang and Peter Eckersley and Phu Mon Htut and Pinyu Hwang and Piotr Miłkowski and Piyush Patil and Pouya Pezeshkpour and Priti Oli and Qiaozhu Mei and Qing Lyu and Qinlang Chen and Rabin Banjade and Rachel Etta Rudolph and Raefer Gabriel and Rahel Habacker and Ramón Risco Delgado and Raphaël Millière and Rhythm Garg and Richard Barnes and Rif A. Saurous and Riku Arakawa and Robbe Raymaekers and Robert Frank and Rohan Sikand and Roman Novak and Roman Sitelew and Ronan LeBras and Rosanne Liu and Rowan Jacobs and Rui Zhang and Ruslan Salakhutdinov and Ryan Chi and Ryan Lee and Ryan Stovall and Ryan Teehan and Rylan Yang and Sahib Singh and Saif M. Mohammad and Sajant Anand and Sam Dillavou and Sam Shleifer and Sam Wiseman and Samuel Gruetter and Samuel R. Bowman and Samuel S. Schoenholz and Sanghyun Han and Sanjeev Kwatra and Sarah A. Rous and Sarik Ghazarian and Sayan Ghosh and Sean Casey and Sebastian Bischoff and Sebastian Gehrmann and Sebastian Schuster and Sepideh Sadeghi and Shadi Hamdan and Sharon Zhou and Shashank Srivastava and Sherry Shi and Shikhar Singh and Shima Asaadi and Shixiang Shane Gu and Shubh Pachchigar and Shubham Toshniwal and Shyam Upadhyay and Shyamolima and Debnath and Siamak Shakeri and Simon Thormeyer and Simone Melzi and Siva Reddy and Sneha Priscilla Makini and Soo-Hwan Lee and Spencer Torene and Sriharsha Hatwar and Stanislas Dehaene and Stefan Divic and Stefano Ermon and Stella Biderman and Stephanie Lin and Stephen Prasad and Steven T. Piantadosi and Stuart M. Shieber and Summer Misherghi and Svetlana Kiritchenko and Swaroop Mishra and Tal Linzen and Tal Schuster and Tao Li and Tao Yu and Tariq Ali and Tatsu Hashimoto and Te-Lin Wu and Théo Desbordes and Theodore Rothschild and Thomas Phan and Tianle Wang and Tiberius Nkinyili and Timo Schick and Timofei Kornev and Timothy Telleen-Lawton and Titus Tunduny and Tobias Gerstenberg and Trenton Chang and Trishala Neeraj and Tushar Khot and Tyler Shultz and Uri Shaham and Vedant Misra and Vera Demberg and Victoria Nyamai and Vikas Raunak and Vinay Ramasesh and Vinay Uday Prabhu and Vishakh Padmakumar and Vivek Srikumar and William Fedus and William Saunders and William Zhang and Wout Vossen and Xiang Ren and Xiaoyu Tong and Xinran Zhao and Xinyi Wu and Xudong Shen and Yadollah Yaghoobzadeh and Yair Lakretz and Yangqiu Song and Yasaman Bahri and Yejin Choi and Yichi Yang and Yiding Hao and Yifu Chen and Yonatan Belinkov and Yu Hou and Yufang Hou and Yuntao Bai and Zachary Seid and Zhuoye Zhao and Zijian Wang and Zijie J. Wang and Zirui Wang and Ziyi Wu},
year={2022},
eprint={2206.04615},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
"""
_DEFAULT_REGEX = r"[^\.\?\!\;\n]+"
class BigBenchJsonTask(Task):
VERSION = 0
def __init__(self, json_path):
self._random_seed = 42
with open(json_path) as file:
self._task_json = json.load(file)
self._has_multi_choice = "multiple_choice_grade" in self._task_json["metrics"]
self._has_generative = "exact_str_match" in self._task_json["metrics"]
self.output_regex = self._task_json.get("output_regex", None)
self.stop_string = self._task_json.get("stop_string", None)
if self.output_regex is None and self.stop_string is None:
self.output_regex = _DEFAULT_REGEX
# differs from the default 30 when evaluating HF models in the BIG-bench codebase
self.max_length = 128
def has_training_docs(self):
return False
def has_validation_docs(self):
return False
def has_test_docs(self):
return True
def test_docs(self):
return _get_unique_examples(self._task_json["examples"])
def doc_to_text(self, doc):
example_input_prefix = self._task_json.get("example_input_prefix", "\nQ: ")
res = f"{example_input_prefix}{doc['input']}"
rng = np.random.RandomState(seed=self._random_seed)
choice_prefix = self._task_json.get("choice_prefix", "\n choice: ")
append_choices = self._task_json.get("append_choices_to_input", True)
if "target_scores" in doc and append_choices:
choice_dict = doc["target_scores"]
permuted_choices = rng.permutation(sorted(list(choice_dict.keys())))
res = f"{res}{choice_prefix}{choice_prefix.join(permuted_choices)}"
example_output_prefix = self._task_json.get("example_output_prefix", "\nA: ")
res = f"{res}{example_output_prefix}"
return res
def doc_to_target(self, doc):
return max(doc["target_scores"].items(), key=lambda x: x[1])[0]
def _doc_to_queries(self, doc):
if "target_scores" in doc:
return list(doc["target_scores"].keys())
return doc["target"] if isinstance(doc["target"], list) else [doc["target"]]
def construct_requests(self, doc, ctx):
requests = []
if self._has_multi_choice:
queries = self._doc_to_queries(doc)
requests += [
rf.loglikelihood(ctx, continuation)[0] for continuation in queries
]
if self._has_generative:
requests.append(
rf.greedy_until(ctx, {"until": [], "max_length": self.max_length})
)
return requests
def process_results(self, doc, results):
res = {}
for metric in self._task_json["metrics"]:
if metric == "multiple_choice_grade":
likelihoods = results[:-1] if self._has_generative else results
queries = self._doc_to_queries(doc)
highest_score_index = _argmax(likelihoods)
highest_score_key = queries[highest_score_index]
res["multiple_choice_grade"] = doc["target_scores"][highest_score_key]
elif metric == "exact_str_match":
postprocessed = _postprocess_output(
results[-1],
max_length=self.max_length,
stop_string=self.stop_string,
output_regex=self.output_regex,
)
res["exact_str_match"] = int(postprocessed == doc["target"])
else:
raise NotImplementedError(f"Metric {metric} isn't implemented")
return res
def aggregation(self):
return {
"multiple_choice_grade": mean,
"exact_str_match": mean,
}
def higher_is_better(self):
return {
"multiple_choice_grade": True,
"exact_str_match": True,
}
@functools.lru_cache()
def _doc_to_few_shot_context(self, shots):
rng = np.random.RandomState(seed=self._random_seed)
res = {}
samples = self.test_docs()
separator = self._task_json.get("few_shot_example_separator", "\n")
for sample in rng.choice(samples, len(samples), replace=False):
valid_samples = [x for x in samples if x != sample]
shot_examples = list(rng.choice(valid_samples, shots, replace=False))
if self._has_multi_choice:
context = separator.join(
[
self.doc_to_text(example)
+ rng.choice(_get_valid_answers(example["target_scores"]))
for example in shot_examples
]
)
else:
context = separator.join(
[
self.doc_to_text(example) + example["target"]
for example in shot_examples
]
)
res[json.dumps(sample)] = context + separator + self.doc_to_text(sample)
return res
def fewshot_context(self, doc, num_fewshot, **kwargs):
if num_fewshot == 0:
res = self.doc_to_text(doc)
else:
res = self._doc_to_few_shot_context(shots=num_fewshot)[json.dumps(doc)]
res = f"{self._task_json.get('task_prefix', '')}{res}"
return res
def _get_valid_answers(scores):
max_value = max(scores.values())
return [key for key, value in scores.items() if value == max_value]
def _get_unique_examples(examples):
seen_examples, res = set(), []
for example in examples:
example_string = json.dumps(example)
if example_string not in seen_examples:
res.append(example)
seen_examples.add(example_string)
return res
def _argmax(array):
"""argmax with deterministic pseudorandom tie breaking."""
max_indices = np.arange(len(array))[array == np.max(array)]
idx = int(hashlib.sha256(np.asarray(array).tobytes()).hexdigest(), 16) % len(
max_indices
)
return max_indices[idx]
def _postprocess_output(text, max_length, stop_string, output_regex):
if isinstance(text, list):
return [
_postprocess_output(mo, max_length, stop_string, output_regex)
for mo in text
]
# Ensure it is a string (will convert from bytes, ... as needed)
if not isinstance(text, str):
text = str(text, "utf-8")
# truncate at max_length
if max_length:
text = text[:max_length]
# Remove all text after any stop_string
if stop_string:
index = text.find(stop_string)
if index > 0:
text = text[: index + len(stop_string)]
# extract substring matching regex (empty string for no match)
if output_regex:
_text = text
text = next(iter(re.findall(output_regex, text)), "")
assert (
not type(text) is tuple
), f'Regex {output_regex} returned multiple matching groups when applied to string {_text}. Try using non-capturing groups, by starting regex groups with ?: (e.g. "(stuff)" -> "(?:stuff)").'
return text
def create_task_from_path(json_path):
class WrappedTask(BigBenchJsonTask):
def __init__(self):
super().__init__(json_path)
return WrappedTask
def create_all_tasks():
resources_dir = importlib.resources.files("lm_eval.datasets") / "bigbench_resources"
supported_tasks = [os.path.splitext(x)[0] for x in os.listdir(resources_dir)]
res = {}
for task_name in supported_tasks:
task_path = os.path.join(resources_dir, f"{task_name}.json")
res[f"bigbench_{task_name}"] = create_task_from_path(task_path)
return res
"""
BLiMP: A Benchmark of Linguistic Minimal Pairs for English
https://arxiv.org/abs/1912.00582
BLiMP is a challenge set for evaluating what language models (LMs) know about
major grammatical phenomena in English. BLiMP consists of 67 sub-datasets, each
containing 1000 minimal pairs isolating specific contrasts in syntax, morphology,
or semantics. The data is automatically generated according to expert-crafted
grammars.
Homepage: https://github.com/alexwarstadt/blimp
"""
from lm_eval.base import rf, Task
from lm_eval.metrics import mean
_CITATION = """
@article{warstadt2019blimp,
author = {Warstadt, Alex and Parrish, Alicia and Liu, Haokun and Mohananey, Anhad and Peng, Wei and Wang, Sheng-Fu and Bowman, Samuel R.},
title = {BLiMP: The Benchmark of Linguistic Minimal Pairs for English},
journal = {Transactions of the Association for Computational Linguistics},
volume = {8},
number = {},
pages = {377-392},
year = {2020},
doi = {10.1162/tacl\_a\_00321},
URL = {https://doi.org/10.1162/tacl_a_00321},
eprint = {https://doi.org/10.1162/tacl_a_00321},
abstract = { We introduce The Benchmark of Linguistic Minimal Pairs (BLiMP),1 a challenge set for evaluating the linguistic knowledge of language models (LMs) on major grammatical phenomena in English. BLiMP consists of 67 individual datasets, each containing 1,000 minimal pairs—that is, pairs of minimally different sentences that contrast in grammatical acceptability and isolate specific phenomenon in syntax, morphology, or semantics. We generate the data according to linguist-crafted grammar templates, and human aggregate agreement with the labels is 96.4\%. We evaluate n-gram, LSTM, and Transformer (GPT-2 and Transformer-XL) LMs by observing whether they assign a higher probability to the acceptable sentence in each minimal pair. We find that state-of-the-art models identify morphological contrasts related to agreement reliably, but they struggle with some subtle semantic and syntactic phenomena, such as negative polarity items and extraction islands. }
}
""" # noqa: W605
class BlimpTask(Task):
VERSION = 0
DATASET_PATH = "blimp"
def has_training_docs(self):
return False
def has_validation_docs(self):
return True
def has_test_docs(self):
return False
def validation_docs(self):
# The HF dataset only contains a "train" dataset, but the harness expects a "validation"
# dataset. Let's use the training dataset, on the assumption that the model wasn't actually
# trained on this data.
return self.dataset["train"]
def fewshot_context(
self, doc, num_fewshot, provide_description=None, rnd=None, description=None
):
assert num_fewshot == 0
assert (
rnd is not None
), "A `random.Random` generator argument must be provided to `rnd`"
assert not provide_description, (
"The `provide_description` arg will be removed in future versions. To prepend "
"a custom description to the context, supply the corresponding string via the "
"`description` arg."
)
if provide_description is not None:
# nudge people to not specify it at all
print(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
return ""
def doc_to_text(self, doc):
# this method is invoked by tests only
return ""
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["sentence_good"] + " " + doc["sentence_bad"]
def doc_to_target(self, doc):
# this method is invoked by tests only
return ""
def construct_requests(self, doc, ctx):
assert not ctx
# Calculate the loglikelihood for the good and the bad sentence.
# Note that loglikelihood translates the "" prefix to the "<|endoftext|>" token
return [
rf.loglikelihood("", doc["sentence_good"]),
rf.loglikelihood("", doc["sentence_bad"]),
]
def process_results(self, doc, results):
likelihood1, likelihood2 = results
# the model got this case right iff the good sentence scored higher than the bad sentence
acc = 1.0 if likelihood1 > likelihood2 else 0.0
return {
"acc": acc,
}
def higher_is_better(self):
return {
"acc": True,
}
def aggregation(self):
return {
"acc": mean,
}
class BlimpAdjunctIsland(BlimpTask):
DATASET_NAME = "adjunct_island"
class BlimpAnaphorGenderAgreement(BlimpTask):
DATASET_NAME = "anaphor_gender_agreement"
class BlimpAnaphorNumberAgreement(BlimpTask):
DATASET_NAME = "anaphor_number_agreement"
class BlimpAnimateSubjectPassive(BlimpTask):
DATASET_NAME = "animate_subject_passive"
class BlimpAnimateSubjectTrans(BlimpTask):
DATASET_NAME = "animate_subject_trans"
class BlimpCausative(BlimpTask):
DATASET_NAME = "causative"
class BlimpComplex_NPIsland(BlimpTask):
DATASET_NAME = "complex_NP_island"
class BlimpCoordinateStructureConstraintComplexLeftBranch(BlimpTask):
DATASET_NAME = "coordinate_structure_constraint_complex_left_branch"
class BlimpCoordinateStructureConstraintObjectExtraction(BlimpTask):
DATASET_NAME = "coordinate_structure_constraint_object_extraction"
class BlimpDeterminerNounAgreement_1(BlimpTask):
DATASET_NAME = "determiner_noun_agreement_1"
class BlimpDeterminerNounAgreement_2(BlimpTask):
DATASET_NAME = "determiner_noun_agreement_2"
class BlimpDeterminerNounAgreementIrregular_1(BlimpTask):
DATASET_NAME = "determiner_noun_agreement_irregular_1"
class BlimpDeterminerNounAgreementIrregular_2(BlimpTask):
DATASET_NAME = "determiner_noun_agreement_irregular_2"
class BlimpDeterminerNounAgreementWithAdj_2(BlimpTask):
DATASET_NAME = "determiner_noun_agreement_with_adj_2"
class BlimpDeterminerNounAgreementWithAdjIrregular_1(BlimpTask):
DATASET_NAME = "determiner_noun_agreement_with_adj_irregular_1"
class BlimpDeterminerNounAgreementWithAdjIrregular_2(BlimpTask):
DATASET_NAME = "determiner_noun_agreement_with_adj_irregular_2"
class BlimpDeterminerNounAgreementWithAdjective_1(BlimpTask):
DATASET_NAME = "determiner_noun_agreement_with_adjective_1"
class BlimpDistractorAgreementRelationalNoun(BlimpTask):
DATASET_NAME = "distractor_agreement_relational_noun"
class BlimpDistractorAgreementRelativeClause(BlimpTask):
DATASET_NAME = "distractor_agreement_relative_clause"
class BlimpDropArgument(BlimpTask):
DATASET_NAME = "drop_argument"
class BlimpEllipsisNBar_1(BlimpTask):
DATASET_NAME = "ellipsis_n_bar_1"
class BlimpEllipsisNBar_2(BlimpTask):
DATASET_NAME = "ellipsis_n_bar_2"
class BlimpExistentialThereObjectRaising(BlimpTask):
DATASET_NAME = "existential_there_object_raising"
class BlimpExistentialThereQuantifiers_1(BlimpTask):
DATASET_NAME = "existential_there_quantifiers_1"
class BlimpExistentialThereQuantifiers_2(BlimpTask):
DATASET_NAME = "existential_there_quantifiers_2"
class BlimpExistentialThereSubjectRaising(BlimpTask):
DATASET_NAME = "existential_there_subject_raising"
class BlimpExpletiveItObjectRaising(BlimpTask):
DATASET_NAME = "expletive_it_object_raising"
class BlimpInchoative(BlimpTask):
DATASET_NAME = "inchoative"
class BlimpIntransitive(BlimpTask):
DATASET_NAME = "intransitive"
class BlimpIrregularPastParticipleAdjectives(BlimpTask):
DATASET_NAME = "irregular_past_participle_adjectives"
class BlimpIrregularPastParticipleVerbs(BlimpTask):
DATASET_NAME = "irregular_past_participle_verbs"
class BlimpIrregularPluralSubjectVerbAgreement_1(BlimpTask):
DATASET_NAME = "irregular_plural_subject_verb_agreement_1"
class BlimpIrregularPluralSubjectVerbAgreement_2(BlimpTask):
DATASET_NAME = "irregular_plural_subject_verb_agreement_2"
class BlimpLeftBranchIslandEchoQuestion(BlimpTask):
DATASET_NAME = "left_branch_island_echo_question"
class BlimpLeftBranchIslandSimpleQuestion(BlimpTask):
DATASET_NAME = "left_branch_island_simple_question"
class BlimpMatrixQuestionNpiLicensorPresent(BlimpTask):
DATASET_NAME = "matrix_question_npi_licensor_present"
class BlimpNpiPresent_1(BlimpTask):
DATASET_NAME = "npi_present_1"
class BlimpNpiPresent_2(BlimpTask):
DATASET_NAME = "npi_present_2"
class BlimpOnlyNpiLicensorPresent(BlimpTask):
DATASET_NAME = "only_npi_licensor_present"
class BlimpOnlyNpiScope(BlimpTask):
DATASET_NAME = "only_npi_scope"
class BlimpPassive_1(BlimpTask):
DATASET_NAME = "passive_1"
class BlimpPassive_2(BlimpTask):
DATASET_NAME = "passive_2"
class BlimpPrinciple_ACCommand(BlimpTask):
DATASET_NAME = "principle_A_c_command"
class BlimpPrinciple_ACase_1(BlimpTask):
DATASET_NAME = "principle_A_case_1"
class BlimpPrinciple_ACase_2(BlimpTask):
DATASET_NAME = "principle_A_case_2"
class BlimpPrinciple_ADomain_1(BlimpTask):
DATASET_NAME = "principle_A_domain_1"
class BlimpPrinciple_ADomain_2(BlimpTask):
DATASET_NAME = "principle_A_domain_2"
class BlimpPrinciple_ADomain_3(BlimpTask):
DATASET_NAME = "principle_A_domain_3"
class BlimpPrinciple_AReconstruction(BlimpTask):
DATASET_NAME = "principle_A_reconstruction"
class BlimpRegularPluralSubjectVerbAgreement_1(BlimpTask):
DATASET_NAME = "regular_plural_subject_verb_agreement_1"
class BlimpRegularPluralSubjectVerbAgreement_2(BlimpTask):
DATASET_NAME = "regular_plural_subject_verb_agreement_2"
class BlimpSententialNegationNpiLicensorPresent(BlimpTask):
DATASET_NAME = "sentential_negation_npi_licensor_present"
class BlimpSententialNegationNpiScope(BlimpTask):
DATASET_NAME = "sentential_negation_npi_scope"
class BlimpSententialSubjectIsland(BlimpTask):
DATASET_NAME = "sentential_subject_island"
class BlimpSuperlativeQuantifiers_1(BlimpTask):
DATASET_NAME = "superlative_quantifiers_1"
class BlimpSuperlativeQuantifiers_2(BlimpTask):
DATASET_NAME = "superlative_quantifiers_2"
class BlimpToughVsRaising_1(BlimpTask):
DATASET_NAME = "tough_vs_raising_1"
class BlimpToughVsRaising_2(BlimpTask):
DATASET_NAME = "tough_vs_raising_2"
class BlimpTransitive(BlimpTask):
DATASET_NAME = "transitive"
class BlimpWhIsland(BlimpTask):
DATASET_NAME = "wh_island"
class BlimpWhQuestionsObjectGap(BlimpTask):
DATASET_NAME = "wh_questions_object_gap"
class BlimpWhQuestionsSubjectGap(BlimpTask):
DATASET_NAME = "wh_questions_subject_gap"
class BlimpWhQuestionsSubjectGapLongDistance(BlimpTask):
DATASET_NAME = "wh_questions_subject_gap_long_distance"
class BlimpWhVsThatNoGap(BlimpTask):
DATASET_NAME = "wh_vs_that_no_gap"
class BlimpWhVsThatNoGapLongDistance(BlimpTask):
DATASET_NAME = "wh_vs_that_no_gap_long_distance"
class BlimpWhVsThatWithGap(BlimpTask):
DATASET_NAME = "wh_vs_that_with_gap"
class BlimpWhVsThatWithGapLongDistance(BlimpTask):
DATASET_NAME = "wh_vs_that_with_gap_long_distance"
"""
The Children’s Book Test (CBT) from the paper:
https://research.fb.com/wp-content/uploads/2016/11/the_goldilocks_principle_reading_children_s_books_with_explicit_memory_representations.pdf
The Children's Book Test (CBT) is test of how well language models capture
meaning in children's books. Unlike standard language modelling benchmarks,
it distinguishes the task of predicting syntactic function words from that
of predicting lower-frequency words, which carry greater semantic content.
NOTE: This evaluation is based on the (context + query) question-answering variant
used by the Recurrent Language Models described in the paper. See section 4.4.
Homepage: https://github.com/facebookresearch/ParlAI/tree/main/parlai/tasks/cbt
"""
import numpy as np
from lm_eval.base import rf, Task
from lm_eval.metrics import mean
_CITATION = """
@misc{hill2016goldilocks,
title={The Goldilocks Principle: Reading Children's Books with Explicit Memory Representations},
author={Felix Hill and Antoine Bordes and Sumit Chopra and Jason Weston},
year={2016},
eprint={1511.02301},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
"""
class CBTBase(Task):
VERSION = 0
DATASET_PATH = "cbt"
DATASET_NAME = None
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def test_docs(self):
return self.dataset["test"]
def detokenize(self, text):
text = text.replace(" '", "'")
text = text.replace(" \n", "\n")
text = text.replace("\n ", "\n")
text = text.replace(" n't", "n't")
text = text.replace("`` ", '"')
text = text.replace("''", '"')
# punctuation
text = text.replace(" :", ":")
text = text.replace(" ;", ";")
text = text.replace(" !", "!")
text = text.replace(" ?", "?")
text = text.replace(" ,", ",")
text = text.replace(" .", ".")
return text
def doc_to_text(self, doc):
passage = " ".join(doc["sentences"])
text = "Passage: " + passage + "\nQuestion: " + doc["question"]
return self.detokenize(text)
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
passage = " ".join(doc["sentences"])
return passage
def doc_to_target(self, doc):
return ""
def fewshot_examples(self, k, rnd):
assert (
k == 0
), f"CBT is only implemented for the zero-shot setting. Given k={k}."
return super().fewshot_examples(k, rnd)
def construct_requests(self, doc, ctx):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
lls = []
for option in doc["options"]:
# Following Section 4.4 "Recurrent Language Models" in the CBT paper:
# "we rank candidate [option] c based on p(q1 . . . qk−1, c, qk+1 . . . ql)
# rather than simply p(q1 . . . qk−1, c)."
lls.append(rf.loglikelihood("", ctx.replace("XXXXX", option))[0])
return lls
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
The results of the requests created in construct_requests.
"""
gold = doc["options"].index(doc["answer"])
pred = np.argmax(results)
return {"acc": pred == gold}
def aggregation(self):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
return {"acc": mean}
def higher_is_better(self):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return {"acc": True}
class CBTCN(CBTBase):
DATASET_NAME = "CN"
class CBTNE(CBTBase):
DATASET_NAME = "NE"
"""
CoQA: A Conversational Question Answering Challenge
https://arxiv.org/pdf/1808.07042.pdf
CoQA is a large-scale dataset for building Conversational Question Answering
systems. The goal of the CoQA challenge is to measure the ability of machines to
understand a text passage and answer a series of interconnected questions that
appear in a conversation.
Homepage: https://stanfordnlp.github.io/coqa/
"""
import inspect
import transformers.data.metrics.squad_metrics as squad_metrics
import lm_eval.datasets.coqa.coqa
from lm_eval.base import Task, rf, mean
from itertools import zip_longest
_CITATION = """
@misc{reddy2018coqa,
title={CoQA: A Conversational Question Answering Challenge},
author={Siva Reddy and Danqi Chen and Christopher D. Manning},
year={2018},
eprint={1808.07042},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
"""
class CoQA(Task):
VERSION = 1
DATASET_PATH = inspect.getfile(lm_eval.datasets.coqa.coqa)
DATASET_NAME = None
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return False
def training_docs(self):
return self.dataset["train"]
def validation_docs(self):
return self.dataset["validation"]
def test_docs(self):
pass
def doc_to_text(self, doc):
# Given a passage p, the conversation history {q1, a1, . . . qi−1, ai−1}
# and a question qi, the task is to predict the answer ai
doc_text = doc["story"] + "\n\n"
for (q, a) in zip_longest(
doc["questions"]["input_text"], doc["answers"]["input_text"][:-1]
): # omit target answer ai
question = f"Q: {q}\n\n"
answer = f"A: {a}\n\n" if a is not None else "A:"
doc_text += question + answer
return doc_text
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["story"] + " " + "\n".join(doc["questions"]["input_text"])
@classmethod
def get_answers(cls, doc, turn_id):
# Returns unique answers and valid alternatives (Some questions in CoQA have multiple valid answers).
answers = []
answer_forturn = doc["answers"]["input_text"][turn_id - 1]
answers.append(answer_forturn)
additional_answers = doc.get("additional_answers")
if additional_answers:
for key in additional_answers:
additional_answer_for_turn = additional_answers[key]["input_text"][
turn_id - 1
]
if additional_answer_for_turn.lower() not in map(str.lower, answers):
answers.append(additional_answer_for_turn)
return answers
@classmethod
def get_answer_choice(self, raw_text):
# Function maps answers to CoQA answer categories
# ~ 1/5 of the CoQA answers are Yes/No
# ~ 2/3 of the CoQA answers are span-based
# (answers overlap with the passage ignoring punctuation and case mismatch)
if raw_text == "unknown":
return "0"
if squad_metrics.normalize_answer(raw_text) == "yes":
return "1"
if squad_metrics.normalize_answer(raw_text) == "no":
return "2"
return "3" # Not a yes/no question
@staticmethod
def compute_scores(gold_list, pred):
# tests for exact match and on the normalised answer (compute_exact)
# test for overlap (compute_f1)
f1_sum = 0.0
em_sum = 0.0
if len(gold_list) > 1:
for i in range(len(gold_list)):
gold_answers = gold_list[0:i] + gold_list[i + 1 :]
# predictions compared against (n) golds and take maximum
em_sum += max(
squad_metrics.compute_exact(a, pred) for a in gold_answers
)
f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_answers)
else:
em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_list)
f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_list)
return {
"em": em_sum / max(1, len(gold_list)),
"f1": f1_sum / max(1, len(gold_list)),
}
def doc_to_target(self, doc, turnid=None):
# Default to prediction of last turn.
if turnid is None:
turnid = len(doc["questions"]["input_text"])
raw_text = doc["answers"]["input_text"][turnid - 1]
return " " + raw_text
def construct_requests(self, doc, ctx):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
cont_request = rf.greedy_until(ctx, {"until": ["\nQ:"]})
return cont_request
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
The results of the requests created in construct_requests.
"""
turn_id = len(doc["questions"]["input_text"])
gold_list = self.get_answers(doc, turn_id)
pred = results[0].strip().split("\n")[0]
scores = self.compute_scores(gold_list, pred)
return {
"f1": scores["f1"],
"em": scores["em"],
}
def higher_is_better(self):
return {
"f1": True,
"em": True,
}
def aggregation(self):
return {
"f1": mean,
"em": mean,
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment