rotary_embedding.py 6.82 KB
Newer Older
xgqdut2016's avatar
xgqdut2016 committed
1
import torch
PanZezhongQY's avatar
PanZezhongQY committed
2
import ctypes
xgqdut2016's avatar
xgqdut2016 committed
3
from ctypes import POINTER, Structure, c_int32, c_size_t, c_uint64, c_void_p, c_float
4
from libinfiniop import (
PanZezhongQY's avatar
PanZezhongQY committed
5
6
    infiniopHandle_t,
    infiniopTensorDescriptor_t,
7
8
9
    open_lib,
    to_tensor,
    get_test_devices,
PanZezhongQY's avatar
PanZezhongQY committed
10
    check_error,
11
    rearrange_if_needed,
PanZezhongQY's avatar
PanZezhongQY committed
12
    create_workspace,
13
14
15
    test_operator,
    get_args,
    debug,
xgqdut2016's avatar
xgqdut2016 committed
16
    get_tolerance,
17
    profile_operation,
xgqdut2016's avatar
xgqdut2016 committed
18
    synchronize_device,
PanZezhongQY's avatar
PanZezhongQY committed
19
)
xgqdut2016's avatar
xgqdut2016 committed
20
21
22
23
24
25
26

# ==============================================================================
#  Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES = [
    # (t_shape, t_strides)
xgqdut2016's avatar
xgqdut2016 committed
27
28
29
30
31
32
33
    ((1, 32, 128), None),
    ((1, 32, 64), None),
    # 昇腾暂不满足这个用例,最后一维度 <=32 会有问题,可能与其核心
    # 接口 GatherMask 的内部实现相关,目前 48 64 128 都可以支持
    ((4, 1, 32), None),
    ((1, 32, 128), None),
    ((3, 32, 128), (8000, 200, 1)),
xgqdut2016's avatar
xgqdut2016 committed
34
35
36
]

# Data types used for testing
xgqdut2016's avatar
xgqdut2016 committed
37
_TENSOR_DTYPES = [torch.float16]
xgqdut2016's avatar
xgqdut2016 committed
38
39
40

# Tolerance map for different data types
_TOLERANCE_MAP = {
xgqdut2016's avatar
xgqdut2016 committed
41
    torch.float16: {"atol": 1e-4, "rtol": 1e-2},
xgqdut2016's avatar
xgqdut2016 committed
42
}
PanZezhongQY's avatar
PanZezhongQY committed
43

44
45
46
47
48
DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000

PanZezhongQY's avatar
PanZezhongQY committed
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66

class RoPEDescriptor(Structure):
    _fields_ = [("device", c_int32)]


infiniopRoPEDescriptor_t = POINTER(RoPEDescriptor)


def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[0], x.shape[-1])
    shape = [d if i == 0 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)


def rotary_embedding(t, pos, theta, torch_device):
    dh = t.shape[2]
67
68
69
70
71
72
73
74
75
76
77
78
79
80
    assert dh % 2 == 0, "Embedding dimension must be even."
    t_even = t[..., 0::2]  # [seq_len, n_head, dh // 2]
    t_odd = t[..., 1::2]  # [seq_len, n_head, dh // 2]
    freqs = (1.0 / (theta ** (torch.arange(0, dh, 2).float() / dh))).to(torch_device)
    freqs = torch.outer(pos, freqs)  # [seq_len, dh // 2]
    cos = torch.cos(freqs).unsqueeze(1)  # [seq_len, 1, dh // 2]
    sin = torch.sin(freqs).unsqueeze(1)  # [seq_len, 1, dh // 2]

    t_out_even = t_even * cos - t_odd * sin
    t_out_odd = t_even * sin + t_odd * cos

    t_out = torch.empty_like(t)
    t_out[..., 0::2] = t_out_even
    t_out[..., 1::2] = t_out_odd
81

PanZezhongQY's avatar
PanZezhongQY committed
82
83
    return t_out

84

PanZezhongQY's avatar
PanZezhongQY committed
85
86
87
88
89
90
91
92
93
94
95
96
97
def sin_cos_table(max_seq_len, dim, torch_device, theta):
    pos = torch.arange(
        0, max_seq_len, dtype=torch.float32, device=torch.device(torch_device)
    )
    freqs = (1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))).to(
        torch_device
    )
    # (a0, a1, a2) -> (a0, a0, a1, a1, a2, a2)
    freqs = torch.repeat_interleave(freqs, repeats=2)
    angles = torch.outer(pos, freqs)
    return torch.sin(angles), torch.cos(angles)


xgqdut2016's avatar
xgqdut2016 committed
98
def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16):
PanZezhongQY's avatar
PanZezhongQY committed
99
100
101
102
103
    print(
        f"Testing Rotary Positional Embedding on {torch_device} with shape:{shape} strides:{strides} and dtype:{dtype}"
    )

    t = torch.rand(shape, dtype=dtype)
xgqdut2016's avatar
xgqdut2016 committed
104
105
106

    t = rearrange_if_needed(t, strides)

107
    posTmp = torch.arange(0, t.shape[0]).to(torch_device)
108
    pos = torch.zeros(2 * posTmp.shape[0], dtype=torch.int32)
PanZezhongQY's avatar
PanZezhongQY committed
109
110
111
    for i in range(posTmp.shape[0]):
        pos[2 * i] = posTmp[i]
        pos[2 * i + 1] = 0
112
    pos = pos.to(torch_device)
PanZezhongQY's avatar
PanZezhongQY committed
113
    theta = 1e4
114
115

    ans = rotary_embedding(t, posTmp, theta, torch_device)
PanZezhongQY's avatar
PanZezhongQY committed
116
117
118
119

    descriptor = infiniopRoPEDescriptor_t()
    # 2x table length for test
    sin_table, cos_table = sin_cos_table(t.shape[0] * 2, t.shape[2], t.device, theta)
xgqdut2016's avatar
xgqdut2016 committed
120

xgqdut2016's avatar
xgqdut2016 committed
121
122
123
124
    t_tensor, sin_table_tensor, cos_table_tensor = [
        to_tensor(tensor, lib) for tensor in [t, sin_table, cos_table]
    ]

PanZezhongQY's avatar
PanZezhongQY committed
125
    pos_tensor = to_tensor(pos[: t.shape[0]], lib)
126
    pos_tensor.descriptor.contents.dtype = InfiniDtype.U64
PanZezhongQY's avatar
PanZezhongQY committed
127
128

    if torch_device == "npu":
xgqdut2016's avatar
xgqdut2016 committed
129
        synchronize_device(torch_device)
PanZezhongQY's avatar
PanZezhongQY committed
130
131
132
133
134
135
136
137
138
139
140
141
142

    check_error(
        lib.infiniopCreateRoPEDescriptor(
            handle,
            byref(descriptor),
            t_tensor.descriptor,
            pos_tensor.descriptor,
            sin_table_tensor.descriptor,
            cos_table_tensor.descriptor,
        )
    )

    # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
xgqdut2016's avatar
xgqdut2016 committed
143
144
    for tensor in [t_tensor, pos_tensor, sin_table_tensor, cos_table_tensor]:
        tensor.descriptor.contents.invalidate()
PanZezhongQY's avatar
PanZezhongQY committed
145
146
147
148
149
150

    workspace_size = c_uint64(0)
    check_error(
        lib.infiniopGetRoPEWorkspaceSize(descriptor, ctypes.byref(workspace_size))
    )
    workspace = create_workspace(workspace_size.value, t.device)
151
152
153
154
155
156
157
158
159
160
161
162
163

    def lib_rope():
        check_error(
            lib.infiniopRoPE(
                descriptor,
                workspace.data_ptr() if workspace is not None else None,
                workspace_size.value,
                t_tensor.data,
                pos_tensor.data,
                sin_table_tensor.data,
                cos_table_tensor.data,
                None,
            )
PanZezhongQY's avatar
PanZezhongQY committed
164
165
        )

166
    lib_rope()
xgqdut2016's avatar
xgqdut2016 committed
167

xgqdut2016's avatar
xgqdut2016 committed
168
    atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
169
    if DEBUG:
xgqdut2016's avatar
xgqdut2016 committed
170
171
        debug(t, ans, atol=atol, rtol=rtol)
    assert torch.allclose(t, ans, atol=atol, rtol=rtol)
xgqdut2016's avatar
xgqdut2016 committed
172

173
174
175
176
177
178
179
180
181
182
183
    if PROFILE:
        profile_operation(
            "PyTorch",
            lambda: rotary_embedding(t, posTmp, theta, torch_device),
            torch_device,
            NUM_PRERUN,
            NUM_ITERATIONS,
        )
        profile_operation(
            "    lib", lambda: lib_rope(), torch_device, NUM_PRERUN, NUM_ITERATIONS
        )
PanZezhongQY's avatar
PanZezhongQY committed
184

185
    check_error(lib.infiniopDestroyRoPEDescriptor(descriptor))
PanZezhongQY's avatar
PanZezhongQY committed
186

187

PanZezhongQY's avatar
PanZezhongQY committed
188
189
190
if __name__ == "__main__":
    args = get_args()
    lib = open_lib()
xgqdut2016's avatar
xgqdut2016 committed
191

PanZezhongQY's avatar
PanZezhongQY committed
192
193
194
195
196
197
198
199
200
    lib.infiniopCreateRoPEDescriptor.restype = c_int32
    lib.infiniopCreateRoPEDescriptor.argtypes = [
        infiniopHandle_t,
        POINTER(infiniopRoPEDescriptor_t),
        infiniopTensorDescriptor_t,
        infiniopTensorDescriptor_t,
        infiniopTensorDescriptor_t,
        infiniopTensorDescriptor_t,
    ]
xgqdut2016's avatar
xgqdut2016 committed
201

PanZezhongQY's avatar
PanZezhongQY committed
202
203
204
205
206
    lib.infiniopGetRoPEWorkspaceSize.restype = c_int32
    lib.infiniopGetRoPEWorkspaceSize.argtypes = [
        infiniopRoPEDescriptor_t,
        POINTER(c_uint64),
    ]
xgqdut2016's avatar
xgqdut2016 committed
207

PanZezhongQY's avatar
PanZezhongQY committed
208
209
210
211
212
213
214
215
216
217
218
    lib.infiniopRoPE.restype = c_int32
    lib.infiniopRoPE.argtypes = [
        infiniopRoPEDescriptor_t,
        c_void_p,
        c_uint64,
        c_void_p,
        c_void_p,
        c_void_p,
        c_void_p,
        c_void_p,
    ]
xgqdut2016's avatar
xgqdut2016 committed
219

PanZezhongQY's avatar
PanZezhongQY committed
220
221
222
223
    lib.infiniopDestroyRoPEDescriptor.restype = c_int32
    lib.infiniopDestroyRoPEDescriptor.argtypes = [
        infiniopRoPEDescriptor_t,
    ]
xgqdut2016's avatar
xgqdut2016 committed
224

225
226
227
228
229
230
231
232
    # Configure testing options
    DEBUG = args.debug
    PROFILE = args.profile
    NUM_PRERUN = args.num_prerun
    NUM_ITERATIONS = args.num_iterations

    # Execute tests
    for device in get_test_devices(args):
xgqdut2016's avatar
xgqdut2016 committed
233
        test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES)
PanZezhongQY's avatar
PanZezhongQY committed
234
    print("\033[92mTest passed!\033[0m")