Commit 04aa18f6 authored by xgqdut2016's avatar xgqdut2016
Browse files

issue/66: modified format

parent ca2f34cf
...@@ -23,22 +23,20 @@ from libinfiniop import ( ...@@ -23,22 +23,20 @@ from libinfiniop import (
# These are not meant to be imported from other modules # These are not meant to be imported from other modules
_TEST_CASES = [ _TEST_CASES = [
# x_shape, x_stride # x_shape, x_stride
((32, 512), None), ((32, 512), None),
((32, 512), (1024, 1)), ((32, 512), (1024, 1)),
((32, 5, 5), None), ((32, 5, 5), None),
((32, 20, 512), None), ((32, 20, 512), None),
((32, 20, 512), (20480, 512, 1)), # Ascend 暂不支持非连续 ((32, 20, 512), (20480, 512, 1)), # Ascend 暂不支持非连续
((32, 20, 4, 512), None), ]
((32, 20, 4, 512), (81920, 2048, 512, 1)),
]
# Data types used for testing # Data types used for testing
_TENSOR_DTYPES = [torch.float16, torch.float32] _TENSOR_DTYPES = [torch.float16]
# Tolerance map for different data types # Tolerance map for different data types
_TOLERANCE_MAP = { _TOLERANCE_MAP = {
torch.float16: {'atol': 0, 'rtol': 1e-2}, torch.float16: {"atol": 0, "rtol": 1e-2},
torch.float32: {'atol': 0, 'rtol': 1e-3},
} }
DEBUG = False DEBUG = False
...@@ -46,6 +44,7 @@ PROFILE = False ...@@ -46,6 +44,7 @@ PROFILE = False
NUM_PRERUN = 10 NUM_PRERUN = 10
NUM_ITERATIONS = 1000 NUM_ITERATIONS = 1000
class CausalSoftmaxDescriptor(Structure): class CausalSoftmaxDescriptor(Structure):
_fields_ = [("device", c_int32)] _fields_ = [("device", c_int32)]
...@@ -61,39 +60,29 @@ def causal_softmax(x): ...@@ -61,39 +60,29 @@ def causal_softmax(x):
return torch.nn.functional.softmax(masked, dim=-1).to(type) return torch.nn.functional.softmax(masked, dim=-1).to(type)
def test( def test(lib, handle, torch_device, x_shape, x_stride=None, dtype=torch.float16):
lib,
handle,
torch_device,
x_shape,
x_stride=None,
dtype=torch.float16
):
print( print(
f"Testing CausalSoftmax on {torch_device} with x_shape:{x_shape} x_stride:{x_stride} dtype:{dtype}" f"Testing CausalSoftmax on {torch_device} with x_shape:{x_shape} x_stride:{x_stride} dtype:{dtype}"
) )
x = torch.rand(x_shape, dtype=dtype).to(torch_device) x = torch.rand(x_shape, dtype=dtype).to(torch_device)
ans = causal_softmax(x) ans = causal_softmax(x)
x = rearrange_if_needed(x, x_stride) x = rearrange_if_needed(x, x_stride)
x_tensor = to_tensor(x, lib) x_tensor = to_tensor(x, lib)
descriptor = infiniopCausalSoftmaxDescriptor_t() descriptor = infiniopCausalSoftmaxDescriptor_t()
check_error( check_error(
lib.infiniopCreateCausalSoftmaxDescriptor( lib.infiniopCreateCausalSoftmaxDescriptor(
handle, handle, ctypes.byref(descriptor), x_tensor.descriptor
ctypes.byref(descriptor),
x_tensor.descriptor
) )
) )
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
x_tensor.descriptor.contents.invalidate() x_tensor.descriptor.contents.invalidate()
workspace_size = c_uint64(0) workspace_size = c_uint64(0)
check_error( check_error(
lib.infiniopGetCausalSoftmaxWorkspaceSize( lib.infiniopGetCausalSoftmaxWorkspaceSize(
...@@ -101,6 +90,7 @@ def test( ...@@ -101,6 +90,7 @@ def test(
) )
) )
workspace = create_workspace(workspace_size.value, x.device) workspace = create_workspace(workspace_size.value, x.device)
def lib_causal_softmax(): def lib_causal_softmax():
check_error( check_error(
lib.infiniopCausalSoftmax( lib.infiniopCausalSoftmax(
...@@ -111,8 +101,9 @@ def test( ...@@ -111,8 +101,9 @@ def test(
None, None,
) )
) )
lib_causal_softmax() lib_causal_softmax()
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG: if DEBUG:
debug(x, ans, atol=atol, rtol=rtol) debug(x, ans, atol=atol, rtol=rtol)
...@@ -128,24 +119,23 @@ def test( ...@@ -128,24 +119,23 @@ def test(
check_error(lib.infiniopDestroyCausalSoftmaxDescriptor(descriptor)) check_error(lib.infiniopDestroyCausalSoftmaxDescriptor(descriptor))
if __name__ == "__main__": if __name__ == "__main__":
args = get_args() args = get_args()
lib = open_lib() lib = open_lib()
lib.infiniopCreateCausalSoftmaxDescriptor.restype = c_int32 lib.infiniopCreateCausalSoftmaxDescriptor.restype = c_int32
lib.infiniopCreateCausalSoftmaxDescriptor.argtypes = [ lib.infiniopCreateCausalSoftmaxDescriptor.argtypes = [
infiniopHandle_t, infiniopHandle_t,
POINTER(infiniopCausalSoftmaxDescriptor_t), POINTER(infiniopCausalSoftmaxDescriptor_t),
infiniopTensorDescriptor_t, infiniopTensorDescriptor_t,
] ]
lib.infiniopGetCausalSoftmaxWorkspaceSize.restype = c_int32 lib.infiniopGetCausalSoftmaxWorkspaceSize.restype = c_int32
lib.infiniopGetCausalSoftmaxWorkspaceSize.argtypes = [ lib.infiniopGetCausalSoftmaxWorkspaceSize.argtypes = [
infiniopCausalSoftmaxDescriptor_t, infiniopCausalSoftmaxDescriptor_t,
POINTER(c_uint64), POINTER(c_uint64),
] ]
lib.infiniopCausalSoftmax.restype = c_int32 lib.infiniopCausalSoftmax.restype = c_int32
lib.infiniopCausalSoftmax.argtypes = [ lib.infiniopCausalSoftmax.argtypes = [
infiniopCausalSoftmaxDescriptor_t, infiniopCausalSoftmaxDescriptor_t,
...@@ -154,18 +144,19 @@ if __name__ == "__main__": ...@@ -154,18 +144,19 @@ if __name__ == "__main__":
c_void_p, c_void_p,
c_void_p, c_void_p,
] ]
lib.infiniopDestroyCausalSoftmaxDescriptor.restype = c_int32 lib.infiniopDestroyCausalSoftmaxDescriptor.restype = c_int32
lib.infiniopDestroyCausalSoftmaxDescriptor.argtypes = [ lib.infiniopDestroyCausalSoftmaxDescriptor.argtypes = [
infiniopCausalSoftmaxDescriptor_t, infiniopCausalSoftmaxDescriptor_t,
] ]
# Configure testing options # Configure testing options
DEBUG = args.debug DEBUG = args.debug
PROFILE = args.profile PROFILE = args.profile
NUM_PRERUN = args.num_prerun NUM_PRERUN = args.num_prerun
NUM_ITERATIONS = args.num_iterations NUM_ITERATIONS = args.num_iterations
for device in get_test_devices(args): for device in get_test_devices(args):
test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES) test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES)
print("\033[92mTest passed!\033[0m") print("\033[92mTest passed!\033[0m")
...@@ -12,33 +12,40 @@ from libinfiniop import ( ...@@ -12,33 +12,40 @@ from libinfiniop import (
create_workspace, create_workspace,
test_operator, test_operator,
get_args, get_args,
debug, debug_all,
get_tolerance, get_tolerance,
profile_operation, profile_operation,
synchronize_device,
) )
# ============================================================================== # ==============================================================================
# Configuration (Internal Use Only) # Configuration (Internal Use Only)
# ============================================================================== # ==============================================================================
# These are not meant to be imported from other modules # These are not meant to be imported from other modules
_TEST_CASES = [ _TEST_CASES = [
# voc, random_val, topp, topk, temperature # voc, random_val, topp, topk, temperature
(512, 0.8, 0.8, 3, 0.5), (512, 0.8, 0.8, 3, 0.5),
(4096, 0.05, 0.9, 5, 1.0), (4096, 0.05, 0.9, 5, 1.0),
(16384, 0.15, 0.85, 10, 2.0), (16384, 0.15, 0.85, 10, 2.0),
(512, 0.08, 0, 3, 0.5), (512, 0.08, 0, 3, 0.5),
(4096, 0.5, 0.9, 1, 1.0), (4096, 0.5, 0.9, 1, 1.0),
(16384, 0.15, 0, 1, 2.0), (16384, 0.15, 0, 1, 2.0),
(16384, 0.15, 0, 1, 2.0), (16384, 0.15, 0, 1, 2.0),
(32000, 0.08, 0.8, 50, 1.0), (32000, 0.08, 0.8, 50, 1.0),
(32000, 0.08, 1.0, 25, 1.0), (32000, 0.08, 1.0, 25, 1.0),
# (119696, 0.01, 1.0, 100, 1.0), # (119696, 0.01, 1.0, 100, 1.0),
] ]
# Data types used for testing # Data types used for testing
_TENSOR_DTYPES = [torch.float16, torch.float32] _TENSOR_DTYPES = [torch.float16]
_TOLERANCE_MAP = {
torch.float16: {"atol": 0, "rtol": 0},
}
DEBUG = False
PROFILE = False PROFILE = False
NUM_PRERUN = 10 NUM_PRERUN = 10
NUM_ITERATIONS = 1000 NUM_ITERATIONS = 1000
...@@ -113,6 +120,7 @@ def test( ...@@ -113,6 +120,7 @@ def test(
x_dtype=torch.float16, x_dtype=torch.float16,
): ):
print(f"Testing RandomSample on {torch_device} with voc:{voc} dtype:{x_dtype}") print(f"Testing RandomSample on {torch_device} with voc:{voc} dtype:{x_dtype}")
data = torch.arange(voc).float() * 0.0001 data = torch.arange(voc).float() * 0.0001
_perm = torch.randperm(voc) _perm = torch.randperm(voc)
data = data[_perm].to(x_dtype).to(torch_device) data = data[_perm].to(x_dtype).to(torch_device)
...@@ -122,9 +130,11 @@ def test( ...@@ -122,9 +130,11 @@ def test(
) )
else: else:
ans = random_sample_0(data) ans = random_sample_0(data)
indices = torch.zeros([1], dtype=torch.int64).to(torch_device) indices = torch.zeros([1], dtype=torch.int64).to(torch_device)
x_tensor = to_tensor(data, lib)
indices_tensor = to_tensor(indices, lib) x_tensor, indices_tensor = [to_tensor(tensor, lib) for tensor in [data, indices]]
indices_tensor.descriptor.contents.dt = U64 # treat int64 as uint64 indices_tensor.descriptor.contents.dt = U64 # treat int64 as uint64
descriptor = infiniopRandomSampleDescriptor_t() descriptor = infiniopRandomSampleDescriptor_t()
...@@ -148,7 +158,7 @@ def test( ...@@ -148,7 +158,7 @@ def test(
) )
) )
workspace = create_workspace(workspace_size.value, torch_device) workspace = create_workspace(workspace_size.value, torch_device)
def lib_random_sample(): def lib_random_sample():
check_error( check_error(
lib.infiniopRandomSample( lib.infiniopRandomSample(
...@@ -164,11 +174,21 @@ def test( ...@@ -164,11 +174,21 @@ def test(
None, None,
) )
) )
if torch_device == "npu":
torch.npu.synchronize()
if torch_device == "npu":
synchronize_device(torch_device)
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG:
debug_all(
(indices[0].type(ans.dtype), data[indices[0]]),
(ans, data[ans]),
"or",
atol=atol,
rtol=rtol,
)
assert indices[0].type(ans.dtype) == ans or data[ans] == data[indices[0]] assert indices[0].type(ans.dtype) == ans or data[ans] == data[indices[0]]
# Profiling workflow # Profiling workflow
if PROFILE: if PROFILE:
# fmt: off # fmt: off
...@@ -184,23 +204,23 @@ def test( ...@@ -184,23 +204,23 @@ def test(
check_error(lib.infiniopDestroyRandomSampleDescriptor(descriptor)) check_error(lib.infiniopDestroyRandomSampleDescriptor(descriptor))
if __name__ == "__main__": if __name__ == "__main__":
args = get_args() args = get_args()
lib = open_lib() lib = open_lib()
lib.infiniopCreateRandomSampleDescriptor.restype = c_int32 lib.infiniopCreateRandomSampleDescriptor.restype = c_int32
lib.infiniopCreateRandomSampleDescriptor.argtypes = [ lib.infiniopCreateRandomSampleDescriptor.argtypes = [
infiniopHandle_t, infiniopHandle_t,
POINTER(infiniopRandomSampleDescriptor_t), POINTER(infiniopRandomSampleDescriptor_t),
infiniopTensorDescriptor_t, infiniopTensorDescriptor_t,
] ]
lib.infiniopGetRandomSampleWorkspaceSize.restype = c_int32 lib.infiniopGetRandomSampleWorkspaceSize.restype = c_int32
lib.infiniopGetRandomSampleWorkspaceSize.argtypes = [ lib.infiniopGetRandomSampleWorkspaceSize.argtypes = [
infiniopRandomSampleDescriptor_t, infiniopRandomSampleDescriptor_t,
POINTER(c_uint64), POINTER(c_uint64),
] ]
lib.infiniopRandomSample.restype = c_int32 lib.infiniopRandomSample.restype = c_int32
lib.infiniopRandomSample.argtypes = [ lib.infiniopRandomSample.argtypes = [
infiniopRandomSampleDescriptor_t, infiniopRandomSampleDescriptor_t,
...@@ -214,11 +234,13 @@ if __name__ == "__main__": ...@@ -214,11 +234,13 @@ if __name__ == "__main__":
c_float, c_float,
c_void_p, c_void_p,
] ]
lib.infiniopDestroyRandomSampleDescriptor.restype = c_int32 lib.infiniopDestroyRandomSampleDescriptor.restype = c_int32
lib.infiniopDestroyRandomSampleDescriptor.argtypes = [ lib.infiniopDestroyRandomSampleDescriptor.argtypes = [
infiniopRandomSampleDescriptor_t, infiniopRandomSampleDescriptor_t,
] ]
DEBUG = args.debug
PROFILE = args.profile PROFILE = args.profile
NUM_PRERUN = args.num_prerun NUM_PRERUN = args.num_prerun
NUM_ITERATIONS = args.num_iterations NUM_ITERATIONS = args.num_iterations
......
...@@ -23,13 +23,13 @@ from libinfiniop import ( ...@@ -23,13 +23,13 @@ from libinfiniop import (
# These are not meant to be imported from other modules # These are not meant to be imported from other modules
_TEST_CASES = [ _TEST_CASES = [
# ((src_shape, src_stride), (dst_shape, dst_stride)) # ((src_shape, src_stride), (dst_shape, dst_stride))
(((2, 4, 32), None), ((2, 4, 32), (256, 64, 1))), (((2, 4, 32), None), ((2, 4, 32), (256, 64, 1))),
(((32, 6, 64), (64, 2560, 1)), ((32, 6, 64), None)), (((32, 6, 64), (64, 2560, 1)), ((32, 6, 64), None)),
(((4, 6, 64), (64, 2560, 1)), ((4, 6, 64), (131072, 64, 1))), (((4, 6, 64), (64, 2560, 1)), ((4, 6, 64), (131072, 64, 1))),
(((1, 32, 64), (2048, 64, 1)), ((1, 32, 64), (2048, 64, 1))), (((1, 32, 64), (2048, 64, 1)), ((1, 32, 64), (2048, 64, 1))),
(((32, 1, 64), (64, 2560, 1)), ((32, 1, 64), (64, 64, 1))), (((32, 1, 64), (64, 2560, 1)), ((32, 1, 64), (64, 64, 1))),
(((4, 1, 64), (64, 2560, 1)), ((4, 1, 64), (64, 11264, 1))), (((4, 1, 64), (64, 2560, 1)), ((4, 1, 64), (64, 11264, 1))),
(((64,), (1,)), ((64,), (1,))), (((64,), (1,)), ((64,), (1,))),
] ]
# Data types used for testing # Data types used for testing
...@@ -37,8 +37,8 @@ _TENSOR_DTYPES = [torch.float16, torch.float32] ...@@ -37,8 +37,8 @@ _TENSOR_DTYPES = [torch.float16, torch.float32]
# Tolerance map for different data types # Tolerance map for different data types
_TOLERANCE_MAP = { _TOLERANCE_MAP = {
torch.float16: {"atol": 0, "rtol": 1e-3}, torch.float16: {"atol": 0, "rtol": 0},
torch.float32: {"atol": 0, "rtol": 1e-3}, torch.float32: {"atol": 0, "rtol": 0},
} }
DEBUG = False DEBUG = False
...@@ -47,7 +47,6 @@ NUM_PRERUN = 10 ...@@ -47,7 +47,6 @@ NUM_PRERUN = 10
NUM_ITERATIONS = 1000 NUM_ITERATIONS = 1000
class RerrangeDescriptor(Structure): class RerrangeDescriptor(Structure):
_fields_ = [("device", c_int32)] _fields_ = [("device", c_int32)]
...@@ -68,16 +67,16 @@ def test( ...@@ -68,16 +67,16 @@ def test(
print( print(
f"Testing Rerrange on {torch_device} with x_shape:{x_shape} x_stride:{x_stride} y_shape:{y_shape} y_stride:{y_stride} x_dtype:{x_dtype}" f"Testing Rerrange on {torch_device} with x_shape:{x_shape} x_stride:{x_stride} y_shape:{y_shape} y_stride:{y_stride} x_dtype:{x_dtype}"
) )
x = torch.rand(x_shape, dtype=x_dtype).to(torch_device) x = torch.rand(x_shape, dtype=x_dtype).to(torch_device)
y = torch.zeros(y_shape, dtype=x_dtype).to(torch_device) y = torch.zeros(y_shape, dtype=x_dtype).to(torch_device)
x, y = [ x, y = [
rearrange_if_needed(tensor, stride) rearrange_if_needed(tensor, stride)
for tensor, stride in zip([x, y], [x_stride, y_stride]) for tensor, stride in zip([x, y], [x_stride, y_stride])
] ]
x_tensor, y_tensor = [to_tensor(tensor, lib) for tensor in [x, y]] x_tensor, y_tensor = [to_tensor(tensor, lib) for tensor in [x, y]]
descriptor = infiniopRearrangeDescriptor_t() descriptor = infiniopRearrangeDescriptor_t()
check_error( check_error(
lib.infiniopCreateRearrangeDescriptor( lib.infiniopCreateRearrangeDescriptor(
...@@ -91,15 +90,11 @@ def test( ...@@ -91,15 +90,11 @@ def test(
def lib_rearrange(): def lib_rearrange():
check_error( check_error(
lib.infiniopRearrange( lib.infiniopRearrange(descriptor, y_tensor.data, x_tensor.data, None)
descriptor,
y_tensor.data,
x_tensor.data,
None
)
) )
lib_rearrange() lib_rearrange()
# Validate results # Validate results
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG: if DEBUG:
...@@ -116,8 +111,6 @@ def test( ...@@ -116,8 +111,6 @@ def test(
check_error(lib.infiniopDestroyRearrangeDescriptor(descriptor)) check_error(lib.infiniopDestroyRearrangeDescriptor(descriptor))
if __name__ == "__main__": if __name__ == "__main__":
args = get_args() args = get_args()
lib = open_lib() lib = open_lib()
...@@ -129,6 +122,7 @@ if __name__ == "__main__": ...@@ -129,6 +122,7 @@ if __name__ == "__main__":
infiniopTensorDescriptor_t, infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t, infiniopTensorDescriptor_t,
] ]
lib.infiniopRearrange.restype = c_int32 lib.infiniopRearrange.restype = c_int32
lib.infiniopRearrange.argtypes = [ lib.infiniopRearrange.argtypes = [
infiniopRearrangeDescriptor_t, infiniopRearrangeDescriptor_t,
...@@ -136,9 +130,10 @@ if __name__ == "__main__": ...@@ -136,9 +130,10 @@ if __name__ == "__main__":
c_void_p, c_void_p,
c_void_p, c_void_p,
] ]
lib.infiniopDestroyRearrangeDescriptor.restype = c_int32 lib.infiniopDestroyRearrangeDescriptor.restype = c_int32
lib.infiniopDestroyRearrangeDescriptor.argtypes = [infiniopRearrangeDescriptor_t] lib.infiniopDestroyRearrangeDescriptor.argtypes = [infiniopRearrangeDescriptor_t]
# Configure testing options # Configure testing options
DEBUG = args.debug DEBUG = args.debug
PROFILE = args.profile PROFILE = args.profile
......
...@@ -23,21 +23,20 @@ from libinfiniop import ( ...@@ -23,21 +23,20 @@ from libinfiniop import (
# Configuration (Internal Use Only) # Configuration (Internal Use Only)
# ============================================================================== # ==============================================================================
# These are not meant to be imported from other modules # These are not meant to be imported from other modules
_TEST_CASES = [ _TEST_CASES = [
# y_shape, x_shape, w_shape, y_stride, x_stride, w_dtype # y_shape, x_shape, w_shape, y_stride, x_stride, w_dtype
((16, 2048), (16, 2048), (2048,), None, None,torch.float32), ((16, 2048), (16, 2048), (2048,), None, None, torch.float32),
((16, 2048), (16, 2048), (2048,), None, None, torch.float16), ((16, 2048), (16, 2048), (2048,), None, None, torch.float16),
((16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1),torch.float32), ((16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1), torch.float32),
((16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1), torch.float16), ((16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1), torch.float16),
] ]
# x types used for testing # x types used for testing
_TENSOR_DTYPES = [torch.float16, torch.float32] _TENSOR_DTYPES = [torch.float16]
# Tolerance map for different data types # Tolerance map for different data types
_TOLERANCE_MAP = { _TOLERANCE_MAP = {
torch.float16: {"atol": 0, "rtol": 1e-2}, torch.float16: {"atol": 1e-3, "rtol": 1e-3},
torch.float32: {"atol": 0, "rtol": 1e-3},
} }
DEBUG = False DEBUG = False
...@@ -45,12 +44,14 @@ PROFILE = False ...@@ -45,12 +44,14 @@ PROFILE = False
NUM_PRERUN = 10 NUM_PRERUN = 10
NUM_ITERATIONS = 1000 NUM_ITERATIONS = 1000
class RMSNormDescriptor(Structure): class RMSNormDescriptor(Structure):
_fields_ = [("device", c_int32)] _fields_ = [("device", c_int32)]
infiniopRMSNormDescriptor_t = POINTER(RMSNormDescriptor) infiniopRMSNormDescriptor_t = POINTER(RMSNormDescriptor)
def rms_norm(x, w, eps): def rms_norm(x, w, eps):
input_dtype = x.dtype input_dtype = x.dtype
hidden_states = x.to(torch.float32) hidden_states = x.to(torch.float32)
...@@ -60,18 +61,21 @@ def rms_norm(x, w, eps): ...@@ -60,18 +61,21 @@ def rms_norm(x, w, eps):
def test( def test(
lib, lib,
handle, handle,
torch_device, torch_device,
y_shape, y_shape,
x_shape, x_shape,
w_shape, w_shape,
y_stride, y_stride,
x_stride, x_stride,
dtype=torch.float16, dtype=torch.float16,
w_dtype=torch.float16): w_dtype=torch.float16,
print(f"Testing RMS_Norm on {torch_device} with y_shape:{y_shape} x_shape:{x_shape} w_shape:{w_shape}" ):
f" dtype:{dtype} w_dtype:{w_dtype}") print(
f"Testing RMS_Norm on {torch_device} with y_shape:{y_shape} x_shape:{x_shape} w_shape:{w_shape}"
f" dtype:{dtype} w_dtype:{w_dtype}"
)
y = torch.zeros(y_shape, dtype=dtype).to(torch_device) y = torch.zeros(y_shape, dtype=dtype).to(torch_device)
x = torch.rand(x_shape, dtype=dtype).to(torch_device) x = torch.rand(x_shape, dtype=dtype).to(torch_device)
...@@ -80,18 +84,23 @@ def test( ...@@ -80,18 +84,23 @@ def test(
eps = 1e-5 eps = 1e-5
ans = rms_norm(x, w, eps) ans = rms_norm(x, w, eps)
x = rearrange_if_needed(x, x_stride) x, y = [
y = rearrange_if_needed(y, y_stride) rearrange_if_needed(tensor, stride)
for tensor, stride in zip([x, y], [x_stride, y_stride])
]
x_tensor, y_tensor, w_tensor = [to_tensor(tensor, lib) for tensor in [x, y, w]] x_tensor, y_tensor, w_tensor = [to_tensor(tensor, lib) for tensor in [x, y, w]]
descriptor = infiniopRMSNormDescriptor_t() descriptor = infiniopRMSNormDescriptor_t()
w_dataType = 0 if w_dtype==torch.float16 else 1
check_error( check_error(
lib.infiniopCreateRMSNormDescriptor( lib.infiniopCreateRMSNormDescriptor(
handle, ctypes.byref(descriptor), y_tensor.descriptor, x_tensor.descriptor, handle,
w_tensor.descriptor, eps ctypes.byref(descriptor),
y_tensor.descriptor,
x_tensor.descriptor,
w_tensor.descriptor,
eps,
) )
) )
...@@ -101,11 +110,10 @@ def test( ...@@ -101,11 +110,10 @@ def test(
workspace_size = c_uint64(0) workspace_size = c_uint64(0)
check_error( check_error(
lib.infiniopGetRMSNormWorkspaceSize( lib.infiniopGetRMSNormWorkspaceSize(descriptor, ctypes.byref(workspace_size))
descriptor, ctypes.byref(workspace_size)
)
) )
workspace = create_workspace(workspace_size.value, y.device) workspace = create_workspace(workspace_size.value, y.device)
def lib_rms_norm(): def lib_rms_norm():
check_error( check_error(
lib.infiniopRMSNorm( lib.infiniopRMSNorm(
...@@ -134,12 +142,10 @@ def test( ...@@ -134,12 +142,10 @@ def test(
check_error(lib.infiniopDestroyRMSNormDescriptor(descriptor)) check_error(lib.infiniopDestroyRMSNormDescriptor(descriptor))
if __name__ == "__main__": if __name__ == "__main__":
args = get_args() args = get_args()
lib = open_lib() lib = open_lib()
lib.infiniopCreateRMSNormDescriptor.restype = c_int32 lib.infiniopCreateRMSNormDescriptor.restype = c_int32
lib.infiniopCreateRMSNormDescriptor.argtypes = [ lib.infiniopCreateRMSNormDescriptor.argtypes = [
infiniopHandle_t, infiniopHandle_t,
...@@ -166,6 +172,7 @@ if __name__ == "__main__": ...@@ -166,6 +172,7 @@ if __name__ == "__main__":
c_void_p, c_void_p,
c_void_p, c_void_p,
] ]
lib.infiniopDestroyRMSNormDescriptor.restype = c_int32 lib.infiniopDestroyRMSNormDescriptor.restype = c_int32
lib.infiniopDestroyRMSNormDescriptor.argtypes = [ lib.infiniopDestroyRMSNormDescriptor.argtypes = [
infiniopRMSNormDescriptor_t, infiniopRMSNormDescriptor_t,
...@@ -182,5 +189,3 @@ if __name__ == "__main__": ...@@ -182,5 +189,3 @@ if __name__ == "__main__":
test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES) test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES)
print("\033[92mTest passed!\033[0m") print("\033[92mTest passed!\033[0m")
...@@ -15,6 +15,7 @@ from libinfiniop import ( ...@@ -15,6 +15,7 @@ from libinfiniop import (
debug, debug,
get_tolerance, get_tolerance,
profile_operation, profile_operation,
synchronize_device,
) )
# ============================================================================== # ==============================================================================
...@@ -23,22 +24,21 @@ from libinfiniop import ( ...@@ -23,22 +24,21 @@ from libinfiniop import (
# These are not meant to be imported from other modules # These are not meant to be imported from other modules
_TEST_CASES = [ _TEST_CASES = [
# (t_shape, t_strides) # (t_shape, t_strides)
((1, 32, 128), None), ((1, 32, 128), None),
((1, 32, 64), None), ((1, 32, 64), None),
# 昇腾暂不满足这个用例,最后一维度 <=32 会有问题,可能与其核心 # 昇腾暂不满足这个用例,最后一维度 <=32 会有问题,可能与其核心
# 接口 GatherMask 的内部实现相关,目前 48 64 128 都可以支持 # 接口 GatherMask 的内部实现相关,目前 48 64 128 都可以支持
((4, 1, 32), None), ((4, 1, 32), None),
((1, 32, 128), None), ((1, 32, 128), None),
((3, 32, 128), (8000, 200, 1)), ((3, 32, 128), (8000, 200, 1)),
] ]
# Data types used for testing # Data types used for testing
_TENSOR_DTYPES = [torch.float16, torch.float32] _TENSOR_DTYPES = [torch.float16]
# Tolerance map for different data types # Tolerance map for different data types
_TOLERANCE_MAP = { _TOLERANCE_MAP = {
torch.float16: {"atol": 0, "rtol": 1e-2}, torch.float16: {"atol": 1e-4, "rtol": 1e-2},
torch.float32: {"atol": 0, "rtol": 1e-3},
} }
DEBUG = False DEBUG = False
...@@ -47,7 +47,6 @@ NUM_PRERUN = 10 ...@@ -47,7 +47,6 @@ NUM_PRERUN = 10
NUM_ITERATIONS = 1000 NUM_ITERATIONS = 1000
class RoPEDescriptor(Structure): class RoPEDescriptor(Structure):
_fields_ = [("device", c_int32)] _fields_ = [("device", c_int32)]
...@@ -96,14 +95,7 @@ def sin_cos_table(max_seq_len, dim, torch_device, theta): ...@@ -96,14 +95,7 @@ def sin_cos_table(max_seq_len, dim, torch_device, theta):
return torch.sin(angles), torch.cos(angles) return torch.sin(angles), torch.cos(angles)
def test( def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16):
lib,
handle,
torch_device,
shape,
strides=None,
dtype=torch.float16
):
print( print(
f"Testing Rotary Positional Embedding on {torch_device} with shape:{shape} strides:{strides} and dtype:{dtype}" f"Testing Rotary Positional Embedding on {torch_device} with shape:{shape} strides:{strides} and dtype:{dtype}"
) )
...@@ -126,14 +118,15 @@ def test( ...@@ -126,14 +118,15 @@ def test(
# 2x table length for test # 2x table length for test
sin_table, cos_table = sin_cos_table(t.shape[0] * 2, t.shape[2], t.device, theta) sin_table, cos_table = sin_cos_table(t.shape[0] * 2, t.shape[2], t.device, theta)
t_tensor, sin_table_tensor, cos_table_tensor = [to_tensor(tensor, lib) for tensor in [t, sin_table, cos_table]] t_tensor, sin_table_tensor, cos_table_tensor = [
to_tensor(tensor, lib) for tensor in [t, sin_table, cos_table]
]
pos_tensor = to_tensor(pos[: t.shape[0]], lib) pos_tensor = to_tensor(pos[: t.shape[0]], lib)
pos_tensor.descriptor.contents.dtype = InfiniDtype.U64 pos_tensor.descriptor.contents.dtype = InfiniDtype.U64
if torch_device == "npu": if torch_device == "npu":
torch.npu.synchronize() synchronize_device(torch_device)
check_error( check_error(
lib.infiniopCreateRoPEDescriptor( lib.infiniopCreateRoPEDescriptor(
...@@ -171,11 +164,12 @@ def test( ...@@ -171,11 +164,12 @@ def test(
) )
lib_rope() lib_rope()
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG: if DEBUG:
debug(t, ans, atol=atol, rtol=rtol) debug(t, ans, atol=atol, rtol=rtol)
assert torch.allclose(t, ans, atol=atol, rtol=rtol) assert torch.allclose(t, ans, atol=atol, rtol=rtol)
if PROFILE: if PROFILE:
profile_operation( profile_operation(
"PyTorch", "PyTorch",
...@@ -194,6 +188,7 @@ def test( ...@@ -194,6 +188,7 @@ def test(
if __name__ == "__main__": if __name__ == "__main__":
args = get_args() args = get_args()
lib = open_lib() lib = open_lib()
lib.infiniopCreateRoPEDescriptor.restype = c_int32 lib.infiniopCreateRoPEDescriptor.restype = c_int32
lib.infiniopCreateRoPEDescriptor.argtypes = [ lib.infiniopCreateRoPEDescriptor.argtypes = [
infiniopHandle_t, infiniopHandle_t,
...@@ -203,11 +198,13 @@ if __name__ == "__main__": ...@@ -203,11 +198,13 @@ if __name__ == "__main__":
infiniopTensorDescriptor_t, infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t, infiniopTensorDescriptor_t,
] ]
lib.infiniopGetRoPEWorkspaceSize.restype = c_int32 lib.infiniopGetRoPEWorkspaceSize.restype = c_int32
lib.infiniopGetRoPEWorkspaceSize.argtypes = [ lib.infiniopGetRoPEWorkspaceSize.argtypes = [
infiniopRoPEDescriptor_t, infiniopRoPEDescriptor_t,
POINTER(c_uint64), POINTER(c_uint64),
] ]
lib.infiniopRoPE.restype = c_int32 lib.infiniopRoPE.restype = c_int32
lib.infiniopRoPE.argtypes = [ lib.infiniopRoPE.argtypes = [
infiniopRoPEDescriptor_t, infiniopRoPEDescriptor_t,
...@@ -219,10 +216,12 @@ if __name__ == "__main__": ...@@ -219,10 +216,12 @@ if __name__ == "__main__":
c_void_p, c_void_p,
c_void_p, c_void_p,
] ]
lib.infiniopDestroyRoPEDescriptor.restype = c_int32 lib.infiniopDestroyRoPEDescriptor.restype = c_int32
lib.infiniopDestroyRoPEDescriptor.argtypes = [ lib.infiniopDestroyRoPEDescriptor.argtypes = [
infiniopRoPEDescriptor_t, infiniopRoPEDescriptor_t,
] ]
# Configure testing options # Configure testing options
DEBUG = args.debug DEBUG = args.debug
PROFILE = args.profile PROFILE = args.profile
......
...@@ -16,29 +16,64 @@ from libinfiniop import ( ...@@ -16,29 +16,64 @@ from libinfiniop import (
get_tolerance, get_tolerance,
profile_operation, profile_operation,
) )
from enum import Enum, auto
# ============================================================================== # ==============================================================================
# Configuration (Internal Use Only) # Configuration (Internal Use Only)
# ============================================================================== # ==============================================================================
# These are not meant to be imported from other modules # These are not meant to be imported from other modules
_TEST_CASES = [ _TEST_CASES = [
# shape, a_stride, b_stride, c_stride # shape, a_stride, b_stride, c_stride, inplace
((13, 4), None, None, None), ((13, 4), None, None, None, Inplace.OUT_OF_PLACE),
((13, 4), (10, 1), (10, 1), (10, 1)), ((13, 4), None, None, None, Inplace.INPLACE_A),
((13, 4, 4), None, None, None), ((13, 4), None, None, None, Inplace.INPLACE_B),
((13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1)), ((13, 4), (10, 1), (10, 1), (10, 1), Inplace.OUT_OF_PLACE),
((16, 5632), None, None, None), ((13, 4), (10, 1), (10, 1), (10, 1), Inplace.INPLACE_A),
((16, 5632), (13312, 1), (13312, 1), (13312, 1)), ((13, 4), (10, 1), (10, 1), (10, 1), Inplace.INPLACE_B),
((4, 4, 5632), None, None, None), ((13, 4, 4), None, None, None, Inplace.OUT_OF_PLACE),
((4, 4, 5632), (45056, 5632, 1), (45056, 5632, 1), (45056, 5632, 1)), ((13, 4, 4), None, None, None, Inplace.INPLACE_A),
((13, 4, 4), None, None, None, Inplace.INPLACE_B),
((13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1), Inplace.OUT_OF_PLACE),
((13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1), Inplace.INPLACE_A),
((13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1), Inplace.INPLACE_B),
((16, 5632), None, None, None, Inplace.OUT_OF_PLACE),
((16, 5632), None, None, None, Inplace.INPLACE_A),
((16, 5632), None, None, None, Inplace.INPLACE_B),
((16, 5632), (13312, 1), (13312, 1), (13312, 1), Inplace.OUT_OF_PLACE),
((16, 5632), (13312, 1), (13312, 1), (13312, 1), Inplace.INPLACE_A),
((16, 5632), (13312, 1), (13312, 1), (13312, 1), Inplace.INPLACE_B),
((4, 4, 5632), None, None, None, Inplace.OUT_OF_PLACE),
((4, 4, 5632), None, None, None, Inplace.INPLACE_A),
((4, 4, 5632), None, None, None, Inplace.INPLACE_B),
(
(4, 4, 5632),
(45056, 5632, 1),
(45056, 5632, 1),
(45056, 5632, 1),
Inplace.OUT_OF_PLACE,
),
(
(4, 4, 5632),
(45056, 5632, 1),
(45056, 5632, 1),
(45056, 5632, 1),
Inplace.INPLACE_A,
),
(
(4, 4, 5632),
(45056, 5632, 1),
(45056, 5632, 1),
(45056, 5632, 1),
Inplace.INPLACE_B,
),
] ]
# Data types used for testing # Data types used for testing
_TENSOR_DTYPES = [torch.float16, torch.float32] _TENSOR_DTYPES = [torch.float16]
# Tolerance map for different data types # Tolerance map for different data types
_TOLERANCE_MAP = { _TOLERANCE_MAP = {
torch.float16: {'atol': 0, 'rtol': 1e-2}, torch.float16: {"atol": 1e-4, "rtol": 1e-2},
torch.float32: {'atol': 0, 'rtol': 1e-3},
} }
DEBUG = False DEBUG = False
...@@ -46,6 +81,13 @@ PROFILE = False ...@@ -46,6 +81,13 @@ PROFILE = False
NUM_PRERUN = 10 NUM_PRERUN = 10
NUM_ITERATIONS = 1000 NUM_ITERATIONS = 1000
class Inplace(Enum):
OUT_OF_PLACE = auto()
INPLACE_A = auto()
INPLACE_B = auto()
class SwiGLUDescriptor(Structure): class SwiGLUDescriptor(Structure):
_fields_ = [("device", c_int32)] _fields_ = [("device", c_int32)]
...@@ -54,11 +96,10 @@ infiniopSwiGLUDescriptor_t = POINTER(SwiGLUDescriptor) ...@@ -54,11 +96,10 @@ infiniopSwiGLUDescriptor_t = POINTER(SwiGLUDescriptor)
def swiglu(a, b): def swiglu(a, b):
return a * b / (1 + torch.exp(-b.float()).to(b.dtype)) return a * b / (1 + torch.exp(-b.float()).to(b.dtype))
def test_out_of_place( def test(
lib, lib,
handle, handle,
torch_device, torch_device,
...@@ -66,15 +107,21 @@ def test_out_of_place( ...@@ -66,15 +107,21 @@ def test_out_of_place(
a_stride=None, a_stride=None,
b_stride=None, b_stride=None,
c_stride=None, c_stride=None,
inplace=Inplace.OUT_OF_PLACE,
dtype=torch.float16, dtype=torch.float16,
sync=None, sync=None,
): ):
print( print(
f"Testing SwiGLU on {torch_device} with shape:{shape} a_stride:{a_stride} b_stride:{b_stride} c_stride:{c_stride} dtype:{dtype}" f"Testing SwiGLU on {torch_device} with shape:{shape} a_stride:{a_stride} b_stride:{b_stride} c_stride:{c_stride} dtype:{dtype}"
) )
a = torch.rand(shape, dtype=dtype).to(torch_device) a = torch.rand(shape, dtype=dtype).to(torch_device)
b = torch.rand(shape, dtype=dtype).to(torch_device) b = torch.rand(shape, dtype=dtype).to(torch_device)
c = torch.rand(shape, dtype=dtype).to(torch_device) c = (
torch.rand(c_shape, dtype=tensor_dtype).to(torch_device)
if inplace == Inplace.OUT_OF_PLACE
else (a if inplace == Inplace.INPLACE_A else b)
)
ans = swiglu(a, b) ans = swiglu(a, b)
...@@ -82,9 +129,12 @@ def test_out_of_place( ...@@ -82,9 +129,12 @@ def test_out_of_place(
rearrange_if_needed(tensor, stride) rearrange_if_needed(tensor, stride)
for tensor, stride in zip([a, b, c], [a_stride, b_stride, c_stride]) for tensor, stride in zip([a, b, c], [a_stride, b_stride, c_stride])
] ]
a_tensor, b_tensor, c_tensor = [to_tensor(tensor, lib) for tensor in [a, b, c]] a_tensor, b_tensor = [to_tensor(tensor, lib) for tensor in [a, b]]
c_tensor = (
to_tensor(c, lib)
if inplace == Inplace.OUT_OF_PLACE
else (a_tensor if inplace == Inplace.INPLACE_A else b_tensor)
)
if sync is not None: if sync is not None:
sync() sync()
...@@ -106,13 +156,10 @@ def test_out_of_place( ...@@ -106,13 +156,10 @@ def test_out_of_place(
def lib_swiglu(): def lib_swiglu():
check_error( check_error(
lib.infiniopSwiGLU( lib.infiniopSwiGLU(
descriptor, descriptor, c_tensor.data, a_tensor.data, b_tensor.data, None
c_tensor.data,
a_tensor.data,
b_tensor.data,
None
) )
) )
lib_swiglu() lib_swiglu()
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
...@@ -130,139 +177,7 @@ def test_out_of_place( ...@@ -130,139 +177,7 @@ def test_out_of_place(
check_error(lib.infiniopDestroySwiGLUDescriptor(descriptor)) check_error(lib.infiniopDestroySwiGLUDescriptor(descriptor))
def test_in_place1(
lib,
handle,
torch_device,
shape,
a_stride=None,
b_stride=None,
dtype=torch.float16,
sync=None,
):
a = torch.rand(shape, dtype=dtype).to(torch_device)
b = torch.rand(shape, dtype=dtype).to(torch_device)
ans = swiglu(a, b)
if sync is not None:
sync()
a, b = [
rearrange_if_needed(tensor, stride)
for tensor, stride in zip([a, b], [a_stride, b_stride])
]
a_tensor, b_tensor = [to_tensor(tensor, lib) for tensor in [a, b]]
descriptor = infiniopSwiGLUDescriptor_t()
check_error(
lib.infiniopCreateSwiGLUDescriptor(
handle,
ctypes.byref(descriptor),
a_tensor.descriptor,
a_tensor.descriptor,
b_tensor.descriptor,
)
)
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
for tensor in [a_tensor, b_tensor]:
tensor.descriptor.contents.invalidate()
def lib_swiglu():
check_error(
lib.infiniopSwiGLU(
descriptor, a_tensor.data, a_tensor.data, b_tensor.data, None
)
)
lib_swiglu()
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG:
debug(a, ans, atol=atol, rtol=rtol)
assert torch.allclose(a, ans, atol=atol, rtol=rtol)
print("in-place1 Test passed!")
# Profiling workflow
if PROFILE:
# fmt: off
profile_operation("PyTorch", lambda: swiglu(a, b), torch_device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_swiglu(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on
check_error(lib.infiniopDestroySwiGLUDescriptor(descriptor))
def test_in_place2(
lib,
handle,
torch_device,
shape,
a_stride=None,
b_stride=None,
dtype=torch.float16,
sync=None,
):
a = torch.rand(shape, dtype=dtype).to(torch_device)
b = torch.rand(shape, dtype=dtype).to(torch_device)
ans = swiglu(a, b)
if sync is not None:
sync()
a, b = [
rearrange_if_needed(tensor, stride)
for tensor, stride in zip([a, b], [a_stride, b_stride])
]
a_tensor, b_tensor = [to_tensor(tensor, lib) for tensor in [a, b]]
descriptor = infiniopSwiGLUDescriptor_t()
check_error(
lib.infiniopCreateSwiGLUDescriptor(
handle,
ctypes.byref(descriptor),
b_tensor.descriptor,
a_tensor.descriptor,
b_tensor.descriptor,
)
)
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
for tensor in [a_tensor, b_tensor]:
tensor.descriptor.contents.invalidate()
def lib_swiglu():
check_error(
lib.infiniopSwiGLU(
descriptor, b_tensor.data, a_tensor.data, b_tensor.data, None
)
)
lib_swiglu()
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG:
debug(b, ans, atol=atol, rtol=rtol)
assert torch.allclose(b, ans, atol=atol, rtol=rtol)
print("in-place2 Test passed!")
# Profiling workflow
if PROFILE:
# fmt: off
profile_operation("PyTorch", lambda: swiglu(a, b), torch_device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_swiglu(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on
check_error(lib.infiniopDestroySwiGLUDescriptor(descriptor))
def test(lib, handle, torch_device, shape, a_stride, b_stride, c_stride, dtype, sync = None):
test_out_of_place(
lib, handle, torch_device, shape, a_stride, b_stride, c_stride, dtype, sync
)
test_in_place1(lib, handle, torch_device, shape, a_stride, b_stride, dtype, sync)
test_in_place2(lib, handle, torch_device, shape, a_stride, b_stride, dtype, sync)
if __name__ == "__main__": if __name__ == "__main__":
args = get_args() args = get_args()
lib = open_lib() lib = open_lib()
...@@ -288,12 +203,13 @@ if __name__ == "__main__": ...@@ -288,12 +203,13 @@ if __name__ == "__main__":
lib.infiniopDestroySwiGLUDescriptor.argtypes = [ lib.infiniopDestroySwiGLUDescriptor.argtypes = [
infiniopSwiGLUDescriptor_t, infiniopSwiGLUDescriptor_t,
] ]
# Configure testing options # Configure testing options
DEBUG = args.debug DEBUG = args.debug
PROFILE = args.profile PROFILE = args.profile
NUM_PRERUN = args.num_prerun NUM_PRERUN = args.num_prerun
NUM_ITERATIONS = args.num_iterations NUM_ITERATIONS = args.num_iterations
for device in get_test_devices(args): for device in get_test_devices(args):
test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES) test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment