rotary_embedding.py 6.9 KB
Newer Older
xgqdut2016's avatar
xgqdut2016 committed
1
import torch
PanZezhongQY's avatar
PanZezhongQY committed
2
import ctypes
xgqdut2016's avatar
xgqdut2016 committed
3
from ctypes import POINTER, Structure, c_int32, c_size_t, c_uint64, c_void_p, c_float
4
from libinfiniop import (
PanZezhongQY's avatar
PanZezhongQY committed
5
6
    infiniopHandle_t,
    infiniopTensorDescriptor_t,
7
8
9
    open_lib,
    to_tensor,
    get_test_devices,
PanZezhongQY's avatar
PanZezhongQY committed
10
    check_error,
11
    rearrange_if_needed,
PanZezhongQY's avatar
PanZezhongQY committed
12
    create_workspace,
13
14
15
    test_operator,
    get_args,
    debug,
xgqdut2016's avatar
xgqdut2016 committed
16
    get_tolerance,
17
    profile_operation,
PanZezhongQY's avatar
PanZezhongQY committed
18
)
xgqdut2016's avatar
xgqdut2016 committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42

# ==============================================================================
#  Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES = [
    # (t_shape, t_strides)
        ((1, 32, 128), None),
        ((1, 32, 64), None),
        # 昇腾暂不满足这个用例,最后一维度 <=32 会有问题,可能与其核心
        # 接口 GatherMask 的内部实现相关,目前 48 64 128 都可以支持
        ((4, 1, 32), None),
        ((1, 32, 128), None),
        ((3, 32, 128), (8000, 200, 1)),
]

# Data types used for testing
_TENSOR_DTYPES = [torch.float16, torch.float32]

# Tolerance map for different data types
_TOLERANCE_MAP = {
    torch.float16: {"atol": 0, "rtol": 1e-2},
    torch.float32: {"atol": 0, "rtol": 1e-3},
}
PanZezhongQY's avatar
PanZezhongQY committed
43

44
45
46
47
48
DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000

PanZezhongQY's avatar
PanZezhongQY committed
49

xgqdut2016's avatar
xgqdut2016 committed
50

PanZezhongQY's avatar
PanZezhongQY committed
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
class RoPEDescriptor(Structure):
    _fields_ = [("device", c_int32)]


infiniopRoPEDescriptor_t = POINTER(RoPEDescriptor)


def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[0], x.shape[-1])
    shape = [d if i == 0 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)


def rotary_embedding(t, pos, theta, torch_device):
    dh = t.shape[2]
68
69
70
71
72
73
74
75
76
77
78
79
80
81
    assert dh % 2 == 0, "Embedding dimension must be even."
    t_even = t[..., 0::2]  # [seq_len, n_head, dh // 2]
    t_odd = t[..., 1::2]  # [seq_len, n_head, dh // 2]
    freqs = (1.0 / (theta ** (torch.arange(0, dh, 2).float() / dh))).to(torch_device)
    freqs = torch.outer(pos, freqs)  # [seq_len, dh // 2]
    cos = torch.cos(freqs).unsqueeze(1)  # [seq_len, 1, dh // 2]
    sin = torch.sin(freqs).unsqueeze(1)  # [seq_len, 1, dh // 2]

    t_out_even = t_even * cos - t_odd * sin
    t_out_odd = t_even * sin + t_odd * cos

    t_out = torch.empty_like(t)
    t_out[..., 0::2] = t_out_even
    t_out[..., 1::2] = t_out_odd
82

PanZezhongQY's avatar
PanZezhongQY committed
83
84
    return t_out

85

PanZezhongQY's avatar
PanZezhongQY committed
86
87
88
89
90
91
92
93
94
95
96
97
98
def sin_cos_table(max_seq_len, dim, torch_device, theta):
    pos = torch.arange(
        0, max_seq_len, dtype=torch.float32, device=torch.device(torch_device)
    )
    freqs = (1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))).to(
        torch_device
    )
    # (a0, a1, a2) -> (a0, a0, a1, a1, a2, a2)
    freqs = torch.repeat_interleave(freqs, repeats=2)
    angles = torch.outer(pos, freqs)
    return torch.sin(angles), torch.cos(angles)


xgqdut2016's avatar
xgqdut2016 committed
99
100
101
102
103
104
105
106
def test(
    lib, 
    handle, 
    torch_device, 
    shape, 
    strides=None, 
    dtype=torch.float16
):
PanZezhongQY's avatar
PanZezhongQY committed
107
108
109
110
111
    print(
        f"Testing Rotary Positional Embedding on {torch_device} with shape:{shape} strides:{strides} and dtype:{dtype}"
    )

    t = torch.rand(shape, dtype=dtype)
xgqdut2016's avatar
xgqdut2016 committed
112
113
114

    t = rearrange_if_needed(t, strides)

115
    posTmp = torch.arange(0, t.shape[0]).to(torch_device)
116
    pos = torch.zeros(2 * posTmp.shape[0], dtype=torch.int32)
PanZezhongQY's avatar
PanZezhongQY committed
117
118
119
    for i in range(posTmp.shape[0]):
        pos[2 * i] = posTmp[i]
        pos[2 * i + 1] = 0
120
    pos = pos.to(torch_device)
PanZezhongQY's avatar
PanZezhongQY committed
121
    theta = 1e4
122
123

    ans = rotary_embedding(t, posTmp, theta, torch_device)
PanZezhongQY's avatar
PanZezhongQY committed
124
125
126
127

    descriptor = infiniopRoPEDescriptor_t()
    # 2x table length for test
    sin_table, cos_table = sin_cos_table(t.shape[0] * 2, t.shape[2], t.device, theta)
xgqdut2016's avatar
xgqdut2016 committed
128
129
130

    t_tensor, sin_table_tensor, cos_table_tensor = [to_tensor(tensor, lib) for tensor in [t, sin_table, cos_table]]
    
PanZezhongQY's avatar
PanZezhongQY committed
131
    pos_tensor = to_tensor(pos[: t.shape[0]], lib)
132
    pos_tensor.descriptor.contents.dtype = InfiniDtype.U64
xgqdut2016's avatar
xgqdut2016 committed
133
    
PanZezhongQY's avatar
PanZezhongQY committed
134
135

    if torch_device == "npu":
136
        torch.npu.synchronize()
PanZezhongQY's avatar
PanZezhongQY committed
137
138
139
140
141
142
143
144
145
146
147
148
149

    check_error(
        lib.infiniopCreateRoPEDescriptor(
            handle,
            byref(descriptor),
            t_tensor.descriptor,
            pos_tensor.descriptor,
            sin_table_tensor.descriptor,
            cos_table_tensor.descriptor,
        )
    )

    # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
xgqdut2016's avatar
xgqdut2016 committed
150
151
    for tensor in [t_tensor, pos_tensor, sin_table_tensor, cos_table_tensor]:
        tensor.descriptor.contents.invalidate()
PanZezhongQY's avatar
PanZezhongQY committed
152
153
154
155
156
157

    workspace_size = c_uint64(0)
    check_error(
        lib.infiniopGetRoPEWorkspaceSize(descriptor, ctypes.byref(workspace_size))
    )
    workspace = create_workspace(workspace_size.value, t.device)
158
159
160
161
162
163
164
165
166
167
168
169
170

    def lib_rope():
        check_error(
            lib.infiniopRoPE(
                descriptor,
                workspace.data_ptr() if workspace is not None else None,
                workspace_size.value,
                t_tensor.data,
                pos_tensor.data,
                sin_table_tensor.data,
                cos_table_tensor.data,
                None,
            )
PanZezhongQY's avatar
PanZezhongQY committed
171
172
        )

173
    lib_rope()
xgqdut2016's avatar
xgqdut2016 committed
174
    atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
175
    if DEBUG:
xgqdut2016's avatar
xgqdut2016 committed
176
177
178
        debug(t, ans, atol=atol, rtol=rtol)
    assert torch.allclose(t, ans, atol=atol, rtol=rtol)
    
179
180
181
182
183
184
185
186
187
188
189
    if PROFILE:
        profile_operation(
            "PyTorch",
            lambda: rotary_embedding(t, posTmp, theta, torch_device),
            torch_device,
            NUM_PRERUN,
            NUM_ITERATIONS,
        )
        profile_operation(
            "    lib", lambda: lib_rope(), torch_device, NUM_PRERUN, NUM_ITERATIONS
        )
PanZezhongQY's avatar
PanZezhongQY committed
190

191
    check_error(lib.infiniopDestroyRoPEDescriptor(descriptor))
PanZezhongQY's avatar
PanZezhongQY committed
192

193

PanZezhongQY's avatar
PanZezhongQY committed
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
if __name__ == "__main__":
    args = get_args()
    lib = open_lib()
    lib.infiniopCreateRoPEDescriptor.restype = c_int32
    lib.infiniopCreateRoPEDescriptor.argtypes = [
        infiniopHandle_t,
        POINTER(infiniopRoPEDescriptor_t),
        infiniopTensorDescriptor_t,
        infiniopTensorDescriptor_t,
        infiniopTensorDescriptor_t,
        infiniopTensorDescriptor_t,
    ]
    lib.infiniopGetRoPEWorkspaceSize.restype = c_int32
    lib.infiniopGetRoPEWorkspaceSize.argtypes = [
        infiniopRoPEDescriptor_t,
        POINTER(c_uint64),
    ]
    lib.infiniopRoPE.restype = c_int32
    lib.infiniopRoPE.argtypes = [
        infiniopRoPEDescriptor_t,
        c_void_p,
        c_uint64,
        c_void_p,
        c_void_p,
        c_void_p,
        c_void_p,
        c_void_p,
    ]
    lib.infiniopDestroyRoPEDescriptor.restype = c_int32
    lib.infiniopDestroyRoPEDescriptor.argtypes = [
        infiniopRoPEDescriptor_t,
    ]
226
227
228
229
230
231
232
233
    # Configure testing options
    DEBUG = args.debug
    PROFILE = args.profile
    NUM_PRERUN = args.num_prerun
    NUM_ITERATIONS = args.num_iterations

    # Execute tests
    for device in get_test_devices(args):
xgqdut2016's avatar
xgqdut2016 committed
234
        test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES)
PanZezhongQY's avatar
PanZezhongQY committed
235
    print("\033[92mTest passed!\033[0m")