context.py 2.44 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""A module for unified context of benchmarks."""

import enum


class Enum(enum.Enum):
    """Customized Enum class."""
    @classmethod
    def get_values(cls):
        """Return the value list."""
        values = [item.value for item in cls]
        return values

    def __str__(self):
        """Value as the string."""
        return self.value


class Platform(Enum):
    """The Enum class representing different platforms."""
    CPU = 'CPU'
    CUDA = 'CUDA'
    ROCM = 'ROCm'


class Framework(Enum):
    """The Enum class representing different frameworks."""
    ONNX = 'onnx'
    PYTORCH = 'pytorch'
    TENSORFLOW1 = 'tf1'
    TENSORFLOW2 = 'tf2'
    NONE = 'none'


class BenchmarkType(Enum):
    """The Enum class representing different types of benchmarks."""
    MODEL = 'model'
    MICRO = 'micro'
    DOCKER = 'docker'


class Precision(Enum):
    """The Enum class representing different data precisions."""
    FLOAT16 = 'float16'
    FLOAT32 = 'float32'
    FLOAT64 = 'float64'
    BFLOAT16 = 'bfloat16'
    UINT8 = 'uint8'
    INT8 = 'int8'
    INT16 = 'int16'
    INT32 = 'int32'
    INT64 = 'int64'


class ModelAction(Enum):
    """The Enum class representing different model process."""
    TRAIN = 'train'
    INFERENCE = 'inference'


class BenchmarkContext():
    """Context class of all benchmarks.

    Containing all information to launch one benchmark.
    """
    def __init__(self, name, platform, parameters='', framework=Framework.NONE):
        """Constructor.

        Args:
            name (str): name of benchmark in config file.
            platform (Platform): Platform types like CUDA, ROCM.
            parameters (str): predefined parameters of benchmark.
            framework (Framework): Framework types like ONNX, PYTORCH.
        """
        self.__name = name
        self.__platform = platform
        self.__parameters = parameters
        self.__framework = framework

    @property
    def name(self):
        """Decoration function to access __name."""
        return self.__name

    @property
    def platform(self):
        """Decoration function to access __platform."""
        return self.__platform

    @property
    def parameters(self):
        """Decoration function to access __parameters."""
        return self.__parameters

    @property
    def framework(self):
        """Decoration function to access __framework."""
        return self.__framework