Commit de46fb9a authored by lintangsutawika's avatar lintangsutawika
Browse files

reformat

parent dfb036b7
import os
import logging
import evaluate
import collections
import logging
from functools import partial
import evaluate
from lm_eval.api.model import LM
eval_logger = logging.getLogger("lm-eval")
MODEL_REGISTRY = {}
......@@ -92,9 +93,9 @@ def register_metric(
):
# TODO: do we want to enforce a certain interface to registered metrics?
def decorate(fn):
if type(metric) == str:
if isinstance(metric, str):
metric_list = [metric]
elif type(metric) == list:
elif isinstance(metric, list):
metric_list = metric
for _metric in metric_list:
......@@ -107,9 +108,9 @@ def register_metric(
METRIC_REGISTRY[_metric]["higher_is_better"] = higher_is_better
if output_type is not None:
if type(output_type) == str:
if isinstance(output_type, str):
output_type_list = [output_type]
elif type(output_type) == list:
elif isinstance(output_type, list):
output_type_list = output_type
for _output_type in output_type_list:
......@@ -121,7 +122,6 @@ def register_metric(
def get_metric(name):
if name in METRIC_REGISTRY:
return METRIC_REGISTRY[name]
else:
......@@ -133,7 +133,6 @@ def get_evaluate(name, **kwargs):
class HFEvaluateAdaptor:
def __init__(self, name, **kwargs):
self.name = name
metric_object = evaluate.load(name)
self.hf_evaluate_fn = partial(metric_object.compute, **kwargs)
......
......@@ -18,31 +18,13 @@ from lm_eval.api.metrics import (
bits_per_byte,
mean,
weighted_perplexity,
<<<<<<< HEAD
<<<<<<< HEAD
=======
>>>>>>> cda25fef4e1df2f4bc2dab3ec6668ae9f5bf7296
bits_per_byte,
)
from lm_eval.api.registry import (
get_metric,
get_evaluate,
get_aggregation,
METRIC_REGISTRY,
DEFAULT_METRIC_REGISTRY,
<<<<<<< HEAD
=======
)
from lm_eval.api.registry import (
AGGREGATION_REGISTRY,
DEFAULT_METRIC_REGISTRY,
METRIC_REGISTRY,
get_aggregation,
get_evaluate,
get_metric,
get_metric_aggregation,
is_higher_better,
>>>>>>> 4d10ad56b1ffe569467eee2297e2317c99313118
=======
>>>>>>> cda25fef4e1df2f4bc2dab3ec6668ae9f5bf7296
)
from lm_eval.filters import build_filter_ensemble
from lm_eval.prompts import get_prompt
......@@ -603,7 +585,7 @@ class ConfigurableTask(Task):
metric_fn = metric_name.__call__
metric_name = metric_name.__name__
else:
assert type(metric_name) == str
assert isinstance(metric_name, str)
if use_hf_evaluate:
metric_fn = get_evaluate(metric_name, **kwargs)
elif metric_name in METRIC_REGISTRY:
......@@ -620,7 +602,6 @@ class ConfigurableTask(Task):
self._aggregation_list[metric_name] = metric_fn
else:
if "aggregation" in metric_config:
agg_name = metric_config["aggregation"]
if isinstance(agg_name, str):
self._aggregation_list[metric_name] = get_aggregation(
......@@ -1028,7 +1009,6 @@ class ConfigurableTask(Task):
)
def process_results(self, doc, results):
# Process results returns 1 of X things per doc/results
# 1. A score
# 2. Components to be processed later to obtained a score. such as gold and prediction
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment