rope.py 6.85 KB
Newer Older
xgqdut2016's avatar
xgqdut2016 committed
1
import torch
PanZezhongQY's avatar
PanZezhongQY committed
2
import ctypes
xgqdut2016's avatar
xgqdut2016 committed
3
from ctypes import POINTER, Structure, c_int32, c_size_t, c_uint64, c_void_p, c_float
4
from libinfiniop import (
5
    InfiniDtype,
PanZezhongQY's avatar
PanZezhongQY committed
6
7
    infiniopHandle_t,
    infiniopTensorDescriptor_t,
8
9
10
    open_lib,
    to_tensor,
    get_test_devices,
PanZezhongQY's avatar
PanZezhongQY committed
11
    check_error,
12
    rearrange_if_needed,
PanZezhongQY's avatar
PanZezhongQY committed
13
    create_workspace,
14
15
16
    test_operator,
    get_args,
    debug,
xgqdut2016's avatar
xgqdut2016 committed
17
    get_tolerance,
18
    profile_operation,
xgqdut2016's avatar
xgqdut2016 committed
19
    synchronize_device,
PanZezhongQY's avatar
PanZezhongQY committed
20
)
xgqdut2016's avatar
xgqdut2016 committed
21
22
23
24
25
26
27

# ==============================================================================
#  Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES = [
    # (t_shape, t_strides)
xgqdut2016's avatar
xgqdut2016 committed
28
29
30
31
32
33
34
    ((1, 32, 128), None),
    ((1, 32, 64), None),
    # 昇腾暂不满足这个用例,最后一维度 <=32 会有问题,可能与其核心
    # 接口 GatherMask 的内部实现相关,目前 48 64 128 都可以支持
    ((4, 1, 32), None),
    ((1, 32, 128), None),
    ((3, 32, 128), (8000, 200, 1)),
xgqdut2016's avatar
xgqdut2016 committed
35
36
37
]

# Data types used for testing
xgqdut2016's avatar
xgqdut2016 committed
38
_TENSOR_DTYPES = [torch.float16]
xgqdut2016's avatar
xgqdut2016 committed
39
40
41

# Tolerance map for different data types
_TOLERANCE_MAP = {
xgqdut2016's avatar
xgqdut2016 committed
42
    torch.float16: {"atol": 1e-4, "rtol": 1e-2},
xgqdut2016's avatar
xgqdut2016 committed
43
}
PanZezhongQY's avatar
PanZezhongQY committed
44

45
46
47
48
49
DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000

PanZezhongQY's avatar
PanZezhongQY committed
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67

class RoPEDescriptor(Structure):
    _fields_ = [("device", c_int32)]


infiniopRoPEDescriptor_t = POINTER(RoPEDescriptor)


def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[0], x.shape[-1])
    shape = [d if i == 0 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)


def rotary_embedding(t, pos, theta, torch_device):
    dh = t.shape[2]
68
69
70
71
72
73
74
75
76
77
78
79
80
81
    assert dh % 2 == 0, "Embedding dimension must be even."
    t_even = t[..., 0::2]  # [seq_len, n_head, dh // 2]
    t_odd = t[..., 1::2]  # [seq_len, n_head, dh // 2]
    freqs = (1.0 / (theta ** (torch.arange(0, dh, 2).float() / dh))).to(torch_device)
    freqs = torch.outer(pos, freqs)  # [seq_len, dh // 2]
    cos = torch.cos(freqs).unsqueeze(1)  # [seq_len, 1, dh // 2]
    sin = torch.sin(freqs).unsqueeze(1)  # [seq_len, 1, dh // 2]

    t_out_even = t_even * cos - t_odd * sin
    t_out_odd = t_even * sin + t_odd * cos

    t_out = torch.empty_like(t)
    t_out[..., 0::2] = t_out_even
    t_out[..., 1::2] = t_out_odd
82

PanZezhongQY's avatar
PanZezhongQY committed
83
84
    return t_out

85

PanZezhongQY's avatar
PanZezhongQY committed
86
87
88
89
90
91
92
93
94
95
96
97
98
def sin_cos_table(max_seq_len, dim, torch_device, theta):
    pos = torch.arange(
        0, max_seq_len, dtype=torch.float32, device=torch.device(torch_device)
    )
    freqs = (1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))).to(
        torch_device
    )
    # (a0, a1, a2) -> (a0, a0, a1, a1, a2, a2)
    freqs = torch.repeat_interleave(freqs, repeats=2)
    angles = torch.outer(pos, freqs)
    return torch.sin(angles), torch.cos(angles)


xgqdut2016's avatar
xgqdut2016 committed
99
def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16):
PanZezhongQY's avatar
PanZezhongQY committed
100
101
102
103
104
    print(
        f"Testing Rotary Positional Embedding on {torch_device} with shape:{shape} strides:{strides} and dtype:{dtype}"
    )

    t = torch.rand(shape, dtype=dtype)
xgqdut2016's avatar
xgqdut2016 committed
105
106
107

    t = rearrange_if_needed(t, strides)

108
    posTmp = torch.arange(0, t.shape[0]).to(torch_device)
109
    pos = torch.zeros(2 * posTmp.shape[0], dtype=torch.int32)
PanZezhongQY's avatar
PanZezhongQY committed
110
111
112
    for i in range(posTmp.shape[0]):
        pos[2 * i] = posTmp[i]
        pos[2 * i + 1] = 0
113
    pos = pos.to(torch_device)
PanZezhongQY's avatar
PanZezhongQY committed
114
    theta = 1e4
115
116

    ans = rotary_embedding(t, posTmp, theta, torch_device)
PanZezhongQY's avatar
PanZezhongQY committed
117
118
119
120

    descriptor = infiniopRoPEDescriptor_t()
    # 2x table length for test
    sin_table, cos_table = sin_cos_table(t.shape[0] * 2, t.shape[2], t.device, theta)
xgqdut2016's avatar
xgqdut2016 committed
121

xgqdut2016's avatar
xgqdut2016 committed
122
123
124
125
    t_tensor, sin_table_tensor, cos_table_tensor = [
        to_tensor(tensor, lib) for tensor in [t, sin_table, cos_table]
    ]

PanZezhongQY's avatar
PanZezhongQY committed
126
    pos_tensor = to_tensor(pos[: t.shape[0]], lib)
127
    pos_tensor.descriptor.contents.dtype = InfiniDtype.U64
PanZezhongQY's avatar
PanZezhongQY committed
128
129

    if torch_device == "npu":
xgqdut2016's avatar
xgqdut2016 committed
130
        synchronize_device(torch_device)
PanZezhongQY's avatar
PanZezhongQY committed
131
132
133
134

    check_error(
        lib.infiniopCreateRoPEDescriptor(
            handle,
135
            ctypes.byref(descriptor),
PanZezhongQY's avatar
PanZezhongQY committed
136
137
138
139
140
141
142
143
            t_tensor.descriptor,
            pos_tensor.descriptor,
            sin_table_tensor.descriptor,
            cos_table_tensor.descriptor,
        )
    )

    # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
xgqdut2016's avatar
xgqdut2016 committed
144
145
    for tensor in [t_tensor, pos_tensor, sin_table_tensor, cos_table_tensor]:
        tensor.descriptor.contents.invalidate()
PanZezhongQY's avatar
PanZezhongQY committed
146
147
148
149
150
151

    workspace_size = c_uint64(0)
    check_error(
        lib.infiniopGetRoPEWorkspaceSize(descriptor, ctypes.byref(workspace_size))
    )
    workspace = create_workspace(workspace_size.value, t.device)
152
153
154
155
156
157
158
159
160
161
162
163
164

    def lib_rope():
        check_error(
            lib.infiniopRoPE(
                descriptor,
                workspace.data_ptr() if workspace is not None else None,
                workspace_size.value,
                t_tensor.data,
                pos_tensor.data,
                sin_table_tensor.data,
                cos_table_tensor.data,
                None,
            )
PanZezhongQY's avatar
PanZezhongQY committed
165
166
        )

167
    lib_rope()
xgqdut2016's avatar
xgqdut2016 committed
168

xgqdut2016's avatar
xgqdut2016 committed
169
    atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
170
    if DEBUG:
xgqdut2016's avatar
xgqdut2016 committed
171
172
        debug(t, ans, atol=atol, rtol=rtol)
    assert torch.allclose(t, ans, atol=atol, rtol=rtol)
xgqdut2016's avatar
xgqdut2016 committed
173

174
175
176
177
178
179
180
181
182
183
184
    if PROFILE:
        profile_operation(
            "PyTorch",
            lambda: rotary_embedding(t, posTmp, theta, torch_device),
            torch_device,
            NUM_PRERUN,
            NUM_ITERATIONS,
        )
        profile_operation(
            "    lib", lambda: lib_rope(), torch_device, NUM_PRERUN, NUM_ITERATIONS
        )
PanZezhongQY's avatar
PanZezhongQY committed
185

186
    check_error(lib.infiniopDestroyRoPEDescriptor(descriptor))
PanZezhongQY's avatar
PanZezhongQY committed
187

188

PanZezhongQY's avatar
PanZezhongQY committed
189
190
191
if __name__ == "__main__":
    args = get_args()
    lib = open_lib()
xgqdut2016's avatar
xgqdut2016 committed
192

PanZezhongQY's avatar
PanZezhongQY committed
193
194
195
196
197
198
199
200
201
    lib.infiniopCreateRoPEDescriptor.restype = c_int32
    lib.infiniopCreateRoPEDescriptor.argtypes = [
        infiniopHandle_t,
        POINTER(infiniopRoPEDescriptor_t),
        infiniopTensorDescriptor_t,
        infiniopTensorDescriptor_t,
        infiniopTensorDescriptor_t,
        infiniopTensorDescriptor_t,
    ]
xgqdut2016's avatar
xgqdut2016 committed
202

PanZezhongQY's avatar
PanZezhongQY committed
203
204
205
206
207
    lib.infiniopGetRoPEWorkspaceSize.restype = c_int32
    lib.infiniopGetRoPEWorkspaceSize.argtypes = [
        infiniopRoPEDescriptor_t,
        POINTER(c_uint64),
    ]
xgqdut2016's avatar
xgqdut2016 committed
208

PanZezhongQY's avatar
PanZezhongQY committed
209
210
211
212
213
214
215
216
217
218
219
    lib.infiniopRoPE.restype = c_int32
    lib.infiniopRoPE.argtypes = [
        infiniopRoPEDescriptor_t,
        c_void_p,
        c_uint64,
        c_void_p,
        c_void_p,
        c_void_p,
        c_void_p,
        c_void_p,
    ]
xgqdut2016's avatar
xgqdut2016 committed
220

PanZezhongQY's avatar
PanZezhongQY committed
221
222
223
224
    lib.infiniopDestroyRoPEDescriptor.restype = c_int32
    lib.infiniopDestroyRoPEDescriptor.argtypes = [
        infiniopRoPEDescriptor_t,
    ]
xgqdut2016's avatar
xgqdut2016 committed
225

226
227
228
229
230
231
232
233
    # Configure testing options
    DEBUG = args.debug
    PROFILE = args.profile
    NUM_PRERUN = args.num_prerun
    NUM_ITERATIONS = args.num_iterations

    # Execute tests
    for device in get_test_devices(args):
xgqdut2016's avatar
xgqdut2016 committed
234
        test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES)
235

PanZezhongQY's avatar
PanZezhongQY committed
236
    print("\033[92mTest passed!\033[0m")