rope.py 7.02 KB
Newer Older
xgqdut2016's avatar
xgqdut2016 committed
1
import torch
PanZezhongQY's avatar
PanZezhongQY committed
2
import ctypes
3
from ctypes import c_uint64
4
from libinfiniop import (
5
6
    LIBINFINIOP,
    TestTensor,
7
    get_test_devices,
PanZezhongQY's avatar
PanZezhongQY committed
8
    check_error,
9
10
11
    test_operator,
    get_args,
    debug,
xgqdut2016's avatar
xgqdut2016 committed
12
    get_tolerance,
13
    profile_operation,
14
15
16
17
18
19
    TestWorkspace,
    InfiniDtype,
    InfiniDtypeNames,
    InfiniDeviceEnum,
    InfiniDeviceNames,
    infiniopOperatorDescriptor_t,
PanZezhongQY's avatar
PanZezhongQY committed
20
)
PanZezhong's avatar
PanZezhong committed
21
from enum import Enum, auto
xgqdut2016's avatar
xgqdut2016 committed
22
23
24
25
26

# ==============================================================================
#  Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
PanZezhong's avatar
PanZezhong committed
27
28
29
_TEST_CASES_ = [
    # (shape, x_strides, y_strides)
    ((1, 32, 128), None, None),
PanZezhong's avatar
PanZezhong committed
30
    ((10, 32, 64), None, None),
xgqdut2016's avatar
xgqdut2016 committed
31
32
    # 昇腾暂不满足这个用例,最后一维度 <=32 会有问题,可能与其核心
    # 接口 GatherMask 的内部实现相关,目前 48 64 128 都可以支持
PanZezhong's avatar
PanZezhong committed
33
34
35
    ((4, 1, 32), (64, 64, 1), None),
    ((11, 33, 128), None, (8000, 200, 1)),
    ((3, 32, 128), (8000, 200, 1), (7000, 128, 1)),
36
37
38
39
    ((8, 1, 32, 128), None, None),
    ((8, 10, 32, 64), None, None),
    ((8, 20, 32, 64), (40960, 64, 1280, 1), (40960, 64, 1280, 1)),
    ((8, 20, 4, 64), (1048576, 64, 262144, 1), (1048576, 64, 262144, 1)),
xgqdut2016's avatar
xgqdut2016 committed
40
41
42
]

# Data types used for testing
43
_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16, InfiniDtype.F32]
xgqdut2016's avatar
xgqdut2016 committed
44
45
46

# Tolerance map for different data types
_TOLERANCE_MAP = {
47
48
49
    InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-2},
    InfiniDtype.BF16: {"atol": 5e-3, "rtol": 5e-2},
    InfiniDtype.F32: {"atol": 1e-4, "rtol": 1e-3},
xgqdut2016's avatar
xgqdut2016 committed
50
}
PanZezhongQY's avatar
PanZezhongQY committed
51

PanZezhong's avatar
PanZezhong committed
52
53
54
55
56
57

class Inplace(Enum):
    OUT_OF_PLACE = auto()
    INPLACE_X = auto()


58
59
60
61
62
class Algorithm(Enum):
    GPT_J = 0
    GPT_NEOX = 1


PanZezhong's avatar
PanZezhong committed
63
64
65
66
67
_INPLACE = [
    Inplace.OUT_OF_PLACE,
    Inplace.INPLACE_X,
]

68
69
70
71
72
73

_ALGO = [
    Algorithm.GPT_J,
    Algorithm.GPT_NEOX,
]

PanZezhong's avatar
PanZezhong committed
74
_TEST_CASES = [
75
    test_case + (inplace_item, algo_item)
PanZezhong's avatar
PanZezhong committed
76
77
    for test_case in _TEST_CASES_
    for inplace_item in _INPLACE
78
    for algo_item in _ALGO
PanZezhong's avatar
PanZezhong committed
79
80
]

81
82
83
84
85
DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000

PanZezhongQY's avatar
PanZezhongQY committed
86

87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
def rotary_embedding(ans, t, sin, cos, device, algo):
    def _torch_rope(sin, cos, t1, t2):
        cos = cos.unsqueeze(1)  # [seq_len, 1, dh // 2]
        sin = sin.unsqueeze(1)  # [seq_len, 1, dh // 2]
        if device == InfiniDeviceEnum.CPU:
            (t1, t2, cos, sin) = (
                t1.float(),
                t2.float(),
                cos.float(),
                sin.float(),
            )

        t_out_1 = t1 * cos - t2 * sin
        t_out_2 = t1 * sin + t2 * cos

        return t_out_1, t_out_2

    dh = t.shape[-1]
PanZezhong's avatar
PanZezhong committed
105
    dt = t.dtype
106
107
    assert dh % 2 == 0, "Embedding dimension must be even."

108
109
110
111
112
113
114
115
116
    if algo == Algorithm.GPT_J:
        t_even = t[..., 0::2]  # [seq_len, n_head, dh // 2]
        t_odd = t[..., 1::2]  # [seq_len, n_head, dh // 2]

        t_out_even, t_out_odd = _torch_rope(sin, cos, t_even, t_odd)

        ans[..., 0::2] = t_out_even.to(dt)
        ans[..., 1::2] = t_out_odd.to(dt)
    else:
wooway777's avatar
wooway777 committed
117
        half_dim = dh // 2
118
119
120
121
        t_first = t[..., :half_dim]
        t_second = t[..., half_dim:]

        t_out_first, t_out_second = _torch_rope(sin, cos, t_first, t_second)
122

123
124
        ans[..., :half_dim] = t_out_first.to(dt)
        ans[..., half_dim:] = t_out_second.to(dt)
PanZezhongQY's avatar
PanZezhongQY committed
125

126

127
def sin_cos_table(pos, dim, device, theta, dtype):
PanZezhong's avatar
PanZezhong committed
128
    assert dim % 2 == 0, "Embedding dimension must be even."
129
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
130
    angles = torch.outer(pos.cpu(), freqs)
131
132
133
134
    return (
        TestTensor.from_torch(torch.sin(angles), dtype, device),
        TestTensor.from_torch(torch.cos(angles), dtype, device),
    )
PanZezhong's avatar
PanZezhong committed
135
136
137
138


def test(
    handle,
139
    device,
PanZezhong's avatar
PanZezhong committed
140
141
142
143
    shape,
    x_strides=None,
    y_strides=None,
    inplace=Inplace.OUT_OF_PLACE,
144
    algo=Algorithm.GPT_J,
PanZezhong's avatar
PanZezhong committed
145
    dtype=torch.float32,
146
    sync=None,
PanZezhong's avatar
PanZezhong committed
147
):
148
    x = TestTensor(shape, x_strides, dtype, device)
PanZezhong's avatar
PanZezhong committed
149
    if inplace == Inplace.INPLACE_X:
150
151
        if x_strides != y_strides:
            return
PanZezhong's avatar
PanZezhong committed
152
153
        y = x
    else:
154
155
156
        y = TestTensor(shape, y_strides, dtype, device)

    print(
157
        f"Testing Rotary Positional Embedding on {InfiniDeviceNames[device]} with shape:{shape} x_strides:{x_strides} y_strides:{y_strides} and dtype:{InfiniDtypeNames[dtype]} inplace:{inplace} algo:{algo}"
158
    )
PanZezhong's avatar
PanZezhong committed
159
    theta = 1e5
160
    pos = TestTensor.from_torch(torch.arange(0, x.shape[-3]), InfiniDtype.I32, device)
161
    sin_table, cos_table = sin_cos_table(
162
        pos.torch_tensor(), x.shape[-1], x.device, theta, dtype
163
    )
xgqdut2016's avatar
xgqdut2016 committed
164

165
166
167
168
169
170
    rotary_embedding(
        y.torch_tensor(),
        x.torch_tensor(),
        sin_table.torch_tensor(),
        cos_table.torch_tensor(),
        device,
171
        algo,
172
    )
PanZezhongQY's avatar
PanZezhongQY committed
173

174
    descriptor = infiniopOperatorDescriptor_t()
PanZezhongQY's avatar
PanZezhongQY committed
175

176
177
    if sync is not None:
        sync()
PanZezhongQY's avatar
PanZezhongQY committed
178
179

    check_error(
180
        LIBINFINIOP.infiniopCreateRoPEDescriptor(
PanZezhongQY's avatar
PanZezhongQY committed
181
            handle,
182
            ctypes.byref(descriptor),
183
184
185
186
187
            y.descriptor,
            x.descriptor,
            pos.descriptor,
            sin_table.descriptor,
            cos_table.descriptor,
188
            algo.value,
PanZezhongQY's avatar
PanZezhongQY committed
189
190
191
192
        )
    )

    # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
193
194
    for tensor in [y, x, pos, sin_table, cos_table]:
        tensor.destroy_desc()
PanZezhongQY's avatar
PanZezhongQY committed
195
196
197

    workspace_size = c_uint64(0)
    check_error(
198
199
200
        LIBINFINIOP.infiniopGetRoPEWorkspaceSize(
            descriptor, ctypes.byref(workspace_size)
        )
PanZezhongQY's avatar
PanZezhongQY committed
201
    )
202
    workspace = TestWorkspace(workspace_size.value, x.device)
203
204
205

    def lib_rope():
        check_error(
206
            LIBINFINIOP.infiniopRoPE(
207
                descriptor,
208
                workspace.data(),
209
                workspace_size.value,
210
211
212
213
214
                y.data(),
                x.data(),
                pos.data(),
                sin_table.data(),
                cos_table.data(),
215
216
                None,
            )
PanZezhongQY's avatar
PanZezhongQY committed
217
218
        )

219
    lib_rope()
220

zhangyue's avatar
zhangyue committed
221
222
    if sync is not None:
        sync()
xgqdut2016's avatar
xgqdut2016 committed
223

xgqdut2016's avatar
xgqdut2016 committed
224
    atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
225
    if DEBUG:
226
227
        debug(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol)
    assert torch.allclose(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol)
xgqdut2016's avatar
xgqdut2016 committed
228

229
230
231
    if PROFILE:
        profile_operation(
            "PyTorch",
232
233
234
235
236
237
            lambda: rotary_embedding(
                y.torch_tensor(),
                x.torch_tensor(),
                sin_table.torch_tensor(),
                cos_table.torch_tensor(),
                device,
wooway777's avatar
wooway777 committed
238
                algo,
239
240
            ),
            device,
241
242
243
244
            NUM_PRERUN,
            NUM_ITERATIONS,
        )
        profile_operation(
245
            "    lib", lambda: lib_rope(), device, NUM_PRERUN, NUM_ITERATIONS
246
        )
PanZezhongQY's avatar
PanZezhongQY committed
247

248
    check_error(LIBINFINIOP.infiniopDestroyRoPEDescriptor(descriptor))
PanZezhongQY's avatar
PanZezhongQY committed
249

250

PanZezhongQY's avatar
PanZezhongQY committed
251
252
if __name__ == "__main__":
    args = get_args()
xgqdut2016's avatar
xgqdut2016 committed
253

254
255
256
257
258
259
260
261
    # Configure testing options
    DEBUG = args.debug
    PROFILE = args.profile
    NUM_PRERUN = args.num_prerun
    NUM_ITERATIONS = args.num_iterations

    # Execute tests
    for device in get_test_devices(args):
262
        test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES)
263

PanZezhongQY's avatar
PanZezhongQY committed
264
    print("\033[92mTest passed!\033[0m")