Commit ca2f34cf authored by xgqdut2016's avatar xgqdut2016
Browse files

issue/66: modified test py

parent 87d10975
from ctypes import POINTER, Structure, c_int32, c_uint64, c_void_p import torch
import ctypes import ctypes
import sys from ctypes import POINTER, Structure, c_int32, c_size_t, c_uint64, c_void_p, c_float
import os from libinfiniop import (
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
from operatorspy import (
open_lib,
to_tensor,
DeviceEnum,
infiniopHandle_t, infiniopHandle_t,
infiniopTensorDescriptor_t, infiniopTensorDescriptor_t,
create_handle, open_lib,
destroy_handle, to_tensor,
get_test_devices,
check_error, check_error,
rearrange_tensor, rearrange_if_needed,
create_workspace, create_workspace,
test_operator,
get_args,
debug,
get_tolerance,
profile_operation,
) )
from operatorspy.tests.test_utils import get_args # ==============================================================================
import torch # Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES = [
# x_shape, x_stride
((32, 512), None),
((32, 512), (1024, 1)),
((32, 5, 5), None),
((32, 20, 512), None),
((32, 20, 512), (20480, 512, 1)), # Ascend 暂不支持非连续
((32, 20, 4, 512), None),
((32, 20, 4, 512), (81920, 2048, 512, 1)),
]
# Data types used for testing
_TENSOR_DTYPES = [torch.float16, torch.float32]
# Tolerance map for different data types
_TOLERANCE_MAP = {
torch.float16: {'atol': 0, 'rtol': 1e-2},
torch.float32: {'atol': 0, 'rtol': 1e-3},
}
DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000
class CausalSoftmaxDescriptor(Structure): class CausalSoftmaxDescriptor(Structure):
_fields_ = [("device", c_int32)] _fields_ = [("device", c_int32)]
...@@ -37,88 +61,78 @@ def causal_softmax(x): ...@@ -37,88 +61,78 @@ def causal_softmax(x):
return torch.nn.functional.softmax(masked, dim=-1).to(type) return torch.nn.functional.softmax(masked, dim=-1).to(type)
def test(lib, handle, torch_device, x_shape, x_stride=None, x_dtype=torch.float16): def test(
lib,
handle,
torch_device,
x_shape,
x_stride=None,
dtype=torch.float16
):
print( print(
f"Testing CausalSoftmax on {torch_device} with x_shape:{x_shape} x_stride:{x_stride} dtype:{x_dtype}" f"Testing CausalSoftmax on {torch_device} with x_shape:{x_shape} x_stride:{x_stride} dtype:{dtype}"
) )
x = torch.rand(x_shape, dtype=x_dtype).to(torch_device) x = torch.rand(x_shape, dtype=dtype).to(torch_device)
if x_stride is not None:
x = rearrange_tensor(x, x_stride)
ans = causal_softmax(x) ans = causal_softmax(x)
x = rearrange_if_needed(x, x_stride)
x_tensor = to_tensor(x, lib) x_tensor = to_tensor(x, lib)
descriptor = infiniopCausalSoftmaxDescriptor_t() descriptor = infiniopCausalSoftmaxDescriptor_t()
check_error( check_error(
lib.infiniopCreateCausalSoftmaxDescriptor( lib.infiniopCreateCausalSoftmaxDescriptor(
handle, ctypes.byref(descriptor), x_tensor.descriptor handle,
ctypes.byref(descriptor),
x_tensor.descriptor
) )
) )
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
x_tensor.descriptor.contents.invalidate()
workspace_size = c_uint64(0) workspace_size = c_uint64(0)
check_error( check_error(
lib.infiniopGetCausalSoftmaxWorkspaceSize( lib.infiniopGetCausalSoftmaxWorkspaceSize(
descriptor, ctypes.byref(workspace_size) descriptor, ctypes.byref(workspace_size)
) )
) )
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
x_tensor.descriptor.contents.invalidate()
workspace = create_workspace(workspace_size.value, x.device) workspace = create_workspace(workspace_size.value, x.device)
check_error( def lib_causal_softmax():
lib.infiniopCausalSoftmax( check_error(
descriptor, lib.infiniopCausalSoftmax(
workspace.data_ptr() if workspace is not None else None, descriptor,
workspace_size.value, workspace.data_ptr() if workspace is not None else None,
x_tensor.data, workspace_size.value,
None, x_tensor.data,
None,
)
) )
) lib_causal_softmax()
assert torch.allclose(x, ans, atol=0, rtol=1e-2)
check_error(lib.infiniopDestroyCausalSoftmaxDescriptor(descriptor)) atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG:
debug(x, ans, atol=atol, rtol=rtol)
def test_cpu(lib, test_cases): assert torch.allclose(x, ans, atol=atol, rtol=rtol)
device = DeviceEnum.DEVICE_CPU
handle = create_handle(lib, device) # Profiling workflow
for x_shape, x_stride in test_cases: if PROFILE:
test(lib, handle, "cpu", x_shape, x_stride) # fmt: off
destroy_handle(lib, handle) profile_operation("PyTorch", lambda: causal_softmax(x), torch_device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_causal_softmax(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on
def test_cuda(lib, test_cases):
device = DeviceEnum.DEVICE_CUDA
handle = create_handle(lib, device)
for x_shape, x_stride in test_cases:
test(lib, handle, "cuda", x_shape, x_stride)
destroy_handle(lib, handle)
def test_bang(lib, test_cases):
import torch_mlu
device = DeviceEnum.DEVICE_BANG
handle = create_handle(lib, device)
for x_shape, x_stride in test_cases:
test(lib, handle, "mlu", x_shape, x_stride)
destroy_handle(lib, handle)
check_error(lib.infiniopDestroyCausalSoftmaxDescriptor(descriptor))
def test_ascend(lib, test_cases):
import torch_npu
device = DeviceEnum.DEVICE_ASCEND
handle = create_handle(lib, device)
for x_shape, x_stride in test_cases:
test(lib, handle, "npu", x_shape, x_stride)
destroy_handle(lib, handle)
if __name__ == "__main__": if __name__ == "__main__":
test_cases = [
# x_shape, x_stride
((32, 20, 512), None),
((32, 20, 512), (20480, 512, 1)), # Ascend 暂不支持非连续
]
args = get_args() args = get_args()
lib = open_lib() lib = open_lib()
lib.infiniopCreateCausalSoftmaxDescriptor.restype = c_int32 lib.infiniopCreateCausalSoftmaxDescriptor.restype = c_int32
...@@ -144,15 +158,14 @@ if __name__ == "__main__": ...@@ -144,15 +158,14 @@ if __name__ == "__main__":
lib.infiniopDestroyCausalSoftmaxDescriptor.argtypes = [ lib.infiniopDestroyCausalSoftmaxDescriptor.argtypes = [
infiniopCausalSoftmaxDescriptor_t, infiniopCausalSoftmaxDescriptor_t,
] ]
# Configure testing options
DEBUG = args.debug
PROFILE = args.profile
NUM_PRERUN = args.num_prerun
NUM_ITERATIONS = args.num_iterations
for device in get_test_devices(args):
test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES)
if args.cpu:
test_cpu(lib, test_cases)
if args.cuda:
test_cuda(lib, test_cases)
if args.bang:
test_bang(lib, test_cases)
if args.ascend:
test_ascend(lib, test_cases)
if not (args.cpu or args.cuda or args.bang or args.ascend):
test_cpu(lib, test_cases)
print("\033[92mTest passed!\033[0m") print("\033[92mTest passed!\033[0m")
from ctypes import POINTER, Structure, c_int32, c_uint64, c_void_p, c_float import torch
import ctypes import ctypes
import sys from ctypes import POINTER, Structure, c_int32, c_size_t, c_uint64, c_void_p, c_float
import os from libinfiniop import (
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
from operatorspy import (
open_lib,
to_tensor,
DeviceEnum,
infiniopHandle_t, infiniopHandle_t,
infiniopTensorDescriptor_t, infiniopTensorDescriptor_t,
create_handle, open_lib,
destroy_handle, to_tensor,
get_test_devices,
check_error, check_error,
rearrange_tensor, rearrange_if_needed,
create_workspace, create_workspace,
U64, test_operator,
get_args,
debug,
get_tolerance,
profile_operation,
) )
from operatorspy.tests.test_utils import get_args # ==============================================================================
import torch # Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES = [
# voc, random_val, topp, topk, temperature
(512, 0.8, 0.8, 3, 0.5),
(4096, 0.05, 0.9, 5, 1.0),
(16384, 0.15, 0.85, 10, 2.0),
(512, 0.08, 0, 3, 0.5),
(4096, 0.5, 0.9, 1, 1.0),
(16384, 0.15, 0, 1, 2.0),
(16384, 0.15, 0, 1, 2.0),
(32000, 0.08, 0.8, 50, 1.0),
(32000, 0.08, 1.0, 25, 1.0),
# (119696, 0.01, 1.0, 100, 1.0),
]
# Data types used for testing
_TENSOR_DTYPES = [torch.float16, torch.float32]
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000
class RandomSampleDescriptor(Structure): class RandomSampleDescriptor(Structure):
...@@ -116,8 +138,8 @@ def test( ...@@ -116,8 +138,8 @@ def test(
) )
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
x_tensor.descriptor.contents.invalidate() for tensor in [x_tensor, indices_tensor]:
indices_tensor.descriptor.contents.invalidate() tensor.descriptor.contents.invalidate()
workspace_size = c_uint64(0) workspace_size = c_uint64(0)
check_error( check_error(
...@@ -126,77 +148,45 @@ def test( ...@@ -126,77 +148,45 @@ def test(
) )
) )
workspace = create_workspace(workspace_size.value, torch_device) workspace = create_workspace(workspace_size.value, torch_device)
check_error(
lib.infiniopRandomSample( def lib_random_sample():
descriptor, check_error(
workspace.data_ptr() if workspace is not None else None, lib.infiniopRandomSample(
workspace_size.value, descriptor,
indices_tensor.data, workspace.data_ptr() if workspace is not None else None,
x_tensor.data, workspace_size.value,
random_val, indices_tensor.data,
topp, x_tensor.data,
topk, random_val,
temperature, topp,
None, topk,
temperature,
None,
)
) )
)
if torch_device == "npu": if torch_device == "npu":
torch.npu.synchronize() torch.npu.synchronize()
assert indices[0].type(ans.dtype) == ans or data[ans] == data[indices[0]] assert indices[0].type(ans.dtype) == ans or data[ans] == data[indices[0]]
# Profiling workflow
if PROFILE:
# fmt: off
if topp > 0 and topk > 1:
profile_operation("PyTorch", lambda: random_sample(
data.to("cpu"), random_val, topp, topk, voc, temperature, "cpu"
), torch_device, NUM_PRERUN, NUM_ITERATIONS)
else:
profile_operation("PyTorch", lambda: random_sample_0(data), torch_device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_random_sample(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on
check_error(lib.infiniopDestroyRandomSampleDescriptor(descriptor)) check_error(lib.infiniopDestroyRandomSampleDescriptor(descriptor))
def test_cpu(lib, test_cases):
device = DeviceEnum.DEVICE_CPU
handle = create_handle(lib, device)
for voc, random_val, topp, topk, temperature in test_cases:
test(lib, handle, "cpu", voc, random_val, topp, topk, temperature)
destroy_handle(lib, handle)
def test_cuda(lib, test_cases):
device = DeviceEnum.DEVICE_CUDA
handle = create_handle(lib, device)
for voc, random_val, topp, topk, temperature in test_cases:
test(lib, handle, "cuda", voc, random_val, topp, topk, temperature)
destroy_handle(lib, handle)
def test_bang(lib, test_cases):
import torch_mlu
device = DeviceEnum.DEVICE_BANG
handle = create_handle(lib, device)
for voc, random_val, topp, topk, temperature in test_cases:
test(lib, handle, "mlu", voc, random_val, topp, topk, temperature)
destroy_handle(lib, handle)
def test_ascend(lib, test_cases):
import torch_npu
device = DeviceEnum.DEVICE_ASCEND
handle = create_handle(lib, device)
for voc, random_val, topp, topk, temperature in test_cases:
test(lib, handle, "npu", voc, random_val, topp, topk, temperature)
destroy_handle(lib, handle)
if __name__ == "__main__": if __name__ == "__main__":
test_cases = [
# voc, random_val, topp, topk, temperature
(512, 0.8, 0.8, 3, 0.5),
(4096, 0.05, 0.9, 5, 1.0),
(16384, 0.15, 0.85, 10, 2.0),
(512, 0.08, 0, 3, 0.5),
(4096, 0.5, 0.9, 1, 1.0),
(16384, 0.15, 0, 1, 2.0),
(16384, 0.15, 0, 1, 2.0),
(32000, 0.08, 0.8, 50, 1.0),
(32000, 0.08, 1.0, 25, 1.0),
# (119696, 0.01, 1.0, 100, 1.0),
]
args = get_args() args = get_args()
lib = open_lib() lib = open_lib()
...@@ -229,14 +219,12 @@ if __name__ == "__main__": ...@@ -229,14 +219,12 @@ if __name__ == "__main__":
infiniopRandomSampleDescriptor_t, infiniopRandomSampleDescriptor_t,
] ]
if args.cpu: PROFILE = args.profile
test_cpu(lib, test_cases) NUM_PRERUN = args.num_prerun
if args.cuda: NUM_ITERATIONS = args.num_iterations
test_cuda(lib, test_cases)
if args.bang: # Execute tests
test_bang(lib, test_cases) for device in get_test_devices(args):
if args.ascend: test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES)
test_ascend(lib, test_cases)
if not (args.cpu or args.cuda or args.bang or args.ascend):
test_cpu(lib, test_cases)
print("\033[92mTest passed!\033[0m") print("\033[92mTest passed!\033[0m")
import torch
import ctypes import ctypes
from ctypes import POINTER, Structure, c_int32, c_uint64, c_void_p from ctypes import POINTER, Structure, c_int32, c_size_t, c_uint64, c_void_p, c_float
import sys from libinfiniop import (
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
from operatorspy import (
open_lib,
to_tensor,
CTensor,
DeviceEnum,
infiniopHandle_t, infiniopHandle_t,
infiniopTensorDescriptor_t, infiniopTensorDescriptor_t,
create_handle, open_lib,
destroy_handle, to_tensor,
get_test_devices,
check_error, check_error,
rearrange_tensor, rearrange_if_needed,
create_workspace,
test_operator,
get_args,
debug,
get_tolerance,
profile_operation,
) )
from operatorspy.tests.test_utils import get_args # ==============================================================================
import torch # Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES = [
# ((src_shape, src_stride), (dst_shape, dst_stride))
(((2, 4, 32), None), ((2, 4, 32), (256, 64, 1))),
(((32, 6, 64), (64, 2560, 1)), ((32, 6, 64), None)),
(((4, 6, 64), (64, 2560, 1)), ((4, 6, 64), (131072, 64, 1))),
(((1, 32, 64), (2048, 64, 1)), ((1, 32, 64), (2048, 64, 1))),
(((32, 1, 64), (64, 2560, 1)), ((32, 1, 64), (64, 64, 1))),
(((4, 1, 64), (64, 2560, 1)), ((4, 1, 64), (64, 11264, 1))),
(((64,), (1,)), ((64,), (1,))),
]
# Data types used for testing
_TENSOR_DTYPES = [torch.float16, torch.float32]
# Tolerance map for different data types
_TOLERANCE_MAP = {
torch.float16: {"atol": 0, "rtol": 1e-3},
torch.float32: {"atol": 0, "rtol": 1e-3},
}
DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000
class RerrangeDescriptor(Structure): class RerrangeDescriptor(Structure):
...@@ -43,12 +70,13 @@ def test( ...@@ -43,12 +70,13 @@ def test(
) )
x = torch.rand(x_shape, dtype=x_dtype).to(torch_device) x = torch.rand(x_shape, dtype=x_dtype).to(torch_device)
y = torch.zeros(y_shape, dtype=x_dtype).to(torch_device) y = torch.zeros(y_shape, dtype=x_dtype).to(torch_device)
if x_stride is not None:
x = rearrange_tensor(x, x_stride) x, y = [
if y_stride is not None: rearrange_if_needed(tensor, stride)
y = rearrange_tensor(y, y_stride) for tensor, stride in zip([x, y], [x_stride, y_stride])
x_tensor = to_tensor(x, lib) ]
y_tensor = to_tensor(y, lib) x_tensor, y_tensor = [to_tensor(tensor, lib) for tensor in [x, y]]
descriptor = infiniopRearrangeDescriptor_t() descriptor = infiniopRearrangeDescriptor_t()
check_error( check_error(
...@@ -58,71 +86,42 @@ def test( ...@@ -58,71 +86,42 @@ def test(
) )
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
x_tensor.descriptor.contents.invalidate() for tensor in [x_tensor, y_tensor]:
y_tensor.descriptor.contents.invalidate() tensor.descriptor.contents.invalidate()
def lib_rearrange():
check_error(
lib.infiniopRearrange(
descriptor,
y_tensor.data,
x_tensor.data,
None
)
)
lib_rearrange()
# Validate results
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG:
debug(x, y, atol=atol, rtol=rtol)
assert torch.allclose(x, y, atol=atol, rtol=rtol)
# Profiling workflow
if PROFILE:
# fmt: off
profile_operation("PyTorch", lambda: rearrange_tensor(y, y_stride), torch_device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_rearrange(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on
check_error(lib.infiniopRearrange(descriptor, y_tensor.data, x_tensor.data, None))
assert torch.allclose(x, y, atol=0, rtol=1e-3)
check_error(lib.infiniopDestroyRearrangeDescriptor(descriptor)) check_error(lib.infiniopDestroyRearrangeDescriptor(descriptor))
def test_cpu(lib, test_cases):
device = DeviceEnum.DEVICE_CPU
handle = create_handle(lib, device)
for test_case in test_cases:
x_shape, x_stride = test_case[0]
y_shape, y_stride = test_case[1]
test(lib, handle, "cpu", x_shape, x_stride, y_shape, y_stride)
destroy_handle(lib, handle)
def test_cuda(lib, test_cases):
device = DeviceEnum.DEVICE_CUDA
handle = create_handle(lib, device)
for test_case in test_cases:
x_shape, x_stride = test_case[0]
y_shape, y_stride = test_case[1]
test(lib, handle, "cuda", x_shape, x_stride, y_shape, y_stride)
destroy_handle(lib, handle)
def test_bang(lib, test_cases):
import torch_mlu
device = DeviceEnum.DEVICE_BANG
handle = create_handle(lib, device)
for test_case in test_cases:
x_shape, x_stride = test_case[0]
y_shape, y_stride = test_case[1]
test(lib, handle, "mlu", x_shape, x_stride, y_shape, y_stride)
destroy_handle(lib, handle)
def test_ascend(lib, test_cases):
import torch_npu
device = DeviceEnum.DEVICE_ASCEND
handle = create_handle(lib, device)
for test_case in test_cases:
x_shape, x_stride = test_case[0]
y_shape, y_stride = test_case[1]
test(lib, handle, "npu", x_shape, x_stride, y_shape, y_stride)
destroy_handle(lib, handle)
if __name__ == "__main__": if __name__ == "__main__":
args = get_args() args = get_args()
test_cases = [
# ((src_shape, src_stride), (dst_shape, dst_stride))
(((2, 4, 32), None), ((2, 4, 32), (256, 64, 1))),
(((32, 6, 64), (64, 2560, 1)), ((32, 6, 64), None)),
(((4, 6, 64), (64, 2560, 1)), ((4, 6, 64), (131072, 64, 1))),
(((1, 32, 64), (2048, 64, 1)), ((1, 32, 64), (2048, 64, 1))),
(((32, 1, 64), (64, 2560, 1)), ((32, 1, 64), (64, 64, 1))),
(((4, 1, 64), (64, 2560, 1)), ((4, 1, 64), (64, 11264, 1))),
(((64,), (1,)), ((64,), (1,))),
]
lib = open_lib() lib = open_lib()
lib.infiniopCreateRearrangeDescriptor.restype = c_int32 lib.infiniopCreateRearrangeDescriptor.restype = c_int32
lib.infiniopCreateRearrangeDescriptor.argtypes = [ lib.infiniopCreateRearrangeDescriptor.argtypes = [
infiniopHandle_t, infiniopHandle_t,
...@@ -139,12 +138,15 @@ if __name__ == "__main__": ...@@ -139,12 +138,15 @@ if __name__ == "__main__":
] ]
lib.infiniopDestroyRearrangeDescriptor.restype = c_int32 lib.infiniopDestroyRearrangeDescriptor.restype = c_int32
lib.infiniopDestroyRearrangeDescriptor.argtypes = [infiniopRearrangeDescriptor_t] lib.infiniopDestroyRearrangeDescriptor.argtypes = [infiniopRearrangeDescriptor_t]
if args.cpu:
test_cpu(lib, test_cases) # Configure testing options
if args.cuda: DEBUG = args.debug
test_cuda(lib, test_cases) PROFILE = args.profile
if args.bang: NUM_PRERUN = args.num_prerun
test_bang(lib, test_cases) NUM_ITERATIONS = args.num_iterations
if args.ascend:
test_ascend(lib, test_cases) # Execute tests
for device in get_test_devices(args):
test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES)
print("\033[92mTest passed!\033[0m") print("\033[92mTest passed!\033[0m")
from ctypes import POINTER, Structure, c_int32, c_uint64, c_void_p, c_float from ctypes import POINTER, Structure, c_int32, c_uint64, c_void_p, c_float
import ctypes import ctypes
import sys import torch
import os import ctypes
from ctypes import POINTER, Structure, c_int32, c_size_t, c_uint64, c_void_p, c_float
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) from libinfiniop import (
from operatorspy import (
open_lib,
to_tensor,
DeviceEnum,
infiniopHandle_t, infiniopHandle_t,
infiniopTensorDescriptor_t, infiniopTensorDescriptor_t,
create_handle, open_lib,
destroy_handle, to_tensor,
get_test_devices,
check_error, check_error,
rearrange_tensor, rearrange_if_needed,
create_workspace, create_workspace,
test_operator,
get_args,
debug,
get_tolerance,
profile_operation,
) )
from operatorspy.tests.test_utils import get_args # ==============================================================================
import torch # Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES = [
# y_shape, x_shape, w_shape, y_stride, x_stride, w_dtype
((16, 2048), (16, 2048), (2048,), None, None,torch.float32),
((16, 2048), (16, 2048), (2048,), None, None, torch.float16),
((16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1),torch.float32),
((16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1), torch.float16),
]
# x types used for testing
_TENSOR_DTYPES = [torch.float16, torch.float32]
# Tolerance map for different data types
_TOLERANCE_MAP = {
torch.float16: {"atol": 0, "rtol": 1e-2},
torch.float32: {"atol": 0, "rtol": 1e-3},
}
DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000
class RMSNormDescriptor(Structure): class RMSNormDescriptor(Structure):
_fields_ = [("device", c_int32)] _fields_ = [("device", c_int32)]
...@@ -27,7 +51,6 @@ class RMSNormDescriptor(Structure): ...@@ -27,7 +51,6 @@ class RMSNormDescriptor(Structure):
infiniopRMSNormDescriptor_t = POINTER(RMSNormDescriptor) infiniopRMSNormDescriptor_t = POINTER(RMSNormDescriptor)
def rms_norm(x, w, eps): def rms_norm(x, w, eps):
input_dtype = x.dtype input_dtype = x.dtype
hidden_states = x.to(torch.float32) hidden_states = x.to(torch.float32)
...@@ -37,19 +60,18 @@ def rms_norm(x, w, eps): ...@@ -37,19 +60,18 @@ def rms_norm(x, w, eps):
def test( def test(
lib, lib,
handle, handle,
torch_device, torch_device,
y_shape, y_shape,
x_shape, x_shape,
w_shape, w_shape,
dtype=torch.float16, y_stride,
w_dtype=torch.float16, x_stride,
): dtype=torch.float16,
print( w_dtype=torch.float16):
f"Testing RMS_Norm on {torch_device} with y_shape:{y_shape} x_shape:{x_shape} w_shape:{w_shape}" print(f"Testing RMS_Norm on {torch_device} with y_shape:{y_shape} x_shape:{x_shape} w_shape:{w_shape}"
f" dtype:{dtype} w_dtype:{w_dtype}" f" dtype:{dtype} w_dtype:{w_dtype}")
)
y = torch.zeros(y_shape, dtype=dtype).to(torch_device) y = torch.zeros(y_shape, dtype=dtype).to(torch_device)
x = torch.rand(x_shape, dtype=dtype).to(torch_device) x = torch.rand(x_shape, dtype=dtype).to(torch_device)
...@@ -58,93 +80,64 @@ def test( ...@@ -58,93 +80,64 @@ def test(
eps = 1e-5 eps = 1e-5
ans = rms_norm(x, w, eps) ans = rms_norm(x, w, eps)
y_tensor = to_tensor(y, lib) x = rearrange_if_needed(x, x_stride)
x_tensor = to_tensor(x, lib) y = rearrange_if_needed(y, y_stride)
w_tensor = to_tensor(w, lib)
x_tensor, y_tensor, w_tensor = [to_tensor(tensor, lib) for tensor in [x, y, w]]
descriptor = infiniopRMSNormDescriptor_t() descriptor = infiniopRMSNormDescriptor_t()
w_dataType = 0 if w_dtype == torch.float16 else 1 w_dataType = 0 if w_dtype==torch.float16 else 1
check_error( check_error(
lib.infiniopCreateRMSNormDescriptor( lib.infiniopCreateRMSNormDescriptor(
handle, handle, ctypes.byref(descriptor), y_tensor.descriptor, x_tensor.descriptor,
ctypes.byref(descriptor), w_tensor.descriptor, eps
y_tensor.descriptor,
x_tensor.descriptor,
w_tensor.descriptor,
eps,
) )
) )
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
x_tensor.descriptor.contents.invalidate() for tensor in [x_tensor, y_tensor, w_tensor]:
y_tensor.descriptor.contents.invalidate() tensor.descriptor.contents.invalidate()
w_tensor.descriptor.contents.invalidate()
workspace_size = c_uint64(0) workspace_size = c_uint64(0)
check_error( check_error(
lib.infiniopGetRMSNormWorkspaceSize(descriptor, ctypes.byref(workspace_size)) lib.infiniopGetRMSNormWorkspaceSize(
descriptor, ctypes.byref(workspace_size)
)
) )
workspace = create_workspace(workspace_size.value, y.device) workspace = create_workspace(workspace_size.value, y.device)
check_error( def lib_rms_norm():
lib.infiniopRMSNorm( check_error(
descriptor, lib.infiniopRMSNorm(
workspace.data_ptr() if workspace is not None else None, descriptor,
workspace_size.value, workspace.data_ptr() if workspace is not None else None,
y_tensor.data, workspace_size.value,
x_tensor.data, y_tensor.data,
w_tensor.data, x_tensor.data,
None, w_tensor.data,
None,
)
) )
)
assert torch.allclose(y.to(dtype), ans.to(dtype), atol=1e-3, rtol=1e-3) lib_rms_norm()
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG:
debug(y, ans, atol=atol, rtol=rtol)
assert torch.allclose(y, ans, atol=atol, rtol=rtol)
# Profiling workflow
if PROFILE:
# fmt: off
profile_operation("PyTorch", lambda: rms_norm(x, w, eps), torch_device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_rms_norm(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on
check_error(lib.infiniopDestroyRMSNormDescriptor(descriptor)) check_error(lib.infiniopDestroyRMSNormDescriptor(descriptor))
def test_cpu(lib, test_cases):
device = DeviceEnum.DEVICE_CPU
handle = create_handle(lib, device)
for y_shape, x_shape, w_shape, dtype, w_dtype in test_cases:
test(lib, handle, "cpu", y_shape, x_shape, w_shape, dtype, w_dtype)
destroy_handle(lib, handle)
def test_cuda(lib, test_cases):
device = DeviceEnum.DEVICE_CUDA
handle = create_handle(lib, device)
for y_shape, x_shape, w_shape, dtype, w_dtype in test_cases:
test(lib, handle, "cuda", y_shape, x_shape, w_shape, dtype, w_dtype)
destroy_handle(lib, handle)
def test_bang(lib, test_cases):
import torch_mlu
device = DeviceEnum.DEVICE_BANG
handle = create_handle(lib, device)
for y_shape, x_shape, w_shape, dtype, w_dtype in test_cases:
test(lib, handle, "mlu", y_shape, x_shape, w_shape, dtype, w_dtype)
destroy_handle(lib, handle)
def test_ascend(lib, test_cases):
import torch_npu
device = DeviceEnum.DEVICE_ASCEND
handle = create_handle(lib, device)
for y_shape, x_shape, w_shape, dtype, w_dtype in test_cases:
test(lib, handle, "npu", y_shape, x_shape, w_shape, dtype, w_dtype)
destroy_handle(lib, handle)
if __name__ == "__main__": if __name__ == "__main__":
test_cases = [
# y_shape, x_shape, w_shape, dtype, w_dtype
((16, 2048), (16, 2048), (2048,), torch.float16, torch.float16),
((16, 2048), (16, 2048), (2048,), torch.float16, torch.float32),
]
args = get_args() args = get_args()
lib = open_lib() lib = open_lib()
lib.infiniopCreateRMSNormDescriptor.restype = c_int32 lib.infiniopCreateRMSNormDescriptor.restype = c_int32
...@@ -178,14 +171,16 @@ if __name__ == "__main__": ...@@ -178,14 +171,16 @@ if __name__ == "__main__":
infiniopRMSNormDescriptor_t, infiniopRMSNormDescriptor_t,
] ]
if args.cpu: # Configure testing options
test_cpu(lib, test_cases) DEBUG = args.debug
if args.cuda: PROFILE = args.profile
test_cuda(lib, test_cases) NUM_PRERUN = args.num_prerun
if args.bang: NUM_ITERATIONS = args.num_iterations
test_bang(lib, test_cases)
if args.ascend: # Execute tests
test_ascend(lib, test_cases) for device in get_test_devices(args):
if not (args.cpu or args.cuda or args.bang or args.ascend): test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES)
test_cpu(lib, test_cases)
print("\033[92mTest passed!\033[0m") print("\033[92mTest passed!\033[0m")
import torch
import ctypes import ctypes
from ctypes import POINTER, c_void_p, c_int32, c_uint64, Structure, byref from ctypes import POINTER, Structure, c_int32, c_size_t, c_uint64, c_void_p, c_float
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
from libinfiniop import ( from libinfiniop import (
infiniopHandle_t, infiniopHandle_t,
infiniopTensorDescriptor_t, infiniopTensorDescriptor_t,
...@@ -16,10 +13,33 @@ from libinfiniop import ( ...@@ -16,10 +13,33 @@ from libinfiniop import (
test_operator, test_operator,
get_args, get_args,
debug, debug,
get_tolerance,
profile_operation, profile_operation,
InfiniDtype,
) )
import torch
# ==============================================================================
# Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES = [
# (t_shape, t_strides)
((1, 32, 128), None),
((1, 32, 64), None),
# 昇腾暂不满足这个用例,最后一维度 <=32 会有问题,可能与其核心
# 接口 GatherMask 的内部实现相关,目前 48 64 128 都可以支持
((4, 1, 32), None),
((1, 32, 128), None),
((3, 32, 128), (8000, 200, 1)),
]
# Data types used for testing
_TENSOR_DTYPES = [torch.float16, torch.float32]
# Tolerance map for different data types
_TOLERANCE_MAP = {
torch.float16: {"atol": 0, "rtol": 1e-2},
torch.float32: {"atol": 0, "rtol": 1e-3},
}
DEBUG = False DEBUG = False
PROFILE = False PROFILE = False
...@@ -27,6 +47,7 @@ NUM_PRERUN = 10 ...@@ -27,6 +47,7 @@ NUM_PRERUN = 10
NUM_ITERATIONS = 1000 NUM_ITERATIONS = 1000
class RoPEDescriptor(Structure): class RoPEDescriptor(Structure):
_fields_ = [("device", c_int32)] _fields_ = [("device", c_int32)]
...@@ -75,13 +96,22 @@ def sin_cos_table(max_seq_len, dim, torch_device, theta): ...@@ -75,13 +96,22 @@ def sin_cos_table(max_seq_len, dim, torch_device, theta):
return torch.sin(angles), torch.cos(angles) return torch.sin(angles), torch.cos(angles)
def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16): def test(
lib,
handle,
torch_device,
shape,
strides=None,
dtype=torch.float16
):
print( print(
f"Testing Rotary Positional Embedding on {torch_device} with shape:{shape} strides:{strides} and dtype:{dtype}" f"Testing Rotary Positional Embedding on {torch_device} with shape:{shape} strides:{strides} and dtype:{dtype}"
) )
t = torch.rand(shape, dtype=dtype) t = torch.rand(shape, dtype=dtype)
t = rearrange_if_needed(t, strides).to(torch_device)
t = rearrange_if_needed(t, strides)
posTmp = torch.arange(0, t.shape[0]).to(torch_device) posTmp = torch.arange(0, t.shape[0]).to(torch_device)
pos = torch.zeros(2 * posTmp.shape[0], dtype=torch.int32) pos = torch.zeros(2 * posTmp.shape[0], dtype=torch.int32)
for i in range(posTmp.shape[0]): for i in range(posTmp.shape[0]):
...@@ -95,11 +125,12 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16): ...@@ -95,11 +125,12 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16):
descriptor = infiniopRoPEDescriptor_t() descriptor = infiniopRoPEDescriptor_t()
# 2x table length for test # 2x table length for test
sin_table, cos_table = sin_cos_table(t.shape[0] * 2, t.shape[2], t.device, theta) sin_table, cos_table = sin_cos_table(t.shape[0] * 2, t.shape[2], t.device, theta)
t_tensor = to_tensor(t, lib)
t_tensor, sin_table_tensor, cos_table_tensor = [to_tensor(tensor, lib) for tensor in [t, sin_table, cos_table]]
pos_tensor = to_tensor(pos[: t.shape[0]], lib) pos_tensor = to_tensor(pos[: t.shape[0]], lib)
pos_tensor.descriptor.contents.dtype = InfiniDtype.U64 pos_tensor.descriptor.contents.dtype = InfiniDtype.U64
sin_table_tensor = to_tensor(sin_table, lib)
cos_table_tensor = to_tensor(cos_table, lib)
if torch_device == "npu": if torch_device == "npu":
torch.npu.synchronize() torch.npu.synchronize()
...@@ -116,10 +147,8 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16): ...@@ -116,10 +147,8 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16):
) )
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
t_tensor.descriptor.contents.invalidate() for tensor in [t_tensor, pos_tensor, sin_table_tensor, cos_table_tensor]:
pos_tensor.descriptor.contents.invalidate() tensor.descriptor.contents.invalidate()
sin_table_tensor.descriptor.contents.invalidate()
cos_table_tensor.descriptor.contents.invalidate()
workspace_size = c_uint64(0) workspace_size = c_uint64(0)
check_error( check_error(
...@@ -142,9 +171,11 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16): ...@@ -142,9 +171,11 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16):
) )
lib_rope() lib_rope()
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG: if DEBUG:
debug(t, ans, atol=1e-4, rtol=1e-2) debug(t, ans, atol=atol, rtol=rtol)
assert torch.allclose(t, ans, atol=1e-4, rtol=1e-2) assert torch.allclose(t, ans, atol=atol, rtol=rtol)
if PROFILE: if PROFILE:
profile_operation( profile_operation(
"PyTorch", "PyTorch",
...@@ -161,17 +192,6 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16): ...@@ -161,17 +192,6 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16):
if __name__ == "__main__": if __name__ == "__main__":
test_cases = [
# (t_shape, t_strides)
((1, 32, 128), None),
((1, 32, 64), None),
# 昇腾暂不满足这个用例,最后一维度 <=32 会有问题,可能与其核心
# 接口 GatherMask 的内部实现相关,目前 48 64 128 都可以支持
((4, 1, 32), None),
((1, 32, 128), None),
((3, 32, 128), (8000, 200, 1)),
]
test_dtypes = [torch.float16]
args = get_args() args = get_args()
lib = open_lib() lib = open_lib()
lib.infiniopCreateRoPEDescriptor.restype = c_int32 lib.infiniopCreateRoPEDescriptor.restype = c_int32
...@@ -211,5 +231,5 @@ if __name__ == "__main__": ...@@ -211,5 +231,5 @@ if __name__ == "__main__":
# Execute tests # Execute tests
for device in get_test_devices(args): for device in get_test_devices(args):
test_operator(lib, device, test, test_cases, test_dtypes) test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES)
print("\033[92mTest passed!\033[0m") print("\033[92mTest passed!\033[0m")
from ctypes import POINTER, Structure, c_int32, c_uint64, c_void_p import torch
import ctypes import ctypes
import sys from ctypes import POINTER, Structure, c_int32, c_size_t, c_uint64, c_void_p, c_float
import os from libinfiniop import (
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
from operatorspy import (
open_lib,
to_tensor,
CTensor,
DeviceEnum,
infiniopHandle_t, infiniopHandle_t,
infiniopTensorDescriptor_t, infiniopTensorDescriptor_t,
create_handle, open_lib,
destroy_handle, to_tensor,
get_test_devices,
check_error, check_error,
rearrange_tensor, rearrange_if_needed,
create_workspace,
test_operator,
get_args,
debug,
get_tolerance,
profile_operation,
) )
from operatorspy.tests.test_utils import get_args # ==============================================================================
import torch # Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES = [
# shape, a_stride, b_stride, c_stride
((13, 4), None, None, None),
((13, 4), (10, 1), (10, 1), (10, 1)),
((13, 4, 4), None, None, None),
((13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1)),
((16, 5632), None, None, None),
((16, 5632), (13312, 1), (13312, 1), (13312, 1)),
((4, 4, 5632), None, None, None),
((4, 4, 5632), (45056, 5632, 1), (45056, 5632, 1), (45056, 5632, 1)),
]
# Data types used for testing
_TENSOR_DTYPES = [torch.float16, torch.float32]
# Tolerance map for different data types
_TOLERANCE_MAP = {
torch.float16: {'atol': 0, 'rtol': 1e-2},
torch.float32: {'atol': 0, 'rtol': 1e-3},
}
DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000
class SwiGLUDescriptor(Structure): class SwiGLUDescriptor(Structure):
_fields_ = [("device", c_int32)] _fields_ = [("device", c_int32)]
...@@ -51,20 +76,18 @@ def test_out_of_place( ...@@ -51,20 +76,18 @@ def test_out_of_place(
b = torch.rand(shape, dtype=dtype).to(torch_device) b = torch.rand(shape, dtype=dtype).to(torch_device)
c = torch.rand(shape, dtype=dtype).to(torch_device) c = torch.rand(shape, dtype=dtype).to(torch_device)
if a_stride is not None:
a = rearrange_tensor(a, a_stride)
if b_stride is not None:
b = rearrange_tensor(b, b_stride)
if c_stride is not None:
c = rearrange_tensor(c, c_stride)
ans = swiglu(a, b) ans = swiglu(a, b)
a, b, c = [
rearrange_if_needed(tensor, stride)
for tensor, stride in zip([a, b, c], [a_stride, b_stride, c_stride])
]
a_tensor, b_tensor, c_tensor = [to_tensor(tensor, lib) for tensor in [a, b, c]]
if sync is not None: if sync is not None:
sync() sync()
a_tensor = to_tensor(a, lib)
b_tensor = to_tensor(b, lib)
c_tensor = to_tensor(c, lib)
descriptor = infiniopSwiGLUDescriptor_t() descriptor = infiniopSwiGLUDescriptor_t()
check_error( check_error(
lib.infiniopCreateSwiGLUDescriptor( lib.infiniopCreateSwiGLUDescriptor(
...@@ -77,19 +100,33 @@ def test_out_of_place( ...@@ -77,19 +100,33 @@ def test_out_of_place(
) )
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
a_tensor.descriptor.contents.invalidate() for tensor in [a_tensor, b_tensor, c_tensor]:
b_tensor.descriptor.contents.invalidate() tensor.descriptor.contents.invalidate()
c_tensor.descriptor.contents.invalidate()
def lib_swiglu():
check_error( check_error(
lib.infiniopSwiGLU( lib.infiniopSwiGLU(
descriptor, c_tensor.data, a_tensor.data, b_tensor.data, None descriptor,
c_tensor.data,
a_tensor.data,
b_tensor.data,
None
)
) )
) lib_swiglu()
assert torch.allclose(c, ans, atol=1e-4, rtol=1e-2) atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG:
debug(c, ans, atol=atol, rtol=rtol)
assert torch.allclose(c, ans, atol=atol, rtol=rtol)
print("out-of-place Test passed!") print("out-of-place Test passed!")
# Profiling workflow
if PROFILE:
# fmt: off
profile_operation("PyTorch", lambda: swiglu(a, b), torch_device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_swiglu(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on
check_error(lib.infiniopDestroySwiGLUDescriptor(descriptor)) check_error(lib.infiniopDestroySwiGLUDescriptor(descriptor))
...@@ -106,18 +143,19 @@ def test_in_place1( ...@@ -106,18 +143,19 @@ def test_in_place1(
a = torch.rand(shape, dtype=dtype).to(torch_device) a = torch.rand(shape, dtype=dtype).to(torch_device)
b = torch.rand(shape, dtype=dtype).to(torch_device) b = torch.rand(shape, dtype=dtype).to(torch_device)
if a_stride is not None:
a = rearrange_tensor(a, a_stride)
if b_stride is not None:
b = rearrange_tensor(b, b_stride)
ans = swiglu(a, b) ans = swiglu(a, b)
if sync is not None: if sync is not None:
sync() sync()
a_tensor = to_tensor(a, lib) a, b = [
b_tensor = to_tensor(b, lib) rearrange_if_needed(tensor, stride)
for tensor, stride in zip([a, b], [a_stride, b_stride])
]
a_tensor, b_tensor = [to_tensor(tensor, lib) for tensor in [a, b]]
descriptor = infiniopSwiGLUDescriptor_t() descriptor = infiniopSwiGLUDescriptor_t()
check_error( check_error(
lib.infiniopCreateSwiGLUDescriptor( lib.infiniopCreateSwiGLUDescriptor(
handle, handle,
...@@ -129,18 +167,27 @@ def test_in_place1( ...@@ -129,18 +167,27 @@ def test_in_place1(
) )
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
a_tensor.descriptor.contents.invalidate() for tensor in [a_tensor, b_tensor]:
b_tensor.descriptor.contents.invalidate() tensor.descriptor.contents.invalidate()
def lib_swiglu():
check_error( check_error(
lib.infiniopSwiGLU( lib.infiniopSwiGLU(
descriptor, a_tensor.data, a_tensor.data, b_tensor.data, None descriptor, a_tensor.data, a_tensor.data, b_tensor.data, None
)
) )
) lib_swiglu()
assert torch.allclose(a, ans, atol=1e-4, rtol=1e-2) atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG:
debug(a, ans, atol=atol, rtol=rtol)
assert torch.allclose(a, ans, atol=atol, rtol=rtol)
print("in-place1 Test passed!") print("in-place1 Test passed!")
# Profiling workflow
if PROFILE:
# fmt: off
profile_operation("PyTorch", lambda: swiglu(a, b), torch_device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_swiglu(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on
check_error(lib.infiniopDestroySwiGLUDescriptor(descriptor)) check_error(lib.infiniopDestroySwiGLUDescriptor(descriptor))
...@@ -157,17 +204,17 @@ def test_in_place2( ...@@ -157,17 +204,17 @@ def test_in_place2(
a = torch.rand(shape, dtype=dtype).to(torch_device) a = torch.rand(shape, dtype=dtype).to(torch_device)
b = torch.rand(shape, dtype=dtype).to(torch_device) b = torch.rand(shape, dtype=dtype).to(torch_device)
if a_stride is not None:
a = rearrange_tensor(a, a_stride)
if b_stride is not None:
b = rearrange_tensor(b, b_stride)
ans = swiglu(a, b) ans = swiglu(a, b)
if sync is not None: if sync is not None:
sync() sync()
a_tensor = to_tensor(a, lib) a, b = [
b_tensor = to_tensor(b, lib) rearrange_if_needed(tensor, stride)
for tensor, stride in zip([a, b], [a_stride, b_stride])
]
a_tensor, b_tensor = [to_tensor(tensor, lib) for tensor in [a, b]]
descriptor = infiniopSwiGLUDescriptor_t() descriptor = infiniopSwiGLUDescriptor_t()
check_error( check_error(
lib.infiniopCreateSwiGLUDescriptor( lib.infiniopCreateSwiGLUDescriptor(
...@@ -180,100 +227,42 @@ def test_in_place2( ...@@ -180,100 +227,42 @@ def test_in_place2(
) )
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
a_tensor.descriptor.contents.invalidate() for tensor in [a_tensor, b_tensor]:
b_tensor.descriptor.contents.invalidate() tensor.descriptor.contents.invalidate()
check_error( def lib_swiglu():
lib.infiniopSwiGLU( check_error(
descriptor, b_tensor.data, a_tensor.data, b_tensor.data, None lib.infiniopSwiGLU(
descriptor, b_tensor.data, a_tensor.data, b_tensor.data, None
)
) )
) lib_swiglu()
assert torch.allclose(b, ans, atol=1e-4, rtol=1e-2) atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG:
debug(b, ans, atol=atol, rtol=rtol)
assert torch.allclose(b, ans, atol=atol, rtol=rtol)
print("in-place2 Test passed!")
# Profiling workflow
if PROFILE:
# fmt: off
profile_operation("PyTorch", lambda: swiglu(a, b), torch_device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_swiglu(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on
check_error(lib.infiniopDestroySwiGLUDescriptor(descriptor)) check_error(lib.infiniopDestroySwiGLUDescriptor(descriptor))
def test_cpu(lib, test_cases): def test(lib, handle, torch_device, shape, a_stride, b_stride, c_stride, dtype, sync = None):
device = DeviceEnum.DEVICE_CPU test_out_of_place(
handle = create_handle(lib, device) lib, handle, torch_device, shape, a_stride, b_stride, c_stride, dtype, sync
)
for shape, a_stride, b_stride, c_stride, dtype in test_cases: test_in_place1(lib, handle, torch_device, shape, a_stride, b_stride, dtype, sync)
test_out_of_place( test_in_place2(lib, handle, torch_device, shape, a_stride, b_stride, dtype, sync)
lib, handle, "cpu", shape, a_stride, b_stride, c_stride, dtype
)
test_in_place1(lib, handle, "cpu", shape, a_stride, b_stride, dtype)
test_in_place2(lib, handle, "cpu", shape, a_stride, b_stride, dtype)
destroy_handle(lib, handle)
def test_cuda(lib, test_cases):
device = DeviceEnum.DEVICE_CUDA
handle = create_handle(lib, device)
for shape, a_stride, b_stride, c_stride, dtype in test_cases:
test_out_of_place(
lib, handle, "cuda", shape, a_stride, b_stride, c_stride, dtype
)
test_in_place1(lib, handle, "cuda", shape, a_stride, b_stride, dtype)
test_in_place2(lib, handle, "cuda", shape, a_stride, b_stride, dtype)
destroy_handle(lib, handle)
def test_bang(lib, test_cases):
import torch_mlu
device = DeviceEnum.DEVICE_BANG
handle = create_handle(lib, device)
for shape, a_stride, b_stride, c_stride, dtype in test_cases:
test_out_of_place(
lib, handle, "mlu", shape, a_stride, b_stride, c_stride, dtype
)
test_in_place1(lib, handle, "mlu", shape, a_stride, b_stride, dtype)
test_in_place2(lib, handle, "mlu", shape, a_stride, b_stride, dtype)
destroy_handle(lib, handle)
def test_ascend(lib, test_cases):
import torch_npu
device = DeviceEnum.DEVICE_ASCEND
handle = create_handle(lib, device)
for shape, a_stride, b_stride, c_stride, dtype in test_cases:
test_out_of_place(
lib,
handle,
"npu",
shape,
a_stride,
b_stride,
c_stride,
dtype,
torch.npu.synchronize,
)
test_in_place1(
lib, handle, "npu", shape, a_stride, b_stride, dtype, torch.npu.synchronize
)
test_in_place2(
lib, handle, "npu", shape, a_stride, b_stride, dtype, torch.npu.synchronize
)
destroy_handle(lib, handle)
if __name__ == "__main__": if __name__ == "__main__":
test_cases = [
# shape, a_stride, b_stride, c_stride, dtype
((13, 4), None, None, None, torch.float16),
((13, 4), (10, 1), (10, 1), (10, 1), torch.float16),
((16, 5632), None, None, None, torch.float16),
((16, 5632), (13312, 1), (13312, 1), (13312, 1), torch.float16),
]
args = get_args() args = get_args()
lib = open_lib() lib = open_lib()
...@@ -299,13 +288,13 @@ if __name__ == "__main__": ...@@ -299,13 +288,13 @@ if __name__ == "__main__":
lib.infiniopDestroySwiGLUDescriptor.argtypes = [ lib.infiniopDestroySwiGLUDescriptor.argtypes = [
infiniopSwiGLUDescriptor_t, infiniopSwiGLUDescriptor_t,
] ]
# Configure testing options
DEBUG = args.debug
PROFILE = args.profile
NUM_PRERUN = args.num_prerun
NUM_ITERATIONS = args.num_iterations
for device in get_test_devices(args):
test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES)
if args.cpu:
test_cpu(lib, test_cases)
if args.cuda:
test_cuda(lib, test_cases)
if args.bang:
test_bang(lib, test_cases)
if args.ascend:
test_ascend(lib, test_cases)
print("\033[92mTest passed!\033[0m") print("\033[92mTest passed!\033[0m")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment