"vllm/vscode:/vscode.git/clone" did not exist on "91a61da9b12c483a6688841b8f860c1a32b8918c"
rms_norm.py 5.35 KB
Newer Older
PanZezhongQY's avatar
PanZezhongQY committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from ctypes import POINTER, Structure, c_int32, c_uint64, c_void_p, c_float
import ctypes
import sys
import os

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
from operatorspy import (
    open_lib,
    to_tensor,
    DeviceEnum,
    infiniopHandle_t,
    infiniopTensorDescriptor_t,
    create_handle,
    destroy_handle,
    check_error,
    rearrange_tensor,
    create_workspace,
)

from operatorspy.tests.test_utils import get_args
import torch

23

PanZezhongQY's avatar
PanZezhongQY committed
24
25
26
27
28
29
class RMSNormDescriptor(Structure):
    _fields_ = [("device", c_int32)]


infiniopRMSNormDescriptor_t = POINTER(RMSNormDescriptor)

30

PanZezhongQY's avatar
PanZezhongQY committed
31
32
33
34
35
36
37
38
def rms_norm(x, w, eps):
    input_dtype = x.dtype
    hidden_states = x.to(torch.float32)
    variance = hidden_states.pow(2).mean(-1, keepdim=True)
    hidden_states = hidden_states * torch.rsqrt(variance + eps)
    return w * hidden_states.to(input_dtype)


39
40
41
42
43
44
45
46
47
48
49
50
51
52
def test(
    lib,
    handle,
    torch_device,
    y_shape,
    x_shape,
    w_shape,
    dtype=torch.float16,
    w_dtype=torch.float16,
):
    print(
        f"Testing RMS_Norm on {torch_device} with y_shape:{y_shape} x_shape:{x_shape} w_shape:{w_shape}"
        f" dtype:{dtype} w_dtype:{w_dtype}"
    )
PanZezhongQY's avatar
PanZezhongQY committed
53
54
55
56
57
58
59
60
61
62
63
64
65

    y = torch.zeros(y_shape, dtype=dtype).to(torch_device)
    x = torch.rand(x_shape, dtype=dtype).to(torch_device)
    w = torch.ones(w_shape, dtype=w_dtype).to(torch_device)

    eps = 1e-5
    ans = rms_norm(x, w, eps)

    y_tensor = to_tensor(y, lib)
    x_tensor = to_tensor(x, lib)
    w_tensor = to_tensor(w, lib)

    descriptor = infiniopRMSNormDescriptor_t()
66
    w_dataType = 0 if w_dtype == torch.float16 else 1
PanZezhongQY's avatar
PanZezhongQY committed
67
68
69

    check_error(
        lib.infiniopCreateRMSNormDescriptor(
70
71
72
73
74
75
            handle,
            ctypes.byref(descriptor),
            y_tensor.descriptor,
            x_tensor.descriptor,
            w_tensor.descriptor,
            eps,
PanZezhongQY's avatar
PanZezhongQY committed
76
77
78
79
80
81
82
83
84
85
        )
    )

    # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
    x_tensor.descriptor.contents.invalidate()
    y_tensor.descriptor.contents.invalidate()
    w_tensor.descriptor.contents.invalidate()

    workspace_size = c_uint64(0)
    check_error(
86
        lib.infiniopGetRMSNormWorkspaceSize(descriptor, ctypes.byref(workspace_size))
PanZezhongQY's avatar
PanZezhongQY committed
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
    )
    workspace = create_workspace(workspace_size.value, y.device)
    check_error(
        lib.infiniopRMSNorm(
            descriptor,
            workspace.data_ptr() if workspace is not None else None,
            workspace_size.value,
            y_tensor.data,
            x_tensor.data,
            w_tensor.data,
            None,
        )
    )

    assert torch.allclose(y.to(dtype), ans.to(dtype), atol=1e-3, rtol=1e-3)
    check_error(lib.infiniopDestroyRMSNormDescriptor(descriptor))

104

PanZezhongQY's avatar
PanZezhongQY committed
105
106
107
def test_cpu(lib, test_cases):
    device = DeviceEnum.DEVICE_CPU
    handle = create_handle(lib, device)
108
    for y_shape, x_shape, w_shape, dtype, w_dtype in test_cases:
PanZezhongQY's avatar
PanZezhongQY committed
109
110
111
        test(lib, handle, "cpu", y_shape, x_shape, w_shape, dtype, w_dtype)
    destroy_handle(lib, handle)

112

PanZezhongQY's avatar
PanZezhongQY committed
113
114
115
def test_cuda(lib, test_cases):
    device = DeviceEnum.DEVICE_CUDA
    handle = create_handle(lib, device)
116
    for y_shape, x_shape, w_shape, dtype, w_dtype in test_cases:
PanZezhongQY's avatar
PanZezhongQY committed
117
118
119
        test(lib, handle, "cuda", y_shape, x_shape, w_shape, dtype, w_dtype)
    destroy_handle(lib, handle)

120

PanZezhongQY's avatar
PanZezhongQY committed
121
122
def test_bang(lib, test_cases):
    import torch_mlu
123

PanZezhongQY's avatar
PanZezhongQY committed
124
125
    device = DeviceEnum.DEVICE_BANG
    handle = create_handle(lib, device)
126
    for y_shape, x_shape, w_shape, dtype, w_dtype in test_cases:
PanZezhongQY's avatar
PanZezhongQY committed
127
128
129
        test(lib, handle, "mlu", y_shape, x_shape, w_shape, dtype, w_dtype)
    destroy_handle(lib, handle)

130

PanZezhongQY's avatar
PanZezhongQY committed
131
132
def test_ascend(lib, test_cases):
    import torch_npu
133

PanZezhongQY's avatar
PanZezhongQY committed
134
135
    device = DeviceEnum.DEVICE_ASCEND
    handle = create_handle(lib, device)
136
    for y_shape, x_shape, w_shape, dtype, w_dtype in test_cases:
PanZezhongQY's avatar
PanZezhongQY committed
137
138
139
140
        test(lib, handle, "npu", y_shape, x_shape, w_shape, dtype, w_dtype)

    destroy_handle(lib, handle)

141

PanZezhongQY's avatar
PanZezhongQY committed
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
if __name__ == "__main__":
    test_cases = [
        # y_shape, x_shape, w_shape, dtype, w_dtype
        ((16, 2048), (16, 2048), (2048,), torch.float16, torch.float16),
        ((16, 2048), (16, 2048), (2048,), torch.float16, torch.float32),
    ]
    args = get_args()
    lib = open_lib()
    lib.infiniopCreateRMSNormDescriptor.restype = c_int32
    lib.infiniopCreateRMSNormDescriptor.argtypes = [
        infiniopHandle_t,
        POINTER(infiniopRMSNormDescriptor_t),
        infiniopTensorDescriptor_t,
        infiniopTensorDescriptor_t,
        infiniopTensorDescriptor_t,
        c_float,
    ]

    lib.infiniopGetRMSNormWorkspaceSize.restype = c_int32
    lib.infiniopGetRMSNormWorkspaceSize.argtypes = [
        infiniopRMSNormDescriptor_t,
        POINTER(c_uint64),
    ]

    lib.infiniopRMSNorm.restypes = c_int32
    lib.infiniopRMSNorm.argtypes = [
        infiniopRMSNormDescriptor_t,
        c_void_p,
        c_uint64,
        c_void_p,
        c_void_p,
        c_void_p,
        c_void_p,
    ]
    lib.infiniopDestroyRMSNormDescriptor.restype = c_int32
    lib.infiniopDestroyRMSNormDescriptor.argtypes = [
        infiniopRMSNormDescriptor_t,
    ]

    if args.cpu:
        test_cpu(lib, test_cases)
    if args.cuda:
        test_cuda(lib, test_cases)
    if args.bang:
        test_bang(lib, test_cases)
    if args.ascend:
        test_ascend(lib, test_cases)
    if not (args.cpu or args.cuda or args.bang or args.ascend):
        test_cpu(lib, test_cases)
    print("\033[92mTest passed!\033[0m")