Unverified Commit b7893d65 authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #67 from PanZezhong1725/issue/66

issue/66: 重构7个算子的测试脚本
parents 3165aba0 642e8de0
from ctypes import POINTER, Structure, c_int32, c_uint64, c_void_p import torch
import ctypes import ctypes
import sys from ctypes import POINTER, Structure, c_int32, c_size_t, c_uint64, c_void_p, c_float
import os from libinfiniop import (
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
from operatorspy import (
open_lib,
to_tensor,
DeviceEnum,
infiniopHandle_t, infiniopHandle_t,
infiniopTensorDescriptor_t, infiniopTensorDescriptor_t,
create_handle, open_lib,
destroy_handle, to_tensor,
get_test_devices,
check_error, check_error,
rearrange_tensor, rearrange_if_needed,
create_workspace, create_workspace,
test_operator,
get_args,
debug,
get_tolerance,
profile_operation,
) )
from operatorspy.tests.test_utils import get_args # ==============================================================================
import torch # Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES = [
# x_shape, x_stride
((32, 512), None),
((32, 512), (1024, 1)),
((32, 5, 5), None),
((32, 20, 512), None),
((32, 20, 512), (20480, 512, 1)), # Ascend 暂不支持非连续
]
# Data types used for testing
_TENSOR_DTYPES = [torch.float16]
# Tolerance map for different data types
_TOLERANCE_MAP = {
torch.float16: {"atol": 0, "rtol": 1e-2},
}
DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000
class CausalSoftmaxDescriptor(Structure): class CausalSoftmaxDescriptor(Structure):
...@@ -37,32 +59,38 @@ def causal_softmax(x): ...@@ -37,32 +59,38 @@ def causal_softmax(x):
return torch.nn.functional.softmax(masked, dim=-1).to(type) return torch.nn.functional.softmax(masked, dim=-1).to(type)
def test(lib, handle, torch_device, x_shape, x_stride=None, x_dtype=torch.float16): def test(lib, handle, torch_device, x_shape, x_stride=None, dtype=torch.float16):
print( print(
f"Testing CausalSoftmax on {torch_device} with x_shape:{x_shape} x_stride:{x_stride} dtype:{x_dtype}" f"Testing CausalSoftmax on {torch_device} with x_shape:{x_shape} x_stride:{x_stride} dtype:{dtype}"
) )
x = torch.rand(x_shape, dtype=x_dtype).to(torch_device)
if x_stride is not None: x = torch.rand(x_shape, dtype=dtype).to(torch_device)
x = rearrange_tensor(x, x_stride)
ans = causal_softmax(x) ans = causal_softmax(x)
x = rearrange_if_needed(x, x_stride)
x_tensor = to_tensor(x, lib) x_tensor = to_tensor(x, lib)
descriptor = infiniopCausalSoftmaxDescriptor_t() descriptor = infiniopCausalSoftmaxDescriptor_t()
check_error( check_error(
lib.infiniopCreateCausalSoftmaxDescriptor( lib.infiniopCreateCausalSoftmaxDescriptor(
handle, ctypes.byref(descriptor), x_tensor.descriptor handle, ctypes.byref(descriptor), x_tensor.descriptor
) )
) )
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
x_tensor.descriptor.contents.invalidate()
workspace_size = c_uint64(0) workspace_size = c_uint64(0)
check_error( check_error(
lib.infiniopGetCausalSoftmaxWorkspaceSize( lib.infiniopGetCausalSoftmaxWorkspaceSize(
descriptor, ctypes.byref(workspace_size) descriptor, ctypes.byref(workspace_size)
) )
) )
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
x_tensor.descriptor.contents.invalidate()
workspace = create_workspace(workspace_size.value, x.device) workspace = create_workspace(workspace_size.value, x.device)
def lib_causal_softmax():
check_error( check_error(
lib.infiniopCausalSoftmax( lib.infiniopCausalSoftmax(
descriptor, descriptor,
...@@ -72,66 +100,41 @@ def test(lib, handle, torch_device, x_shape, x_stride=None, x_dtype=torch.float1 ...@@ -72,66 +100,41 @@ def test(lib, handle, torch_device, x_shape, x_stride=None, x_dtype=torch.float1
None, None,
) )
) )
assert torch.allclose(x, ans, atol=0, rtol=1e-2)
check_error(lib.infiniopDestroyCausalSoftmaxDescriptor(descriptor))
def test_cpu(lib, test_cases):
device = DeviceEnum.DEVICE_CPU
handle = create_handle(lib, device)
for x_shape, x_stride in test_cases:
test(lib, handle, "cpu", x_shape, x_stride)
destroy_handle(lib, handle)
def test_cuda(lib, test_cases):
device = DeviceEnum.DEVICE_CUDA
handle = create_handle(lib, device)
for x_shape, x_stride in test_cases:
test(lib, handle, "cuda", x_shape, x_stride)
destroy_handle(lib, handle)
def test_bang(lib, test_cases):
import torch_mlu
device = DeviceEnum.DEVICE_BANG
handle = create_handle(lib, device)
for x_shape, x_stride in test_cases:
test(lib, handle, "mlu", x_shape, x_stride)
destroy_handle(lib, handle)
lib_causal_softmax()
def test_ascend(lib, test_cases): atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
import torch_npu if DEBUG:
debug(x, ans, atol=atol, rtol=rtol)
assert torch.allclose(x, ans, atol=atol, rtol=rtol)
device = DeviceEnum.DEVICE_ASCEND # Profiling workflow
handle = create_handle(lib, device) if PROFILE:
for x_shape, x_stride in test_cases: # fmt: off
test(lib, handle, "npu", x_shape, x_stride) profile_operation("PyTorch", lambda: causal_softmax(x), torch_device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_causal_softmax(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on
destroy_handle(lib, handle) check_error(lib.infiniopDestroyCausalSoftmaxDescriptor(descriptor))
if __name__ == "__main__": if __name__ == "__main__":
test_cases = [
# x_shape, x_stride
((32, 20, 512), None),
((32, 20, 512), (20480, 512, 1)), # Ascend 暂不支持非连续
]
args = get_args() args = get_args()
lib = open_lib() lib = open_lib()
lib.infiniopCreateCausalSoftmaxDescriptor.restype = c_int32 lib.infiniopCreateCausalSoftmaxDescriptor.restype = c_int32
lib.infiniopCreateCausalSoftmaxDescriptor.argtypes = [ lib.infiniopCreateCausalSoftmaxDescriptor.argtypes = [
infiniopHandle_t, infiniopHandle_t,
POINTER(infiniopCausalSoftmaxDescriptor_t), POINTER(infiniopCausalSoftmaxDescriptor_t),
infiniopTensorDescriptor_t, infiniopTensorDescriptor_t,
] ]
lib.infiniopGetCausalSoftmaxWorkspaceSize.restype = c_int32 lib.infiniopGetCausalSoftmaxWorkspaceSize.restype = c_int32
lib.infiniopGetCausalSoftmaxWorkspaceSize.argtypes = [ lib.infiniopGetCausalSoftmaxWorkspaceSize.argtypes = [
infiniopCausalSoftmaxDescriptor_t, infiniopCausalSoftmaxDescriptor_t,
POINTER(c_uint64), POINTER(c_uint64),
] ]
lib.infiniopCausalSoftmax.restype = c_int32 lib.infiniopCausalSoftmax.restype = c_int32
lib.infiniopCausalSoftmax.argtypes = [ lib.infiniopCausalSoftmax.argtypes = [
infiniopCausalSoftmaxDescriptor_t, infiniopCausalSoftmaxDescriptor_t,
...@@ -140,19 +143,19 @@ if __name__ == "__main__": ...@@ -140,19 +143,19 @@ if __name__ == "__main__":
c_void_p, c_void_p,
c_void_p, c_void_p,
] ]
lib.infiniopDestroyCausalSoftmaxDescriptor.restype = c_int32 lib.infiniopDestroyCausalSoftmaxDescriptor.restype = c_int32
lib.infiniopDestroyCausalSoftmaxDescriptor.argtypes = [ lib.infiniopDestroyCausalSoftmaxDescriptor.argtypes = [
infiniopCausalSoftmaxDescriptor_t, infiniopCausalSoftmaxDescriptor_t,
] ]
if args.cpu: # Configure testing options
test_cpu(lib, test_cases) DEBUG = args.debug
if args.cuda: PROFILE = args.profile
test_cuda(lib, test_cases) NUM_PRERUN = args.num_prerun
if args.bang: NUM_ITERATIONS = args.num_iterations
test_bang(lib, test_cases)
if args.ascend: for device in get_test_devices(args):
test_ascend(lib, test_cases) test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES)
if not (args.cpu or args.cuda or args.bang or args.ascend):
test_cpu(lib, test_cases)
print("\033[92mTest passed!\033[0m") print("\033[92mTest passed!\033[0m")
from ctypes import POINTER, Structure, c_int32, c_uint64, c_void_p, c_float import torch
import ctypes import ctypes
import sys from ctypes import POINTER, Structure, c_int32, c_size_t, c_uint64, c_void_p, c_float
import os from libinfiniop import (
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
from operatorspy import (
open_lib,
to_tensor,
DeviceEnum,
infiniopHandle_t, infiniopHandle_t,
infiniopTensorDescriptor_t, infiniopTensorDescriptor_t,
create_handle, open_lib,
destroy_handle, to_tensor,
get_test_devices,
check_error, check_error,
rearrange_tensor, rearrange_if_needed,
create_workspace, create_workspace,
U64, test_operator,
get_args,
debug_all,
get_tolerance,
profile_operation,
synchronize_device,
) )
from operatorspy.tests.test_utils import get_args # ==============================================================================
import torch # Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES = [
# voc, random_val, topp, topk, temperature
(512, 0.8, 0.8, 3, 0.5),
(4096, 0.05, 0.9, 5, 1.0),
(16384, 0.15, 0.85, 10, 2.0),
(512, 0.08, 0, 3, 0.5),
(4096, 0.5, 0.9, 1, 1.0),
(16384, 0.15, 0, 1, 2.0),
(16384, 0.15, 0, 1, 2.0),
(32000, 0.08, 0.8, 50, 1.0),
(32000, 0.08, 1.0, 25, 1.0),
# (119696, 0.01, 1.0, 100, 1.0),
]
# Data types used for testing
_TENSOR_DTYPES = [torch.float16]
_TOLERANCE_MAP = {
torch.float16: {"atol": 0, "rtol": 0},
}
DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000
class RandomSampleDescriptor(Structure): class RandomSampleDescriptor(Structure):
...@@ -30,6 +58,7 @@ infiniopRandomSampleDescriptor_t = POINTER(RandomSampleDescriptor) ...@@ -30,6 +58,7 @@ infiniopRandomSampleDescriptor_t = POINTER(RandomSampleDescriptor)
def random_sample(data, random_val, topp, topk, voc, temperature, torch_device): def random_sample(data, random_val, topp, topk, voc, temperature, torch_device):
if topp > 0 and topk > 1:
indices = torch.zeros([topk], dtype=torch.int64) indices = torch.zeros([topk], dtype=torch.int64)
dataNp = data.clone().detach() dataNp = data.clone().detach()
sorted_indices = torch.arange(voc) sorted_indices = torch.arange(voc)
...@@ -73,9 +102,7 @@ def random_sample(data, random_val, topp, topk, voc, temperature, torch_device): ...@@ -73,9 +102,7 @@ def random_sample(data, random_val, topp, topk, voc, temperature, torch_device):
sum_s += dataNp[i] sum_s += dataNp[i]
if random_val < sum_s: if random_val < sum_s:
return indices[i] return indices[i]
else:
def random_sample_0(data):
return torch.argmax(data) return torch.argmax(data)
...@@ -91,18 +118,19 @@ def test( ...@@ -91,18 +118,19 @@ def test(
x_dtype=torch.float16, x_dtype=torch.float16,
): ):
print(f"Testing RandomSample on {torch_device} with voc:{voc} dtype:{x_dtype}") print(f"Testing RandomSample on {torch_device} with voc:{voc} dtype:{x_dtype}")
data = torch.arange(voc).float() * 0.0001 data = torch.arange(voc).float() * 0.0001
_perm = torch.randperm(voc) _perm = torch.randperm(voc)
data = data[_perm].to(x_dtype).to(torch_device) data = data[_perm].to(x_dtype).to(torch_device)
if topp > 0 and topk > 1:
ans = random_sample( ans = random_sample(
data.to("cpu"), random_val, topp, topk, voc, temperature, "cpu" data, random_val, topp, topk, voc, temperature, torch_device
) ) # 这个函数在device速度可能会很慢,可以通过data.to("cpu")方式加快计算过程
else:
ans = random_sample_0(data)
indices = torch.zeros([1], dtype=torch.int64).to(torch_device) indices = torch.zeros([1], dtype=torch.int64).to(torch_device)
x_tensor = to_tensor(data, lib)
indices_tensor = to_tensor(indices, lib) x_tensor, indices_tensor = [to_tensor(tensor, lib) for tensor in [data, indices]]
indices_tensor.descriptor.contents.dt = U64 # treat int64 as uint64 indices_tensor.descriptor.contents.dt = U64 # treat int64 as uint64
descriptor = infiniopRandomSampleDescriptor_t() descriptor = infiniopRandomSampleDescriptor_t()
...@@ -116,8 +144,8 @@ def test( ...@@ -116,8 +144,8 @@ def test(
) )
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
x_tensor.descriptor.contents.invalidate() for tensor in [x_tensor, indices_tensor]:
indices_tensor.descriptor.contents.invalidate() tensor.descriptor.contents.invalidate()
workspace_size = c_uint64(0) workspace_size = c_uint64(0)
check_error( check_error(
...@@ -126,6 +154,8 @@ def test( ...@@ -126,6 +154,8 @@ def test(
) )
) )
workspace = create_workspace(workspace_size.value, torch_device) workspace = create_workspace(workspace_size.value, torch_device)
def lib_random_sample():
check_error( check_error(
lib.infiniopRandomSample( lib.infiniopRandomSample(
descriptor, descriptor,
...@@ -140,77 +170,51 @@ def test( ...@@ -140,77 +170,51 @@ def test(
None, None,
) )
) )
if torch_device == "npu":
torch.npu.synchronize()
assert indices[0].type(ans.dtype) == ans or data[ans] == data[indices[0]]
check_error(lib.infiniopDestroyRandomSampleDescriptor(descriptor))
def test_cpu(lib, test_cases):
device = DeviceEnum.DEVICE_CPU
handle = create_handle(lib, device)
for voc, random_val, topp, topk, temperature in test_cases:
test(lib, handle, "cpu", voc, random_val, topp, topk, temperature)
destroy_handle(lib, handle)
def test_cuda(lib, test_cases):
device = DeviceEnum.DEVICE_CUDA
handle = create_handle(lib, device)
for voc, random_val, topp, topk, temperature in test_cases:
test(lib, handle, "cuda", voc, random_val, topp, topk, temperature)
destroy_handle(lib, handle)
def test_bang(lib, test_cases): lib_random_sample()
import torch_mlu
device = DeviceEnum.DEVICE_BANG if torch_device == "npu":
handle = create_handle(lib, device) synchronize_device(torch_device)
for voc, random_val, topp, topk, temperature in test_cases:
test(lib, handle, "mlu", voc, random_val, topp, topk, temperature) atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
destroy_handle(lib, handle) if DEBUG:
debug_all(
(indices[0].type(ans.dtype), data[indices[0]]),
def test_ascend(lib, test_cases): (ans, data[ans]),
import torch_npu "or",
atol=atol,
rtol=rtol,
)
assert indices[0].type(ans.dtype) == ans or data[ans] == data[indices[0]]
device = DeviceEnum.DEVICE_ASCEND # Profiling workflow
handle = create_handle(lib, device) if PROFILE:
for voc, random_val, topp, topk, temperature in test_cases: # fmt: off
test(lib, handle, "npu", voc, random_val, topp, topk, temperature) profile_operation("PyTorch", lambda: random_sample(
destroy_handle(lib, handle) data, random_val, topp, topk, voc, temperature, torch_device
), torch_device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_random_sample(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on
check_error(lib.infiniopDestroyRandomSampleDescriptor(descriptor))
if __name__ == "__main__": if __name__ == "__main__":
test_cases = [
# voc, random_val, topp, topk, temperature
(512, 0.8, 0.8, 3, 0.5),
(4096, 0.05, 0.9, 5, 1.0),
(16384, 0.15, 0.85, 10, 2.0),
(512, 0.08, 0, 3, 0.5),
(4096, 0.5, 0.9, 1, 1.0),
(16384, 0.15, 0, 1, 2.0),
(16384, 0.15, 0, 1, 2.0),
(32000, 0.08, 0.8, 50, 1.0),
(32000, 0.08, 1.0, 25, 1.0),
# (119696, 0.01, 1.0, 100, 1.0),
]
args = get_args() args = get_args()
lib = open_lib() lib = open_lib()
lib.infiniopCreateRandomSampleDescriptor.restype = c_int32 lib.infiniopCreateRandomSampleDescriptor.restype = c_int32
lib.infiniopCreateRandomSampleDescriptor.argtypes = [ lib.infiniopCreateRandomSampleDescriptor.argtypes = [
infiniopHandle_t, infiniopHandle_t,
POINTER(infiniopRandomSampleDescriptor_t), POINTER(infiniopRandomSampleDescriptor_t),
infiniopTensorDescriptor_t, infiniopTensorDescriptor_t,
] ]
lib.infiniopGetRandomSampleWorkspaceSize.restype = c_int32 lib.infiniopGetRandomSampleWorkspaceSize.restype = c_int32
lib.infiniopGetRandomSampleWorkspaceSize.argtypes = [ lib.infiniopGetRandomSampleWorkspaceSize.argtypes = [
infiniopRandomSampleDescriptor_t, infiniopRandomSampleDescriptor_t,
POINTER(c_uint64), POINTER(c_uint64),
] ]
lib.infiniopRandomSample.restype = c_int32 lib.infiniopRandomSample.restype = c_int32
lib.infiniopRandomSample.argtypes = [ lib.infiniopRandomSample.argtypes = [
infiniopRandomSampleDescriptor_t, infiniopRandomSampleDescriptor_t,
...@@ -224,19 +228,19 @@ if __name__ == "__main__": ...@@ -224,19 +228,19 @@ if __name__ == "__main__":
c_float, c_float,
c_void_p, c_void_p,
] ]
lib.infiniopDestroyRandomSampleDescriptor.restype = c_int32 lib.infiniopDestroyRandomSampleDescriptor.restype = c_int32
lib.infiniopDestroyRandomSampleDescriptor.argtypes = [ lib.infiniopDestroyRandomSampleDescriptor.argtypes = [
infiniopRandomSampleDescriptor_t, infiniopRandomSampleDescriptor_t,
] ]
if args.cpu: DEBUG = args.debug
test_cpu(lib, test_cases) PROFILE = args.profile
if args.cuda: NUM_PRERUN = args.num_prerun
test_cuda(lib, test_cases) NUM_ITERATIONS = args.num_iterations
if args.bang:
test_bang(lib, test_cases) # Execute tests
if args.ascend: for device in get_test_devices(args):
test_ascend(lib, test_cases) test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES)
if not (args.cpu or args.cuda or args.bang or args.ascend):
test_cpu(lib, test_cases)
print("\033[92mTest passed!\033[0m") print("\033[92mTest passed!\033[0m")
import torch
import ctypes import ctypes
from ctypes import POINTER, Structure, c_int32, c_uint64, c_void_p from ctypes import POINTER, Structure, c_int32, c_size_t, c_uint64, c_void_p, c_float
import sys from libinfiniop import (
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
from operatorspy import (
open_lib,
to_tensor,
CTensor,
DeviceEnum,
infiniopHandle_t, infiniopHandle_t,
infiniopTensorDescriptor_t, infiniopTensorDescriptor_t,
create_handle, open_lib,
destroy_handle, to_tensor,
get_test_devices,
check_error, check_error,
rearrange_tensor, rearrange_if_needed,
create_workspace,
test_operator,
get_args,
debug,
get_tolerance,
profile_operation,
) )
from operatorspy.tests.test_utils import get_args # ==============================================================================
import torch # Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES = [
# ((src_shape, src_stride), (dst_shape, dst_stride))
(((2, 4, 32), None), ((2, 4, 32), (256, 64, 1))),
(((32, 6, 64), (64, 2560, 1)), ((32, 6, 64), None)),
(((4, 6, 64), (64, 2560, 1)), ((4, 6, 64), (131072, 64, 1))),
(((1, 32, 64), (2048, 64, 1)), ((1, 32, 64), (2048, 64, 1))),
(((32, 1, 64), (64, 2560, 1)), ((32, 1, 64), (64, 64, 1))),
(((4, 1, 64), (64, 2560, 1)), ((4, 1, 64), (64, 11264, 1))),
(((64,), (1,)), ((64,), (1,))),
]
# Data types used for testing
_TENSOR_DTYPES = [torch.float16, torch.float32]
# Tolerance map for different data types
_TOLERANCE_MAP = {
torch.float16: {"atol": 0, "rtol": 0},
torch.float32: {"atol": 0, "rtol": 0},
}
DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000
class RerrangeDescriptor(Structure): class RerrangeDescriptor(Structure):
...@@ -41,14 +67,15 @@ def test( ...@@ -41,14 +67,15 @@ def test(
print( print(
f"Testing Rerrange on {torch_device} with x_shape:{x_shape} x_stride:{x_stride} y_shape:{y_shape} y_stride:{y_stride} x_dtype:{x_dtype}" f"Testing Rerrange on {torch_device} with x_shape:{x_shape} x_stride:{x_stride} y_shape:{y_shape} y_stride:{y_stride} x_dtype:{x_dtype}"
) )
x = torch.rand(x_shape, dtype=x_dtype).to(torch_device) x = torch.rand(x_shape, dtype=x_dtype).to(torch_device)
y = torch.zeros(y_shape, dtype=x_dtype).to(torch_device) y = torch.zeros(y_shape, dtype=x_dtype).to(torch_device)
if x_stride is not None:
x = rearrange_tensor(x, x_stride) x, y = [
if y_stride is not None: rearrange_if_needed(tensor, stride)
y = rearrange_tensor(y, y_stride) for tensor, stride in zip([x, y], [x_stride, y_stride])
x_tensor = to_tensor(x, lib) ]
y_tensor = to_tensor(y, lib) x_tensor, y_tensor = [to_tensor(tensor, lib) for tensor in [x, y]]
descriptor = infiniopRearrangeDescriptor_t() descriptor = infiniopRearrangeDescriptor_t()
check_error( check_error(
...@@ -58,71 +85,36 @@ def test( ...@@ -58,71 +85,36 @@ def test(
) )
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
x_tensor.descriptor.contents.invalidate() for tensor in [x_tensor, y_tensor]:
y_tensor.descriptor.contents.invalidate() tensor.descriptor.contents.invalidate()
check_error(lib.infiniopRearrange(descriptor, y_tensor.data, x_tensor.data, None))
assert torch.allclose(x, y, atol=0, rtol=1e-3)
check_error(lib.infiniopDestroyRearrangeDescriptor(descriptor))
def lib_rearrange():
check_error(
lib.infiniopRearrange(descriptor, y_tensor.data, x_tensor.data, None)
)
def test_cpu(lib, test_cases): lib_rearrange()
device = DeviceEnum.DEVICE_CPU
handle = create_handle(lib, device)
for test_case in test_cases:
x_shape, x_stride = test_case[0]
y_shape, y_stride = test_case[1]
test(lib, handle, "cpu", x_shape, x_stride, y_shape, y_stride)
destroy_handle(lib, handle)
def test_cuda(lib, test_cases):
device = DeviceEnum.DEVICE_CUDA
handle = create_handle(lib, device)
for test_case in test_cases:
x_shape, x_stride = test_case[0]
y_shape, y_stride = test_case[1]
test(lib, handle, "cuda", x_shape, x_stride, y_shape, y_stride)
destroy_handle(lib, handle)
def test_bang(lib, test_cases):
import torch_mlu
device = DeviceEnum.DEVICE_BANG
handle = create_handle(lib, device)
for test_case in test_cases:
x_shape, x_stride = test_case[0]
y_shape, y_stride = test_case[1]
test(lib, handle, "mlu", x_shape, x_stride, y_shape, y_stride)
destroy_handle(lib, handle)
# Validate results
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG:
debug(x, y, atol=atol, rtol=rtol)
assert torch.allclose(x, y, atol=atol, rtol=rtol)
def test_ascend(lib, test_cases): # Profiling workflow
import torch_npu if PROFILE:
# fmt: off
profile_operation("PyTorch", lambda: rearrange_tensor(y, y_stride), torch_device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_rearrange(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on
device = DeviceEnum.DEVICE_ASCEND check_error(lib.infiniopDestroyRearrangeDescriptor(descriptor))
handle = create_handle(lib, device)
for test_case in test_cases:
x_shape, x_stride = test_case[0]
y_shape, y_stride = test_case[1]
test(lib, handle, "npu", x_shape, x_stride, y_shape, y_stride)
destroy_handle(lib, handle)
if __name__ == "__main__": if __name__ == "__main__":
args = get_args() args = get_args()
test_cases = [
# ((src_shape, src_stride), (dst_shape, dst_stride))
(((2, 4, 32), None), ((2, 4, 32), (256, 64, 1))),
(((32, 6, 64), (64, 2560, 1)), ((32, 6, 64), None)),
(((4, 6, 64), (64, 2560, 1)), ((4, 6, 64), (131072, 64, 1))),
(((1, 32, 64), (2048, 64, 1)), ((1, 32, 64), (2048, 64, 1))),
(((32, 1, 64), (64, 2560, 1)), ((32, 1, 64), (64, 64, 1))),
(((4, 1, 64), (64, 2560, 1)), ((4, 1, 64), (64, 11264, 1))),
(((64,), (1,)), ((64,), (1,))),
]
lib = open_lib() lib = open_lib()
lib.infiniopCreateRearrangeDescriptor.restype = c_int32 lib.infiniopCreateRearrangeDescriptor.restype = c_int32
lib.infiniopCreateRearrangeDescriptor.argtypes = [ lib.infiniopCreateRearrangeDescriptor.argtypes = [
infiniopHandle_t, infiniopHandle_t,
...@@ -130,6 +122,7 @@ if __name__ == "__main__": ...@@ -130,6 +122,7 @@ if __name__ == "__main__":
infiniopTensorDescriptor_t, infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t, infiniopTensorDescriptor_t,
] ]
lib.infiniopRearrange.restype = c_int32 lib.infiniopRearrange.restype = c_int32
lib.infiniopRearrange.argtypes = [ lib.infiniopRearrange.argtypes = [
infiniopRearrangeDescriptor_t, infiniopRearrangeDescriptor_t,
...@@ -137,14 +130,18 @@ if __name__ == "__main__": ...@@ -137,14 +130,18 @@ if __name__ == "__main__":
c_void_p, c_void_p,
c_void_p, c_void_p,
] ]
lib.infiniopDestroyRearrangeDescriptor.restype = c_int32 lib.infiniopDestroyRearrangeDescriptor.restype = c_int32
lib.infiniopDestroyRearrangeDescriptor.argtypes = [infiniopRearrangeDescriptor_t] lib.infiniopDestroyRearrangeDescriptor.argtypes = [infiniopRearrangeDescriptor_t]
if args.cpu:
test_cpu(lib, test_cases) # Configure testing options
if args.cuda: DEBUG = args.debug
test_cuda(lib, test_cases) PROFILE = args.profile
if args.bang: NUM_PRERUN = args.num_prerun
test_bang(lib, test_cases) NUM_ITERATIONS = args.num_iterations
if args.ascend:
test_ascend(lib, test_cases) # Execute tests
for device in get_test_devices(args):
test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES)
print("\033[92mTest passed!\033[0m") print("\033[92mTest passed!\033[0m")
from ctypes import POINTER, Structure, c_int32, c_uint64, c_void_p, c_float from ctypes import POINTER, Structure, c_int32, c_uint64, c_void_p, c_float
import ctypes import ctypes
import sys import torch
import os import ctypes
from ctypes import POINTER, Structure, c_int32, c_size_t, c_uint64, c_void_p, c_float
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) from libinfiniop import (
from operatorspy import (
open_lib,
to_tensor,
DeviceEnum,
infiniopHandle_t, infiniopHandle_t,
infiniopTensorDescriptor_t, infiniopTensorDescriptor_t,
create_handle, open_lib,
destroy_handle, to_tensor,
get_test_devices,
check_error, check_error,
rearrange_tensor, rearrange_if_needed,
create_workspace, create_workspace,
test_operator,
get_args,
debug,
get_tolerance,
profile_operation,
) )
from operatorspy.tests.test_utils import get_args # ==============================================================================
import torch # Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES = [
# y_shape, x_shape, w_shape, y_stride, x_stride, w_dtype
((16, 2048), (16, 2048), (2048,), None, None, torch.float32),
((16, 2048), (16, 2048), (2048,), None, None, torch.float16),
((16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1), torch.float32),
((16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1), torch.float16),
]
# x types used for testing
_TENSOR_DTYPES = [torch.float16]
# Tolerance map for different data types
_TOLERANCE_MAP = {
torch.float16: {"atol": 1e-3, "rtol": 1e-3},
}
DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000
class RMSNormDescriptor(Structure): class RMSNormDescriptor(Structure):
...@@ -43,6 +67,8 @@ def test( ...@@ -43,6 +67,8 @@ def test(
y_shape, y_shape,
x_shape, x_shape,
w_shape, w_shape,
y_stride,
x_stride,
dtype=torch.float16, dtype=torch.float16,
w_dtype=torch.float16, w_dtype=torch.float16,
): ):
...@@ -58,12 +84,14 @@ def test( ...@@ -58,12 +84,14 @@ def test(
eps = 1e-5 eps = 1e-5
ans = rms_norm(x, w, eps) ans = rms_norm(x, w, eps)
y_tensor = to_tensor(y, lib) x, y = [
x_tensor = to_tensor(x, lib) rearrange_if_needed(tensor, stride)
w_tensor = to_tensor(w, lib) for tensor, stride in zip([x, y], [x_stride, y_stride])
]
x_tensor, y_tensor, w_tensor = [to_tensor(tensor, lib) for tensor in [x, y, w]]
descriptor = infiniopRMSNormDescriptor_t() descriptor = infiniopRMSNormDescriptor_t()
w_dataType = 0 if w_dtype == torch.float16 else 1
check_error( check_error(
lib.infiniopCreateRMSNormDescriptor( lib.infiniopCreateRMSNormDescriptor(
...@@ -77,15 +105,16 @@ def test( ...@@ -77,15 +105,16 @@ def test(
) )
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
x_tensor.descriptor.contents.invalidate() for tensor in [x_tensor, y_tensor, w_tensor]:
y_tensor.descriptor.contents.invalidate() tensor.descriptor.contents.invalidate()
w_tensor.descriptor.contents.invalidate()
workspace_size = c_uint64(0) workspace_size = c_uint64(0)
check_error( check_error(
lib.infiniopGetRMSNormWorkspaceSize(descriptor, ctypes.byref(workspace_size)) lib.infiniopGetRMSNormWorkspaceSize(descriptor, ctypes.byref(workspace_size))
) )
workspace = create_workspace(workspace_size.value, y.device) workspace = create_workspace(workspace_size.value, y.device)
def lib_rms_norm():
check_error( check_error(
lib.infiniopRMSNorm( lib.infiniopRMSNorm(
descriptor, descriptor,
...@@ -98,55 +127,26 @@ def test( ...@@ -98,55 +127,26 @@ def test(
) )
) )
assert torch.allclose(y.to(dtype), ans.to(dtype), atol=1e-3, rtol=1e-3) lib_rms_norm()
check_error(lib.infiniopDestroyRMSNormDescriptor(descriptor))
def test_cpu(lib, test_cases):
device = DeviceEnum.DEVICE_CPU
handle = create_handle(lib, device)
for y_shape, x_shape, w_shape, dtype, w_dtype in test_cases:
test(lib, handle, "cpu", y_shape, x_shape, w_shape, dtype, w_dtype)
destroy_handle(lib, handle)
def test_cuda(lib, test_cases):
device = DeviceEnum.DEVICE_CUDA
handle = create_handle(lib, device)
for y_shape, x_shape, w_shape, dtype, w_dtype in test_cases:
test(lib, handle, "cuda", y_shape, x_shape, w_shape, dtype, w_dtype)
destroy_handle(lib, handle)
def test_bang(lib, test_cases):
import torch_mlu
device = DeviceEnum.DEVICE_BANG
handle = create_handle(lib, device)
for y_shape, x_shape, w_shape, dtype, w_dtype in test_cases:
test(lib, handle, "mlu", y_shape, x_shape, w_shape, dtype, w_dtype)
destroy_handle(lib, handle)
def test_ascend(lib, test_cases): atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
import torch_npu if DEBUG:
debug(y, ans, atol=atol, rtol=rtol)
assert torch.allclose(y, ans, atol=atol, rtol=rtol)
device = DeviceEnum.DEVICE_ASCEND # Profiling workflow
handle = create_handle(lib, device) if PROFILE:
for y_shape, x_shape, w_shape, dtype, w_dtype in test_cases: # fmt: off
test(lib, handle, "npu", y_shape, x_shape, w_shape, dtype, w_dtype) profile_operation("PyTorch", lambda: rms_norm(x, w, eps), torch_device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_rms_norm(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
destroy_handle(lib, handle) # fmt: on
check_error(lib.infiniopDestroyRMSNormDescriptor(descriptor))
if __name__ == "__main__": if __name__ == "__main__":
test_cases = [
# y_shape, x_shape, w_shape, dtype, w_dtype
((16, 2048), (16, 2048), (2048,), torch.float16, torch.float16),
((16, 2048), (16, 2048), (2048,), torch.float16, torch.float32),
]
args = get_args() args = get_args()
lib = open_lib() lib = open_lib()
lib.infiniopCreateRMSNormDescriptor.restype = c_int32 lib.infiniopCreateRMSNormDescriptor.restype = c_int32
lib.infiniopCreateRMSNormDescriptor.argtypes = [ lib.infiniopCreateRMSNormDescriptor.argtypes = [
infiniopHandle_t, infiniopHandle_t,
...@@ -173,19 +173,20 @@ if __name__ == "__main__": ...@@ -173,19 +173,20 @@ if __name__ == "__main__":
c_void_p, c_void_p,
c_void_p, c_void_p,
] ]
lib.infiniopDestroyRMSNormDescriptor.restype = c_int32 lib.infiniopDestroyRMSNormDescriptor.restype = c_int32
lib.infiniopDestroyRMSNormDescriptor.argtypes = [ lib.infiniopDestroyRMSNormDescriptor.argtypes = [
infiniopRMSNormDescriptor_t, infiniopRMSNormDescriptor_t,
] ]
if args.cpu: # Configure testing options
test_cpu(lib, test_cases) DEBUG = args.debug
if args.cuda: PROFILE = args.profile
test_cuda(lib, test_cases) NUM_PRERUN = args.num_prerun
if args.bang: NUM_ITERATIONS = args.num_iterations
test_bang(lib, test_cases)
if args.ascend: # Execute tests
test_ascend(lib, test_cases) for device in get_test_devices(args):
if not (args.cpu or args.cuda or args.bang or args.ascend): test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES)
test_cpu(lib, test_cases)
print("\033[92mTest passed!\033[0m") print("\033[92mTest passed!\033[0m")
import torch
import ctypes import ctypes
from ctypes import POINTER, c_void_p, c_int32, c_uint64, Structure, byref from ctypes import POINTER, Structure, c_int32, c_size_t, c_uint64, c_void_p, c_float
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
from libinfiniop import ( from libinfiniop import (
infiniopHandle_t, infiniopHandle_t,
infiniopTensorDescriptor_t, infiniopTensorDescriptor_t,
...@@ -16,10 +13,33 @@ from libinfiniop import ( ...@@ -16,10 +13,33 @@ from libinfiniop import (
test_operator, test_operator,
get_args, get_args,
debug, debug,
get_tolerance,
profile_operation, profile_operation,
InfiniDtype, synchronize_device,
) )
import torch
# ==============================================================================
# Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES = [
# (t_shape, t_strides)
((1, 32, 128), None),
((1, 32, 64), None),
# 昇腾暂不满足这个用例,最后一维度 <=32 会有问题,可能与其核心
# 接口 GatherMask 的内部实现相关,目前 48 64 128 都可以支持
((4, 1, 32), None),
((1, 32, 128), None),
((3, 32, 128), (8000, 200, 1)),
]
# Data types used for testing
_TENSOR_DTYPES = [torch.float16]
# Tolerance map for different data types
_TOLERANCE_MAP = {
torch.float16: {"atol": 1e-4, "rtol": 1e-2},
}
DEBUG = False DEBUG = False
PROFILE = False PROFILE = False
...@@ -81,7 +101,9 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16): ...@@ -81,7 +101,9 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16):
) )
t = torch.rand(shape, dtype=dtype) t = torch.rand(shape, dtype=dtype)
t = rearrange_if_needed(t, strides).to(torch_device)
t = rearrange_if_needed(t, strides)
posTmp = torch.arange(0, t.shape[0]).to(torch_device) posTmp = torch.arange(0, t.shape[0]).to(torch_device)
pos = torch.zeros(2 * posTmp.shape[0], dtype=torch.int32) pos = torch.zeros(2 * posTmp.shape[0], dtype=torch.int32)
for i in range(posTmp.shape[0]): for i in range(posTmp.shape[0]):
...@@ -95,14 +117,16 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16): ...@@ -95,14 +117,16 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16):
descriptor = infiniopRoPEDescriptor_t() descriptor = infiniopRoPEDescriptor_t()
# 2x table length for test # 2x table length for test
sin_table, cos_table = sin_cos_table(t.shape[0] * 2, t.shape[2], t.device, theta) sin_table, cos_table = sin_cos_table(t.shape[0] * 2, t.shape[2], t.device, theta)
t_tensor = to_tensor(t, lib)
t_tensor, sin_table_tensor, cos_table_tensor = [
to_tensor(tensor, lib) for tensor in [t, sin_table, cos_table]
]
pos_tensor = to_tensor(pos[: t.shape[0]], lib) pos_tensor = to_tensor(pos[: t.shape[0]], lib)
pos_tensor.descriptor.contents.dtype = InfiniDtype.U64 pos_tensor.descriptor.contents.dtype = InfiniDtype.U64
sin_table_tensor = to_tensor(sin_table, lib)
cos_table_tensor = to_tensor(cos_table, lib)
if torch_device == "npu": if torch_device == "npu":
torch.npu.synchronize() synchronize_device(torch_device)
check_error( check_error(
lib.infiniopCreateRoPEDescriptor( lib.infiniopCreateRoPEDescriptor(
...@@ -116,10 +140,8 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16): ...@@ -116,10 +140,8 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16):
) )
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
t_tensor.descriptor.contents.invalidate() for tensor in [t_tensor, pos_tensor, sin_table_tensor, cos_table_tensor]:
pos_tensor.descriptor.contents.invalidate() tensor.descriptor.contents.invalidate()
sin_table_tensor.descriptor.contents.invalidate()
cos_table_tensor.descriptor.contents.invalidate()
workspace_size = c_uint64(0) workspace_size = c_uint64(0)
check_error( check_error(
...@@ -142,9 +164,12 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16): ...@@ -142,9 +164,12 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16):
) )
lib_rope() lib_rope()
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG: if DEBUG:
debug(t, ans, atol=1e-4, rtol=1e-2) debug(t, ans, atol=atol, rtol=rtol)
assert torch.allclose(t, ans, atol=1e-4, rtol=1e-2) assert torch.allclose(t, ans, atol=atol, rtol=rtol)
if PROFILE: if PROFILE:
profile_operation( profile_operation(
"PyTorch", "PyTorch",
...@@ -161,19 +186,9 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16): ...@@ -161,19 +186,9 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16):
if __name__ == "__main__": if __name__ == "__main__":
test_cases = [
# (t_shape, t_strides)
((1, 32, 128), None),
((1, 32, 64), None),
# 昇腾暂不满足这个用例,最后一维度 <=32 会有问题,可能与其核心
# 接口 GatherMask 的内部实现相关,目前 48 64 128 都可以支持
((4, 1, 32), None),
((1, 32, 128), None),
((3, 32, 128), (8000, 200, 1)),
]
test_dtypes = [torch.float16]
args = get_args() args = get_args()
lib = open_lib() lib = open_lib()
lib.infiniopCreateRoPEDescriptor.restype = c_int32 lib.infiniopCreateRoPEDescriptor.restype = c_int32
lib.infiniopCreateRoPEDescriptor.argtypes = [ lib.infiniopCreateRoPEDescriptor.argtypes = [
infiniopHandle_t, infiniopHandle_t,
...@@ -183,11 +198,13 @@ if __name__ == "__main__": ...@@ -183,11 +198,13 @@ if __name__ == "__main__":
infiniopTensorDescriptor_t, infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t, infiniopTensorDescriptor_t,
] ]
lib.infiniopGetRoPEWorkspaceSize.restype = c_int32 lib.infiniopGetRoPEWorkspaceSize.restype = c_int32
lib.infiniopGetRoPEWorkspaceSize.argtypes = [ lib.infiniopGetRoPEWorkspaceSize.argtypes = [
infiniopRoPEDescriptor_t, infiniopRoPEDescriptor_t,
POINTER(c_uint64), POINTER(c_uint64),
] ]
lib.infiniopRoPE.restype = c_int32 lib.infiniopRoPE.restype = c_int32
lib.infiniopRoPE.argtypes = [ lib.infiniopRoPE.argtypes = [
infiniopRoPEDescriptor_t, infiniopRoPEDescriptor_t,
...@@ -199,10 +216,12 @@ if __name__ == "__main__": ...@@ -199,10 +216,12 @@ if __name__ == "__main__":
c_void_p, c_void_p,
c_void_p, c_void_p,
] ]
lib.infiniopDestroyRoPEDescriptor.restype = c_int32 lib.infiniopDestroyRoPEDescriptor.restype = c_int32
lib.infiniopDestroyRoPEDescriptor.argtypes = [ lib.infiniopDestroyRoPEDescriptor.argtypes = [
infiniopRoPEDescriptor_t, infiniopRoPEDescriptor_t,
] ]
# Configure testing options # Configure testing options
DEBUG = args.debug DEBUG = args.debug
PROFILE = args.profile PROFILE = args.profile
...@@ -211,5 +230,5 @@ if __name__ == "__main__": ...@@ -211,5 +230,5 @@ if __name__ == "__main__":
# Execute tests # Execute tests
for device in get_test_devices(args): for device in get_test_devices(args):
test_operator(lib, device, test, test_cases, test_dtypes) test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES)
print("\033[92mTest passed!\033[0m") print("\033[92mTest passed!\033[0m")
from ctypes import POINTER, Structure, c_int32, c_uint64, c_void_p import torch
import ctypes import ctypes
import sys from ctypes import POINTER, Structure, c_int32, c_size_t, c_uint64, c_void_p, c_float
import os from libinfiniop import (
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
from operatorspy import (
open_lib,
to_tensor,
CTensor,
DeviceEnum,
infiniopHandle_t, infiniopHandle_t,
infiniopTensorDescriptor_t, infiniopTensorDescriptor_t,
create_handle, open_lib,
destroy_handle, to_tensor,
get_test_devices,
check_error, check_error,
rearrange_tensor, rearrange_if_needed,
create_workspace,
test_operator,
get_args,
debug,
get_tolerance,
profile_operation,
) )
from enum import Enum, auto
from operatorspy.tests.test_utils import get_args
import torch # ==============================================================================
# Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES_ = [
((13, 4), None, None, None),
((13, 4), (10, 1), (10, 1), (10, 1)),
((13, 4, 4), None, None, None),
((13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1)),
((16, 5632), None, None, None),
((16, 5632), (13312, 1), (13312, 1), (13312, 1)),
((4, 4, 5632), None, None, None),
((4, 4, 5632), (45056, 5632, 1), (45056, 5632, 1), (45056, 5632, 1)),
]
# Inplace options applied for each test case in _TEST_CASES_
_INPLACE = [
"Inplace.OUT_OF_PLACE",
"Inplace.INPLACE_A",
"Inplace.INPLACE_B",
]
# Form the test cases by appending each element of _INPLACE to each tuple in _TEST_CASES_
_TEST_CASES = [
test_case + (inplace_item,)
for test_case in _TEST_CASES_
for inplace_item in _INPLACE
]
# Data types used for testing
_TENSOR_DTYPES = [torch.float16]
# Tolerance map for different data types
_TOLERANCE_MAP = {
torch.float16: {"atol": 1e-4, "rtol": 1e-2},
}
DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000
class Inplace(Enum):
OUT_OF_PLACE = auto()
INPLACE_A = auto()
INPLACE_B = auto()
class SwiGLUDescriptor(Structure): class SwiGLUDescriptor(Structure):
...@@ -29,11 +75,10 @@ infiniopSwiGLUDescriptor_t = POINTER(SwiGLUDescriptor) ...@@ -29,11 +75,10 @@ infiniopSwiGLUDescriptor_t = POINTER(SwiGLUDescriptor)
def swiglu(a, b): def swiglu(a, b):
return a * b / (1 + torch.exp(-b.float()).to(b.dtype)) return a * b / (1 + torch.exp(-b.float()).to(b.dtype))
def test_out_of_place( def test(
lib, lib,
handle, handle,
torch_device, torch_device,
...@@ -41,239 +86,76 @@ def test_out_of_place( ...@@ -41,239 +86,76 @@ def test_out_of_place(
a_stride=None, a_stride=None,
b_stride=None, b_stride=None,
c_stride=None, c_stride=None,
inplace=Inplace.OUT_OF_PLACE,
dtype=torch.float16, dtype=torch.float16,
sync=None, sync=None,
): ):
print( print(
f"Testing SwiGLU on {torch_device} with shape:{shape} a_stride:{a_stride} b_stride:{b_stride} c_stride:{c_stride} dtype:{dtype}" f"Testing SwiGLU on {torch_device} with shape:{shape} a_stride:{a_stride} b_stride:{b_stride} c_stride:{c_stride} dtype:{dtype}"
) )
a = torch.rand(shape, dtype=dtype).to(torch_device)
b = torch.rand(shape, dtype=dtype).to(torch_device)
c = torch.rand(shape, dtype=dtype).to(torch_device)
if a_stride is not None:
a = rearrange_tensor(a, a_stride)
if b_stride is not None:
b = rearrange_tensor(b, b_stride)
if c_stride is not None:
c = rearrange_tensor(c, c_stride)
ans = swiglu(a, b)
if sync is not None:
sync()
a_tensor = to_tensor(a, lib)
b_tensor = to_tensor(b, lib)
c_tensor = to_tensor(c, lib)
descriptor = infiniopSwiGLUDescriptor_t()
check_error(
lib.infiniopCreateSwiGLUDescriptor(
handle,
ctypes.byref(descriptor),
c_tensor.descriptor,
a_tensor.descriptor,
b_tensor.descriptor,
)
)
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
a_tensor.descriptor.contents.invalidate()
b_tensor.descriptor.contents.invalidate()
c_tensor.descriptor.contents.invalidate()
check_error(
lib.infiniopSwiGLU(
descriptor, c_tensor.data, a_tensor.data, b_tensor.data, None
)
)
assert torch.allclose(c, ans, atol=1e-4, rtol=1e-2)
print("out-of-place Test passed!")
check_error(lib.infiniopDestroySwiGLUDescriptor(descriptor))
def test_in_place1(
lib,
handle,
torch_device,
shape,
a_stride=None,
b_stride=None,
dtype=torch.float16,
sync=None,
):
a = torch.rand(shape, dtype=dtype).to(torch_device) a = torch.rand(shape, dtype=dtype).to(torch_device)
b = torch.rand(shape, dtype=dtype).to(torch_device) b = torch.rand(shape, dtype=dtype).to(torch_device)
c = (
if a_stride is not None: torch.rand(c_shape, dtype=tensor_dtype).to(torch_device)
a = rearrange_tensor(a, a_stride) if inplace == Inplace.OUT_OF_PLACE
if b_stride is not None: else (a if inplace == Inplace.INPLACE_A else b)
b = rearrange_tensor(b, b_stride)
ans = swiglu(a, b)
if sync is not None:
sync()
a_tensor = to_tensor(a, lib)
b_tensor = to_tensor(b, lib)
descriptor = infiniopSwiGLUDescriptor_t()
check_error(
lib.infiniopCreateSwiGLUDescriptor(
handle,
ctypes.byref(descriptor),
a_tensor.descriptor,
a_tensor.descriptor,
b_tensor.descriptor,
) )
)
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
a_tensor.descriptor.contents.invalidate()
b_tensor.descriptor.contents.invalidate()
check_error(
lib.infiniopSwiGLU(
descriptor, a_tensor.data, a_tensor.data, b_tensor.data, None
)
)
assert torch.allclose(a, ans, atol=1e-4, rtol=1e-2)
print("in-place1 Test passed!")
check_error(lib.infiniopDestroySwiGLUDescriptor(descriptor))
def test_in_place2(
lib,
handle,
torch_device,
shape,
a_stride=None,
b_stride=None,
dtype=torch.float16,
sync=None,
):
a = torch.rand(shape, dtype=dtype).to(torch_device)
b = torch.rand(shape, dtype=dtype).to(torch_device)
if a_stride is not None:
a = rearrange_tensor(a, a_stride)
if b_stride is not None:
b = rearrange_tensor(b, b_stride)
ans = swiglu(a, b) ans = swiglu(a, b)
a, b, c = [
rearrange_if_needed(tensor, stride)
for tensor, stride in zip([a, b, c], [a_stride, b_stride, c_stride])
]
a_tensor, b_tensor = [to_tensor(tensor, lib) for tensor in [a, b]]
c_tensor = (
to_tensor(c, lib)
if inplace == Inplace.OUT_OF_PLACE
else (a_tensor if inplace == Inplace.INPLACE_A else b_tensor)
)
if sync is not None: if sync is not None:
sync() sync()
a_tensor = to_tensor(a, lib)
b_tensor = to_tensor(b, lib)
descriptor = infiniopSwiGLUDescriptor_t() descriptor = infiniopSwiGLUDescriptor_t()
check_error( check_error(
lib.infiniopCreateSwiGLUDescriptor( lib.infiniopCreateSwiGLUDescriptor(
handle, handle,
ctypes.byref(descriptor), ctypes.byref(descriptor),
b_tensor.descriptor, c_tensor.descriptor,
a_tensor.descriptor, a_tensor.descriptor,
b_tensor.descriptor, b_tensor.descriptor,
) )
) )
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
a_tensor.descriptor.contents.invalidate() for tensor in [a_tensor, b_tensor, c_tensor]:
b_tensor.descriptor.contents.invalidate() tensor.descriptor.contents.invalidate()
def lib_swiglu():
check_error( check_error(
lib.infiniopSwiGLU( lib.infiniopSwiGLU(
descriptor, b_tensor.data, a_tensor.data, b_tensor.data, None descriptor, c_tensor.data, a_tensor.data, b_tensor.data, None
)
)
assert torch.allclose(b, ans, atol=1e-4, rtol=1e-2)
check_error(lib.infiniopDestroySwiGLUDescriptor(descriptor))
def test_cpu(lib, test_cases):
device = DeviceEnum.DEVICE_CPU
handle = create_handle(lib, device)
for shape, a_stride, b_stride, c_stride, dtype in test_cases:
test_out_of_place(
lib, handle, "cpu", shape, a_stride, b_stride, c_stride, dtype
)
test_in_place1(lib, handle, "cpu", shape, a_stride, b_stride, dtype)
test_in_place2(lib, handle, "cpu", shape, a_stride, b_stride, dtype)
destroy_handle(lib, handle)
def test_cuda(lib, test_cases):
device = DeviceEnum.DEVICE_CUDA
handle = create_handle(lib, device)
for shape, a_stride, b_stride, c_stride, dtype in test_cases:
test_out_of_place(
lib, handle, "cuda", shape, a_stride, b_stride, c_stride, dtype
) )
test_in_place1(lib, handle, "cuda", shape, a_stride, b_stride, dtype)
test_in_place2(lib, handle, "cuda", shape, a_stride, b_stride, dtype)
destroy_handle(lib, handle)
def test_bang(lib, test_cases):
import torch_mlu
device = DeviceEnum.DEVICE_BANG
handle = create_handle(lib, device)
for shape, a_stride, b_stride, c_stride, dtype in test_cases:
test_out_of_place(
lib, handle, "mlu", shape, a_stride, b_stride, c_stride, dtype
) )
test_in_place1(lib, handle, "mlu", shape, a_stride, b_stride, dtype)
test_in_place2(lib, handle, "mlu", shape, a_stride, b_stride, dtype)
destroy_handle(lib, handle)
def test_ascend(lib, test_cases): lib_swiglu()
import torch_npu
device = DeviceEnum.DEVICE_ASCEND atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
handle = create_handle(lib, device) if DEBUG:
debug(c, ans, atol=atol, rtol=rtol)
assert torch.allclose(c, ans, atol=atol, rtol=rtol)
for shape, a_stride, b_stride, c_stride, dtype in test_cases: # Profiling workflow
test_out_of_place( if PROFILE:
lib, # fmt: off
handle, profile_operation("PyTorch", lambda: swiglu(a, b), torch_device, NUM_PRERUN, NUM_ITERATIONS)
"npu", profile_operation(" lib", lambda: lib_swiglu(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
shape, # fmt: on
a_stride, check_error(lib.infiniopDestroySwiGLUDescriptor(descriptor))
b_stride,
c_stride,
dtype,
torch.npu.synchronize,
)
test_in_place1(
lib, handle, "npu", shape, a_stride, b_stride, dtype, torch.npu.synchronize
)
test_in_place2(
lib, handle, "npu", shape, a_stride, b_stride, dtype, torch.npu.synchronize
)
destroy_handle(lib, handle)
if __name__ == "__main__": if __name__ == "__main__":
test_cases = [
# shape, a_stride, b_stride, c_stride, dtype
((13, 4), None, None, None, torch.float16),
((13, 4), (10, 1), (10, 1), (10, 1), torch.float16),
((16, 5632), None, None, None, torch.float16),
((16, 5632), (13312, 1), (13312, 1), (13312, 1), torch.float16),
]
args = get_args() args = get_args()
lib = open_lib() lib = open_lib()
...@@ -300,12 +182,13 @@ if __name__ == "__main__": ...@@ -300,12 +182,13 @@ if __name__ == "__main__":
infiniopSwiGLUDescriptor_t, infiniopSwiGLUDescriptor_t,
] ]
if args.cpu: # Configure testing options
test_cpu(lib, test_cases) DEBUG = args.debug
if args.cuda: PROFILE = args.profile
test_cuda(lib, test_cases) NUM_PRERUN = args.num_prerun
if args.bang: NUM_ITERATIONS = args.num_iterations
test_bang(lib, test_cases)
if args.ascend: for device in get_test_devices(args):
test_ascend(lib, test_cases) test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES)
print("\033[92mTest passed!\033[0m") print("\033[92mTest passed!\033[0m")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment