Commit 04aa18f6 authored by xgqdut2016's avatar xgqdut2016
Browse files

issue/66: modified format

parent ca2f34cf
......@@ -29,16 +29,14 @@ _TEST_CASES = [
((32, 5, 5), None),
((32, 20, 512), None),
((32, 20, 512), (20480, 512, 1)), # Ascend 暂不支持非连续
((32, 20, 4, 512), None),
((32, 20, 4, 512), (81920, 2048, 512, 1)),
]
]
# Data types used for testing
_TENSOR_DTYPES = [torch.float16, torch.float32]
_TENSOR_DTYPES = [torch.float16]
# Tolerance map for different data types
_TOLERANCE_MAP = {
torch.float16: {'atol': 0, 'rtol': 1e-2},
torch.float32: {'atol': 0, 'rtol': 1e-3},
torch.float16: {"atol": 0, "rtol": 1e-2},
}
DEBUG = False
......@@ -46,6 +44,7 @@ PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000
class CausalSoftmaxDescriptor(Structure):
_fields_ = [("device", c_int32)]
......@@ -61,22 +60,15 @@ def causal_softmax(x):
return torch.nn.functional.softmax(masked, dim=-1).to(type)
def test(
lib,
handle,
torch_device,
x_shape,
x_stride=None,
dtype=torch.float16
):
def test(lib, handle, torch_device, x_shape, x_stride=None, dtype=torch.float16):
print(
f"Testing CausalSoftmax on {torch_device} with x_shape:{x_shape} x_stride:{x_stride} dtype:{dtype}"
)
x = torch.rand(x_shape, dtype=dtype).to(torch_device)
ans = causal_softmax(x)
x = rearrange_if_needed(x, x_stride)
x_tensor = to_tensor(x, lib)
......@@ -84,16 +76,13 @@ def test(
descriptor = infiniopCausalSoftmaxDescriptor_t()
check_error(
lib.infiniopCreateCausalSoftmaxDescriptor(
handle,
ctypes.byref(descriptor),
x_tensor.descriptor
handle, ctypes.byref(descriptor), x_tensor.descriptor
)
)
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
x_tensor.descriptor.contents.invalidate()
workspace_size = c_uint64(0)
check_error(
lib.infiniopGetCausalSoftmaxWorkspaceSize(
......@@ -101,6 +90,7 @@ def test(
)
)
workspace = create_workspace(workspace_size.value, x.device)
def lib_causal_softmax():
check_error(
lib.infiniopCausalSoftmax(
......@@ -111,6 +101,7 @@ def test(
None,
)
)
lib_causal_softmax()
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
......@@ -128,24 +119,23 @@ def test(
check_error(lib.infiniopDestroyCausalSoftmaxDescriptor(descriptor))
if __name__ == "__main__":
args = get_args()
lib = open_lib()
lib.infiniopCreateCausalSoftmaxDescriptor.restype = c_int32
lib.infiniopCreateCausalSoftmaxDescriptor.argtypes = [
infiniopHandle_t,
POINTER(infiniopCausalSoftmaxDescriptor_t),
infiniopTensorDescriptor_t,
]
lib.infiniopGetCausalSoftmaxWorkspaceSize.restype = c_int32
lib.infiniopGetCausalSoftmaxWorkspaceSize.argtypes = [
infiniopCausalSoftmaxDescriptor_t,
POINTER(c_uint64),
]
lib.infiniopCausalSoftmax.restype = c_int32
lib.infiniopCausalSoftmax.argtypes = [
infiniopCausalSoftmaxDescriptor_t,
......@@ -154,10 +144,12 @@ if __name__ == "__main__":
c_void_p,
c_void_p,
]
lib.infiniopDestroyCausalSoftmaxDescriptor.restype = c_int32
lib.infiniopDestroyCausalSoftmaxDescriptor.argtypes = [
infiniopCausalSoftmaxDescriptor_t,
]
# Configure testing options
DEBUG = args.debug
PROFILE = args.profile
......@@ -168,4 +160,3 @@ if __name__ == "__main__":
test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES)
print("\033[92mTest passed!\033[0m")
......@@ -12,15 +12,17 @@ from libinfiniop import (
create_workspace,
test_operator,
get_args,
debug,
debug_all,
get_tolerance,
profile_operation,
synchronize_device,
)
# ==============================================================================
# Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES = [
# voc, random_val, topp, topk, temperature
(512, 0.8, 0.8, 3, 0.5),
......@@ -36,9 +38,14 @@ _TEST_CASES = [
]
# Data types used for testing
_TENSOR_DTYPES = [torch.float16, torch.float32]
_TENSOR_DTYPES = [torch.float16]
_TOLERANCE_MAP = {
torch.float16: {"atol": 0, "rtol": 0},
}
DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000
......@@ -113,6 +120,7 @@ def test(
x_dtype=torch.float16,
):
print(f"Testing RandomSample on {torch_device} with voc:{voc} dtype:{x_dtype}")
data = torch.arange(voc).float() * 0.0001
_perm = torch.randperm(voc)
data = data[_perm].to(x_dtype).to(torch_device)
......@@ -122,9 +130,11 @@ def test(
)
else:
ans = random_sample_0(data)
indices = torch.zeros([1], dtype=torch.int64).to(torch_device)
x_tensor = to_tensor(data, lib)
indices_tensor = to_tensor(indices, lib)
x_tensor, indices_tensor = [to_tensor(tensor, lib) for tensor in [data, indices]]
indices_tensor.descriptor.contents.dt = U64 # treat int64 as uint64
descriptor = infiniopRandomSampleDescriptor_t()
......@@ -164,9 +174,19 @@ def test(
None,
)
)
if torch_device == "npu":
torch.npu.synchronize()
if torch_device == "npu":
synchronize_device(torch_device)
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG:
debug_all(
(indices[0].type(ans.dtype), data[indices[0]]),
(ans, data[ans]),
"or",
atol=atol,
rtol=rtol,
)
assert indices[0].type(ans.dtype) == ans or data[ans] == data[indices[0]]
# Profiling workflow
......@@ -184,23 +204,23 @@ def test(
check_error(lib.infiniopDestroyRandomSampleDescriptor(descriptor))
if __name__ == "__main__":
args = get_args()
lib = open_lib()
lib.infiniopCreateRandomSampleDescriptor.restype = c_int32
lib.infiniopCreateRandomSampleDescriptor.argtypes = [
infiniopHandle_t,
POINTER(infiniopRandomSampleDescriptor_t),
infiniopTensorDescriptor_t,
]
lib.infiniopGetRandomSampleWorkspaceSize.restype = c_int32
lib.infiniopGetRandomSampleWorkspaceSize.argtypes = [
infiniopRandomSampleDescriptor_t,
POINTER(c_uint64),
]
lib.infiniopRandomSample.restype = c_int32
lib.infiniopRandomSample.argtypes = [
infiniopRandomSampleDescriptor_t,
......@@ -214,11 +234,13 @@ if __name__ == "__main__":
c_float,
c_void_p,
]
lib.infiniopDestroyRandomSampleDescriptor.restype = c_int32
lib.infiniopDestroyRandomSampleDescriptor.argtypes = [
infiniopRandomSampleDescriptor_t,
]
DEBUG = args.debug
PROFILE = args.profile
NUM_PRERUN = args.num_prerun
NUM_ITERATIONS = args.num_iterations
......
......@@ -37,8 +37,8 @@ _TENSOR_DTYPES = [torch.float16, torch.float32]
# Tolerance map for different data types
_TOLERANCE_MAP = {
torch.float16: {"atol": 0, "rtol": 1e-3},
torch.float32: {"atol": 0, "rtol": 1e-3},
torch.float16: {"atol": 0, "rtol": 0},
torch.float32: {"atol": 0, "rtol": 0},
}
DEBUG = False
......@@ -47,7 +47,6 @@ NUM_PRERUN = 10
NUM_ITERATIONS = 1000
class RerrangeDescriptor(Structure):
_fields_ = [("device", c_int32)]
......@@ -68,6 +67,7 @@ def test(
print(
f"Testing Rerrange on {torch_device} with x_shape:{x_shape} x_stride:{x_stride} y_shape:{y_shape} y_stride:{y_stride} x_dtype:{x_dtype}"
)
x = torch.rand(x_shape, dtype=x_dtype).to(torch_device)
y = torch.zeros(y_shape, dtype=x_dtype).to(torch_device)
......@@ -77,7 +77,6 @@ def test(
]
x_tensor, y_tensor = [to_tensor(tensor, lib) for tensor in [x, y]]
descriptor = infiniopRearrangeDescriptor_t()
check_error(
lib.infiniopCreateRearrangeDescriptor(
......@@ -91,13 +90,9 @@ def test(
def lib_rearrange():
check_error(
lib.infiniopRearrange(
descriptor,
y_tensor.data,
x_tensor.data,
None
)
lib.infiniopRearrange(descriptor, y_tensor.data, x_tensor.data, None)
)
lib_rearrange()
# Validate results
......@@ -116,8 +111,6 @@ def test(
check_error(lib.infiniopDestroyRearrangeDescriptor(descriptor))
if __name__ == "__main__":
args = get_args()
lib = open_lib()
......@@ -129,6 +122,7 @@ if __name__ == "__main__":
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
]
lib.infiniopRearrange.restype = c_int32
lib.infiniopRearrange.argtypes = [
infiniopRearrangeDescriptor_t,
......@@ -136,6 +130,7 @@ if __name__ == "__main__":
c_void_p,
c_void_p,
]
lib.infiniopDestroyRearrangeDescriptor.restype = c_int32
lib.infiniopDestroyRearrangeDescriptor.argtypes = [infiniopRearrangeDescriptor_t]
......
......@@ -23,21 +23,20 @@ from libinfiniop import (
# Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES = [
# y_shape, x_shape, w_shape, y_stride, x_stride, w_dtype
((16, 2048), (16, 2048), (2048,), None, None,torch.float32),
((16, 2048), (16, 2048), (2048,), None, None, torch.float32),
((16, 2048), (16, 2048), (2048,), None, None, torch.float16),
((16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1),torch.float32),
((16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1), torch.float32),
((16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1), torch.float16),
]
# x types used for testing
_TENSOR_DTYPES = [torch.float16, torch.float32]
_TENSOR_DTYPES = [torch.float16]
# Tolerance map for different data types
_TOLERANCE_MAP = {
torch.float16: {"atol": 0, "rtol": 1e-2},
torch.float32: {"atol": 0, "rtol": 1e-3},
torch.float16: {"atol": 1e-3, "rtol": 1e-3},
}
DEBUG = False
......@@ -45,12 +44,14 @@ PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000
class RMSNormDescriptor(Structure):
_fields_ = [("device", c_int32)]
infiniopRMSNormDescriptor_t = POINTER(RMSNormDescriptor)
def rms_norm(x, w, eps):
input_dtype = x.dtype
hidden_states = x.to(torch.float32)
......@@ -69,9 +70,12 @@ def test(
y_stride,
x_stride,
dtype=torch.float16,
w_dtype=torch.float16):
print(f"Testing RMS_Norm on {torch_device} with y_shape:{y_shape} x_shape:{x_shape} w_shape:{w_shape}"
f" dtype:{dtype} w_dtype:{w_dtype}")
w_dtype=torch.float16,
):
print(
f"Testing RMS_Norm on {torch_device} with y_shape:{y_shape} x_shape:{x_shape} w_shape:{w_shape}"
f" dtype:{dtype} w_dtype:{w_dtype}"
)
y = torch.zeros(y_shape, dtype=dtype).to(torch_device)
x = torch.rand(x_shape, dtype=dtype).to(torch_device)
......@@ -80,18 +84,23 @@ def test(
eps = 1e-5
ans = rms_norm(x, w, eps)
x = rearrange_if_needed(x, x_stride)
y = rearrange_if_needed(y, y_stride)
x, y = [
rearrange_if_needed(tensor, stride)
for tensor, stride in zip([x, y], [x_stride, y_stride])
]
x_tensor, y_tensor, w_tensor = [to_tensor(tensor, lib) for tensor in [x, y, w]]
descriptor = infiniopRMSNormDescriptor_t()
w_dataType = 0 if w_dtype==torch.float16 else 1
check_error(
lib.infiniopCreateRMSNormDescriptor(
handle, ctypes.byref(descriptor), y_tensor.descriptor, x_tensor.descriptor,
w_tensor.descriptor, eps
handle,
ctypes.byref(descriptor),
y_tensor.descriptor,
x_tensor.descriptor,
w_tensor.descriptor,
eps,
)
)
......@@ -101,11 +110,10 @@ def test(
workspace_size = c_uint64(0)
check_error(
lib.infiniopGetRMSNormWorkspaceSize(
descriptor, ctypes.byref(workspace_size)
)
lib.infiniopGetRMSNormWorkspaceSize(descriptor, ctypes.byref(workspace_size))
)
workspace = create_workspace(workspace_size.value, y.device)
def lib_rms_norm():
check_error(
lib.infiniopRMSNorm(
......@@ -134,12 +142,10 @@ def test(
check_error(lib.infiniopDestroyRMSNormDescriptor(descriptor))
if __name__ == "__main__":
args = get_args()
lib = open_lib()
lib.infiniopCreateRMSNormDescriptor.restype = c_int32
lib.infiniopCreateRMSNormDescriptor.argtypes = [
infiniopHandle_t,
......@@ -166,6 +172,7 @@ if __name__ == "__main__":
c_void_p,
c_void_p,
]
lib.infiniopDestroyRMSNormDescriptor.restype = c_int32
lib.infiniopDestroyRMSNormDescriptor.argtypes = [
infiniopRMSNormDescriptor_t,
......@@ -182,5 +189,3 @@ if __name__ == "__main__":
test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES)
print("\033[92mTest passed!\033[0m")
......@@ -15,6 +15,7 @@ from libinfiniop import (
debug,
get_tolerance,
profile_operation,
synchronize_device,
)
# ==============================================================================
......@@ -33,12 +34,11 @@ _TEST_CASES = [
]
# Data types used for testing
_TENSOR_DTYPES = [torch.float16, torch.float32]
_TENSOR_DTYPES = [torch.float16]
# Tolerance map for different data types
_TOLERANCE_MAP = {
torch.float16: {"atol": 0, "rtol": 1e-2},
torch.float32: {"atol": 0, "rtol": 1e-3},
torch.float16: {"atol": 1e-4, "rtol": 1e-2},
}
DEBUG = False
......@@ -47,7 +47,6 @@ NUM_PRERUN = 10
NUM_ITERATIONS = 1000
class RoPEDescriptor(Structure):
_fields_ = [("device", c_int32)]
......@@ -96,14 +95,7 @@ def sin_cos_table(max_seq_len, dim, torch_device, theta):
return torch.sin(angles), torch.cos(angles)
def test(
lib,
handle,
torch_device,
shape,
strides=None,
dtype=torch.float16
):
def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16):
print(
f"Testing Rotary Positional Embedding on {torch_device} with shape:{shape} strides:{strides} and dtype:{dtype}"
)
......@@ -126,14 +118,15 @@ def test(
# 2x table length for test
sin_table, cos_table = sin_cos_table(t.shape[0] * 2, t.shape[2], t.device, theta)
t_tensor, sin_table_tensor, cos_table_tensor = [to_tensor(tensor, lib) for tensor in [t, sin_table, cos_table]]
t_tensor, sin_table_tensor, cos_table_tensor = [
to_tensor(tensor, lib) for tensor in [t, sin_table, cos_table]
]
pos_tensor = to_tensor(pos[: t.shape[0]], lib)
pos_tensor.descriptor.contents.dtype = InfiniDtype.U64
if torch_device == "npu":
torch.npu.synchronize()
synchronize_device(torch_device)
check_error(
lib.infiniopCreateRoPEDescriptor(
......@@ -171,6 +164,7 @@ def test(
)
lib_rope()
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG:
debug(t, ans, atol=atol, rtol=rtol)
......@@ -194,6 +188,7 @@ def test(
if __name__ == "__main__":
args = get_args()
lib = open_lib()
lib.infiniopCreateRoPEDescriptor.restype = c_int32
lib.infiniopCreateRoPEDescriptor.argtypes = [
infiniopHandle_t,
......@@ -203,11 +198,13 @@ if __name__ == "__main__":
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
]
lib.infiniopGetRoPEWorkspaceSize.restype = c_int32
lib.infiniopGetRoPEWorkspaceSize.argtypes = [
infiniopRoPEDescriptor_t,
POINTER(c_uint64),
]
lib.infiniopRoPE.restype = c_int32
lib.infiniopRoPE.argtypes = [
infiniopRoPEDescriptor_t,
......@@ -219,10 +216,12 @@ if __name__ == "__main__":
c_void_p,
c_void_p,
]
lib.infiniopDestroyRoPEDescriptor.restype = c_int32
lib.infiniopDestroyRoPEDescriptor.argtypes = [
infiniopRoPEDescriptor_t,
]
# Configure testing options
DEBUG = args.debug
PROFILE = args.profile
......
......@@ -16,29 +16,64 @@ from libinfiniop import (
get_tolerance,
profile_operation,
)
from enum import Enum, auto
# ==============================================================================
# Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES = [
# shape, a_stride, b_stride, c_stride
((13, 4), None, None, None),
((13, 4), (10, 1), (10, 1), (10, 1)),
((13, 4, 4), None, None, None),
((13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1)),
((16, 5632), None, None, None),
((16, 5632), (13312, 1), (13312, 1), (13312, 1)),
((4, 4, 5632), None, None, None),
((4, 4, 5632), (45056, 5632, 1), (45056, 5632, 1), (45056, 5632, 1)),
# shape, a_stride, b_stride, c_stride, inplace
((13, 4), None, None, None, Inplace.OUT_OF_PLACE),
((13, 4), None, None, None, Inplace.INPLACE_A),
((13, 4), None, None, None, Inplace.INPLACE_B),
((13, 4), (10, 1), (10, 1), (10, 1), Inplace.OUT_OF_PLACE),
((13, 4), (10, 1), (10, 1), (10, 1), Inplace.INPLACE_A),
((13, 4), (10, 1), (10, 1), (10, 1), Inplace.INPLACE_B),
((13, 4, 4), None, None, None, Inplace.OUT_OF_PLACE),
((13, 4, 4), None, None, None, Inplace.INPLACE_A),
((13, 4, 4), None, None, None, Inplace.INPLACE_B),
((13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1), Inplace.OUT_OF_PLACE),
((13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1), Inplace.INPLACE_A),
((13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1), Inplace.INPLACE_B),
((16, 5632), None, None, None, Inplace.OUT_OF_PLACE),
((16, 5632), None, None, None, Inplace.INPLACE_A),
((16, 5632), None, None, None, Inplace.INPLACE_B),
((16, 5632), (13312, 1), (13312, 1), (13312, 1), Inplace.OUT_OF_PLACE),
((16, 5632), (13312, 1), (13312, 1), (13312, 1), Inplace.INPLACE_A),
((16, 5632), (13312, 1), (13312, 1), (13312, 1), Inplace.INPLACE_B),
((4, 4, 5632), None, None, None, Inplace.OUT_OF_PLACE),
((4, 4, 5632), None, None, None, Inplace.INPLACE_A),
((4, 4, 5632), None, None, None, Inplace.INPLACE_B),
(
(4, 4, 5632),
(45056, 5632, 1),
(45056, 5632, 1),
(45056, 5632, 1),
Inplace.OUT_OF_PLACE,
),
(
(4, 4, 5632),
(45056, 5632, 1),
(45056, 5632, 1),
(45056, 5632, 1),
Inplace.INPLACE_A,
),
(
(4, 4, 5632),
(45056, 5632, 1),
(45056, 5632, 1),
(45056, 5632, 1),
Inplace.INPLACE_B,
),
]
# Data types used for testing
_TENSOR_DTYPES = [torch.float16, torch.float32]
_TENSOR_DTYPES = [torch.float16]
# Tolerance map for different data types
_TOLERANCE_MAP = {
torch.float16: {'atol': 0, 'rtol': 1e-2},
torch.float32: {'atol': 0, 'rtol': 1e-3},
torch.float16: {"atol": 1e-4, "rtol": 1e-2},
}
DEBUG = False
......@@ -46,6 +81,13 @@ PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000
class Inplace(Enum):
OUT_OF_PLACE = auto()
INPLACE_A = auto()
INPLACE_B = auto()
class SwiGLUDescriptor(Structure):
_fields_ = [("device", c_int32)]
......@@ -54,11 +96,10 @@ infiniopSwiGLUDescriptor_t = POINTER(SwiGLUDescriptor)
def swiglu(a, b):
return a * b / (1 + torch.exp(-b.float()).to(b.dtype))
def test_out_of_place(
def test(
lib,
handle,
torch_device,
......@@ -66,15 +107,21 @@ def test_out_of_place(
a_stride=None,
b_stride=None,
c_stride=None,
inplace=Inplace.OUT_OF_PLACE,
dtype=torch.float16,
sync=None,
):
print(
f"Testing SwiGLU on {torch_device} with shape:{shape} a_stride:{a_stride} b_stride:{b_stride} c_stride:{c_stride} dtype:{dtype}"
)
a = torch.rand(shape, dtype=dtype).to(torch_device)
b = torch.rand(shape, dtype=dtype).to(torch_device)
c = torch.rand(shape, dtype=dtype).to(torch_device)
c = (
torch.rand(c_shape, dtype=tensor_dtype).to(torch_device)
if inplace == Inplace.OUT_OF_PLACE
else (a if inplace == Inplace.INPLACE_A else b)
)
ans = swiglu(a, b)
......@@ -82,9 +129,12 @@ def test_out_of_place(
rearrange_if_needed(tensor, stride)
for tensor, stride in zip([a, b, c], [a_stride, b_stride, c_stride])
]
a_tensor, b_tensor, c_tensor = [to_tensor(tensor, lib) for tensor in [a, b, c]]
a_tensor, b_tensor = [to_tensor(tensor, lib) for tensor in [a, b]]
c_tensor = (
to_tensor(c, lib)
if inplace == Inplace.OUT_OF_PLACE
else (a_tensor if inplace == Inplace.INPLACE_A else b_tensor)
)
if sync is not None:
sync()
......@@ -106,13 +156,10 @@ def test_out_of_place(
def lib_swiglu():
check_error(
lib.infiniopSwiGLU(
descriptor,
c_tensor.data,
a_tensor.data,
b_tensor.data,
None
descriptor, c_tensor.data, a_tensor.data, b_tensor.data, None
)
)
lib_swiglu()
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
......@@ -130,139 +177,7 @@ def test_out_of_place(
check_error(lib.infiniopDestroySwiGLUDescriptor(descriptor))
def test_in_place1(
lib,
handle,
torch_device,
shape,
a_stride=None,
b_stride=None,
dtype=torch.float16,
sync=None,
):
a = torch.rand(shape, dtype=dtype).to(torch_device)
b = torch.rand(shape, dtype=dtype).to(torch_device)
ans = swiglu(a, b)
if sync is not None:
sync()
a, b = [
rearrange_if_needed(tensor, stride)
for tensor, stride in zip([a, b], [a_stride, b_stride])
]
a_tensor, b_tensor = [to_tensor(tensor, lib) for tensor in [a, b]]
descriptor = infiniopSwiGLUDescriptor_t()
check_error(
lib.infiniopCreateSwiGLUDescriptor(
handle,
ctypes.byref(descriptor),
a_tensor.descriptor,
a_tensor.descriptor,
b_tensor.descriptor,
)
)
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
for tensor in [a_tensor, b_tensor]:
tensor.descriptor.contents.invalidate()
def lib_swiglu():
check_error(
lib.infiniopSwiGLU(
descriptor, a_tensor.data, a_tensor.data, b_tensor.data, None
)
)
lib_swiglu()
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG:
debug(a, ans, atol=atol, rtol=rtol)
assert torch.allclose(a, ans, atol=atol, rtol=rtol)
print("in-place1 Test passed!")
# Profiling workflow
if PROFILE:
# fmt: off
profile_operation("PyTorch", lambda: swiglu(a, b), torch_device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_swiglu(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on
check_error(lib.infiniopDestroySwiGLUDescriptor(descriptor))
def test_in_place2(
lib,
handle,
torch_device,
shape,
a_stride=None,
b_stride=None,
dtype=torch.float16,
sync=None,
):
a = torch.rand(shape, dtype=dtype).to(torch_device)
b = torch.rand(shape, dtype=dtype).to(torch_device)
ans = swiglu(a, b)
if sync is not None:
sync()
a, b = [
rearrange_if_needed(tensor, stride)
for tensor, stride in zip([a, b], [a_stride, b_stride])
]
a_tensor, b_tensor = [to_tensor(tensor, lib) for tensor in [a, b]]
descriptor = infiniopSwiGLUDescriptor_t()
check_error(
lib.infiniopCreateSwiGLUDescriptor(
handle,
ctypes.byref(descriptor),
b_tensor.descriptor,
a_tensor.descriptor,
b_tensor.descriptor,
)
)
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
for tensor in [a_tensor, b_tensor]:
tensor.descriptor.contents.invalidate()
def lib_swiglu():
check_error(
lib.infiniopSwiGLU(
descriptor, b_tensor.data, a_tensor.data, b_tensor.data, None
)
)
lib_swiglu()
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG:
debug(b, ans, atol=atol, rtol=rtol)
assert torch.allclose(b, ans, atol=atol, rtol=rtol)
print("in-place2 Test passed!")
# Profiling workflow
if PROFILE:
# fmt: off
profile_operation("PyTorch", lambda: swiglu(a, b), torch_device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_swiglu(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on
check_error(lib.infiniopDestroySwiGLUDescriptor(descriptor))
def test(lib, handle, torch_device, shape, a_stride, b_stride, c_stride, dtype, sync = None):
test_out_of_place(
lib, handle, torch_device, shape, a_stride, b_stride, c_stride, dtype, sync
)
test_in_place1(lib, handle, torch_device, shape, a_stride, b_stride, dtype, sync)
test_in_place2(lib, handle, torch_device, shape, a_stride, b_stride, dtype, sync)
if __name__ == "__main__":
args = get_args()
lib = open_lib()
......@@ -288,6 +203,7 @@ if __name__ == "__main__":
lib.infiniopDestroySwiGLUDescriptor.argtypes = [
infiniopSwiGLUDescriptor_t,
]
# Configure testing options
DEBUG = args.debug
PROFILE = args.profile
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment