Commit a22d8ffa authored by lintangsutawika's avatar lintangsutawika
Browse files

modified import orgin

parent f2166089
......@@ -23,18 +23,20 @@ from lm_eval.api.filter import FilterEnsemble
from lm_eval.logger import eval_logger
from lm_eval.prompts import get_prompt
from lm_eval.filters import build_filter_ensemble
from lm_eval.metrics import (
from lm_eval.api.metrics import (
# get_metric,
# get_aggregation,
mean,
weighted_perplexity,
bits_per_byte,
)
from lm_eval.api.registry import (
METRIC_REGISTRY,
DEFAULT_METRIC_REGISTRY,
OUTPUT_TYPE_REGISTRY,
AGGREGATION_REGISTRY,
HIGHER_IS_BETTER_REGISTRY,
DEFAULT_AGGREGATION_REGISTRY,
# get_metric,
# get_aggregation,
mean,
weighted_perplexity,
bits_per_byte,
)
ALL_OUTPUT_TYPES = [
......@@ -504,8 +506,9 @@ class ConfigurableTask(Task):
)
for metric_name in _metric_list:
self._metric_fn_list[metric_name] = METRIC_REGISTRY[metric_name]
aggregation = DEFAULT_AGGREGATION_REGISTRY[metric_name]
self._aggregation_list[metric_name] = AGGREGATION_REGISTRY[aggregation]
self._aggregation_list[metric_name] = DEFAULT_AGGREGATION_REGISTRY[
metric_name
]
self._higher_is_better[metric_name] = HIGHER_IS_BETTER_REGISTRY[
metric_name
]
......@@ -754,6 +757,9 @@ class ConfigurableTask(Task):
def process_results(self, doc, results):
# if callable(self._config.process_results):
# return self._config.process_results(doc, results)
result_dict = {}
use_metric = list(self._metric_fn_list.keys())
if self.OUTPUT_TYPE == "loglikelihood":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment