Commit 0a9ad6ee authored by lintangsutawika's avatar lintangsutawika
Browse files

much better way to process all metrics chosen

parent 5693abc5
...@@ -34,6 +34,13 @@ from lm_eval.logger import eval_logger ...@@ -34,6 +34,13 @@ from lm_eval.logger import eval_logger
from lm_eval.prompts import get_prompt from lm_eval.prompts import get_prompt
from lm_eval.filters import build_filter_ensemble from lm_eval.filters import build_filter_ensemble
ALL_OUTPUT_TYPES = [
"loglikelihood",
"multiple_choice",
"loglikelihood_rolling",
"greedy_until",
]
@dataclass @dataclass
class TaskConfig(dict): class TaskConfig(dict):
...@@ -80,12 +87,12 @@ class TaskConfig(dict): ...@@ -80,12 +87,12 @@ class TaskConfig(dict):
# allow user-specified aliases so that users can # allow user-specified aliases so that users can
# force prompt-compatibility for some prompt regardless of # force prompt-compatibility for some prompt regardless of
# field names in prompt # field names in prompt
# if self.template_aliases is not None: if self.template_aliases is not None:
# if type(self.doc_to_text) == str: if type(self.doc_to_text) == str:
# self.doc_to_text = self.template_aliases + self.doc_to_text self.doc_to_text = self.template_aliases + self.doc_to_text
# if type(self.doc_to_target) == str: if type(self.doc_to_target) == str:
# self.doc_to_target = self.template_aliases + self.doc_to_target self.doc_to_target = self.template_aliases + self.doc_to_target
# set "task_name" metadata field based on the "primary" name set # set "task_name" metadata field based on the "primary" name set
if self.names: if self.names:
...@@ -472,6 +479,7 @@ class ConfigurableTask(Task): ...@@ -472,6 +479,7 @@ class ConfigurableTask(Task):
) )
if self._config.output_type is not None: if self._config.output_type is not None:
assert self._config.output_type in ALL_OUTPUT_TYPES
self.OUTPUT_TYPE = self._config.output_type self.OUTPUT_TYPE = self._config.output_type
if self._config.dataset_path is not None: if self._config.dataset_path is not None:
...@@ -480,68 +488,71 @@ class ConfigurableTask(Task): ...@@ -480,68 +488,71 @@ class ConfigurableTask(Task):
if self._config.dataset_name is not None: if self._config.dataset_name is not None:
self.DATASET_NAME = self._config.dataset_name self.DATASET_NAME = self._config.dataset_name
if self._config.metric_list is not None: self._metric_fn_list = {}
self._metric_list = {} self._metric_fn_kwargs = {}
self._metric_kwargs = {} self._aggregation_list = {}
self._aggregation_list = {} self._higher_is_better = {}
self._higher_is_better = {}
if self._config.output_type == "greedy_until":
for metric_config in self._config.metric_list:
metric_name = metric_config["metric"]
aggregation = metric_config["aggregation"]
higher_is_better = metric_config["higher_is_better"]
kwargs = {
key: metric_config[key]
for key in metric_config
if key not in ["metric", "aggregation", "higher_is_better"]
}
self._aggregation_list[metric_name] = AGGREGATION_REGISTRY[
aggregation
]
if metric_name in METRIC_REGISTRY.keys(): if self._config.metric_list is None:
self._metric_list[metric_name] = METRIC_REGISTRY[metric_name] eval_logger.warning(
self._higher_is_better[metric_name] = HIGHER_IS_BETTER_REGISTRY[ f"Output Type set as {self._config.output_type} and metric_list is not set"
metric_name "Will default to exact_match"
] )
else: _metric_list = METRIC_REGISTRY[self._config.output_type]
self._higher_is_better[metric_name] = higher_is_better for metric_name, metric_params in _metric_list.items():
try: self._metric_fn_list[metric_name] = metric_params["fn"]
metric_object = evaluate.load(metric_name) self._aggregation_list[metric_name] = metric_params["aggregation"]
self._metric_list[metric_name] = metric_object self._higher_is_better[metric_name] = metric_params["higher_is_better"]
self._metric_kwargs[metric_name] = kwargs else:
for metric_config in self._config.metric_list:
except Exception:
raise Warning( assert "metric" in metric_config
"{} not found in the evaluate library!".format( metric_name = metric_config["metric"]
metric_name kwargs = {
), key: metric_config[key]
"Please check https://huggingface.co/evaluate-metric", for key in metric_config
) if key not in ["metric", "aggregation", "higher_is_better"]
else: }
eval_logger.warning( if metric_name in _metric_list:
f"Output Type set as {self._config.output_type} which does not use metric_list" self._metric_fn_list[metric_name] = metric_params["fn"]
"metric list will be unused." else:
) eval_logger.warning(
f"Metric {metric_name} not found, "
"Searching from https://huggingface.co/evaluate-metric"
)
try:
metric_object = evaluate.load(metric_name)
self._metric_fn_list[metric_name] = metric_object
self._metric_fn_kwargs[metric_name] = kwargs
except Exception:
raise Warning(
"{} not found in the evaluate library!".format(metric_name),
"Please check https://huggingface.co/evaluate-metric",
)
if self._config.output_type == "loglikelihood": if "aggregation" in metric_config:
metric_list = ["perplexity", "acc"] self._aggregation_list[metric_name] = metric_config["aggregation"]
elif self._config.output_type == "loglikelihood_rolling": else:
metric_list = [ eval_logger.warning(
"word_perplexity", f"metric {metric_name} is defined, but aggregation is not"
"byte_perplexity", f"using default aggregation for {metric_name}"
"bits_per_byte", )
self._aggregation_list[metric_name] = _metric_list[metric_name][
"aggregation"
] ]
elif self._config.output_type == "multiple_choice":
metric_list = ["acc", "acc_norm"]
for metric_name in metric_list: if "higher_is_better" in metric_config:
self._aggregation_list[metric_name] = AGGREGATION_REGISTRY["mean"] self._higher_is_better[metric_name] = metric_config[
self._higher_is_better[metric_name] = HIGHER_IS_BETTER_REGISTRY[ "higher_is_better"
metric_name ]
else:
eval_logger.warning(
f"metric {metric_name} is defined, but higher_is_better is not"
f"using default higher_is_better for {metric_name}"
)
self._higher_is_better[metric_name] = _metric_list[metric_name][
"higher_is_better"
] ]
self.download(self._config.dataset_kwargs) self.download(self._config.dataset_kwargs)
...@@ -743,18 +754,19 @@ class ConfigurableTask(Task): ...@@ -743,18 +754,19 @@ class ConfigurableTask(Task):
result_dict = {"perplexity": ll, "acc": int(is_greedy)} result_dict = {"perplexity": ll, "acc": int(is_greedy)}
elif self.OUTPUT_TYPE == "loglikelihood_rolling": elif self.OUTPUT_TYPE == "loglikelihood_rolling":
(loglikelihood,) = results (loglikelihood,) = results
words = self.count_words(self.doc_to_target(doc)) _words = self.count_words(self.doc_to_target(doc))
bytes_ = self.count_bytes(self.doc_to_target(doc)) _bytes = self.count_bytes(self.doc_to_target(doc))
return { return {
"word_perplexity": (loglikelihood, words), "word_perplexity": (loglikelihood, _words),
"byte_perplexity": (loglikelihood, bytes_), "byte_perplexity": (loglikelihood, _bytes),
"bits_per_byte": (loglikelihood, bytes_), "bits_per_byte": (loglikelihood, _bytes),
} }
elif self.OUTPUT_TYPE == "multiple_choice": elif self.OUTPUT_TYPE == "multiple_choice":
lls = [ lls = [
res[0] for res in results res[0] for res in results
] # only retain loglikelihoods, discard is_greedy ] # only retain loglikelihoods, discard is_greedy
gold = int(self.doc_to_target(doc)) gold = int(self.doc_to_target(doc))
pred = np.argmax(lls)
# retrieve choices in List[str] form, to compute choice lengths, etc. # retrieve choices in List[str] form, to compute choice lengths, etc.
choices = ast.literal_eval( choices = ast.literal_eval(
utils.apply_template( utils.apply_template(
...@@ -778,6 +790,7 @@ class ConfigurableTask(Task): ...@@ -778,6 +790,7 @@ class ConfigurableTask(Task):
result_dict = { result_dict = {
"acc": acc, "acc": acc,
"f1": (pred, gold),
"acc_norm": acc_norm, "acc_norm": acc_norm,
} }
...@@ -814,7 +827,7 @@ class ConfigurableTask(Task): ...@@ -814,7 +827,7 @@ class ConfigurableTask(Task):
else: else:
raise ValueError( raise ValueError(
f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ", f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ",
"'loglikelihood', 'loglikelihood_rolling', 'greedy_until'", "'loglikelihood', 'loglikelihood_rolling', 'greedy_until', or 'multiple_choice'",
) )
return result_dict return result_dict
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment