"cacheflow/core/block_manager.py" did not exist on "46958cf941997264ff36c23c42a71d30674b61f0"
rope.py 6.08 KB
Newer Older
xgqdut2016's avatar
xgqdut2016 committed
1
import torch
PanZezhongQY's avatar
PanZezhongQY committed
2
import ctypes
3
from ctypes import c_uint64
4
from libinfiniop import (
5
6
    LIBINFINIOP,
    TestTensor,
7
    get_test_devices,
PanZezhongQY's avatar
PanZezhongQY committed
8
    check_error,
9
10
11
    test_operator,
    get_args,
    debug,
xgqdut2016's avatar
xgqdut2016 committed
12
    get_tolerance,
13
    profile_operation,
14
15
16
17
18
19
    TestWorkspace,
    InfiniDtype,
    InfiniDtypeNames,
    InfiniDeviceEnum,
    InfiniDeviceNames,
    infiniopOperatorDescriptor_t,
PanZezhongQY's avatar
PanZezhongQY committed
20
)
PanZezhong's avatar
PanZezhong committed
21
from enum import Enum, auto
xgqdut2016's avatar
xgqdut2016 committed
22
23
24
25
26

# ==============================================================================
#  Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
PanZezhong's avatar
PanZezhong committed
27
28
29
_TEST_CASES_ = [
    # (shape, x_strides, y_strides)
    ((1, 32, 128), None, None),
PanZezhong's avatar
PanZezhong committed
30
    ((10, 32, 64), None, None),
xgqdut2016's avatar
xgqdut2016 committed
31
32
    # 昇腾暂不满足这个用例,最后一维度 <=32 会有问题,可能与其核心
    # 接口 GatherMask 的内部实现相关,目前 48 64 128 都可以支持
PanZezhong's avatar
PanZezhong committed
33
34
35
    ((4, 1, 32), (64, 64, 1), None),
    ((11, 33, 128), None, (8000, 200, 1)),
    ((3, 32, 128), (8000, 200, 1), (7000, 128, 1)),
xgqdut2016's avatar
xgqdut2016 committed
36
37
38
]

# Data types used for testing
39
_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16, InfiniDtype.F32]
xgqdut2016's avatar
xgqdut2016 committed
40
41
42

# Tolerance map for different data types
_TOLERANCE_MAP = {
43
44
45
    InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-2},
    InfiniDtype.BF16: {"atol": 5e-3, "rtol": 5e-2},
    InfiniDtype.F32: {"atol": 1e-4, "rtol": 1e-3},
xgqdut2016's avatar
xgqdut2016 committed
46
}
PanZezhongQY's avatar
PanZezhongQY committed
47

PanZezhong's avatar
PanZezhong committed
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64

class Inplace(Enum):
    OUT_OF_PLACE = auto()
    INPLACE_X = auto()


_INPLACE = [
    Inplace.OUT_OF_PLACE,
    Inplace.INPLACE_X,
]

_TEST_CASES = [
    test_case + (inplace_item,)
    for test_case in _TEST_CASES_
    for inplace_item in _INPLACE
]

65
66
67
68
69
DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000

PanZezhongQY's avatar
PanZezhongQY committed
70

71
def rotary_embedding(ans, t, sin, cos, device):
PanZezhongQY's avatar
PanZezhongQY committed
72
    dh = t.shape[2]
PanZezhong's avatar
PanZezhong committed
73
    dt = t.dtype
74
    assert dh % 2 == 0, "Embedding dimension must be even."
PanZezhong's avatar
PanZezhong committed
75
76
77
78
    t_even = t[..., 0::2]  # [seq_len, n_head, dh // 2]
    t_odd = t[..., 1::2]  # [seq_len, n_head, dh // 2]
    cos = cos.unsqueeze(1)  # [seq_len, 1, dh // 2]
    sin = sin.unsqueeze(1)  # [seq_len, 1, dh // 2]
79
    if device == InfiniDeviceEnum.CPU:
PanZezhong's avatar
PanZezhong committed
80
81
82
83
84
85
        (t_even, t_odd, cos, sin) = (
            t_even.float(),
            t_odd.float(),
            cos.float(),
            sin.float(),
        )
86
87
88
89

    t_out_even = t_even * cos - t_odd * sin
    t_out_odd = t_even * sin + t_odd * cos

90
91
    ans[..., 0::2] = t_out_even.to(dt)
    ans[..., 1::2] = t_out_odd.to(dt)
PanZezhongQY's avatar
PanZezhongQY committed
92

93

94
def sin_cos_table(pos, dim, device, theta, dtype):
PanZezhong's avatar
PanZezhong committed
95
    assert dim % 2 == 0, "Embedding dimension must be even."
96
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
PanZezhongQY's avatar
PanZezhongQY committed
97
    angles = torch.outer(pos, freqs)
98
99
100
101
    return (
        TestTensor.from_torch(torch.sin(angles), dtype, device),
        TestTensor.from_torch(torch.cos(angles), dtype, device),
    )
PanZezhong's avatar
PanZezhong committed
102
103
104
105


def test(
    handle,
106
    device,
PanZezhong's avatar
PanZezhong committed
107
108
109
110
111
    shape,
    x_strides=None,
    y_strides=None,
    inplace=Inplace.OUT_OF_PLACE,
    dtype=torch.float32,
112
    sync=None,
PanZezhong's avatar
PanZezhong committed
113
):
114
    x = TestTensor(shape, x_strides, dtype, device)
PanZezhong's avatar
PanZezhong committed
115
    if inplace == Inplace.INPLACE_X:
116
117
        if x_strides != y_strides:
            return
PanZezhong's avatar
PanZezhong committed
118
119
        y = x
    else:
120
121
122
123
124
        y = TestTensor(shape, y_strides, dtype, device)

    print(
        f"Testing Rotary Positional Embedding on {InfiniDeviceNames[device]} with shape:{shape} x_strides:{x_strides} y_strides:{y_strides} and dtype:{InfiniDtypeNames[dtype]} inplace:{inplace}"
    )
PanZezhong's avatar
PanZezhong committed
125
    theta = 1e5
126
127
128
129
    pos = TestTensor.from_torch(torch.arange(0, x.shape[0]), InfiniDtype.I32, device)
    sin_table, cos_table = sin_cos_table(
        pos.torch_tensor(), x.shape[2], x.device, theta, dtype
    )
xgqdut2016's avatar
xgqdut2016 committed
130

131
132
133
134
135
136
137
    rotary_embedding(
        y.torch_tensor(),
        x.torch_tensor(),
        sin_table.torch_tensor(),
        cos_table.torch_tensor(),
        device,
    )
PanZezhongQY's avatar
PanZezhongQY committed
138

139
    descriptor = infiniopOperatorDescriptor_t()
PanZezhongQY's avatar
PanZezhongQY committed
140

141
142
    if sync is not None:
        sync()
PanZezhongQY's avatar
PanZezhongQY committed
143
144

    check_error(
145
        LIBINFINIOP.infiniopCreateRoPEDescriptor(
PanZezhongQY's avatar
PanZezhongQY committed
146
            handle,
147
            ctypes.byref(descriptor),
148
149
150
151
152
            y.descriptor,
            x.descriptor,
            pos.descriptor,
            sin_table.descriptor,
            cos_table.descriptor,
PanZezhongQY's avatar
PanZezhongQY committed
153
154
155
156
        )
    )

    # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
157
158
    for tensor in [y, x, pos, sin_table, cos_table]:
        tensor.destroy_desc()
PanZezhongQY's avatar
PanZezhongQY committed
159
160
161

    workspace_size = c_uint64(0)
    check_error(
162
163
164
        LIBINFINIOP.infiniopGetRoPEWorkspaceSize(
            descriptor, ctypes.byref(workspace_size)
        )
PanZezhongQY's avatar
PanZezhongQY committed
165
    )
166
    workspace = TestWorkspace(workspace_size.value, x.device)
167
168
169

    def lib_rope():
        check_error(
170
            LIBINFINIOP.infiniopRoPE(
171
                descriptor,
172
                workspace.data(),
173
                workspace_size.value,
174
175
176
177
178
                y.data(),
                x.data(),
                pos.data(),
                sin_table.data(),
                cos_table.data(),
179
180
                None,
            )
PanZezhongQY's avatar
PanZezhongQY committed
181
182
        )

183
    lib_rope()
184

zhangyue's avatar
zhangyue committed
185
186
    if sync is not None:
        sync()
xgqdut2016's avatar
xgqdut2016 committed
187

xgqdut2016's avatar
xgqdut2016 committed
188
    atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
189
    if DEBUG:
190
191
        debug(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol)
    assert torch.allclose(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol)
xgqdut2016's avatar
xgqdut2016 committed
192

193
194
195
    if PROFILE:
        profile_operation(
            "PyTorch",
196
197
198
199
200
201
202
203
            lambda: rotary_embedding(
                y.torch_tensor(),
                x.torch_tensor(),
                sin_table.torch_tensor(),
                cos_table.torch_tensor(),
                device,
            ),
            device,
204
205
206
207
            NUM_PRERUN,
            NUM_ITERATIONS,
        )
        profile_operation(
208
            "    lib", lambda: lib_rope(), device, NUM_PRERUN, NUM_ITERATIONS
209
        )
PanZezhongQY's avatar
PanZezhongQY committed
210

211
    check_error(LIBINFINIOP.infiniopDestroyRoPEDescriptor(descriptor))
PanZezhongQY's avatar
PanZezhongQY committed
212

213

PanZezhongQY's avatar
PanZezhongQY committed
214
215
if __name__ == "__main__":
    args = get_args()
xgqdut2016's avatar
xgqdut2016 committed
216

217
218
219
220
221
222
223
224
    # Configure testing options
    DEBUG = args.debug
    PROFILE = args.profile
    NUM_PRERUN = args.num_prerun
    NUM_ITERATIONS = args.num_iterations

    # Execute tests
    for device in get_test_devices(args):
225
        test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES)
226

PanZezhongQY's avatar
PanZezhongQY committed
227
    print("\033[92mTest passed!\033[0m")