conv.py 7.35 KB
Newer Older
1
import torch
PanZezhongQY's avatar
PanZezhongQY committed
2
import ctypes
3
from ctypes import c_uint64
PanZezhongQY's avatar
PanZezhongQY committed
4

5
6
7
8
from libinfiniop import (
    LIBINFINIOP,
    TestTensor,
    get_test_devices,
PanZezhongQY's avatar
PanZezhongQY committed
9
    check_error,
10
11
12
13
14
15
16
17
18
19
    test_operator,
    get_args,
    debug,
    get_tolerance,
    profile_operation,
    TestWorkspace,
    InfiniDtype,
    InfiniDtypeNames,
    InfiniDeviceNames,
    infiniopOperatorDescriptor_t,
PanZezhongQY's avatar
PanZezhongQY committed
20
)
21
22
from enum import Enum, auto
from typing import List, Tuple
PanZezhongQY's avatar
PanZezhongQY committed
23
24
25
26
27
28
29
30
31
import math
from torch.nn import functional as F

# constant for control whether profile the pytorch and lib functions
# NOTE: need to manually add synchronization function to the lib function,
#       e.g., cudaDeviceSynchronize() for CUDA
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000
32
33
34
35
36
37
38
39
40
41
42
43
44
_TEST_CASES = [
    # x_shape, x_stride, w_shape, w_stride, pads, strides, dilations, x_strides
    (
        (32, 3, 4),
        (12, 4, 1),
        (32, 3, 5),
        (15, 5, 1),
        (1,),
        (1,),
        (1,),
    ),
    (
        (1, 3, 4, 4),
45
        (48, 16, 4, 1),
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
        (2, 3, 3, 3),
        (27, 9, 3, 1),
        (1, 1),
        (1, 2),
        (2, 1),
    ),
    (
        (32, 3, 32, 32),
        (32 * 32 * 3, 32 * 32, 32, 1),
        (64, 3, 5, 5),
        (75, 25, 5, 1),
        (2, 2),
        (2, 2),
        (1, 1),
    ),
    (
        (1, 1, 4, 4, 4),
        (64, 64, 16, 4, 1),
        (1, 1, 5, 5, 5),
        (125, 125, 25, 5, 1),
        (1, 1, 1),
        (1, 1, 1),
        (1, 1, 1),
    ),
    (
        (32, 3, 32, 32, 32),
        (32 * 32 * 32 * 3, 32 * 32 * 32, 32 * 32, 32, 1),
        (64, 3, 5, 5, 5),
        (375, 125, 25, 5, 1),
        (3, 2, 2),
        (4, 3, 3),
        (2, 2, 1),
    ),
]


# Data types used for testing
_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.F32, InfiniDtype.BF16]

# Tolerance map for different data types
_TOLERANCE_MAP = {
    InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-3},
88
    InfiniDtype.F32: {"atol": 1e-5, "rtol": 1e-5},
89
90
91
92
93
94
95
    InfiniDtype.BF16: {"atol": 1e-3, "rtol": 1e-2},
}

DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000
PanZezhongQY's avatar
PanZezhongQY committed
96

97

98
def conv(x, w, stride, padding, dilation, y_tensor, bias=None):
99
100
101
102
103
    ndim = len(x.shape) - 2#不要使用match,会导致CI无法通过
    if ndim == 1:
        y_tensor.copy_(
            F.conv1d(
                x, w, bias=bias, stride=stride, padding=padding, dilation=dilation
104
            )
105
106
107
108
109
        )
    elif ndim == 2:
        y_tensor.copy_(
            F.conv2d(
                x, w, bias=bias, stride=stride, padding=padding, dilation=dilation
110
            )
111
112
113
114
115
        )
    elif ndim == 3:
        y_tensor.copy_(
            F.conv3d(
                x, w, bias=bias, stride=stride, padding=padding, dilation=dilation
116
            )
117
118
119
        )
    else:
        print("Error: Pytorch -> Unsupported tensor dimension")
PanZezhongQY's avatar
PanZezhongQY committed
120
121
122


# infer the shape of the output given the inputs for a N-ary convolution
123
def inferShapeStride(
PanZezhongQY's avatar
PanZezhongQY committed
124
125
126
127
128
    x_shape: List[int],
    w_shape: List[int],
    pads: List[int],
    strides: List[int],
    dilations: List[int],
129
) -> Tuple[Tuple[int, ...], Tuple[int, ...]]:
PanZezhongQY's avatar
PanZezhongQY committed
130
    assert (
131
132
133
134
135
        len(x_shape)
        == len(w_shape)
        == len(pads) + 2
        == len(dilations) + 2
        == len(strides) + 2
PanZezhongQY's avatar
PanZezhongQY committed
136
137
138
    ), "x and w should have the same length; pads, strides, and dilatinos should have the same length; the length of pads should be that of x - 2"
    output_dims = [
        math.floor(
139
            (x_shape[i + 2] + 2 * pads[i] - dilations[i] * (w_shape[i + 2] - 1) - 1)
PanZezhongQY's avatar
PanZezhongQY committed
140
141
142
143
144
            / strides[i]
            + 1
        )
        for i in range(len(pads))
    ]
145
146
147
148
149
150
    output_shape = (x_shape[0], w_shape[0]) + tuple(output_dims)
    output_strides = [1]
    for s in reversed(output_shape[1:]):
        output_strides.insert(0, output_strides[0] * s)
    output_strides = tuple(output_strides)
    return output_shape, output_strides
PanZezhongQY's avatar
PanZezhongQY committed
151
152
153
154
155
156
157
158
159
160
161


# convert a python tuple to a ctype void pointer
def tuple_to_void_p(py_tuple: Tuple):
    array = ctypes.c_int64 * len(py_tuple)
    data_array = array(*py_tuple)
    return ctypes.cast(data_array, ctypes.c_void_p)


def test(
    handle,
162
    device,
PanZezhongQY's avatar
PanZezhongQY committed
163
    x_shape,
164
    x_stride,
PanZezhongQY's avatar
PanZezhongQY committed
165
    w_shape,
166
    w_stride,
PanZezhongQY's avatar
PanZezhongQY committed
167
168
169
    pads,
    strides,
    dilations,
170
171
    tensor_dtype=InfiniDtype.F16,
    sync=None,
PanZezhongQY's avatar
PanZezhongQY committed
172
173
):
    assert len(pads) == len(strides) == len(dilations)
174
175
176
177
178
    x = TestTensor(x_shape, x_stride, dt=tensor_dtype, device=device, scale=0.01)
    w = TestTensor(w_shape, w_stride, dt=tensor_dtype, device=device, scale=0.01)
    y_shape, y_stride = inferShapeStride(x_shape, w_shape, pads, strides, dilations)
    y = TestTensor(y_shape, y_stride, dt=tensor_dtype, device=device)

179
180
181
182
183
    b = (
        TestTensor((w.shape[0],), (1,), dt=tensor_dtype, device=device, scale=0.01)
        if w.shape[0] > 1
        else None
    )
PanZezhongQY's avatar
PanZezhongQY committed
184
    print(
185
186
187
188
189
190
191
192
193
194
        f"Testing Conv on {InfiniDeviceNames[device]} with x_shape: {x_shape}, w_shape: {w_shape}, b_shape: {w_shape[0]}, pads: {pads}, strides: {strides}, dilations: {dilations}, x_stride: {x_stride} dtype:{InfiniDtypeNames[tensor_dtype]}"
    )
    conv(
        x.torch_tensor(),
        w.torch_tensor(),
        strides,
        pads,
        dilations,
        y.torch_tensor(),
        b.torch_tensor() if b is not None else None,
PanZezhongQY's avatar
PanZezhongQY committed
195
196
    )

197
198
    if sync is not None:
        sync()
PanZezhongQY's avatar
PanZezhongQY committed
199

200
    descriptor = infiniopOperatorDescriptor_t()
PanZezhongQY's avatar
PanZezhongQY committed
201
    check_error(
202
        LIBINFINIOP.infiniopCreateConvDescriptor(
PanZezhongQY's avatar
PanZezhongQY committed
203
204
            handle,
            ctypes.byref(descriptor),
205
206
207
208
            y.descriptor,
            x.descriptor,
            w.descriptor,
            b.descriptor if b is not None else None,
PanZezhongQY's avatar
PanZezhongQY committed
209
210
211
212
213
214
215
216
            tuple_to_void_p(pads),
            tuple_to_void_p(strides),
            tuple_to_void_p(dilations),
            len(pads),
        )
    )

    # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
217
218
219
    for tensor in [x, y, w, b]:
        if tensor is not None:
            tensor.destroy_desc()
PanZezhongQY's avatar
PanZezhongQY committed
220

221
    workspace_size = ctypes.c_uint64(0)
PanZezhongQY's avatar
PanZezhongQY committed
222
    check_error(
223
224
225
        LIBINFINIOP.infiniopGetConvWorkspaceSize(
            descriptor, ctypes.byref(workspace_size)
        )
PanZezhongQY's avatar
PanZezhongQY committed
226
    )
227
    workspace = TestWorkspace(workspace_size.value, y.device)
PanZezhongQY's avatar
PanZezhongQY committed
228

229
    def lib_conv():
PanZezhongQY's avatar
PanZezhongQY committed
230
        check_error(
231
            LIBINFINIOP.infiniopConv(
PanZezhongQY's avatar
PanZezhongQY committed
232
                descriptor,
233
234
235
236
237
238
                workspace.data(),
                workspace_size.value,
                y.data(),
                x.data(),
                w.data(),
                b.data() if b is not None else None,
PanZezhongQY's avatar
PanZezhongQY committed
239
240
241
242
                None,
            )
        )

243
244
245
246
247
    lib_conv()
    atol, rtol = get_tolerance(_TOLERANCE_MAP, tensor_dtype)
    if DEBUG:
        debug(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol)
    assert torch.allclose(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol)
PanZezhongQY's avatar
PanZezhongQY committed
248

249
250
    # Profiling workflow
    if PROFILE:
251
        # fmt: off
252
253
        profile_operation("PyTorch", lambda: conv(x.torch_tensor(), w.torch_tensor(), strides, pads, dilations, b.torch_tensor() if b is not None else None), device, NUM_PRERUN, NUM_ITERATIONS)
        profile_operation("    lib", lambda: lib_conv(), device, NUM_PRERUN, NUM_ITERATIONS)
254
        # fmt: on
255
    check_error(LIBINFINIOP.infiniopDestroyConvDescriptor(descriptor))
PanZezhongQY's avatar
PanZezhongQY committed
256
257
258
259
260


if __name__ == "__main__":
    args = get_args()

261
262
263
264
265
266
267
268
    # Configure testing options
    DEBUG = args.debug
    PROFILE = args.profile
    NUM_PRERUN = args.num_prerun
    NUM_ITERATIONS = args.num_iterations
    for device in get_test_devices(args):
        test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES)

PanZezhongQY's avatar
PanZezhongQY committed
269
    print("\033[92mTest passed!\033[0m")