register.py 2.24 KB
Newer Older
1
2
3
4
5
6
7
import os

task_registry = {}
group_registry = {}
task2func_index = {}
func2task_index = {}

lintangsutawika's avatar
lintangsutawika committed
8

9
10
11
12
13
14
15
16
17
18
def register_task(name):
    def wrapper(func):

        task_registry[name] = func
        func2task_index[func.__name__] = name
        task2func_index[name] = func.__name__
        return func

    return wrapper

lintangsutawika's avatar
lintangsutawika committed
19

20
21
22
23
24
25
def register_group(name):
    def wrapper(func):

        func_name = func2task_index[func.__name__]

        if name in group_registry:
lintangsutawika's avatar
lintangsutawika committed
26
            group_registry[name].append(func_name)
27
28
29
        else:
            group_registry[name] = [func_name]
        return func
lintangsutawika's avatar
lintangsutawika committed
30

31
    return wrapper
lintangsutawika's avatar
lintangsutawika committed
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105


metric_registry = {}
aggregation_registry = {}
default_aggregation_registry = {}
higher_is_better_registry = {}
output_type_registry = {}
metric2func_index = {}
func2metric_index = {}
aggregation2func_index = {}
func2aggregation_index = {}


def register_metric(name):
    def wrapper(func):

        metric_registry[name] = func
        func2metric_index[func.__name__] = name
        metric2func_index[name] = func.__name__
        return func

    return wrapper


def register_aggregation(name):
    def wrapper(func):

        aggregation_registry[name] = func
        func2aggregation_index[func.__name__] = name
        aggregation2func_index[name] = func.__name__
        return func

    return wrapper


def register_default_aggregation(aggregation):
    def wrapper(func):

        if aggregation in aggregation_registry:
            metric_name = func2metric_index[func.__name__]
            default_aggregation_registry[metric_name] = aggregation
        else:
            print("aggregation not registered")
        return func

    return wrapper


def register_higher_is_better(higher_is_better):
    def wrapper(func):

        if func.__name__ in func2metric_index:
            metric_name = func2metric_index[func.__name__]
            higher_is_better_registry[metric_name] = higher_is_better
        else:
            pass

        return func

    return wrapper


def register_output_type(output_type):
    def wrapper(func):

        if func.__name__ in func2metric_index:
            metric_name = func2metric_index[func.__name__]
            output_type_registry[metric_name] = output_type
        else:
            pass

        return func

    return wrapper