rope.py 6.81 KB
Newer Older
xgqdut2016's avatar
xgqdut2016 committed
1
import torch
PanZezhongQY's avatar
PanZezhongQY committed
2
import ctypes
3
from ctypes import c_uint64
4
from libinfiniop import (
5
6
    LIBINFINIOP,
    TestTensor,
7
    get_test_devices,
PanZezhongQY's avatar
PanZezhongQY committed
8
    check_error,
9
10
11
    test_operator,
    get_args,
    debug,
xgqdut2016's avatar
xgqdut2016 committed
12
    get_tolerance,
13
    profile_operation,
14
15
16
17
18
19
    TestWorkspace,
    InfiniDtype,
    InfiniDtypeNames,
    InfiniDeviceEnum,
    InfiniDeviceNames,
    infiniopOperatorDescriptor_t,
PanZezhongQY's avatar
PanZezhongQY committed
20
)
PanZezhong's avatar
PanZezhong committed
21
from enum import Enum, auto
xgqdut2016's avatar
xgqdut2016 committed
22
23
24
25
26

# ==============================================================================
#  Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
PanZezhong's avatar
PanZezhong committed
27
28
29
_TEST_CASES_ = [
    # (shape, x_strides, y_strides)
    ((1, 32, 128), None, None),
PanZezhong's avatar
PanZezhong committed
30
    ((10, 32, 64), None, None),
xgqdut2016's avatar
xgqdut2016 committed
31
32
    # 昇腾暂不满足这个用例,最后一维度 <=32 会有问题,可能与其核心
    # 接口 GatherMask 的内部实现相关,目前 48 64 128 都可以支持
PanZezhong's avatar
PanZezhong committed
33
34
35
    ((4, 1, 32), (64, 64, 1), None),
    ((11, 33, 128), None, (8000, 200, 1)),
    ((3, 32, 128), (8000, 200, 1), (7000, 128, 1)),
xgqdut2016's avatar
xgqdut2016 committed
36
37
38
]

# Data types used for testing
39
_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16, InfiniDtype.F32]
xgqdut2016's avatar
xgqdut2016 committed
40
41
42

# Tolerance map for different data types
_TOLERANCE_MAP = {
43
44
45
    InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-2},
    InfiniDtype.BF16: {"atol": 5e-3, "rtol": 5e-2},
    InfiniDtype.F32: {"atol": 1e-4, "rtol": 1e-3},
xgqdut2016's avatar
xgqdut2016 committed
46
}
PanZezhongQY's avatar
PanZezhongQY committed
47

PanZezhong's avatar
PanZezhong committed
48
49
50
51
52
53

class Inplace(Enum):
    OUT_OF_PLACE = auto()
    INPLACE_X = auto()


54
55
56
57
58
class Algorithm(Enum):
    GPT_J = 0
    GPT_NEOX = 1


PanZezhong's avatar
PanZezhong committed
59
60
61
62
63
_INPLACE = [
    Inplace.OUT_OF_PLACE,
    Inplace.INPLACE_X,
]

64
65
66
67
68
69

_ALGO = [
    Algorithm.GPT_J,
    Algorithm.GPT_NEOX,
]

PanZezhong's avatar
PanZezhong committed
70
_TEST_CASES = [
71
    test_case + (inplace_item, algo_item)
PanZezhong's avatar
PanZezhong committed
72
73
    for test_case in _TEST_CASES_
    for inplace_item in _INPLACE
74
    for algo_item in _ALGO
PanZezhong's avatar
PanZezhong committed
75
76
]

77
78
79
80
81
DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000

PanZezhongQY's avatar
PanZezhongQY committed
82

83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
def rotary_embedding(ans, t, sin, cos, device, algo):
    def _torch_rope(sin, cos, t1, t2):
        cos = cos.unsqueeze(1)  # [seq_len, 1, dh // 2]
        sin = sin.unsqueeze(1)  # [seq_len, 1, dh // 2]
        if device == InfiniDeviceEnum.CPU:
            (t1, t2, cos, sin) = (
                t1.float(),
                t2.float(),
                cos.float(),
                sin.float(),
            )

        t_out_1 = t1 * cos - t2 * sin
        t_out_2 = t1 * sin + t2 * cos

        return t_out_1, t_out_2

    dh = t.shape[-1]
PanZezhong's avatar
PanZezhong committed
101
    dt = t.dtype
102
103
    assert dh % 2 == 0, "Embedding dimension must be even."

104
105
106
107
108
109
110
111
112
    if algo == Algorithm.GPT_J:
        t_even = t[..., 0::2]  # [seq_len, n_head, dh // 2]
        t_odd = t[..., 1::2]  # [seq_len, n_head, dh // 2]

        t_out_even, t_out_odd = _torch_rope(sin, cos, t_even, t_odd)

        ans[..., 0::2] = t_out_even.to(dt)
        ans[..., 1::2] = t_out_odd.to(dt)
    else:
wooway777's avatar
wooway777 committed
113
        half_dim = dh // 2
114
115
116
117
        t_first = t[..., :half_dim]
        t_second = t[..., half_dim:]

        t_out_first, t_out_second = _torch_rope(sin, cos, t_first, t_second)
118

119
120
        ans[..., :half_dim] = t_out_first.to(dt)
        ans[..., half_dim:] = t_out_second.to(dt)
PanZezhongQY's avatar
PanZezhongQY committed
121

122

123
def sin_cos_table(pos, dim, device, theta, dtype):
PanZezhong's avatar
PanZezhong committed
124
    assert dim % 2 == 0, "Embedding dimension must be even."
125
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
126
    angles = torch.outer(pos.cpu(), freqs)
127
128
129
130
    return (
        TestTensor.from_torch(torch.sin(angles), dtype, device),
        TestTensor.from_torch(torch.cos(angles), dtype, device),
    )
PanZezhong's avatar
PanZezhong committed
131
132
133
134


def test(
    handle,
135
    device,
PanZezhong's avatar
PanZezhong committed
136
137
138
139
    shape,
    x_strides=None,
    y_strides=None,
    inplace=Inplace.OUT_OF_PLACE,
140
    algo=Algorithm.GPT_J,
PanZezhong's avatar
PanZezhong committed
141
    dtype=torch.float32,
142
    sync=None,
PanZezhong's avatar
PanZezhong committed
143
):
144
    x = TestTensor(shape, x_strides, dtype, device)
PanZezhong's avatar
PanZezhong committed
145
    if inplace == Inplace.INPLACE_X:
146
147
        if x_strides != y_strides:
            return
PanZezhong's avatar
PanZezhong committed
148
149
        y = x
    else:
150
151
152
        y = TestTensor(shape, y_strides, dtype, device)

    print(
153
        f"Testing Rotary Positional Embedding on {InfiniDeviceNames[device]} with shape:{shape} x_strides:{x_strides} y_strides:{y_strides} and dtype:{InfiniDtypeNames[dtype]} inplace:{inplace} algo:{algo}"
154
    )
PanZezhong's avatar
PanZezhong committed
155
    theta = 1e5
156
157
158
159
    pos = TestTensor.from_torch(torch.arange(0, x.shape[0]), InfiniDtype.I32, device)
    sin_table, cos_table = sin_cos_table(
        pos.torch_tensor(), x.shape[2], x.device, theta, dtype
    )
xgqdut2016's avatar
xgqdut2016 committed
160

161
162
163
164
165
166
    rotary_embedding(
        y.torch_tensor(),
        x.torch_tensor(),
        sin_table.torch_tensor(),
        cos_table.torch_tensor(),
        device,
167
        algo,
168
    )
PanZezhongQY's avatar
PanZezhongQY committed
169

170
    descriptor = infiniopOperatorDescriptor_t()
PanZezhongQY's avatar
PanZezhongQY committed
171

172
173
    if sync is not None:
        sync()
PanZezhongQY's avatar
PanZezhongQY committed
174
175

    check_error(
176
        LIBINFINIOP.infiniopCreateRoPEDescriptor(
PanZezhongQY's avatar
PanZezhongQY committed
177
            handle,
178
            ctypes.byref(descriptor),
179
180
181
182
183
            y.descriptor,
            x.descriptor,
            pos.descriptor,
            sin_table.descriptor,
            cos_table.descriptor,
184
            algo.value,
PanZezhongQY's avatar
PanZezhongQY committed
185
186
187
188
        )
    )

    # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
189
190
    for tensor in [y, x, pos, sin_table, cos_table]:
        tensor.destroy_desc()
PanZezhongQY's avatar
PanZezhongQY committed
191
192
193

    workspace_size = c_uint64(0)
    check_error(
194
195
196
        LIBINFINIOP.infiniopGetRoPEWorkspaceSize(
            descriptor, ctypes.byref(workspace_size)
        )
PanZezhongQY's avatar
PanZezhongQY committed
197
    )
198
    workspace = TestWorkspace(workspace_size.value, x.device)
199
200
201

    def lib_rope():
        check_error(
202
            LIBINFINIOP.infiniopRoPE(
203
                descriptor,
204
                workspace.data(),
205
                workspace_size.value,
206
207
208
209
210
                y.data(),
                x.data(),
                pos.data(),
                sin_table.data(),
                cos_table.data(),
211
212
                None,
            )
PanZezhongQY's avatar
PanZezhongQY committed
213
214
        )

215
    lib_rope()
216

zhangyue's avatar
zhangyue committed
217
218
    if sync is not None:
        sync()
xgqdut2016's avatar
xgqdut2016 committed
219

xgqdut2016's avatar
xgqdut2016 committed
220
    atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
221
    if DEBUG:
222
223
        debug(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol)
    assert torch.allclose(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol)
xgqdut2016's avatar
xgqdut2016 committed
224

225
226
227
    if PROFILE:
        profile_operation(
            "PyTorch",
228
229
230
231
232
233
            lambda: rotary_embedding(
                y.torch_tensor(),
                x.torch_tensor(),
                sin_table.torch_tensor(),
                cos_table.torch_tensor(),
                device,
wooway777's avatar
wooway777 committed
234
                algo,
235
236
            ),
            device,
237
238
239
240
            NUM_PRERUN,
            NUM_ITERATIONS,
        )
        profile_operation(
241
            "    lib", lambda: lib_rope(), device, NUM_PRERUN, NUM_ITERATIONS
242
        )
PanZezhongQY's avatar
PanZezhongQY committed
243

244
    check_error(LIBINFINIOP.infiniopDestroyRoPEDescriptor(descriptor))
PanZezhongQY's avatar
PanZezhongQY committed
245

246

PanZezhongQY's avatar
PanZezhongQY committed
247
248
if __name__ == "__main__":
    args = get_args()
xgqdut2016's avatar
xgqdut2016 committed
249

250
251
252
253
254
255
256
257
    # Configure testing options
    DEBUG = args.debug
    PROFILE = args.profile
    NUM_PRERUN = args.num_prerun
    NUM_ITERATIONS = args.num_iterations

    # Execute tests
    for device in get_test_devices(args):
258
        test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES)
259

PanZezhongQY's avatar
PanZezhongQY committed
260
    print("\033[92mTest passed!\033[0m")