Commit de46fb9a authored by lintangsutawika's avatar lintangsutawika
Browse files

reformat

parent dfb036b7
import os
import logging
import evaluate
import collections import collections
import logging
from functools import partial from functools import partial
import evaluate
from lm_eval.api.model import LM from lm_eval.api.model import LM
eval_logger = logging.getLogger("lm-eval") eval_logger = logging.getLogger("lm-eval")
MODEL_REGISTRY = {} MODEL_REGISTRY = {}
...@@ -92,9 +93,9 @@ def register_metric( ...@@ -92,9 +93,9 @@ def register_metric(
): ):
# TODO: do we want to enforce a certain interface to registered metrics? # TODO: do we want to enforce a certain interface to registered metrics?
def decorate(fn): def decorate(fn):
if type(metric) == str: if isinstance(metric, str):
metric_list = [metric] metric_list = [metric]
elif type(metric) == list: elif isinstance(metric, list):
metric_list = metric metric_list = metric
for _metric in metric_list: for _metric in metric_list:
...@@ -107,9 +108,9 @@ def register_metric( ...@@ -107,9 +108,9 @@ def register_metric(
METRIC_REGISTRY[_metric]["higher_is_better"] = higher_is_better METRIC_REGISTRY[_metric]["higher_is_better"] = higher_is_better
if output_type is not None: if output_type is not None:
if type(output_type) == str: if isinstance(output_type, str):
output_type_list = [output_type] output_type_list = [output_type]
elif type(output_type) == list: elif isinstance(output_type, list):
output_type_list = output_type output_type_list = output_type
for _output_type in output_type_list: for _output_type in output_type_list:
...@@ -121,7 +122,6 @@ def register_metric( ...@@ -121,7 +122,6 @@ def register_metric(
def get_metric(name): def get_metric(name):
if name in METRIC_REGISTRY: if name in METRIC_REGISTRY:
return METRIC_REGISTRY[name] return METRIC_REGISTRY[name]
else: else:
...@@ -133,7 +133,6 @@ def get_evaluate(name, **kwargs): ...@@ -133,7 +133,6 @@ def get_evaluate(name, **kwargs):
class HFEvaluateAdaptor: class HFEvaluateAdaptor:
def __init__(self, name, **kwargs): def __init__(self, name, **kwargs):
self.name = name self.name = name
metric_object = evaluate.load(name) metric_object = evaluate.load(name)
self.hf_evaluate_fn = partial(metric_object.compute, **kwargs) self.hf_evaluate_fn = partial(metric_object.compute, **kwargs)
......
...@@ -18,31 +18,13 @@ from lm_eval.api.metrics import ( ...@@ -18,31 +18,13 @@ from lm_eval.api.metrics import (
bits_per_byte, bits_per_byte,
mean, mean,
weighted_perplexity, weighted_perplexity,
<<<<<<< HEAD
<<<<<<< HEAD
=======
>>>>>>> cda25fef4e1df2f4bc2dab3ec6668ae9f5bf7296
bits_per_byte,
)
from lm_eval.api.registry import (
get_metric,
get_evaluate,
get_aggregation,
METRIC_REGISTRY,
DEFAULT_METRIC_REGISTRY,
<<<<<<< HEAD
=======
) )
from lm_eval.api.registry import ( from lm_eval.api.registry import (
AGGREGATION_REGISTRY,
DEFAULT_METRIC_REGISTRY, DEFAULT_METRIC_REGISTRY,
METRIC_REGISTRY,
get_aggregation, get_aggregation,
get_evaluate,
get_metric, get_metric,
get_metric_aggregation,
is_higher_better,
>>>>>>> 4d10ad56b1ffe569467eee2297e2317c99313118
=======
>>>>>>> cda25fef4e1df2f4bc2dab3ec6668ae9f5bf7296
) )
from lm_eval.filters import build_filter_ensemble from lm_eval.filters import build_filter_ensemble
from lm_eval.prompts import get_prompt from lm_eval.prompts import get_prompt
...@@ -603,7 +585,7 @@ class ConfigurableTask(Task): ...@@ -603,7 +585,7 @@ class ConfigurableTask(Task):
metric_fn = metric_name.__call__ metric_fn = metric_name.__call__
metric_name = metric_name.__name__ metric_name = metric_name.__name__
else: else:
assert type(metric_name) == str assert isinstance(metric_name, str)
if use_hf_evaluate: if use_hf_evaluate:
metric_fn = get_evaluate(metric_name, **kwargs) metric_fn = get_evaluate(metric_name, **kwargs)
elif metric_name in METRIC_REGISTRY: elif metric_name in METRIC_REGISTRY:
...@@ -620,7 +602,6 @@ class ConfigurableTask(Task): ...@@ -620,7 +602,6 @@ class ConfigurableTask(Task):
self._aggregation_list[metric_name] = metric_fn self._aggregation_list[metric_name] = metric_fn
else: else:
if "aggregation" in metric_config: if "aggregation" in metric_config:
agg_name = metric_config["aggregation"] agg_name = metric_config["aggregation"]
if isinstance(agg_name, str): if isinstance(agg_name, str):
self._aggregation_list[metric_name] = get_aggregation( self._aggregation_list[metric_name] = get_aggregation(
...@@ -1028,7 +1009,6 @@ class ConfigurableTask(Task): ...@@ -1028,7 +1009,6 @@ class ConfigurableTask(Task):
) )
def process_results(self, doc, results): def process_results(self, doc, results):
# Process results returns 1 of X things per doc/results # Process results returns 1 of X things per doc/results
# 1. A score # 1. A score
# 2. Components to be processed later to obtained a score. such as gold and prediction # 2. Components to be processed later to obtained a score. such as gold and prediction
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment