rms_norm.py 5.33 KB
Newer Older
PanZezhongQY's avatar
PanZezhongQY committed
1
2
from ctypes import POINTER, Structure, c_int32, c_uint64, c_void_p, c_float
import ctypes
xgqdut2016's avatar
xgqdut2016 committed
3
4
5
6
import torch
import ctypes
from ctypes import POINTER, Structure, c_int32, c_size_t, c_uint64, c_void_p, c_float
from libinfiniop import (
PanZezhongQY's avatar
PanZezhongQY committed
7
8
    infiniopHandle_t,
    infiniopTensorDescriptor_t,
xgqdut2016's avatar
xgqdut2016 committed
9
10
11
    open_lib,
    to_tensor,
    get_test_devices,
PanZezhongQY's avatar
PanZezhongQY committed
12
    check_error,
xgqdut2016's avatar
xgqdut2016 committed
13
    rearrange_if_needed,
PanZezhongQY's avatar
PanZezhongQY committed
14
    create_workspace,
xgqdut2016's avatar
xgqdut2016 committed
15
16
17
18
19
    test_operator,
    get_args,
    debug,
    get_tolerance,
    profile_operation,
PanZezhongQY's avatar
PanZezhongQY committed
20
21
)

xgqdut2016's avatar
xgqdut2016 committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# ==============================================================================
#  Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules

_TEST_CASES = [
    # y_shape, x_shape, w_shape, y_stride, x_stride, w_dtype
    ((16, 2048), (16, 2048), (2048,), None, None,torch.float32),
    ((16, 2048), (16, 2048), (2048,), None, None, torch.float16),
    ((16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1),torch.float32),
    ((16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1), torch.float16),
]
# x types used for testing
_TENSOR_DTYPES = [torch.float16, torch.float32]

# Tolerance map for different data types
_TOLERANCE_MAP = {
    torch.float16: {"atol": 0, "rtol": 1e-2},
    torch.float32: {"atol": 0, "rtol": 1e-3},
}

DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000
47

PanZezhongQY's avatar
PanZezhongQY committed
48
49
50
51
52
53
54
55
56
57
58
59
60
61
class RMSNormDescriptor(Structure):
    _fields_ = [("device", c_int32)]


infiniopRMSNormDescriptor_t = POINTER(RMSNormDescriptor)

def rms_norm(x, w, eps):
    input_dtype = x.dtype
    hidden_states = x.to(torch.float32)
    variance = hidden_states.pow(2).mean(-1, keepdim=True)
    hidden_states = hidden_states * torch.rsqrt(variance + eps)
    return w * hidden_states.to(input_dtype)


62
def test(
xgqdut2016's avatar
xgqdut2016 committed
63
64
65
66
67
68
69
70
71
72
73
74
    lib, 
    handle, 
    torch_device, 
    y_shape, 
    x_shape, 
    w_shape, 
    y_stride,
    x_stride,
    dtype=torch.float16, 
    w_dtype=torch.float16):
    print(f"Testing RMS_Norm on {torch_device} with y_shape:{y_shape} x_shape:{x_shape} w_shape:{w_shape}"
        f" dtype:{dtype} w_dtype:{w_dtype}")
PanZezhongQY's avatar
PanZezhongQY committed
75
76
77
78
79
80
81
82

    y = torch.zeros(y_shape, dtype=dtype).to(torch_device)
    x = torch.rand(x_shape, dtype=dtype).to(torch_device)
    w = torch.ones(w_shape, dtype=w_dtype).to(torch_device)

    eps = 1e-5
    ans = rms_norm(x, w, eps)

xgqdut2016's avatar
xgqdut2016 committed
83
84
85
86
    x = rearrange_if_needed(x, x_stride)
    y = rearrange_if_needed(y, y_stride)

    x_tensor, y_tensor, w_tensor = [to_tensor(tensor, lib) for tensor in [x, y, w]]
PanZezhongQY's avatar
PanZezhongQY committed
87
88

    descriptor = infiniopRMSNormDescriptor_t()
xgqdut2016's avatar
xgqdut2016 committed
89
    w_dataType = 0 if w_dtype==torch.float16 else 1
PanZezhongQY's avatar
PanZezhongQY committed
90
91
92

    check_error(
        lib.infiniopCreateRMSNormDescriptor(
xgqdut2016's avatar
xgqdut2016 committed
93
94
            handle, ctypes.byref(descriptor), y_tensor.descriptor, x_tensor.descriptor,
            w_tensor.descriptor, eps
PanZezhongQY's avatar
PanZezhongQY committed
95
96
97
98
        )
    )

    # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
xgqdut2016's avatar
xgqdut2016 committed
99
100
    for tensor in [x_tensor, y_tensor, w_tensor]:
        tensor.descriptor.contents.invalidate()
PanZezhongQY's avatar
PanZezhongQY committed
101
102
103

    workspace_size = c_uint64(0)
    check_error(
xgqdut2016's avatar
xgqdut2016 committed
104
105
106
        lib.infiniopGetRMSNormWorkspaceSize(
            descriptor, ctypes.byref(workspace_size)
        )
PanZezhongQY's avatar
PanZezhongQY committed
107
108
    )
    workspace = create_workspace(workspace_size.value, y.device)
xgqdut2016's avatar
xgqdut2016 committed
109
110
111
112
113
114
115
116
117
118
119
    def lib_rms_norm():
        check_error(
            lib.infiniopRMSNorm(
                descriptor,
                workspace.data_ptr() if workspace is not None else None,
                workspace_size.value,
                y_tensor.data,
                x_tensor.data,
                w_tensor.data,
                None,
            )
PanZezhongQY's avatar
PanZezhongQY committed
120
121
        )

xgqdut2016's avatar
xgqdut2016 committed
122
123
124
125
126
127
128
129
130
131
132
133
    lib_rms_norm()

    atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
    if DEBUG:
        debug(y, ans, atol=atol, rtol=rtol)
    assert torch.allclose(y, ans, atol=atol, rtol=rtol)
    # Profiling workflow
    if PROFILE:
        # fmt: off
        profile_operation("PyTorch", lambda: rms_norm(x, w, eps), torch_device, NUM_PRERUN, NUM_ITERATIONS)
        profile_operation("    lib", lambda: lib_rms_norm(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
        # fmt: on
PanZezhongQY's avatar
PanZezhongQY committed
134
135
    check_error(lib.infiniopDestroyRMSNormDescriptor(descriptor))

136

PanZezhongQY's avatar
PanZezhongQY committed
137

138

PanZezhongQY's avatar
PanZezhongQY committed
139
if __name__ == "__main__":
xgqdut2016's avatar
xgqdut2016 committed
140
    
PanZezhongQY's avatar
PanZezhongQY committed
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
    args = get_args()
    lib = open_lib()
    lib.infiniopCreateRMSNormDescriptor.restype = c_int32
    lib.infiniopCreateRMSNormDescriptor.argtypes = [
        infiniopHandle_t,
        POINTER(infiniopRMSNormDescriptor_t),
        infiniopTensorDescriptor_t,
        infiniopTensorDescriptor_t,
        infiniopTensorDescriptor_t,
        c_float,
    ]

    lib.infiniopGetRMSNormWorkspaceSize.restype = c_int32
    lib.infiniopGetRMSNormWorkspaceSize.argtypes = [
        infiniopRMSNormDescriptor_t,
        POINTER(c_uint64),
    ]

    lib.infiniopRMSNorm.restypes = c_int32
    lib.infiniopRMSNorm.argtypes = [
        infiniopRMSNormDescriptor_t,
        c_void_p,
        c_uint64,
        c_void_p,
        c_void_p,
        c_void_p,
        c_void_p,
    ]
    lib.infiniopDestroyRMSNormDescriptor.restype = c_int32
    lib.infiniopDestroyRMSNormDescriptor.argtypes = [
        infiniopRMSNormDescriptor_t,
    ]

xgqdut2016's avatar
xgqdut2016 committed
174
175
176
177
178
179
180
181
182
183
    # Configure testing options
    DEBUG = args.debug
    PROFILE = args.profile
    NUM_PRERUN = args.num_prerun
    NUM_ITERATIONS = args.num_iterations

    # Execute tests
    for device in get_test_devices(args):
        test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES)

PanZezhongQY's avatar
PanZezhongQY committed
184
    print("\033[92mTest passed!\033[0m")
xgqdut2016's avatar
xgqdut2016 committed
185
186