Commit eb1ae658 authored by PanZezhong's avatar PanZezhong
Browse files

issue/48/test: 重构rope测试脚本

parent b3941ede
import ctypes
from ctypes import c_float, POINTER, c_void_p, c_int32, c_uint64, Structure, byref
from ctypes import POINTER, c_void_p, c_int32, c_uint64, Structure, byref
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
from operatorspy import (
open_lib,
to_tensor,
DeviceEnum,
from libinfiniop import (
infiniopHandle_t,
infiniopTensorDescriptor_t,
create_handle,
destroy_handle,
open_lib,
to_tensor,
get_test_devices,
check_error,
rearrange_tensor,
rearrange_if_needed,
create_workspace,
U64,
test_operator,
get_args,
debug,
profile_operation,
InfiniDtype,
)
from operatorspy.tests.test_utils import get_args
import torch
DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000
class RoPEDescriptor(Structure):
_fields_ = [("device", c_int32)]
......@@ -40,15 +44,21 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
def rotary_embedding(t, pos, theta, torch_device):
dh = t.shape[2]
freqs = (1.0 / (theta ** (torch.arange(0, dh, 2)[: (dh // 2)].float() / dh))).to(
torch_device
)
freqs = torch.outer(pos, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
assert dh % 2 == 0, "Embedding dimension must be even."
t_even = t[..., 0::2] # [seq_len, n_head, dh // 2]
t_odd = t[..., 1::2] # [seq_len, n_head, dh // 2]
freqs = (1.0 / (theta ** (torch.arange(0, dh, 2).float() / dh))).to(torch_device)
freqs = torch.outer(pos, freqs) # [seq_len, dh // 2]
cos = torch.cos(freqs).unsqueeze(1) # [seq_len, 1, dh // 2]
sin = torch.sin(freqs).unsqueeze(1) # [seq_len, 1, dh // 2]
t_out_even = t_even * cos - t_odd * sin
t_out_odd = t_even * sin + t_odd * cos
t_out = torch.empty_like(t)
t_out[..., 0::2] = t_out_even
t_out[..., 1::2] = t_out_odd
t_ = torch.view_as_complex(t.reshape(*t.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, t_)
t_out = torch.view_as_real(t_ * freqs_cis).flatten(2).to(t.dtype)
return t_out
......@@ -71,29 +81,23 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16):
)
t = torch.rand(shape, dtype=dtype)
if strides is not None:
t = rearrange_tensor(t, strides)
posTmp = torch.arange(0, t.shape[0])
t = rearrange_if_needed(t, strides).to(torch_device)
posTmp = torch.arange(0, t.shape[0]).to(torch_device)
pos = torch.zeros(2 * posTmp.shape[0], dtype=torch.int32)
for i in range(posTmp.shape[0]):
pos[2 * i] = posTmp[i]
pos[2 * i + 1] = 0
pos = pos.to(torch_device)
theta = 1e4
if torch_device == "mlu" or torch_device == "npu":
ans = rotary_embedding(t, posTmp, theta, "cpu").to(torch_device)
pos = pos.to(torch_device)
t = t.to(torch_device)
else:
t = t.to(torch_device)
pos = pos.to(torch_device)
ans = rotary_embedding(t, posTmp.to(torch_device), theta, torch_device)
ans = rotary_embedding(t, posTmp, theta, torch_device)
descriptor = infiniopRoPEDescriptor_t()
# 2x table length for test
sin_table, cos_table = sin_cos_table(t.shape[0] * 2, t.shape[2], t.device, theta)
t_tensor = to_tensor(t, lib)
pos_tensor = to_tensor(pos[: t.shape[0]], lib)
pos_tensor.descriptor.contents.dt = U64
pos_tensor.descriptor.contents.dtype = InfiniDtype.U64
sin_table_tensor = to_tensor(sin_table, lib)
cos_table_tensor = to_tensor(cos_table, lib)
......@@ -122,69 +126,52 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16):
lib.infiniopGetRoPEWorkspaceSize(descriptor, ctypes.byref(workspace_size))
)
workspace = create_workspace(workspace_size.value, t.device)
check_error(
lib.infiniopRoPE(
descriptor,
workspace.data_ptr() if workspace is not None else None,
workspace_size.value,
t_tensor.data,
pos_tensor.data,
sin_table_tensor.data,
cos_table_tensor.data,
None,
def lib_rope():
check_error(
lib.infiniopRoPE(
descriptor,
workspace.data_ptr() if workspace is not None else None,
workspace_size.value,
t_tensor.data,
pos_tensor.data,
sin_table_tensor.data,
cos_table_tensor.data,
None,
)
)
)
lib_rope()
if DEBUG:
debug(t, ans, atol=1e-4, rtol=1e-2)
assert torch.allclose(t, ans, atol=1e-4, rtol=1e-2)
check_error(lib.infiniopDestroyRoPEDescriptor(descriptor))
def test_cpu(lib, test_cases):
device = DeviceEnum.DEVICE_CPU
handle = create_handle(lib, device)
for shape, strides, dtype in test_cases:
test(lib, handle, "cpu", shape, strides, dtype)
destroy_handle(lib, handle)
def test_cuda(lib, test_cases):
device = DeviceEnum.DEVICE_CUDA
handle = create_handle(lib, device)
for shape, strides, dtype in test_cases:
test(lib, handle, "cuda", shape, strides, dtype)
destroy_handle(lib, handle)
def test_bang(lib, test_cases):
import torch_mlu
device = DeviceEnum.DEVICE_BANG
handle = create_handle(lib, device)
for shape, strides, dtype in test_cases:
test(lib, handle, "mlu", shape, strides, dtype)
destroy_handle(lib, handle)
def test_ascend(lib, test_cases):
import torch_npu
if PROFILE:
profile_operation(
"PyTorch",
lambda: rotary_embedding(t, posTmp, theta, torch_device),
torch_device,
NUM_PRERUN,
NUM_ITERATIONS,
)
profile_operation(
" lib", lambda: lib_rope(), torch_device, NUM_PRERUN, NUM_ITERATIONS
)
device = DeviceEnum.DEVICE_ASCEND
handle = create_handle(lib, device)
for shape, strides, dtype in test_cases:
test(lib, handle, "npu", shape, strides, dtype)
destroy_handle(lib, handle)
check_error(lib.infiniopDestroyRoPEDescriptor(descriptor))
if __name__ == "__main__":
test_cases = [
((1, 32, 128), None, torch.float16),
((1, 32, 64), None, torch.float16),
# (t_shape, t_strides)
((1, 32, 128), None),
((1, 32, 64), None),
# 昇腾暂不满足这个用例,最后一维度 <=32 会有问题,可能与其核心
# 接口 GatherMask 的内部实现相关,目前 48 64 128 都可以支持
((4, 1, 32), None, torch.float16),
((1, 32, 128), None, torch.float16),
((3, 32, 128), (8000, 200, 1), torch.float16),
((4, 1, 32), None),
((1, 32, 128), None),
((3, 32, 128), (8000, 200, 1)),
]
test_dtypes = [torch.float16]
args = get_args()
lib = open_lib()
lib.infiniopCreateRoPEDescriptor.restype = c_int32
......@@ -216,14 +203,13 @@ if __name__ == "__main__":
lib.infiniopDestroyRoPEDescriptor.argtypes = [
infiniopRoPEDescriptor_t,
]
if args.cpu:
test_cpu(lib, test_cases)
if args.cuda:
test_cuda(lib, test_cases)
if args.bang:
test_bang(lib, test_cases)
if args.ascend:
test_ascend(lib, test_cases)
if not (args.cpu or args.cuda or args.bang or args.ascend):
test_cpu(lib, test_cases)
# Configure testing options
DEBUG = args.debug
PROFILE = args.profile
NUM_PRERUN = args.num_prerun
NUM_ITERATIONS = args.num_iterations
# Execute tests
for device in get_test_devices(args):
test_operator(lib, device, test, test_cases, test_dtypes)
print("\033[92mTest passed!\033[0m")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment