rms_norm.py 5.44 KB
Newer Older
PanZezhongQY's avatar
PanZezhongQY committed
1
2
from ctypes import POINTER, Structure, c_int32, c_uint64, c_void_p, c_float
import ctypes
xgqdut2016's avatar
xgqdut2016 committed
3
4
import torch
import ctypes
5
from ctypes import POINTER, Structure, c_int32, c_uint64, c_void_p, c_float
xgqdut2016's avatar
xgqdut2016 committed
6
from libinfiniop import (
PanZezhongQY's avatar
PanZezhongQY committed
7
8
    infiniopHandle_t,
    infiniopTensorDescriptor_t,
xgqdut2016's avatar
xgqdut2016 committed
9
10
11
    open_lib,
    to_tensor,
    get_test_devices,
PanZezhongQY's avatar
PanZezhongQY committed
12
    check_error,
xgqdut2016's avatar
xgqdut2016 committed
13
    rearrange_if_needed,
PanZezhongQY's avatar
PanZezhongQY committed
14
    create_workspace,
xgqdut2016's avatar
xgqdut2016 committed
15
16
17
18
19
    test_operator,
    get_args,
    debug,
    get_tolerance,
    profile_operation,
PanZezhongQY's avatar
PanZezhongQY committed
20
21
)

xgqdut2016's avatar
xgqdut2016 committed
22
23
24
25
26
27
# ==============================================================================
#  Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES = [
    # y_shape, x_shape, w_shape, y_stride, x_stride, w_dtype
28
    ((1, 4), (1, 4), (4,), None, None, torch.float32),
xgqdut2016's avatar
xgqdut2016 committed
29
    ((16, 2048), (16, 2048), (2048,), None, None, torch.float32),
30
    ((16, 2048), (16, 2048), (2048,), None, None, torch.float16),
xgqdut2016's avatar
xgqdut2016 committed
31
    ((16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1), torch.float32),
32
    ((16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1), torch.float16),
xgqdut2016's avatar
xgqdut2016 committed
33
]
xgqdut2016's avatar
xgqdut2016 committed
34

xgqdut2016's avatar
xgqdut2016 committed
35
# x types used for testing
36
_TENSOR_DTYPES = [torch.float16]
xgqdut2016's avatar
xgqdut2016 committed
37
38
39

# Tolerance map for different data types
_TOLERANCE_MAP = {
40
    torch.float16: {"atol": 1e-3, "rtol": 1e-3},
xgqdut2016's avatar
xgqdut2016 committed
41
42
43
44
45
46
}

DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000
47

xgqdut2016's avatar
xgqdut2016 committed
48

PanZezhongQY's avatar
PanZezhongQY committed
49
50
51
52
53
54
class RMSNormDescriptor(Structure):
    _fields_ = [("device", c_int32)]


infiniopRMSNormDescriptor_t = POINTER(RMSNormDescriptor)

xgqdut2016's avatar
xgqdut2016 committed
55

56
57
58
59
60
61
62
def rms_norm(ans, x, w, eps):
    torch.pow(x, 2, out=ans)
    mean = torch.mean(ans, dim=-1, keepdim=True)
    mean.add_(eps)
    torch.rsqrt(mean, out=mean)
    torch.mul(x, mean, out=ans)
    ans.mul_(w)
PanZezhongQY's avatar
PanZezhongQY committed
63
64


65
def test(
xgqdut2016's avatar
xgqdut2016 committed
66
67
68
69
70
71
    lib,
    handle,
    torch_device,
    y_shape,
    x_shape,
    w_shape,
xgqdut2016's avatar
xgqdut2016 committed
72
73
    y_stride,
    x_stride,
xgqdut2016's avatar
xgqdut2016 committed
74
    w_dtype=torch.float16,
75
    dtype=torch.float16,
76
    sync=None
xgqdut2016's avatar
xgqdut2016 committed
77
78
79
):
    print(
        f"Testing RMS_Norm on {torch_device} with y_shape:{y_shape} x_shape:{x_shape} w_shape:{w_shape}"
80
        f" y_stride:{y_stride} x_stride:{x_stride} w_dtype:{w_dtype} dtype:{dtype}"
xgqdut2016's avatar
xgqdut2016 committed
81
    )
PanZezhongQY's avatar
PanZezhongQY committed
82
83
84

    y = torch.zeros(y_shape, dtype=dtype).to(torch_device)
    x = torch.rand(x_shape, dtype=dtype).to(torch_device)
85
    w = torch.rand(w_shape, dtype=w_dtype).to(torch_device)
86
    ans = torch.zeros(y_shape, dtype=dtype).to(torch_device)
PanZezhongQY's avatar
PanZezhongQY committed
87
88

    eps = 1e-5
89
    rms_norm(ans, x, w, eps)
PanZezhongQY's avatar
PanZezhongQY committed
90

xgqdut2016's avatar
xgqdut2016 committed
91
92
93
94
    x, y = [
        rearrange_if_needed(tensor, stride)
        for tensor, stride in zip([x, y], [x_stride, y_stride])
    ]
xgqdut2016's avatar
xgqdut2016 committed
95
    x_tensor, y_tensor, w_tensor = [to_tensor(tensor, lib) for tensor in [x, y, w]]
96
97
98
99
    
    if sync is not None:
        sync()
    
PanZezhongQY's avatar
PanZezhongQY committed
100
101
102
103
    descriptor = infiniopRMSNormDescriptor_t()

    check_error(
        lib.infiniopCreateRMSNormDescriptor(
xgqdut2016's avatar
xgqdut2016 committed
104
105
106
107
108
109
            handle,
            ctypes.byref(descriptor),
            y_tensor.descriptor,
            x_tensor.descriptor,
            w_tensor.descriptor,
            eps,
PanZezhongQY's avatar
PanZezhongQY committed
110
111
112
113
        )
    )

    # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
xgqdut2016's avatar
xgqdut2016 committed
114
    for tensor in [x_tensor, y_tensor, w_tensor]:
115
        tensor.destroyDesc(lib)
PanZezhongQY's avatar
PanZezhongQY committed
116
117
118

    workspace_size = c_uint64(0)
    check_error(
xgqdut2016's avatar
xgqdut2016 committed
119
        lib.infiniopGetRMSNormWorkspaceSize(descriptor, ctypes.byref(workspace_size))
PanZezhongQY's avatar
PanZezhongQY committed
120
121
    )
    workspace = create_workspace(workspace_size.value, y.device)
xgqdut2016's avatar
xgqdut2016 committed
122

xgqdut2016's avatar
xgqdut2016 committed
123
124
125
126
127
128
129
130
131
132
133
    def lib_rms_norm():
        check_error(
            lib.infiniopRMSNorm(
                descriptor,
                workspace.data_ptr() if workspace is not None else None,
                workspace_size.value,
                y_tensor.data,
                x_tensor.data,
                w_tensor.data,
                None,
            )
PanZezhongQY's avatar
PanZezhongQY committed
134
135
        )

xgqdut2016's avatar
xgqdut2016 committed
136
137
138
139
140
141
    lib_rms_norm()

    atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
    if DEBUG:
        debug(y, ans, atol=atol, rtol=rtol)
    assert torch.allclose(y, ans, atol=atol, rtol=rtol)
142

xgqdut2016's avatar
xgqdut2016 committed
143
144
145
    # Profiling workflow
    if PROFILE:
        # fmt: off
146
        profile_operation("PyTorch", lambda: rms_norm(ans, x, w, eps), torch_device, NUM_PRERUN, NUM_ITERATIONS)
xgqdut2016's avatar
xgqdut2016 committed
147
148
        profile_operation("    lib", lambda: lib_rms_norm(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
        # fmt: on
PanZezhongQY's avatar
PanZezhongQY committed
149
150
    check_error(lib.infiniopDestroyRMSNormDescriptor(descriptor))

151

PanZezhongQY's avatar
PanZezhongQY committed
152
153
154
if __name__ == "__main__":
    args = get_args()
    lib = open_lib()
xgqdut2016's avatar
xgqdut2016 committed
155

PanZezhongQY's avatar
PanZezhongQY committed
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
    lib.infiniopCreateRMSNormDescriptor.restype = c_int32
    lib.infiniopCreateRMSNormDescriptor.argtypes = [
        infiniopHandle_t,
        POINTER(infiniopRMSNormDescriptor_t),
        infiniopTensorDescriptor_t,
        infiniopTensorDescriptor_t,
        infiniopTensorDescriptor_t,
        c_float,
    ]

    lib.infiniopGetRMSNormWorkspaceSize.restype = c_int32
    lib.infiniopGetRMSNormWorkspaceSize.argtypes = [
        infiniopRMSNormDescriptor_t,
        POINTER(c_uint64),
    ]

    lib.infiniopRMSNorm.restypes = c_int32
    lib.infiniopRMSNorm.argtypes = [
        infiniopRMSNormDescriptor_t,
        c_void_p,
        c_uint64,
        c_void_p,
        c_void_p,
        c_void_p,
        c_void_p,
    ]
xgqdut2016's avatar
xgqdut2016 committed
182

PanZezhongQY's avatar
PanZezhongQY committed
183
184
185
186
187
    lib.infiniopDestroyRMSNormDescriptor.restype = c_int32
    lib.infiniopDestroyRMSNormDescriptor.argtypes = [
        infiniopRMSNormDescriptor_t,
    ]

xgqdut2016's avatar
xgqdut2016 committed
188
189
190
191
192
193
194
195
196
197
    # Configure testing options
    DEBUG = args.debug
    PROFILE = args.profile
    NUM_PRERUN = args.num_prerun
    NUM_ITERATIONS = args.num_iterations

    # Execute tests
    for device in get_test_devices(args):
        test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES)

PanZezhongQY's avatar
PanZezhongQY committed
198
    print("\033[92mTest passed!\033[0m")