relu.py 4 KB
Newer Older
PanZezhongQY's avatar
PanZezhongQY committed
1
import ctypes
2
from ctypes import c_uint64
PanZezhongQY's avatar
PanZezhongQY committed
3
from enum import Enum, auto
4

PanZezhongQY's avatar
PanZezhongQY committed
5
import torch
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from libinfiniop import (
    LIBINFINIOP,
    InfiniDeviceNames,
    InfiniDtype,
    InfiniDtypeNames,
    TestTensor,
    TestWorkspace,
    check_error,
    debug,
    get_args,
    get_test_devices,
    get_tolerance,
    infiniopOperatorDescriptor_t,
    profile_operation,
    test_operator,
)
PanZezhongQY's avatar
PanZezhongQY committed
22

23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# ==============================================================================
#  Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES_ = [
    # tensor_shape, inplace
    # TODO: Uncomment the following line.
    # ((),),
    ((1, 3),),
    ((3, 3),),
    ((32, 20, 512),),
    ((33, 333, 333),),
    ((32, 256, 112, 112),),
    ((3, 3, 13, 9, 17),),
]
PanZezhongQY's avatar
PanZezhongQY committed
38
39
40
41
42
43
44


class Inplace(Enum):
    OUT_OF_PLACE = auto()
    INPLACE_X = auto()


45
46
47
48
49
# Inplace options applied for each test case in _TEST_CASES_
_INPLACE = [
    Inplace.OUT_OF_PLACE,
    Inplace.INPLACE_X,
]
PanZezhongQY's avatar
PanZezhongQY committed
50

51
52
53
54
55
56
# Form the test cases by appending each element of _INPLACE to each tuple in _TEST_CASES_
_TEST_CASES = [
    test_case + (inplace_item,)
    for test_case in _TEST_CASES_
    for inplace_item in _INPLACE
]
PanZezhongQY's avatar
PanZezhongQY committed
57

58
59
60
61
62
63
64
65
66
67
68
69
70
71
# Data types used for testing
_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.F32, InfiniDtype.BF16]

# Tolerance map for different data types
_TOLERANCE_MAP = {
    InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-3},
    InfiniDtype.F32: {"atol": 1e-7, "rtol": 1e-7},
    InfiniDtype.BF16: {"atol": 1e-3, "rtol": 1e-3},
}

DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000
PanZezhongQY's avatar
PanZezhongQY committed
72
73
74
75
76
77
78


def relu(x):
    return torch.nn.functional.relu(x).to(x.dtype)


def test(
79
    handle, device, shape, inplace=Inplace.OUT_OF_PLACE, dtype=torch.float16, sync=None
PanZezhongQY's avatar
PanZezhongQY committed
80
):
81
82
83
84
85
86
87
88
89
    x_torch_tensor = torch.rand(shape) * 2 - 1

    x = TestTensor(
        shape,
        x_torch_tensor.stride(),
        dtype,
        device,
        mode="manual",
        set_tensor=x_torch_tensor,
PanZezhongQY's avatar
PanZezhongQY committed
90
91
    )

92
93
94
95
    if inplace == Inplace.INPLACE_X:
        y = x
    else:
        y = TestTensor(shape, None, dtype, device)
PanZezhongQY's avatar
PanZezhongQY committed
96

97
98
    if y.is_broadcast():
        return
PanZezhongQY's avatar
PanZezhongQY committed
99

100
101
102
103
104
    print(
        f"Testing Relu on {InfiniDeviceNames[device]} with shape:{shape} dtype:{InfiniDtypeNames[dtype]} inplace: {inplace}"
    )

    ans = relu(x.torch_tensor())
PanZezhongQY's avatar
PanZezhongQY committed
105

106
    if sync is not None:
107
        sync()
108

109
    descriptor = infiniopOperatorDescriptor_t()
PanZezhongQY's avatar
PanZezhongQY committed
110
    check_error(
111
112
        LIBINFINIOP.infiniopCreateReluDescriptor(
            handle, ctypes.byref(descriptor), y.descriptor, x.descriptor
PanZezhongQY's avatar
PanZezhongQY committed
113
114
115
116
        )
    )

    # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
117
118
    for tensor in [x, y]:
        tensor.destroy_desc()
PanZezhongQY's avatar
PanZezhongQY committed
119

120
121
122
123
124
125
126
    workspace_size = c_uint64(0)
    check_error(
        LIBINFINIOP.infiniopGetReluWorkspaceSize(
            descriptor, ctypes.byref(workspace_size)
        )
    )
    workspace = TestWorkspace(workspace_size.value, y.device)
127

128
129
130
131
    def lib_relu():
        LIBINFINIOP.infiniopRelu(
            descriptor, workspace.data(), workspace.size(), y.data(), x.data(), None
        )
PanZezhongQY's avatar
PanZezhongQY committed
132

133
    lib_relu()
PanZezhongQY's avatar
PanZezhongQY committed
134

135
136
137
138
    atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
    if DEBUG:
        debug(y.actual_tensor(), ans, atol=atol, rtol=rtol)
    assert torch.allclose(y.actual_tensor(), ans, atol=atol, rtol=rtol)
PanZezhongQY's avatar
PanZezhongQY committed
139

140
141
    # Profiling workflow
    if PROFILE:
142
        # fmt: off
143
144
        profile_operation("PyTorch", lambda: relu(x.torch_tensor()), device, NUM_PRERUN, NUM_ITERATIONS)
        profile_operation("    lib", lambda: lib_relu(), device, NUM_PRERUN, NUM_ITERATIONS)
145
        # fmt: on
146
147

    check_error(LIBINFINIOP.infiniopDestroyReluDescriptor(descriptor))
PanZezhongQY's avatar
PanZezhongQY committed
148
149
150
151


if __name__ == "__main__":
    args = get_args()
152
153
154
155
156
157
158
159
160
161

    # Configure testing options
    DEBUG = args.debug
    PROFILE = args.profile
    NUM_PRERUN = args.num_prerun
    NUM_ITERATIONS = args.num_iterations

    for device in get_test_devices(args):
        test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES)

PanZezhongQY's avatar
PanZezhongQY committed
162
    print("\033[92mTest passed!\033[0m")