Unverified Commit 3c31dc6c authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #45 from YdrMaster/main

issue/52 代码格式化:机制和效果
parents 16dad776 e5ed9fa1
...@@ -2,9 +2,19 @@ import torch ...@@ -2,9 +2,19 @@ import torch
import ctypes import ctypes
from ctypes import POINTER, Structure, c_int32, c_size_t, c_uint64, c_void_p, c_float from ctypes import POINTER, Structure, c_int32, c_size_t, c_uint64, c_void_p, c_float
from libinfiniop import ( from libinfiniop import (
infiniopHandle_t, infiniopTensorDescriptor_t, open_lib, to_tensor, get_test_devices, infiniopHandle_t,
check_error, rearrange_if_needed, create_workspace, test_operator, get_args, infiniopTensorDescriptor_t,
debug, get_tolerance, profile_operation, open_lib,
to_tensor,
get_test_devices,
check_error,
rearrange_if_needed,
create_workspace,
test_operator,
get_args,
debug,
get_tolerance,
profile_operation,
) )
# ============================================================================== # ==============================================================================
...@@ -21,8 +31,8 @@ _TEST_CASES = [ ...@@ -21,8 +31,8 @@ _TEST_CASES = [
(1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), (4096, 1), (4096, 1), (4096, 1)), (1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), (4096, 1), (4096, 1), (4096, 1)),
(1.0, 1.0, (6, 2048), (2048, 2560), (6, 2560), (2048, 1), (1, 2048), (2560, 1)), (1.0, 1.0, (6, 2048), (2048, 2560), (6, 2560), (2048, 1), (1, 2048), (2560, 1)),
(1.0, 1.0, (6, 2048), (2048, 2560), (6, 2560), (2048, 1), (1, 2048), (2560, 1)), (1.0, 1.0, (6, 2048), (2048, 2560), (6, 2560), (2048, 1), (1, 2048), (2560, 1)),
(1.0/8.0, 0.0, (4, 8*6, 64), (4, 64, 6), (4, 8*6, 6), None, None, None), (1.0 / 8.0, 0.0, (4, 8 * 6, 64), (4, 64, 6), (4, 8 * 6, 6), None, None, None),
(1.0/8.0, 0.0, (4, 8*6, 64), (4, 64, 6), (4, 8*6, 6), None, None, None), (1.0 / 8.0, 0.0, (4, 8 * 6, 64), (4, 64, 6), (4, 8 * 6, 6), None, None, None),
] ]
# Data types used for testing # Data types used for testing
...@@ -30,8 +40,8 @@ _TENSOR_DTYPES = [torch.float16, torch.float32] ...@@ -30,8 +40,8 @@ _TENSOR_DTYPES = [torch.float16, torch.float32]
# Tolerance map for different data types # Tolerance map for different data types
_TOLERANCE_MAP = { _TOLERANCE_MAP = {
torch.float16: {'atol': 0, 'rtol': 1e-2}, torch.float16: {"atol": 0, "rtol": 1e-2},
torch.float32: {'atol': 0, 'rtol': 1e-3}, torch.float32: {"atol": 0, "rtol": 1e-3},
} }
DEBUG = False DEBUG = False
...@@ -39,6 +49,7 @@ PROFILE = False ...@@ -39,6 +49,7 @@ PROFILE = False
NUM_PRERUN = 10 NUM_PRERUN = 10
NUM_ITERATIONS = 1000 NUM_ITERATIONS = 1000
# ============================================================================== # ==============================================================================
# Definitions # Definitions
# ============================================================================== # ==============================================================================
...@@ -48,6 +59,7 @@ class MatmulDescriptor(Structure): ...@@ -48,6 +59,7 @@ class MatmulDescriptor(Structure):
infiniopMatmulDescriptor_t = POINTER(MatmulDescriptor) infiniopMatmulDescriptor_t = POINTER(MatmulDescriptor)
# PyTorch implementation for matrix multiplication # PyTorch implementation for matrix multiplication
def matmul(_c, beta, _a, _b, alpha): def matmul(_c, beta, _a, _b, alpha):
a, b, c = _a.clone(), _b.clone(), _c.clone() a, b, c = _a.clone(), _b.clone(), _c.clone()
...@@ -55,6 +67,7 @@ def matmul(_c, beta, _a, _b, alpha): ...@@ -55,6 +67,7 @@ def matmul(_c, beta, _a, _b, alpha):
fp32_result = torch.matmul(a.to(torch.float32), b.to(torch.float32)) fp32_result = torch.matmul(a.to(torch.float32), b.to(torch.float32))
return alpha * fp32_result.to(result_dtype) + beta * c return alpha * fp32_result.to(result_dtype) + beta * c
# The argument list should be (lib, handle, torch_device, <param list>, dtype) # The argument list should be (lib, handle, torch_device, <param list>, dtype)
# The <param list> should keep the same order as the one specified in _TEST_CASES # The <param list> should keep the same order as the one specified in _TEST_CASES
def test( def test(
...@@ -85,7 +98,10 @@ def test( ...@@ -85,7 +98,10 @@ def test(
# Compute the PyTorch reference result # Compute the PyTorch reference result
ans = matmul(c, beta, a, b, alpha) ans = matmul(c, beta, a, b, alpha)
a, b, c = [rearrange_if_needed(tensor, stride) for tensor, stride in zip([a, b, c], [a_stride, b_stride, c_stride])] a, b, c = [
rearrange_if_needed(tensor, stride)
for tensor, stride in zip([a, b, c], [a_stride, b_stride, c_stride])
]
a_tensor, b_tensor, c_tensor = [to_tensor(tensor, lib) for tensor in [a, b, c]] a_tensor, b_tensor, c_tensor = [to_tensor(tensor, lib) for tensor in [a, b, c]]
descriptor = infiniopMatmulDescriptor_t() descriptor = infiniopMatmulDescriptor_t()
...@@ -95,7 +111,7 @@ def test( ...@@ -95,7 +111,7 @@ def test(
ctypes.byref(descriptor), ctypes.byref(descriptor),
c_tensor.descriptor, c_tensor.descriptor,
a_tensor.descriptor, a_tensor.descriptor,
b_tensor.descriptor b_tensor.descriptor,
) )
) )
...@@ -105,22 +121,27 @@ def test( ...@@ -105,22 +121,27 @@ def test(
# Get workspace size and create workspace # Get workspace size and create workspace
workspace_size = c_uint64(0) workspace_size = c_uint64(0)
check_error(lib.infiniopGetMatmulWorkspaceSize(descriptor, ctypes.byref(workspace_size))) check_error(
lib.infiniopGetMatmulWorkspaceSize(descriptor, ctypes.byref(workspace_size))
)
workspace = create_workspace(workspace_size.value, a.device) workspace = create_workspace(workspace_size.value, a.device)
# Execute infiniop matmul operator # Execute infiniop matmul operator
def lib_matmul(): def lib_matmul():
check_error(lib.infiniopMatmul( check_error(
descriptor, lib.infiniopMatmul(
workspace.data_ptr() if workspace is not None else None, descriptor,
workspace_size.value, workspace.data_ptr() if workspace is not None else None,
c_tensor.data, workspace_size.value,
a_tensor.data, c_tensor.data,
b_tensor.data, a_tensor.data,
alpha, b_tensor.data,
beta, alpha,
None, beta,
)) None,
)
)
lib_matmul() lib_matmul()
# Validate results # Validate results
...@@ -131,9 +152,10 @@ def test( ...@@ -131,9 +152,10 @@ def test(
# Profiling workflow # Profiling workflow
if PROFILE: if PROFILE:
# fmt: off
profile_operation("PyTorch", lambda: matmul(c, beta, a, b, alpha), torch_device, NUM_PRERUN, NUM_ITERATIONS) profile_operation("PyTorch", lambda: matmul(c, beta, a, b, alpha), torch_device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_matmul(), torch_device, NUM_PRERUN, NUM_ITERATIONS) profile_operation(" lib", lambda: lib_matmul(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on
check_error(lib.infiniopDestroyMatmulDescriptor(descriptor)) check_error(lib.infiniopDestroyMatmulDescriptor(descriptor))
...@@ -150,7 +172,7 @@ if __name__ == "__main__": ...@@ -150,7 +172,7 @@ if __name__ == "__main__":
POINTER(infiniopMatmulDescriptor_t), POINTER(infiniopMatmulDescriptor_t),
infiniopTensorDescriptor_t, infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t, infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t infiniopTensorDescriptor_t,
] ]
lib.infiniopGetMatmulWorkspaceSize.restype = c_int32 lib.infiniopGetMatmulWorkspaceSize.restype = c_int32
......
...@@ -35,7 +35,7 @@ class MaxPoolDescriptor(Structure): ...@@ -35,7 +35,7 @@ class MaxPoolDescriptor(Structure):
infiniopMaxPoolDescriptor_t = POINTER(MaxPoolDescriptor) infiniopMaxPoolDescriptor_t = POINTER(MaxPoolDescriptor)
def pool(x, k, padding, stride, dilation = 1): def pool(x, k, padding, stride, dilation=1):
pooling_layers = { pooling_layers = {
1: torch.nn.MaxPool1d, 1: torch.nn.MaxPool1d,
2: torch.nn.MaxPool2d, 2: torch.nn.MaxPool2d,
...@@ -66,18 +66,20 @@ def inferShape(x_shape, kernel_shape, padding, strides): ...@@ -66,18 +66,20 @@ def inferShape(x_shape, kernel_shape, padding, strides):
return x_shape[:2] + tuple(output_shape) return x_shape[:2] + tuple(output_shape)
# convert a python tuple to a ctype void pointer # convert a python tuple to a ctype void pointer
def tuple_to_void_p(py_tuple: Tuple): def tuple_to_void_p(py_tuple: Tuple):
array = ctypes.c_int64 * len(py_tuple) array = ctypes.c_int64 * len(py_tuple)
data_array = array(*py_tuple) data_array = array(*py_tuple)
return ctypes.cast(data_array, ctypes.c_void_p) return ctypes.cast(data_array, ctypes.c_void_p)
def test( def test(
lib, lib,
handle, handle,
torch_device, torch_device,
x_shape, x_shape,
k_shape, k_shape,
padding, padding,
strides, strides,
tensor_dtype=torch.float16, tensor_dtype=torch.float16,
...@@ -87,7 +89,9 @@ def test( ...@@ -87,7 +89,9 @@ def test(
) )
x = torch.rand(x_shape, dtype=tensor_dtype).to(torch_device) x = torch.rand(x_shape, dtype=tensor_dtype).to(torch_device)
y = torch.rand(inferShape(x_shape, k_shape, padding, strides), dtype=tensor_dtype).to(torch_device) y = torch.rand(
inferShape(x_shape, k_shape, padding, strides), dtype=tensor_dtype
).to(torch_device)
for i in range(NUM_PRERUN if PROFILE else 1): for i in range(NUM_PRERUN if PROFILE else 1):
ans = pool(x, k_shape, padding, strides) ans = pool(x, k_shape, padding, strides)
...@@ -123,7 +127,9 @@ def test( ...@@ -123,7 +127,9 @@ def test(
check_error( check_error(
lib.infiniopGetMaxPoolWorkspaceSize(descriptor, ctypes.byref(workspaceSize)) lib.infiniopGetMaxPoolWorkspaceSize(descriptor, ctypes.byref(workspaceSize))
) )
workspace = torch.zeros(int(workspaceSize.value), dtype=torch.uint8).to(torch_device) workspace = torch.zeros(int(workspaceSize.value), dtype=torch.uint8).to(
torch_device
)
workspace_ptr = ctypes.cast(workspace.data_ptr(), ctypes.POINTER(ctypes.c_uint8)) workspace_ptr = ctypes.cast(workspace.data_ptr(), ctypes.POINTER(ctypes.c_uint8))
for i in range(NUM_PRERUN if PROFILE else 1): for i in range(NUM_PRERUN if PROFILE else 1):
...@@ -161,8 +167,10 @@ def test_cpu(lib, test_cases): ...@@ -161,8 +167,10 @@ def test_cpu(lib, test_cases):
device = DeviceEnum.DEVICE_CPU device = DeviceEnum.DEVICE_CPU
handle = create_handle(lib, device) handle = create_handle(lib, device)
for x_shape, kernel_shape, padding, strides in test_cases: for x_shape, kernel_shape, padding, strides in test_cases:
# fmt: off
test(lib, handle, "cpu", x_shape, kernel_shape, padding, strides, tensor_dtype=torch.float16) test(lib, handle, "cpu", x_shape, kernel_shape, padding, strides, tensor_dtype=torch.float16)
test(lib, handle, "cpu", x_shape, kernel_shape, padding, strides, tensor_dtype=torch.float32) test(lib, handle, "cpu", x_shape, kernel_shape, padding, strides, tensor_dtype=torch.float32)
# fmt: on
destroy_handle(lib, handle) destroy_handle(lib, handle)
...@@ -170,8 +178,10 @@ def test_cuda(lib, test_cases): ...@@ -170,8 +178,10 @@ def test_cuda(lib, test_cases):
device = DeviceEnum.DEVICE_CUDA device = DeviceEnum.DEVICE_CUDA
handle = create_handle(lib, device) handle = create_handle(lib, device)
for x_shape, kernel_shape, padding, strides in test_cases: for x_shape, kernel_shape, padding, strides in test_cases:
# fmt: off
test(lib, handle, "cuda", x_shape, kernel_shape, padding, strides, tensor_dtype=torch.float16) test(lib, handle, "cuda", x_shape, kernel_shape, padding, strides, tensor_dtype=torch.float16)
test(lib, handle, "cuda", x_shape, kernel_shape, padding, strides, tensor_dtype=torch.float32) test(lib, handle, "cuda", x_shape, kernel_shape, padding, strides, tensor_dtype=torch.float32)
# fmt: on
destroy_handle(lib, handle) destroy_handle(lib, handle)
...@@ -181,8 +191,10 @@ def test_bang(lib, test_cases): ...@@ -181,8 +191,10 @@ def test_bang(lib, test_cases):
device = DeviceEnum.DEVICE_BANG device = DeviceEnum.DEVICE_BANG
handle = create_handle(lib, device) handle = create_handle(lib, device)
for x_shape, kernel_shape, padding, strides in test_cases: for x_shape, kernel_shape, padding, strides in test_cases:
# fmt: off
test(lib, handle, "mlu", x_shape, kernel_shape, padding, strides, tensor_dtype=torch.float16) test(lib, handle, "mlu", x_shape, kernel_shape, padding, strides, tensor_dtype=torch.float16)
test(lib, handle, "mlu", x_shape, kernel_shape, padding, strides, tensor_dtype=torch.float32) test(lib, handle, "mlu", x_shape, kernel_shape, padding, strides, tensor_dtype=torch.float32)
# fmt: on
destroy_handle(lib, handle) destroy_handle(lib, handle)
......
...@@ -30,13 +30,13 @@ infiniopRandomSampleDescriptor_t = POINTER(RandomSampleDescriptor) ...@@ -30,13 +30,13 @@ infiniopRandomSampleDescriptor_t = POINTER(RandomSampleDescriptor)
def random_sample(data, random_val, topp, topk, voc, temperature, torch_device): def random_sample(data, random_val, topp, topk, voc, temperature, torch_device):
indices = torch.zeros([topk], dtype = torch.int64) indices = torch.zeros([topk], dtype=torch.int64)
dataNp = data.clone().detach() dataNp = data.clone().detach()
sorted_indices = torch.arange(voc) sorted_indices = torch.arange(voc)
for i in range(topk): for i in range(topk):
for j in range(i + 1, voc): for j in range(i + 1, voc):
if(dataNp[i] < dataNp[j]): if dataNp[i] < dataNp[j]:
tmp = dataNp[i].clone().detach() tmp = dataNp[i].clone().detach()
dataNp[i] = dataNp[j].clone().detach() dataNp[i] = dataNp[j].clone().detach()
dataNp[j] = tmp dataNp[j] = tmp
...@@ -44,48 +44,60 @@ def random_sample(data, random_val, topp, topk, voc, temperature, torch_device): ...@@ -44,48 +44,60 @@ def random_sample(data, random_val, topp, topk, voc, temperature, torch_device):
tmpInd = sorted_indices[i].clone().detach() tmpInd = sorted_indices[i].clone().detach()
sorted_indices[i] = sorted_indices[j].clone().detach() sorted_indices[i] = sorted_indices[j].clone().detach()
sorted_indices[j] = tmpInd sorted_indices[j] = tmpInd
#sorted_indices = torch.argsort(dataNp, descending=True) # sorted_indices = torch.argsort(dataNp, descending=True)
indices = sorted_indices[:topk] indices = sorted_indices[:topk]
dataNp = dataNp[sorted_indices] dataNp = dataNp[sorted_indices]
globalM = dataNp[0] globalM = dataNp[0]
dataNp = (dataNp - globalM) / temperature dataNp = (dataNp - globalM) / temperature
dataNp = torch.softmax(dataNp.float(), dim = 0) dataNp = torch.softmax(dataNp.float(), dim=0)
sum_s = 0 sum_s = 0
for end in range(topk): for end in range(topk):
sum_s += dataNp[end] sum_s += dataNp[end]
if(sum_s >= topp): if sum_s >= topp:
break break
if(end < topk - 1): if end < topk - 1:
end += 1 end += 1
else: else:
end = topk end = topk
sum_s = 0 sum_s = 0
for i in range(end): for i in range(end):
sum_s += dataNp[i] sum_s += dataNp[i]
random_val *= sum_s random_val *= sum_s
sum_s = 0 sum_s = 0
for i in range(end): for i in range(end):
sum_s += dataNp[i] sum_s += dataNp[i]
if(random_val < sum_s): if random_val < sum_s:
return indices[i] return indices[i]
def random_sample_0(data): def random_sample_0(data):
return torch.argmax(data) return torch.argmax(data)
def test(lib, handle, torch_device, voc, random_val, topp, topk, temperature, x_dtype=torch.float16):
print( def test(
f"Testing RandomSample on {torch_device} with voc:{voc} dtype:{x_dtype}" lib,
) handle,
torch_device,
voc,
random_val,
topp,
topk,
temperature,
x_dtype=torch.float16,
):
print(f"Testing RandomSample on {torch_device} with voc:{voc} dtype:{x_dtype}")
data = torch.arange(voc).float() * 0.0001 data = torch.arange(voc).float() * 0.0001
_perm = torch.randperm(voc) _perm = torch.randperm(voc)
data = data[_perm].to(x_dtype).to(torch_device) data = data[_perm].to(x_dtype).to(torch_device)
if(topp > 0 and topk > 1): if topp > 0 and topk > 1:
ans = random_sample(data.to("cpu"), random_val, topp, topk, voc, temperature, "cpu") ans = random_sample(
data.to("cpu"), random_val, topp, topk, voc, temperature, "cpu"
)
else: else:
ans = random_sample_0(data) ans = random_sample_0(data)
indices = torch.zeros([1], dtype=torch.int64).to(torch_device) indices = torch.zeros([1], dtype=torch.int64).to(torch_device)
...@@ -96,7 +108,10 @@ def test(lib, handle, torch_device, voc, random_val, topp, topk, temperature, x_ ...@@ -96,7 +108,10 @@ def test(lib, handle, torch_device, voc, random_val, topp, topk, temperature, x_
descriptor = infiniopRandomSampleDescriptor_t() descriptor = infiniopRandomSampleDescriptor_t()
check_error( check_error(
lib.infiniopCreateRandomSampleDescriptor( lib.infiniopCreateRandomSampleDescriptor(
handle, ctypes.byref(descriptor), indices_tensor.descriptor, x_tensor.descriptor handle,
ctypes.byref(descriptor),
indices_tensor.descriptor,
x_tensor.descriptor,
) )
) )
...@@ -110,7 +125,7 @@ def test(lib, handle, torch_device, voc, random_val, topp, topk, temperature, x_ ...@@ -110,7 +125,7 @@ def test(lib, handle, torch_device, voc, random_val, topp, topk, temperature, x_
descriptor, ctypes.byref(workspace_size) descriptor, ctypes.byref(workspace_size)
) )
) )
workspace = create_workspace(workspace_size.value, torch_device) workspace = create_workspace(workspace_size.value, torch_device)
check_error( check_error(
lib.infiniopRandomSample( lib.infiniopRandomSample(
descriptor, descriptor,
...@@ -131,10 +146,11 @@ def test(lib, handle, torch_device, voc, random_val, topp, topk, temperature, x_ ...@@ -131,10 +146,11 @@ def test(lib, handle, torch_device, voc, random_val, topp, topk, temperature, x_
assert indices[0].type(ans.dtype) == ans or data[ans] == data[indices[0]] assert indices[0].type(ans.dtype) == ans or data[ans] == data[indices[0]]
check_error(lib.infiniopDestroyRandomSampleDescriptor(descriptor)) check_error(lib.infiniopDestroyRandomSampleDescriptor(descriptor))
def test_cpu(lib, test_cases): def test_cpu(lib, test_cases):
device = DeviceEnum.DEVICE_CPU device = DeviceEnum.DEVICE_CPU
handle = create_handle(lib, device) handle = create_handle(lib, device)
for (voc, random_val, topp, topk, temperature) in test_cases: for voc, random_val, topp, topk, temperature in test_cases:
test(lib, handle, "cpu", voc, random_val, topp, topk, temperature) test(lib, handle, "cpu", voc, random_val, topp, topk, temperature)
destroy_handle(lib, handle) destroy_handle(lib, handle)
...@@ -142,7 +158,7 @@ def test_cpu(lib, test_cases): ...@@ -142,7 +158,7 @@ def test_cpu(lib, test_cases):
def test_cuda(lib, test_cases): def test_cuda(lib, test_cases):
device = DeviceEnum.DEVICE_CUDA device = DeviceEnum.DEVICE_CUDA
handle = create_handle(lib, device) handle = create_handle(lib, device)
for (voc, random_val, topp, topk, temperature) in test_cases: for voc, random_val, topp, topk, temperature in test_cases:
test(lib, handle, "cuda", voc, random_val, topp, topk, temperature) test(lib, handle, "cuda", voc, random_val, topp, topk, temperature)
destroy_handle(lib, handle) destroy_handle(lib, handle)
...@@ -152,16 +168,17 @@ def test_bang(lib, test_cases): ...@@ -152,16 +168,17 @@ def test_bang(lib, test_cases):
device = DeviceEnum.DEVICE_BANG device = DeviceEnum.DEVICE_BANG
handle = create_handle(lib, device) handle = create_handle(lib, device)
for (voc, random_val, topp, topk, temperature) in test_cases: for voc, random_val, topp, topk, temperature in test_cases:
test(lib, handle, "mlu", voc, random_val, topp, topk, temperature) test(lib, handle, "mlu", voc, random_val, topp, topk, temperature)
destroy_handle(lib, handle) destroy_handle(lib, handle)
def test_ascend(lib, test_cases): def test_ascend(lib, test_cases):
import torch_npu import torch_npu
device = DeviceEnum.DEVICE_ASCEND device = DeviceEnum.DEVICE_ASCEND
handle = create_handle(lib, device) handle = create_handle(lib, device)
for (voc, random_val, topp, topk, temperature) in test_cases: for voc, random_val, topp, topk, temperature in test_cases:
test(lib, handle, "npu", voc, random_val, topp, topk, temperature) test(lib, handle, "npu", voc, random_val, topp, topk, temperature)
destroy_handle(lib, handle) destroy_handle(lib, handle)
...@@ -180,7 +197,7 @@ if __name__ == "__main__": ...@@ -180,7 +197,7 @@ if __name__ == "__main__":
(32000, 0.08, 1.0, 25, 1.0), (32000, 0.08, 1.0, 25, 1.0),
# (119696, 0.01, 1.0, 100, 1.0), # (119696, 0.01, 1.0, 100, 1.0),
] ]
args = get_args() args = get_args()
lib = open_lib() lib = open_lib()
lib.infiniopCreateRandomSampleDescriptor.restype = c_int32 lib.infiniopCreateRandomSampleDescriptor.restype = c_int32
......
...@@ -61,9 +61,7 @@ def test( ...@@ -61,9 +61,7 @@ def test(
x_tensor.descriptor.contents.invalidate() x_tensor.descriptor.contents.invalidate()
y_tensor.descriptor.contents.invalidate() y_tensor.descriptor.contents.invalidate()
check_error( check_error(lib.infiniopRearrange(descriptor, y_tensor.data, x_tensor.data, None))
lib.infiniopRearrange(descriptor, y_tensor.data, x_tensor.data, None)
)
assert torch.allclose(x, y, atol=0, rtol=1e-3) assert torch.allclose(x, y, atol=0, rtol=1e-3)
check_error(lib.infiniopDestroyRearrangeDescriptor(descriptor)) check_error(lib.infiniopDestroyRearrangeDescriptor(descriptor))
...@@ -87,8 +85,10 @@ def test_cuda(lib, test_cases): ...@@ -87,8 +85,10 @@ def test_cuda(lib, test_cases):
test(lib, handle, "cuda", x_shape, x_stride, y_shape, y_stride) test(lib, handle, "cuda", x_shape, x_stride, y_shape, y_stride)
destroy_handle(lib, handle) destroy_handle(lib, handle)
def test_bang(lib, test_cases): def test_bang(lib, test_cases):
import torch_mlu import torch_mlu
device = DeviceEnum.DEVICE_BANG device = DeviceEnum.DEVICE_BANG
handle = create_handle(lib, device) handle = create_handle(lib, device)
for test_case in test_cases: for test_case in test_cases:
...@@ -97,6 +97,7 @@ def test_bang(lib, test_cases): ...@@ -97,6 +97,7 @@ def test_bang(lib, test_cases):
test(lib, handle, "mlu", x_shape, x_stride, y_shape, y_stride) test(lib, handle, "mlu", x_shape, x_stride, y_shape, y_stride)
destroy_handle(lib, handle) destroy_handle(lib, handle)
def test_ascend(lib, test_cases): def test_ascend(lib, test_cases):
import torch_npu import torch_npu
...@@ -106,7 +107,8 @@ def test_ascend(lib, test_cases): ...@@ -106,7 +107,8 @@ def test_ascend(lib, test_cases):
x_shape, x_stride = test_case[0] x_shape, x_stride = test_case[0]
y_shape, y_stride = test_case[1] y_shape, y_stride = test_case[1]
test(lib, handle, "npu", x_shape, x_stride, y_shape, y_stride) test(lib, handle, "npu", x_shape, x_stride, y_shape, y_stride)
destroy_handle(lib, handle) destroy_handle(lib, handle)
if __name__ == "__main__": if __name__ == "__main__":
args = get_args() args = get_args()
...@@ -119,7 +121,7 @@ if __name__ == "__main__": ...@@ -119,7 +121,7 @@ if __name__ == "__main__":
(((32, 1, 64), (64, 2560, 1)), ((32, 1, 64), (64, 64, 1))), (((32, 1, 64), (64, 2560, 1)), ((32, 1, 64), (64, 64, 1))),
(((4, 1, 64), (64, 2560, 1)), ((4, 1, 64), (64, 11264, 1))), (((4, 1, 64), (64, 2560, 1)), ((4, 1, 64), (64, 11264, 1))),
(((64,), (1,)), ((64,), (1,))), (((64,), (1,)), ((64,), (1,))),
] ]
lib = open_lib() lib = open_lib()
lib.infiniopCreateRearrangeDescriptor.restype = c_int32 lib.infiniopCreateRearrangeDescriptor.restype = c_int32
lib.infiniopCreateRearrangeDescriptor.argtypes = [ lib.infiniopCreateRearrangeDescriptor.argtypes = [
......
...@@ -52,7 +52,7 @@ def test( ...@@ -52,7 +52,7 @@ def test(
lib, lib,
handle, handle,
torch_device, torch_device,
tensor_shape, tensor_shape,
tensor_dtype=torch.float16, tensor_dtype=torch.float16,
inplace=Inplace.OUT_OF_PLACE, inplace=Inplace.OUT_OF_PLACE,
): ):
...@@ -61,7 +61,11 @@ def test( ...@@ -61,7 +61,11 @@ def test(
) )
x = torch.rand(tensor_shape, dtype=tensor_dtype).to(torch_device) * 2 - 1 x = torch.rand(tensor_shape, dtype=tensor_dtype).to(torch_device) * 2 - 1
y = torch.rand(tensor_shape, dtype=tensor_dtype).to(torch_device) if inplace == Inplace.OUT_OF_PLACE else x y = (
torch.rand(tensor_shape, dtype=tensor_dtype).to(torch_device)
if inplace == Inplace.OUT_OF_PLACE
else x
)
for i in range(NUM_PRERUN if PROFILE else 1): for i in range(NUM_PRERUN if PROFILE else 1):
ans = relu(x) ans = relu(x)
...@@ -108,17 +112,22 @@ def test_cpu(lib, test_cases): ...@@ -108,17 +112,22 @@ def test_cpu(lib, test_cases):
device = DeviceEnum.DEVICE_CPU device = DeviceEnum.DEVICE_CPU
handle = create_handle(lib, device) handle = create_handle(lib, device)
for tensor_shape, inplace in test_cases: for tensor_shape, inplace in test_cases:
# fmt: off
test(lib, handle, "cpu", tensor_shape, tensor_dtype=torch.float16, inplace=inplace) test(lib, handle, "cpu", tensor_shape, tensor_dtype=torch.float16, inplace=inplace)
test(lib, handle, "cpu", tensor_shape, tensor_dtype=torch.float32, inplace=inplace) test(lib, handle, "cpu", tensor_shape, tensor_dtype=torch.float32, inplace=inplace)
# fmt: on
destroy_handle(lib, handle) destroy_handle(lib, handle)
def test_cuda(lib, test_cases): def test_cuda(lib, test_cases):
device = DeviceEnum.DEVICE_CUDA device = DeviceEnum.DEVICE_CUDA
handle = create_handle(lib, device) handle = create_handle(lib, device)
for tensor_shape, inplace in test_cases: for tensor_shape, inplace in test_cases:
# fmt: off
test(lib, handle, "cuda", tensor_shape, tensor_dtype=torch.float16, inplace=inplace) test(lib, handle, "cuda", tensor_shape, tensor_dtype=torch.float16, inplace=inplace)
test(lib, handle, "cuda", tensor_shape, tensor_dtype=torch.float32, inplace=inplace) test(lib, handle, "cuda", tensor_shape, tensor_dtype=torch.float32, inplace=inplace)
# fmt: on
destroy_handle(lib, handle) destroy_handle(lib, handle)
...@@ -128,8 +137,10 @@ def test_bang(lib, test_cases): ...@@ -128,8 +137,10 @@ def test_bang(lib, test_cases):
device = DeviceEnum.DEVICE_BANG device = DeviceEnum.DEVICE_BANG
handle = create_handle(lib, device) handle = create_handle(lib, device)
for tensor_shape, inplace in test_cases: for tensor_shape, inplace in test_cases:
# fmt: off
test(lib, handle, "mlu", tensor_shape, tensor_dtype=torch.float16, inplace=inplace) test(lib, handle, "mlu", tensor_shape, tensor_dtype=torch.float16, inplace=inplace)
test(lib, handle, "mlu", tensor_shape, tensor_dtype=torch.float32, inplace=inplace) test(lib, handle, "mlu", tensor_shape, tensor_dtype=torch.float32, inplace=inplace)
# fmt: on
destroy_handle(lib, handle) destroy_handle(lib, handle)
......
...@@ -20,12 +20,14 @@ from operatorspy import ( ...@@ -20,12 +20,14 @@ from operatorspy import (
from operatorspy.tests.test_utils import get_args from operatorspy.tests.test_utils import get_args
import torch import torch
class RMSNormDescriptor(Structure): class RMSNormDescriptor(Structure):
_fields_ = [("device", c_int32)] _fields_ = [("device", c_int32)]
infiniopRMSNormDescriptor_t = POINTER(RMSNormDescriptor) infiniopRMSNormDescriptor_t = POINTER(RMSNormDescriptor)
def rms_norm(x, w, eps): def rms_norm(x, w, eps):
input_dtype = x.dtype input_dtype = x.dtype
hidden_states = x.to(torch.float32) hidden_states = x.to(torch.float32)
...@@ -34,9 +36,20 @@ def rms_norm(x, w, eps): ...@@ -34,9 +36,20 @@ def rms_norm(x, w, eps):
return w * hidden_states.to(input_dtype) return w * hidden_states.to(input_dtype)
def test(lib, handle, torch_device, y_shape, x_shape, w_shape, dtype=torch.float16, w_dtype=torch.float16): def test(
print(f"Testing RMS_Norm on {torch_device} with y_shape:{y_shape} x_shape:{x_shape} w_shape:{w_shape}" lib,
f" dtype:{dtype} w_dtype:{w_dtype}") handle,
torch_device,
y_shape,
x_shape,
w_shape,
dtype=torch.float16,
w_dtype=torch.float16,
):
print(
f"Testing RMS_Norm on {torch_device} with y_shape:{y_shape} x_shape:{x_shape} w_shape:{w_shape}"
f" dtype:{dtype} w_dtype:{w_dtype}"
)
y = torch.zeros(y_shape, dtype=dtype).to(torch_device) y = torch.zeros(y_shape, dtype=dtype).to(torch_device)
x = torch.rand(x_shape, dtype=dtype).to(torch_device) x = torch.rand(x_shape, dtype=dtype).to(torch_device)
...@@ -50,12 +63,16 @@ def test(lib, handle, torch_device, y_shape, x_shape, w_shape, dtype=torch.float ...@@ -50,12 +63,16 @@ def test(lib, handle, torch_device, y_shape, x_shape, w_shape, dtype=torch.float
w_tensor = to_tensor(w, lib) w_tensor = to_tensor(w, lib)
descriptor = infiniopRMSNormDescriptor_t() descriptor = infiniopRMSNormDescriptor_t()
w_dataType = 0 if w_dtype==torch.float16 else 1 w_dataType = 0 if w_dtype == torch.float16 else 1
check_error( check_error(
lib.infiniopCreateRMSNormDescriptor( lib.infiniopCreateRMSNormDescriptor(
handle, ctypes.byref(descriptor), y_tensor.descriptor, x_tensor.descriptor, handle,
w_tensor.descriptor, eps ctypes.byref(descriptor),
y_tensor.descriptor,
x_tensor.descriptor,
w_tensor.descriptor,
eps,
) )
) )
...@@ -66,9 +83,7 @@ def test(lib, handle, torch_device, y_shape, x_shape, w_shape, dtype=torch.float ...@@ -66,9 +83,7 @@ def test(lib, handle, torch_device, y_shape, x_shape, w_shape, dtype=torch.float
workspace_size = c_uint64(0) workspace_size = c_uint64(0)
check_error( check_error(
lib.infiniopGetRMSNormWorkspaceSize( lib.infiniopGetRMSNormWorkspaceSize(descriptor, ctypes.byref(workspace_size))
descriptor, ctypes.byref(workspace_size)
)
) )
workspace = create_workspace(workspace_size.value, y.device) workspace = create_workspace(workspace_size.value, y.device)
check_error( check_error(
...@@ -86,37 +101,44 @@ def test(lib, handle, torch_device, y_shape, x_shape, w_shape, dtype=torch.float ...@@ -86,37 +101,44 @@ def test(lib, handle, torch_device, y_shape, x_shape, w_shape, dtype=torch.float
assert torch.allclose(y.to(dtype), ans.to(dtype), atol=1e-3, rtol=1e-3) assert torch.allclose(y.to(dtype), ans.to(dtype), atol=1e-3, rtol=1e-3)
check_error(lib.infiniopDestroyRMSNormDescriptor(descriptor)) check_error(lib.infiniopDestroyRMSNormDescriptor(descriptor))
def test_cpu(lib, test_cases): def test_cpu(lib, test_cases):
device = DeviceEnum.DEVICE_CPU device = DeviceEnum.DEVICE_CPU
handle = create_handle(lib, device) handle = create_handle(lib, device)
for (y_shape, x_shape, w_shape, dtype, w_dtype) in test_cases: for y_shape, x_shape, w_shape, dtype, w_dtype in test_cases:
test(lib, handle, "cpu", y_shape, x_shape, w_shape, dtype, w_dtype) test(lib, handle, "cpu", y_shape, x_shape, w_shape, dtype, w_dtype)
destroy_handle(lib, handle) destroy_handle(lib, handle)
def test_cuda(lib, test_cases): def test_cuda(lib, test_cases):
device = DeviceEnum.DEVICE_CUDA device = DeviceEnum.DEVICE_CUDA
handle = create_handle(lib, device) handle = create_handle(lib, device)
for (y_shape, x_shape, w_shape, dtype, w_dtype) in test_cases: for y_shape, x_shape, w_shape, dtype, w_dtype in test_cases:
test(lib, handle, "cuda", y_shape, x_shape, w_shape, dtype, w_dtype) test(lib, handle, "cuda", y_shape, x_shape, w_shape, dtype, w_dtype)
destroy_handle(lib, handle) destroy_handle(lib, handle)
def test_bang(lib, test_cases): def test_bang(lib, test_cases):
import torch_mlu import torch_mlu
device = DeviceEnum.DEVICE_BANG device = DeviceEnum.DEVICE_BANG
handle = create_handle(lib, device) handle = create_handle(lib, device)
for (y_shape, x_shape, w_shape, dtype, w_dtype) in test_cases: for y_shape, x_shape, w_shape, dtype, w_dtype in test_cases:
test(lib, handle, "mlu", y_shape, x_shape, w_shape, dtype, w_dtype) test(lib, handle, "mlu", y_shape, x_shape, w_shape, dtype, w_dtype)
destroy_handle(lib, handle) destroy_handle(lib, handle)
def test_ascend(lib, test_cases): def test_ascend(lib, test_cases):
import torch_npu import torch_npu
device = DeviceEnum.DEVICE_ASCEND device = DeviceEnum.DEVICE_ASCEND
handle = create_handle(lib, device) handle = create_handle(lib, device)
for (y_shape, x_shape, w_shape, dtype, w_dtype) in test_cases: for y_shape, x_shape, w_shape, dtype, w_dtype in test_cases:
test(lib, handle, "npu", y_shape, x_shape, w_shape, dtype, w_dtype) test(lib, handle, "npu", y_shape, x_shape, w_shape, dtype, w_dtype)
destroy_handle(lib, handle) destroy_handle(lib, handle)
if __name__ == "__main__": if __name__ == "__main__":
test_cases = [ test_cases = [
# y_shape, x_shape, w_shape, dtype, w_dtype # y_shape, x_shape, w_shape, dtype, w_dtype
......
...@@ -45,12 +45,13 @@ def rotary_embedding(t, pos, theta, torch_device): ...@@ -45,12 +45,13 @@ def rotary_embedding(t, pos, theta, torch_device):
) )
freqs = torch.outer(pos, freqs) freqs = torch.outer(pos, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
t_ = torch.view_as_complex(t.reshape(*t.shape[:-1], -1, 2)) t_ = torch.view_as_complex(t.reshape(*t.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, t_) freqs_cis = reshape_for_broadcast(freqs_cis, t_)
t_out = torch.view_as_real(t_ * freqs_cis).flatten(2).to(t.dtype) t_out = torch.view_as_real(t_ * freqs_cis).flatten(2).to(t.dtype)
return t_out return t_out
def sin_cos_table(max_seq_len, dim, torch_device, theta): def sin_cos_table(max_seq_len, dim, torch_device, theta):
pos = torch.arange( pos = torch.arange(
0, max_seq_len, dtype=torch.float32, device=torch.device(torch_device) 0, max_seq_len, dtype=torch.float32, device=torch.device(torch_device)
...@@ -73,12 +74,12 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16): ...@@ -73,12 +74,12 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16):
if strides is not None: if strides is not None:
t = rearrange_tensor(t, strides) t = rearrange_tensor(t, strides)
posTmp = torch.arange(0, t.shape[0]) posTmp = torch.arange(0, t.shape[0])
pos = torch.zeros(2 * posTmp.shape[0], dtype = torch.int32) pos = torch.zeros(2 * posTmp.shape[0], dtype=torch.int32)
for i in range(posTmp.shape[0]): for i in range(posTmp.shape[0]):
pos[2 * i] = posTmp[i] pos[2 * i] = posTmp[i]
pos[2 * i + 1] = 0 pos[2 * i + 1] = 0
theta = 1e4 theta = 1e4
if torch_device == 'mlu' or torch_device == 'npu': if torch_device == "mlu" or torch_device == "npu":
ans = rotary_embedding(t, posTmp, theta, "cpu").to(torch_device) ans = rotary_embedding(t, posTmp, theta, "cpu").to(torch_device)
pos = pos.to(torch_device) pos = pos.to(torch_device)
t = t.to(torch_device) t = t.to(torch_device)
...@@ -97,7 +98,7 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16): ...@@ -97,7 +98,7 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16):
cos_table_tensor = to_tensor(cos_table, lib) cos_table_tensor = to_tensor(cos_table, lib)
if torch_device == "npu": if torch_device == "npu":
torch.npu.synchronize() torch.npu.synchronize()
check_error( check_error(
lib.infiniopCreateRoPEDescriptor( lib.infiniopCreateRoPEDescriptor(
...@@ -156,6 +157,7 @@ def test_cuda(lib, test_cases): ...@@ -156,6 +157,7 @@ def test_cuda(lib, test_cases):
def test_bang(lib, test_cases): def test_bang(lib, test_cases):
import torch_mlu import torch_mlu
device = DeviceEnum.DEVICE_BANG device = DeviceEnum.DEVICE_BANG
handle = create_handle(lib, device) handle = create_handle(lib, device)
for shape, strides, dtype in test_cases: for shape, strides, dtype in test_cases:
...@@ -163,7 +165,7 @@ def test_bang(lib, test_cases): ...@@ -163,7 +165,7 @@ def test_bang(lib, test_cases):
destroy_handle(lib, handle) destroy_handle(lib, handle)
def test_ascend(lib, test_cases) : def test_ascend(lib, test_cases):
import torch_npu import torch_npu
device = DeviceEnum.DEVICE_ASCEND device = DeviceEnum.DEVICE_ASCEND
...@@ -172,6 +174,7 @@ def test_ascend(lib, test_cases) : ...@@ -172,6 +174,7 @@ def test_ascend(lib, test_cases) :
test(lib, handle, "npu", shape, strides, dtype) test(lib, handle, "npu", shape, strides, dtype)
destroy_handle(lib, handle) destroy_handle(lib, handle)
if __name__ == "__main__": if __name__ == "__main__":
test_cases = [ test_cases = [
((1, 32, 128), None, torch.float16), ((1, 32, 128), None, torch.float16),
...@@ -180,7 +183,6 @@ if __name__ == "__main__": ...@@ -180,7 +183,6 @@ if __name__ == "__main__":
# 接口 GatherMask 的内部实现相关,目前 48 64 128 都可以支持 # 接口 GatherMask 的内部实现相关,目前 48 64 128 都可以支持
((4, 1, 32), None, torch.float16), ((4, 1, 32), None, torch.float16),
((1, 32, 128), None, torch.float16), ((1, 32, 128), None, torch.float16),
((3, 32, 128), (8000, 200, 1), torch.float16), ((3, 32, 128), (8000, 200, 1), torch.float16),
] ]
args = get_args() args = get_args()
......
...@@ -29,9 +29,10 @@ infiniopSwiGLUDescriptor_t = POINTER(SwiGLUDescriptor) ...@@ -29,9 +29,10 @@ infiniopSwiGLUDescriptor_t = POINTER(SwiGLUDescriptor)
def swiglu(a, b): def swiglu(a, b):
return a * b / (1 + torch.exp(-b.float()).to(b.dtype)) return a * b / (1 + torch.exp(-b.float()).to(b.dtype))
def test_out_of_place( def test_out_of_place(
lib, lib,
handle, handle,
...@@ -223,6 +224,7 @@ def test_cuda(lib, test_cases): ...@@ -223,6 +224,7 @@ def test_cuda(lib, test_cases):
def test_bang(lib, test_cases): def test_bang(lib, test_cases):
import torch_mlu import torch_mlu
device = DeviceEnum.DEVICE_BANG device = DeviceEnum.DEVICE_BANG
handle = create_handle(lib, device) handle = create_handle(lib, device)
...@@ -238,17 +240,30 @@ def test_bang(lib, test_cases): ...@@ -238,17 +240,30 @@ def test_bang(lib, test_cases):
def test_ascend(lib, test_cases): def test_ascend(lib, test_cases):
import torch_npu import torch_npu
device = DeviceEnum.DEVICE_ASCEND device = DeviceEnum.DEVICE_ASCEND
handle = create_handle(lib, device) handle = create_handle(lib, device)
for shape, a_stride, b_stride, c_stride, dtype in test_cases: for shape, a_stride, b_stride, c_stride, dtype in test_cases:
test_out_of_place( test_out_of_place(
lib, handle, "npu", shape, a_stride, b_stride, c_stride, dtype, torch.npu.synchronize lib,
handle,
"npu",
shape,
a_stride,
b_stride,
c_stride,
dtype,
torch.npu.synchronize,
)
test_in_place1(
lib, handle, "npu", shape, a_stride, b_stride, dtype, torch.npu.synchronize
)
test_in_place2(
lib, handle, "npu", shape, a_stride, b_stride, dtype, torch.npu.synchronize
) )
test_in_place1(lib, handle, "npu", shape, a_stride, b_stride, dtype, torch.npu.synchronize)
test_in_place2(lib, handle, "npu", shape, a_stride, b_stride, dtype, torch.npu.synchronize)
destroy_handle(lib, handle) destroy_handle(lib, handle)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment