Commit 90d818da authored by lintangsutawika's avatar lintangsutawika
Browse files

fix issue with default metrics and aggregation functions

parent 6b4161c1
......@@ -135,6 +135,16 @@ def acc_mutual_info_fn(items): # This is a passthrough function
return items
@register_metric(
metric="exact_match",
higher_is_better=True,
output_type="generate_until",
aggregation="mean",
)
def exact_match_fn(items): # This is a passthrough function
return items
@register_metric(
metric="perplexity",
higher_is_better=False,
......
......@@ -68,10 +68,10 @@ def register_group(name):
return decorate
AGGREGATION_REGISTRY = {}
DEFAULT_AGGREGATION_REGISTRY = {}
METRIC_REGISTRY = {}
OUTPUT_TYPE_REGISTRY = {}
METRIC_REGISTRY = {}
METRIC_AGGREGATION_REGISTRY = {}
AGGREGATION_REGISTRY = {}
HIGHER_IS_BETTER_REGISTRY = {}
DEFAULT_METRIC_REGISTRY = {
......@@ -95,8 +95,7 @@ def register_metric(**args):
for key, registry in [
("metric", METRIC_REGISTRY),
("higher_is_better", HIGHER_IS_BETTER_REGISTRY),
# ("output_type", OUTPUT_TYPE_REGISTRY),
("aggregation", DEFAULT_AGGREGATION_REGISTRY),
("aggregation", METRIC_AGGREGATION_REGISTRY),
]:
if key in args:
......@@ -158,6 +157,16 @@ def get_aggregation(name):
)
def get_metric_aggregation(name):
try:
return METRIC_AGGREGATION_REGISTRY[name]
except KeyError:
eval_logger.warning(
"{} not a registered aggregation metric!".format(name),
)
def is_higher_better(metric_name):
try:
return HIGHER_IS_BETTER_REGISTRY[metric_name]
......
......@@ -33,6 +33,7 @@ from lm_eval.api.metrics import (
from lm_eval.api.registry import (
get_metric,
get_aggregation,
get_metric_aggregation,
is_higher_better,
DEFAULT_METRIC_REGISTRY,
OUTPUT_TYPE_REGISTRY,
......@@ -537,12 +538,15 @@ class ConfigurableTask(Task):
self._aggregation_list = {}
self._higher_is_better = {}
_metric_list = DEFAULT_METRIC_REGISTRY[self.config.output_type]
if self.config.metric_list is None:
# TODO: handle this in TaskConfig.__post_init__ ?
_metric_list = DEFAULT_METRIC_REGISTRY[self.config.output_type]
for metric_name in _metric_list:
self._metric_fn_list[metric_name] = get_metric(metric_name)
self._aggregation_list[metric_name] = get_aggregation(metric_name)
self._aggregation_list[metric_name] = get_metric_aggregation(
metric_name
)
self._higher_is_better[metric_name] = is_higher_better(metric_name)
else:
for metric_config in self.config.metric_list:
......
......@@ -663,8 +663,16 @@ class HFLM(LM):
chunks = utils.chunks(
re_ord.get_reordered(),
n=self.batch_size if self.batch_size != "auto" else override_bs if override_bs is not None else 0,
fn=self._batch_scheduler if self.batch_size == "auto" and n_reordered_requests > 0 and not override_bs else None,
n=self.batch_size
if self.batch_size != "auto"
else override_bs
if override_bs is not None
else 0,
fn=self._batch_scheduler
if self.batch_size == "auto"
and n_reordered_requests > 0
and not override_bs
else None,
)
for chunk in tqdm(chunks, disable=(disable_tqdm or (self.rank != 0))):
......@@ -840,8 +848,14 @@ class HFLM(LM):
for key, re_ord in re_ords.items():
chunks = utils.chunks(
re_ord.get_reordered(),
n=self.batch_size if self.batch_size != "auto" else adaptive_batch_size if adaptive_batch_size is not None else 0,
fn=self._batch_scheduler if self.batch_size == "auto" and not adaptive_batch_size else None,
n=self.batch_size
if self.batch_size != "auto"
else adaptive_batch_size
if adaptive_batch_size is not None
else 0,
fn=self._batch_scheduler
if self.batch_size == "auto" and not adaptive_batch_size
else None,
)
for chunk in tqdm(chunks, disable=self.rank != 0):
contexts, all_gen_kwargs = zip(*chunk)
......
......@@ -15,7 +15,8 @@ from lm_eval.api.registry import (
import logging
eval_logger = logging.getLogger('lm-eval')
eval_logger = logging.getLogger("lm-eval")
def register_configurable_task(config: Dict[str, str]) -> int:
SubClass = type(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment