Commit ba41ab7c authored by lintangsutawika's avatar lintangsutawika
Browse files

mejor revision into how metrics are properly loaded for all output_types

parent b3591562
......@@ -19,21 +19,24 @@ from lm_eval import utils
from lm_eval.api import samplers
from lm_eval.api.instance import Instance
from lm_eval.api.filter import FilterEnsemble
from lm_eval.api.metrics import (
from lm_eval.logger import eval_logger
from lm_eval.prompts import get_prompt
from lm_eval.filters import build_filter_ensemble
from lm_eval.metrics import (
METRIC_REGISTRY,
DEFAULT_METRIC_REGISTRY,
OUTPUT_TYPE_REGISTRY,
AGGREGATION_REGISTRY,
HIGHER_IS_BETTER_REGISTRY,
get_metric,
get_aggregation,
DEFAULT_AGGREGATION_REGISTRY,
# get_metric,
# get_aggregation,
mean,
weighted_perplexity,
bits_per_byte,
)
from lm_eval.logger import eval_logger
from lm_eval.prompts import get_prompt
from lm_eval.filters import build_filter_ensemble
ALL_OUTPUT_TYPES = [
"loglikelihood",
"multiple_choice",
......@@ -493,16 +496,19 @@ class ConfigurableTask(Task):
self._aggregation_list = {}
self._higher_is_better = {}
_metric_list = DEFAULT_METRIC_REGISTRY[self._config.output_type]
if self._config.metric_list is None:
eval_logger.warning(
f"Output Type set as {self._config.output_type} and metric_list is not set"
"Will default to exact_match"
)
_metric_list = METRIC_REGISTRY[self._config.output_type]
for metric_name, metric_params in _metric_list.items():
self._metric_fn_list[metric_name] = metric_params["fn"]
self._aggregation_list[metric_name] = metric_params["aggregation"]
self._higher_is_better[metric_name] = metric_params["higher_is_better"]
for metric_name in _metric_list:
self._metric_fn_list[metric_name] = METRIC_REGISTRY[metric_name]
aggregation = DEFAULT_AGGREGATION_REGISTRY[metric_name]
self._aggregation_list[metric_name] = AGGREGATION_REGISTRY[aggregation]
self._higher_is_better[metric_name] = HIGHER_IS_BETTER_REGISTRY[
metric_name
]
else:
for metric_config in self._config.metric_list:
......@@ -514,7 +520,7 @@ class ConfigurableTask(Task):
if key not in ["metric", "aggregation", "higher_is_better"]
}
if metric_name in _metric_list:
self._metric_fn_list[metric_name] = metric_params["fn"]
self._metric_fn_list[metric_name] = METRIC_REGISTRY[metric_name]
else:
eval_logger.warning(
f"Metric {metric_name} not found, "
......@@ -538,8 +544,9 @@ class ConfigurableTask(Task):
f"metric {metric_name} is defined, but aggregation is not"
f"using default aggregation for {metric_name}"
)
self._aggregation_list[metric_name] = _metric_list[metric_name][
"aggregation"
aggregation = DEFAULT_AGGREGATION_REGISTRY[metric_name]
self._aggregation_list[metric_name] = AGGREGATION_REGISTRY[
aggregation
]
if "higher_is_better" in metric_config:
......@@ -551,8 +558,8 @@ class ConfigurableTask(Task):
f"metric {metric_name} is defined, but higher_is_better is not"
f"using default higher_is_better for {metric_name}"
)
self._higher_is_better[metric_name] = _metric_list[metric_name][
"higher_is_better"
self._higher_is_better[metric_name] = HIGHER_IS_BETTER_REGISTRY[
metric_name
]
self.download(self._config.dataset_kwargs)
......@@ -717,7 +724,7 @@ class ConfigurableTask(Task):
for i, choice in enumerate(choices)
]
# TODO: we should raise a warning telling users this will at most ~2x runtime.
if "acc_mutual_info" in self._metric_list.keys():
if "acc_mutual_info" in self._metric_fn_list.keys():
# if we are calculating multiple choice accuracy
# using mutual information instead of raw loglikelihood as metric, need unconditional lls.
......@@ -748,23 +755,38 @@ class ConfigurableTask(Task):
def process_results(self, doc, results):
result_dict = {}
use_metric = list(self._metric_fn_list.keys())
if self.OUTPUT_TYPE == "loglikelihood":
results = results[0]
ll, is_greedy = results
result_dict = {"perplexity": ll, "acc": int(is_greedy)}
return {
**({"perplexity": ll} if "perplexity" in use_metric else {}),
**({"acc": int(is_greedy)} if "acc" in use_metric else {}),
}
elif self.OUTPUT_TYPE == "loglikelihood_rolling":
(loglikelihood,) = results
_words = self.count_words(self.doc_to_target(doc))
_bytes = self.count_bytes(self.doc_to_target(doc))
return {
"word_perplexity": (loglikelihood, _words),
"byte_perplexity": (loglikelihood, _bytes),
"bits_per_byte": (loglikelihood, _bytes),
**(
{"word_perplexity": (loglikelihood, _words)}
if "word_perplexity" in use_metric
else {}
),
**(
{"byte_perplexity": (loglikelihood, _bytes)}
if "byte_perplexity" in use_metric
else {}
),
**(
{"bits_per_byte": (loglikelihood, _bytes)}
if "bits_per_byte" in use_metric
else {}
),
}
elif self.OUTPUT_TYPE == "multiple_choice":
lls = [
res[0] for res in results
] # only retain loglikelihoods, discard is_greedy
lls, is_greedy = zip(*results)
gold = int(self.doc_to_target(doc))
pred = np.argmax(lls)
# retrieve choices in List[str] form, to compute choice lengths, etc.
......@@ -773,38 +795,32 @@ class ConfigurableTask(Task):
self._config.template_aliases + "{{answer_choices}}", doc
)
)
if (
2 * len(choices) == len(lls)
and "acc_mutual_info" in self._metric_list.keys()
):
# then we are doing mutual info.
# this stores the "dryrun" / unconditional answer loglikelihoods
lls_unconditional = lls[1::2]
assert len(lls_unconditional) == len(choices)
# and this stores our "regular" conditional loglikelihoods
lls = lls[::2]
acc = 1.0 if np.argmax(lls) == gold else 0.0
completion_len = np.array([float(len(i)) for i in choices])
acc_norm = 1.0 if np.argmax(lls / completion_len) == gold else 0.0
result_dict = {
"acc": acc,
"f1": (pred, gold),
"acc_norm": acc_norm,
**({"acc": acc} if "acc" in use_metric else {}),
**({"f1": (pred, gold)} if "f1" in use_metric else {}),
**({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}),
}
# TODO: set which normalization metrics should be reported, and calculate them
if "exact_match" in self._metric_list.keys():
if "exact_match" in self._metric_fn_list.keys():
# TODO: this gets score of 0 on arc_challenge for pythia-70m. need to test that this works properly
is_greedy = [
res[1] for res in results
] # take only the `is_greedy` results
is_greedy = is_greedy[gold] # take value for the gold answer
result_dict["exact_match"] = int(is_greedy)
if "acc_mutual_info" in self._metric_list.keys():
if "acc_mutual_info" in use_metric:
if 2 * len(choices) == len(lls):
# then we are doing mutual info.
# this stores the "dryrun" / unconditional answer loglikelihoods
lls_unconditional = lls[1::2]
assert len(lls_unconditional) == len(choices)
# and this stores our "regular" conditional loglikelihoods
lls = lls[::2]
lls_mutual_info = [
ll_c - ll_u for ll_c, ll_u in zip(lls, lls_unconditional)
]
......@@ -818,8 +834,8 @@ class ConfigurableTask(Task):
else:
gold = self.doc_to_target(doc)
for key, result in zip(self._metric_list.keys(), results):
_dict = self._metric_list[key].compute(
for key, result in zip(self._metric_fn_list.keys(), results):
_dict = self._metric_fn_list[key].compute(
references=[gold], predictions=[result], **self._metric_kwargs[key]
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment