context.py 2.94 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""A module for unified context of benchmarks."""

import enum


class Enum(enum.Enum):
    """Customized Enum class."""
    @classmethod
    def get_values(cls):
        """Return the value list."""
        values = [item.value for item in cls]
        return values

    def __str__(self):
        """Value as the string."""
19
        return str(self.value)
20
21
22
23
24
25
26
27
28
29
30


class Platform(Enum):
    """The Enum class representing different platforms."""
    CPU = 'CPU'
    CUDA = 'CUDA'
    ROCM = 'ROCm'


class Framework(Enum):
    """The Enum class representing different frameworks."""
31
    ONNXRUNTIME = 'onnxruntime'
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
    PYTORCH = 'pytorch'
    TENSORFLOW1 = 'tf1'
    TENSORFLOW2 = 'tf2'
    NONE = 'none'


class BenchmarkType(Enum):
    """The Enum class representing different types of benchmarks."""
    MODEL = 'model'
    MICRO = 'micro'
    DOCKER = 'docker'


class Precision(Enum):
    """The Enum class representing different data precisions."""
47
48
49
    FP8_HYBRID = 'fp8_hybrid'
    FP8_E4M3 = 'fp8_e4m3'
    FP8_E5M2 = 'fp8_e5m2'
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
    FLOAT16 = 'float16'
    FLOAT32 = 'float32'
    FLOAT64 = 'float64'
    BFLOAT16 = 'bfloat16'
    UINT8 = 'uint8'
    INT8 = 'int8'
    INT16 = 'int16'
    INT32 = 'int32'
    INT64 = 'int64'


class ModelAction(Enum):
    """The Enum class representing different model process."""
    TRAIN = 'train'
    INFERENCE = 'inference'


67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
class DistributedImpl(Enum):
    """The Enum class representing different distributed implementations."""
    DDP = 'ddp'
    MIRRORED = 'mirrored'
    MW_MIRRORED = 'multiworkermirrored'
    PS = 'parameterserver'
    HOROVOD = 'horovod'


class DistributedBackend(Enum):
    """The Enum class representing different distributed backends."""
    NCCL = 'nccl'
    MPI = 'mpi'
    GLOO = 'gloo'


83
84
85
86
87
88
89
90
91
92
93
94
class BenchmarkContext():
    """Context class of all benchmarks.

    Containing all information to launch one benchmark.
    """
    def __init__(self, name, platform, parameters='', framework=Framework.NONE):
        """Constructor.

        Args:
            name (str): name of benchmark in config file.
            platform (Platform): Platform types like CUDA, ROCM.
            parameters (str): predefined parameters of benchmark.
95
            framework (Framework): Framework types like ONNXRUNTIME, PYTORCH.
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
        """
        self.__name = name
        self.__platform = platform
        self.__parameters = parameters
        self.__framework = framework

    @property
    def name(self):
        """Decoration function to access __name."""
        return self.__name

    @property
    def platform(self):
        """Decoration function to access __platform."""
        return self.__platform

    @property
    def parameters(self):
        """Decoration function to access __parameters."""
        return self.__parameters

    @property
    def framework(self):
        """Decoration function to access __framework."""
        return self.__framework