Unverified Commit 761f0087 authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Merge pull request #560 from EleutherAI/dataset-metric-log

Dataset metric log [WIP]
parents 232632c6 ae4d9ed2
......@@ -25,11 +25,11 @@ graph LR;
I[Input]
F[Filter]
M[Model]
O[Ouput]:::empty
O[Output]:::empty
P[Prompt]
Me[Metric]
R[Result]
T --- I:::empty
P --- I
I --> M
......
......@@ -6,96 +6,95 @@ import sacrebleu
import sklearn.metrics
import random
import evaluate
AGGREGATION_REGISTRY = {}
METRIC_REGISTRY = {
"acc": None,
"acc_norm": None,
"acc_mutual_info": None,
"word_perplexity": None,
"byte_perplexity": None,
}
HIGHER_IS_BETTER_REGISTRY = {
"matthews_corrcoef": True,
"f1_score": True,
"perplexity": False,
"bleu": True,
"chrf": True,
"ter": False,
"acc": True,
"acc_norm": True,
"acc_mutual_info": True,
"word_perplexity": False,
"byte_perplexity": False,
"bits_per_byte": False,
}
def register_metric(name):
# TODO: do we want to enforce a certain interface to registered metrics?
def decorate(fn):
assert (
name not in METRIC_REGISTRY
), f"metric named '{name}' conflicts with existing registered metric!"
METRIC_REGISTRY[name] = fn
return fn
return decorate
def get_metric(name):
try:
return METRIC_REGISTRY[name]
except KeyError:
# TODO: change this print to logging?
print(
f"Could not find registered metric '{name}' in lm-eval, \
searching in HF Evaluate library..."
)
try:
metric_object = evaluate.load(name)
return metric_object.compute
except Exception:
raise Warning(
"{} not found in the evaluate library!".format(name),
"Please check https://huggingface.co/evaluate-metric",
)
def register_aggregation(name):
# TODO: should we enforce a specific interface to aggregation metrics?
def decorate(fn):
assert (
name not in AGGREGATION_REGISTRY
), f"aggregation named '{name}' conflicts with existing registered aggregation!"
AGGREGATION_REGISTRY[name] = fn
return fn
return decorate
def get_aggregation(name):
try:
return AGGREGATION_REGISTRY[name]
except KeyError:
raise Warning(
"{} not a registered aggregation metric!".format(name),
)
from lm_eval.api.registry import register_metric, register_aggregation
# Register Aggregations First
@register_aggregation("mean")
def mean(arr):
return sum(arr) / len(arr)
@register_aggregation("median")
def median(arr):
return arr[len(arr) // 2]
@register_aggregation("perplexity")
def perplexity(items):
return math.exp(-mean(items))
@register_aggregation("weighted_perplexity")
def weighted_perplexity(items):
return math.exp(-weighted_mean(items))
@register_aggregation("bits_per_byte")
def bits_per_byte(items):
return -weighted_mean(items) / math.log(2)
@register_metric(
metric="acc",
higher_is_better=True,
output_type=["loglikelihood", "multiple_choice"],
aggregation="mean",
)
def acc_fn(items): # This is a passthrough function
return items
@register_metric(
metric="acc_mutual_info",
higher_is_better=True,
output_type="multiple_choice",
aggregation="mean",
)
def acc_mutual_info_fn(items): # This is a passthrough function
return items
@register_metric(
metric="perplexity",
higher_is_better=False,
output_type="loglikelihood",
aggregation="perplexity",
)
def perplexity_fn(items): # This is a passthrough function
return items
@register_metric(
metric="word_perplexity",
higher_is_better=False,
output_type="loglikelihood_rolling",
aggregation="weighted_perplexity",
)
def word_perplexity_fn(items): # This is a passthrough function
return items
@register_metric(
metric="byte_perplexity",
higher_is_better=False,
output_type="loglikelihood_rolling",
aggregation="weighted_perplexity",
)
def byte_perplexity_fn(items): # This is a passthrough function
return items
@register_metric(
metric="bits_per_byte",
higher_is_better=False,
output_type="loglikelihood_rolling",
aggregation="bits_per_byte",
)
def bits_per_byte_fn(items): # This is a passthrough function
return items
def pop_stddev(arr):
mu = mean(arr)
return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / len(arr))
......@@ -110,12 +109,7 @@ def mean_stderr(arr):
return sample_stddev(arr) / math.sqrt(len(arr))
@register_aggregation("median")
def median(arr):
return arr[len(arr) // 2]
@register_metric("matthews_corrcoef")
@register_metric(metric="matthews_corrcoef", higher_is_better=True, aggregation="mean")
def matthews_corrcoef(items):
unzipped_list = list(zip(*items))
golds = unzipped_list[0]
......@@ -123,7 +117,12 @@ def matthews_corrcoef(items):
return sklearn.metrics.matthews_corrcoef(golds, preds)
@register_metric("f1_score")
@register_metric(
metric="f1",
higher_is_better=True,
output_type="multiple_choice",
aggregation="mean",
)
def f1_score(items):
unzipped_list = list(zip(*items))
golds = unzipped_list[0]
......@@ -133,6 +132,12 @@ def f1_score(items):
return np.max(fscore)
@register_metric(
metric="acc_all",
higher_is_better=True,
output_type="loglikelihood",
aggregation="mean",
)
def acc_all(items):
# Only count as correct if all answers are labeled correctly for each question
question_scoring_dict = {}
......@@ -179,30 +184,12 @@ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
return max(scores_for_ground_truths)
@register_metric("perplexity")
@register_aggregation("perplexity")
def perplexity(items):
return math.exp(-mean(items))
def weighted_mean(items):
a, b = zip(*items)
return sum(a) / sum(b)
@register_metric("weighted_perplexity")
@register_aggregation("weighted_perplexity")
def weighted_perplexity(items):
return math.exp(-weighted_mean(items))
@register_metric("bits_per_byte")
@register_aggregation("bits_per_byte")
def bits_per_byte(items):
return -weighted_mean(items) / math.log(2)
@register_metric("bleu")
@register_metric(metric="bleu", higher_is_better=True, aggregation="mean")
def bleu(items):
"""The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric
for evaluating a generated sentence to a reference sentence. It counts matching
......@@ -220,7 +207,7 @@ def bleu(items):
return sacrebleu.corpus_bleu(preds, refs).score
@register_metric("chrf")
@register_metric(metric="chrf", higher_is_better=True, aggregation="mean")
def chrf(items):
"""chrF++ is a tool for automatic evaluation of machine translation output
based on character n-gram precision and recall enhanced with word n-grams.
......@@ -235,7 +222,7 @@ def chrf(items):
return sacrebleu.corpus_chrf(preds, refs).score
@register_metric("ter")
@register_metric(metric="ter", higher_is_better=True, aggregation="mean")
def ter(items):
"""Translation Error Rate is an error metric for machine translation that
measures the number of edits required to change a system output into one
......
......@@ -4,32 +4,6 @@ from typing import Union
from lm_eval import utils
MODEL_REGISTRY = {}
def register_model(*names):
# either pass a list or a single alias.
# function receives them as a tuple of strings
def decorate(cls):
for name in names:
assert issubclass(
cls, LM
), f"Model '{name}' ({cls.__name__}) must extend LM class"
assert (
name not in MODEL_REGISTRY
), f"Model named '{name}' conflicts with existing model! Please register with a non-conflicting alias instead."
MODEL_REGISTRY[name] = cls
return cls
return decorate
def get_model(model_name):
return MODEL_REGISTRY[model_name]
class LM(abc.ABC):
def __init__(self):
......
import os
task_registry = {}
group_registry = {}
task2func_index = {}
func2task_index = {}
def register_task(name):
def wrapper(func):
task_registry[name] = func
func2task_index[func.__name__] = name
task2func_index[name] = func.__name__
return func
return wrapper
def register_group(name):
def wrapper(func):
func_name = func2task_index[func.__name__]
if name in group_registry:
group_registry[name].append(func_name)
else:
group_registry[name] = [func_name]
return func
return wrapper
import os
import evaluate
from lm_eval.api.model import LM
MODEL_REGISTRY = {}
def register_model(*names):
# either pass a list or a single alias.
# function receives them as a tuple of strings
def decorate(cls):
for name in names:
assert issubclass(
cls, LM
), f"Model '{name}' ({cls.__name__}) must extend LM class"
assert (
name not in MODEL_REGISTRY
), f"Model named '{name}' conflicts with existing model! Please register with a non-conflicting alias instead."
MODEL_REGISTRY[name] = cls
return cls
return decorate
def get_model(model_name):
return MODEL_REGISTRY[model_name]
TASK_REGISTRY = {}
GROUP_REGISTRY = {}
func2task_index = {}
def register_task(name):
def decorate(fn):
assert (
name not in TASK_REGISTRY
), f"task named '{name}' conflicts with existing registered task!"
TASK_REGISTRY[name] = fn
func2task_index[fn.__name__] = name
return fn
return decorate
def register_group(name):
def decorate(fn):
# assert (
# name not in GROUP_REGISTRY
# ), f"group named '{name}' conflicts with existing registered group!"
func_name = func2task_index[fn.__name__]
if name in GROUP_REGISTRY:
GROUP_REGISTRY[name].append(func_name)
else:
GROUP_REGISTRY[name] = [func_name]
return fn
return decorate
AGGREGATION_REGISTRY = {}
DEFAULT_AGGREGATION_REGISTRY = {}
METRIC_REGISTRY = {}
OUTPUT_TYPE_REGISTRY = {}
HIGHER_IS_BETTER_REGISTRY = {}
DEFAULT_METRIC_REGISTRY = {
"loglikelihood": [
"perplexity",
"acc",
],
"loglikelihood_rolling": ["word_perplexity", "byte_perplexity", "bits_per_byte"],
"multiple_choice": [
"acc",
],
"greedy_until": ["exact_match"],
}
def register_metric(**args):
# TODO: do we want to enforce a certain interface to registered metrics?
def decorate(fn):
assert "metric" in args
name = args["metric"]
for key, registry in [
("metric", METRIC_REGISTRY),
("higher_is_better", HIGHER_IS_BETTER_REGISTRY),
# ("output_type", OUTPUT_TYPE_REGISTRY),
("aggregation", DEFAULT_AGGREGATION_REGISTRY),
]:
if key in args:
value = args[key]
assert (
value not in registry
), f"{key} named '{value}' conflicts with existing registered {key}!"
if key == "metric":
registry[name] = fn
elif key == "aggregation":
registry[name] = AGGREGATION_REGISTRY[value]
else:
registry[name] = value
return fn
return decorate
def get_metric(name):
try:
return METRIC_REGISTRY[name]
except KeyError:
# TODO: change this print to logging?
print(
f"Could not find registered metric '{name}' in lm-eval, \
searching in HF Evaluate library..."
)
try:
metric_object = evaluate.load(name)
return metric_object.compute
except Exception:
raise Warning(
"{} not found in the evaluate library!".format(name),
"Please check https://huggingface.co/evaluate-metric",
)
def register_aggregation(name):
# TODO: should we enforce a specific interface to aggregation metrics?
def decorate(fn):
assert (
name not in AGGREGATION_REGISTRY
), f"aggregation named '{name}' conflicts with existing registered aggregation!"
AGGREGATION_REGISTRY[name] = fn
return fn
return decorate
def get_aggregation(name):
try:
return AGGREGATION_REGISTRY[name]
except KeyError:
raise Warning(
"{} not a registered aggregation metric!".format(name),
)
......@@ -18,20 +18,33 @@ from collections.abc import Callable
from lm_eval import utils
from lm_eval.api import samplers
from lm_eval.api.instance import Instance
from lm_eval.api.filter import FilterEnsemble
from lm_eval.logger import eval_logger
from lm_eval.prompts import get_prompt
from lm_eval.filters import build_filter_ensemble
from lm_eval.api.metrics import (
METRIC_REGISTRY,
AGGREGATION_REGISTRY,
HIGHER_IS_BETTER_REGISTRY,
get_metric,
get_aggregation,
# get_metric,
# get_aggregation,
mean,
weighted_perplexity,
bits_per_byte,
)
from lm_eval.api.registry import (
METRIC_REGISTRY,
DEFAULT_METRIC_REGISTRY,
OUTPUT_TYPE_REGISTRY,
AGGREGATION_REGISTRY,
HIGHER_IS_BETTER_REGISTRY,
DEFAULT_AGGREGATION_REGISTRY,
)
from lm_eval.logger import eval_logger
from lm_eval.prompts import get_prompt
from lm_eval.filters import build_filter_ensemble
ALL_OUTPUT_TYPES = [
"loglikelihood",
"multiple_choice",
"loglikelihood_rolling",
"greedy_until",
]
@dataclass
......@@ -43,15 +56,16 @@ class TaskConfig(dict):
task_name: str = (
None # TODO: deprecate this, it'll be set in __post_init__ to be names[0]
)
base_task: str = None
dataset_path: str = None
dataset_name: str = None
dataset_kwargs: dict = None
training_split: str = None
validation_split: str = None
test_split: str = None
fewshot_split: str = None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?)
template_aliases: str = None
aliases: Union[str, list] = None
doc_to_text: Union[Callable, str] = None
doc_to_target: Union[Callable, str] = None
......@@ -87,7 +101,7 @@ class TaskConfig(dict):
if type(self.gold_alias) == str:
self.gold_alias = self.template_aliases + self.doc_to_target
if not self.generation_kwargs:
# ensure that we greedily generate in absence of explicit arguments otherwise
self.generation_kwargs = {"do_sample": False, "temperature": 0.0}
......@@ -439,8 +453,12 @@ class Task(abc.ABC):
def apply_filters(self):
for f in self._filters:
f.apply(self._instances)
if hasattr(self, "_filters"):
for f in self._filters:
f.apply(self._instances)
else:
eval_logger.warning("No filter defined, passing through instances")
return self._instances
class ConfigurableTask(Task):
......@@ -469,6 +487,7 @@ class ConfigurableTask(Task):
)
if self._config.output_type is not None:
assert self._config.output_type in ALL_OUTPUT_TYPES
self.OUTPUT_TYPE = self._config.output_type
if self._config.dataset_path is not None:
......@@ -477,35 +496,42 @@ class ConfigurableTask(Task):
if self._config.dataset_name is not None:
self.DATASET_NAME = self._config.dataset_name
if self._config.metric_list is not None:
self._metric_list = {}
self._metric_kwargs = {}
self._aggregation_list = {}
self._higher_is_better = {}
for metric_config in self._config.metric_list:
self._metric_fn_list = {}
self._metric_fn_kwargs = {}
self._aggregation_list = {}
self._higher_is_better = {}
_metric_list = DEFAULT_METRIC_REGISTRY[self._config.output_type]
if self._config.metric_list is None:
for metric_name in _metric_list:
self._metric_fn_list[metric_name] = METRIC_REGISTRY[metric_name]
self._aggregation_list[metric_name] = DEFAULT_AGGREGATION_REGISTRY[
metric_name
]
self._higher_is_better[metric_name] = HIGHER_IS_BETTER_REGISTRY[
metric_name
]
else:
for metric_config in self._config.metric_list:
assert "metric" in metric_config
metric_name = metric_config["metric"]
aggregation = metric_config["aggregation"]
higher_is_better = metric_config["higher_is_better"]
kwargs = {
key: metric_config[key]
for key in metric_config
if key not in ["metric", "aggregation", "higher_is_better"]
}
self._aggregation_list[metric_name] = AGGREGATION_REGISTRY[aggregation]
if metric_name in METRIC_REGISTRY.keys():
self._metric_list[metric_name] = METRIC_REGISTRY[metric_name]
self._higher_is_better[metric_name] = HIGHER_IS_BETTER_REGISTRY[
metric_name
]
if metric_name in _metric_list:
self._metric_fn_list[metric_name] = METRIC_REGISTRY[metric_name]
else:
self._higher_is_better[metric_name] = higher_is_better
eval_logger.warning(
f"Metric {metric_name} not found, "
"Searching from https://huggingface.co/evaluate-metric"
)
try:
metric_object = evaluate.load(metric_name)
self._metric_list[metric_name] = metric_object
self._metric_kwargs[metric_name] = kwargs
self._metric_fn_list[metric_name] = metric_object
self._metric_fn_kwargs[metric_name] = kwargs
except Exception:
raise Warning(
......@@ -513,12 +539,36 @@ class ConfigurableTask(Task):
"Please check https://huggingface.co/evaluate-metric",
)
self.download(data_dir, cache_dir, download_mode)
if "aggregation" in metric_config:
self._aggregation_list[metric_name] = metric_config["aggregation"]
else:
eval_logger.warning(
f"metric {metric_name} is defined, but aggregation is not"
f"using default aggregation for {metric_name}"
)
self._aggregation_list[metric_name] = DEFAULT_AGGREGATION_REGISTRY[
metric_name
]
if "higher_is_better" in metric_config:
self._higher_is_better[metric_name] = metric_config[
"higher_is_better"
]
else:
eval_logger.warning(
f"metric {metric_name} is defined, but higher_is_better is not"
f"using default higher_is_better for {metric_name}"
)
self._higher_is_better[metric_name] = HIGHER_IS_BETTER_REGISTRY[
metric_name
]
self.download(self._config.dataset_kwargs)
self._training_docs = None
self._fewshot_docs = None
self._filters = []
if self._config.filter_list is not None:
self._filters = []
for filter_config in self._config.filter_list:
for filter_pipeline in filter_config:
filter_name = filter_config["name"]
......@@ -530,7 +580,7 @@ class ConfigurableTask(Task):
}
components.append([function["function"], kwargs])
filter_pipeline = build_filter_ensemble(filter_name, components)
filter_pipeline = build_filter_ensemble(filter_name, components)
self._filters.append(filter_pipeline)
else:
self._filters = [
......@@ -550,6 +600,14 @@ class ConfigurableTask(Task):
list(self.fewshot_docs()), self, rnd=random.Random()
) # TODO: pass the correct docs in here
def download(self, dataset_kwargs=None):
self.dataset = datasets.load_dataset(
path=self.DATASET_PATH,
name=self.DATASET_NAME,
**dataset_kwargs if dataset_kwargs is not None else {},
)
def has_training_docs(self):
if self._config.training_split is not None:
return True
......@@ -643,7 +701,7 @@ class ConfigurableTask(Task):
raise TypeError
def gold_alias(self, doc):
# TODO: reevaluate if we need this. implemented to have a
# TODO: reevaluate if we need this. implemented to have a
# processed version of answer to put into gsm8k exact_match scoring as ref.
if self._config.gold_alias:
doc_to_target = self._config.gold_alias
......@@ -684,7 +742,7 @@ class ConfigurableTask(Task):
for i, choice in enumerate(choices)
]
# TODO: we should raise a warning telling users this will at most ~2x runtime.
if "acc_mutual_info" in self._metric_list.keys():
if "acc_mutual_info" in self._metric_fn_list.keys():
# if we are calculating multiple choice accuracy
# using mutual information instead of raw loglikelihood as metric, need unconditional lls.
......@@ -714,25 +772,44 @@ class ConfigurableTask(Task):
def process_results(self, doc, results):
# if callable(self._config.process_results):
# return self._config.process_results(doc, results)
result_dict = {}
use_metric = list(self._metric_fn_list.keys())
if self.OUTPUT_TYPE == "loglikelihood":
results = results[0]
ll, is_greedy = results
result_dict = {"perplexity": ll, "accuracy": int(is_greedy)}
return {
**({"perplexity": ll} if "perplexity" in use_metric else {}),
**({"acc": int(is_greedy)} if "acc" in use_metric else {}),
}
elif self.OUTPUT_TYPE == "loglikelihood_rolling":
(loglikelihood,) = results
words = self.count_words(self.doc_to_target(doc))
bytes_ = self.count_bytes(self.doc_to_target(doc))
_words = self.count_words(self.doc_to_target(doc))
_bytes = self.count_bytes(self.doc_to_target(doc))
return {
"word_perplexity": (loglikelihood, words),
"byte_perplexity": (loglikelihood, bytes_),
"bits_per_byte": (loglikelihood, bytes_),
**(
{"word_perplexity": (loglikelihood, _words)}
if "word_perplexity" in use_metric
else {}
),
**(
{"byte_perplexity": (loglikelihood, _bytes)}
if "byte_perplexity" in use_metric
else {}
),
**(
{"bits_per_byte": (loglikelihood, _bytes)}
if "bits_per_byte" in use_metric
else {}
),
}
elif self.OUTPUT_TYPE == "multiple_choice":
lls = [
res[0] for res in results
] # only retain loglikelihoods, discard is_greedy
lls, is_greedy = zip(*results)
gold = int(self.doc_to_target(doc))
pred = np.argmax(lls)
# retrieve choices in List[str] form, to compute choice lengths, etc.
choices = ast.literal_eval(
utils.apply_template(
......@@ -755,21 +832,18 @@ class ConfigurableTask(Task):
acc_norm = 1.0 if np.argmax(lls / completion_len) == gold else 0.0
result_dict = {
"acc": acc,
"acc_norm": acc_norm,
**({"acc": acc} if "acc" in use_metric else {}),
**({"f1": (pred, gold)} if "f1" in use_metric else {}),
**({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}),
}
# TODO: set which normalization metrics should be reported, and calculate them
if "exact_match" in self._metric_list.keys():
if "exact_match" in self._metric_fn_list.keys():
# TODO: this gets score of 0 on arc_challenge for pythia-70m. need to test that this works properly
is_greedy = [
res[1] for res in results
] # take only the `is_greedy` results
is_greedy = is_greedy[gold] # take value for the gold answer
result_dict["exact_match"] = int(is_greedy)
if "acc_mutual_info" in self._metric_list.keys():
if "acc_mutual_info" in use_metric:
lls_mutual_info = [
ll_c - ll_u for ll_c, ll_u in zip(lls, lls_unconditional)
]
......@@ -783,16 +857,16 @@ class ConfigurableTask(Task):
else:
gold = self.doc_to_target(doc)
for key, result in zip(self._metric_list.keys(), results):
_dict = self._metric_list[key].compute(
for key, result in zip(self._metric_fn_list.keys(), results):
_dict = self._metric_fn_list[key].compute(
references=[gold], predictions=[result], **self._metric_kwargs[key]
)
result_dict[key] = _dict[key]
result_dict = {**result_dict, **_dict}
else:
raise ValueError(
f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ",
"'loglikelihood', 'loglikelihood_rolling', 'greedy_until'",
"'loglikelihood', 'loglikelihood_rolling', 'greedy_until', or 'multiple_choice'",
)
return result_dict
......
......@@ -7,10 +7,10 @@ import torch
import numpy as np
import lm_eval.api
import lm_eval.api.metrics
import lm_eval.tasks
import lm_eval.models
import lm_eval.api.metrics
import lm_eval.api.registry
from lm_eval.utils import (
positional_deprecated,
......@@ -72,7 +72,7 @@ def simple_evaluate(
if isinstance(model, str):
if model_args is None:
model_args = ""
lm = lm_eval.api.model.get_model(model).create_from_arg_string(
lm = lm_eval.api.registry.get_model(model).create_from_arg_string(
model_args, {"batch_size": batch_size, "device": device}
)
else:
......@@ -274,9 +274,7 @@ def evaluate(
# aggregate results ; run bootstrap CIs
for (task_name, key, metric), items in vals.items():
task = task_dict[task_name]
results[task_name][metric + "," + key] = task.aggregation()[
metric
](items)
results[task_name][metric + "," + key] = task.aggregation()[metric](items)
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this
......@@ -289,9 +287,7 @@ def evaluate(
)
if stderr is not None:
results[task_name][metric + "_stderr" + "," + key] = stderr(
items
)
results[task_name][metric + "_stderr" + "," + key] = stderr(items)
return {"results": dict(results), "versions": dict(versions)}
......
......@@ -6,7 +6,7 @@ from . import extraction
FILTER_REGISTRY = {
"take_first": selection.TakeFirstFilter,
"regex": extraction.RegexFilter,
"majority_vote": selection.MajorityVoteFilter,
"majority_vote": selection.MajorityVoteFilter,
"take_first_k": selection.TakeKFilter,
# TODO: implement this filter. either it should take in an arbitrary "scoring"/reward function
# that takes an input and returns a scalar and then should select the max reward,
......
......@@ -15,8 +15,8 @@ class TakeFirstFilter(Filter):
"""
return map(lambda r: r[0], resps)
class TakeKFilter(Filter):
class TakeKFilter(Filter):
def __init__(self, *args, **kwargs):
self.k = kwargs.pop("k")
......@@ -25,8 +25,10 @@ class TakeKFilter(Filter):
def apply(self, resps):
# check we have at least k responses per doc, else we can't take the first k
assert len(resps[0]) >= self.k, f"Need at least {self.k} responses per doc to take first {self.k}, but got {len(resps[0])} only! Please increase TaskConfig.repeats ."
return map(lambda r: r[:self.k], resps)
assert (
len(resps[0]) >= self.k
), f"Need at least {self.k} responses per doc to take first {self.k}, but got {len(resps[0])} only! Please increase TaskConfig.repeats ."
return map(lambda r: r[: self.k], resps)
class MajorityVoteFilter(Filter):
......@@ -37,12 +39,13 @@ class MajorityVoteFilter(Filter):
def apply(self, resps):
"""
Each entry of `resps` is a list of model responses.
Each entry of `resps` is a list of model responses.
We select the response that occurs most frequently in each entry of `resps`.
"""
def select_majority(resp):
counts = Counter(resp)
vote = counts.most_common(1)[0][0]
vote = counts.most_common(1)[0][0]
return vote
return map(lambda r: [select_majority(r)], resps)
import random
from lm_eval.api.model import LM, register_model
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
@register_model("dummy")
......
......@@ -6,7 +6,8 @@ import numpy as np
from tqdm import tqdm
from lm_eval import utils
from lm_eval.api.model import LM, register_model
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
def get_result(response, ctxlen):
......
......@@ -8,7 +8,8 @@ import torch.nn.functional as F
from lm_eval import utils
from lm_eval.logger import eval_logger
from lm_eval.api.model import LM, register_model
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
from accelerate import Accelerator
from itertools import islice
......@@ -38,10 +39,10 @@ class HFLM(LM):
if device not in ["cuda", "cpu"]:
device = int(device)
self._device = torch.device(device)
print(f"Using device '{device}'")
eval_logger.info(f"Using device '{device}'")
else:
print("Device not specified")
print(f"Cuda Available? {torch.cuda.is_available()}")
eval_logger.info("Device not specified")
eval_logger.info(f"Cuda Available? {torch.cuda.is_available()}")
self._device = (
torch.device("cuda")
if torch.cuda.is_available()
......@@ -75,13 +76,12 @@ class HFLM(LM):
if gpus > 1:
accelerator = Accelerator()
if gpus > accelerator.num_processes:
warning = (
eval_logger.warning(
"WARNING: The number of total system GPUs does not match the number of spawned processes. "
"If you would like to use data parallelism, please launch the script "
"with 'accelerate launch *script*'. "
f"Current run will proceed with {accelerator.num_processes} devices."
)
print(warning)
self._rank = accelerator.local_process_index
self._world_size = accelerator.num_processes
else:
......@@ -90,7 +90,7 @@ class HFLM(LM):
self.accelerator = accelerator
if self.accelerator.is_local_main_process:
print(f"Using {gpus} devices with data parallelism")
eval_logger.info(f"Using {gpus} devices with data parallelism")
self._rank = self.accelerator.local_process_index
self._world_size = self.accelerator.num_processes
......@@ -154,17 +154,26 @@ class HFLM(LM):
return self.model(inps)[0]
def _model_generate(self, context, max_length, eos_token_id, **generation_kwargs):
# we require users to pass do_sample=True explicitly
# we require users to pass do_sample=True explicitly
# for non-greedy gen. This should be reevaluated when considering beam search.
if "do_sample" not in generation_kwargs.keys():
generation_kwargs["do_sample"] = False
return self.model.generate(
context,
max_length=max_length,
pad_token_id=eos_token_id,
eos_token_id=eos_token_id,
**generation_kwargs,
)
if hasattr(self, "accelerator"):
return self.accelerator.unwrap_model(self.model).generate(
context,
max_length=max_length,
pad_token_id=eos_token_id,
eos_token_id=eos_token_id,
**generation_kwargs,
)
else:
return self.model.generate(
context,
max_length=max_length,
pad_token_id=eos_token_id,
eos_token_id=eos_token_id,
**generation_kwargs,
)
def loglikelihood(self, requests):
new_reqs = []
......@@ -354,7 +363,7 @@ class HFLM(LM):
for context, gen_kwargs in tqdm(re_ord.get_reordered()):
if isinstance(gen_kwargs, dict):
gen_kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
gen_kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
if "until" in gen_kwargs.keys():
until = gen_kwargs.pop("until")
if isinstance(until, str):
......@@ -362,9 +371,11 @@ class HFLM(LM):
elif not isinstance(until, list):
raise ValueError(
f"Expected `gen_kwargs['until']` to be of type Union[str,list] but got {until}"
)
)
else:
raise ValueError(f"Expected `gen_kwargs` to be of type `dict` but got {gen_kwargs}")
raise ValueError(
f"Expected `gen_kwargs` to be of type `dict` but got {gen_kwargs}"
)
if not until:
until = [self.tok_decode(self.eot_token_id)]
if "max_gen_toks" in gen_kwargs.keys():
......@@ -374,7 +385,7 @@ class HFLM(LM):
try:
(primary_until,) = self.tok_encode(until[0])
except:
except Exception:
# if our primary until would be multiple tokens long, we'll have errors.
# TODO: handling this better will let us stop generating earlier + often.
primary_until = self.eot_token_id
......@@ -384,8 +395,8 @@ class HFLM(LM):
).to(self.device)
cont = self._model_generate(
context=context_enc,
max_length=context_enc.shape[1] + max_gen_toks,
context=context_enc,
max_length=context_enc.shape[1] + max_gen_toks,
eos_token_id=primary_until,
**gen_kwargs,
)
......
......@@ -16,7 +16,8 @@ import os
import requests as _requests
import time
from tqdm import tqdm
from lm_eval.api.model import LM, register_model
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
logger = logging.getLogger(__name__)
......
......@@ -64,4 +64,4 @@ Tasks added in the revamped harness that were not previously available. Again, a
- [ ] Chain of Thought
- [ ] Self-consistency ; Least-to-Most prompting, etc.
- [ ] Summarization Tasks
- [ ] Anthropic Model-Written Evals
\ No newline at end of file
- [ ] Anthropic Model-Written Evals
......@@ -7,16 +7,16 @@ from .triviaqa import *
from lm_eval import utils
from lm_eval.logger import eval_logger
from lm_eval.api.task import TaskConfig, Task, ConfigurableTask
from lm_eval.api.register import (
from lm_eval.api.registry import (
register_task,
register_group,
task_registry,
group_registry,
TASK_REGISTRY,
GROUP_REGISTRY,
)
def get_task_name_from_config(task_config):
return "configurable_{dataset_path}_{dataset_name}".format(**task_config)
return "{dataset_path}_{dataset_name}".format(**task_config)
task_dir = os.path.dirname(os.path.abspath(__file__)) + "/"
......@@ -35,24 +35,20 @@ for root, subdirs, file_list in os.walk(task_dir):
)
if "task" in config:
task_name = "{}".format(
config["task"]
)
task_name = "{}".format(config["task"])
register_task(task_name)(SubClass)
if "group" in config:
for group in config["group"]:
register_group(group)(SubClass)
except Exception as e:
raise e
except Exception as error:
eval_logger.warning(
"Failed to load config in\n"
f" {yaml_path}\n"
" Config will not be added to registry"
f" Error: {error}"
)
TASK_REGISTRY = task_registry
GROUP_REGISTRY = group_registry
ALL_TASKS = sorted(list(TASK_REGISTRY.keys()) + list(GROUP_REGISTRY.keys()))
......
"""
Think you have Solved Question Answering? Try ARC, the AI2 Reasoning Challenge
https://arxiv.org/pdf/1803.05457.pdf
The ARC dataset consists of 7,787 science exam questions drawn from a variety
of sources, including science questions provided under license by a research
partner affiliated with AI2. These are text-only, English language exam questions
that span several grade levels as indicated in the files. Each question has a
multiple choice structure (typically 4 answer options). The questions are sorted
into a Challenge Set of 2,590 “hard” questions (those that both a retrieval and
a co-occurrence method fail to answer correctly) and an Easy Set of 5,197 questions.
Homepage: https://allenai.org/data/arc
"""
from lm_eval import utils
from lm_eval.prompts import get_prompt
from lm_eval.api.task import MultipleChoiceTask
from lm_eval.api.registry import register_task, register_group
_CITATION = """
@article{Clark2018ThinkYH,
title={Think you have Solved Question Answering? Try ARC, the AI2 Reasoning Challenge},
author={Peter Clark and Isaac Cowhey and Oren Etzioni and Tushar Khot and Ashish Sabharwal and Carissa Schoenick and Oyvind Tafjord},
journal={ArXiv},
year={2018},
volume={abs/1803.05457}
}
"""
@register_group("arc")
@register_task("arc_easy")
class ARCEasy(MultipleChoiceTask):
VERSION = "2.0"
DATASET_PATH = "ai2_arc"
DATASET_NAME = "ARC-Easy"
OUTPUT_TYPE = "loglikelihood"
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs
def validation_docs(self):
return map(self._process_doc, self.dataset["validation"])
def test_docs(self):
return map(self._process_doc, self.dataset["test"])
def _process_doc(self, doc):
# NOTE: Some `doc["answerKey"]`s are in numeric string format being one
# of {'1', '2', '3', '4', '5'}. We map them back to letters.
num_to_letter = {"1": "A", "2": "B", "3": "C", "4": "D", "5": "E"}
doc["answerKey"] = num_to_letter.get(doc["answerKey"], doc["answerKey"])
out_doc = {
"id": doc["id"],
"question": doc["question"],
"choices": doc["choices"]["text"],
"gold": ["A", "B", "C", "D", "E"].index(doc["answerKey"]),
}
return out_doc
def doc_to_text(self, doc):
doc_to_text = get_prompt("qa-basic:question-newline-answer")
return utils.apply_template(doc_to_text, doc)
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["query"]
@register_group("arc")
@register_task("arc_challenge")
class ARCChallenge(ARCEasy):
DATASET_PATH = "ai2_arc"
DATASET_NAME = "ARC-Challenge"
......@@ -24,7 +24,7 @@ from lm_eval.api.instance import Instance
from lm_eval.prompts import get_prompt
from lm_eval.api.register import register_task, register_group
from lm_eval.api.registry import register_task, register_group
_CITATION = """
@misc{cobbe2021training,
......
......@@ -29,4 +29,4 @@ Homepage: https://github.com/openai/grade-school-math
archivePrefix={arXiv},
primaryClass={cs.LG}
}
```
\ No newline at end of file
```
......@@ -29,4 +29,4 @@ filter_list:
- function: "regex"
regex_pattern: "The answer is (\\-?[0-9\\.\\,]*[0-9]+)"
- function: "majority_vote"
- function: "take_first"
\ No newline at end of file
- function: "take_first"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment