Commit 1d262a59 authored by lintangsutawika's avatar lintangsutawika
Browse files

change how metrics are registered

parent e7cd7d68
...@@ -77,42 +77,59 @@ METRIC_AGGREGATION_REGISTRY = {} ...@@ -77,42 +77,59 @@ METRIC_AGGREGATION_REGISTRY = {}
AGGREGATION_REGISTRY = {} AGGREGATION_REGISTRY = {}
HIGHER_IS_BETTER_REGISTRY = {} HIGHER_IS_BETTER_REGISTRY = {}
# DEFAULT_METRIC_REGISTRY = {
# "loglikelihood": [
# "perplexity",
# "acc",
# ],
# "loglikelihood_rolling": ["word_perplexity", "byte_perplexity", "bits_per_byte"],
# "multiple_choice": ["acc", "acc_norm"],
# "generate_until": ["exact_match"],
# }
DEFAULT_METRIC_REGISTRY = { DEFAULT_METRIC_REGISTRY = {
"loglikelihood": [ "loglikelihood": [],
"perplexity", "loglikelihood_rolling": [],
"acc", "multiple_choice": [],
], "generate_until": [],
"loglikelihood_rolling": ["word_perplexity", "byte_perplexity", "bits_per_byte"],
"multiple_choice": ["acc", "acc_norm"],
"generate_until": ["exact_match"],
} }
def register_metric(**args): def register_metric(
metric,
higher_is_better=None,
output_type=None,
aggregation=None,
):
# TODO: do we want to enforce a certain interface to registered metrics? # TODO: do we want to enforce a certain interface to registered metrics?
def decorate(fn): def decorate(fn):
assert "metric" in args METRIC_REGISTRY[metric] = fn(aggregation=aggregation)
name = args["metric"]
if higher_is_better is not None:
for key, registry in [ HIGHER_IS_BETTER_REGISTRY[metric] = higher_is_better
("metric", METRIC_REGISTRY), if output_type is not None:
("higher_is_better", HIGHER_IS_BETTER_REGISTRY), DEFAULT_METRIC_REGISTRY[output_type].append(metric)
("aggregation", METRIC_AGGREGATION_REGISTRY),
]: # for key, registry in [
# ("output_type", OUTPUT_TYPE_REGISTRY),
if key in args: # ("metric", METRIC_REGISTRY),
value = args[key] # ("higher_is_better", HIGHER_IS_BETTER_REGISTRY),
assert ( # ("aggregation", METRIC_AGGREGATION_REGISTRY),
value not in registry # ]:
), f"{key} named '{value}' conflicts with existing registered {key}!"
# if key in args:
if key == "metric": # value = args[key]
registry[name] = fn # assert (
elif key == "aggregation": # value not in registry
registry[name] = AGGREGATION_REGISTRY[value] # ), f"{key} named '{value}' conflicts with existing registered {key}!"
else:
registry[name] = value # if key == "metric":
# registry[name] = fn
# elif key == "aggregation":
# registry[name] = AGGREGATION_REGISTRY[value]
# else:
# registry[name] = value
return fn return fn
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment