"configs/rtmdet/rtmdet-ins_l_8xb32-300e_coco.py" did not exist on "fdfe3c4f8ba935ae428a8a496ce57755d5b2ea98"
matmul.py 6.46 KB
Newer Older
PanZezhongQY's avatar
PanZezhongQY committed
1
import sys
2
3
4
import torch
import ctypes
from ctypes import POINTER, Structure, c_int32, c_size_t, c_uint64, c_void_p, c_float
PanZezhongQY's avatar
PanZezhongQY committed
5
from libinfiniop import (
6
7
8
    infiniopHandle_t, infiniopTensorDescriptor_t, open_lib, to_tensor, get_test_devices, 
    check_error, rearrange_if_needed, create_workspace, test_operator, get_args, 
    debug, get_tolerance, profile_operation,
PanZezhongQY's avatar
PanZezhongQY committed
9
10
)

11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# ==============================================================================
#  Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES = [
    # alpha, beta, a_shape, b_shape, c_shape, a_stride, b_stride, c_stride
    (1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), None, None, None),
    (1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), None, None, None),
    (1.0, 0.0, (2, 4, 2048), (2, 2048, 2048), (2, 4, 2048), None, None, None),
    (1.0, 0.0, (2, 4, 2048), (2, 2048, 2048), (2, 4, 2048), None, None, None),
    (1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), (4096, 1), (4096, 1), (4096, 1)),
    (1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), (4096, 1), (4096, 1), (4096, 1)),
    (1.0, 1.0, (6, 2048), (2048, 2560), (6, 2560), (2048, 1), (1, 2048), (2560, 1)),
    (1.0, 1.0, (6, 2048), (2048, 2560), (6, 2560), (2048, 1), (1, 2048), (2560, 1)),
    (1.0/8.0, 0.0, (4, 8*6, 64), (4, 64, 6), (4, 8*6, 6), None, None, None),
    (1.0/8.0, 0.0, (4, 8*6, 64), (4, 64, 6), (4, 8*6, 6), None, None, None),
]

# Data types used for testing
_TENSOR_DTYPES = [torch.float16, torch.float32]

# Tolerance map for different data types
_TOLERANCE_MAP = {
    torch.float16: {'atol': 0, 'rtol': 1e-2},
    torch.float32: {'atol': 0, 'rtol': 1e-3},
}

DEBUG = False
PanZezhongQY's avatar
PanZezhongQY committed
39
40
41
42
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000

43
44
45
# ==============================================================================
#  Definitions
# ==============================================================================
PanZezhongQY's avatar
PanZezhongQY committed
46
47
48
49
50
51
class MatmulDescriptor(Structure):
    _fields_ = [("device", c_int32)]


infiniopMatmulDescriptor_t = POINTER(MatmulDescriptor)

52
# PyTorch implementation for matrix multiplication
PanZezhongQY's avatar
PanZezhongQY committed
53
def matmul(_c, beta, _a, _b, alpha):
54
55
56
57
    a, b, c = _a.clone(), _b.clone(), _c.clone()
    result_dtype = c.dtype
    fp32_result = torch.matmul(a.to(torch.float32), b.to(torch.float32))
    return alpha * fp32_result.to(result_dtype) + beta * c
PanZezhongQY's avatar
PanZezhongQY committed
58

59
60
# The argument list should be (lib, handle, torch_device, <param list>, dtype)
# The <param list> should keep the same order as the one specified in _TEST_CASES
PanZezhongQY's avatar
PanZezhongQY committed
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def test(
    lib,
    handle,
    torch_device,
    alpha,
    beta,
    a_shape,
    b_shape,
    c_shape,
    a_stride=None,
    b_stride=None,
    c_stride=None,
    dtype=torch.float16,
):
    print(
76
77
78
        f"Testing Matmul on {torch_device} with alpha:{alpha}, beta:{beta},"
        f" a_shape:{a_shape}, b_shape:{b_shape}, c_shape:{c_shape},"
        f" a_stride:{a_stride}, b_stride:{b_stride}, c_stride:{c_stride}, dtype:{dtype}"
PanZezhongQY's avatar
PanZezhongQY committed
79
80
    )

81
    # Initialize tensors
PanZezhongQY's avatar
PanZezhongQY committed
82
83
84
85
    a = torch.rand(a_shape, dtype=dtype).to(torch_device)
    b = torch.rand(b_shape, dtype=dtype).to(torch_device)
    c = torch.ones(c_shape, dtype=dtype).to(torch_device)

86
    # Compute the PyTorch reference result
PanZezhongQY's avatar
PanZezhongQY committed
87
88
    ans = matmul(c, beta, a, b, alpha)

89
90
    a, b, c = [rearrange_if_needed(tensor, stride) for tensor, stride in zip([a, b, c], [a_stride, b_stride, c_stride])]
    a_tensor, b_tensor, c_tensor = [to_tensor(tensor, lib) for tensor in [a, b, c]]
PanZezhongQY's avatar
PanZezhongQY committed
91
92
93
94
95
96
97
98
99
100
101
102
103

    descriptor = infiniopMatmulDescriptor_t()
    check_error(
        lib.infiniopCreateMatmulDescriptor(
            handle,
            ctypes.byref(descriptor),
            c_tensor.descriptor,
            a_tensor.descriptor,
            b_tensor.descriptor
        )
    )

    # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
104
105
    for tensor in [a_tensor, b_tensor, c_tensor]:
        tensor.descriptor.contents.invalidate()
PanZezhongQY's avatar
PanZezhongQY committed
106

107
    # Get workspace size and create workspace
PanZezhongQY's avatar
PanZezhongQY committed
108
    workspace_size = c_uint64(0)
109
    check_error(lib.infiniopGetMatmulWorkspaceSize(descriptor, ctypes.byref(workspace_size)))
PanZezhongQY's avatar
PanZezhongQY committed
110
111
    workspace = create_workspace(workspace_size.value, a.device)

112
113
114
115
116
    # Execute infiniop matmul operator
    def lib_matmul():
        check_error(lib.infiniopMatmul(
            descriptor, 
            workspace.data_ptr() if workspace else None,
PanZezhongQY's avatar
PanZezhongQY committed
117
118
119
120
121
122
123
            workspace_size.value,
            c_tensor.data,
            a_tensor.data,
            b_tensor.data,
            alpha,
            beta,
            None,
124
125
        ))
    lib_matmul()
PanZezhongQY's avatar
PanZezhongQY committed
126

127
128
129
130
131
    # Validate results
    atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
    if DEBUG:
        debug(c, ans, atol=atol, rtol=rtol)
    assert torch.allclose(c, ans, atol=atol, rtol=rtol)
PanZezhongQY's avatar
PanZezhongQY committed
132

133
    # Profiling workflow
PanZezhongQY's avatar
PanZezhongQY committed
134
    if PROFILE:
135
136
        profile_operation("PyTorch", lambda: matmul(c, beta, a, b, alpha), torch_device, NUM_PRERUN, NUM_ITERATIONS)
        profile_operation("    lib", lambda: lib_matmul(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
PanZezhongQY's avatar
PanZezhongQY committed
137
138
139
140

    check_error(lib.infiniopDestroyMatmulDescriptor(descriptor))


141
142
143
# ==============================================================================
#  Main Execution
# ==============================================================================
PanZezhongQY's avatar
PanZezhongQY committed
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
if __name__ == "__main__":
    args = get_args()
    lib = open_lib()

    lib.infiniopCreateMatmulDescriptor.restype = c_int32
    lib.infiniopCreateMatmulDescriptor.argtypes = [
        infiniopHandle_t,
        POINTER(infiniopMatmulDescriptor_t),
        infiniopTensorDescriptor_t,
        infiniopTensorDescriptor_t,
        infiniopTensorDescriptor_t
    ]

    lib.infiniopGetMatmulWorkspaceSize.restype = c_int32
    lib.infiniopGetMatmulWorkspaceSize.argtypes = [
        infiniopMatmulDescriptor_t,
        POINTER(c_size_t),
    ]

    lib.infiniopMatmul.restype = c_int32
    lib.infiniopMatmul.argtypes = [
        infiniopMatmulDescriptor_t,
        c_void_p,
        c_uint64,
        c_void_p,
        c_void_p,
        c_void_p,
        c_float,
        c_float,
        c_void_p,
    ]

    lib.infiniopDestroyMatmulDescriptor.restype = c_int32
    lib.infiniopDestroyMatmulDescriptor.argtypes = [
        infiniopMatmulDescriptor_t,
    ]

181
182
183
184
185
186
187
188
189
190
    # Configure testing options
    DEBUG = args.debug
    PROFILE = args.profile
    NUM_PRERUN = args.num_prerun
    NUM_ITERATIONS = args.num_iterations

    # Execute tests
    for device in get_test_devices(args):
        test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES)

PanZezhongQY's avatar
PanZezhongQY committed
191
    print("\033[92mTest passed!\033[0m")