Commit a07d05f7 authored by baberabb's avatar baberabb
Browse files

Merge remote-tracking branch 'origin/big-refactor' into nqopen_baber

# Conflicts:
#	lm_eval/api/task.py
parents b1d468f2 6ba2a2b0
import re import re
import string import string
import timeit
import pickle import pickle
import traceback import traceback
from pprint import pprint from pprint import pprint
from typing import Iterator, Sequence, TypeVar
# This is a cpp module. Compile janitor_util.cpp with: # This is a cpp module. Compile janitor_util.cpp with:
# c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup # c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup
...@@ -16,10 +16,12 @@ except Exception: ...@@ -16,10 +16,12 @@ except Exception:
traceback.print_exc() traceback.print_exc()
JANITOR_CPP = False JANITOR_CPP = False
T = TypeVar("T")
# Implementation from nltk source # Implementation from nltk source
# https://www.nltk.org/_modules/nltk/util.html # https://www.nltk.org/_modules/nltk/util.html
def form_ngrams(sequence, n): def form_ngrams(sequence: Iterator[T], n: int) -> Iterator[tuple[T, ...]]:
history = [] history = []
while n > 1: while n > 1:
# PEP 479, prevent RuntimeError from being raised when StopIteration bubbles out of generator # PEP 479, prevent RuntimeError from being raised when StopIteration bubbles out of generator
...@@ -36,7 +38,7 @@ def form_ngrams(sequence, n): ...@@ -36,7 +38,7 @@ def form_ngrams(sequence, n):
del history[0] del history[0]
def word_ngrams(s, n): def word_ngrams(s: str, n: int) -> Iterator[str]:
"""Splits a string into ngram words""" """Splits a string into ngram words"""
tokens = s.split() # not a generator :( tokens = s.split() # not a generator :(
ngram_seqs = form_ngrams(iter(tokens), n) ngram_seqs = form_ngrams(iter(tokens), n)
...@@ -68,14 +70,14 @@ def word_ngrams(s, n): ...@@ -68,14 +70,14 @@ def word_ngrams(s, n):
# https://stackoverflow.com/questions/13734451/string-split-with-indices-in-python # https://stackoverflow.com/questions/13734451/string-split-with-indices-in-python
def split_indices(s): def split_indices(s: str) -> Iterator[tuple[str, tuple[int, int]]]:
"""Splits a string on whitespaces and records the indices of each in the original string. """Splits a string on whitespaces and records the indices of each in the original string.
@:return generator((word, (start_idx, end_idx)), ...) @:return generator((word, (start_idx, end_idx)), ...)
""" """
return ((m.group(0), (m.start(), m.end() - 1)) for m in re.finditer(r"\S+", s)) return ((m.group(0), (m.start(), m.end() - 1)) for m in re.finditer(r"\S+", s))
def word_ngrams_indices(s, n): def word_ngrams_indices(s: str, n: int) -> Iterator[tuple[str, tuple[int, int]]]:
"""Splits a string into pairs of (ngram words, their start/end indices)""" """Splits a string into pairs of (ngram words, their start/end indices)"""
tokens_with_indices = split_indices(s) tokens_with_indices = split_indices(s)
...@@ -104,16 +106,15 @@ def word_ngrams_indices(s, n): ...@@ -104,16 +106,15 @@ def word_ngrams_indices(s, n):
class Janitor: class Janitor:
# FIXME delete_chars: Should anything else go here? Special chars? # FIXME delete_chars: Should anything else go here? Special chars?
def __init__( def __init__(
self, self,
ngram_n=13, ngram_n: int = 13,
window_to_remove=200, window_to_remove: int = 200,
too_dirty_cutoff=10, too_dirty_cutoff: int = 10,
minimum_slice_length=200, minimum_slice_length: int = 200,
delete_chars=string.punctuation, delete_chars: str = string.punctuation,
): ) -> None:
self.ngram_n = ngram_n self.ngram_n = ngram_n
self.window_to_remove = window_to_remove self.window_to_remove = window_to_remove
self.too_dirty_cutoff = too_dirty_cutoff self.too_dirty_cutoff = too_dirty_cutoff
...@@ -135,11 +136,11 @@ class Janitor: ...@@ -135,11 +136,11 @@ class Janitor:
# I/O for saving contamination ngrams # I/O for saving contamination ngrams
############## ##############
def save_contamination_ngrams(self, filename): def save_contamination_ngrams(self, filename: str) -> None:
with open(filename, "wb") as fp: with open(filename, "wb") as fp:
pickle.dump(filename, fp) pickle.dump(filename, fp)
def load_contamination_ngrams(self, filename): def load_contamination_ngrams(self, filename: str) -> None:
with open(filename, "rb") as fp: with open(filename, "rb") as fp:
self.dirt_ngrams = pickle.load(fp) self.dirt_ngrams = pickle.load(fp)
...@@ -147,7 +148,7 @@ class Janitor: ...@@ -147,7 +148,7 @@ class Janitor:
# Call these :) # Call these :)
############## ##############
def register_contaminant(self, dirt_string): def register_contaminant(self, dirt_string: str) -> None:
"""Register a string as contamination to be removed, e.g. a test set """Register a string as contamination to be removed, e.g. a test set
This breaks the dirt_string into ngrams to store for future cleaning""" This breaks the dirt_string into ngrams to store for future cleaning"""
if JANITOR_CPP: if JANITOR_CPP:
...@@ -156,7 +157,7 @@ class Janitor: ...@@ -156,7 +157,7 @@ class Janitor:
print("WARNING: Janitor running in python mode") print("WARNING: Janitor running in python mode")
return self.register_contaminant_python(dirt_string) return self.register_contaminant_python(dirt_string)
def clean(self, dirty_string): def clean(self, dirty_string: str) -> list[str]:
"""Clean a string (e.g. a training set) by removing all ngrams previously """Clean a string (e.g. a training set) by removing all ngrams previously
registered as contaminants. Returns a list of clean chunks, or empty if registered as contaminants. Returns a list of clean chunks, or empty if
the string was too dirty""" the string was too dirty"""
...@@ -166,7 +167,9 @@ class Janitor: ...@@ -166,7 +167,9 @@ class Janitor:
print("WARNING: Janitor running in python mode") print("WARNING: Janitor running in python mode")
return self.clean_python(dirty_string) return self.clean_python(dirty_string)
def _split_chunks(self, dirty_string, dirty_parts): def _split_chunks(
self, dirty_string: str, dirty_parts: Sequence[tuple]
) -> list[str]:
clean_chunks = [] clean_chunks = []
splice_idx = 0 splice_idx = 0
end = -1 end = -1
...@@ -189,12 +192,12 @@ class Janitor: ...@@ -189,12 +192,12 @@ class Janitor:
# Fast C++ # Fast C++
############## ##############
def register_contaminant_cpp(self, dirt_string): def register_contaminant_cpp(self, dirt_string) -> None:
self.dirt_ngrams.update( self.dirt_ngrams.update(
janitor_util.clean_ngram(dirt_string, self.delete_chars, self.ngram_n) janitor_util.clean_ngram(dirt_string, self.delete_chars, self.ngram_n)
) )
def clean_cpp(self, dirty_string): def clean_cpp(self, dirty_string: str) -> list[str]:
contamination_indices = janitor_util.clean_ngram_with_indices( contamination_indices = janitor_util.clean_ngram_with_indices(
dirty_string, self.delete_chars, self.ngram_n dirty_string, self.delete_chars, self.ngram_n
) )
...@@ -204,15 +207,15 @@ class Janitor: ...@@ -204,15 +207,15 @@ class Janitor:
# Slow python # Slow python
############## ##############
def normalize_string(self, s): def normalize_string(self, s: str) -> str:
return s.translate(self.translation_table) return s.translate(self.translation_table)
def register_contaminant_python(self, dirt_string): def register_contaminant_python(self, dirt_string: str) -> None:
self.dirt_ngrams.update( self.dirt_ngrams.update(
word_ngrams(self.normalize_string(dirt_string), self.ngram_n) word_ngrams(self.normalize_string(dirt_string), self.ngram_n)
) )
def clean_python(self, dirty_string): def clean_python(self, dirty_string: str) -> list[str]:
contamination_indices = ( contamination_indices = (
(None, *idx_pair) (None, *idx_pair)
for dirty_ngram, idx_pair in word_ngrams_indices(dirty_string, self.ngram_n) for dirty_ngram, idx_pair in word_ngrams_indices(dirty_string, self.ngram_n)
......
...@@ -42,11 +42,11 @@ def simple_evaluate( ...@@ -42,11 +42,11 @@ def simple_evaluate(
device=None, device=None,
use_cache=None, use_cache=None,
limit=None, limit=None,
bootstrap_iters=100000, bootstrap_iters: int = 100000,
check_integrity=False, check_integrity: bool = False,
decontamination_ngrams_path=None, decontamination_ngrams_path=None,
write_out=False, write_out: bool = False,
log_samples=True, log_samples: bool = True,
): ):
"""Instantiate and evaluate a model on a list of tasks. """Instantiate and evaluate a model on a list of tasks.
...@@ -117,10 +117,11 @@ def simple_evaluate( ...@@ -117,10 +117,11 @@ def simple_evaluate(
task_dict = lm_eval.tasks.get_task_dict(tasks) task_dict = lm_eval.tasks.get_task_dict(tasks)
for task_name in task_dict.keys(): for task_name in task_dict.keys():
task_obj = task_dict[task_name] task_obj = task_dict[task_name]
if type(task_obj) == tuple: if type(task_obj) == tuple:
group, task_obj = task_obj group, task_obj = task_obj
if task_obj is None:
continue
config = task_obj._config config = task_obj._config
if num_fewshot is not None: if num_fewshot is not None:
...@@ -175,17 +176,17 @@ def evaluate( ...@@ -175,17 +176,17 @@ def evaluate(
lm, lm,
task_dict, task_dict,
limit=None, limit=None,
bootstrap_iters=100000, bootstrap_iters: int = 100000,
decontamination_ngrams_path=None, decontamination_ngrams_path=None,
write_out=False, write_out: bool = False,
log_samples=True, log_samples: bool = True,
): ):
"""Instantiate and evaluate a model on a list of tasks. """Instantiate and evaluate a model on a list of tasks.
:param lm: obj :param lm: obj
Language Model Language Model
:param task_dict: dict[str, Task] :param task_dict: dict[str, Task]
Dictionary of tasks. Tasks will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise. Dictionary of tasks. Tasks will be taken to have name type(task).config.task .
:param limit: int, optional :param limit: int, optional
Limit the number of examples per task (only use this for testing) Limit the number of examples per task (only use this for testing)
:param bootstrap_iters: :param bootstrap_iters:
...@@ -210,24 +211,30 @@ def evaluate( ...@@ -210,24 +211,30 @@ def evaluate(
samples = collections.defaultdict(list) samples = collections.defaultdict(list)
# tracks all Instances/requests a model must generate output on. # tracks all Instances/requests a model must generate output on.
requests = collections.defaultdict(list) requests = collections.defaultdict(list)
# Stores task scores based on task grouping. # Aggregated task scores presented with groups
aggregate = collections.defaultdict(dict) results_agg = collections.defaultdict(dict)
# tracks if a task was chosen via user selecting a group containing it # Aggregated groups scores only
task_groups = collections.defaultdict(dict) groups_agg = collections.defaultdict(dict)
# stores the amount to pad out reqs per req. type so that # stores the amount to pad out reqs per req. type so that
# number of fwd passes per distributed rank is equal # number of fwd passes per distributed rank is equal
padding_requests = collections.defaultdict(int) padding_requests = collections.defaultdict(int)
# store the hierarchy to do proper ordering
# Stores group related keys and values for group-aggregation task_hierarchy = collections.defaultdict(list)
aggregate = collections.defaultdict(dict) # store the ordering of tasks and groups
task_groups = collections.defaultdict(dict) task_order = collections.defaultdict(int)
# store the aggregation for aggregating across tasks in the same group
sample_agg_fn = collections.defaultdict(dict)
# get lists of each type of request # get lists of each type of request
for task_name, task in task_dict.items(): for task_name, task in task_dict.items():
if type(task) == tuple: if type(task) == tuple:
group, task = task group_name, task = task
task_groups[task_name] = group task_hierarchy[group_name].append(task_name)
else:
task_hierarchy[task_name] = []
if task is None:
continue
versions[task_name] = task.VERSION versions[task_name] = task.VERSION
configs[task_name] = dict(task.dump_config()) configs[task_name] = dict(task.dump_config())
...@@ -252,7 +259,8 @@ def evaluate( ...@@ -252,7 +259,8 @@ def evaluate(
# print the prompt for the first few documents # print the prompt for the first few documents
if inst.doc_id < 1: if inst.doc_id < 1:
eval_logger.info( eval_logger.info(
f"Task: {task_name}; document {inst.doc_id}; context prompt (starting on next line):\n{inst.args[0]}\n(end of prompt on previous line)" f"Task: {task_name}; document {inst.doc_id}; context prompt (starting on next line):\
\n{inst.args[0]}\n(end of prompt on previous line)\ntarget string or answer choice index (starting on next line):\n{task.doc_to_target(inst.doc)}\n(end of target on previous line)"
) )
eval_logger.info(f"Request: {str(inst)}") eval_logger.info(f"Request: {str(inst)}")
...@@ -302,6 +310,8 @@ def evaluate( ...@@ -302,6 +310,8 @@ def evaluate(
for task_name, task in task_dict.items(): for task_name, task in task_dict.items():
if type(task) == tuple: if type(task) == tuple:
group, task = task group, task = task
if task is None:
continue
task.apply_filters() task.apply_filters()
### Collect values of metrics on all datapoints ### ### Collect values of metrics on all datapoints ###
...@@ -311,6 +321,8 @@ def evaluate( ...@@ -311,6 +321,8 @@ def evaluate(
for task_name, task in task_dict.items(): for task_name, task in task_dict.items():
if type(task) == tuple: if type(task) == tuple:
group, task = task group, task = task
if task is None:
continue
# TODO: make it possible to use a different metric per filter # TODO: make it possible to use a different metric per filter
# iterate over different filters used # iterate over different filters used
for key in task.instances[0].filtered_resps.keys(): for key in task.instances[0].filtered_resps.keys():
...@@ -349,7 +361,6 @@ def evaluate( ...@@ -349,7 +361,6 @@ def evaluate(
# if multigpu, then gather data across all ranks # if multigpu, then gather data across all ranks
# first gather logged samples across all ranks # first gather logged samples across all ranks
for task_name, task_samples in list(samples.items()): for task_name, task_samples in list(samples.items()):
full_samples = [None] * lm.world_size full_samples = [None] * lm.world_size
torch.distributed.all_gather_object(full_samples, task_samples) torch.distributed.all_gather_object(full_samples, task_samples)
...@@ -358,11 +369,17 @@ def evaluate( ...@@ -358,11 +369,17 @@ def evaluate(
# then collect metrics across all ranks # then collect metrics across all ranks
vals_torch = collections.defaultdict(list) vals_torch = collections.defaultdict(list)
for (task_name, key, metric), items in vals.items(): for (task_name, key, metric), items in vals.items():
numitem = 0 numitem = 0
if type(items[0]) == tuple: if type(items[0]) == tuple:
numitem = len(items[0]) numitem = len(items[0])
if isinstance(items[0], (str, list)):
# handle the string case
gathered_items = [None] * lm.accelerator.num_processes
torch.distributed.all_gather_object(gathered_items, items)
gathered_item = list(itertools.chain.from_iterable(gathered_items))
else:
# distributed gather requires all ranks to have same dimensions # distributed gather requires all ranks to have same dimensions
# so we pad out with float32 min value # so we pad out with float32 min value
pad_value = torch.finfo(torch.float32).min pad_value = torch.finfo(torch.float32).min
...@@ -392,31 +409,68 @@ def evaluate( ...@@ -392,31 +409,68 @@ def evaluate(
vals = vals_torch vals = vals_torch
if lm.rank == 0: if lm.rank == 0:
### Get task ordering for correct sample-wide aggregation
group_to_task = {}
for group in task_hierarchy.keys():
if group not in task_order:
task_order[group] = 0
if len(task_hierarchy[group]) > 0:
group_to_task[group] = task_hierarchy[group].copy()
for task in task_hierarchy[group]:
if task in task_order:
task_order[task] += 1
else:
task_order[task] = 1 + task_order[group]
if task in task_hierarchy:
group_to_task[group].remove(task)
group_to_task[group].extend(task_hierarchy[task])
task_to_group = {}
for group in group_to_task:
for task in group_to_task[group]:
if task in task_to_group:
task_to_group[task].append(group)
else:
task_to_group[task] = [group]
### Aggregate results over all datapoints ### ### Aggregate results over all datapoints ###
# aggregate results ; run bootstrap CIs # aggregate results ; run bootstrap CIs
for (task_name, key, metric), items in vals.items(): for (task_name, key, metric), items in vals.items():
task = task_dict[task_name] task = task_dict[task_name]
metric_key = metric + "," + key
if type(task) == tuple: if type(task) == tuple:
group, task = task group_name, task = task
task_score = task.aggregation()[metric](items) else:
results[task_name][metric + "," + key] = task_score group_name = None
# Need to put back in results agg_fn = task.aggregation()[metric]
# pythia | acc task_score = agg_fn(items)
# | perplexity
# | word_perplexity if group_name is not None:
# | byte_perplexity sample_metric_key = metric + "(sample agg)," + key
# | bits_per_byte for grouping in task_to_group[task_name]:
if bool(task_groups): if metric_key in results[grouping]:
group_name = task_groups[task_name] results[grouping][metric_key].append(task_score)
if metric not in aggregate[group_name]:
aggregate[group_name][metric] = [task_score]
else: else:
aggregate[group_name][metric].append(task_score) results[grouping][metric_key] = [task_score]
if sample_metric_key in results[grouping]:
results[grouping][sample_metric_key] += items
else:
results[grouping][sample_metric_key] = items.copy()
sample_agg_fn[grouping][sample_metric_key] = agg_fn
results[task_name][metric_key] = task_score
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this # so we run them less iterations. still looking for a cleaner way to do this
if bootstrap_iters > 0: if False: # bootstrap_iters > 0:
stderr = lm_eval.api.metrics.stderr_for_metric( stderr = lm_eval.api.metrics.stderr_for_metric(
metric=task.aggregation()[metric], metric=task.aggregation()[metric],
bootstrap_iters=min(bootstrap_iters, 1000) bootstrap_iters=min(bootstrap_iters, 1000)
...@@ -427,19 +481,38 @@ def evaluate( ...@@ -427,19 +481,38 @@ def evaluate(
if stderr is not None: if stderr is not None:
results[task_name][metric + "_stderr" + "," + key] = stderr(items) results[task_name][metric + "_stderr" + "," + key] = stderr(items)
if bool(aggregate): if bool(results):
for group in aggregate.keys(): for task_or_group in results.keys():
for metric in aggregate[group].keys(): for metric in results[task_or_group].keys():
aggregate[group][metric] = np.average(aggregate[group][metric]) if type(results[task_or_group][metric]) == list:
versions[group] = "N/A" if "(sample agg)" in metric:
results[task_or_group][metric] = sample_agg_fn[
task_or_group
][metric](results[task_or_group][metric])
else:
results[task_or_group][metric] = np.average(
results[task_or_group][metric]
)
versions[task_or_group] = "N/A"
for task_name, task in task_dict.items():
if type(task) == tuple:
group_name, task = task
order = task_order[group_name]
tabbed_name = "-" * order + group_name
results_agg[tabbed_name] = results[group_name]
versions[tabbed_name] = versions[group_name]
if order == 0:
groups_agg[group_name] = results[group_name]
order = task_order[task_name]
tabbed_name = "-" * order + task_name
results_agg[tabbed_name] = results[task_name]
versions[tabbed_name] = versions[task_name]
results_dict = { results_dict = {
"results": dict(sorted(results.items())), "results": dict(results_agg.items()),
**( **({"groups": dict(groups_agg.items())} if bool(groups_agg) else {}),
{"aggregate": dict(sorted(aggregate.items()))}
if bool(aggregate)
else {}
),
"configs": dict(sorted(configs.items())), "configs": dict(sorted(configs.items())),
"versions": dict(sorted(versions.items())), "versions": dict(sorted(versions.items())),
} }
......
...@@ -17,14 +17,16 @@ FILTER_REGISTRY = { ...@@ -17,14 +17,16 @@ FILTER_REGISTRY = {
def get_filter(filter_name): def get_filter(filter_name):
if filter_name in FILTER_REGISTRY:
return FILTER_REGISTRY[filter_name] return FILTER_REGISTRY[filter_name]
else:
return filter_name
def build_filter_ensemble(filter_name, components): def build_filter_ensemble(filter_name, components):
""" """
Create a filtering pipeline. Create a filtering pipeline.
""" """
filters = [] filters = []
for (function, kwargs) in components: for (function, kwargs) in components:
if kwargs is None: if kwargs is None:
......
...@@ -9,7 +9,7 @@ class DecontaminationFilter(Filter): ...@@ -9,7 +9,7 @@ class DecontaminationFilter(Filter):
name = "track_decontamination" name = "track_decontamination"
def __init__(self, path): def __init__(self, path) -> None:
""" """
TODO: make sure only ever run one time on the train set (should this be cached as a class var? keyed by value for "path"). TODO: make sure only ever run one time on the train set (should this be cached as a class var? keyed by value for "path").
...@@ -17,7 +17,7 @@ class DecontaminationFilter(Filter): ...@@ -17,7 +17,7 @@ class DecontaminationFilter(Filter):
""" """
self._decontam_results = None self._decontam_results = None
def apply(self, reps): def apply(self, resps, docs) -> None:
""" """
Return {"no_contamination", "only_contamination"} keys for the 2 different subsets Return {"no_contamination", "only_contamination"} keys for the 2 different subsets
""" """
......
...@@ -6,7 +6,9 @@ from lm_eval.api.filter import Filter ...@@ -6,7 +6,9 @@ from lm_eval.api.filter import Filter
class RegexFilter(Filter): class RegexFilter(Filter):
""" """ """ """
def __init__(self, regex_pattern=r"#### (\-?[0-9\.\,]+)", fallback="[invalid]"): def __init__(
self, regex_pattern: str = r"#### (\-?[0-9\.\,]+)", fallback: str = "[invalid]"
) -> None:
""" """
pass a string `regex` to run `re.compile(r"regex")` on. pass a string `regex` to run `re.compile(r"regex")` on.
`fallback` defines the output returned if no matches for the regex are located. `fallback` defines the output returned if no matches for the regex are located.
...@@ -15,7 +17,7 @@ class RegexFilter(Filter): ...@@ -15,7 +17,7 @@ class RegexFilter(Filter):
self.regex = re.compile(regex_pattern) self.regex = re.compile(regex_pattern)
self.fallback = fallback self.fallback = fallback
def apply(self, resps): def apply(self, resps, docs):
# here, we assume we have a list, in which each element is # here, we assume we have a list, in which each element is
# a list of model responses for some particular input/target pair. # a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets) # so we process each of these (same input/target response sets)
...@@ -41,12 +43,11 @@ class RegexFilter(Filter): ...@@ -41,12 +43,11 @@ class RegexFilter(Filter):
class WhitespaceFilter(Filter): class WhitespaceFilter(Filter):
""" """ """ """
def __init__(self): def __init__(self) -> None:
pass pass
def apply(self, resps): def apply(self, resps, docs):
def filter_set(inst): def filter_set(inst):
filtered_resp = [] filtered_resp = []
for resp in inst: for resp in inst:
if resp.startswith(" "): if resp.startswith(" "):
......
...@@ -4,12 +4,12 @@ from lm_eval.api.filter import Filter ...@@ -4,12 +4,12 @@ from lm_eval.api.filter import Filter
class TakeFirstFilter(Filter): class TakeFirstFilter(Filter):
def __init__(self): def __init__(self) -> None:
""" """
Can define custom behavior here, if an individual instantiation of a Filter class should have state. Can define custom behavior here, if an individual instantiation of a Filter class should have state.
""" """
def apply(self, resps): def apply(self, resps, docs):
""" """
Assuming each entry of `resps` is a list of model responses, we discard all but the first response. Assuming each entry of `resps` is a list of model responses, we discard all but the first response.
""" """
...@@ -17,13 +17,12 @@ class TakeFirstFilter(Filter): ...@@ -17,13 +17,12 @@ class TakeFirstFilter(Filter):
class TakeKFilter(Filter): class TakeKFilter(Filter):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs) -> None:
self.k = kwargs.pop("k") self.k = kwargs.pop("k")
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def apply(self, resps): def apply(self, resps, docs):
# check we have at least k responses per doc, else we can't take the first k # check we have at least k responses per doc, else we can't take the first k
assert ( assert (
len(resps[0]) >= self.k len(resps[0]) >= self.k
...@@ -32,12 +31,12 @@ class TakeKFilter(Filter): ...@@ -32,12 +31,12 @@ class TakeKFilter(Filter):
class MajorityVoteFilter(Filter): class MajorityVoteFilter(Filter):
def __init__(self): def __init__(self) -> None:
""" """
Can define custom behavior here, if an individual instantiation of a Filter class should have state. Can define custom behavior here, if an individual instantiation of a Filter class should have state.
""" """
def apply(self, resps): def apply(self, resps, docs):
""" """
Each entry of `resps` is a list of model responses. Each entry of `resps` is a list of model responses.
We select the response that occurs most frequently in each entry of `resps`. We select the response that occurs most frequently in each entry of `resps`.
......
...@@ -6,3 +6,5 @@ logging.basicConfig( ...@@ -6,3 +6,5 @@ logging.basicConfig(
level=logging.INFO, level=logging.INFO,
) )
eval_logger = logging.getLogger("lm-eval") eval_logger = logging.getLogger("lm-eval")
SPACING = " " * 47
...@@ -76,7 +76,7 @@ class AnthropicLM(LM): ...@@ -76,7 +76,7 @@ class AnthropicLM(LM):
max_tokens_to_sample: int = 256, max_tokens_to_sample: int = 256,
temperature: float = 0, # defaults to 1 temperature: float = 0, # defaults to 1
**kwargs, # top_p, top_k, etc. **kwargs, # top_p, top_k, etc.
): ) -> None:
"""Anthropic API wrapper. """Anthropic API wrapper.
:param model: str :param model: str
...@@ -135,11 +135,10 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e ...@@ -135,11 +135,10 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e
def tok_decode(self, tokens: List[int]) -> str: def tok_decode(self, tokens: List[int]) -> str:
return self.tokenizer.decode(tokens) return self.tokenizer.decode(tokens)
def _loglikelihood_tokens(self, requests, disable_tqdm=False): def _loglikelihood_tokens(self, requests, disable_tqdm: bool = False):
raise NotImplementedError("No support for logits.") raise NotImplementedError("No support for logits.")
def greedy_until(self, requests) -> List[str]: def greedy_until(self, requests) -> List[str]:
if not requests: if not requests:
return [] return []
......
...@@ -5,7 +5,7 @@ from lm_eval.api.registry import register_model ...@@ -5,7 +5,7 @@ from lm_eval.api.registry import register_model
@register_model("dummy") @register_model("dummy")
class DummyLM(LM): class DummyLM(LM):
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
@classmethod @classmethod
......
import os
import torch import torch
import transformers import transformers
from transformers.models.auto.modeling_auto import ( from transformers.models.auto.modeling_auto import (
...@@ -20,7 +22,7 @@ from lm_eval.api.registry import register_model ...@@ -20,7 +22,7 @@ from lm_eval.api.registry import register_model
from lm_eval.utils import MultiTokenEOSCriteria, stop_sequences_criteria from lm_eval.utils import MultiTokenEOSCriteria, stop_sequences_criteria
from accelerate import Accelerator, find_executable_batch_size from accelerate import Accelerator, find_executable_batch_size, DistributedType
from typing import List, Optional, Union from typing import List, Optional, Union
...@@ -67,6 +69,7 @@ class HFLM(LM): ...@@ -67,6 +69,7 @@ class HFLM(LM):
revision: Optional[str] = "main", revision: Optional[str] = "main",
subfolder: Optional[str] = None, subfolder: Optional[str] = None,
tokenizer: Optional[str] = None, tokenizer: Optional[str] = None,
truncation: Optional[bool] = False,
max_length: Optional[int] = None, max_length: Optional[int] = None,
device: Optional[str] = "cuda", device: Optional[str] = "cuda",
dtype: Optional[Union[str, torch.dtype]] = "auto", dtype: Optional[Union[str, torch.dtype]] = "auto",
...@@ -75,6 +78,7 @@ class HFLM(LM): ...@@ -75,6 +78,7 @@ class HFLM(LM):
low_cpu_mem_usage: Optional[bool] = True, low_cpu_mem_usage: Optional[bool] = True,
trust_remote_code: Optional[bool] = False, trust_remote_code: Optional[bool] = False,
use_fast_tokenizer: Optional[bool] = True, use_fast_tokenizer: Optional[bool] = True,
cache_dir: Optional[Union[str, os.PathLike]] = None,
# arguments used for splitting a model across GPUs naively. # arguments used for splitting a model across GPUs naively.
# only used if `parallelize=True`. # only used if `parallelize=True`.
parallelize: Optional[bool] = False, parallelize: Optional[bool] = False,
...@@ -90,7 +94,7 @@ class HFLM(LM): ...@@ -90,7 +94,7 @@ class HFLM(LM):
bnb_4bit_compute_dtype: Optional[Union[str, torch.dtype]] = None, bnb_4bit_compute_dtype: Optional[Union[str, torch.dtype]] = None,
gptq: Optional[Union[bool, str]] = False, gptq: Optional[Union[bool, str]] = False,
gptq_use_triton: Optional[bool] = False, gptq_use_triton: Optional[bool] = False,
): ) -> None:
super().__init__() super().__init__()
assert isinstance(device, str) assert isinstance(device, str)
...@@ -103,17 +107,20 @@ class HFLM(LM): ...@@ -103,17 +107,20 @@ class HFLM(LM):
if not (parallelize or accelerator.num_processes > 1): if not (parallelize or accelerator.num_processes > 1):
# use user-passed device # use user-passed device
device_list = set( device_list = set(
["cuda", "cpu", "mps"] ["cuda", "cpu"]
+ [f"cuda:{i}" for i in range(torch.cuda.device_count())] + [f"cuda:{i}" for i in range(torch.cuda.device_count())]
+ ["mps", "mps:0"]
) )
if device: if device:
if device not in device_list: if device not in device_list:
device = int(device) device = int(device)
self._device = torch.device(device) self._device = torch.device(device)
eval_logger.info(f"Using device '{device}'") eval_logger.info(f"Using device '{device}'")
if device == "mps": if device in ("mps", "mps:0") and "dev" not in torch.__version__:
eval_logger.info( eval_logger.info(
"MPS is still in beta and only supports float32; setting dtype to float32." "MPS: Setting dtype to float32. To use float16 with MPS, please install a nightly build of "
"PyTorch: pip3 install --pre torch torchvision torchaudio --index-url "
"https://download.pytorch.org/whl/nightly/cpu"
) )
else: else:
eval_logger.info("Device not specified") eval_logger.info("Device not specified")
...@@ -240,6 +247,8 @@ class HFLM(LM): ...@@ -240,6 +247,8 @@ class HFLM(LM):
use_fast=use_fast_tokenizer, use_fast=use_fast_tokenizer,
) )
self.truncation = truncation
self.vocab_size = self.tokenizer.vocab_size self.vocab_size = self.tokenizer.vocab_size
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
...@@ -288,6 +297,13 @@ class HFLM(LM): ...@@ -288,6 +297,13 @@ class HFLM(LM):
eval_logger.info( eval_logger.info(
"Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes`. If the desired GPU is being used, this message is safe to ignore." "Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes`. If the desired GPU is being used, this message is safe to ignore."
) )
else:
assert accelerator.distributed_type in [
DistributedType.FSDP,
DistributedType.MULTI_GPU,
], "Unsupported distributed type provided. Only DDP and FSDP are supported."
if accelerator.distributed_type == DistributedType.FSDP:
self._model = accelerator.prepare(self.model)
else: else:
self._model = accelerator.prepare_model( self._model = accelerator.prepare_model(
self.model, evaluation_mode=True self.model, evaluation_mode=True
...@@ -334,7 +350,7 @@ class HFLM(LM): ...@@ -334,7 +350,7 @@ class HFLM(LM):
return self._DEFAULT_MAX_LENGTH return self._DEFAULT_MAX_LENGTH
@property @property
def max_gen_toks(self): def max_gen_toks(self) -> int:
return 256 return 256
@property @property
...@@ -353,7 +369,7 @@ class HFLM(LM): ...@@ -353,7 +369,7 @@ class HFLM(LM):
def world_size(self): def world_size(self):
return self._world_size return self._world_size
def _detect_batch_size(self, requests=None, pos=0): def _detect_batch_size(self, requests=None, pos: int = 0):
if requests: if requests:
_, context_enc, continuation_enc = requests[pos] _, context_enc, continuation_enc = requests[pos]
max_length = len( max_length = len(
...@@ -419,7 +435,11 @@ class HFLM(LM): ...@@ -419,7 +435,11 @@ class HFLM(LM):
return encoding return encoding
def tok_batch_encode( def tok_batch_encode(
self, strings: List[str], padding_side="left", left_truncate_len=None self,
strings: List[str],
padding_side: str = "left",
left_truncate_len: int = None,
truncation: bool = False,
): ):
# encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode. # encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode.
old_padding_side = self.tokenizer.padding_side old_padding_side = self.tokenizer.padding_side
...@@ -432,6 +452,7 @@ class HFLM(LM): ...@@ -432,6 +452,7 @@ class HFLM(LM):
encoding = self.tokenizer( encoding = self.tokenizer(
strings, strings,
truncation=truncation,
padding="longest", padding="longest",
return_tensors="pt", return_tensors="pt",
add_special_tokens=add_special_tokens, add_special_tokens=add_special_tokens,
...@@ -595,7 +616,9 @@ class HFLM(LM): ...@@ -595,7 +616,9 @@ class HFLM(LM):
return loglikelihoods return loglikelihoods
def _loglikelihood_tokens(self, requests, disable_tqdm=False, override_bs=None): def _loglikelihood_tokens(
self, requests, disable_tqdm: bool = False, override_bs=None
):
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res = [] res = []
...@@ -856,7 +879,9 @@ class HFLM(LM): ...@@ -856,7 +879,9 @@ class HFLM(LM):
# encode, pad, and truncate contexts for this batch # encode, pad, and truncate contexts for this batch
context_enc, attn_masks = self.tok_batch_encode( context_enc, attn_masks = self.tok_batch_encode(
contexts, left_truncate_len=max_ctx_len contexts,
left_truncate_len=max_ctx_len,
truncation=self.truncation,
) )
context_enc = context_enc.to(self.device) context_enc = context_enc.to(self.device)
attn_masks = attn_masks.to(self.device) attn_masks = attn_masks.to(self.device)
......
...@@ -69,7 +69,7 @@ class OpenaiCompletionsLM(LM): ...@@ -69,7 +69,7 @@ class OpenaiCompletionsLM(LM):
engine: str = "text-davinci-003", engine: str = "text-davinci-003",
truncate: bool = False, truncate: bool = False,
batch_size: int = 1, batch_size: int = 1,
): ) -> None:
""" """
:param engine: str :param engine: str
...@@ -99,12 +99,12 @@ class OpenaiCompletionsLM(LM): ...@@ -99,12 +99,12 @@ class OpenaiCompletionsLM(LM):
return self.end_of_text_token_id return self.end_of_text_token_id
@property @property
def max_length(self): def max_length(self) -> int:
# Note: the OpenAI API supports up to 2049 tokens, with the first token being the first input token # Note: the OpenAI API supports up to 2049 tokens, with the first token being the first input token
return 2048 return 2048
@property @property
def max_gen_toks(self): def max_gen_toks(self) -> int:
return 256 return 256
@property @property
...@@ -152,7 +152,7 @@ class OpenaiCompletionsLM(LM): ...@@ -152,7 +152,7 @@ class OpenaiCompletionsLM(LM):
return self._loglikelihood_tokens(new_reqs) return self._loglikelihood_tokens(new_reqs)
def _loglikelihood_tokens( def _loglikelihood_tokens(
self, requests, disable_tqdm=False self, requests, disable_tqdm: bool = False
) -> List[Tuple[float, bool]]: ) -> List[Tuple[float, bool]]:
res = [] res = []
......
...@@ -41,7 +41,7 @@ def textsynth_completion(**kwargs): ...@@ -41,7 +41,7 @@ def textsynth_completion(**kwargs):
@register_model("textsynth") @register_model("textsynth")
class TextSynthLM(LM): class TextSynthLM(LM):
def __init__(self, engine, truncate=False): def __init__(self, engine, truncate: bool = False) -> None:
""" """
:param engine: str :param engine: str
TextSynth API engine (e.g. `gptj_6B`) TextSynth API engine (e.g. `gptj_6B`)
...@@ -62,12 +62,12 @@ class TextSynthLM(LM): ...@@ -62,12 +62,12 @@ class TextSynthLM(LM):
raise NotImplementedError() raise NotImplementedError()
@property @property
def max_length(self): def max_length(self) -> int:
# NOTE: Turn on truncation to avoid errors on long inputs. # NOTE: Turn on truncation to avoid errors on long inputs.
return 2048 return 2048
@property @property
def max_gen_toks(self): def max_gen_toks(self) -> int:
return 256 return 256
@property @property
......
import ast
from lm_eval import utils from lm_eval import utils
from lm_eval.logger import eval_logger from lm_eval.logger import eval_logger
...@@ -5,7 +7,7 @@ from lm_eval.logger import eval_logger ...@@ -5,7 +7,7 @@ from lm_eval.logger import eval_logger
# Stores prompts in a dictionary indexed by 2 levels: # Stores prompts in a dictionary indexed by 2 levels:
# prompt category name, and prompt name. # prompt category name, and prompt name.
# This allows us to access prompts # This allows us to access prompts
PROMPT_REGISTRY = { PROMPT_REGISTRY: dict[str, dict[str, str]] = {
"qa-basic": { "qa-basic": {
"question-newline-answer": "Question: {{question}}\nAnswer:", "question-newline-answer": "Question: {{question}}\nAnswer:",
"q-newline-a": "Q: {{question}}\nA:", "q-newline-a": "Q: {{question}}\nA:",
...@@ -13,7 +15,7 @@ PROMPT_REGISTRY = { ...@@ -13,7 +15,7 @@ PROMPT_REGISTRY = {
} }
def get_prompt(prompt_id: str, dataset_name=None, subset_name=None): def get_prompt(prompt_id: str, dataset_name: str = None, subset_name: str = None):
# unpack prompt name # unpack prompt name
category_name, prompt_name = prompt_id.split(":") category_name, prompt_name = prompt_id.split(":")
if subset_name is None: if subset_name is None:
...@@ -63,6 +65,12 @@ def load_prompt_list(use_prompt: str, dataset_name=None, subset_name=None, **kwa ...@@ -63,6 +65,12 @@ def load_prompt_list(use_prompt: str, dataset_name=None, subset_name=None, **kwa
else: else:
prompts = DatasetTemplates(dataset_name=dataset_name, subset_name=subset_name) prompts = DatasetTemplates(dataset_name=dataset_name, subset_name=subset_name)
category_name, prompt_name = use_prompt.split(":") category_name, *prompt_name = use_prompt.split(":")
# TODO allow to multiple prompt naming
# if len(prompt_name) > 1:
# prompt_list = []
# for prompt in prompt_name:
# prompt_list.append(utils.pattern_match(prompt_name, prompts.all_template_names))
# else:
prompt_list = utils.pattern_match(prompt_name, prompts.all_template_names) prompt_list = utils.pattern_match(prompt_name, prompts.all_template_names)
return [":".join([category_name, prompt]) for prompt in prompt_list] return [":".join([category_name, prompt]) for prompt in prompt_list]
...@@ -5,8 +5,8 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for ...@@ -5,8 +5,8 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for
- [x] Glue - [x] Glue
- [x] SuperGlue - [x] SuperGlue
- [ ] CoQA (Lintang) - [x] CoQA
- [ ] DROP (Lintang) - [x] DROP
- [x] ~~Lambada~~ - [x] ~~Lambada~~
- [x] Lambada (Cloze variants) - [x] Lambada (Cloze variants)
- [x] ~~Lambada (Multilingual)~~ - [x] ~~Lambada (Multilingual)~~
...@@ -29,7 +29,7 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for ...@@ -29,7 +29,7 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for
- [x] HeadQA - [x] HeadQA
- [x] MathQA - [x] MathQA
- [x] WebQs - [x] WebQs
- [ ] WSC273 (Lintang) - [x] WSC273
- [x] Winogrande - [x] Winogrande
- [x] ANLI - [x] ANLI
- [x] Hendrycks Ethics (missing some tasks/metrics, see PR 660: <https://github.com/EleutherAI/lm-evaluation-harness/pull/660> for more info) - [x] Hendrycks Ethics (missing some tasks/metrics, see PR 660: <https://github.com/EleutherAI/lm-evaluation-harness/pull/660> for more info)
...@@ -38,11 +38,11 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for ...@@ -38,11 +38,11 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for
- [x] TruthfulQA (gen) - [x] TruthfulQA (gen)
- [ ] MuTual - [ ] MuTual
- [ ] Hendrycks Math (Hailey) - [ ] Hendrycks Math (Hailey)
- [ ] Asdiv - [x] Asdiv
- [ ] GSM8k - [ ] GSM8k
- [x] Arithmetic - [x] Arithmetic
- [ ] MMMLU (Hailey) - [ ] MMMLU (Hailey)
- [ ] Translation (WMT) suite (Hailey) - [x] Translation (WMT) suite
- [x] Unscramble - [x] Unscramble
- [x] ~~Pile (perplexity)~~ - [x] ~~Pile (perplexity)~~
- [x] BLiMP - [x] BLiMP
...@@ -56,7 +56,7 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for ...@@ -56,7 +56,7 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for
- [x] XWinograd - [x] XWinograd
- [x] PAWS-X - [x] PAWS-X
- [x] XNLI - [x] XNLI
- [ ] MGSM (Lintang) - [x] MGSM
- [ ] SCROLLS - [ ] SCROLLS
- [x] Babi - [x] Babi
......
...@@ -15,7 +15,7 @@ from lm_eval.api.registry import ( ...@@ -15,7 +15,7 @@ from lm_eval.api.registry import (
) )
def register_configurable_task(config): def register_configurable_task(config: dict[str, str]) -> int:
SubClass = type( SubClass = type(
config["task"] + "ConfigurableTask", config["task"] + "ConfigurableTask",
(ConfigurableTask,), (ConfigurableTask,),
...@@ -38,7 +38,7 @@ def register_configurable_task(config): ...@@ -38,7 +38,7 @@ def register_configurable_task(config):
return 0 return 0
def check_prompt_config(config): def check_prompt_config(config: dict[str, str]) -> List[dict[str, str]]:
all_configs = [] all_configs = []
if "use_prompt" in config: if "use_prompt" in config:
prompt_list = prompts.load_prompt_list( prompt_list = prompts.load_prompt_list(
...@@ -69,14 +69,14 @@ def check_prompt_config(config): ...@@ -69,14 +69,14 @@ def check_prompt_config(config):
return all_configs return all_configs
def get_task_name_from_config(task_config): def get_task_name_from_config(task_config: dict[str, str]) -> str:
if "dataset_name" in task_config: if "dataset_name" in task_config:
return "{dataset_path}_{dataset_name}".format(**task_config) return "{dataset_path}_{dataset_name}".format(**task_config)
else: else:
return "{dataset_path}".format(**task_config) return "{dataset_path}".format(**task_config)
def include_task_folder(task_dir): def include_task_folder(task_dir: str) -> None:
""" """
Calling this function Calling this function
""" """
...@@ -136,6 +136,9 @@ def get_task_dict(task_name_list: List[Union[str, dict, Task]], **kwargs): ...@@ -136,6 +136,9 @@ def get_task_dict(task_name_list: List[Union[str, dict, Task]], **kwargs):
task_name_from_config_dict = {} task_name_from_config_dict = {}
task_name_from_object_dict = {} task_name_from_object_dict = {}
if type(task_name_list) != list:
task_name_list = [task_name_list]
for task_element in task_name_list: for task_element in task_name_list:
if isinstance(task_element, str): if isinstance(task_element, str):
...@@ -143,12 +146,20 @@ def get_task_dict(task_name_list: List[Union[str, dict, Task]], **kwargs): ...@@ -143,12 +146,20 @@ def get_task_dict(task_name_list: List[Union[str, dict, Task]], **kwargs):
group_name = task_element group_name = task_element
for task_name in GROUP_REGISTRY[task_element]: for task_name in GROUP_REGISTRY[task_element]:
if task_name not in task_name_from_registry_dict: if task_name not in task_name_from_registry_dict:
task_obj = get_task_dict(task_name)
if task_name in task_obj.keys():
task_dict = {
task_name: (group_name, task_obj[task_name]),
}
else:
task_dict = {
task_name: (group_name, None),
**task_obj,
}
task_name_from_registry_dict = { task_name_from_registry_dict = {
**task_name_from_registry_dict, **task_name_from_registry_dict,
task_name: ( **task_dict,
group_name,
get_task(task_name=task_name, config=config),
),
} }
else: else:
task_name = task_element task_name = task_element
......
task: asdiv
dataset_path: EleutherAI/asdiv
output_type: loglikelihood
validation_split: validation
doc_to_text: "{{body}}\nQuestion:{{question}}\nAnswer:"
doc_to_target: "{{answer.split(' (')[0]}}"
should_decontaminate: true
doc_to_decontamination_query: "{{body}} {{question}}"
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
# CoQA
### Paper
Title: `CoQA: A Conversational Question Answering Challenge`
Abstract: https://arxiv.org/pdf/1808.07042.pdf
CoQA is a large-scale dataset for building Conversational Question Answering
systems. The goal of the CoQA challenge is to measure the ability of machines to
understand a text passage and answer a series of interconnected questions that
appear in a conversation.
Homepage: https://stanfordnlp.github.io/coqa/
### Citation
```
BibTeX-formatted citation goes here
```
### Groups and Tasks
#### Groups
* Not part of a group yet
#### Tasks
* `coqa`
### Checklist
For adding novel benchmarks/datasets to the library:
* [ ] Is the task an existing benchmark in the literature?
* [ ] Have you referenced the original paper that introduced the task?
* [ ] If yes, does the original paper provide a reference implementation? If so, have you checked against the reference implementation and documented how to run such a test?
If other tasks on this dataset are already supported:
* [ ] Is the "Main" variant of this task clearly denoted?
* [ ] Have you provided a short sentence in a README on what each new variant adds / evaluates?
* [ ] Have you noted which, if any, published evaluation setups are matched by this variant?
task: coqa
dataset_path: EleutherAI/coqa
output_type: greedy_until
training_split: train
validation_split: validation
doc_to_text: !function utils.doc_to_text
doc_to_target: !function utils.doc_to_target
process_results: !function utils.process_results
should_decontaminate: true
doc_to_decontamination_query: "{{story}} {{question.input_text|join('\n')}}"
generation_kwargs:
until:
- "\nQ:"
metric_list:
- metric: em
aggregation: mean
higher_is_better: true
- metric: f1
aggregation: mean
higher_is_better: true
from itertools import zip_longest
import transformers.data.metrics.squad_metrics as squad_metrics
def doc_to_text(doc):
# Given a passage p, the conversation history {q1, a1, . . . qi−1, ai−1}
# and a question qi, the task is to predict the answer ai
doc_text = doc["story"] + "\n\n"
for (q, a) in zip_longest(
doc["questions"]["input_text"], doc["answers"]["input_text"][:-1]
): # omit target answer ai
question = f"Q: {q}\n\n"
answer = f"A: {a}\n\n" if a is not None else "A:"
doc_text += question + answer
return doc_text
def doc_to_target(doc):
turn_id = len(doc["questions"]["input_text"])
# Returns unique answers and valid alternatives (Some questions in CoQA have multiple valid answers).
answers = []
answer_forturn = doc["answers"]["input_text"][turn_id - 1]
answers.append(answer_forturn)
additional_answers = doc.get("additional_answers")
if additional_answers:
for key in additional_answers:
additional_answer_for_turn = additional_answers[key]["input_text"][
turn_id - 1
]
if additional_answer_for_turn.lower() not in map(str.lower, answers):
answers.append(additional_answer_for_turn)
return answers
def em(gold_list, pred):
# tests for exact match and on the normalised answer (compute_exact)
em_sum = 0.0
if len(gold_list) > 1:
for i in range(len(gold_list)):
gold_answers = gold_list[0:i] + gold_list[i + 1 :]
# predictions compared against (n) golds and take maximum
em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_answers)
else:
em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_list)
return em_sum / max(1, len(gold_list))
def compute_scores(gold_list, pred):
# tests for exact match and on the normalised answer (compute_exact)
# test for overlap (compute_f1)
f1_sum = 0.0
em_sum = 0.0
if len(gold_list) > 1:
for i in range(len(gold_list)):
gold_answers = gold_list[0:i] + gold_list[i + 1 :]
# predictions compared against (n) golds and take maximum
em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_answers)
f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_answers)
else:
em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_list)
f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_list)
return {
"em": em_sum / max(1, len(gold_list)),
"f1": f1_sum / max(1, len(gold_list)),
}
def process_results(doc, results):
gold_list = doc_to_target(doc)
pred = results[0].strip().split("\n")[0]
scores = compute_scores(gold_list, pred)
return scores
# DROP
### Paper
Title: `DROP: A Reading Comprehension Benchmark Requiring Discrete Reasoning Over Paragraphs`
Abstract: https://aclanthology.org/attachments/N19-1246.Supplementary.pdf
DROP is a QA dataset which tests comprehensive understanding of paragraphs. In
this crowdsourced, adversarially-created, 96k question-answering benchmark, a
system must resolve multiple references in a question, map them onto a paragraph,
and perform discrete operations over them (such as addition, counting, or sorting).
Homepage: https://allenai.org/data/drop
Acknowledgement: This implementation is based on the official evaluation for `DROP`:
https://github.com/allenai/allennlp-reading-comprehension/blob/master/allennlp_rc/eval/drop_eval.py
### Citation
```
@misc{dua2019drop,
title={DROP: A Reading Comprehension Benchmark Requiring Discrete Reasoning Over Paragraphs},
author={Dheeru Dua and Yizhong Wang and Pradeep Dasigi and Gabriel Stanovsky and Sameer Singh and Matt Gardner},
year={2019},
eprint={1903.00161},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```
### Groups and Tasks
#### Groups
* Not part of a group yet.
#### Tasks
* `drop`
### Checklist
For adding novel benchmarks/datasets to the library:
* [ ] Is the task an existing benchmark in the literature?
* [ ] Have you referenced the original paper that introduced the task?
* [ ] If yes, does the original paper provide a reference implementation? If so, have you checked against the reference implementation and documented how to run such a test?
If other tasks on this dataset are already supported:
* [ ] Is the "Main" variant of this task clearly denoted?
* [ ] Have you provided a short sentence in a README on what each new variant adds / evaluates?
* [ ] Have you noted which, if any, published evaluation setups are matched by this variant?
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment