_result_handler.py 4.84 KB
Newer Older
1
2
3
4
5
6
7
8
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""SuperBench CLI result subgroup command handler."""

from knack.util import CLIError

from superbench.analyzer import DataDiagnosis
9
from superbench.analyzer import ResultSummary
10
from superbench.analyzer import BaselineGeneration
11
12
13
14
from superbench.common.utils import create_sb_output_dir
from superbench.cli._handler import check_argument_file


15
def diagnosis_command_handler(
16
17
    raw_data_file,
    rule_file,
18
    baseline_file=None,
19
20
21
22
    output_dir=None,
    output_file_format='excel',
    output_all=False,
    decimal_place_value=2
23
):
24
25
26
27
28
29
30
    """Run data diagnosis.

    Args:
        raw_data_file (str): Path to raw data jsonl file.
        rule_file (str): Path to baseline yaml file.
        baseline_file (str): Path to baseline json file.
        output_dir (str): Path to output directory.
31
        output_file_format (str): Format of the output file, 'excel', 'json', 'md' or 'html'. Defaults to 'excel'.
32
        output_all (bool): output diagnosis results for all nodes
33
        decimal_place_value (int): Number of decimal places to show in output.
34
35
36
37
38
    """
    try:
        # Create output directory
        sb_output_dir = create_sb_output_dir(output_dir)
        # Check arguments
39
        supported_output_format = ['excel', 'json', 'md', 'html', 'jsonl']
40
41
        if output_file_format not in supported_output_format:
            raise CLIError('Output format must be in {}.'.format(str(supported_output_format)))
42
43
        check_argument_file('raw_data_file', raw_data_file)
        check_argument_file('rule_file', rule_file)
44
45
        if baseline_file:
            check_argument_file('baseline_file', baseline_file)
46
        # Run data diagnosis
47
        DataDiagnosis().run(
48
            raw_data_file, rule_file, baseline_file, sb_output_dir, output_file_format, output_all, decimal_place_value
49
        )
50
51
    except Exception as ex:
        raise RuntimeError('Failed to run diagnosis command.') from ex
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76


def summary_command_handler(raw_data_file, rule_file, output_dir=None, output_file_format='md', decimal_place_value=2):
    """Run result summary.

    Args:
        raw_data_file (str): Path to raw data jsonl file.
        rule_file (str): Path to baseline yaml file.
        output_dir (str): Path to output directory.
        output_file_format (str): Format of the output file, 'excel', 'md' or 'html'. Defaults to 'md'.
        decimal_place_value (int): Number of decimal places to show in output.
    """
    try:
        # Create output directory
        sb_output_dir = create_sb_output_dir(output_dir)
        # Check arguments
        supported_output_format = ['excel', 'html', 'md']
        if output_file_format not in supported_output_format:
            raise CLIError('Output format must be in {}.'.format(str(supported_output_format)))
        check_argument_file('raw_data_file', raw_data_file)
        check_argument_file('rule_file', rule_file)
        # Run result summary
        ResultSummary().run(raw_data_file, rule_file, sb_output_dir, output_file_format, decimal_place_value)
    except Exception as ex:
        raise RuntimeError('Failed to run summary command.') from ex
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118


def generate_baseline_command_handler(
    raw_data_file,
    summary_rule_file,
    diagnosis_rule_file=None,
    baseline_file=None,
    output_dir=None,
    decimal_place_value=2
):
    """Run result generate-baseline.

    If diagnosis_rule_file is None, use mean of the data as baseline.
    If diagnosis_rule_file is not None, use the rules in diagnosis_rule_file to execute fix_threshold algorithm.

    Args:
        raw_data_file (str): Path to raw data jsonl file.
        summary_rule_file (str): the file name of the summary rule file.
        diagnosis_rule_file (str): the file name of the diagnosis rules which used in fix_threshold algorithm.
        baseline_file (str): the file name of the previous baseline file that plan to merge with current baseline.
        output_dir (str): the directory to save the baseline file.
        decimal_place_value (int): the number of digits after the decimal point.
    """
    try:
        # Create output directory
        sb_output_dir = create_sb_output_dir(output_dir)
        # Check arguments
        check_argument_file('raw_data_file', raw_data_file)
        check_argument_file('rule_file', summary_rule_file)
        algorithm = 'mean'
        if diagnosis_rule_file:
            algorithm = 'fix_threshold'
            check_argument_file('rule_file', diagnosis_rule_file)
        if baseline_file:
            check_argument_file('baseline_file', baseline_file)
        # Run result generate-baseline
        BaselineGeneration().run(
            raw_data_file, summary_rule_file, diagnosis_rule_file, baseline_file, algorithm, sb_output_dir,
            decimal_place_value
        )
    except Exception as ex:
        raise RuntimeError('Failed to run generate-baseline command.') from ex