Commit 0a9ad6ee authored by lintangsutawika's avatar lintangsutawika
Browse files

much better way to process all metrics chosen

parent 5693abc5
......@@ -34,6 +34,13 @@ from lm_eval.logger import eval_logger
from lm_eval.prompts import get_prompt
from lm_eval.filters import build_filter_ensemble
ALL_OUTPUT_TYPES = [
"loglikelihood",
"multiple_choice",
"loglikelihood_rolling",
"greedy_until",
]
@dataclass
class TaskConfig(dict):
......@@ -80,12 +87,12 @@ class TaskConfig(dict):
# allow user-specified aliases so that users can
# force prompt-compatibility for some prompt regardless of
# field names in prompt
# if self.template_aliases is not None:
# if type(self.doc_to_text) == str:
# self.doc_to_text = self.template_aliases + self.doc_to_text
if self.template_aliases is not None:
if type(self.doc_to_text) == str:
self.doc_to_text = self.template_aliases + self.doc_to_text
# if type(self.doc_to_target) == str:
# self.doc_to_target = self.template_aliases + self.doc_to_target
if type(self.doc_to_target) == str:
self.doc_to_target = self.template_aliases + self.doc_to_target
# set "task_name" metadata field based on the "primary" name set
if self.names:
......@@ -472,6 +479,7 @@ class ConfigurableTask(Task):
)
if self._config.output_type is not None:
assert self._config.output_type in ALL_OUTPUT_TYPES
self.OUTPUT_TYPE = self._config.output_type
if self._config.dataset_path is not None:
......@@ -480,68 +488,71 @@ class ConfigurableTask(Task):
if self._config.dataset_name is not None:
self.DATASET_NAME = self._config.dataset_name
if self._config.metric_list is not None:
self._metric_list = {}
self._metric_kwargs = {}
self._aggregation_list = {}
self._higher_is_better = {}
if self._config.output_type == "greedy_until":
for metric_config in self._config.metric_list:
metric_name = metric_config["metric"]
aggregation = metric_config["aggregation"]
higher_is_better = metric_config["higher_is_better"]
kwargs = {
key: metric_config[key]
for key in metric_config
if key not in ["metric", "aggregation", "higher_is_better"]
}
self._aggregation_list[metric_name] = AGGREGATION_REGISTRY[
aggregation
]
self._metric_fn_list = {}
self._metric_fn_kwargs = {}
self._aggregation_list = {}
self._higher_is_better = {}
if metric_name in METRIC_REGISTRY.keys():
self._metric_list[metric_name] = METRIC_REGISTRY[metric_name]
self._higher_is_better[metric_name] = HIGHER_IS_BETTER_REGISTRY[
metric_name
]
else:
self._higher_is_better[metric_name] = higher_is_better
try:
metric_object = evaluate.load(metric_name)
self._metric_list[metric_name] = metric_object
self._metric_kwargs[metric_name] = kwargs
except Exception:
raise Warning(
"{} not found in the evaluate library!".format(
metric_name
),
"Please check https://huggingface.co/evaluate-metric",
)
else:
eval_logger.warning(
f"Output Type set as {self._config.output_type} which does not use metric_list"
"metric list will be unused."
)
if self._config.metric_list is None:
eval_logger.warning(
f"Output Type set as {self._config.output_type} and metric_list is not set"
"Will default to exact_match"
)
_metric_list = METRIC_REGISTRY[self._config.output_type]
for metric_name, metric_params in _metric_list.items():
self._metric_fn_list[metric_name] = metric_params["fn"]
self._aggregation_list[metric_name] = metric_params["aggregation"]
self._higher_is_better[metric_name] = metric_params["higher_is_better"]
else:
for metric_config in self._config.metric_list:
assert "metric" in metric_config
metric_name = metric_config["metric"]
kwargs = {
key: metric_config[key]
for key in metric_config
if key not in ["metric", "aggregation", "higher_is_better"]
}
if metric_name in _metric_list:
self._metric_fn_list[metric_name] = metric_params["fn"]
else:
eval_logger.warning(
f"Metric {metric_name} not found, "
"Searching from https://huggingface.co/evaluate-metric"
)
try:
metric_object = evaluate.load(metric_name)
self._metric_fn_list[metric_name] = metric_object
self._metric_fn_kwargs[metric_name] = kwargs
except Exception:
raise Warning(
"{} not found in the evaluate library!".format(metric_name),
"Please check https://huggingface.co/evaluate-metric",
)
if self._config.output_type == "loglikelihood":
metric_list = ["perplexity", "acc"]
elif self._config.output_type == "loglikelihood_rolling":
metric_list = [
"word_perplexity",
"byte_perplexity",
"bits_per_byte",
if "aggregation" in metric_config:
self._aggregation_list[metric_name] = metric_config["aggregation"]
else:
eval_logger.warning(
f"metric {metric_name} is defined, but aggregation is not"
f"using default aggregation for {metric_name}"
)
self._aggregation_list[metric_name] = _metric_list[metric_name][
"aggregation"
]
elif self._config.output_type == "multiple_choice":
metric_list = ["acc", "acc_norm"]
for metric_name in metric_list:
self._aggregation_list[metric_name] = AGGREGATION_REGISTRY["mean"]
self._higher_is_better[metric_name] = HIGHER_IS_BETTER_REGISTRY[
metric_name
if "higher_is_better" in metric_config:
self._higher_is_better[metric_name] = metric_config[
"higher_is_better"
]
else:
eval_logger.warning(
f"metric {metric_name} is defined, but higher_is_better is not"
f"using default higher_is_better for {metric_name}"
)
self._higher_is_better[metric_name] = _metric_list[metric_name][
"higher_is_better"
]
self.download(self._config.dataset_kwargs)
......@@ -743,18 +754,19 @@ class ConfigurableTask(Task):
result_dict = {"perplexity": ll, "acc": int(is_greedy)}
elif self.OUTPUT_TYPE == "loglikelihood_rolling":
(loglikelihood,) = results
words = self.count_words(self.doc_to_target(doc))
bytes_ = self.count_bytes(self.doc_to_target(doc))
_words = self.count_words(self.doc_to_target(doc))
_bytes = self.count_bytes(self.doc_to_target(doc))
return {
"word_perplexity": (loglikelihood, words),
"byte_perplexity": (loglikelihood, bytes_),
"bits_per_byte": (loglikelihood, bytes_),
"word_perplexity": (loglikelihood, _words),
"byte_perplexity": (loglikelihood, _bytes),
"bits_per_byte": (loglikelihood, _bytes),
}
elif self.OUTPUT_TYPE == "multiple_choice":
lls = [
res[0] for res in results
] # only retain loglikelihoods, discard is_greedy
gold = int(self.doc_to_target(doc))
pred = np.argmax(lls)
# retrieve choices in List[str] form, to compute choice lengths, etc.
choices = ast.literal_eval(
utils.apply_template(
......@@ -778,6 +790,7 @@ class ConfigurableTask(Task):
result_dict = {
"acc": acc,
"f1": (pred, gold),
"acc_norm": acc_norm,
}
......@@ -814,7 +827,7 @@ class ConfigurableTask(Task):
else:
raise ValueError(
f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ",
"'loglikelihood', 'loglikelihood_rolling', 'greedy_until'",
"'loglikelihood', 'loglikelihood_rolling', 'greedy_until', or 'multiple_choice'",
)
return result_dict
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment