Unverified Commit 761f0087 authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Merge pull request #560 from EleutherAI/dataset-metric-log

Dataset metric log [WIP]
parents 232632c6 ae4d9ed2
...@@ -25,11 +25,11 @@ graph LR; ...@@ -25,11 +25,11 @@ graph LR;
I[Input] I[Input]
F[Filter] F[Filter]
M[Model] M[Model]
O[Ouput]:::empty O[Output]:::empty
P[Prompt] P[Prompt]
Me[Metric] Me[Metric]
R[Result] R[Result]
T --- I:::empty T --- I:::empty
P --- I P --- I
I --> M I --> M
......
...@@ -6,96 +6,95 @@ import sacrebleu ...@@ -6,96 +6,95 @@ import sacrebleu
import sklearn.metrics import sklearn.metrics
import random import random
import evaluate from lm_eval.api.registry import register_metric, register_aggregation
AGGREGATION_REGISTRY = {}
METRIC_REGISTRY = {
"acc": None,
"acc_norm": None,
"acc_mutual_info": None,
"word_perplexity": None,
"byte_perplexity": None,
}
HIGHER_IS_BETTER_REGISTRY = {
"matthews_corrcoef": True,
"f1_score": True,
"perplexity": False,
"bleu": True,
"chrf": True,
"ter": False,
"acc": True,
"acc_norm": True,
"acc_mutual_info": True,
"word_perplexity": False,
"byte_perplexity": False,
"bits_per_byte": False,
}
def register_metric(name):
# TODO: do we want to enforce a certain interface to registered metrics?
def decorate(fn):
assert (
name not in METRIC_REGISTRY
), f"metric named '{name}' conflicts with existing registered metric!"
METRIC_REGISTRY[name] = fn
return fn
return decorate
def get_metric(name):
try:
return METRIC_REGISTRY[name]
except KeyError:
# TODO: change this print to logging?
print(
f"Could not find registered metric '{name}' in lm-eval, \
searching in HF Evaluate library..."
)
try:
metric_object = evaluate.load(name)
return metric_object.compute
except Exception:
raise Warning(
"{} not found in the evaluate library!".format(name),
"Please check https://huggingface.co/evaluate-metric",
)
def register_aggregation(name):
# TODO: should we enforce a specific interface to aggregation metrics?
def decorate(fn):
assert (
name not in AGGREGATION_REGISTRY
), f"aggregation named '{name}' conflicts with existing registered aggregation!"
AGGREGATION_REGISTRY[name] = fn
return fn
return decorate
def get_aggregation(name):
try:
return AGGREGATION_REGISTRY[name]
except KeyError:
raise Warning(
"{} not a registered aggregation metric!".format(name),
)
# Register Aggregations First
@register_aggregation("mean") @register_aggregation("mean")
def mean(arr): def mean(arr):
return sum(arr) / len(arr) return sum(arr) / len(arr)
@register_aggregation("median")
def median(arr):
return arr[len(arr) // 2]
@register_aggregation("perplexity")
def perplexity(items):
return math.exp(-mean(items))
@register_aggregation("weighted_perplexity")
def weighted_perplexity(items):
return math.exp(-weighted_mean(items))
@register_aggregation("bits_per_byte")
def bits_per_byte(items):
return -weighted_mean(items) / math.log(2)
@register_metric(
metric="acc",
higher_is_better=True,
output_type=["loglikelihood", "multiple_choice"],
aggregation="mean",
)
def acc_fn(items): # This is a passthrough function
return items
@register_metric(
metric="acc_mutual_info",
higher_is_better=True,
output_type="multiple_choice",
aggregation="mean",
)
def acc_mutual_info_fn(items): # This is a passthrough function
return items
@register_metric(
metric="perplexity",
higher_is_better=False,
output_type="loglikelihood",
aggregation="perplexity",
)
def perplexity_fn(items): # This is a passthrough function
return items
@register_metric(
metric="word_perplexity",
higher_is_better=False,
output_type="loglikelihood_rolling",
aggregation="weighted_perplexity",
)
def word_perplexity_fn(items): # This is a passthrough function
return items
@register_metric(
metric="byte_perplexity",
higher_is_better=False,
output_type="loglikelihood_rolling",
aggregation="weighted_perplexity",
)
def byte_perplexity_fn(items): # This is a passthrough function
return items
@register_metric(
metric="bits_per_byte",
higher_is_better=False,
output_type="loglikelihood_rolling",
aggregation="bits_per_byte",
)
def bits_per_byte_fn(items): # This is a passthrough function
return items
def pop_stddev(arr): def pop_stddev(arr):
mu = mean(arr) mu = mean(arr)
return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / len(arr)) return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / len(arr))
...@@ -110,12 +109,7 @@ def mean_stderr(arr): ...@@ -110,12 +109,7 @@ def mean_stderr(arr):
return sample_stddev(arr) / math.sqrt(len(arr)) return sample_stddev(arr) / math.sqrt(len(arr))
@register_aggregation("median") @register_metric(metric="matthews_corrcoef", higher_is_better=True, aggregation="mean")
def median(arr):
return arr[len(arr) // 2]
@register_metric("matthews_corrcoef")
def matthews_corrcoef(items): def matthews_corrcoef(items):
unzipped_list = list(zip(*items)) unzipped_list = list(zip(*items))
golds = unzipped_list[0] golds = unzipped_list[0]
...@@ -123,7 +117,12 @@ def matthews_corrcoef(items): ...@@ -123,7 +117,12 @@ def matthews_corrcoef(items):
return sklearn.metrics.matthews_corrcoef(golds, preds) return sklearn.metrics.matthews_corrcoef(golds, preds)
@register_metric("f1_score") @register_metric(
metric="f1",
higher_is_better=True,
output_type="multiple_choice",
aggregation="mean",
)
def f1_score(items): def f1_score(items):
unzipped_list = list(zip(*items)) unzipped_list = list(zip(*items))
golds = unzipped_list[0] golds = unzipped_list[0]
...@@ -133,6 +132,12 @@ def f1_score(items): ...@@ -133,6 +132,12 @@ def f1_score(items):
return np.max(fscore) return np.max(fscore)
@register_metric(
metric="acc_all",
higher_is_better=True,
output_type="loglikelihood",
aggregation="mean",
)
def acc_all(items): def acc_all(items):
# Only count as correct if all answers are labeled correctly for each question # Only count as correct if all answers are labeled correctly for each question
question_scoring_dict = {} question_scoring_dict = {}
...@@ -179,30 +184,12 @@ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): ...@@ -179,30 +184,12 @@ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
return max(scores_for_ground_truths) return max(scores_for_ground_truths)
@register_metric("perplexity")
@register_aggregation("perplexity")
def perplexity(items):
return math.exp(-mean(items))
def weighted_mean(items): def weighted_mean(items):
a, b = zip(*items) a, b = zip(*items)
return sum(a) / sum(b) return sum(a) / sum(b)
@register_metric("weighted_perplexity") @register_metric(metric="bleu", higher_is_better=True, aggregation="mean")
@register_aggregation("weighted_perplexity")
def weighted_perplexity(items):
return math.exp(-weighted_mean(items))
@register_metric("bits_per_byte")
@register_aggregation("bits_per_byte")
def bits_per_byte(items):
return -weighted_mean(items) / math.log(2)
@register_metric("bleu")
def bleu(items): def bleu(items):
"""The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric """The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric
for evaluating a generated sentence to a reference sentence. It counts matching for evaluating a generated sentence to a reference sentence. It counts matching
...@@ -220,7 +207,7 @@ def bleu(items): ...@@ -220,7 +207,7 @@ def bleu(items):
return sacrebleu.corpus_bleu(preds, refs).score return sacrebleu.corpus_bleu(preds, refs).score
@register_metric("chrf") @register_metric(metric="chrf", higher_is_better=True, aggregation="mean")
def chrf(items): def chrf(items):
"""chrF++ is a tool for automatic evaluation of machine translation output """chrF++ is a tool for automatic evaluation of machine translation output
based on character n-gram precision and recall enhanced with word n-grams. based on character n-gram precision and recall enhanced with word n-grams.
...@@ -235,7 +222,7 @@ def chrf(items): ...@@ -235,7 +222,7 @@ def chrf(items):
return sacrebleu.corpus_chrf(preds, refs).score return sacrebleu.corpus_chrf(preds, refs).score
@register_metric("ter") @register_metric(metric="ter", higher_is_better=True, aggregation="mean")
def ter(items): def ter(items):
"""Translation Error Rate is an error metric for machine translation that """Translation Error Rate is an error metric for machine translation that
measures the number of edits required to change a system output into one measures the number of edits required to change a system output into one
......
...@@ -4,32 +4,6 @@ from typing import Union ...@@ -4,32 +4,6 @@ from typing import Union
from lm_eval import utils from lm_eval import utils
MODEL_REGISTRY = {}
def register_model(*names):
# either pass a list or a single alias.
# function receives them as a tuple of strings
def decorate(cls):
for name in names:
assert issubclass(
cls, LM
), f"Model '{name}' ({cls.__name__}) must extend LM class"
assert (
name not in MODEL_REGISTRY
), f"Model named '{name}' conflicts with existing model! Please register with a non-conflicting alias instead."
MODEL_REGISTRY[name] = cls
return cls
return decorate
def get_model(model_name):
return MODEL_REGISTRY[model_name]
class LM(abc.ABC): class LM(abc.ABC):
def __init__(self): def __init__(self):
......
import os
task_registry = {}
group_registry = {}
task2func_index = {}
func2task_index = {}
def register_task(name):
def wrapper(func):
task_registry[name] = func
func2task_index[func.__name__] = name
task2func_index[name] = func.__name__
return func
return wrapper
def register_group(name):
def wrapper(func):
func_name = func2task_index[func.__name__]
if name in group_registry:
group_registry[name].append(func_name)
else:
group_registry[name] = [func_name]
return func
return wrapper
import os
import evaluate
from lm_eval.api.model import LM
MODEL_REGISTRY = {}
def register_model(*names):
# either pass a list or a single alias.
# function receives them as a tuple of strings
def decorate(cls):
for name in names:
assert issubclass(
cls, LM
), f"Model '{name}' ({cls.__name__}) must extend LM class"
assert (
name not in MODEL_REGISTRY
), f"Model named '{name}' conflicts with existing model! Please register with a non-conflicting alias instead."
MODEL_REGISTRY[name] = cls
return cls
return decorate
def get_model(model_name):
return MODEL_REGISTRY[model_name]
TASK_REGISTRY = {}
GROUP_REGISTRY = {}
func2task_index = {}
def register_task(name):
def decorate(fn):
assert (
name not in TASK_REGISTRY
), f"task named '{name}' conflicts with existing registered task!"
TASK_REGISTRY[name] = fn
func2task_index[fn.__name__] = name
return fn
return decorate
def register_group(name):
def decorate(fn):
# assert (
# name not in GROUP_REGISTRY
# ), f"group named '{name}' conflicts with existing registered group!"
func_name = func2task_index[fn.__name__]
if name in GROUP_REGISTRY:
GROUP_REGISTRY[name].append(func_name)
else:
GROUP_REGISTRY[name] = [func_name]
return fn
return decorate
AGGREGATION_REGISTRY = {}
DEFAULT_AGGREGATION_REGISTRY = {}
METRIC_REGISTRY = {}
OUTPUT_TYPE_REGISTRY = {}
HIGHER_IS_BETTER_REGISTRY = {}
DEFAULT_METRIC_REGISTRY = {
"loglikelihood": [
"perplexity",
"acc",
],
"loglikelihood_rolling": ["word_perplexity", "byte_perplexity", "bits_per_byte"],
"multiple_choice": [
"acc",
],
"greedy_until": ["exact_match"],
}
def register_metric(**args):
# TODO: do we want to enforce a certain interface to registered metrics?
def decorate(fn):
assert "metric" in args
name = args["metric"]
for key, registry in [
("metric", METRIC_REGISTRY),
("higher_is_better", HIGHER_IS_BETTER_REGISTRY),
# ("output_type", OUTPUT_TYPE_REGISTRY),
("aggregation", DEFAULT_AGGREGATION_REGISTRY),
]:
if key in args:
value = args[key]
assert (
value not in registry
), f"{key} named '{value}' conflicts with existing registered {key}!"
if key == "metric":
registry[name] = fn
elif key == "aggregation":
registry[name] = AGGREGATION_REGISTRY[value]
else:
registry[name] = value
return fn
return decorate
def get_metric(name):
try:
return METRIC_REGISTRY[name]
except KeyError:
# TODO: change this print to logging?
print(
f"Could not find registered metric '{name}' in lm-eval, \
searching in HF Evaluate library..."
)
try:
metric_object = evaluate.load(name)
return metric_object.compute
except Exception:
raise Warning(
"{} not found in the evaluate library!".format(name),
"Please check https://huggingface.co/evaluate-metric",
)
def register_aggregation(name):
# TODO: should we enforce a specific interface to aggregation metrics?
def decorate(fn):
assert (
name not in AGGREGATION_REGISTRY
), f"aggregation named '{name}' conflicts with existing registered aggregation!"
AGGREGATION_REGISTRY[name] = fn
return fn
return decorate
def get_aggregation(name):
try:
return AGGREGATION_REGISTRY[name]
except KeyError:
raise Warning(
"{} not a registered aggregation metric!".format(name),
)
...@@ -18,20 +18,33 @@ from collections.abc import Callable ...@@ -18,20 +18,33 @@ from collections.abc import Callable
from lm_eval import utils from lm_eval import utils
from lm_eval.api import samplers from lm_eval.api import samplers
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
from lm_eval.api.filter import FilterEnsemble
from lm_eval.logger import eval_logger
from lm_eval.prompts import get_prompt
from lm_eval.filters import build_filter_ensemble
from lm_eval.api.metrics import ( from lm_eval.api.metrics import (
METRIC_REGISTRY, # get_metric,
AGGREGATION_REGISTRY, # get_aggregation,
HIGHER_IS_BETTER_REGISTRY,
get_metric,
get_aggregation,
mean, mean,
weighted_perplexity, weighted_perplexity,
bits_per_byte, bits_per_byte,
) )
from lm_eval.api.registry import (
METRIC_REGISTRY,
DEFAULT_METRIC_REGISTRY,
OUTPUT_TYPE_REGISTRY,
AGGREGATION_REGISTRY,
HIGHER_IS_BETTER_REGISTRY,
DEFAULT_AGGREGATION_REGISTRY,
)
from lm_eval.logger import eval_logger ALL_OUTPUT_TYPES = [
from lm_eval.prompts import get_prompt "loglikelihood",
from lm_eval.filters import build_filter_ensemble "multiple_choice",
"loglikelihood_rolling",
"greedy_until",
]
@dataclass @dataclass
...@@ -43,15 +56,16 @@ class TaskConfig(dict): ...@@ -43,15 +56,16 @@ class TaskConfig(dict):
task_name: str = ( task_name: str = (
None # TODO: deprecate this, it'll be set in __post_init__ to be names[0] None # TODO: deprecate this, it'll be set in __post_init__ to be names[0]
) )
base_task: str = None
dataset_path: str = None dataset_path: str = None
dataset_name: str = None dataset_name: str = None
dataset_kwargs: dict = None
training_split: str = None training_split: str = None
validation_split: str = None validation_split: str = None
test_split: str = None test_split: str = None
fewshot_split: str = None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?) fewshot_split: str = None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?)
template_aliases: str = None template_aliases: str = None
aliases: Union[str, list] = None
doc_to_text: Union[Callable, str] = None doc_to_text: Union[Callable, str] = None
doc_to_target: Union[Callable, str] = None doc_to_target: Union[Callable, str] = None
...@@ -87,7 +101,7 @@ class TaskConfig(dict): ...@@ -87,7 +101,7 @@ class TaskConfig(dict):
if type(self.gold_alias) == str: if type(self.gold_alias) == str:
self.gold_alias = self.template_aliases + self.doc_to_target self.gold_alias = self.template_aliases + self.doc_to_target
if not self.generation_kwargs: if not self.generation_kwargs:
# ensure that we greedily generate in absence of explicit arguments otherwise # ensure that we greedily generate in absence of explicit arguments otherwise
self.generation_kwargs = {"do_sample": False, "temperature": 0.0} self.generation_kwargs = {"do_sample": False, "temperature": 0.0}
...@@ -439,8 +453,12 @@ class Task(abc.ABC): ...@@ -439,8 +453,12 @@ class Task(abc.ABC):
def apply_filters(self): def apply_filters(self):
for f in self._filters: if hasattr(self, "_filters"):
f.apply(self._instances) for f in self._filters:
f.apply(self._instances)
else:
eval_logger.warning("No filter defined, passing through instances")
return self._instances
class ConfigurableTask(Task): class ConfigurableTask(Task):
...@@ -469,6 +487,7 @@ class ConfigurableTask(Task): ...@@ -469,6 +487,7 @@ class ConfigurableTask(Task):
) )
if self._config.output_type is not None: if self._config.output_type is not None:
assert self._config.output_type in ALL_OUTPUT_TYPES
self.OUTPUT_TYPE = self._config.output_type self.OUTPUT_TYPE = self._config.output_type
if self._config.dataset_path is not None: if self._config.dataset_path is not None:
...@@ -477,35 +496,42 @@ class ConfigurableTask(Task): ...@@ -477,35 +496,42 @@ class ConfigurableTask(Task):
if self._config.dataset_name is not None: if self._config.dataset_name is not None:
self.DATASET_NAME = self._config.dataset_name self.DATASET_NAME = self._config.dataset_name
if self._config.metric_list is not None: self._metric_fn_list = {}
self._metric_list = {} self._metric_fn_kwargs = {}
self._metric_kwargs = {} self._aggregation_list = {}
self._aggregation_list = {} self._higher_is_better = {}
self._higher_is_better = {}
for metric_config in self._config.metric_list: _metric_list = DEFAULT_METRIC_REGISTRY[self._config.output_type]
if self._config.metric_list is None:
for metric_name in _metric_list:
self._metric_fn_list[metric_name] = METRIC_REGISTRY[metric_name]
self._aggregation_list[metric_name] = DEFAULT_AGGREGATION_REGISTRY[
metric_name
]
self._higher_is_better[metric_name] = HIGHER_IS_BETTER_REGISTRY[
metric_name
]
else:
for metric_config in self._config.metric_list:
assert "metric" in metric_config
metric_name = metric_config["metric"] metric_name = metric_config["metric"]
aggregation = metric_config["aggregation"]
higher_is_better = metric_config["higher_is_better"]
kwargs = { kwargs = {
key: metric_config[key] key: metric_config[key]
for key in metric_config for key in metric_config
if key not in ["metric", "aggregation", "higher_is_better"] if key not in ["metric", "aggregation", "higher_is_better"]
} }
if metric_name in _metric_list:
self._aggregation_list[metric_name] = AGGREGATION_REGISTRY[aggregation] self._metric_fn_list[metric_name] = METRIC_REGISTRY[metric_name]
if metric_name in METRIC_REGISTRY.keys():
self._metric_list[metric_name] = METRIC_REGISTRY[metric_name]
self._higher_is_better[metric_name] = HIGHER_IS_BETTER_REGISTRY[
metric_name
]
else: else:
self._higher_is_better[metric_name] = higher_is_better eval_logger.warning(
f"Metric {metric_name} not found, "
"Searching from https://huggingface.co/evaluate-metric"
)
try: try:
metric_object = evaluate.load(metric_name) metric_object = evaluate.load(metric_name)
self._metric_list[metric_name] = metric_object self._metric_fn_list[metric_name] = metric_object
self._metric_kwargs[metric_name] = kwargs self._metric_fn_kwargs[metric_name] = kwargs
except Exception: except Exception:
raise Warning( raise Warning(
...@@ -513,12 +539,36 @@ class ConfigurableTask(Task): ...@@ -513,12 +539,36 @@ class ConfigurableTask(Task):
"Please check https://huggingface.co/evaluate-metric", "Please check https://huggingface.co/evaluate-metric",
) )
self.download(data_dir, cache_dir, download_mode) if "aggregation" in metric_config:
self._aggregation_list[metric_name] = metric_config["aggregation"]
else:
eval_logger.warning(
f"metric {metric_name} is defined, but aggregation is not"
f"using default aggregation for {metric_name}"
)
self._aggregation_list[metric_name] = DEFAULT_AGGREGATION_REGISTRY[
metric_name
]
if "higher_is_better" in metric_config:
self._higher_is_better[metric_name] = metric_config[
"higher_is_better"
]
else:
eval_logger.warning(
f"metric {metric_name} is defined, but higher_is_better is not"
f"using default higher_is_better for {metric_name}"
)
self._higher_is_better[metric_name] = HIGHER_IS_BETTER_REGISTRY[
metric_name
]
self.download(self._config.dataset_kwargs)
self._training_docs = None self._training_docs = None
self._fewshot_docs = None self._fewshot_docs = None
self._filters = []
if self._config.filter_list is not None: if self._config.filter_list is not None:
self._filters = []
for filter_config in self._config.filter_list: for filter_config in self._config.filter_list:
for filter_pipeline in filter_config: for filter_pipeline in filter_config:
filter_name = filter_config["name"] filter_name = filter_config["name"]
...@@ -530,7 +580,7 @@ class ConfigurableTask(Task): ...@@ -530,7 +580,7 @@ class ConfigurableTask(Task):
} }
components.append([function["function"], kwargs]) components.append([function["function"], kwargs])
filter_pipeline = build_filter_ensemble(filter_name, components) filter_pipeline = build_filter_ensemble(filter_name, components)
self._filters.append(filter_pipeline) self._filters.append(filter_pipeline)
else: else:
self._filters = [ self._filters = [
...@@ -550,6 +600,14 @@ class ConfigurableTask(Task): ...@@ -550,6 +600,14 @@ class ConfigurableTask(Task):
list(self.fewshot_docs()), self, rnd=random.Random() list(self.fewshot_docs()), self, rnd=random.Random()
) # TODO: pass the correct docs in here ) # TODO: pass the correct docs in here
def download(self, dataset_kwargs=None):
self.dataset = datasets.load_dataset(
path=self.DATASET_PATH,
name=self.DATASET_NAME,
**dataset_kwargs if dataset_kwargs is not None else {},
)
def has_training_docs(self): def has_training_docs(self):
if self._config.training_split is not None: if self._config.training_split is not None:
return True return True
...@@ -643,7 +701,7 @@ class ConfigurableTask(Task): ...@@ -643,7 +701,7 @@ class ConfigurableTask(Task):
raise TypeError raise TypeError
def gold_alias(self, doc): def gold_alias(self, doc):
# TODO: reevaluate if we need this. implemented to have a # TODO: reevaluate if we need this. implemented to have a
# processed version of answer to put into gsm8k exact_match scoring as ref. # processed version of answer to put into gsm8k exact_match scoring as ref.
if self._config.gold_alias: if self._config.gold_alias:
doc_to_target = self._config.gold_alias doc_to_target = self._config.gold_alias
...@@ -684,7 +742,7 @@ class ConfigurableTask(Task): ...@@ -684,7 +742,7 @@ class ConfigurableTask(Task):
for i, choice in enumerate(choices) for i, choice in enumerate(choices)
] ]
# TODO: we should raise a warning telling users this will at most ~2x runtime. # TODO: we should raise a warning telling users this will at most ~2x runtime.
if "acc_mutual_info" in self._metric_list.keys(): if "acc_mutual_info" in self._metric_fn_list.keys():
# if we are calculating multiple choice accuracy # if we are calculating multiple choice accuracy
# using mutual information instead of raw loglikelihood as metric, need unconditional lls. # using mutual information instead of raw loglikelihood as metric, need unconditional lls.
...@@ -714,25 +772,44 @@ class ConfigurableTask(Task): ...@@ -714,25 +772,44 @@ class ConfigurableTask(Task):
def process_results(self, doc, results): def process_results(self, doc, results):
# if callable(self._config.process_results):
# return self._config.process_results(doc, results)
result_dict = {} result_dict = {}
use_metric = list(self._metric_fn_list.keys())
if self.OUTPUT_TYPE == "loglikelihood": if self.OUTPUT_TYPE == "loglikelihood":
results = results[0] results = results[0]
ll, is_greedy = results ll, is_greedy = results
result_dict = {"perplexity": ll, "accuracy": int(is_greedy)} return {
**({"perplexity": ll} if "perplexity" in use_metric else {}),
**({"acc": int(is_greedy)} if "acc" in use_metric else {}),
}
elif self.OUTPUT_TYPE == "loglikelihood_rolling": elif self.OUTPUT_TYPE == "loglikelihood_rolling":
(loglikelihood,) = results (loglikelihood,) = results
words = self.count_words(self.doc_to_target(doc)) _words = self.count_words(self.doc_to_target(doc))
bytes_ = self.count_bytes(self.doc_to_target(doc)) _bytes = self.count_bytes(self.doc_to_target(doc))
return { return {
"word_perplexity": (loglikelihood, words), **(
"byte_perplexity": (loglikelihood, bytes_), {"word_perplexity": (loglikelihood, _words)}
"bits_per_byte": (loglikelihood, bytes_), if "word_perplexity" in use_metric
else {}
),
**(
{"byte_perplexity": (loglikelihood, _bytes)}
if "byte_perplexity" in use_metric
else {}
),
**(
{"bits_per_byte": (loglikelihood, _bytes)}
if "bits_per_byte" in use_metric
else {}
),
} }
elif self.OUTPUT_TYPE == "multiple_choice": elif self.OUTPUT_TYPE == "multiple_choice":
lls = [
res[0] for res in results lls, is_greedy = zip(*results)
] # only retain loglikelihoods, discard is_greedy
gold = int(self.doc_to_target(doc)) gold = int(self.doc_to_target(doc))
pred = np.argmax(lls)
# retrieve choices in List[str] form, to compute choice lengths, etc. # retrieve choices in List[str] form, to compute choice lengths, etc.
choices = ast.literal_eval( choices = ast.literal_eval(
utils.apply_template( utils.apply_template(
...@@ -755,21 +832,18 @@ class ConfigurableTask(Task): ...@@ -755,21 +832,18 @@ class ConfigurableTask(Task):
acc_norm = 1.0 if np.argmax(lls / completion_len) == gold else 0.0 acc_norm = 1.0 if np.argmax(lls / completion_len) == gold else 0.0
result_dict = { result_dict = {
"acc": acc, **({"acc": acc} if "acc" in use_metric else {}),
"acc_norm": acc_norm, **({"f1": (pred, gold)} if "f1" in use_metric else {}),
**({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}),
} }
# TODO: set which normalization metrics should be reported, and calculate them # TODO: set which normalization metrics should be reported, and calculate them
if "exact_match" in self._metric_fn_list.keys():
if "exact_match" in self._metric_list.keys():
# TODO: this gets score of 0 on arc_challenge for pythia-70m. need to test that this works properly # TODO: this gets score of 0 on arc_challenge for pythia-70m. need to test that this works properly
is_greedy = [
res[1] for res in results
] # take only the `is_greedy` results
is_greedy = is_greedy[gold] # take value for the gold answer is_greedy = is_greedy[gold] # take value for the gold answer
result_dict["exact_match"] = int(is_greedy) result_dict["exact_match"] = int(is_greedy)
if "acc_mutual_info" in self._metric_list.keys(): if "acc_mutual_info" in use_metric:
lls_mutual_info = [ lls_mutual_info = [
ll_c - ll_u for ll_c, ll_u in zip(lls, lls_unconditional) ll_c - ll_u for ll_c, ll_u in zip(lls, lls_unconditional)
] ]
...@@ -783,16 +857,16 @@ class ConfigurableTask(Task): ...@@ -783,16 +857,16 @@ class ConfigurableTask(Task):
else: else:
gold = self.doc_to_target(doc) gold = self.doc_to_target(doc)
for key, result in zip(self._metric_list.keys(), results): for key, result in zip(self._metric_fn_list.keys(), results):
_dict = self._metric_list[key].compute( _dict = self._metric_fn_list[key].compute(
references=[gold], predictions=[result], **self._metric_kwargs[key] references=[gold], predictions=[result], **self._metric_kwargs[key]
) )
result_dict[key] = _dict[key] result_dict = {**result_dict, **_dict}
else: else:
raise ValueError( raise ValueError(
f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ", f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ",
"'loglikelihood', 'loglikelihood_rolling', 'greedy_until'", "'loglikelihood', 'loglikelihood_rolling', 'greedy_until', or 'multiple_choice'",
) )
return result_dict return result_dict
......
...@@ -7,10 +7,10 @@ import torch ...@@ -7,10 +7,10 @@ import torch
import numpy as np import numpy as np
import lm_eval.api import lm_eval.api
import lm_eval.api.metrics
import lm_eval.tasks import lm_eval.tasks
import lm_eval.models import lm_eval.models
import lm_eval.api.metrics
import lm_eval.api.registry
from lm_eval.utils import ( from lm_eval.utils import (
positional_deprecated, positional_deprecated,
...@@ -72,7 +72,7 @@ def simple_evaluate( ...@@ -72,7 +72,7 @@ def simple_evaluate(
if isinstance(model, str): if isinstance(model, str):
if model_args is None: if model_args is None:
model_args = "" model_args = ""
lm = lm_eval.api.model.get_model(model).create_from_arg_string( lm = lm_eval.api.registry.get_model(model).create_from_arg_string(
model_args, {"batch_size": batch_size, "device": device} model_args, {"batch_size": batch_size, "device": device}
) )
else: else:
...@@ -274,9 +274,7 @@ def evaluate( ...@@ -274,9 +274,7 @@ def evaluate(
# aggregate results ; run bootstrap CIs # aggregate results ; run bootstrap CIs
for (task_name, key, metric), items in vals.items(): for (task_name, key, metric), items in vals.items():
task = task_dict[task_name] task = task_dict[task_name]
results[task_name][metric + "," + key] = task.aggregation()[ results[task_name][metric + "," + key] = task.aggregation()[metric](items)
metric
](items)
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this # so we run them less iterations. still looking for a cleaner way to do this
...@@ -289,9 +287,7 @@ def evaluate( ...@@ -289,9 +287,7 @@ def evaluate(
) )
if stderr is not None: if stderr is not None:
results[task_name][metric + "_stderr" + "," + key] = stderr( results[task_name][metric + "_stderr" + "," + key] = stderr(items)
items
)
return {"results": dict(results), "versions": dict(versions)} return {"results": dict(results), "versions": dict(versions)}
......
...@@ -6,7 +6,7 @@ from . import extraction ...@@ -6,7 +6,7 @@ from . import extraction
FILTER_REGISTRY = { FILTER_REGISTRY = {
"take_first": selection.TakeFirstFilter, "take_first": selection.TakeFirstFilter,
"regex": extraction.RegexFilter, "regex": extraction.RegexFilter,
"majority_vote": selection.MajorityVoteFilter, "majority_vote": selection.MajorityVoteFilter,
"take_first_k": selection.TakeKFilter, "take_first_k": selection.TakeKFilter,
# TODO: implement this filter. either it should take in an arbitrary "scoring"/reward function # TODO: implement this filter. either it should take in an arbitrary "scoring"/reward function
# that takes an input and returns a scalar and then should select the max reward, # that takes an input and returns a scalar and then should select the max reward,
......
...@@ -15,8 +15,8 @@ class TakeFirstFilter(Filter): ...@@ -15,8 +15,8 @@ class TakeFirstFilter(Filter):
""" """
return map(lambda r: r[0], resps) return map(lambda r: r[0], resps)
class TakeKFilter(Filter):
class TakeKFilter(Filter):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.k = kwargs.pop("k") self.k = kwargs.pop("k")
...@@ -25,8 +25,10 @@ class TakeKFilter(Filter): ...@@ -25,8 +25,10 @@ class TakeKFilter(Filter):
def apply(self, resps): def apply(self, resps):
# check we have at least k responses per doc, else we can't take the first k # check we have at least k responses per doc, else we can't take the first k
assert len(resps[0]) >= self.k, f"Need at least {self.k} responses per doc to take first {self.k}, but got {len(resps[0])} only! Please increase TaskConfig.repeats ." assert (
return map(lambda r: r[:self.k], resps) len(resps[0]) >= self.k
), f"Need at least {self.k} responses per doc to take first {self.k}, but got {len(resps[0])} only! Please increase TaskConfig.repeats ."
return map(lambda r: r[: self.k], resps)
class MajorityVoteFilter(Filter): class MajorityVoteFilter(Filter):
...@@ -37,12 +39,13 @@ class MajorityVoteFilter(Filter): ...@@ -37,12 +39,13 @@ class MajorityVoteFilter(Filter):
def apply(self, resps): def apply(self, resps):
""" """
Each entry of `resps` is a list of model responses. Each entry of `resps` is a list of model responses.
We select the response that occurs most frequently in each entry of `resps`. We select the response that occurs most frequently in each entry of `resps`.
""" """
def select_majority(resp): def select_majority(resp):
counts = Counter(resp) counts = Counter(resp)
vote = counts.most_common(1)[0][0] vote = counts.most_common(1)[0][0]
return vote return vote
return map(lambda r: [select_majority(r)], resps) return map(lambda r: [select_majority(r)], resps)
import random import random
from lm_eval.api.model import LM, register_model from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
@register_model("dummy") @register_model("dummy")
......
...@@ -6,7 +6,8 @@ import numpy as np ...@@ -6,7 +6,8 @@ import numpy as np
from tqdm import tqdm from tqdm import tqdm
from lm_eval import utils from lm_eval import utils
from lm_eval.api.model import LM, register_model from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
def get_result(response, ctxlen): def get_result(response, ctxlen):
......
...@@ -8,7 +8,8 @@ import torch.nn.functional as F ...@@ -8,7 +8,8 @@ import torch.nn.functional as F
from lm_eval import utils from lm_eval import utils
from lm_eval.logger import eval_logger from lm_eval.logger import eval_logger
from lm_eval.api.model import LM, register_model from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
from accelerate import Accelerator from accelerate import Accelerator
from itertools import islice from itertools import islice
...@@ -38,10 +39,10 @@ class HFLM(LM): ...@@ -38,10 +39,10 @@ class HFLM(LM):
if device not in ["cuda", "cpu"]: if device not in ["cuda", "cpu"]:
device = int(device) device = int(device)
self._device = torch.device(device) self._device = torch.device(device)
print(f"Using device '{device}'") eval_logger.info(f"Using device '{device}'")
else: else:
print("Device not specified") eval_logger.info("Device not specified")
print(f"Cuda Available? {torch.cuda.is_available()}") eval_logger.info(f"Cuda Available? {torch.cuda.is_available()}")
self._device = ( self._device = (
torch.device("cuda") torch.device("cuda")
if torch.cuda.is_available() if torch.cuda.is_available()
...@@ -75,13 +76,12 @@ class HFLM(LM): ...@@ -75,13 +76,12 @@ class HFLM(LM):
if gpus > 1: if gpus > 1:
accelerator = Accelerator() accelerator = Accelerator()
if gpus > accelerator.num_processes: if gpus > accelerator.num_processes:
warning = ( eval_logger.warning(
"WARNING: The number of total system GPUs does not match the number of spawned processes. " "WARNING: The number of total system GPUs does not match the number of spawned processes. "
"If you would like to use data parallelism, please launch the script " "If you would like to use data parallelism, please launch the script "
"with 'accelerate launch *script*'. " "with 'accelerate launch *script*'. "
f"Current run will proceed with {accelerator.num_processes} devices." f"Current run will proceed with {accelerator.num_processes} devices."
) )
print(warning)
self._rank = accelerator.local_process_index self._rank = accelerator.local_process_index
self._world_size = accelerator.num_processes self._world_size = accelerator.num_processes
else: else:
...@@ -90,7 +90,7 @@ class HFLM(LM): ...@@ -90,7 +90,7 @@ class HFLM(LM):
self.accelerator = accelerator self.accelerator = accelerator
if self.accelerator.is_local_main_process: if self.accelerator.is_local_main_process:
print(f"Using {gpus} devices with data parallelism") eval_logger.info(f"Using {gpus} devices with data parallelism")
self._rank = self.accelerator.local_process_index self._rank = self.accelerator.local_process_index
self._world_size = self.accelerator.num_processes self._world_size = self.accelerator.num_processes
...@@ -154,17 +154,26 @@ class HFLM(LM): ...@@ -154,17 +154,26 @@ class HFLM(LM):
return self.model(inps)[0] return self.model(inps)[0]
def _model_generate(self, context, max_length, eos_token_id, **generation_kwargs): def _model_generate(self, context, max_length, eos_token_id, **generation_kwargs):
# we require users to pass do_sample=True explicitly # we require users to pass do_sample=True explicitly
# for non-greedy gen. This should be reevaluated when considering beam search. # for non-greedy gen. This should be reevaluated when considering beam search.
if "do_sample" not in generation_kwargs.keys(): if "do_sample" not in generation_kwargs.keys():
generation_kwargs["do_sample"] = False generation_kwargs["do_sample"] = False
return self.model.generate( if hasattr(self, "accelerator"):
context, return self.accelerator.unwrap_model(self.model).generate(
max_length=max_length, context,
pad_token_id=eos_token_id, max_length=max_length,
eos_token_id=eos_token_id, pad_token_id=eos_token_id,
**generation_kwargs, eos_token_id=eos_token_id,
) **generation_kwargs,
)
else:
return self.model.generate(
context,
max_length=max_length,
pad_token_id=eos_token_id,
eos_token_id=eos_token_id,
**generation_kwargs,
)
def loglikelihood(self, requests): def loglikelihood(self, requests):
new_reqs = [] new_reqs = []
...@@ -354,7 +363,7 @@ class HFLM(LM): ...@@ -354,7 +363,7 @@ class HFLM(LM):
for context, gen_kwargs in tqdm(re_ord.get_reordered()): for context, gen_kwargs in tqdm(re_ord.get_reordered()):
if isinstance(gen_kwargs, dict): if isinstance(gen_kwargs, dict):
gen_kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1 gen_kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
if "until" in gen_kwargs.keys(): if "until" in gen_kwargs.keys():
until = gen_kwargs.pop("until") until = gen_kwargs.pop("until")
if isinstance(until, str): if isinstance(until, str):
...@@ -362,9 +371,11 @@ class HFLM(LM): ...@@ -362,9 +371,11 @@ class HFLM(LM):
elif not isinstance(until, list): elif not isinstance(until, list):
raise ValueError( raise ValueError(
f"Expected `gen_kwargs['until']` to be of type Union[str,list] but got {until}" f"Expected `gen_kwargs['until']` to be of type Union[str,list] but got {until}"
) )
else: else:
raise ValueError(f"Expected `gen_kwargs` to be of type `dict` but got {gen_kwargs}") raise ValueError(
f"Expected `gen_kwargs` to be of type `dict` but got {gen_kwargs}"
)
if not until: if not until:
until = [self.tok_decode(self.eot_token_id)] until = [self.tok_decode(self.eot_token_id)]
if "max_gen_toks" in gen_kwargs.keys(): if "max_gen_toks" in gen_kwargs.keys():
...@@ -374,7 +385,7 @@ class HFLM(LM): ...@@ -374,7 +385,7 @@ class HFLM(LM):
try: try:
(primary_until,) = self.tok_encode(until[0]) (primary_until,) = self.tok_encode(until[0])
except: except Exception:
# if our primary until would be multiple tokens long, we'll have errors. # if our primary until would be multiple tokens long, we'll have errors.
# TODO: handling this better will let us stop generating earlier + often. # TODO: handling this better will let us stop generating earlier + often.
primary_until = self.eot_token_id primary_until = self.eot_token_id
...@@ -384,8 +395,8 @@ class HFLM(LM): ...@@ -384,8 +395,8 @@ class HFLM(LM):
).to(self.device) ).to(self.device)
cont = self._model_generate( cont = self._model_generate(
context=context_enc, context=context_enc,
max_length=context_enc.shape[1] + max_gen_toks, max_length=context_enc.shape[1] + max_gen_toks,
eos_token_id=primary_until, eos_token_id=primary_until,
**gen_kwargs, **gen_kwargs,
) )
......
...@@ -16,7 +16,8 @@ import os ...@@ -16,7 +16,8 @@ import os
import requests as _requests import requests as _requests
import time import time
from tqdm import tqdm from tqdm import tqdm
from lm_eval.api.model import LM, register_model from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -64,4 +64,4 @@ Tasks added in the revamped harness that were not previously available. Again, a ...@@ -64,4 +64,4 @@ Tasks added in the revamped harness that were not previously available. Again, a
- [ ] Chain of Thought - [ ] Chain of Thought
- [ ] Self-consistency ; Least-to-Most prompting, etc. - [ ] Self-consistency ; Least-to-Most prompting, etc.
- [ ] Summarization Tasks - [ ] Summarization Tasks
- [ ] Anthropic Model-Written Evals - [ ] Anthropic Model-Written Evals
\ No newline at end of file
...@@ -7,16 +7,16 @@ from .triviaqa import * ...@@ -7,16 +7,16 @@ from .triviaqa import *
from lm_eval import utils from lm_eval import utils
from lm_eval.logger import eval_logger from lm_eval.logger import eval_logger
from lm_eval.api.task import TaskConfig, Task, ConfigurableTask from lm_eval.api.task import TaskConfig, Task, ConfigurableTask
from lm_eval.api.register import ( from lm_eval.api.registry import (
register_task, register_task,
register_group, register_group,
task_registry, TASK_REGISTRY,
group_registry, GROUP_REGISTRY,
) )
def get_task_name_from_config(task_config): def get_task_name_from_config(task_config):
return "configurable_{dataset_path}_{dataset_name}".format(**task_config) return "{dataset_path}_{dataset_name}".format(**task_config)
task_dir = os.path.dirname(os.path.abspath(__file__)) + "/" task_dir = os.path.dirname(os.path.abspath(__file__)) + "/"
...@@ -35,24 +35,20 @@ for root, subdirs, file_list in os.walk(task_dir): ...@@ -35,24 +35,20 @@ for root, subdirs, file_list in os.walk(task_dir):
) )
if "task" in config: if "task" in config:
task_name = "{}".format( task_name = "{}".format(config["task"])
config["task"]
)
register_task(task_name)(SubClass) register_task(task_name)(SubClass)
if "group" in config: if "group" in config:
for group in config["group"]: for group in config["group"]:
register_group(group)(SubClass) register_group(group)(SubClass)
except Exception as e: except Exception as error:
raise e
eval_logger.warning( eval_logger.warning(
"Failed to load config in\n" "Failed to load config in\n"
f" {yaml_path}\n" f" {yaml_path}\n"
" Config will not be added to registry" " Config will not be added to registry"
f" Error: {error}"
) )
TASK_REGISTRY = task_registry
GROUP_REGISTRY = group_registry
ALL_TASKS = sorted(list(TASK_REGISTRY.keys()) + list(GROUP_REGISTRY.keys())) ALL_TASKS = sorted(list(TASK_REGISTRY.keys()) + list(GROUP_REGISTRY.keys()))
......
"""
Think you have Solved Question Answering? Try ARC, the AI2 Reasoning Challenge
https://arxiv.org/pdf/1803.05457.pdf
The ARC dataset consists of 7,787 science exam questions drawn from a variety
of sources, including science questions provided under license by a research
partner affiliated with AI2. These are text-only, English language exam questions
that span several grade levels as indicated in the files. Each question has a
multiple choice structure (typically 4 answer options). The questions are sorted
into a Challenge Set of 2,590 “hard” questions (those that both a retrieval and
a co-occurrence method fail to answer correctly) and an Easy Set of 5,197 questions.
Homepage: https://allenai.org/data/arc
"""
from lm_eval import utils
from lm_eval.prompts import get_prompt
from lm_eval.api.task import MultipleChoiceTask
from lm_eval.api.registry import register_task, register_group
_CITATION = """
@article{Clark2018ThinkYH,
title={Think you have Solved Question Answering? Try ARC, the AI2 Reasoning Challenge},
author={Peter Clark and Isaac Cowhey and Oren Etzioni and Tushar Khot and Ashish Sabharwal and Carissa Schoenick and Oyvind Tafjord},
journal={ArXiv},
year={2018},
volume={abs/1803.05457}
}
"""
@register_group("arc")
@register_task("arc_easy")
class ARCEasy(MultipleChoiceTask):
VERSION = "2.0"
DATASET_PATH = "ai2_arc"
DATASET_NAME = "ARC-Easy"
OUTPUT_TYPE = "loglikelihood"
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs
def validation_docs(self):
return map(self._process_doc, self.dataset["validation"])
def test_docs(self):
return map(self._process_doc, self.dataset["test"])
def _process_doc(self, doc):
# NOTE: Some `doc["answerKey"]`s are in numeric string format being one
# of {'1', '2', '3', '4', '5'}. We map them back to letters.
num_to_letter = {"1": "A", "2": "B", "3": "C", "4": "D", "5": "E"}
doc["answerKey"] = num_to_letter.get(doc["answerKey"], doc["answerKey"])
out_doc = {
"id": doc["id"],
"question": doc["question"],
"choices": doc["choices"]["text"],
"gold": ["A", "B", "C", "D", "E"].index(doc["answerKey"]),
}
return out_doc
def doc_to_text(self, doc):
doc_to_text = get_prompt("qa-basic:question-newline-answer")
return utils.apply_template(doc_to_text, doc)
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["query"]
@register_group("arc")
@register_task("arc_challenge")
class ARCChallenge(ARCEasy):
DATASET_PATH = "ai2_arc"
DATASET_NAME = "ARC-Challenge"
...@@ -24,7 +24,7 @@ from lm_eval.api.instance import Instance ...@@ -24,7 +24,7 @@ from lm_eval.api.instance import Instance
from lm_eval.prompts import get_prompt from lm_eval.prompts import get_prompt
from lm_eval.api.register import register_task, register_group from lm_eval.api.registry import register_task, register_group
_CITATION = """ _CITATION = """
@misc{cobbe2021training, @misc{cobbe2021training,
......
...@@ -29,4 +29,4 @@ Homepage: https://github.com/openai/grade-school-math ...@@ -29,4 +29,4 @@ Homepage: https://github.com/openai/grade-school-math
archivePrefix={arXiv}, archivePrefix={arXiv},
primaryClass={cs.LG} primaryClass={cs.LG}
} }
``` ```
\ No newline at end of file
...@@ -29,4 +29,4 @@ filter_list: ...@@ -29,4 +29,4 @@ filter_list:
- function: "regex" - function: "regex"
regex_pattern: "The answer is (\\-?[0-9\\.\\,]*[0-9]+)" regex_pattern: "The answer is (\\-?[0-9\\.\\,]*[0-9]+)"
- function: "majority_vote" - function: "majority_vote"
- function: "take_first" - function: "take_first"
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment