rms_norm.py 5.81 KB
Newer Older
PanZezhongQY's avatar
PanZezhongQY committed
1
2
from ctypes import POINTER, Structure, c_int32, c_uint64, c_void_p, c_float
import ctypes
xgqdut2016's avatar
xgqdut2016 committed
3
4
import torch
import ctypes
5
from ctypes import POINTER, Structure, c_int32, c_uint64, c_void_p, c_float
xgqdut2016's avatar
xgqdut2016 committed
6
from libinfiniop import (
PanZezhongQY's avatar
PanZezhongQY committed
7
8
    infiniopHandle_t,
    infiniopTensorDescriptor_t,
xgqdut2016's avatar
xgqdut2016 committed
9
10
11
    open_lib,
    to_tensor,
    get_test_devices,
PanZezhongQY's avatar
PanZezhongQY committed
12
    check_error,
xgqdut2016's avatar
xgqdut2016 committed
13
    rearrange_if_needed,
PanZezhongQY's avatar
PanZezhongQY committed
14
    create_workspace,
xgqdut2016's avatar
xgqdut2016 committed
15
16
17
18
19
    test_operator,
    get_args,
    debug,
    get_tolerance,
    profile_operation,
PanZezhongQY's avatar
PanZezhongQY committed
20
21
)

xgqdut2016's avatar
xgqdut2016 committed
22
23
24
25
# ==============================================================================
#  Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
26
27
28
29
30
31
32
33
_TEST_CASES_ = [
    # y_shape, x_shape, w_shape, y_stride, x_stride
    ((1, 4), (1, 4), (4,), None, None),
    ((1, 4), (1, 4), (4,), None, None),
    ((16, 2048), (16, 2048), (2048,), None, None),
    ((16, 2048), (16, 2048), (2048,), None, None),
    ((16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1)),
    ((16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1)),
xgqdut2016's avatar
xgqdut2016 committed
34
]
xgqdut2016's avatar
xgqdut2016 committed
35

36
37
38
# w (weight) types 
# Note: 'None' means the same as input dtype
_WEIGHT_DTYPES = [None, torch.float32]
xgqdut2016's avatar
xgqdut2016 committed
39
# x types used for testing
40
41
42
43
44
45
46
47
_TENSOR_DTYPES = [torch.float16, torch.bfloat16]

# Form the test cases by appending each element of _WEIGHT_DTYPES to each tuple in _TEST_CASES_
_TEST_CASES = [
    test_case + (w_dtype,)
    for test_case in _TEST_CASES_
    for w_dtype in _WEIGHT_DTYPES
]
xgqdut2016's avatar
xgqdut2016 committed
48
49
50

# Tolerance map for different data types
_TOLERANCE_MAP = {
51
52
    torch.float16: {"atol": 2e-3, "rtol": 2e-3},
    torch.bfloat16: {"atol": 8e-3, "rtol": 8e-3},
xgqdut2016's avatar
xgqdut2016 committed
53
54
55
56
57
58
}

DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000
59

xgqdut2016's avatar
xgqdut2016 committed
60

PanZezhongQY's avatar
PanZezhongQY committed
61
62
63
64
65
66
class RMSNormDescriptor(Structure):
    _fields_ = [("device", c_int32)]


infiniopRMSNormDescriptor_t = POINTER(RMSNormDescriptor)

xgqdut2016's avatar
xgqdut2016 committed
67

68
69
70
71
72
73
74
def rms_norm(ans, x, w, eps):
    torch.pow(x, 2, out=ans)
    mean = torch.mean(ans, dim=-1, keepdim=True)
    mean.add_(eps)
    torch.rsqrt(mean, out=mean)
    torch.mul(x, mean, out=ans)
    ans.mul_(w)
PanZezhongQY's avatar
PanZezhongQY committed
75
76


77
def test(
xgqdut2016's avatar
xgqdut2016 committed
78
79
80
81
82
83
    lib,
    handle,
    torch_device,
    y_shape,
    x_shape,
    w_shape,
xgqdut2016's avatar
xgqdut2016 committed
84
85
    y_stride,
    x_stride,
xgqdut2016's avatar
xgqdut2016 committed
86
    w_dtype=torch.float16,
87
    dtype=torch.float16,
88
    sync=None,
xgqdut2016's avatar
xgqdut2016 committed
89
90
91
):
    print(
        f"Testing RMS_Norm on {torch_device} with y_shape:{y_shape} x_shape:{x_shape} w_shape:{w_shape}"
92
        f" y_stride:{y_stride} x_stride:{x_stride} w_dtype:{w_dtype} dtype:{dtype}"
xgqdut2016's avatar
xgqdut2016 committed
93
    )
PanZezhongQY's avatar
PanZezhongQY committed
94

95
    w_dtype = w_dtype if w_dtype else dtype
PanZezhongQY's avatar
PanZezhongQY committed
96
97
    y = torch.zeros(y_shape, dtype=dtype).to(torch_device)
    x = torch.rand(x_shape, dtype=dtype).to(torch_device)
98
    w = torch.rand(w_shape, dtype=w_dtype).to(torch_device)
99
    ans = torch.zeros(y_shape, dtype=dtype).to(torch_device)
PanZezhongQY's avatar
PanZezhongQY committed
100
101

    eps = 1e-5
102
    rms_norm(ans, x, w, eps)
PanZezhongQY's avatar
PanZezhongQY committed
103

xgqdut2016's avatar
xgqdut2016 committed
104
105
106
107
    x, y = [
        rearrange_if_needed(tensor, stride)
        for tensor, stride in zip([x, y], [x_stride, y_stride])
    ]
xgqdut2016's avatar
xgqdut2016 committed
108
    x_tensor, y_tensor, w_tensor = [to_tensor(tensor, lib) for tensor in [x, y, w]]
109

110
111
    if sync is not None:
        sync()
112

PanZezhongQY's avatar
PanZezhongQY committed
113
114
115
116
    descriptor = infiniopRMSNormDescriptor_t()

    check_error(
        lib.infiniopCreateRMSNormDescriptor(
xgqdut2016's avatar
xgqdut2016 committed
117
118
119
120
121
122
            handle,
            ctypes.byref(descriptor),
            y_tensor.descriptor,
            x_tensor.descriptor,
            w_tensor.descriptor,
            eps,
PanZezhongQY's avatar
PanZezhongQY committed
123
124
125
126
        )
    )

    # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
xgqdut2016's avatar
xgqdut2016 committed
127
    for tensor in [x_tensor, y_tensor, w_tensor]:
128
        tensor.destroyDesc(lib)
PanZezhongQY's avatar
PanZezhongQY committed
129
130
131

    workspace_size = c_uint64(0)
    check_error(
xgqdut2016's avatar
xgqdut2016 committed
132
        lib.infiniopGetRMSNormWorkspaceSize(descriptor, ctypes.byref(workspace_size))
PanZezhongQY's avatar
PanZezhongQY committed
133
134
    )
    workspace = create_workspace(workspace_size.value, y.device)
xgqdut2016's avatar
xgqdut2016 committed
135

xgqdut2016's avatar
xgqdut2016 committed
136
137
138
139
140
141
142
143
144
145
146
    def lib_rms_norm():
        check_error(
            lib.infiniopRMSNorm(
                descriptor,
                workspace.data_ptr() if workspace is not None else None,
                workspace_size.value,
                y_tensor.data,
                x_tensor.data,
                w_tensor.data,
                None,
            )
PanZezhongQY's avatar
PanZezhongQY committed
147
148
        )

xgqdut2016's avatar
xgqdut2016 committed
149
150
151
152
153
154
    lib_rms_norm()

    atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
    if DEBUG:
        debug(y, ans, atol=atol, rtol=rtol)
    assert torch.allclose(y, ans, atol=atol, rtol=rtol)
155

xgqdut2016's avatar
xgqdut2016 committed
156
157
158
    # Profiling workflow
    if PROFILE:
        # fmt: off
159
        profile_operation("PyTorch", lambda: rms_norm(ans, x, w, eps), torch_device, NUM_PRERUN, NUM_ITERATIONS)
xgqdut2016's avatar
xgqdut2016 committed
160
161
        profile_operation("    lib", lambda: lib_rms_norm(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
        # fmt: on
PanZezhongQY's avatar
PanZezhongQY committed
162
163
    check_error(lib.infiniopDestroyRMSNormDescriptor(descriptor))

164

PanZezhongQY's avatar
PanZezhongQY committed
165
166
167
if __name__ == "__main__":
    args = get_args()
    lib = open_lib()
xgqdut2016's avatar
xgqdut2016 committed
168

PanZezhongQY's avatar
PanZezhongQY committed
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
    lib.infiniopCreateRMSNormDescriptor.restype = c_int32
    lib.infiniopCreateRMSNormDescriptor.argtypes = [
        infiniopHandle_t,
        POINTER(infiniopRMSNormDescriptor_t),
        infiniopTensorDescriptor_t,
        infiniopTensorDescriptor_t,
        infiniopTensorDescriptor_t,
        c_float,
    ]

    lib.infiniopGetRMSNormWorkspaceSize.restype = c_int32
    lib.infiniopGetRMSNormWorkspaceSize.argtypes = [
        infiniopRMSNormDescriptor_t,
        POINTER(c_uint64),
    ]

185
    lib.infiniopRMSNorm.restype = c_int32
PanZezhongQY's avatar
PanZezhongQY committed
186
187
188
189
190
191
192
193
194
    lib.infiniopRMSNorm.argtypes = [
        infiniopRMSNormDescriptor_t,
        c_void_p,
        c_uint64,
        c_void_p,
        c_void_p,
        c_void_p,
        c_void_p,
    ]
xgqdut2016's avatar
xgqdut2016 committed
195

PanZezhongQY's avatar
PanZezhongQY committed
196
197
198
199
200
    lib.infiniopDestroyRMSNormDescriptor.restype = c_int32
    lib.infiniopDestroyRMSNormDescriptor.argtypes = [
        infiniopRMSNormDescriptor_t,
    ]

xgqdut2016's avatar
xgqdut2016 committed
201
202
203
204
205
206
207
208
209
210
    # Configure testing options
    DEBUG = args.debug
    PROFILE = args.profile
    NUM_PRERUN = args.num_prerun
    NUM_ITERATIONS = args.num_iterations

    # Execute tests
    for device in get_test_devices(args):
        test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES)

PanZezhongQY's avatar
PanZezhongQY committed
211
    print("\033[92mTest passed!\033[0m")