Commit 1d262a59 authored by lintangsutawika's avatar lintangsutawika
Browse files

change how metrics are registered

parent e7cd7d68
......@@ -77,42 +77,59 @@ METRIC_AGGREGATION_REGISTRY = {}
AGGREGATION_REGISTRY = {}
HIGHER_IS_BETTER_REGISTRY = {}
# DEFAULT_METRIC_REGISTRY = {
# "loglikelihood": [
# "perplexity",
# "acc",
# ],
# "loglikelihood_rolling": ["word_perplexity", "byte_perplexity", "bits_per_byte"],
# "multiple_choice": ["acc", "acc_norm"],
# "generate_until": ["exact_match"],
# }
DEFAULT_METRIC_REGISTRY = {
"loglikelihood": [
"perplexity",
"acc",
],
"loglikelihood_rolling": ["word_perplexity", "byte_perplexity", "bits_per_byte"],
"multiple_choice": ["acc", "acc_norm"],
"generate_until": ["exact_match"],
"loglikelihood": [],
"loglikelihood_rolling": [],
"multiple_choice": [],
"generate_until": [],
}
def register_metric(**args):
def register_metric(
metric,
higher_is_better=None,
output_type=None,
aggregation=None,
):
# TODO: do we want to enforce a certain interface to registered metrics?
def decorate(fn):
assert "metric" in args
name = args["metric"]
for key, registry in [
("metric", METRIC_REGISTRY),
("higher_is_better", HIGHER_IS_BETTER_REGISTRY),
("aggregation", METRIC_AGGREGATION_REGISTRY),
]:
if key in args:
value = args[key]
assert (
value not in registry
), f"{key} named '{value}' conflicts with existing registered {key}!"
if key == "metric":
registry[name] = fn
elif key == "aggregation":
registry[name] = AGGREGATION_REGISTRY[value]
else:
registry[name] = value
METRIC_REGISTRY[metric] = fn(aggregation=aggregation)
if higher_is_better is not None:
HIGHER_IS_BETTER_REGISTRY[metric] = higher_is_better
if output_type is not None:
DEFAULT_METRIC_REGISTRY[output_type].append(metric)
# for key, registry in [
# ("output_type", OUTPUT_TYPE_REGISTRY),
# ("metric", METRIC_REGISTRY),
# ("higher_is_better", HIGHER_IS_BETTER_REGISTRY),
# ("aggregation", METRIC_AGGREGATION_REGISTRY),
# ]:
# if key in args:
# value = args[key]
# assert (
# value not in registry
# ), f"{key} named '{value}' conflicts with existing registered {key}!"
# if key == "metric":
# registry[name] = fn
# elif key == "aggregation":
# registry[name] = AGGREGATION_REGISTRY[value]
# else:
# registry[name] = value
return fn
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment