rms_norm.py 4.84 KB
Newer Older
xgqdut2016's avatar
xgqdut2016 committed
1
2
import torch
import ctypes
3
from ctypes import c_uint64
xgqdut2016's avatar
xgqdut2016 committed
4
from libinfiniop import (
5
6
    LIBINFINIOP,
    TestTensor,
xgqdut2016's avatar
xgqdut2016 committed
7
    get_test_devices,
PanZezhongQY's avatar
PanZezhongQY committed
8
    check_error,
xgqdut2016's avatar
xgqdut2016 committed
9
10
11
12
13
    test_operator,
    get_args,
    debug,
    get_tolerance,
    profile_operation,
14
15
16
17
18
    TestWorkspace,
    InfiniDtype,
    InfiniDtypeNames,
    InfiniDeviceNames,
    infiniopOperatorDescriptor_t,
PanZezhongQY's avatar
PanZezhongQY committed
19
20
)

xgqdut2016's avatar
xgqdut2016 committed
21
22
23
24
# ==============================================================================
#  Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
25
26
27
_TEST_CASES_ = [
    # y_shape, x_shape, w_shape, y_stride, x_stride
    ((1, 4), (1, 4), (4,), None, None),
28
29
30
    ((2, 4), (2, 4), (4,), None, None),
    ((2, 2, 4), (2, 2, 4), (4,), None, None),
    ((2, 2, 4), (2, 2, 4), (4,), (12, 8, 1), (12, 8, 1)),
31
32
    ((16, 2048), (16, 2048), (2048,), None, None),
    ((16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1)),
zhangyue's avatar
zhangyue committed
33
    ((15, 3584), (15, 3584), (3584,), None, None),
34
35
36
    ((4, 4, 2048), (4, 4, 2048), (2048,), None, None),
    ((4, 4, 2048), (4, 4, 2048), (2048,), (2048, 8192, 1), (2048, 8192, 1)),
    ((4, 4, 2048), (4, 4, 2048), (2048,), (16384, 4096, 1), (16384, 4096, 1)),
xgqdut2016's avatar
xgqdut2016 committed
37
]
xgqdut2016's avatar
xgqdut2016 committed
38

39
# w (weight) types
40
# Note: 'None' means the same as input dtype
blkmjsian's avatar
blkmjsian committed
41
_WEIGHT_DTYPES = [None, InfiniDtype.F32, InfiniDtype.F16, InfiniDtype.BF16]
xgqdut2016's avatar
xgqdut2016 committed
42
# x types used for testing
43
_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16]
44
45
46

# Form the test cases by appending each element of _WEIGHT_DTYPES to each tuple in _TEST_CASES_
_TEST_CASES = [
47
    test_case + (w_dtype,) for test_case in _TEST_CASES_ for w_dtype in _WEIGHT_DTYPES
48
]
xgqdut2016's avatar
xgqdut2016 committed
49
50
51

# Tolerance map for different data types
_TOLERANCE_MAP = {
52
    InfiniDtype.F16: {"atol": 2e-3, "rtol": 2e-3},
53
    InfiniDtype.BF16: {"atol": 1e-2, "rtol": 1e-2},
xgqdut2016's avatar
xgqdut2016 committed
54
55
56
57
58
59
}

DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000
60

xgqdut2016's avatar
xgqdut2016 committed
61

62
def rms_norm(ans, x, w, eps):
63
64
65
66
    input_dtype = x.dtype
    hidden_states = x.to(torch.float32)
    scale = hidden_states.pow(2).mean(-1, keepdim=True).add_(eps).rsqrt_()
    ans.set_((hidden_states.mul_(scale).mul_(w)).to(input_dtype))
PanZezhongQY's avatar
PanZezhongQY committed
67
68


69
def test(
xgqdut2016's avatar
xgqdut2016 committed
70
    handle,
71
    device,
xgqdut2016's avatar
xgqdut2016 committed
72
73
74
    y_shape,
    x_shape,
    w_shape,
xgqdut2016's avatar
xgqdut2016 committed
75
76
    y_stride,
    x_stride,
77
78
    w_dtype=InfiniDtype.F32,
    dtype=InfiniDtype.F16,
79
    sync=None,
xgqdut2016's avatar
xgqdut2016 committed
80
):
81
    w_dtype = w_dtype if w_dtype else dtype
xgqdut2016's avatar
xgqdut2016 committed
82
    print(
83
84
        f"Testing RMS_Norm on {InfiniDeviceNames[device]} with y_shape:{y_shape} x_shape:{x_shape} w_shape:{w_shape}"
        f" y_stride:{y_stride} x_stride:{x_stride} w_dtype:{InfiniDtypeNames[w_dtype]} dtype:{InfiniDtypeNames[dtype]}"
xgqdut2016's avatar
xgqdut2016 committed
85
    )
PanZezhongQY's avatar
PanZezhongQY committed
86

87
88
89
    y = TestTensor(y_shape, y_stride, dtype, device, mode="ones")
    x = TestTensor(x_shape, x_stride, dtype, device, scale=0.01)
    w = TestTensor(w_shape, None, w_dtype, device)
PanZezhongQY's avatar
PanZezhongQY committed
90

91
92
    eps = 1e-6
    rms_norm(y.torch_tensor(), x.torch_tensor(), w.torch_tensor(), eps)
93

94
95
    if sync is not None:
        sync()
96

97
    descriptor = infiniopOperatorDescriptor_t()
PanZezhongQY's avatar
PanZezhongQY committed
98
99

    check_error(
100
        LIBINFINIOP.infiniopCreateRMSNormDescriptor(
xgqdut2016's avatar
xgqdut2016 committed
101
102
            handle,
            ctypes.byref(descriptor),
103
104
105
            y.descriptor,
            x.descriptor,
            w.descriptor,
xgqdut2016's avatar
xgqdut2016 committed
106
            eps,
PanZezhongQY's avatar
PanZezhongQY committed
107
108
109
110
        )
    )

    # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
111
112
    for tensor in [x, y, w]:
        tensor.destroy_desc()
PanZezhongQY's avatar
PanZezhongQY committed
113
114
115

    workspace_size = c_uint64(0)
    check_error(
116
117
118
        LIBINFINIOP.infiniopGetRMSNormWorkspaceSize(
            descriptor, ctypes.byref(workspace_size)
        )
PanZezhongQY's avatar
PanZezhongQY committed
119
    )
120
    workspace = TestWorkspace(workspace_size.value, y.device)
xgqdut2016's avatar
xgqdut2016 committed
121

xgqdut2016's avatar
xgqdut2016 committed
122
123
    def lib_rms_norm():
        check_error(
124
            LIBINFINIOP.infiniopRMSNorm(
xgqdut2016's avatar
xgqdut2016 committed
125
                descriptor,
126
                workspace.data(),
xgqdut2016's avatar
xgqdut2016 committed
127
                workspace_size.value,
128
129
130
                y.data(),
                x.data(),
                w.data(),
xgqdut2016's avatar
xgqdut2016 committed
131
132
                None,
            )
PanZezhongQY's avatar
PanZezhongQY committed
133
134
        )

xgqdut2016's avatar
xgqdut2016 committed
135
136
137
138
    lib_rms_norm()

    atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
    if DEBUG:
139
140
        debug(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol)
    assert torch.allclose(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol)
141

xgqdut2016's avatar
xgqdut2016 committed
142
143
144
    # Profiling workflow
    if PROFILE:
        # fmt: off
145
146
        profile_operation("PyTorch", lambda: rms_norm(y.torch_tensor(), x.torch_tensor(), w.torch_tensor(), eps), device, NUM_PRERUN, NUM_ITERATIONS)
        profile_operation("    lib", lambda: lib_rms_norm(), device, NUM_PRERUN, NUM_ITERATIONS)
xgqdut2016's avatar
xgqdut2016 committed
147
        # fmt: on
148
    check_error(LIBINFINIOP.infiniopDestroyRMSNormDescriptor(descriptor))
PanZezhongQY's avatar
PanZezhongQY committed
149

150

PanZezhongQY's avatar
PanZezhongQY committed
151
152
153
if __name__ == "__main__":
    args = get_args()

xgqdut2016's avatar
xgqdut2016 committed
154
155
156
157
158
159
160
161
    # Configure testing options
    DEBUG = args.debug
    PROFILE = args.profile
    NUM_PRERUN = args.num_prerun
    NUM_ITERATIONS = args.num_iterations

    # Execute tests
    for device in get_test_devices(args):
162
        test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES)
xgqdut2016's avatar
xgqdut2016 committed
163

PanZezhongQY's avatar
PanZezhongQY committed
164
    print("\033[92mTest passed!\033[0m")