rule_base.py 4.06 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""A base module for rule-related module."""

import re

from superbench.common.utils import logger
from superbench.analyzer import file_handler


class RuleBase():
    """RuleBase class."""
    def __init__(self):
        """Init function."""
        self._sb_rules = {}
        self._benchmark_metrics_dict = {}
        self._enable_metrics = set()

    def _get_metrics_by_benchmarks(self, metrics_list):
        """Get mappings of benchmarks:metrics from metrics_list.

        Args:
            metrics_list (list): list of metrics

        Returns:
            dict: metrics organized by benchmarks
        """
        benchmarks_metrics = {}
        for metric in metrics_list:
            if '/' not in metric:
                logger.warning('RuleBase: get_metrics_by_benchmarks - {} does not have benchmark_name'.format(metric))
            else:
                benchmark = metric.split('/')[0]
35
36
37
                # support annotations in benchmark naming
                if ':' in benchmark:
                    benchmark = metric.split(':')[0]
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
                if benchmark not in benchmarks_metrics:
                    benchmarks_metrics[benchmark] = set()
                benchmarks_metrics[benchmark].add(metric)
        return benchmarks_metrics

    def _check_and_format_rules(self, rule, name):
        """Check the rule of the metric whether the format is valid.

        Args:
            rule (dict): the rule
            name (str): the rule name

        Returns:
            dict: the rule for the metric
        """
        # check if rule is supported
        if 'categories' not in rule:
            logger.log_and_raise(exception=Exception, msg='{} lack of category'.format(name))
        if 'metrics' in rule:
            if isinstance(rule['metrics'], str):
                rule['metrics'] = [rule['metrics']]
        return rule

    def _get_metrics(self, rule, benchmark_rules):
        """Get metrics in the rule.

        Parse metric regex in the rule, and store the (metric, -1) pair
        in _sb_rules[rule]['metrics']

        Args:
            rule (str): the name of the rule
            benchmark_rules (dict): the dict of rules
        """
        metrics_in_rule = benchmark_rules[rule]['metrics']
        benchmark_metrics_dict_in_rule = self._get_metrics_by_benchmarks(metrics_in_rule)
        for benchmark_name in benchmark_metrics_dict_in_rule:
            if benchmark_name not in self._benchmark_metrics_dict:
                logger.warning('RuleBase: get metrics failed - {}'.format(benchmark_name))
                continue
            # get rules and criteria for each metric
            for metric in self._benchmark_metrics_dict[benchmark_name]:
                # metric full name in baseline
                if metric in metrics_in_rule:
                    self._sb_rules[rule]['metrics'][metric] = -1
                    self._enable_metrics.add(metric)
                    continue
                # metric full name not in baseline, use regex to match
                for metric_regex in benchmark_metrics_dict_in_rule[benchmark_name]:
                    if re.search(metric_regex, metric):
                        self._sb_rules[rule]['metrics'][metric] = -1
                        self._enable_metrics.add(metric)

    def _preprocess(self, raw_data_file, rule_file):
        """Preprocess/preparation operations for the rules.

        Args:
            raw_data_file (str): the path of raw data file
            rule_file (str): the path of rule file

        Returns:
            dict: dict of rules
        """
        # read raw data from file
        self._raw_data_df = file_handler.read_raw_data(raw_data_file)
        # re-organize metrics by benchmark names
        self._benchmark_metrics_dict = self._get_metrics_by_benchmarks(list(self._raw_data_df.columns))
        # check raw data whether empty
        if len(self._raw_data_df) == 0:
            logger.error('RuleBase: empty raw data')
            return None
        # read rules
        rules = file_handler.read_rules(rule_file)
        return rules