Commit 9d6bc929 authored by lintangsutawika's avatar lintangsutawika
Browse files

aggregation to compute_metric

parent 4d49dd03
...@@ -29,16 +29,11 @@ from lm_eval.api.metrics import ( ...@@ -29,16 +29,11 @@ from lm_eval.api.metrics import (
mean, mean,
weighted_perplexity, weighted_perplexity,
bits_per_byte, bits_per_byte,
metric_max_over_ground_truths,
) )
from lm_eval.api.registry import ( from lm_eval.api.registry import (
get_metric, get_metric,
get_aggregation,
get_metric_aggregation,
is_higher_better, is_higher_better,
DEFAULT_METRIC_REGISTRY, DEFAULT_METRIC_REGISTRY,
OUTPUT_TYPE_REGISTRY,
AGGREGATION_REGISTRY,
) )
ALL_OUTPUT_TYPES = [ ALL_OUTPUT_TYPES = [
...@@ -415,7 +410,7 @@ class Task(abc.ABC): ...@@ -415,7 +410,7 @@ class Task(abc.ABC):
pass pass
@abc.abstractmethod @abc.abstractmethod
def aggregation(self): def compute_metric(self):
""" """
:returns: {str: [metric_score] -> float} :returns: {str: [metric_score] -> float}
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
...@@ -569,13 +564,6 @@ class ConfigurableTask(Task): ...@@ -569,13 +564,6 @@ class ConfigurableTask(Task):
metric = get_metric(metric_name) metric = get_metric(metric_name)
self._metric_fn_list[metric_name] = metric self._metric_fn_list[metric_name] = metric
self._metric_fn_kwargs[metric_name] = {} self._metric_fn_kwargs[metric_name] = {}
self._aggregation_list[metric_name] = metric.aggregation
# try:
# self._aggregation_list[metric_name] = metric.aggregation
# except:
# self._aggregation_list[metric_name] = get_metric_aggregation(
# metric_name
# )
self._higher_is_better[metric_name] = is_higher_better(metric_name) self._higher_is_better[metric_name] = is_higher_better(metric_name)
else: else:
for metric_config in self.config.metric_list: for metric_config in self.config.metric_list:
...@@ -606,36 +594,6 @@ class ConfigurableTask(Task): ...@@ -606,36 +594,6 @@ class ConfigurableTask(Task):
) )
self._metric_fn_kwargs[metric_name] = kwargs self._metric_fn_kwargs[metric_name] = kwargs
# if "aggregation" in metric_config:
# agg_name = metric_config["aggregation"]
# if type(agg_name) == str:
# self._aggregation_list[metric_name] = get_aggregation(agg_name)
# elif callable(agg_name):
# self._aggregation_list[metric_name] = metric_config[
# "aggregation"
# ]
# else:
# INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()}
# metric_agg = get_metric_aggregation(metric_name)
# eval_logger.warning(
# f"[Task: {self._config.task}] metric {metric_name} is defined, but aggregation is not. "
# f"using default "
# f"aggregation={INV_AGG_REGISTRY[metric_agg]}"
# )
# self._aggregation_list[metric_name] = metric_agg
# if "higher_is_better" in metric_config:
# self._higher_is_better[metric_name] = metric_config[
# "higher_is_better"
# ]
# else:
# eval_logger.warning(
# f"[Task: {self._config.task}] metric {metric_name} is defined, but higher_is_better is not. "
# f"using default "
# f"higher_is_better={is_higher_better(metric_name)}"
# )
# self._higher_is_better[metric_name] = is_higher_better(metric_name)
self.download(self.config.dataset_kwargs) self.download(self.config.dataset_kwargs)
self._training_docs = None self._training_docs = None
self._fewshot_docs = None self._fewshot_docs = None
...@@ -1023,19 +981,43 @@ class ConfigurableTask(Task): ...@@ -1023,19 +981,43 @@ class ConfigurableTask(Task):
) )
def process_results(self, doc, results): def process_results(self, doc, results):
# Process results returns 1 of X things per doc/results
# 1. A score
# 2. Components to be processed later to obtained a score. such as gold and prediction
if callable(self.config.process_results): if callable(self.config.process_results):
return self.config.process_results(doc, results) return self.config.process_results(doc, results)
result_dict = {} result_dict = {}
use_metric = list(self._metric_fn_list.keys())
if self.OUTPUT_TYPE == "loglikelihood": if self.OUTPUT_TYPE == "loglikelihood":
results = results[0] results = results[0]
ll, is_greedy = results ll, is_greedy = results
return ll, is_greedy return {
**({"perplexity": ll} if "perplexity" in use_metric else {}),
**({"acc": int(is_greedy)} if "acc" in use_metric else {}),
}
elif self.OUTPUT_TYPE == "loglikelihood_rolling": elif self.OUTPUT_TYPE == "loglikelihood_rolling":
(loglikelihood,) = results (loglikelihood,) = results
_words = self.count_words(self.doc_to_target(doc)) _words = self.count_words(self.doc_to_target(doc))
_bytes = self.count_bytes(self.doc_to_target(doc)) _bytes = self.count_bytes(self.doc_to_target(doc))
return loglikelihood, _words, _bytes return {
**(
{"word_perplexity": (loglikelihood, _words)}
if "word_perplexity" in use_metric
else {}
),
**(
{"byte_perplexity": (loglikelihood, _bytes)}
if "byte_perplexity" in use_metric
else {}
),
**(
{"bits_per_byte": (loglikelihood, _bytes)}
if "bits_per_byte" in use_metric
else {}
),
}
elif self.OUTPUT_TYPE == "multiple_choice": elif self.OUTPUT_TYPE == "multiple_choice":
lls, is_greedy = zip(*results) lls, is_greedy = zip(*results)
...@@ -1063,14 +1045,14 @@ class ConfigurableTask(Task): ...@@ -1063,14 +1045,14 @@ class ConfigurableTask(Task):
gold = self.doc_to_target(doc) gold = self.doc_to_target(doc)
gold_index_error = False gold_index_error = False
if type(gold) is list: if isinstance(gold, list):
gold = [i if i < len(choices) else -100 for i in gold] gold = [i if i < len(choices) else -100 for i in gold]
if -100 in gold: if -100 in gold:
gold_index_error = True gold_index_error = True
else: else:
if type(gold) is int: if isinstance(gold, int):
gold = gold if gold < len(choices) else -100 gold = gold if gold < len(choices) else -100
elif type(gold) is str: elif isinstance(gold, str):
gold = choices.index(gold) if gold in choices else -100 gold = choices.index(gold) if gold in choices else -100
if gold == -100: if gold == -100:
...@@ -1092,12 +1074,13 @@ class ConfigurableTask(Task): ...@@ -1092,12 +1074,13 @@ class ConfigurableTask(Task):
# TODO: this gets score of 0 on arc_challenge for pythia-70m. need to test that this works properly # TODO: this gets score of 0 on arc_challenge for pythia-70m. need to test that this works properly
exact_match = int(is_greedy[gold]) if gold != -100 else 0 exact_match = int(is_greedy[gold]) if gold != -100 else 0
# gold, lls, is_greedy, completion_len
result_dict = { result_dict = {
**({"acc": acc} if "acc" in use_metric else {}), **({"acc": acc} if "acc" in use_metric else {}),
**({"f1": (gold, pred)} if "f1" in use_metric else {}),
**({"mcc": (gold, pred)} if "mcc" in use_metric else {}),
**({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}), **({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}),
**({"exact_match": exact_match} if "exact_match" in use_metric else {}), **({"exact_match": exact_match} if "exact_match" in use_metric else {}),
**({"f1": (gold, pred)} if "f1" in use_metric else {}),
**({"mcc": (gold, pred)} if "mcc" in use_metric else {}),
} }
if "acc_mutual_info" in use_metric: if "acc_mutual_info" in use_metric:
...@@ -1160,9 +1143,7 @@ class ConfigurableTask(Task): ...@@ -1160,9 +1143,7 @@ class ConfigurableTask(Task):
predictions=[result], predictions=[result],
**self._metric_fn_kwargs[metric], **self._metric_fn_kwargs[metric],
) )
except ( except TypeError: # needed for now in order to use a different interface between our own metrics and HF Evaluate metrics
TypeError
): # needed for now in order to use a different interface between our own metrics and HF Evaluate metrics
result_score = self._metric_fn_list[metric]([gold, result]) result_score = self._metric_fn_list[metric]([gold, result])
if isinstance(result_score, dict): if isinstance(result_score, dict):
# TODO: this handles the case where HF evaluate returns a dict. # TODO: this handles the case where HF evaluate returns a dict.
...@@ -1176,8 +1157,7 @@ class ConfigurableTask(Task): ...@@ -1176,8 +1157,7 @@ class ConfigurableTask(Task):
return result_dict return result_dict
def aggregation(self): def compute_metric(self):
# return self._aggregation_list
return self._metric_fn_list return self._metric_fn_list
def higher_is_better(self): def higher_is_better(self):
...@@ -1224,7 +1204,7 @@ class MultipleChoiceTask(Task): ...@@ -1224,7 +1204,7 @@ class MultipleChoiceTask(Task):
"acc_norm": True, "acc_norm": True,
} }
def aggregation(self) -> dict: def compute_metric(self) -> dict:
return { return {
"acc": mean, "acc": mean,
"acc_norm": mean, "acc_norm": mean,
...@@ -1285,7 +1265,7 @@ class PerplexityTask(Task): ...@@ -1285,7 +1265,7 @@ class PerplexityTask(Task):
"bits_per_byte": (loglikelihood, bytes_), "bits_per_byte": (loglikelihood, bytes_),
} }
def aggregation(self) -> dict: def compute_metric(self) -> dict:
return { return {
"word_perplexity": weighted_perplexity, "word_perplexity": weighted_perplexity,
"byte_perplexity": weighted_perplexity, "byte_perplexity": weighted_perplexity,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment