rms_norm.py 4.76 KB
Newer Older
xgqdut2016's avatar
xgqdut2016 committed
1
2
import torch
import ctypes
3
from ctypes import c_uint64
xgqdut2016's avatar
xgqdut2016 committed
4
from libinfiniop import (
5
6
    LIBINFINIOP,
    TestTensor,
xgqdut2016's avatar
xgqdut2016 committed
7
    get_test_devices,
PanZezhongQY's avatar
PanZezhongQY committed
8
    check_error,
xgqdut2016's avatar
xgqdut2016 committed
9
10
11
12
13
    test_operator,
    get_args,
    debug,
    get_tolerance,
    profile_operation,
14
15
16
17
18
    TestWorkspace,
    InfiniDtype,
    InfiniDtypeNames,
    InfiniDeviceNames,
    infiniopOperatorDescriptor_t,
PanZezhongQY's avatar
PanZezhongQY committed
19
20
)

xgqdut2016's avatar
xgqdut2016 committed
21
22
23
24
# ==============================================================================
#  Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
25
26
27
_TEST_CASES_ = [
    # y_shape, x_shape, w_shape, y_stride, x_stride
    ((1, 4), (1, 4), (4,), None, None),
28
29
30
    ((2, 4), (2, 4), (4,), None, None),
    ((2, 2, 4), (2, 2, 4), (4,), None, None),
    ((2, 2, 4), (2, 2, 4), (4,), (12, 8, 1), (12, 8, 1)),
31
32
    ((16, 2048), (16, 2048), (2048,), None, None),
    ((16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1)),
33
34
35
    ((4, 4, 2048), (4, 4, 2048), (2048,), None, None),
    ((4, 4, 2048), (4, 4, 2048), (2048,), (2048, 8192, 1), (2048, 8192, 1)),
    ((4, 4, 2048), (4, 4, 2048), (2048,), (16384, 4096, 1), (16384, 4096, 1)),
xgqdut2016's avatar
xgqdut2016 committed
36
]
xgqdut2016's avatar
xgqdut2016 committed
37

38
# w (weight) types
39
# Note: 'None' means the same as input dtype
blkmjsian's avatar
blkmjsian committed
40
_WEIGHT_DTYPES = [None, InfiniDtype.F32, InfiniDtype.F16, InfiniDtype.BF16]
xgqdut2016's avatar
xgqdut2016 committed
41
# x types used for testing
42
_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16]
43
44
45

# Form the test cases by appending each element of _WEIGHT_DTYPES to each tuple in _TEST_CASES_
_TEST_CASES = [
46
    test_case + (w_dtype,) for test_case in _TEST_CASES_ for w_dtype in _WEIGHT_DTYPES
47
]
xgqdut2016's avatar
xgqdut2016 committed
48
49
50

# Tolerance map for different data types
_TOLERANCE_MAP = {
51
    InfiniDtype.F16: {"atol": 2e-3, "rtol": 2e-3},
52
    InfiniDtype.BF16: {"atol": 1e-2, "rtol": 1e-2},
xgqdut2016's avatar
xgqdut2016 committed
53
54
55
56
57
58
}

DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000
59

xgqdut2016's avatar
xgqdut2016 committed
60

61
62
63
64
65
66
67
def rms_norm(ans, x, w, eps):
    torch.pow(x, 2, out=ans)
    mean = torch.mean(ans, dim=-1, keepdim=True)
    mean.add_(eps)
    torch.rsqrt(mean, out=mean)
    torch.mul(x, mean, out=ans)
    ans.mul_(w)
PanZezhongQY's avatar
PanZezhongQY committed
68
69


70
def test(
xgqdut2016's avatar
xgqdut2016 committed
71
    handle,
72
    device,
xgqdut2016's avatar
xgqdut2016 committed
73
74
75
    y_shape,
    x_shape,
    w_shape,
xgqdut2016's avatar
xgqdut2016 committed
76
77
    y_stride,
    x_stride,
78
79
    w_dtype=InfiniDtype.F32,
    dtype=InfiniDtype.F16,
80
    sync=None,
xgqdut2016's avatar
xgqdut2016 committed
81
):
82
    w_dtype = w_dtype if w_dtype else dtype
xgqdut2016's avatar
xgqdut2016 committed
83
    print(
84
85
        f"Testing RMS_Norm on {InfiniDeviceNames[device]} with y_shape:{y_shape} x_shape:{x_shape} w_shape:{w_shape}"
        f" y_stride:{y_stride} x_stride:{x_stride} w_dtype:{InfiniDtypeNames[w_dtype]} dtype:{InfiniDtypeNames[dtype]}"
xgqdut2016's avatar
xgqdut2016 committed
86
    )
PanZezhongQY's avatar
PanZezhongQY committed
87

88
89
90
    y = TestTensor(y_shape, y_stride, dtype, device, mode="ones")
    x = TestTensor(x_shape, x_stride, dtype, device, scale=0.01)
    w = TestTensor(w_shape, None, w_dtype, device)
PanZezhongQY's avatar
PanZezhongQY committed
91

92
93
    eps = 1e-6
    rms_norm(y.torch_tensor(), x.torch_tensor(), w.torch_tensor(), eps)
94

95
96
    if sync is not None:
        sync()
97

98
    descriptor = infiniopOperatorDescriptor_t()
PanZezhongQY's avatar
PanZezhongQY committed
99
100

    check_error(
101
        LIBINFINIOP.infiniopCreateRMSNormDescriptor(
xgqdut2016's avatar
xgqdut2016 committed
102
103
            handle,
            ctypes.byref(descriptor),
104
105
106
            y.descriptor,
            x.descriptor,
            w.descriptor,
xgqdut2016's avatar
xgqdut2016 committed
107
            eps,
PanZezhongQY's avatar
PanZezhongQY committed
108
109
110
111
        )
    )

    # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
112
113
    for tensor in [x, y, w]:
        tensor.destroy_desc()
PanZezhongQY's avatar
PanZezhongQY committed
114
115
116

    workspace_size = c_uint64(0)
    check_error(
117
118
119
        LIBINFINIOP.infiniopGetRMSNormWorkspaceSize(
            descriptor, ctypes.byref(workspace_size)
        )
PanZezhongQY's avatar
PanZezhongQY committed
120
    )
121
    workspace = TestWorkspace(workspace_size.value, y.device)
xgqdut2016's avatar
xgqdut2016 committed
122

xgqdut2016's avatar
xgqdut2016 committed
123
124
    def lib_rms_norm():
        check_error(
125
            LIBINFINIOP.infiniopRMSNorm(
xgqdut2016's avatar
xgqdut2016 committed
126
                descriptor,
127
                workspace.data(),
xgqdut2016's avatar
xgqdut2016 committed
128
                workspace_size.value,
129
130
131
                y.data(),
                x.data(),
                w.data(),
xgqdut2016's avatar
xgqdut2016 committed
132
133
                None,
            )
PanZezhongQY's avatar
PanZezhongQY committed
134
135
        )

xgqdut2016's avatar
xgqdut2016 committed
136
137
138
139
    lib_rms_norm()

    atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
    if DEBUG:
140
141
        debug(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol)
    assert torch.allclose(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol)
142

xgqdut2016's avatar
xgqdut2016 committed
143
144
145
    # Profiling workflow
    if PROFILE:
        # fmt: off
146
147
        profile_operation("PyTorch", lambda: rms_norm(y.torch_tensor(), x.torch_tensor(), w.torch_tensor(), eps), device, NUM_PRERUN, NUM_ITERATIONS)
        profile_operation("    lib", lambda: lib_rms_norm(), device, NUM_PRERUN, NUM_ITERATIONS)
xgqdut2016's avatar
xgqdut2016 committed
148
        # fmt: on
149
    check_error(LIBINFINIOP.infiniopDestroyRMSNormDescriptor(descriptor))
PanZezhongQY's avatar
PanZezhongQY committed
150

151

PanZezhongQY's avatar
PanZezhongQY committed
152
153
154
if __name__ == "__main__":
    args = get_args()

xgqdut2016's avatar
xgqdut2016 committed
155
156
157
158
159
160
161
162
    # Configure testing options
    DEBUG = args.debug
    PROFILE = args.profile
    NUM_PRERUN = args.num_prerun
    NUM_ITERATIONS = args.num_iterations

    # Execute tests
    for device in get_test_devices(args):
163
        test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES)
xgqdut2016's avatar
xgqdut2016 committed
164

PanZezhongQY's avatar
PanZezhongQY committed
165
    print("\033[92mTest passed!\033[0m")