rope.py 6.79 KB
Newer Older
xgqdut2016's avatar
xgqdut2016 committed
1
import torch
PanZezhongQY's avatar
PanZezhongQY committed
2
import ctypes
3
from ctypes import c_uint64
4
from libinfiniop import (
5
6
    LIBINFINIOP,
    TestTensor,
7
    get_test_devices,
PanZezhongQY's avatar
PanZezhongQY committed
8
    check_error,
9
10
11
    test_operator,
    get_args,
    debug,
xgqdut2016's avatar
xgqdut2016 committed
12
    get_tolerance,
13
    profile_operation,
14
15
16
17
18
19
    TestWorkspace,
    InfiniDtype,
    InfiniDtypeNames,
    InfiniDeviceEnum,
    InfiniDeviceNames,
    infiniopOperatorDescriptor_t,
PanZezhongQY's avatar
PanZezhongQY committed
20
)
PanZezhong's avatar
PanZezhong committed
21
from enum import Enum, auto
xgqdut2016's avatar
xgqdut2016 committed
22
23
24
25
26

# ==============================================================================
#  Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
PanZezhong's avatar
PanZezhong committed
27
28
29
_TEST_CASES_ = [
    # (shape, x_strides, y_strides)
    ((1, 32, 128), None, None),
PanZezhong's avatar
PanZezhong committed
30
    ((10, 32, 64), None, None),
xgqdut2016's avatar
xgqdut2016 committed
31
32
    # 昇腾暂不满足这个用例,最后一维度 <=32 会有问题,可能与其核心
    # 接口 GatherMask 的内部实现相关,目前 48 64 128 都可以支持
PanZezhong's avatar
PanZezhong committed
33
34
35
    ((4, 1, 32), (64, 64, 1), None),
    ((11, 33, 128), None, (8000, 200, 1)),
    ((3, 32, 128), (8000, 200, 1), (7000, 128, 1)),
xgqdut2016's avatar
xgqdut2016 committed
36
37
38
]

# Data types used for testing
39
_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16, InfiniDtype.F32]
xgqdut2016's avatar
xgqdut2016 committed
40
41
42

# Tolerance map for different data types
_TOLERANCE_MAP = {
43
44
45
    InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-2},
    InfiniDtype.BF16: {"atol": 5e-3, "rtol": 5e-2},
    InfiniDtype.F32: {"atol": 1e-4, "rtol": 1e-3},
xgqdut2016's avatar
xgqdut2016 committed
46
}
PanZezhongQY's avatar
PanZezhongQY committed
47

PanZezhong's avatar
PanZezhong committed
48
49
50
51
52
53

class Inplace(Enum):
    OUT_OF_PLACE = auto()
    INPLACE_X = auto()


54
55
56
57
58
class Algorithm(Enum):
    GPT_J = 0
    GPT_NEOX = 1


PanZezhong's avatar
PanZezhong committed
59
60
61
62
63
_INPLACE = [
    Inplace.OUT_OF_PLACE,
    Inplace.INPLACE_X,
]

64
65
66
67
68
69

_ALGO = [
    Algorithm.GPT_J,
    Algorithm.GPT_NEOX,
]

PanZezhong's avatar
PanZezhong committed
70
_TEST_CASES = [
71
    test_case + (inplace_item, algo_item)
PanZezhong's avatar
PanZezhong committed
72
73
    for test_case in _TEST_CASES_
    for inplace_item in _INPLACE
74
    for algo_item in _ALGO
PanZezhong's avatar
PanZezhong committed
75
76
]

77
78
79
80
81
DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000

PanZezhongQY's avatar
PanZezhongQY committed
82

83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
def rotary_embedding(ans, t, sin, cos, device, algo):
    def _torch_rope(sin, cos, t1, t2):
        cos = cos.unsqueeze(1)  # [seq_len, 1, dh // 2]
        sin = sin.unsqueeze(1)  # [seq_len, 1, dh // 2]
        if device == InfiniDeviceEnum.CPU:
            (t1, t2, cos, sin) = (
                t1.float(),
                t2.float(),
                cos.float(),
                sin.float(),
            )

        t_out_1 = t1 * cos - t2 * sin
        t_out_2 = t1 * sin + t2 * cos

        return t_out_1, t_out_2


    dh = t.shape[-1]
PanZezhong's avatar
PanZezhong committed
102
    dt = t.dtype
103
104
    assert dh % 2 == 0, "Embedding dimension must be even."

105
106
107
108
109
110
111
112
113
114
115
116
117
118
    if algo == Algorithm.GPT_J:
        t_even = t[..., 0::2]  # [seq_len, n_head, dh // 2]
        t_odd = t[..., 1::2]  # [seq_len, n_head, dh // 2]

        t_out_even, t_out_odd = _torch_rope(sin, cos, t_even, t_odd)

        ans[..., 0::2] = t_out_even.to(dt)
        ans[..., 1::2] = t_out_odd.to(dt)
    else:
        half_dim = dh // 2   
        t_first = t[..., :half_dim]
        t_second = t[..., half_dim:]

        t_out_first, t_out_second = _torch_rope(sin, cos, t_first, t_second)
119

120
121
        ans[..., :half_dim] = t_out_first.to(dt)
        ans[..., half_dim:] = t_out_second.to(dt)
PanZezhongQY's avatar
PanZezhongQY committed
122

123

124
def sin_cos_table(pos, dim, device, theta, dtype):
PanZezhong's avatar
PanZezhong committed
125
    assert dim % 2 == 0, "Embedding dimension must be even."
126
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
127
    angles = torch.outer(pos.cpu(), freqs)
128
129
130
131
    return (
        TestTensor.from_torch(torch.sin(angles), dtype, device),
        TestTensor.from_torch(torch.cos(angles), dtype, device),
    )
PanZezhong's avatar
PanZezhong committed
132
133
134
135


def test(
    handle,
136
    device,
PanZezhong's avatar
PanZezhong committed
137
138
139
140
    shape,
    x_strides=None,
    y_strides=None,
    inplace=Inplace.OUT_OF_PLACE,
141
    algo=Algorithm.GPT_J,
PanZezhong's avatar
PanZezhong committed
142
    dtype=torch.float32,
143
    sync=None,
PanZezhong's avatar
PanZezhong committed
144
):
145
    x = TestTensor(shape, x_strides, dtype, device)
PanZezhong's avatar
PanZezhong committed
146
    if inplace == Inplace.INPLACE_X:
147
148
        if x_strides != y_strides:
            return
PanZezhong's avatar
PanZezhong committed
149
150
        y = x
    else:
151
152
153
        y = TestTensor(shape, y_strides, dtype, device)

    print(
154
        f"Testing Rotary Positional Embedding on {InfiniDeviceNames[device]} with shape:{shape} x_strides:{x_strides} y_strides:{y_strides} and dtype:{InfiniDtypeNames[dtype]} inplace:{inplace} algo:{algo}"
155
    )
PanZezhong's avatar
PanZezhong committed
156
    theta = 1e5
157
158
159
160
    pos = TestTensor.from_torch(torch.arange(0, x.shape[0]), InfiniDtype.I32, device)
    sin_table, cos_table = sin_cos_table(
        pos.torch_tensor(), x.shape[2], x.device, theta, dtype
    )
xgqdut2016's avatar
xgqdut2016 committed
161

162
163
164
165
166
167
    rotary_embedding(
        y.torch_tensor(),
        x.torch_tensor(),
        sin_table.torch_tensor(),
        cos_table.torch_tensor(),
        device,
168
        algo,
169
    )
PanZezhongQY's avatar
PanZezhongQY committed
170

171
    descriptor = infiniopOperatorDescriptor_t()
PanZezhongQY's avatar
PanZezhongQY committed
172

173
174
    if sync is not None:
        sync()
PanZezhongQY's avatar
PanZezhongQY committed
175
176

    check_error(
177
        LIBINFINIOP.infiniopCreateRoPEDescriptor(
PanZezhongQY's avatar
PanZezhongQY committed
178
            handle,
179
            ctypes.byref(descriptor),
180
181
182
183
184
            y.descriptor,
            x.descriptor,
            pos.descriptor,
            sin_table.descriptor,
            cos_table.descriptor,
185
            algo.value,
PanZezhongQY's avatar
PanZezhongQY committed
186
187
188
189
        )
    )

    # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
190
191
    for tensor in [y, x, pos, sin_table, cos_table]:
        tensor.destroy_desc()
PanZezhongQY's avatar
PanZezhongQY committed
192
193
194

    workspace_size = c_uint64(0)
    check_error(
195
196
197
        LIBINFINIOP.infiniopGetRoPEWorkspaceSize(
            descriptor, ctypes.byref(workspace_size)
        )
PanZezhongQY's avatar
PanZezhongQY committed
198
    )
199
    workspace = TestWorkspace(workspace_size.value, x.device)
200
201
202

    def lib_rope():
        check_error(
203
            LIBINFINIOP.infiniopRoPE(
204
                descriptor,
205
                workspace.data(),
206
                workspace_size.value,
207
208
209
210
211
                y.data(),
                x.data(),
                pos.data(),
                sin_table.data(),
                cos_table.data(),
212
213
                None,
            )
PanZezhongQY's avatar
PanZezhongQY committed
214
215
        )

216
    lib_rope()
217

zhangyue's avatar
zhangyue committed
218
219
    if sync is not None:
        sync()
xgqdut2016's avatar
xgqdut2016 committed
220

xgqdut2016's avatar
xgqdut2016 committed
221
    atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
222
    if DEBUG:
223
224
        debug(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol)
    assert torch.allclose(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol)
xgqdut2016's avatar
xgqdut2016 committed
225

226
227
228
    if PROFILE:
        profile_operation(
            "PyTorch",
229
230
231
232
233
234
235
236
            lambda: rotary_embedding(
                y.torch_tensor(),
                x.torch_tensor(),
                sin_table.torch_tensor(),
                cos_table.torch_tensor(),
                device,
            ),
            device,
237
238
239
240
            NUM_PRERUN,
            NUM_ITERATIONS,
        )
        profile_operation(
241
            "    lib", lambda: lib_rope(), device, NUM_PRERUN, NUM_ITERATIONS
242
        )
PanZezhongQY's avatar
PanZezhongQY committed
243

244
    check_error(LIBINFINIOP.infiniopDestroyRoPEDescriptor(descriptor))
PanZezhongQY's avatar
PanZezhongQY committed
245

246

PanZezhongQY's avatar
PanZezhongQY committed
247
248
if __name__ == "__main__":
    args = get_args()
xgqdut2016's avatar
xgqdut2016 committed
249

250
251
252
253
254
255
256
257
    # Configure testing options
    DEBUG = args.debug
    PROFILE = args.profile
    NUM_PRERUN = args.num_prerun
    NUM_ITERATIONS = args.num_iterations

    # Execute tests
    for device in get_test_devices(args):
258
        test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES)
259

PanZezhongQY's avatar
PanZezhongQY committed
260
    print("\033[92mTest passed!\033[0m")