Commit 04aa18f6 authored by xgqdut2016's avatar xgqdut2016
Browse files

issue/66: modified format

parent ca2f34cf
......@@ -23,22 +23,20 @@ from libinfiniop import (
# These are not meant to be imported from other modules
_TEST_CASES = [
# x_shape, x_stride
((32, 512), None),
((32, 512), (1024, 1)),
((32, 5, 5), None),
((32, 20, 512), None),
((32, 20, 512), (20480, 512, 1)), # Ascend 暂不支持非连续
((32, 20, 4, 512), None),
((32, 20, 4, 512), (81920, 2048, 512, 1)),
]
# x_shape, x_stride
((32, 512), None),
((32, 512), (1024, 1)),
((32, 5, 5), None),
((32, 20, 512), None),
((32, 20, 512), (20480, 512, 1)), # Ascend 暂不支持非连续
]
# Data types used for testing
_TENSOR_DTYPES = [torch.float16, torch.float32]
_TENSOR_DTYPES = [torch.float16]
# Tolerance map for different data types
_TOLERANCE_MAP = {
torch.float16: {'atol': 0, 'rtol': 1e-2},
torch.float32: {'atol': 0, 'rtol': 1e-3},
torch.float16: {"atol": 0, "rtol": 1e-2},
}
DEBUG = False
......@@ -46,6 +44,7 @@ PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000
class CausalSoftmaxDescriptor(Structure):
_fields_ = [("device", c_int32)]
......@@ -61,39 +60,29 @@ def causal_softmax(x):
return torch.nn.functional.softmax(masked, dim=-1).to(type)
def test(
lib,
handle,
torch_device,
x_shape,
x_stride=None,
dtype=torch.float16
):
def test(lib, handle, torch_device, x_shape, x_stride=None, dtype=torch.float16):
print(
f"Testing CausalSoftmax on {torch_device} with x_shape:{x_shape} x_stride:{x_stride} dtype:{dtype}"
)
x = torch.rand(x_shape, dtype=dtype).to(torch_device)
ans = causal_softmax(x)
x = rearrange_if_needed(x, x_stride)
x_tensor = to_tensor(x, lib)
descriptor = infiniopCausalSoftmaxDescriptor_t()
check_error(
lib.infiniopCreateCausalSoftmaxDescriptor(
handle,
ctypes.byref(descriptor),
x_tensor.descriptor
handle, ctypes.byref(descriptor), x_tensor.descriptor
)
)
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
x_tensor.descriptor.contents.invalidate()
workspace_size = c_uint64(0)
check_error(
lib.infiniopGetCausalSoftmaxWorkspaceSize(
......@@ -101,6 +90,7 @@ def test(
)
)
workspace = create_workspace(workspace_size.value, x.device)
def lib_causal_softmax():
check_error(
lib.infiniopCausalSoftmax(
......@@ -111,8 +101,9 @@ def test(
None,
)
)
lib_causal_softmax()
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG:
debug(x, ans, atol=atol, rtol=rtol)
......@@ -128,24 +119,23 @@ def test(
check_error(lib.infiniopDestroyCausalSoftmaxDescriptor(descriptor))
if __name__ == "__main__":
args = get_args()
lib = open_lib()
lib.infiniopCreateCausalSoftmaxDescriptor.restype = c_int32
lib.infiniopCreateCausalSoftmaxDescriptor.argtypes = [
infiniopHandle_t,
POINTER(infiniopCausalSoftmaxDescriptor_t),
infiniopTensorDescriptor_t,
]
lib.infiniopGetCausalSoftmaxWorkspaceSize.restype = c_int32
lib.infiniopGetCausalSoftmaxWorkspaceSize.argtypes = [
infiniopCausalSoftmaxDescriptor_t,
POINTER(c_uint64),
]
lib.infiniopCausalSoftmax.restype = c_int32
lib.infiniopCausalSoftmax.argtypes = [
infiniopCausalSoftmaxDescriptor_t,
......@@ -154,18 +144,19 @@ if __name__ == "__main__":
c_void_p,
c_void_p,
]
lib.infiniopDestroyCausalSoftmaxDescriptor.restype = c_int32
lib.infiniopDestroyCausalSoftmaxDescriptor.argtypes = [
infiniopCausalSoftmaxDescriptor_t,
]
# Configure testing options
DEBUG = args.debug
PROFILE = args.profile
NUM_PRERUN = args.num_prerun
NUM_ITERATIONS = args.num_iterations
for device in get_test_devices(args):
test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES)
print("\033[92mTest passed!\033[0m")
......@@ -12,33 +12,40 @@ from libinfiniop import (
create_workspace,
test_operator,
get_args,
debug,
debug_all,
get_tolerance,
profile_operation,
synchronize_device,
)
# ==============================================================================
# Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES = [
# voc, random_val, topp, topk, temperature
(512, 0.8, 0.8, 3, 0.5),
(4096, 0.05, 0.9, 5, 1.0),
(16384, 0.15, 0.85, 10, 2.0),
(512, 0.08, 0, 3, 0.5),
(4096, 0.5, 0.9, 1, 1.0),
(16384, 0.15, 0, 1, 2.0),
(16384, 0.15, 0, 1, 2.0),
(32000, 0.08, 0.8, 50, 1.0),
(32000, 0.08, 1.0, 25, 1.0),
# (119696, 0.01, 1.0, 100, 1.0),
(512, 0.8, 0.8, 3, 0.5),
(4096, 0.05, 0.9, 5, 1.0),
(16384, 0.15, 0.85, 10, 2.0),
(512, 0.08, 0, 3, 0.5),
(4096, 0.5, 0.9, 1, 1.0),
(16384, 0.15, 0, 1, 2.0),
(16384, 0.15, 0, 1, 2.0),
(32000, 0.08, 0.8, 50, 1.0),
(32000, 0.08, 1.0, 25, 1.0),
# (119696, 0.01, 1.0, 100, 1.0),
]
# Data types used for testing
_TENSOR_DTYPES = [torch.float16, torch.float32]
_TENSOR_DTYPES = [torch.float16]
_TOLERANCE_MAP = {
torch.float16: {"atol": 0, "rtol": 0},
}
DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000
......@@ -113,6 +120,7 @@ def test(
x_dtype=torch.float16,
):
print(f"Testing RandomSample on {torch_device} with voc:{voc} dtype:{x_dtype}")
data = torch.arange(voc).float() * 0.0001
_perm = torch.randperm(voc)
data = data[_perm].to(x_dtype).to(torch_device)
......@@ -122,9 +130,11 @@ def test(
)
else:
ans = random_sample_0(data)
indices = torch.zeros([1], dtype=torch.int64).to(torch_device)
x_tensor = to_tensor(data, lib)
indices_tensor = to_tensor(indices, lib)
x_tensor, indices_tensor = [to_tensor(tensor, lib) for tensor in [data, indices]]
indices_tensor.descriptor.contents.dt = U64 # treat int64 as uint64
descriptor = infiniopRandomSampleDescriptor_t()
......@@ -148,7 +158,7 @@ def test(
)
)
workspace = create_workspace(workspace_size.value, torch_device)
def lib_random_sample():
check_error(
lib.infiniopRandomSample(
......@@ -164,11 +174,21 @@ def test(
None,
)
)
if torch_device == "npu":
torch.npu.synchronize()
if torch_device == "npu":
synchronize_device(torch_device)
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG:
debug_all(
(indices[0].type(ans.dtype), data[indices[0]]),
(ans, data[ans]),
"or",
atol=atol,
rtol=rtol,
)
assert indices[0].type(ans.dtype) == ans or data[ans] == data[indices[0]]
# Profiling workflow
if PROFILE:
# fmt: off
......@@ -184,23 +204,23 @@ def test(
check_error(lib.infiniopDestroyRandomSampleDescriptor(descriptor))
if __name__ == "__main__":
args = get_args()
lib = open_lib()
lib.infiniopCreateRandomSampleDescriptor.restype = c_int32
lib.infiniopCreateRandomSampleDescriptor.argtypes = [
infiniopHandle_t,
POINTER(infiniopRandomSampleDescriptor_t),
infiniopTensorDescriptor_t,
]
lib.infiniopGetRandomSampleWorkspaceSize.restype = c_int32
lib.infiniopGetRandomSampleWorkspaceSize.argtypes = [
infiniopRandomSampleDescriptor_t,
POINTER(c_uint64),
]
lib.infiniopRandomSample.restype = c_int32
lib.infiniopRandomSample.argtypes = [
infiniopRandomSampleDescriptor_t,
......@@ -214,11 +234,13 @@ if __name__ == "__main__":
c_float,
c_void_p,
]
lib.infiniopDestroyRandomSampleDescriptor.restype = c_int32
lib.infiniopDestroyRandomSampleDescriptor.argtypes = [
infiniopRandomSampleDescriptor_t,
]
DEBUG = args.debug
PROFILE = args.profile
NUM_PRERUN = args.num_prerun
NUM_ITERATIONS = args.num_iterations
......
......@@ -23,13 +23,13 @@ from libinfiniop import (
# These are not meant to be imported from other modules
_TEST_CASES = [
# ((src_shape, src_stride), (dst_shape, dst_stride))
(((2, 4, 32), None), ((2, 4, 32), (256, 64, 1))),
(((32, 6, 64), (64, 2560, 1)), ((32, 6, 64), None)),
(((4, 6, 64), (64, 2560, 1)), ((4, 6, 64), (131072, 64, 1))),
(((1, 32, 64), (2048, 64, 1)), ((1, 32, 64), (2048, 64, 1))),
(((32, 1, 64), (64, 2560, 1)), ((32, 1, 64), (64, 64, 1))),
(((4, 1, 64), (64, 2560, 1)), ((4, 1, 64), (64, 11264, 1))),
(((64,), (1,)), ((64,), (1,))),
(((2, 4, 32), None), ((2, 4, 32), (256, 64, 1))),
(((32, 6, 64), (64, 2560, 1)), ((32, 6, 64), None)),
(((4, 6, 64), (64, 2560, 1)), ((4, 6, 64), (131072, 64, 1))),
(((1, 32, 64), (2048, 64, 1)), ((1, 32, 64), (2048, 64, 1))),
(((32, 1, 64), (64, 2560, 1)), ((32, 1, 64), (64, 64, 1))),
(((4, 1, 64), (64, 2560, 1)), ((4, 1, 64), (64, 11264, 1))),
(((64,), (1,)), ((64,), (1,))),
]
# Data types used for testing
......@@ -37,8 +37,8 @@ _TENSOR_DTYPES = [torch.float16, torch.float32]
# Tolerance map for different data types
_TOLERANCE_MAP = {
torch.float16: {"atol": 0, "rtol": 1e-3},
torch.float32: {"atol": 0, "rtol": 1e-3},
torch.float16: {"atol": 0, "rtol": 0},
torch.float32: {"atol": 0, "rtol": 0},
}
DEBUG = False
......@@ -47,7 +47,6 @@ NUM_PRERUN = 10
NUM_ITERATIONS = 1000
class RerrangeDescriptor(Structure):
_fields_ = [("device", c_int32)]
......@@ -68,16 +67,16 @@ def test(
print(
f"Testing Rerrange on {torch_device} with x_shape:{x_shape} x_stride:{x_stride} y_shape:{y_shape} y_stride:{y_stride} x_dtype:{x_dtype}"
)
x = torch.rand(x_shape, dtype=x_dtype).to(torch_device)
y = torch.zeros(y_shape, dtype=x_dtype).to(torch_device)
x, y = [
rearrange_if_needed(tensor, stride)
for tensor, stride in zip([x, y], [x_stride, y_stride])
]
x_tensor, y_tensor = [to_tensor(tensor, lib) for tensor in [x, y]]
descriptor = infiniopRearrangeDescriptor_t()
check_error(
lib.infiniopCreateRearrangeDescriptor(
......@@ -91,15 +90,11 @@ def test(
def lib_rearrange():
check_error(
lib.infiniopRearrange(
descriptor,
y_tensor.data,
x_tensor.data,
None
)
lib.infiniopRearrange(descriptor, y_tensor.data, x_tensor.data, None)
)
lib_rearrange()
# Validate results
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG:
......@@ -116,8 +111,6 @@ def test(
check_error(lib.infiniopDestroyRearrangeDescriptor(descriptor))
if __name__ == "__main__":
args = get_args()
lib = open_lib()
......@@ -129,6 +122,7 @@ if __name__ == "__main__":
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
]
lib.infiniopRearrange.restype = c_int32
lib.infiniopRearrange.argtypes = [
infiniopRearrangeDescriptor_t,
......@@ -136,9 +130,10 @@ if __name__ == "__main__":
c_void_p,
c_void_p,
]
lib.infiniopDestroyRearrangeDescriptor.restype = c_int32
lib.infiniopDestroyRearrangeDescriptor.argtypes = [infiniopRearrangeDescriptor_t]
# Configure testing options
DEBUG = args.debug
PROFILE = args.profile
......
......@@ -23,21 +23,20 @@ from libinfiniop import (
# Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES = [
# y_shape, x_shape, w_shape, y_stride, x_stride, w_dtype
((16, 2048), (16, 2048), (2048,), None, None,torch.float32),
((16, 2048), (16, 2048), (2048,), None, None, torch.float32),
((16, 2048), (16, 2048), (2048,), None, None, torch.float16),
((16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1),torch.float32),
((16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1), torch.float32),
((16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1), torch.float16),
]
# x types used for testing
_TENSOR_DTYPES = [torch.float16, torch.float32]
_TENSOR_DTYPES = [torch.float16]
# Tolerance map for different data types
_TOLERANCE_MAP = {
torch.float16: {"atol": 0, "rtol": 1e-2},
torch.float32: {"atol": 0, "rtol": 1e-3},
torch.float16: {"atol": 1e-3, "rtol": 1e-3},
}
DEBUG = False
......@@ -45,12 +44,14 @@ PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000
class RMSNormDescriptor(Structure):
_fields_ = [("device", c_int32)]
infiniopRMSNormDescriptor_t = POINTER(RMSNormDescriptor)
def rms_norm(x, w, eps):
input_dtype = x.dtype
hidden_states = x.to(torch.float32)
......@@ -60,18 +61,21 @@ def rms_norm(x, w, eps):
def test(
lib,
handle,
torch_device,
y_shape,
x_shape,
w_shape,
lib,
handle,
torch_device,
y_shape,
x_shape,
w_shape,
y_stride,
x_stride,
dtype=torch.float16,
w_dtype=torch.float16):
print(f"Testing RMS_Norm on {torch_device} with y_shape:{y_shape} x_shape:{x_shape} w_shape:{w_shape}"
f" dtype:{dtype} w_dtype:{w_dtype}")
dtype=torch.float16,
w_dtype=torch.float16,
):
print(
f"Testing RMS_Norm on {torch_device} with y_shape:{y_shape} x_shape:{x_shape} w_shape:{w_shape}"
f" dtype:{dtype} w_dtype:{w_dtype}"
)
y = torch.zeros(y_shape, dtype=dtype).to(torch_device)
x = torch.rand(x_shape, dtype=dtype).to(torch_device)
......@@ -80,18 +84,23 @@ def test(
eps = 1e-5
ans = rms_norm(x, w, eps)
x = rearrange_if_needed(x, x_stride)
y = rearrange_if_needed(y, y_stride)
x, y = [
rearrange_if_needed(tensor, stride)
for tensor, stride in zip([x, y], [x_stride, y_stride])
]
x_tensor, y_tensor, w_tensor = [to_tensor(tensor, lib) for tensor in [x, y, w]]
descriptor = infiniopRMSNormDescriptor_t()
w_dataType = 0 if w_dtype==torch.float16 else 1
check_error(
lib.infiniopCreateRMSNormDescriptor(
handle, ctypes.byref(descriptor), y_tensor.descriptor, x_tensor.descriptor,
w_tensor.descriptor, eps
handle,
ctypes.byref(descriptor),
y_tensor.descriptor,
x_tensor.descriptor,
w_tensor.descriptor,
eps,
)
)
......@@ -101,11 +110,10 @@ def test(
workspace_size = c_uint64(0)
check_error(
lib.infiniopGetRMSNormWorkspaceSize(
descriptor, ctypes.byref(workspace_size)
)
lib.infiniopGetRMSNormWorkspaceSize(descriptor, ctypes.byref(workspace_size))
)
workspace = create_workspace(workspace_size.value, y.device)
def lib_rms_norm():
check_error(
lib.infiniopRMSNorm(
......@@ -134,12 +142,10 @@ def test(
check_error(lib.infiniopDestroyRMSNormDescriptor(descriptor))
if __name__ == "__main__":
args = get_args()
lib = open_lib()
lib.infiniopCreateRMSNormDescriptor.restype = c_int32
lib.infiniopCreateRMSNormDescriptor.argtypes = [
infiniopHandle_t,
......@@ -166,6 +172,7 @@ if __name__ == "__main__":
c_void_p,
c_void_p,
]
lib.infiniopDestroyRMSNormDescriptor.restype = c_int32
lib.infiniopDestroyRMSNormDescriptor.argtypes = [
infiniopRMSNormDescriptor_t,
......@@ -182,5 +189,3 @@ if __name__ == "__main__":
test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES)
print("\033[92mTest passed!\033[0m")
......@@ -15,6 +15,7 @@ from libinfiniop import (
debug,
get_tolerance,
profile_operation,
synchronize_device,
)
# ==============================================================================
......@@ -23,22 +24,21 @@ from libinfiniop import (
# These are not meant to be imported from other modules
_TEST_CASES = [
# (t_shape, t_strides)
((1, 32, 128), None),
((1, 32, 64), None),
# 昇腾暂不满足这个用例,最后一维度 <=32 会有问题,可能与其核心
# 接口 GatherMask 的内部实现相关,目前 48 64 128 都可以支持
((4, 1, 32), None),
((1, 32, 128), None),
((3, 32, 128), (8000, 200, 1)),
((1, 32, 128), None),
((1, 32, 64), None),
# 昇腾暂不满足这个用例,最后一维度 <=32 会有问题,可能与其核心
# 接口 GatherMask 的内部实现相关,目前 48 64 128 都可以支持
((4, 1, 32), None),
((1, 32, 128), None),
((3, 32, 128), (8000, 200, 1)),
]
# Data types used for testing
_TENSOR_DTYPES = [torch.float16, torch.float32]
_TENSOR_DTYPES = [torch.float16]
# Tolerance map for different data types
_TOLERANCE_MAP = {
torch.float16: {"atol": 0, "rtol": 1e-2},
torch.float32: {"atol": 0, "rtol": 1e-3},
torch.float16: {"atol": 1e-4, "rtol": 1e-2},
}
DEBUG = False
......@@ -47,7 +47,6 @@ NUM_PRERUN = 10
NUM_ITERATIONS = 1000
class RoPEDescriptor(Structure):
_fields_ = [("device", c_int32)]
......@@ -96,14 +95,7 @@ def sin_cos_table(max_seq_len, dim, torch_device, theta):
return torch.sin(angles), torch.cos(angles)
def test(
lib,
handle,
torch_device,
shape,
strides=None,
dtype=torch.float16
):
def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16):
print(
f"Testing Rotary Positional Embedding on {torch_device} with shape:{shape} strides:{strides} and dtype:{dtype}"
)
......@@ -126,14 +118,15 @@ def test(
# 2x table length for test
sin_table, cos_table = sin_cos_table(t.shape[0] * 2, t.shape[2], t.device, theta)
t_tensor, sin_table_tensor, cos_table_tensor = [to_tensor(tensor, lib) for tensor in [t, sin_table, cos_table]]
t_tensor, sin_table_tensor, cos_table_tensor = [
to_tensor(tensor, lib) for tensor in [t, sin_table, cos_table]
]
pos_tensor = to_tensor(pos[: t.shape[0]], lib)
pos_tensor.descriptor.contents.dtype = InfiniDtype.U64
if torch_device == "npu":
torch.npu.synchronize()
synchronize_device(torch_device)
check_error(
lib.infiniopCreateRoPEDescriptor(
......@@ -171,11 +164,12 @@ def test(
)
lib_rope()
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG:
debug(t, ans, atol=atol, rtol=rtol)
assert torch.allclose(t, ans, atol=atol, rtol=rtol)
if PROFILE:
profile_operation(
"PyTorch",
......@@ -194,6 +188,7 @@ def test(
if __name__ == "__main__":
args = get_args()
lib = open_lib()
lib.infiniopCreateRoPEDescriptor.restype = c_int32
lib.infiniopCreateRoPEDescriptor.argtypes = [
infiniopHandle_t,
......@@ -203,11 +198,13 @@ if __name__ == "__main__":
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
]
lib.infiniopGetRoPEWorkspaceSize.restype = c_int32
lib.infiniopGetRoPEWorkspaceSize.argtypes = [
infiniopRoPEDescriptor_t,
POINTER(c_uint64),
]
lib.infiniopRoPE.restype = c_int32
lib.infiniopRoPE.argtypes = [
infiniopRoPEDescriptor_t,
......@@ -219,10 +216,12 @@ if __name__ == "__main__":
c_void_p,
c_void_p,
]
lib.infiniopDestroyRoPEDescriptor.restype = c_int32
lib.infiniopDestroyRoPEDescriptor.argtypes = [
infiniopRoPEDescriptor_t,
]
# Configure testing options
DEBUG = args.debug
PROFILE = args.profile
......
......@@ -16,29 +16,64 @@ from libinfiniop import (
get_tolerance,
profile_operation,
)
from enum import Enum, auto
# ==============================================================================
# Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES = [
# shape, a_stride, b_stride, c_stride
((13, 4), None, None, None),
((13, 4), (10, 1), (10, 1), (10, 1)),
((13, 4, 4), None, None, None),
((13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1)),
((16, 5632), None, None, None),
((16, 5632), (13312, 1), (13312, 1), (13312, 1)),
((4, 4, 5632), None, None, None),
((4, 4, 5632), (45056, 5632, 1), (45056, 5632, 1), (45056, 5632, 1)),
# shape, a_stride, b_stride, c_stride, inplace
((13, 4), None, None, None, Inplace.OUT_OF_PLACE),
((13, 4), None, None, None, Inplace.INPLACE_A),
((13, 4), None, None, None, Inplace.INPLACE_B),
((13, 4), (10, 1), (10, 1), (10, 1), Inplace.OUT_OF_PLACE),
((13, 4), (10, 1), (10, 1), (10, 1), Inplace.INPLACE_A),
((13, 4), (10, 1), (10, 1), (10, 1), Inplace.INPLACE_B),
((13, 4, 4), None, None, None, Inplace.OUT_OF_PLACE),
((13, 4, 4), None, None, None, Inplace.INPLACE_A),
((13, 4, 4), None, None, None, Inplace.INPLACE_B),
((13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1), Inplace.OUT_OF_PLACE),
((13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1), Inplace.INPLACE_A),
((13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1), Inplace.INPLACE_B),
((16, 5632), None, None, None, Inplace.OUT_OF_PLACE),
((16, 5632), None, None, None, Inplace.INPLACE_A),
((16, 5632), None, None, None, Inplace.INPLACE_B),
((16, 5632), (13312, 1), (13312, 1), (13312, 1), Inplace.OUT_OF_PLACE),
((16, 5632), (13312, 1), (13312, 1), (13312, 1), Inplace.INPLACE_A),
((16, 5632), (13312, 1), (13312, 1), (13312, 1), Inplace.INPLACE_B),
((4, 4, 5632), None, None, None, Inplace.OUT_OF_PLACE),
((4, 4, 5632), None, None, None, Inplace.INPLACE_A),
((4, 4, 5632), None, None, None, Inplace.INPLACE_B),
(
(4, 4, 5632),
(45056, 5632, 1),
(45056, 5632, 1),
(45056, 5632, 1),
Inplace.OUT_OF_PLACE,
),
(
(4, 4, 5632),
(45056, 5632, 1),
(45056, 5632, 1),
(45056, 5632, 1),
Inplace.INPLACE_A,
),
(
(4, 4, 5632),
(45056, 5632, 1),
(45056, 5632, 1),
(45056, 5632, 1),
Inplace.INPLACE_B,
),
]
# Data types used for testing
_TENSOR_DTYPES = [torch.float16, torch.float32]
_TENSOR_DTYPES = [torch.float16]
# Tolerance map for different data types
_TOLERANCE_MAP = {
torch.float16: {'atol': 0, 'rtol': 1e-2},
torch.float32: {'atol': 0, 'rtol': 1e-3},
torch.float16: {"atol": 1e-4, "rtol": 1e-2},
}
DEBUG = False
......@@ -46,6 +81,13 @@ PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000
class Inplace(Enum):
OUT_OF_PLACE = auto()
INPLACE_A = auto()
INPLACE_B = auto()
class SwiGLUDescriptor(Structure):
_fields_ = [("device", c_int32)]
......@@ -54,11 +96,10 @@ infiniopSwiGLUDescriptor_t = POINTER(SwiGLUDescriptor)
def swiglu(a, b):
return a * b / (1 + torch.exp(-b.float()).to(b.dtype))
def test_out_of_place(
def test(
lib,
handle,
torch_device,
......@@ -66,15 +107,21 @@ def test_out_of_place(
a_stride=None,
b_stride=None,
c_stride=None,
inplace=Inplace.OUT_OF_PLACE,
dtype=torch.float16,
sync=None,
):
print(
f"Testing SwiGLU on {torch_device} with shape:{shape} a_stride:{a_stride} b_stride:{b_stride} c_stride:{c_stride} dtype:{dtype}"
)
a = torch.rand(shape, dtype=dtype).to(torch_device)
b = torch.rand(shape, dtype=dtype).to(torch_device)
c = torch.rand(shape, dtype=dtype).to(torch_device)
c = (
torch.rand(c_shape, dtype=tensor_dtype).to(torch_device)
if inplace == Inplace.OUT_OF_PLACE
else (a if inplace == Inplace.INPLACE_A else b)
)
ans = swiglu(a, b)
......@@ -82,9 +129,12 @@ def test_out_of_place(
rearrange_if_needed(tensor, stride)
for tensor, stride in zip([a, b, c], [a_stride, b_stride, c_stride])
]
a_tensor, b_tensor, c_tensor = [to_tensor(tensor, lib) for tensor in [a, b, c]]
a_tensor, b_tensor = [to_tensor(tensor, lib) for tensor in [a, b]]
c_tensor = (
to_tensor(c, lib)
if inplace == Inplace.OUT_OF_PLACE
else (a_tensor if inplace == Inplace.INPLACE_A else b_tensor)
)
if sync is not None:
sync()
......@@ -106,13 +156,10 @@ def test_out_of_place(
def lib_swiglu():
check_error(
lib.infiniopSwiGLU(
descriptor,
c_tensor.data,
a_tensor.data,
b_tensor.data,
None
descriptor, c_tensor.data, a_tensor.data, b_tensor.data, None
)
)
lib_swiglu()
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
......@@ -130,139 +177,7 @@ def test_out_of_place(
check_error(lib.infiniopDestroySwiGLUDescriptor(descriptor))
def test_in_place1(
lib,
handle,
torch_device,
shape,
a_stride=None,
b_stride=None,
dtype=torch.float16,
sync=None,
):
a = torch.rand(shape, dtype=dtype).to(torch_device)
b = torch.rand(shape, dtype=dtype).to(torch_device)
ans = swiglu(a, b)
if sync is not None:
sync()
a, b = [
rearrange_if_needed(tensor, stride)
for tensor, stride in zip([a, b], [a_stride, b_stride])
]
a_tensor, b_tensor = [to_tensor(tensor, lib) for tensor in [a, b]]
descriptor = infiniopSwiGLUDescriptor_t()
check_error(
lib.infiniopCreateSwiGLUDescriptor(
handle,
ctypes.byref(descriptor),
a_tensor.descriptor,
a_tensor.descriptor,
b_tensor.descriptor,
)
)
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
for tensor in [a_tensor, b_tensor]:
tensor.descriptor.contents.invalidate()
def lib_swiglu():
check_error(
lib.infiniopSwiGLU(
descriptor, a_tensor.data, a_tensor.data, b_tensor.data, None
)
)
lib_swiglu()
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG:
debug(a, ans, atol=atol, rtol=rtol)
assert torch.allclose(a, ans, atol=atol, rtol=rtol)
print("in-place1 Test passed!")
# Profiling workflow
if PROFILE:
# fmt: off
profile_operation("PyTorch", lambda: swiglu(a, b), torch_device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_swiglu(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on
check_error(lib.infiniopDestroySwiGLUDescriptor(descriptor))
def test_in_place2(
lib,
handle,
torch_device,
shape,
a_stride=None,
b_stride=None,
dtype=torch.float16,
sync=None,
):
a = torch.rand(shape, dtype=dtype).to(torch_device)
b = torch.rand(shape, dtype=dtype).to(torch_device)
ans = swiglu(a, b)
if sync is not None:
sync()
a, b = [
rearrange_if_needed(tensor, stride)
for tensor, stride in zip([a, b], [a_stride, b_stride])
]
a_tensor, b_tensor = [to_tensor(tensor, lib) for tensor in [a, b]]
descriptor = infiniopSwiGLUDescriptor_t()
check_error(
lib.infiniopCreateSwiGLUDescriptor(
handle,
ctypes.byref(descriptor),
b_tensor.descriptor,
a_tensor.descriptor,
b_tensor.descriptor,
)
)
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
for tensor in [a_tensor, b_tensor]:
tensor.descriptor.contents.invalidate()
def lib_swiglu():
check_error(
lib.infiniopSwiGLU(
descriptor, b_tensor.data, a_tensor.data, b_tensor.data, None
)
)
lib_swiglu()
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG:
debug(b, ans, atol=atol, rtol=rtol)
assert torch.allclose(b, ans, atol=atol, rtol=rtol)
print("in-place2 Test passed!")
# Profiling workflow
if PROFILE:
# fmt: off
profile_operation("PyTorch", lambda: swiglu(a, b), torch_device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_swiglu(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on
check_error(lib.infiniopDestroySwiGLUDescriptor(descriptor))
def test(lib, handle, torch_device, shape, a_stride, b_stride, c_stride, dtype, sync = None):
test_out_of_place(
lib, handle, torch_device, shape, a_stride, b_stride, c_stride, dtype, sync
)
test_in_place1(lib, handle, torch_device, shape, a_stride, b_stride, dtype, sync)
test_in_place2(lib, handle, torch_device, shape, a_stride, b_stride, dtype, sync)
if __name__ == "__main__":
args = get_args()
lib = open_lib()
......@@ -288,12 +203,13 @@ if __name__ == "__main__":
lib.infiniopDestroySwiGLUDescriptor.argtypes = [
infiniopSwiGLUDescriptor_t,
]
# Configure testing options
DEBUG = args.debug
PROFILE = args.profile
NUM_PRERUN = args.num_prerun
NUM_ITERATIONS = args.num_iterations
for device in get_test_devices(args):
test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment