"src/targets/gpu/device/erf.cpp" did not exist on "0d6a3ee97bdd163ad575887264c475d8d2579bc2"
Commit 3888193d authored by lintangsutawika's avatar lintangsutawika
Browse files

simplify registry

parent 9d6bc929
......@@ -71,22 +71,9 @@ def register_group(name):
return decorate
OUTPUT_TYPE_REGISTRY = {}
METRIC_REGISTRY = {}
METRIC_AGGREGATION_REGISTRY = {}
AGGREGATION_REGISTRY = {}
METRIC_FUNCTION_REGISTRY = {}
HIGHER_IS_BETTER_REGISTRY = {}
# DEFAULT_METRIC_REGISTRY = {
# "loglikelihood": [
# "perplexity",
# "acc",
# ],
# "loglikelihood_rolling": ["word_perplexity", "byte_perplexity", "bits_per_byte"],
# "multiple_choice": ["acc", "acc_norm"],
# "generate_until": ["exact_match"],
# }
DEFAULT_METRIC_REGISTRY = {
"loglikelihood": [],
"loglikelihood_rolling": [],
......@@ -96,40 +83,53 @@ DEFAULT_METRIC_REGISTRY = {
def register_metric(
metric,
metric=None,
higher_is_better=None,
output_type=None,
aggregation=None,
# aggregation=None,
):
# TODO: do we want to enforce a certain interface to registered metrics?
def decorate(fn):
METRIC_REGISTRY[metric] = fn(aggregation=aggregation)
if higher_is_better is not None:
HIGHER_IS_BETTER_REGISTRY[metric] = higher_is_better
if output_type is not None:
DEFAULT_METRIC_REGISTRY[output_type].append(metric)
# for key, registry in [
# ("output_type", OUTPUT_TYPE_REGISTRY),
# ("metric", METRIC_REGISTRY),
# ("higher_is_better", HIGHER_IS_BETTER_REGISTRY),
# ("aggregation", METRIC_AGGREGATION_REGISTRY),
# ]:
# if key in args:
# value = args[key]
# assert (
# value not in registry
# ), f"{key} named '{value}' conflicts with existing registered {key}!"
# if key == "metric":
# registry[name] = fn
# elif key == "aggregation":
# registry[name] = AGGREGATION_REGISTRY[value]
# else:
# registry[name] = value
if type(metric) == str:
metric_list = [metric]
elif type(metric) == list:
metric_list = metric
for _metric in metric_list:
METRIC_FUNCTION_REGISTRY[_metric] = fn
if higher_is_better is not None:
HIGHER_IS_BETTER_REGISTRY[_metric] = higher_is_better
if output_type is not None:
if type(output_type) == str:
output_type_list = [output_type]
elif type(output_type) == list:
output_type_list = output_type
for _output_type in output_type_list:
DEFAULT_METRIC_REGISTRY[_output_type].append(_metric)
# # for key, registry in [
# # ("output_type", OUTPUT_TYPE_REGISTRY),
# # ("metric", METRIC_REGISTRY),
# # ("higher_is_better", HIGHER_IS_BETTER_REGISTRY),
# # ("aggregation", METRIC_AGGREGATION_REGISTRY),
# # ]:
# # if key in args:
# # value = args[key]
# # assert (
# # value not in registry
# # ), f"{key} named '{value}' conflicts with existing registered {key}!"
# # if key == "metric":
# # registry[name] = fn
# # elif key == "aggregation":
# # registry[name] = AGGREGATION_REGISTRY[value]
# # else:
# # registry[name] = value
return fn
......@@ -139,8 +139,8 @@ def register_metric(
def get_metric(name, hf_evaluate_metric=False):
if not hf_evaluate_metric:
if name in METRIC_REGISTRY:
return METRIC_REGISTRY[name]
if name in METRIC_FUNCTION_REGISTRY:
return METRIC_FUNCTION_REGISTRY[name]
else:
eval_logger.warning(
f"Could not find registered metric '{name}' in lm-eval, searching in HF Evaluate library..."
......@@ -155,36 +155,36 @@ def get_metric(name, hf_evaluate_metric=False):
)
def register_aggregation(name):
def decorate(fn):
assert (
name not in AGGREGATION_REGISTRY
), f"aggregation named '{name}' conflicts with existing registered aggregation!"
# def register_aggregation(name):
# def decorate(fn):
# assert (
# name not in AGGREGATION_REGISTRY
# ), f"aggregation named '{name}' conflicts with existing registered aggregation!"
AGGREGATION_REGISTRY[name] = fn
return fn
# AGGREGATION_REGISTRY[name] = fn
# return fn
return decorate
# return decorate
def get_aggregation(name):
# def get_aggregation(name):
try:
return AGGREGATION_REGISTRY[name]
except KeyError:
eval_logger.warning(
"{} not a registered aggregation metric!".format(name),
)
# try:
# return AGGREGATION_REGISTRY[name]
# except KeyError:
# eval_logger.warning(
# "{} not a registered aggregation metric!".format(name),
# )
def get_metric_aggregation(name):
# def get_metric_aggregation(name):
try:
return METRIC_AGGREGATION_REGISTRY[name]
except KeyError:
eval_logger.warning(
"{} metric is not assigned a default aggregation!".format(name),
)
# try:
# return METRIC_AGGREGATION_REGISTRY[name]
# except KeyError:
# eval_logger.warning(
# "{} metric is not assigned a default aggregation!".format(name),
# )
def is_higher_better(metric_name):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment