matmul.py 6.45 KB
Newer Older
1
2
3
import torch
import ctypes
from ctypes import POINTER, Structure, c_int32, c_size_t, c_uint64, c_void_p, c_float
PanZezhongQY's avatar
PanZezhongQY committed
4
from libinfiniop import (
5
6
7
    infiniopHandle_t, infiniopTensorDescriptor_t, open_lib, to_tensor, get_test_devices, 
    check_error, rearrange_if_needed, create_workspace, test_operator, get_args, 
    debug, get_tolerance, profile_operation,
PanZezhongQY's avatar
PanZezhongQY committed
8
9
)

10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# ==============================================================================
#  Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES = [
    # alpha, beta, a_shape, b_shape, c_shape, a_stride, b_stride, c_stride
    (1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), None, None, None),
    (1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), None, None, None),
    (1.0, 0.0, (2, 4, 2048), (2, 2048, 2048), (2, 4, 2048), None, None, None),
    (1.0, 0.0, (2, 4, 2048), (2, 2048, 2048), (2, 4, 2048), None, None, None),
    (1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), (4096, 1), (4096, 1), (4096, 1)),
    (1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), (4096, 1), (4096, 1), (4096, 1)),
    (1.0, 1.0, (6, 2048), (2048, 2560), (6, 2560), (2048, 1), (1, 2048), (2560, 1)),
    (1.0, 1.0, (6, 2048), (2048, 2560), (6, 2560), (2048, 1), (1, 2048), (2560, 1)),
    (1.0/8.0, 0.0, (4, 8*6, 64), (4, 64, 6), (4, 8*6, 6), None, None, None),
    (1.0/8.0, 0.0, (4, 8*6, 64), (4, 64, 6), (4, 8*6, 6), None, None, None),
]

# Data types used for testing
_TENSOR_DTYPES = [torch.float16, torch.float32]

# Tolerance map for different data types
_TOLERANCE_MAP = {
    torch.float16: {'atol': 0, 'rtol': 1e-2},
    torch.float32: {'atol': 0, 'rtol': 1e-3},
}

DEBUG = False
PanZezhongQY's avatar
PanZezhongQY committed
38
39
40
41
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000

42
43
44
# ==============================================================================
#  Definitions
# ==============================================================================
PanZezhongQY's avatar
PanZezhongQY committed
45
46
47
48
49
50
class MatmulDescriptor(Structure):
    _fields_ = [("device", c_int32)]


infiniopMatmulDescriptor_t = POINTER(MatmulDescriptor)

51
# PyTorch implementation for matrix multiplication
PanZezhongQY's avatar
PanZezhongQY committed
52
def matmul(_c, beta, _a, _b, alpha):
53
54
55
56
    a, b, c = _a.clone(), _b.clone(), _c.clone()
    result_dtype = c.dtype
    fp32_result = torch.matmul(a.to(torch.float32), b.to(torch.float32))
    return alpha * fp32_result.to(result_dtype) + beta * c
PanZezhongQY's avatar
PanZezhongQY committed
57

58
59
# The argument list should be (lib, handle, torch_device, <param list>, dtype)
# The <param list> should keep the same order as the one specified in _TEST_CASES
PanZezhongQY's avatar
PanZezhongQY committed
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def test(
    lib,
    handle,
    torch_device,
    alpha,
    beta,
    a_shape,
    b_shape,
    c_shape,
    a_stride=None,
    b_stride=None,
    c_stride=None,
    dtype=torch.float16,
):
    print(
75
76
77
        f"Testing Matmul on {torch_device} with alpha:{alpha}, beta:{beta},"
        f" a_shape:{a_shape}, b_shape:{b_shape}, c_shape:{c_shape},"
        f" a_stride:{a_stride}, b_stride:{b_stride}, c_stride:{c_stride}, dtype:{dtype}"
PanZezhongQY's avatar
PanZezhongQY committed
78
79
    )

80
    # Initialize tensors
PanZezhongQY's avatar
PanZezhongQY committed
81
82
83
84
    a = torch.rand(a_shape, dtype=dtype).to(torch_device)
    b = torch.rand(b_shape, dtype=dtype).to(torch_device)
    c = torch.ones(c_shape, dtype=dtype).to(torch_device)

85
    # Compute the PyTorch reference result
PanZezhongQY's avatar
PanZezhongQY committed
86
87
    ans = matmul(c, beta, a, b, alpha)

88
89
    a, b, c = [rearrange_if_needed(tensor, stride) for tensor, stride in zip([a, b, c], [a_stride, b_stride, c_stride])]
    a_tensor, b_tensor, c_tensor = [to_tensor(tensor, lib) for tensor in [a, b, c]]
PanZezhongQY's avatar
PanZezhongQY committed
90
91
92
93
94
95
96
97
98
99
100
101
102

    descriptor = infiniopMatmulDescriptor_t()
    check_error(
        lib.infiniopCreateMatmulDescriptor(
            handle,
            ctypes.byref(descriptor),
            c_tensor.descriptor,
            a_tensor.descriptor,
            b_tensor.descriptor
        )
    )

    # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
103
104
    for tensor in [a_tensor, b_tensor, c_tensor]:
        tensor.descriptor.contents.invalidate()
PanZezhongQY's avatar
PanZezhongQY committed
105

106
    # Get workspace size and create workspace
PanZezhongQY's avatar
PanZezhongQY committed
107
    workspace_size = c_uint64(0)
108
    check_error(lib.infiniopGetMatmulWorkspaceSize(descriptor, ctypes.byref(workspace_size)))
PanZezhongQY's avatar
PanZezhongQY committed
109
110
    workspace = create_workspace(workspace_size.value, a.device)

111
112
113
114
115
    # Execute infiniop matmul operator
    def lib_matmul():
        check_error(lib.infiniopMatmul(
            descriptor, 
            workspace.data_ptr() if workspace else None,
PanZezhongQY's avatar
PanZezhongQY committed
116
117
118
119
120
121
122
            workspace_size.value,
            c_tensor.data,
            a_tensor.data,
            b_tensor.data,
            alpha,
            beta,
            None,
123
124
        ))
    lib_matmul()
PanZezhongQY's avatar
PanZezhongQY committed
125

126
127
128
129
130
    # Validate results
    atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
    if DEBUG:
        debug(c, ans, atol=atol, rtol=rtol)
    assert torch.allclose(c, ans, atol=atol, rtol=rtol)
PanZezhongQY's avatar
PanZezhongQY committed
131

132
    # Profiling workflow
PanZezhongQY's avatar
PanZezhongQY committed
133
    if PROFILE:
134
135
        profile_operation("PyTorch", lambda: matmul(c, beta, a, b, alpha), torch_device, NUM_PRERUN, NUM_ITERATIONS)
        profile_operation("    lib", lambda: lib_matmul(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
PanZezhongQY's avatar
PanZezhongQY committed
136
137
138
139

    check_error(lib.infiniopDestroyMatmulDescriptor(descriptor))


140
141
142
# ==============================================================================
#  Main Execution
# ==============================================================================
PanZezhongQY's avatar
PanZezhongQY committed
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
if __name__ == "__main__":
    args = get_args()
    lib = open_lib()

    lib.infiniopCreateMatmulDescriptor.restype = c_int32
    lib.infiniopCreateMatmulDescriptor.argtypes = [
        infiniopHandle_t,
        POINTER(infiniopMatmulDescriptor_t),
        infiniopTensorDescriptor_t,
        infiniopTensorDescriptor_t,
        infiniopTensorDescriptor_t
    ]

    lib.infiniopGetMatmulWorkspaceSize.restype = c_int32
    lib.infiniopGetMatmulWorkspaceSize.argtypes = [
        infiniopMatmulDescriptor_t,
        POINTER(c_size_t),
    ]

    lib.infiniopMatmul.restype = c_int32
    lib.infiniopMatmul.argtypes = [
        infiniopMatmulDescriptor_t,
        c_void_p,
        c_uint64,
        c_void_p,
        c_void_p,
        c_void_p,
        c_float,
        c_float,
        c_void_p,
    ]

    lib.infiniopDestroyMatmulDescriptor.restype = c_int32
    lib.infiniopDestroyMatmulDescriptor.argtypes = [
        infiniopMatmulDescriptor_t,
    ]

180
181
182
183
184
185
186
187
188
189
    # Configure testing options
    DEBUG = args.debug
    PROFILE = args.profile
    NUM_PRERUN = args.num_prerun
    NUM_ITERATIONS = args.num_iterations

    # Execute tests
    for device in get_test_devices(args):
        test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES)

PanZezhongQY's avatar
PanZezhongQY committed
190
    print("\033[92mTest passed!\033[0m")