gemm.py 5.26 KB
Newer Older
PanZezhong's avatar
PanZezhong committed
1
import torch
PanZezhongQY's avatar
PanZezhongQY committed
2
import ctypes
3
from ctypes import c_uint64
PanZezhong's avatar
PanZezhong committed
4
from libinfiniop import (
5
6
    LIBINFINIOP,
    TestTensor,
PanZezhong's avatar
PanZezhong committed
7
    get_test_devices,
PanZezhongQY's avatar
PanZezhongQY committed
8
    check_error,
PanZezhong's avatar
PanZezhong committed
9
10
11
12
13
    test_operator,
    get_args,
    debug,
    get_tolerance,
    profile_operation,
14
15
16
17
18
    TestWorkspace,
    InfiniDtype,
    InfiniDtypeNames,
    InfiniDeviceNames,
    infiniopOperatorDescriptor_t,
PanZezhongQY's avatar
PanZezhongQY committed
19
20
)

PanZezhong's avatar
PanZezhong committed
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# ==============================================================================
#  Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES = [
    # alpha, beta, a_shape, b_shape, c_shape, a_stride, b_stride, c_stride
    (1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), None, None, None),
    (1.0, 0.0, (2, 4, 2048), (2, 2048, 2048), (2, 4, 2048), None, None, None),
    (1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), (4096, 1), (4096, 1), (4096, 1)),
    (1.0, 1.0, (6, 2048), (2048, 2560), (6, 2560), (2048, 1), (1, 2048), (2560, 1)),
    (1.0 / 8.0, 0.0, (4, 8 * 6, 64), (4, 64, 6), (4, 8 * 6, 6), None, None, None),
]

# Data types used for testing
35
_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16, InfiniDtype.F32]
PanZezhong's avatar
PanZezhong committed
36
37
38

# Tolerance map for different data types
_TOLERANCE_MAP = {
39
40
41
    InfiniDtype.F16: {"atol": 0, "rtol": 1e-2},
    InfiniDtype.F32: {"atol": 0, "rtol": 1e-3},
    InfiniDtype.BF16: {"atol": 0, "rtol": 5e-2},
PanZezhong's avatar
PanZezhong committed
42
43
44
}

DEBUG = False
PanZezhongQY's avatar
PanZezhongQY committed
45
46
47
48
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000

49

PanZezhong's avatar
PanZezhong committed
50
# PyTorch implementation for matrix multiplication
51
def gemm(d, _c, beta, _a, _b, alpha):
52
53
54
55
56
57
58
59
    try:
        if _c.ndim == 2:
            torch.addmm(_c, _a, _b, beta=beta, alpha=alpha, out=d)
        elif _c.ndim == 3:
            torch.baddbmm(_c, _a, _b, beta=beta, alpha=alpha, out=d)
        else:
            raise
    except Exception:
60
61
        torch.matmul(_a, _b, out=d)
        d.mul_(alpha).add_(_c, alpha=beta)
PanZezhongQY's avatar
PanZezhongQY committed
62
63


PanZezhong's avatar
PanZezhong committed
64
65
# The argument list should be (lib, handle, torch_device, <param list>, dtype)
# The <param list> should keep the same order as the one specified in _TEST_CASES
PanZezhongQY's avatar
PanZezhongQY committed
66
67
def test(
    handle,
68
    device,
PanZezhongQY's avatar
PanZezhongQY committed
69
70
71
72
73
74
75
76
    alpha,
    beta,
    a_shape,
    b_shape,
    c_shape,
    a_stride=None,
    b_stride=None,
    c_stride=None,
77
    dtype=InfiniDtype.F16,
78
    sync=None,
PanZezhongQY's avatar
PanZezhongQY committed
79
80
):
    print(
81
        f"Testing Gemm on {InfiniDeviceNames[device]} with alpha:{alpha}, beta:{beta},"
PanZezhong's avatar
PanZezhong committed
82
        f" a_shape:{a_shape}, b_shape:{b_shape}, c_shape:{c_shape},"
83
        f" a_stride:{a_stride}, b_stride:{b_stride}, c_stride:{c_stride}, dtype:{InfiniDtypeNames[dtype]}"
PanZezhongQY's avatar
PanZezhongQY committed
84
85
    )

PanZezhong's avatar
PanZezhong committed
86
    # Initialize tensors
87
88
89
90
    a = TestTensor(a_shape, a_stride, dtype, device)
    b = TestTensor(b_shape, b_stride, dtype, device)
    c = TestTensor(c_shape, c_stride, dtype, device, mode="ones")
    ans = TestTensor(c_shape, c_stride, dtype, device, mode="zeros")
PanZezhongQY's avatar
PanZezhongQY committed
91

PanZezhong's avatar
PanZezhong committed
92
    # Compute the PyTorch reference result
93
94
95
96
97
98
99
100
101
    def torch_gemm():
        gemm(
            ans.torch_tensor(),
            c.torch_tensor(),
            beta,
            a.torch_tensor(),
            b.torch_tensor(),
            alpha,
        )
PanZezhongQY's avatar
PanZezhongQY committed
102

103
    torch_gemm()
PanZezhongQY's avatar
PanZezhongQY committed
104

105
106
107
    if sync is not None:
        sync()

108
    descriptor = infiniopOperatorDescriptor_t()
PanZezhongQY's avatar
PanZezhongQY committed
109
    check_error(
110
        LIBINFINIOP.infiniopCreateGemmDescriptor(
PanZezhongQY's avatar
PanZezhongQY committed
111
112
            handle,
            ctypes.byref(descriptor),
113
114
115
            c.descriptor,
            a.descriptor,
            b.descriptor,
PanZezhongQY's avatar
PanZezhongQY committed
116
117
118
119
        )
    )

    # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
120
121
    for tensor in [a, b, c]:
        tensor.destroy_desc()
PanZezhongQY's avatar
PanZezhongQY committed
122

PanZezhong's avatar
PanZezhong committed
123
124
    # Get workspace size and create workspace
    workspace_size = c_uint64(0)
PanZezhongQY's avatar
PanZezhongQY committed
125
    check_error(
126
127
128
        LIBINFINIOP.infiniopGetGemmWorkspaceSize(
            descriptor, ctypes.byref(workspace_size)
        )
PanZezhongQY's avatar
PanZezhongQY committed
129
    )
130
    workspace = TestWorkspace(workspace_size.value, device)
PanZezhongQY's avatar
PanZezhongQY committed
131

PanZezhong's avatar
PanZezhong committed
132
133
    # Execute infiniop gemm operator
    def lib_gemm():
PanZezhongQY's avatar
PanZezhongQY committed
134
        check_error(
135
            LIBINFINIOP.infiniopGemm(
PanZezhongQY's avatar
PanZezhongQY committed
136
                descriptor,
137
                workspace.data(),
PanZezhong's avatar
PanZezhong committed
138
                workspace_size.value,
139
140
141
                c.data(),
                a.data(),
                b.data(),
PanZezhong's avatar
PanZezhong committed
142
143
                alpha,
                beta,
PanZezhongQY's avatar
PanZezhongQY committed
144
145
146
147
                None,
            )
        )

PanZezhong's avatar
PanZezhong committed
148
    lib_gemm()
PanZezhongQY's avatar
PanZezhongQY committed
149

PanZezhong's avatar
PanZezhong committed
150
151
    # Validate results
    atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
152

PanZezhong's avatar
PanZezhong committed
153
    if DEBUG:
154
        debug(c.actual_tensor(), ans.torch_tensor(), atol=atol, rtol=rtol)
155

156
    assert torch.allclose(c.actual_tensor(), ans.torch_tensor(), atol=atol, rtol=rtol)
PanZezhongQY's avatar
PanZezhongQY committed
157

PanZezhong's avatar
PanZezhong committed
158
159
    # Profiling workflow
    if PROFILE:
160
        # fmt: off
161
162
        profile_operation("PyTorch", lambda: torch_gemm(), device, NUM_PRERUN, NUM_ITERATIONS)
        profile_operation("    lib", lambda: lib_gemm(), device, NUM_PRERUN, NUM_ITERATIONS)
163
        # fmt: on
164
    check_error(LIBINFINIOP.infiniopDestroyGemmDescriptor(descriptor))
PanZezhongQY's avatar
PanZezhongQY committed
165
166


PanZezhong's avatar
PanZezhong committed
167
168
169
# ==============================================================================
#  Main Execution
# ==============================================================================
PanZezhongQY's avatar
PanZezhongQY committed
170
171
172
if __name__ == "__main__":
    args = get_args()

PanZezhong's avatar
PanZezhong committed
173
174
175
176
177
178
179
180
    # Configure testing options
    DEBUG = args.debug
    PROFILE = args.profile
    NUM_PRERUN = args.num_prerun
    NUM_ITERATIONS = args.num_iterations

    # Execute tests
    for device in get_test_devices(args):
181
        test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES)
PanZezhong's avatar
PanZezhong committed
182

PanZezhongQY's avatar
PanZezhongQY committed
183
    print("\033[92mTest passed!\033[0m")