"vscode:/vscode.git/clone" did not exist on "49df1dc595734d20ecdf9dfe11933e527fea84f1"
Commit ba41ab7c authored by lintangsutawika's avatar lintangsutawika
Browse files

mejor revision into how metrics are properly loaded for all output_types

parent b3591562
...@@ -19,21 +19,24 @@ from lm_eval import utils ...@@ -19,21 +19,24 @@ from lm_eval import utils
from lm_eval.api import samplers from lm_eval.api import samplers
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
from lm_eval.api.filter import FilterEnsemble from lm_eval.api.filter import FilterEnsemble
from lm_eval.api.metrics import (
from lm_eval.logger import eval_logger
from lm_eval.prompts import get_prompt
from lm_eval.filters import build_filter_ensemble
from lm_eval.metrics import (
METRIC_REGISTRY, METRIC_REGISTRY,
DEFAULT_METRIC_REGISTRY,
OUTPUT_TYPE_REGISTRY,
AGGREGATION_REGISTRY, AGGREGATION_REGISTRY,
HIGHER_IS_BETTER_REGISTRY, HIGHER_IS_BETTER_REGISTRY,
get_metric, DEFAULT_AGGREGATION_REGISTRY,
get_aggregation, # get_metric,
# get_aggregation,
mean, mean,
weighted_perplexity, weighted_perplexity,
bits_per_byte, bits_per_byte,
) )
from lm_eval.logger import eval_logger
from lm_eval.prompts import get_prompt
from lm_eval.filters import build_filter_ensemble
ALL_OUTPUT_TYPES = [ ALL_OUTPUT_TYPES = [
"loglikelihood", "loglikelihood",
"multiple_choice", "multiple_choice",
...@@ -493,16 +496,19 @@ class ConfigurableTask(Task): ...@@ -493,16 +496,19 @@ class ConfigurableTask(Task):
self._aggregation_list = {} self._aggregation_list = {}
self._higher_is_better = {} self._higher_is_better = {}
_metric_list = DEFAULT_METRIC_REGISTRY[self._config.output_type]
if self._config.metric_list is None: if self._config.metric_list is None:
eval_logger.warning( eval_logger.warning(
f"Output Type set as {self._config.output_type} and metric_list is not set" f"Output Type set as {self._config.output_type} and metric_list is not set"
"Will default to exact_match" "Will default to exact_match"
) )
_metric_list = METRIC_REGISTRY[self._config.output_type] for metric_name in _metric_list:
for metric_name, metric_params in _metric_list.items(): self._metric_fn_list[metric_name] = METRIC_REGISTRY[metric_name]
self._metric_fn_list[metric_name] = metric_params["fn"] aggregation = DEFAULT_AGGREGATION_REGISTRY[metric_name]
self._aggregation_list[metric_name] = metric_params["aggregation"] self._aggregation_list[metric_name] = AGGREGATION_REGISTRY[aggregation]
self._higher_is_better[metric_name] = metric_params["higher_is_better"] self._higher_is_better[metric_name] = HIGHER_IS_BETTER_REGISTRY[
metric_name
]
else: else:
for metric_config in self._config.metric_list: for metric_config in self._config.metric_list:
...@@ -514,7 +520,7 @@ class ConfigurableTask(Task): ...@@ -514,7 +520,7 @@ class ConfigurableTask(Task):
if key not in ["metric", "aggregation", "higher_is_better"] if key not in ["metric", "aggregation", "higher_is_better"]
} }
if metric_name in _metric_list: if metric_name in _metric_list:
self._metric_fn_list[metric_name] = metric_params["fn"] self._metric_fn_list[metric_name] = METRIC_REGISTRY[metric_name]
else: else:
eval_logger.warning( eval_logger.warning(
f"Metric {metric_name} not found, " f"Metric {metric_name} not found, "
...@@ -538,8 +544,9 @@ class ConfigurableTask(Task): ...@@ -538,8 +544,9 @@ class ConfigurableTask(Task):
f"metric {metric_name} is defined, but aggregation is not" f"metric {metric_name} is defined, but aggregation is not"
f"using default aggregation for {metric_name}" f"using default aggregation for {metric_name}"
) )
self._aggregation_list[metric_name] = _metric_list[metric_name][ aggregation = DEFAULT_AGGREGATION_REGISTRY[metric_name]
"aggregation" self._aggregation_list[metric_name] = AGGREGATION_REGISTRY[
aggregation
] ]
if "higher_is_better" in metric_config: if "higher_is_better" in metric_config:
...@@ -551,8 +558,8 @@ class ConfigurableTask(Task): ...@@ -551,8 +558,8 @@ class ConfigurableTask(Task):
f"metric {metric_name} is defined, but higher_is_better is not" f"metric {metric_name} is defined, but higher_is_better is not"
f"using default higher_is_better for {metric_name}" f"using default higher_is_better for {metric_name}"
) )
self._higher_is_better[metric_name] = _metric_list[metric_name][ self._higher_is_better[metric_name] = HIGHER_IS_BETTER_REGISTRY[
"higher_is_better" metric_name
] ]
self.download(self._config.dataset_kwargs) self.download(self._config.dataset_kwargs)
...@@ -717,7 +724,7 @@ class ConfigurableTask(Task): ...@@ -717,7 +724,7 @@ class ConfigurableTask(Task):
for i, choice in enumerate(choices) for i, choice in enumerate(choices)
] ]
# TODO: we should raise a warning telling users this will at most ~2x runtime. # TODO: we should raise a warning telling users this will at most ~2x runtime.
if "acc_mutual_info" in self._metric_list.keys(): if "acc_mutual_info" in self._metric_fn_list.keys():
# if we are calculating multiple choice accuracy # if we are calculating multiple choice accuracy
# using mutual information instead of raw loglikelihood as metric, need unconditional lls. # using mutual information instead of raw loglikelihood as metric, need unconditional lls.
...@@ -748,23 +755,38 @@ class ConfigurableTask(Task): ...@@ -748,23 +755,38 @@ class ConfigurableTask(Task):
def process_results(self, doc, results): def process_results(self, doc, results):
result_dict = {} result_dict = {}
use_metric = list(self._metric_fn_list.keys())
if self.OUTPUT_TYPE == "loglikelihood": if self.OUTPUT_TYPE == "loglikelihood":
results = results[0] results = results[0]
ll, is_greedy = results ll, is_greedy = results
result_dict = {"perplexity": ll, "acc": int(is_greedy)} return {
**({"perplexity": ll} if "perplexity" in use_metric else {}),
**({"acc": int(is_greedy)} if "acc" in use_metric else {}),
}
elif self.OUTPUT_TYPE == "loglikelihood_rolling": elif self.OUTPUT_TYPE == "loglikelihood_rolling":
(loglikelihood,) = results (loglikelihood,) = results
_words = self.count_words(self.doc_to_target(doc)) _words = self.count_words(self.doc_to_target(doc))
_bytes = self.count_bytes(self.doc_to_target(doc)) _bytes = self.count_bytes(self.doc_to_target(doc))
return { return {
"word_perplexity": (loglikelihood, _words), **(
"byte_perplexity": (loglikelihood, _bytes), {"word_perplexity": (loglikelihood, _words)}
"bits_per_byte": (loglikelihood, _bytes), if "word_perplexity" in use_metric
else {}
),
**(
{"byte_perplexity": (loglikelihood, _bytes)}
if "byte_perplexity" in use_metric
else {}
),
**(
{"bits_per_byte": (loglikelihood, _bytes)}
if "bits_per_byte" in use_metric
else {}
),
} }
elif self.OUTPUT_TYPE == "multiple_choice": elif self.OUTPUT_TYPE == "multiple_choice":
lls = [
res[0] for res in results lls, is_greedy = zip(*results)
] # only retain loglikelihoods, discard is_greedy
gold = int(self.doc_to_target(doc)) gold = int(self.doc_to_target(doc))
pred = np.argmax(lls) pred = np.argmax(lls)
# retrieve choices in List[str] form, to compute choice lengths, etc. # retrieve choices in List[str] form, to compute choice lengths, etc.
...@@ -773,38 +795,32 @@ class ConfigurableTask(Task): ...@@ -773,38 +795,32 @@ class ConfigurableTask(Task):
self._config.template_aliases + "{{answer_choices}}", doc self._config.template_aliases + "{{answer_choices}}", doc
) )
) )
if (
2 * len(choices) == len(lls)
and "acc_mutual_info" in self._metric_list.keys()
):
# then we are doing mutual info.
# this stores the "dryrun" / unconditional answer loglikelihoods
lls_unconditional = lls[1::2]
assert len(lls_unconditional) == len(choices)
# and this stores our "regular" conditional loglikelihoods
lls = lls[::2]
acc = 1.0 if np.argmax(lls) == gold else 0.0 acc = 1.0 if np.argmax(lls) == gold else 0.0
completion_len = np.array([float(len(i)) for i in choices]) completion_len = np.array([float(len(i)) for i in choices])
acc_norm = 1.0 if np.argmax(lls / completion_len) == gold else 0.0 acc_norm = 1.0 if np.argmax(lls / completion_len) == gold else 0.0
result_dict = { result_dict = {
"acc": acc, **({"acc": acc} if "acc" in use_metric else {}),
"f1": (pred, gold), **({"f1": (pred, gold)} if "f1" in use_metric else {}),
"acc_norm": acc_norm, **({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}),
} }
# TODO: set which normalization metrics should be reported, and calculate them # TODO: set which normalization metrics should be reported, and calculate them
if "exact_match" in self._metric_fn_list.keys():
if "exact_match" in self._metric_list.keys():
# TODO: this gets score of 0 on arc_challenge for pythia-70m. need to test that this works properly # TODO: this gets score of 0 on arc_challenge for pythia-70m. need to test that this works properly
is_greedy = [
res[1] for res in results
] # take only the `is_greedy` results
is_greedy = is_greedy[gold] # take value for the gold answer is_greedy = is_greedy[gold] # take value for the gold answer
result_dict["exact_match"] = int(is_greedy) result_dict["exact_match"] = int(is_greedy)
if "acc_mutual_info" in self._metric_list.keys(): if "acc_mutual_info" in use_metric:
if 2 * len(choices) == len(lls):
# then we are doing mutual info.
# this stores the "dryrun" / unconditional answer loglikelihoods
lls_unconditional = lls[1::2]
assert len(lls_unconditional) == len(choices)
# and this stores our "regular" conditional loglikelihoods
lls = lls[::2]
lls_mutual_info = [ lls_mutual_info = [
ll_c - ll_u for ll_c, ll_u in zip(lls, lls_unconditional) ll_c - ll_u for ll_c, ll_u in zip(lls, lls_unconditional)
] ]
...@@ -818,8 +834,8 @@ class ConfigurableTask(Task): ...@@ -818,8 +834,8 @@ class ConfigurableTask(Task):
else: else:
gold = self.doc_to_target(doc) gold = self.doc_to_target(doc)
for key, result in zip(self._metric_list.keys(), results): for key, result in zip(self._metric_fn_list.keys(), results):
_dict = self._metric_list[key].compute( _dict = self._metric_fn_list[key].compute(
references=[gold], predictions=[result], **self._metric_kwargs[key] references=[gold], predictions=[result], **self._metric_kwargs[key]
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment