rms_norm.py 4.53 KB
Newer Older
xgqdut2016's avatar
xgqdut2016 committed
1
2
import torch
import ctypes
3
from ctypes import c_uint64
xgqdut2016's avatar
xgqdut2016 committed
4
from libinfiniop import (
5
6
    LIBINFINIOP,
    TestTensor,
xgqdut2016's avatar
xgqdut2016 committed
7
    get_test_devices,
PanZezhongQY's avatar
PanZezhongQY committed
8
    check_error,
xgqdut2016's avatar
xgqdut2016 committed
9
10
11
12
13
    test_operator,
    get_args,
    debug,
    get_tolerance,
    profile_operation,
14
15
16
17
18
    TestWorkspace,
    InfiniDtype,
    InfiniDtypeNames,
    InfiniDeviceNames,
    infiniopOperatorDescriptor_t,
PanZezhongQY's avatar
PanZezhongQY committed
19
20
)

xgqdut2016's avatar
xgqdut2016 committed
21
22
23
24
# ==============================================================================
#  Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
25
26
27
28
29
30
31
32
_TEST_CASES_ = [
    # y_shape, x_shape, w_shape, y_stride, x_stride
    ((1, 4), (1, 4), (4,), None, None),
    ((1, 4), (1, 4), (4,), None, None),
    ((16, 2048), (16, 2048), (2048,), None, None),
    ((16, 2048), (16, 2048), (2048,), None, None),
    ((16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1)),
    ((16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1)),
xgqdut2016's avatar
xgqdut2016 committed
33
]
xgqdut2016's avatar
xgqdut2016 committed
34

35
# w (weight) types
36
# Note: 'None' means the same as input dtype
37
_WEIGHT_DTYPES = [None, InfiniDtype.F32]
xgqdut2016's avatar
xgqdut2016 committed
38
# x types used for testing
39
_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16]
40
41
42

# Form the test cases by appending each element of _WEIGHT_DTYPES to each tuple in _TEST_CASES_
_TEST_CASES = [
43
    test_case + (w_dtype,) for test_case in _TEST_CASES_ for w_dtype in _WEIGHT_DTYPES
44
]
xgqdut2016's avatar
xgqdut2016 committed
45
46
47

# Tolerance map for different data types
_TOLERANCE_MAP = {
48
    InfiniDtype.F16: {"atol": 2e-3, "rtol": 2e-3},
49
    InfiniDtype.BF16: {"atol": 1e-2, "rtol": 1e-2},
xgqdut2016's avatar
xgqdut2016 committed
50
51
52
53
54
55
}

DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000
56

xgqdut2016's avatar
xgqdut2016 committed
57

58
59
60
61
62
63
64
def rms_norm(ans, x, w, eps):
    torch.pow(x, 2, out=ans)
    mean = torch.mean(ans, dim=-1, keepdim=True)
    mean.add_(eps)
    torch.rsqrt(mean, out=mean)
    torch.mul(x, mean, out=ans)
    ans.mul_(w)
PanZezhongQY's avatar
PanZezhongQY committed
65
66


67
def test(
xgqdut2016's avatar
xgqdut2016 committed
68
    handle,
69
    device,
xgqdut2016's avatar
xgqdut2016 committed
70
71
72
    y_shape,
    x_shape,
    w_shape,
xgqdut2016's avatar
xgqdut2016 committed
73
74
    y_stride,
    x_stride,
75
76
    w_dtype=InfiniDtype.F32,
    dtype=InfiniDtype.F16,
77
    sync=None,
xgqdut2016's avatar
xgqdut2016 committed
78
):
79
    w_dtype = w_dtype if w_dtype else dtype
xgqdut2016's avatar
xgqdut2016 committed
80
    print(
81
82
        f"Testing RMS_Norm on {InfiniDeviceNames[device]} with y_shape:{y_shape} x_shape:{x_shape} w_shape:{w_shape}"
        f" y_stride:{y_stride} x_stride:{x_stride} w_dtype:{InfiniDtypeNames[w_dtype]} dtype:{InfiniDtypeNames[dtype]}"
xgqdut2016's avatar
xgqdut2016 committed
83
    )
PanZezhongQY's avatar
PanZezhongQY committed
84

85
86
87
    y = TestTensor(y_shape, y_stride, dtype, device, mode="ones")
    x = TestTensor(x_shape, x_stride, dtype, device, scale=0.01)
    w = TestTensor(w_shape, None, w_dtype, device)
PanZezhongQY's avatar
PanZezhongQY committed
88

89
90
    eps = 1e-6
    rms_norm(y.torch_tensor(), x.torch_tensor(), w.torch_tensor(), eps)
91

92
93
    if sync is not None:
        sync()
94

95
    descriptor = infiniopOperatorDescriptor_t()
PanZezhongQY's avatar
PanZezhongQY committed
96
97

    check_error(
98
        LIBINFINIOP.infiniopCreateRMSNormDescriptor(
xgqdut2016's avatar
xgqdut2016 committed
99
100
            handle,
            ctypes.byref(descriptor),
101
102
103
            y.descriptor,
            x.descriptor,
            w.descriptor,
xgqdut2016's avatar
xgqdut2016 committed
104
            eps,
PanZezhongQY's avatar
PanZezhongQY committed
105
106
107
108
        )
    )

    # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
109
110
    for tensor in [x, y, w]:
        tensor.destroy_desc()
PanZezhongQY's avatar
PanZezhongQY committed
111
112
113

    workspace_size = c_uint64(0)
    check_error(
114
115
116
        LIBINFINIOP.infiniopGetRMSNormWorkspaceSize(
            descriptor, ctypes.byref(workspace_size)
        )
PanZezhongQY's avatar
PanZezhongQY committed
117
    )
118
    workspace = TestWorkspace(workspace_size.value, y.device)
xgqdut2016's avatar
xgqdut2016 committed
119

xgqdut2016's avatar
xgqdut2016 committed
120
121
    def lib_rms_norm():
        check_error(
122
            LIBINFINIOP.infiniopRMSNorm(
xgqdut2016's avatar
xgqdut2016 committed
123
                descriptor,
124
                workspace.data(),
xgqdut2016's avatar
xgqdut2016 committed
125
                workspace_size.value,
126
127
128
                y.data(),
                x.data(),
                w.data(),
xgqdut2016's avatar
xgqdut2016 committed
129
130
                None,
            )
PanZezhongQY's avatar
PanZezhongQY committed
131
132
        )

xgqdut2016's avatar
xgqdut2016 committed
133
134
135
136
    lib_rms_norm()

    atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
    if DEBUG:
137
138
        debug(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol)
    assert torch.allclose(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol)
139

xgqdut2016's avatar
xgqdut2016 committed
140
141
142
    # Profiling workflow
    if PROFILE:
        # fmt: off
143
144
        profile_operation("PyTorch", lambda: rms_norm(y.torch_tensor(), x.torch_tensor(), w.torch_tensor(), eps), device, NUM_PRERUN, NUM_ITERATIONS)
        profile_operation("    lib", lambda: lib_rms_norm(), device, NUM_PRERUN, NUM_ITERATIONS)
xgqdut2016's avatar
xgqdut2016 committed
145
        # fmt: on
146
    check_error(LIBINFINIOP.infiniopDestroyRMSNormDescriptor(descriptor))
PanZezhongQY's avatar
PanZezhongQY committed
147

148

PanZezhongQY's avatar
PanZezhongQY committed
149
150
151
if __name__ == "__main__":
    args = get_args()

xgqdut2016's avatar
xgqdut2016 committed
152
153
154
155
156
157
158
159
    # Configure testing options
    DEBUG = args.debug
    PROFILE = args.profile
    NUM_PRERUN = args.num_prerun
    NUM_ITERATIONS = args.num_iterations

    # Execute tests
    for device in get_test_devices(args):
160
        test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES)
xgqdut2016's avatar
xgqdut2016 committed
161

PanZezhongQY's avatar
PanZezhongQY committed
162
    print("\033[92mTest passed!\033[0m")