test_result.py 4.41 KB
Newer Older
1
2
3
4
5
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""Tests for BenchmarkResult module."""

6
7
import os

8
from superbench.benchmarks import BenchmarkType, ReturnCode, ReduceType
9
10
11
12
13
from superbench.benchmarks.result import BenchmarkResult


def test_add_raw_data():
    """Test interface BenchmarkResult.add_raw_data()."""
14
    result = BenchmarkResult('micro', BenchmarkType.MICRO, ReturnCode.SUCCESS)
15
16
    result.add_raw_data('metric1', 'raw log 1', False)
    result.add_raw_data('metric1', 'raw log 2', False)
17
18
    assert (result.raw_data['metric1'][0] == 'raw log 1')
    assert (result.raw_data['metric1'][1] == 'raw log 2')
19
20
    assert (result.type == BenchmarkType.MICRO)
    assert (result.return_code == ReturnCode.SUCCESS)
21

22
    result = BenchmarkResult('model', BenchmarkType.MODEL, ReturnCode.SUCCESS)
23
24
    result.add_raw_data('metric1', [1, 2, 3], False)
    result.add_raw_data('metric1', [4, 5, 6], False)
25
26
    assert (result.raw_data['metric1'][0] == [1, 2, 3])
    assert (result.raw_data['metric1'][1] == [4, 5, 6])
27
28
    assert (result.type == BenchmarkType.MODEL)
    assert (result.return_code == ReturnCode.SUCCESS)
29

30
31
32
33
34
35
36
37
38
39
    # Test log_raw_data = True.
    result = BenchmarkResult('micro', BenchmarkType.MICRO, ReturnCode.SUCCESS)
    result.add_raw_data('metric1', 'raw log 1', True)
    result.add_raw_data('metric1', 'raw log 2', True)
    assert (result.type == BenchmarkType.MICRO)
    assert (result.return_code == ReturnCode.SUCCESS)
    raw_data_file = os.path.join(os.getcwd(), 'rawdata.log')
    assert (os.path.isfile(raw_data_file))
    os.remove(raw_data_file)

40
41
42

def test_add_result():
    """Test interface BenchmarkResult.add_result()."""
43
    result = BenchmarkResult('micro', BenchmarkType.MICRO, ReturnCode.SUCCESS)
44
45
46
47
48
49
50
51
    result.add_result('metric1', 300)
    result.add_result('metric1', 200)
    assert (result.result['metric1'][0] == 300)
    assert (result.result['metric1'][1] == 200)


def test_set_timestamp():
    """Test interface BenchmarkResult.set_timestamp()."""
52
    result = BenchmarkResult('micro', BenchmarkType.MICRO, ReturnCode.SUCCESS)
53
54
55
56
57
58
59
60
61
    start_time = '2021-02-03 16:59:49'
    end_time = '2021-02-03 17:00:08'
    result.set_timestamp(start_time, end_time)
    assert (result.start_time == start_time)
    assert (result.end_time == end_time)


def test_set_benchmark_type():
    """Test interface BenchmarkResult.set_benchmark_type()."""
62
63
64
    result = BenchmarkResult('micro', BenchmarkType.MICRO, ReturnCode.SUCCESS)
    result.set_benchmark_type(BenchmarkType.MICRO)
    assert (result.type == BenchmarkType.MICRO)
65
66
67
68


def test_set_return_code():
    """Test interface BenchmarkResult.set_return_code()."""
69
70
71
72
73
74
75
76
77
    result = BenchmarkResult('micro', BenchmarkType.MICRO, ReturnCode.SUCCESS)
    assert (result.return_code == ReturnCode.SUCCESS)
    assert (result.result['return_code'] == [ReturnCode.SUCCESS.value])
    result.set_return_code(ReturnCode.INVALID_ARGUMENT)
    assert (result.return_code == ReturnCode.INVALID_ARGUMENT)
    assert (result.result['return_code'] == [ReturnCode.INVALID_ARGUMENT.value])
    result.set_return_code(ReturnCode.INVALID_BENCHMARK_RESULT)
    assert (result.return_code == ReturnCode.INVALID_BENCHMARK_RESULT)
    assert (result.result['return_code'] == [ReturnCode.INVALID_BENCHMARK_RESULT.value])
78
79
80
81
82


def test_serialize_deserialize():
    """Test serialization/deserialization and compare the results."""
    # Result with one metric.
83
    result = BenchmarkResult('pytorch-bert-base1', BenchmarkType.MICRO, ReturnCode.SUCCESS, run_count=2)
84
85
86
    result.add_result('metric1', 300, ReduceType.MAX)
    result.add_result('metric1', 200, ReduceType.MAX)
    result.add_result('metric2', 100, ReduceType.AVG)
87
88
89
    result.add_raw_data('metric1', [1, 2, 3], False)
    result.add_raw_data('metric1', [4, 5, 6], False)
    result.add_raw_data('metric1', [7, 8, 9], False)
90
91
92
    start_time = '2021-02-03 16:59:49'
    end_time = '2021-02-03 17:00:08'
    result.set_timestamp(start_time, end_time)
93
    result.set_benchmark_type(BenchmarkType.MICRO)
94
95
96
97
98

    expected = (
        '{"name": "pytorch-bert-base1", "type": "micro", "run_count": 2, "return_code": 0, '
        '"start_time": "2021-02-03 16:59:49", "end_time": "2021-02-03 17:00:08", '
        '"raw_data": {"metric1": [[1, 2, 3], [4, 5, 6], [7, 8, 9]]}, '
99
        '"result": {"return_code": [0], "metric1": [300, 200], "metric2": [100]}, '
100
        '"reduce_op": {"return_code": null, "metric1": "max", "metric2": "avg"}}'
101
102
    )
    assert (result.to_string() == expected)