Unverified Commit 89e49e31 authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #58 from PanZezhong1725/issue/48-test

issue/48/test: 重构rope测试脚本
parents b3941ede eb1ae658
import ctypes import ctypes
from ctypes import c_float, POINTER, c_void_p, c_int32, c_uint64, Structure, byref from ctypes import POINTER, c_void_p, c_int32, c_uint64, Structure, byref
import sys import sys
import os import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
from operatorspy import ( from libinfiniop import (
open_lib,
to_tensor,
DeviceEnum,
infiniopHandle_t, infiniopHandle_t,
infiniopTensorDescriptor_t, infiniopTensorDescriptor_t,
create_handle, open_lib,
destroy_handle, to_tensor,
get_test_devices,
check_error, check_error,
rearrange_tensor, rearrange_if_needed,
create_workspace, create_workspace,
U64, test_operator,
get_args,
debug,
profile_operation,
InfiniDtype,
) )
from operatorspy.tests.test_utils import get_args
import torch import torch
DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000
class RoPEDescriptor(Structure): class RoPEDescriptor(Structure):
_fields_ = [("device", c_int32)] _fields_ = [("device", c_int32)]
...@@ -40,15 +44,21 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): ...@@ -40,15 +44,21 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
def rotary_embedding(t, pos, theta, torch_device): def rotary_embedding(t, pos, theta, torch_device):
dh = t.shape[2] dh = t.shape[2]
freqs = (1.0 / (theta ** (torch.arange(0, dh, 2)[: (dh // 2)].float() / dh))).to( assert dh % 2 == 0, "Embedding dimension must be even."
torch_device t_even = t[..., 0::2] # [seq_len, n_head, dh // 2]
) t_odd = t[..., 1::2] # [seq_len, n_head, dh // 2]
freqs = torch.outer(pos, freqs) freqs = (1.0 / (theta ** (torch.arange(0, dh, 2).float() / dh))).to(torch_device)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) freqs = torch.outer(pos, freqs) # [seq_len, dh // 2]
cos = torch.cos(freqs).unsqueeze(1) # [seq_len, 1, dh // 2]
sin = torch.sin(freqs).unsqueeze(1) # [seq_len, 1, dh // 2]
t_out_even = t_even * cos - t_odd * sin
t_out_odd = t_even * sin + t_odd * cos
t_out = torch.empty_like(t)
t_out[..., 0::2] = t_out_even
t_out[..., 1::2] = t_out_odd
t_ = torch.view_as_complex(t.reshape(*t.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, t_)
t_out = torch.view_as_real(t_ * freqs_cis).flatten(2).to(t.dtype)
return t_out return t_out
...@@ -71,29 +81,23 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16): ...@@ -71,29 +81,23 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16):
) )
t = torch.rand(shape, dtype=dtype) t = torch.rand(shape, dtype=dtype)
if strides is not None: t = rearrange_if_needed(t, strides).to(torch_device)
t = rearrange_tensor(t, strides) posTmp = torch.arange(0, t.shape[0]).to(torch_device)
posTmp = torch.arange(0, t.shape[0])
pos = torch.zeros(2 * posTmp.shape[0], dtype=torch.int32) pos = torch.zeros(2 * posTmp.shape[0], dtype=torch.int32)
for i in range(posTmp.shape[0]): for i in range(posTmp.shape[0]):
pos[2 * i] = posTmp[i] pos[2 * i] = posTmp[i]
pos[2 * i + 1] = 0 pos[2 * i + 1] = 0
pos = pos.to(torch_device)
theta = 1e4 theta = 1e4
if torch_device == "mlu" or torch_device == "npu":
ans = rotary_embedding(t, posTmp, theta, "cpu").to(torch_device) ans = rotary_embedding(t, posTmp, theta, torch_device)
pos = pos.to(torch_device)
t = t.to(torch_device)
else:
t = t.to(torch_device)
pos = pos.to(torch_device)
ans = rotary_embedding(t, posTmp.to(torch_device), theta, torch_device)
descriptor = infiniopRoPEDescriptor_t() descriptor = infiniopRoPEDescriptor_t()
# 2x table length for test # 2x table length for test
sin_table, cos_table = sin_cos_table(t.shape[0] * 2, t.shape[2], t.device, theta) sin_table, cos_table = sin_cos_table(t.shape[0] * 2, t.shape[2], t.device, theta)
t_tensor = to_tensor(t, lib) t_tensor = to_tensor(t, lib)
pos_tensor = to_tensor(pos[: t.shape[0]], lib) pos_tensor = to_tensor(pos[: t.shape[0]], lib)
pos_tensor.descriptor.contents.dt = U64 pos_tensor.descriptor.contents.dtype = InfiniDtype.U64
sin_table_tensor = to_tensor(sin_table, lib) sin_table_tensor = to_tensor(sin_table, lib)
cos_table_tensor = to_tensor(cos_table, lib) cos_table_tensor = to_tensor(cos_table, lib)
...@@ -122,69 +126,52 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16): ...@@ -122,69 +126,52 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16):
lib.infiniopGetRoPEWorkspaceSize(descriptor, ctypes.byref(workspace_size)) lib.infiniopGetRoPEWorkspaceSize(descriptor, ctypes.byref(workspace_size))
) )
workspace = create_workspace(workspace_size.value, t.device) workspace = create_workspace(workspace_size.value, t.device)
check_error(
lib.infiniopRoPE( def lib_rope():
descriptor, check_error(
workspace.data_ptr() if workspace is not None else None, lib.infiniopRoPE(
workspace_size.value, descriptor,
t_tensor.data, workspace.data_ptr() if workspace is not None else None,
pos_tensor.data, workspace_size.value,
sin_table_tensor.data, t_tensor.data,
cos_table_tensor.data, pos_tensor.data,
None, sin_table_tensor.data,
cos_table_tensor.data,
None,
)
) )
)
lib_rope()
if DEBUG:
debug(t, ans, atol=1e-4, rtol=1e-2)
assert torch.allclose(t, ans, atol=1e-4, rtol=1e-2) assert torch.allclose(t, ans, atol=1e-4, rtol=1e-2)
check_error(lib.infiniopDestroyRoPEDescriptor(descriptor)) if PROFILE:
profile_operation(
"PyTorch",
def test_cpu(lib, test_cases): lambda: rotary_embedding(t, posTmp, theta, torch_device),
device = DeviceEnum.DEVICE_CPU torch_device,
handle = create_handle(lib, device) NUM_PRERUN,
for shape, strides, dtype in test_cases: NUM_ITERATIONS,
test(lib, handle, "cpu", shape, strides, dtype) )
destroy_handle(lib, handle) profile_operation(
" lib", lambda: lib_rope(), torch_device, NUM_PRERUN, NUM_ITERATIONS
)
def test_cuda(lib, test_cases):
device = DeviceEnum.DEVICE_CUDA
handle = create_handle(lib, device)
for shape, strides, dtype in test_cases:
test(lib, handle, "cuda", shape, strides, dtype)
destroy_handle(lib, handle)
def test_bang(lib, test_cases):
import torch_mlu
device = DeviceEnum.DEVICE_BANG
handle = create_handle(lib, device)
for shape, strides, dtype in test_cases:
test(lib, handle, "mlu", shape, strides, dtype)
destroy_handle(lib, handle)
def test_ascend(lib, test_cases):
import torch_npu
device = DeviceEnum.DEVICE_ASCEND check_error(lib.infiniopDestroyRoPEDescriptor(descriptor))
handle = create_handle(lib, device)
for shape, strides, dtype in test_cases:
test(lib, handle, "npu", shape, strides, dtype)
destroy_handle(lib, handle)
if __name__ == "__main__": if __name__ == "__main__":
test_cases = [ test_cases = [
((1, 32, 128), None, torch.float16), # (t_shape, t_strides)
((1, 32, 64), None, torch.float16), ((1, 32, 128), None),
((1, 32, 64), None),
# 昇腾暂不满足这个用例,最后一维度 <=32 会有问题,可能与其核心 # 昇腾暂不满足这个用例,最后一维度 <=32 会有问题,可能与其核心
# 接口 GatherMask 的内部实现相关,目前 48 64 128 都可以支持 # 接口 GatherMask 的内部实现相关,目前 48 64 128 都可以支持
((4, 1, 32), None, torch.float16), ((4, 1, 32), None),
((1, 32, 128), None, torch.float16), ((1, 32, 128), None),
((3, 32, 128), (8000, 200, 1), torch.float16), ((3, 32, 128), (8000, 200, 1)),
] ]
test_dtypes = [torch.float16]
args = get_args() args = get_args()
lib = open_lib() lib = open_lib()
lib.infiniopCreateRoPEDescriptor.restype = c_int32 lib.infiniopCreateRoPEDescriptor.restype = c_int32
...@@ -216,14 +203,13 @@ if __name__ == "__main__": ...@@ -216,14 +203,13 @@ if __name__ == "__main__":
lib.infiniopDestroyRoPEDescriptor.argtypes = [ lib.infiniopDestroyRoPEDescriptor.argtypes = [
infiniopRoPEDescriptor_t, infiniopRoPEDescriptor_t,
] ]
if args.cpu: # Configure testing options
test_cpu(lib, test_cases) DEBUG = args.debug
if args.cuda: PROFILE = args.profile
test_cuda(lib, test_cases) NUM_PRERUN = args.num_prerun
if args.bang: NUM_ITERATIONS = args.num_iterations
test_bang(lib, test_cases)
if args.ascend: # Execute tests
test_ascend(lib, test_cases) for device in get_test_devices(args):
if not (args.cpu or args.cuda or args.bang or args.ascend): test_operator(lib, device, test, test_cases, test_dtypes)
test_cpu(lib, test_cases)
print("\033[92mTest passed!\033[0m") print("\033[92mTest passed!\033[0m")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment