Unverified Commit 3c31dc6c authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #45 from YdrMaster/main

issue/52 代码格式化:机制和效果
parents 16dad776 e5ed9fa1
......@@ -2,9 +2,19 @@ import torch
import ctypes
from ctypes import POINTER, Structure, c_int32, c_size_t, c_uint64, c_void_p, c_float
from libinfiniop import (
infiniopHandle_t, infiniopTensorDescriptor_t, open_lib, to_tensor, get_test_devices,
check_error, rearrange_if_needed, create_workspace, test_operator, get_args,
debug, get_tolerance, profile_operation,
infiniopHandle_t,
infiniopTensorDescriptor_t,
open_lib,
to_tensor,
get_test_devices,
check_error,
rearrange_if_needed,
create_workspace,
test_operator,
get_args,
debug,
get_tolerance,
profile_operation,
)
# ==============================================================================
......@@ -21,8 +31,8 @@ _TEST_CASES = [
(1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), (4096, 1), (4096, 1), (4096, 1)),
(1.0, 1.0, (6, 2048), (2048, 2560), (6, 2560), (2048, 1), (1, 2048), (2560, 1)),
(1.0, 1.0, (6, 2048), (2048, 2560), (6, 2560), (2048, 1), (1, 2048), (2560, 1)),
(1.0/8.0, 0.0, (4, 8*6, 64), (4, 64, 6), (4, 8*6, 6), None, None, None),
(1.0/8.0, 0.0, (4, 8*6, 64), (4, 64, 6), (4, 8*6, 6), None, None, None),
(1.0 / 8.0, 0.0, (4, 8 * 6, 64), (4, 64, 6), (4, 8 * 6, 6), None, None, None),
(1.0 / 8.0, 0.0, (4, 8 * 6, 64), (4, 64, 6), (4, 8 * 6, 6), None, None, None),
]
# Data types used for testing
......@@ -30,8 +40,8 @@ _TENSOR_DTYPES = [torch.float16, torch.float32]
# Tolerance map for different data types
_TOLERANCE_MAP = {
torch.float16: {'atol': 0, 'rtol': 1e-2},
torch.float32: {'atol': 0, 'rtol': 1e-3},
torch.float16: {"atol": 0, "rtol": 1e-2},
torch.float32: {"atol": 0, "rtol": 1e-3},
}
DEBUG = False
......@@ -39,6 +49,7 @@ PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000
# ==============================================================================
# Definitions
# ==============================================================================
......@@ -48,6 +59,7 @@ class MatmulDescriptor(Structure):
infiniopMatmulDescriptor_t = POINTER(MatmulDescriptor)
# PyTorch implementation for matrix multiplication
def matmul(_c, beta, _a, _b, alpha):
a, b, c = _a.clone(), _b.clone(), _c.clone()
......@@ -55,6 +67,7 @@ def matmul(_c, beta, _a, _b, alpha):
fp32_result = torch.matmul(a.to(torch.float32), b.to(torch.float32))
return alpha * fp32_result.to(result_dtype) + beta * c
# The argument list should be (lib, handle, torch_device, <param list>, dtype)
# The <param list> should keep the same order as the one specified in _TEST_CASES
def test(
......@@ -85,7 +98,10 @@ def test(
# Compute the PyTorch reference result
ans = matmul(c, beta, a, b, alpha)
a, b, c = [rearrange_if_needed(tensor, stride) for tensor, stride in zip([a, b, c], [a_stride, b_stride, c_stride])]
a, b, c = [
rearrange_if_needed(tensor, stride)
for tensor, stride in zip([a, b, c], [a_stride, b_stride, c_stride])
]
a_tensor, b_tensor, c_tensor = [to_tensor(tensor, lib) for tensor in [a, b, c]]
descriptor = infiniopMatmulDescriptor_t()
......@@ -95,7 +111,7 @@ def test(
ctypes.byref(descriptor),
c_tensor.descriptor,
a_tensor.descriptor,
b_tensor.descriptor
b_tensor.descriptor,
)
)
......@@ -105,22 +121,27 @@ def test(
# Get workspace size and create workspace
workspace_size = c_uint64(0)
check_error(lib.infiniopGetMatmulWorkspaceSize(descriptor, ctypes.byref(workspace_size)))
check_error(
lib.infiniopGetMatmulWorkspaceSize(descriptor, ctypes.byref(workspace_size))
)
workspace = create_workspace(workspace_size.value, a.device)
# Execute infiniop matmul operator
def lib_matmul():
check_error(lib.infiniopMatmul(
descriptor,
workspace.data_ptr() if workspace is not None else None,
workspace_size.value,
c_tensor.data,
a_tensor.data,
b_tensor.data,
alpha,
beta,
None,
))
check_error(
lib.infiniopMatmul(
descriptor,
workspace.data_ptr() if workspace is not None else None,
workspace_size.value,
c_tensor.data,
a_tensor.data,
b_tensor.data,
alpha,
beta,
None,
)
)
lib_matmul()
# Validate results
......@@ -131,9 +152,10 @@ def test(
# Profiling workflow
if PROFILE:
# fmt: off
profile_operation("PyTorch", lambda: matmul(c, beta, a, b, alpha), torch_device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_matmul(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on
check_error(lib.infiniopDestroyMatmulDescriptor(descriptor))
......@@ -150,7 +172,7 @@ if __name__ == "__main__":
POINTER(infiniopMatmulDescriptor_t),
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t
infiniopTensorDescriptor_t,
]
lib.infiniopGetMatmulWorkspaceSize.restype = c_int32
......
......@@ -35,7 +35,7 @@ class MaxPoolDescriptor(Structure):
infiniopMaxPoolDescriptor_t = POINTER(MaxPoolDescriptor)
def pool(x, k, padding, stride, dilation = 1):
def pool(x, k, padding, stride, dilation=1):
pooling_layers = {
1: torch.nn.MaxPool1d,
2: torch.nn.MaxPool2d,
......@@ -66,18 +66,20 @@ def inferShape(x_shape, kernel_shape, padding, strides):
return x_shape[:2] + tuple(output_shape)
# convert a python tuple to a ctype void pointer
def tuple_to_void_p(py_tuple: Tuple):
array = ctypes.c_int64 * len(py_tuple)
data_array = array(*py_tuple)
return ctypes.cast(data_array, ctypes.c_void_p)
def test(
lib,
handle,
torch_device,
x_shape,
k_shape,
x_shape,
k_shape,
padding,
strides,
tensor_dtype=torch.float16,
......@@ -87,7 +89,9 @@ def test(
)
x = torch.rand(x_shape, dtype=tensor_dtype).to(torch_device)
y = torch.rand(inferShape(x_shape, k_shape, padding, strides), dtype=tensor_dtype).to(torch_device)
y = torch.rand(
inferShape(x_shape, k_shape, padding, strides), dtype=tensor_dtype
).to(torch_device)
for i in range(NUM_PRERUN if PROFILE else 1):
ans = pool(x, k_shape, padding, strides)
......@@ -123,7 +127,9 @@ def test(
check_error(
lib.infiniopGetMaxPoolWorkspaceSize(descriptor, ctypes.byref(workspaceSize))
)
workspace = torch.zeros(int(workspaceSize.value), dtype=torch.uint8).to(torch_device)
workspace = torch.zeros(int(workspaceSize.value), dtype=torch.uint8).to(
torch_device
)
workspace_ptr = ctypes.cast(workspace.data_ptr(), ctypes.POINTER(ctypes.c_uint8))
for i in range(NUM_PRERUN if PROFILE else 1):
......@@ -161,8 +167,10 @@ def test_cpu(lib, test_cases):
device = DeviceEnum.DEVICE_CPU
handle = create_handle(lib, device)
for x_shape, kernel_shape, padding, strides in test_cases:
# fmt: off
test(lib, handle, "cpu", x_shape, kernel_shape, padding, strides, tensor_dtype=torch.float16)
test(lib, handle, "cpu", x_shape, kernel_shape, padding, strides, tensor_dtype=torch.float32)
# fmt: on
destroy_handle(lib, handle)
......@@ -170,8 +178,10 @@ def test_cuda(lib, test_cases):
device = DeviceEnum.DEVICE_CUDA
handle = create_handle(lib, device)
for x_shape, kernel_shape, padding, strides in test_cases:
# fmt: off
test(lib, handle, "cuda", x_shape, kernel_shape, padding, strides, tensor_dtype=torch.float16)
test(lib, handle, "cuda", x_shape, kernel_shape, padding, strides, tensor_dtype=torch.float32)
# fmt: on
destroy_handle(lib, handle)
......@@ -181,8 +191,10 @@ def test_bang(lib, test_cases):
device = DeviceEnum.DEVICE_BANG
handle = create_handle(lib, device)
for x_shape, kernel_shape, padding, strides in test_cases:
# fmt: off
test(lib, handle, "mlu", x_shape, kernel_shape, padding, strides, tensor_dtype=torch.float16)
test(lib, handle, "mlu", x_shape, kernel_shape, padding, strides, tensor_dtype=torch.float32)
# fmt: on
destroy_handle(lib, handle)
......
......@@ -30,13 +30,13 @@ infiniopRandomSampleDescriptor_t = POINTER(RandomSampleDescriptor)
def random_sample(data, random_val, topp, topk, voc, temperature, torch_device):
indices = torch.zeros([topk], dtype = torch.int64)
indices = torch.zeros([topk], dtype=torch.int64)
dataNp = data.clone().detach()
sorted_indices = torch.arange(voc)
for i in range(topk):
for j in range(i + 1, voc):
if(dataNp[i] < dataNp[j]):
if dataNp[i] < dataNp[j]:
tmp = dataNp[i].clone().detach()
dataNp[i] = dataNp[j].clone().detach()
dataNp[j] = tmp
......@@ -44,48 +44,60 @@ def random_sample(data, random_val, topp, topk, voc, temperature, torch_device):
tmpInd = sorted_indices[i].clone().detach()
sorted_indices[i] = sorted_indices[j].clone().detach()
sorted_indices[j] = tmpInd
#sorted_indices = torch.argsort(dataNp, descending=True)
indices = sorted_indices[:topk]
# sorted_indices = torch.argsort(dataNp, descending=True)
indices = sorted_indices[:topk]
dataNp = dataNp[sorted_indices]
globalM = dataNp[0]
dataNp = (dataNp - globalM) / temperature
dataNp = torch.softmax(dataNp.float(), dim = 0)
dataNp = torch.softmax(dataNp.float(), dim=0)
sum_s = 0
for end in range(topk):
sum_s += dataNp[end]
if(sum_s >= topp):
if sum_s >= topp:
break
if(end < topk - 1):
if end < topk - 1:
end += 1
else:
end = topk
sum_s = 0
for i in range(end):
sum_s += dataNp[i]
random_val *= sum_s
sum_s = 0
for i in range(end):
sum_s += dataNp[i]
if(random_val < sum_s):
if random_val < sum_s:
return indices[i]
def random_sample_0(data):
return torch.argmax(data)
def test(lib, handle, torch_device, voc, random_val, topp, topk, temperature, x_dtype=torch.float16):
print(
f"Testing RandomSample on {torch_device} with voc:{voc} dtype:{x_dtype}"
)
def test(
lib,
handle,
torch_device,
voc,
random_val,
topp,
topk,
temperature,
x_dtype=torch.float16,
):
print(f"Testing RandomSample on {torch_device} with voc:{voc} dtype:{x_dtype}")
data = torch.arange(voc).float() * 0.0001
_perm = torch.randperm(voc)
data = data[_perm].to(x_dtype).to(torch_device)
if(topp > 0 and topk > 1):
ans = random_sample(data.to("cpu"), random_val, topp, topk, voc, temperature, "cpu")
if topp > 0 and topk > 1:
ans = random_sample(
data.to("cpu"), random_val, topp, topk, voc, temperature, "cpu"
)
else:
ans = random_sample_0(data)
indices = torch.zeros([1], dtype=torch.int64).to(torch_device)
......@@ -96,7 +108,10 @@ def test(lib, handle, torch_device, voc, random_val, topp, topk, temperature, x_
descriptor = infiniopRandomSampleDescriptor_t()
check_error(
lib.infiniopCreateRandomSampleDescriptor(
handle, ctypes.byref(descriptor), indices_tensor.descriptor, x_tensor.descriptor
handle,
ctypes.byref(descriptor),
indices_tensor.descriptor,
x_tensor.descriptor,
)
)
......@@ -110,7 +125,7 @@ def test(lib, handle, torch_device, voc, random_val, topp, topk, temperature, x_
descriptor, ctypes.byref(workspace_size)
)
)
workspace = create_workspace(workspace_size.value, torch_device)
workspace = create_workspace(workspace_size.value, torch_device)
check_error(
lib.infiniopRandomSample(
descriptor,
......@@ -131,10 +146,11 @@ def test(lib, handle, torch_device, voc, random_val, topp, topk, temperature, x_
assert indices[0].type(ans.dtype) == ans or data[ans] == data[indices[0]]
check_error(lib.infiniopDestroyRandomSampleDescriptor(descriptor))
def test_cpu(lib, test_cases):
device = DeviceEnum.DEVICE_CPU
handle = create_handle(lib, device)
for (voc, random_val, topp, topk, temperature) in test_cases:
for voc, random_val, topp, topk, temperature in test_cases:
test(lib, handle, "cpu", voc, random_val, topp, topk, temperature)
destroy_handle(lib, handle)
......@@ -142,7 +158,7 @@ def test_cpu(lib, test_cases):
def test_cuda(lib, test_cases):
device = DeviceEnum.DEVICE_CUDA
handle = create_handle(lib, device)
for (voc, random_val, topp, topk, temperature) in test_cases:
for voc, random_val, topp, topk, temperature in test_cases:
test(lib, handle, "cuda", voc, random_val, topp, topk, temperature)
destroy_handle(lib, handle)
......@@ -152,16 +168,17 @@ def test_bang(lib, test_cases):
device = DeviceEnum.DEVICE_BANG
handle = create_handle(lib, device)
for (voc, random_val, topp, topk, temperature) in test_cases:
for voc, random_val, topp, topk, temperature in test_cases:
test(lib, handle, "mlu", voc, random_val, topp, topk, temperature)
destroy_handle(lib, handle)
def test_ascend(lib, test_cases):
import torch_npu
device = DeviceEnum.DEVICE_ASCEND
handle = create_handle(lib, device)
for (voc, random_val, topp, topk, temperature) in test_cases:
for voc, random_val, topp, topk, temperature in test_cases:
test(lib, handle, "npu", voc, random_val, topp, topk, temperature)
destroy_handle(lib, handle)
......@@ -180,7 +197,7 @@ if __name__ == "__main__":
(32000, 0.08, 1.0, 25, 1.0),
# (119696, 0.01, 1.0, 100, 1.0),
]
args = get_args()
lib = open_lib()
lib.infiniopCreateRandomSampleDescriptor.restype = c_int32
......
......@@ -61,9 +61,7 @@ def test(
x_tensor.descriptor.contents.invalidate()
y_tensor.descriptor.contents.invalidate()
check_error(
lib.infiniopRearrange(descriptor, y_tensor.data, x_tensor.data, None)
)
check_error(lib.infiniopRearrange(descriptor, y_tensor.data, x_tensor.data, None))
assert torch.allclose(x, y, atol=0, rtol=1e-3)
check_error(lib.infiniopDestroyRearrangeDescriptor(descriptor))
......@@ -87,8 +85,10 @@ def test_cuda(lib, test_cases):
test(lib, handle, "cuda", x_shape, x_stride, y_shape, y_stride)
destroy_handle(lib, handle)
def test_bang(lib, test_cases):
import torch_mlu
device = DeviceEnum.DEVICE_BANG
handle = create_handle(lib, device)
for test_case in test_cases:
......@@ -97,6 +97,7 @@ def test_bang(lib, test_cases):
test(lib, handle, "mlu", x_shape, x_stride, y_shape, y_stride)
destroy_handle(lib, handle)
def test_ascend(lib, test_cases):
import torch_npu
......@@ -106,7 +107,8 @@ def test_ascend(lib, test_cases):
x_shape, x_stride = test_case[0]
y_shape, y_stride = test_case[1]
test(lib, handle, "npu", x_shape, x_stride, y_shape, y_stride)
destroy_handle(lib, handle)
destroy_handle(lib, handle)
if __name__ == "__main__":
args = get_args()
......@@ -119,7 +121,7 @@ if __name__ == "__main__":
(((32, 1, 64), (64, 2560, 1)), ((32, 1, 64), (64, 64, 1))),
(((4, 1, 64), (64, 2560, 1)), ((4, 1, 64), (64, 11264, 1))),
(((64,), (1,)), ((64,), (1,))),
]
]
lib = open_lib()
lib.infiniopCreateRearrangeDescriptor.restype = c_int32
lib.infiniopCreateRearrangeDescriptor.argtypes = [
......
......@@ -52,7 +52,7 @@ def test(
lib,
handle,
torch_device,
tensor_shape,
tensor_shape,
tensor_dtype=torch.float16,
inplace=Inplace.OUT_OF_PLACE,
):
......@@ -61,7 +61,11 @@ def test(
)
x = torch.rand(tensor_shape, dtype=tensor_dtype).to(torch_device) * 2 - 1
y = torch.rand(tensor_shape, dtype=tensor_dtype).to(torch_device) if inplace == Inplace.OUT_OF_PLACE else x
y = (
torch.rand(tensor_shape, dtype=tensor_dtype).to(torch_device)
if inplace == Inplace.OUT_OF_PLACE
else x
)
for i in range(NUM_PRERUN if PROFILE else 1):
ans = relu(x)
......@@ -108,17 +112,22 @@ def test_cpu(lib, test_cases):
device = DeviceEnum.DEVICE_CPU
handle = create_handle(lib, device)
for tensor_shape, inplace in test_cases:
# fmt: off
test(lib, handle, "cpu", tensor_shape, tensor_dtype=torch.float16, inplace=inplace)
test(lib, handle, "cpu", tensor_shape, tensor_dtype=torch.float32, inplace=inplace)
# fmt: on
destroy_handle(lib, handle)
def test_cuda(lib, test_cases):
device = DeviceEnum.DEVICE_CUDA
handle = create_handle(lib, device)
for tensor_shape, inplace in test_cases:
# fmt: off
test(lib, handle, "cuda", tensor_shape, tensor_dtype=torch.float16, inplace=inplace)
test(lib, handle, "cuda", tensor_shape, tensor_dtype=torch.float32, inplace=inplace)
# fmt: on
destroy_handle(lib, handle)
......@@ -128,8 +137,10 @@ def test_bang(lib, test_cases):
device = DeviceEnum.DEVICE_BANG
handle = create_handle(lib, device)
for tensor_shape, inplace in test_cases:
# fmt: off
test(lib, handle, "mlu", tensor_shape, tensor_dtype=torch.float16, inplace=inplace)
test(lib, handle, "mlu", tensor_shape, tensor_dtype=torch.float32, inplace=inplace)
# fmt: on
destroy_handle(lib, handle)
......
......@@ -20,12 +20,14 @@ from operatorspy import (
from operatorspy.tests.test_utils import get_args
import torch
class RMSNormDescriptor(Structure):
_fields_ = [("device", c_int32)]
infiniopRMSNormDescriptor_t = POINTER(RMSNormDescriptor)
def rms_norm(x, w, eps):
input_dtype = x.dtype
hidden_states = x.to(torch.float32)
......@@ -34,9 +36,20 @@ def rms_norm(x, w, eps):
return w * hidden_states.to(input_dtype)
def test(lib, handle, torch_device, y_shape, x_shape, w_shape, dtype=torch.float16, w_dtype=torch.float16):
print(f"Testing RMS_Norm on {torch_device} with y_shape:{y_shape} x_shape:{x_shape} w_shape:{w_shape}"
f" dtype:{dtype} w_dtype:{w_dtype}")
def test(
lib,
handle,
torch_device,
y_shape,
x_shape,
w_shape,
dtype=torch.float16,
w_dtype=torch.float16,
):
print(
f"Testing RMS_Norm on {torch_device} with y_shape:{y_shape} x_shape:{x_shape} w_shape:{w_shape}"
f" dtype:{dtype} w_dtype:{w_dtype}"
)
y = torch.zeros(y_shape, dtype=dtype).to(torch_device)
x = torch.rand(x_shape, dtype=dtype).to(torch_device)
......@@ -50,12 +63,16 @@ def test(lib, handle, torch_device, y_shape, x_shape, w_shape, dtype=torch.float
w_tensor = to_tensor(w, lib)
descriptor = infiniopRMSNormDescriptor_t()
w_dataType = 0 if w_dtype==torch.float16 else 1
w_dataType = 0 if w_dtype == torch.float16 else 1
check_error(
lib.infiniopCreateRMSNormDescriptor(
handle, ctypes.byref(descriptor), y_tensor.descriptor, x_tensor.descriptor,
w_tensor.descriptor, eps
handle,
ctypes.byref(descriptor),
y_tensor.descriptor,
x_tensor.descriptor,
w_tensor.descriptor,
eps,
)
)
......@@ -66,9 +83,7 @@ def test(lib, handle, torch_device, y_shape, x_shape, w_shape, dtype=torch.float
workspace_size = c_uint64(0)
check_error(
lib.infiniopGetRMSNormWorkspaceSize(
descriptor, ctypes.byref(workspace_size)
)
lib.infiniopGetRMSNormWorkspaceSize(descriptor, ctypes.byref(workspace_size))
)
workspace = create_workspace(workspace_size.value, y.device)
check_error(
......@@ -86,37 +101,44 @@ def test(lib, handle, torch_device, y_shape, x_shape, w_shape, dtype=torch.float
assert torch.allclose(y.to(dtype), ans.to(dtype), atol=1e-3, rtol=1e-3)
check_error(lib.infiniopDestroyRMSNormDescriptor(descriptor))
def test_cpu(lib, test_cases):
device = DeviceEnum.DEVICE_CPU
handle = create_handle(lib, device)
for (y_shape, x_shape, w_shape, dtype, w_dtype) in test_cases:
for y_shape, x_shape, w_shape, dtype, w_dtype in test_cases:
test(lib, handle, "cpu", y_shape, x_shape, w_shape, dtype, w_dtype)
destroy_handle(lib, handle)
def test_cuda(lib, test_cases):
device = DeviceEnum.DEVICE_CUDA
handle = create_handle(lib, device)
for (y_shape, x_shape, w_shape, dtype, w_dtype) in test_cases:
for y_shape, x_shape, w_shape, dtype, w_dtype in test_cases:
test(lib, handle, "cuda", y_shape, x_shape, w_shape, dtype, w_dtype)
destroy_handle(lib, handle)
def test_bang(lib, test_cases):
import torch_mlu
device = DeviceEnum.DEVICE_BANG
handle = create_handle(lib, device)
for (y_shape, x_shape, w_shape, dtype, w_dtype) in test_cases:
for y_shape, x_shape, w_shape, dtype, w_dtype in test_cases:
test(lib, handle, "mlu", y_shape, x_shape, w_shape, dtype, w_dtype)
destroy_handle(lib, handle)
def test_ascend(lib, test_cases):
import torch_npu
device = DeviceEnum.DEVICE_ASCEND
handle = create_handle(lib, device)
for (y_shape, x_shape, w_shape, dtype, w_dtype) in test_cases:
for y_shape, x_shape, w_shape, dtype, w_dtype in test_cases:
test(lib, handle, "npu", y_shape, x_shape, w_shape, dtype, w_dtype)
destroy_handle(lib, handle)
if __name__ == "__main__":
test_cases = [
# y_shape, x_shape, w_shape, dtype, w_dtype
......
......@@ -45,12 +45,13 @@ def rotary_embedding(t, pos, theta, torch_device):
)
freqs = torch.outer(pos, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
t_ = torch.view_as_complex(t.reshape(*t.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, t_)
t_out = torch.view_as_real(t_ * freqs_cis).flatten(2).to(t.dtype)
return t_out
def sin_cos_table(max_seq_len, dim, torch_device, theta):
pos = torch.arange(
0, max_seq_len, dtype=torch.float32, device=torch.device(torch_device)
......@@ -73,12 +74,12 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16):
if strides is not None:
t = rearrange_tensor(t, strides)
posTmp = torch.arange(0, t.shape[0])
pos = torch.zeros(2 * posTmp.shape[0], dtype = torch.int32)
pos = torch.zeros(2 * posTmp.shape[0], dtype=torch.int32)
for i in range(posTmp.shape[0]):
pos[2 * i] = posTmp[i]
pos[2 * i + 1] = 0
theta = 1e4
if torch_device == 'mlu' or torch_device == 'npu':
if torch_device == "mlu" or torch_device == "npu":
ans = rotary_embedding(t, posTmp, theta, "cpu").to(torch_device)
pos = pos.to(torch_device)
t = t.to(torch_device)
......@@ -97,7 +98,7 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16):
cos_table_tensor = to_tensor(cos_table, lib)
if torch_device == "npu":
torch.npu.synchronize()
torch.npu.synchronize()
check_error(
lib.infiniopCreateRoPEDescriptor(
......@@ -156,6 +157,7 @@ def test_cuda(lib, test_cases):
def test_bang(lib, test_cases):
import torch_mlu
device = DeviceEnum.DEVICE_BANG
handle = create_handle(lib, device)
for shape, strides, dtype in test_cases:
......@@ -163,7 +165,7 @@ def test_bang(lib, test_cases):
destroy_handle(lib, handle)
def test_ascend(lib, test_cases) :
def test_ascend(lib, test_cases):
import torch_npu
device = DeviceEnum.DEVICE_ASCEND
......@@ -172,6 +174,7 @@ def test_ascend(lib, test_cases) :
test(lib, handle, "npu", shape, strides, dtype)
destroy_handle(lib, handle)
if __name__ == "__main__":
test_cases = [
((1, 32, 128), None, torch.float16),
......@@ -180,7 +183,6 @@ if __name__ == "__main__":
# 接口 GatherMask 的内部实现相关,目前 48 64 128 都可以支持
((4, 1, 32), None, torch.float16),
((1, 32, 128), None, torch.float16),
((3, 32, 128), (8000, 200, 1), torch.float16),
]
args = get_args()
......
......@@ -29,9 +29,10 @@ infiniopSwiGLUDescriptor_t = POINTER(SwiGLUDescriptor)
def swiglu(a, b):
return a * b / (1 + torch.exp(-b.float()).to(b.dtype))
def test_out_of_place(
lib,
handle,
......@@ -223,6 +224,7 @@ def test_cuda(lib, test_cases):
def test_bang(lib, test_cases):
import torch_mlu
device = DeviceEnum.DEVICE_BANG
handle = create_handle(lib, device)
......@@ -238,17 +240,30 @@ def test_bang(lib, test_cases):
def test_ascend(lib, test_cases):
import torch_npu
device = DeviceEnum.DEVICE_ASCEND
handle = create_handle(lib, device)
for shape, a_stride, b_stride, c_stride, dtype in test_cases:
test_out_of_place(
lib, handle, "npu", shape, a_stride, b_stride, c_stride, dtype, torch.npu.synchronize
lib,
handle,
"npu",
shape,
a_stride,
b_stride,
c_stride,
dtype,
torch.npu.synchronize,
)
test_in_place1(
lib, handle, "npu", shape, a_stride, b_stride, dtype, torch.npu.synchronize
)
test_in_place2(
lib, handle, "npu", shape, a_stride, b_stride, dtype, torch.npu.synchronize
)
test_in_place1(lib, handle, "npu", shape, a_stride, b_stride, dtype, torch.npu.synchronize)
test_in_place2(lib, handle, "npu", shape, a_stride, b_stride, dtype, torch.npu.synchronize)
destroy_handle(lib, handle)
destroy_handle(lib, handle)
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment