test_rulebase.py 3.25 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""Tests for RuleBase module."""

import unittest
from pathlib import Path

from superbench.analyzer import RuleBase
import superbench.analyzer.file_handler as file_handler


class TestRuleBase(unittest.TestCase):
    """Test for RuleBase class."""
    def setUp(self):
        """Method called to prepare the test fixture."""
        self.parent_path = Path(__file__).parent

    def test_rule_base(self):
        """Test for rule-based functions."""
        # Test - read_raw_data and get_metrics_from_raw_data
        # Positive case
        test_raw_data = str(self.parent_path / 'test_results.jsonl')
        test_rule_file = str(self.parent_path / 'test_rules.yaml')
        rulebase1 = RuleBase()
        rulebase1._raw_data_df = file_handler.read_raw_data(test_raw_data)
        rulebase1._benchmark_metrics_dict = rulebase1._get_metrics_by_benchmarks(list(rulebase1._raw_data_df))
        assert (len(rulebase1._raw_data_df) == 3)
        # Negative case
        test_rule_file_fake = str(self.parent_path / 'test_rules_fake.yaml')
31

32
33
        test_raw_data_fake = str(self.parent_path / 'test_results_fake.jsonl')
        rulebase2 = RuleBase()
34
        self.assertRaises(FileNotFoundError, file_handler.read_raw_data, test_raw_data_fake)
35
        rulebase2._benchmark_metrics_dict = rulebase2._get_metrics_by_benchmarks([])
36
37
38
39
40
41
42
43
44
45
46
47
48
        assert (len(rulebase2._benchmark_metrics_dict) == 0)
        metric_list = [
            'gpu_temperature', 'gpu_power_limit', 'gemm-flops/FP64',
            'bert_models/pytorch-bert-base/steptime_train_float32'
        ]
        self.assertDictEqual(
            rulebase2._get_metrics_by_benchmarks(metric_list), {
                'gemm-flops': {'gemm-flops/FP64'},
                'bert_models': {'bert_models/pytorch-bert-base/steptime_train_float32'}
            }
        )

        # Test - _preprocess
49
50
        self.assertRaises(Exception, rulebase1._preprocess, test_raw_data_fake, test_rule_file)
        self.assertRaises(Exception, rulebase1._preprocess, test_raw_data, test_rule_file_fake)
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
        rules = rulebase1._preprocess(test_raw_data, test_rule_file)
        assert (rules)

        # Test - _check_and_format_rules
        # Negative case
        false_rule = {
            'criteria': 'lambda x:x>0',
            'function': 'variance',
            'metrics': ['kernel-launch/event_overhead:\\d+']
        }
        metric = 'kernel-launch/event_overhead:0'
        self.assertRaises(Exception, rulebase1._check_and_format_rules, false_rule, metric)
        # Positive case
        true_rule = {
            'categories': 'KernelLaunch',
            'criteria': 'lambda x:x<-0.05',
            'function': 'variance',
            'metrics': 'kernel-launch/event_overhead:\\d+'
        }
        true_rule = rulebase1._check_and_format_rules(true_rule, metric)
        assert (true_rule)
        assert (true_rule['metrics'] == ['kernel-launch/event_overhead:\\d+'])

        # Test - _get_metrics
        rules = rules['superbench']['rules']
        for rule in ['rule0', 'rule1']:
            rulebase1._sb_rules[rule] = {}
            rulebase1._sb_rules[rule]['metrics'] = {}
            rulebase1._get_metrics(rule, rules)
            assert (len(rulebase1._sb_rules[rule]['metrics']) == 16)