"deploy/cpp/src/main_keypoint.cc" did not exist on "dcc7bf4f1a243d90d6c4f7c51551cea3f256325f"
Commit de46fb9a authored by lintangsutawika's avatar lintangsutawika
Browse files

reformat

parent dfb036b7
import os
import logging
import evaluate
import collections
import logging
from functools import partial
import evaluate
from lm_eval.api.model import LM
eval_logger = logging.getLogger("lm-eval")
MODEL_REGISTRY = {}
......@@ -92,9 +93,9 @@ def register_metric(
):
# TODO: do we want to enforce a certain interface to registered metrics?
def decorate(fn):
if type(metric) == str:
if isinstance(metric, str):
metric_list = [metric]
elif type(metric) == list:
elif isinstance(metric, list):
metric_list = metric
for _metric in metric_list:
......@@ -107,9 +108,9 @@ def register_metric(
METRIC_REGISTRY[_metric]["higher_is_better"] = higher_is_better
if output_type is not None:
if type(output_type) == str:
if isinstance(output_type, str):
output_type_list = [output_type]
elif type(output_type) == list:
elif isinstance(output_type, list):
output_type_list = output_type
for _output_type in output_type_list:
......@@ -121,7 +122,6 @@ def register_metric(
def get_metric(name):
if name in METRIC_REGISTRY:
return METRIC_REGISTRY[name]
else:
......@@ -133,7 +133,6 @@ def get_evaluate(name, **kwargs):
class HFEvaluateAdaptor:
def __init__(self, name, **kwargs):
self.name = name
metric_object = evaluate.load(name)
self.hf_evaluate_fn = partial(metric_object.compute, **kwargs)
......
......@@ -18,31 +18,13 @@ from lm_eval.api.metrics import (
bits_per_byte,
mean,
weighted_perplexity,
<<<<<<< HEAD
<<<<<<< HEAD
=======
>>>>>>> cda25fef4e1df2f4bc2dab3ec6668ae9f5bf7296
bits_per_byte,
)
from lm_eval.api.registry import (
get_metric,
get_evaluate,
get_aggregation,
METRIC_REGISTRY,
DEFAULT_METRIC_REGISTRY,
<<<<<<< HEAD
=======
)
from lm_eval.api.registry import (
AGGREGATION_REGISTRY,
DEFAULT_METRIC_REGISTRY,
METRIC_REGISTRY,
get_aggregation,
get_evaluate,
get_metric,
get_metric_aggregation,
is_higher_better,
>>>>>>> 4d10ad56b1ffe569467eee2297e2317c99313118
=======
>>>>>>> cda25fef4e1df2f4bc2dab3ec6668ae9f5bf7296
)
from lm_eval.filters import build_filter_ensemble
from lm_eval.prompts import get_prompt
......@@ -603,7 +585,7 @@ class ConfigurableTask(Task):
metric_fn = metric_name.__call__
metric_name = metric_name.__name__
else:
assert type(metric_name) == str
assert isinstance(metric_name, str)
if use_hf_evaluate:
metric_fn = get_evaluate(metric_name, **kwargs)
elif metric_name in METRIC_REGISTRY:
......@@ -620,7 +602,6 @@ class ConfigurableTask(Task):
self._aggregation_list[metric_name] = metric_fn
else:
if "aggregation" in metric_config:
agg_name = metric_config["aggregation"]
if isinstance(agg_name, str):
self._aggregation_list[metric_name] = get_aggregation(
......@@ -1028,7 +1009,6 @@ class ConfigurableTask(Task):
)
def process_results(self, doc, results):
# Process results returns 1 of X things per doc/results
# 1. A score
# 2. Components to be processed later to obtained a score. such as gold and prediction
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment