Commit 9d6bc929 authored by lintangsutawika's avatar lintangsutawika
Browse files

aggregation to compute_metric

parent 4d49dd03
......@@ -29,16 +29,11 @@ from lm_eval.api.metrics import (
mean,
weighted_perplexity,
bits_per_byte,
metric_max_over_ground_truths,
)
from lm_eval.api.registry import (
get_metric,
get_aggregation,
get_metric_aggregation,
is_higher_better,
DEFAULT_METRIC_REGISTRY,
OUTPUT_TYPE_REGISTRY,
AGGREGATION_REGISTRY,
)
ALL_OUTPUT_TYPES = [
......@@ -415,7 +410,7 @@ class Task(abc.ABC):
pass
@abc.abstractmethod
def aggregation(self):
def compute_metric(self):
"""
:returns: {str: [metric_score] -> float}
A dictionary where keys are the names of submetrics and values are
......@@ -569,13 +564,6 @@ class ConfigurableTask(Task):
metric = get_metric(metric_name)
self._metric_fn_list[metric_name] = metric
self._metric_fn_kwargs[metric_name] = {}
self._aggregation_list[metric_name] = metric.aggregation
# try:
# self._aggregation_list[metric_name] = metric.aggregation
# except:
# self._aggregation_list[metric_name] = get_metric_aggregation(
# metric_name
# )
self._higher_is_better[metric_name] = is_higher_better(metric_name)
else:
for metric_config in self.config.metric_list:
......@@ -606,36 +594,6 @@ class ConfigurableTask(Task):
)
self._metric_fn_kwargs[metric_name] = kwargs
# if "aggregation" in metric_config:
# agg_name = metric_config["aggregation"]
# if type(agg_name) == str:
# self._aggregation_list[metric_name] = get_aggregation(agg_name)
# elif callable(agg_name):
# self._aggregation_list[metric_name] = metric_config[
# "aggregation"
# ]
# else:
# INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()}
# metric_agg = get_metric_aggregation(metric_name)
# eval_logger.warning(
# f"[Task: {self._config.task}] metric {metric_name} is defined, but aggregation is not. "
# f"using default "
# f"aggregation={INV_AGG_REGISTRY[metric_agg]}"
# )
# self._aggregation_list[metric_name] = metric_agg
# if "higher_is_better" in metric_config:
# self._higher_is_better[metric_name] = metric_config[
# "higher_is_better"
# ]
# else:
# eval_logger.warning(
# f"[Task: {self._config.task}] metric {metric_name} is defined, but higher_is_better is not. "
# f"using default "
# f"higher_is_better={is_higher_better(metric_name)}"
# )
# self._higher_is_better[metric_name] = is_higher_better(metric_name)
self.download(self.config.dataset_kwargs)
self._training_docs = None
self._fewshot_docs = None
......@@ -1023,19 +981,43 @@ class ConfigurableTask(Task):
)
def process_results(self, doc, results):
# Process results returns 1 of X things per doc/results
# 1. A score
# 2. Components to be processed later to obtained a score. such as gold and prediction
if callable(self.config.process_results):
return self.config.process_results(doc, results)
result_dict = {}
use_metric = list(self._metric_fn_list.keys())
if self.OUTPUT_TYPE == "loglikelihood":
results = results[0]
ll, is_greedy = results
return ll, is_greedy
return {
**({"perplexity": ll} if "perplexity" in use_metric else {}),
**({"acc": int(is_greedy)} if "acc" in use_metric else {}),
}
elif self.OUTPUT_TYPE == "loglikelihood_rolling":
(loglikelihood,) = results
_words = self.count_words(self.doc_to_target(doc))
_bytes = self.count_bytes(self.doc_to_target(doc))
return loglikelihood, _words, _bytes
return {
**(
{"word_perplexity": (loglikelihood, _words)}
if "word_perplexity" in use_metric
else {}
),
**(
{"byte_perplexity": (loglikelihood, _bytes)}
if "byte_perplexity" in use_metric
else {}
),
**(
{"bits_per_byte": (loglikelihood, _bytes)}
if "bits_per_byte" in use_metric
else {}
),
}
elif self.OUTPUT_TYPE == "multiple_choice":
lls, is_greedy = zip(*results)
......@@ -1063,14 +1045,14 @@ class ConfigurableTask(Task):
gold = self.doc_to_target(doc)
gold_index_error = False
if type(gold) is list:
if isinstance(gold, list):
gold = [i if i < len(choices) else -100 for i in gold]
if -100 in gold:
gold_index_error = True
else:
if type(gold) is int:
if isinstance(gold, int):
gold = gold if gold < len(choices) else -100
elif type(gold) is str:
elif isinstance(gold, str):
gold = choices.index(gold) if gold in choices else -100
if gold == -100:
......@@ -1092,12 +1074,13 @@ class ConfigurableTask(Task):
# TODO: this gets score of 0 on arc_challenge for pythia-70m. need to test that this works properly
exact_match = int(is_greedy[gold]) if gold != -100 else 0
# gold, lls, is_greedy, completion_len
result_dict = {
**({"acc": acc} if "acc" in use_metric else {}),
**({"f1": (gold, pred)} if "f1" in use_metric else {}),
**({"mcc": (gold, pred)} if "mcc" in use_metric else {}),
**({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}),
**({"exact_match": exact_match} if "exact_match" in use_metric else {}),
**({"f1": (gold, pred)} if "f1" in use_metric else {}),
**({"mcc": (gold, pred)} if "mcc" in use_metric else {}),
}
if "acc_mutual_info" in use_metric:
......@@ -1160,9 +1143,7 @@ class ConfigurableTask(Task):
predictions=[result],
**self._metric_fn_kwargs[metric],
)
except (
TypeError
): # needed for now in order to use a different interface between our own metrics and HF Evaluate metrics
except TypeError: # needed for now in order to use a different interface between our own metrics and HF Evaluate metrics
result_score = self._metric_fn_list[metric]([gold, result])
if isinstance(result_score, dict):
# TODO: this handles the case where HF evaluate returns a dict.
......@@ -1176,8 +1157,7 @@ class ConfigurableTask(Task):
return result_dict
def aggregation(self):
# return self._aggregation_list
def compute_metric(self):
return self._metric_fn_list
def higher_is_better(self):
......@@ -1224,7 +1204,7 @@ class MultipleChoiceTask(Task):
"acc_norm": True,
}
def aggregation(self) -> dict:
def compute_metric(self) -> dict:
return {
"acc": mean,
"acc_norm": mean,
......@@ -1285,7 +1265,7 @@ class PerplexityTask(Task):
"bits_per_byte": (loglikelihood, bytes_),
}
def aggregation(self) -> dict:
def compute_metric(self) -> dict:
return {
"word_perplexity": weighted_perplexity,
"byte_perplexity": weighted_perplexity,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment