Commit 7bd656b7 authored by YdrMaster's avatar YdrMaster
Browse files

issue/52: 格式化所有 python 文件,并标注排除格式化的区域


Signed-off-by: default avatarYdrMaster <ydrml@hotmail.com>
parent ec0ff893
...@@ -56,13 +56,21 @@ def test( ...@@ -56,13 +56,21 @@ def test(
a = torch.rand(a_shape, dtype=tensor_dtype).to(torch_device) a = torch.rand(a_shape, dtype=tensor_dtype).to(torch_device)
b = torch.rand(b_shape, dtype=tensor_dtype).to(torch_device) b = torch.rand(b_shape, dtype=tensor_dtype).to(torch_device)
c = torch.rand(c_shape, dtype=tensor_dtype).to(torch_device) if inplace == Inplace.OUT_OF_PLACE else (a if inplace == Inplace.INPLACE_A else b) c = (
torch.rand(c_shape, dtype=tensor_dtype).to(torch_device)
if inplace == Inplace.OUT_OF_PLACE
else (a if inplace == Inplace.INPLACE_A else b)
)
ans = add(a, b) ans = add(a, b)
a_tensor = to_tensor(a, lib) a_tensor = to_tensor(a, lib)
b_tensor = to_tensor(b, lib) b_tensor = to_tensor(b, lib)
c_tensor = to_tensor(c, lib) if inplace == Inplace.OUT_OF_PLACE else (a_tensor if inplace == Inplace.INPLACE_A else b_tensor) c_tensor = (
to_tensor(c, lib)
if inplace == Inplace.OUT_OF_PLACE
else (a_tensor if inplace == Inplace.INPLACE_A else b_tensor)
)
descriptor = infiniopAddDescriptor_t() descriptor = infiniopAddDescriptor_t()
check_error( check_error(
...@@ -91,8 +99,10 @@ def test_cpu(lib, test_cases): ...@@ -91,8 +99,10 @@ def test_cpu(lib, test_cases):
device = DeviceEnum.DEVICE_CPU device = DeviceEnum.DEVICE_CPU
handle = create_handle(lib, device) handle = create_handle(lib, device)
for c_shape, a_shape, b_shape, inplace in test_cases: for c_shape, a_shape, b_shape, inplace in test_cases:
# fmt: off
test(lib, handle, "cpu", c_shape, a_shape, b_shape, tensor_dtype=torch.float16, inplace=inplace) test(lib, handle, "cpu", c_shape, a_shape, b_shape, tensor_dtype=torch.float16, inplace=inplace)
test(lib, handle, "cpu", c_shape, a_shape, b_shape, tensor_dtype=torch.float32, inplace=inplace) test(lib, handle, "cpu", c_shape, a_shape, b_shape, tensor_dtype=torch.float32, inplace=inplace)
# fmt: on
destroy_handle(lib, handle) destroy_handle(lib, handle)
...@@ -100,8 +110,10 @@ def test_cuda(lib, test_cases): ...@@ -100,8 +110,10 @@ def test_cuda(lib, test_cases):
device = DeviceEnum.DEVICE_CUDA device = DeviceEnum.DEVICE_CUDA
handle = create_handle(lib, device) handle = create_handle(lib, device)
for c_shape, a_shape, b_shape, inplace in test_cases: for c_shape, a_shape, b_shape, inplace in test_cases:
# fmt: off
test(lib, handle, "cuda", c_shape, a_shape, b_shape, tensor_dtype=torch.float16, inplace=inplace) test(lib, handle, "cuda", c_shape, a_shape, b_shape, tensor_dtype=torch.float16, inplace=inplace)
test(lib, handle, "cuda", c_shape, a_shape, b_shape, tensor_dtype=torch.float32, inplace=inplace) test(lib, handle, "cuda", c_shape, a_shape, b_shape, tensor_dtype=torch.float32, inplace=inplace)
# fmt: on
destroy_handle(lib, handle) destroy_handle(lib, handle)
...@@ -111,13 +123,16 @@ def test_bang(lib, test_cases): ...@@ -111,13 +123,16 @@ def test_bang(lib, test_cases):
device = DeviceEnum.DEVICE_BANG device = DeviceEnum.DEVICE_BANG
handle = create_handle(lib, device) handle = create_handle(lib, device)
for c_shape, a_shape, b_shape, inplace in test_cases: for c_shape, a_shape, b_shape, inplace in test_cases:
# fmt: off
test(lib, handle, "mlu", c_shape, a_shape, b_shape, tensor_dtype=torch.float16, inplace=inplace) test(lib, handle, "mlu", c_shape, a_shape, b_shape, tensor_dtype=torch.float16, inplace=inplace)
test(lib, handle, "mlu", c_shape, a_shape, b_shape, tensor_dtype=torch.float32, inplace=inplace) test(lib, handle, "mlu", c_shape, a_shape, b_shape, tensor_dtype=torch.float32, inplace=inplace)
# fmt: on
destroy_handle(lib, handle) destroy_handle(lib, handle)
if __name__ == "__main__": if __name__ == "__main__":
test_cases = [ test_cases = [
# fmt: off
# c_shape, a_shape, b_shape, inplace # c_shape, a_shape, b_shape, inplace
# ((32, 150, 512000), (32, 150, 512000), (32, 150, 512000), Inplace.OUT_OF_PLACE), # ((32, 150, 512000), (32, 150, 512000), (32, 150, 512000), Inplace.OUT_OF_PLACE),
# ((32, 150, 51200), (32, 150, 51200), (32, 150, 1), Inplace.OUT_OF_PLACE), # ((32, 150, 51200), (32, 150, 51200), (32, 150, 1), Inplace.OUT_OF_PLACE),
...@@ -133,6 +148,7 @@ if __name__ == "__main__": ...@@ -133,6 +148,7 @@ if __name__ == "__main__":
((2, 4, 3), (2, 1, 3), (4, 3), Inplace.OUT_OF_PLACE), ((2, 4, 3), (2, 1, 3), (4, 3), Inplace.OUT_OF_PLACE),
((2, 3, 4, 5), (2, 3, 4, 5), (5,), Inplace.OUT_OF_PLACE), ((2, 3, 4, 5), (2, 3, 4, 5), (5,), Inplace.OUT_OF_PLACE),
((3, 2, 4, 5), (4, 5), (3, 2, 1, 1), Inplace.OUT_OF_PLACE), ((3, 2, 4, 5), (4, 5), (3, 2, 1, 1), Inplace.OUT_OF_PLACE),
# fmt: on
] ]
args = get_args() args = get_args()
lib = open_lib() lib = open_lib()
......
...@@ -35,7 +35,7 @@ class AvgPoolDescriptor(Structure): ...@@ -35,7 +35,7 @@ class AvgPoolDescriptor(Structure):
infiniopAvgPoolDescriptor_t = POINTER(AvgPoolDescriptor) infiniopAvgPoolDescriptor_t = POINTER(AvgPoolDescriptor)
def pool(x, k, padding, stride, dilation = 1): def pool(x, k, padding, stride, dilation=1):
pooling_layers = { pooling_layers = {
1: torch.nn.AvgPool1d, 1: torch.nn.AvgPool1d,
2: torch.nn.AvgPool2d, 2: torch.nn.AvgPool2d,
...@@ -48,7 +48,9 @@ def pool(x, k, padding, stride, dilation = 1): ...@@ -48,7 +48,9 @@ def pool(x, k, padding, stride, dilation = 1):
return None return None
if ndim == 3 and x.dtype == torch.float16: if ndim == 3 and x.dtype == torch.float16:
ans = pooling_layers[ndim](k, stride=stride, padding=padding)(x.to(torch.float32)).to(torch.float16) ans = pooling_layers[ndim](k, stride=stride, padding=padding)(
x.to(torch.float32)
).to(torch.float16)
else: else:
ans = pooling_layers[ndim](k, stride=stride, padding=padding)(x) ans = pooling_layers[ndim](k, stride=stride, padding=padding)(x)
if PROFILE: if PROFILE:
...@@ -69,12 +71,14 @@ def inferShape(x_shape, kernel_shape, padding, strides): ...@@ -69,12 +71,14 @@ def inferShape(x_shape, kernel_shape, padding, strides):
return x_shape[:2] + tuple(output_shape) return x_shape[:2] + tuple(output_shape)
# convert a python tuple to a ctype void pointer # convert a python tuple to a ctype void pointer
def tuple_to_void_p(py_tuple: Tuple): def tuple_to_void_p(py_tuple: Tuple):
array = ctypes.c_int64 * len(py_tuple) array = ctypes.c_int64 * len(py_tuple)
data_array = array(*py_tuple) data_array = array(*py_tuple)
return ctypes.cast(data_array, ctypes.c_void_p) return ctypes.cast(data_array, ctypes.c_void_p)
def test( def test(
lib, lib,
handle, handle,
...@@ -90,7 +94,9 @@ def test( ...@@ -90,7 +94,9 @@ def test(
) )
x = torch.rand(x_shape, dtype=tensor_dtype).to(torch_device) x = torch.rand(x_shape, dtype=tensor_dtype).to(torch_device)
y = torch.rand(inferShape(x_shape, k_shape, padding, strides), dtype=tensor_dtype).to(torch_device) y = torch.rand(
inferShape(x_shape, k_shape, padding, strides), dtype=tensor_dtype
).to(torch_device)
for i in range(NUM_PRERUN if PROFILE else 1): for i in range(NUM_PRERUN if PROFILE else 1):
ans = pool(x, k_shape, padding, strides) ans = pool(x, k_shape, padding, strides)
...@@ -126,7 +132,9 @@ def test( ...@@ -126,7 +132,9 @@ def test(
check_error( check_error(
lib.infiniopGetAvgPoolWorkspaceSize(descriptor, ctypes.byref(workspaceSize)) lib.infiniopGetAvgPoolWorkspaceSize(descriptor, ctypes.byref(workspaceSize))
) )
workspace = torch.zeros(int(workspaceSize.value), dtype=torch.uint8).to(torch_device) workspace = torch.zeros(int(workspaceSize.value), dtype=torch.uint8).to(
torch_device
)
workspace_ptr = ctypes.cast(workspace.data_ptr(), ctypes.POINTER(ctypes.c_uint8)) workspace_ptr = ctypes.cast(workspace.data_ptr(), ctypes.POINTER(ctypes.c_uint8))
for i in range(NUM_PRERUN if PROFILE else 1): for i in range(NUM_PRERUN if PROFILE else 1):
...@@ -164,8 +172,10 @@ def test_cpu(lib, test_cases): ...@@ -164,8 +172,10 @@ def test_cpu(lib, test_cases):
device = DeviceEnum.DEVICE_CPU device = DeviceEnum.DEVICE_CPU
handle = create_handle(lib, device) handle = create_handle(lib, device)
for x_shape, kernel_shape, padding, strides in test_cases: for x_shape, kernel_shape, padding, strides in test_cases:
# fmt: off
test(lib, handle, "cpu", x_shape, kernel_shape, padding, strides, tensor_dtype=torch.float16) test(lib, handle, "cpu", x_shape, kernel_shape, padding, strides, tensor_dtype=torch.float16)
test(lib, handle, "cpu", x_shape, kernel_shape, padding, strides, tensor_dtype=torch.float32) test(lib, handle, "cpu", x_shape, kernel_shape, padding, strides, tensor_dtype=torch.float32)
# fmt: on
destroy_handle(lib, handle) destroy_handle(lib, handle)
...@@ -173,8 +183,10 @@ def test_cuda(lib, test_cases): ...@@ -173,8 +183,10 @@ def test_cuda(lib, test_cases):
device = DeviceEnum.DEVICE_CUDA device = DeviceEnum.DEVICE_CUDA
handle = create_handle(lib, device) handle = create_handle(lib, device)
for x_shape, kernel_shape, padding, strides in test_cases: for x_shape, kernel_shape, padding, strides in test_cases:
# fmt: off
test(lib, handle, "cuda", x_shape, kernel_shape, padding, strides, tensor_dtype=torch.float16) test(lib, handle, "cuda", x_shape, kernel_shape, padding, strides, tensor_dtype=torch.float16)
test(lib, handle, "cuda", x_shape, kernel_shape, padding, strides, tensor_dtype=torch.float32) test(lib, handle, "cuda", x_shape, kernel_shape, padding, strides, tensor_dtype=torch.float32)
# fmt: on
destroy_handle(lib, handle) destroy_handle(lib, handle)
...@@ -184,17 +196,21 @@ def test_bang(lib, test_cases): ...@@ -184,17 +196,21 @@ def test_bang(lib, test_cases):
device = DeviceEnum.DEVICE_BANG device = DeviceEnum.DEVICE_BANG
handle = create_handle(lib, device) handle = create_handle(lib, device)
for x_shape, kernel_shape, padding, strides in test_cases: for x_shape, kernel_shape, padding, strides in test_cases:
# fmt: off
test(lib, handle, "mlu", x_shape, kernel_shape, padding, strides, tensor_dtype=torch.float16) test(lib, handle, "mlu", x_shape, kernel_shape, padding, strides, tensor_dtype=torch.float16)
test(lib, handle, "mlu", x_shape, kernel_shape, padding, strides, tensor_dtype=torch.float32) test(lib, handle, "mlu", x_shape, kernel_shape, padding, strides, tensor_dtype=torch.float32)
# fmt: on
destroy_handle(lib, handle) destroy_handle(lib, handle)
if __name__ == "__main__": if __name__ == "__main__":
test_cases = [ test_cases = [
# fmt: off
# x_shape, kernel_shape, padding, strides # x_shape, kernel_shape, padding, strides
((1, 1, 10), (3,), (1,), (1,)), ((1, 1, 10), (3,), (1,), (1,)),
((32, 3, 224, 224), (3, 3), (1, 1), (2, 2)), ((32, 3, 224, 224), (3, 3), (1, 1), (2, 2)),
((1, 1, 16, 16, 16), (5, 5, 5), (2, 2, 2), (2, 2, 2)), ((1, 1, 16, 16, 16), (5, 5, 5), (2, 2, 2), (2, 2, 2)),
# fmt: on
] ]
args = get_args() args = get_args()
lib = open_lib() lib = open_lib()
......
...@@ -101,6 +101,7 @@ def test_bang(lib, test_cases): ...@@ -101,6 +101,7 @@ def test_bang(lib, test_cases):
test(lib, handle, "mlu", x_shape, x_stride) test(lib, handle, "mlu", x_shape, x_stride)
destroy_handle(lib, handle) destroy_handle(lib, handle)
def test_ascend(lib, test_cases): def test_ascend(lib, test_cases):
import torch_npu import torch_npu
...@@ -111,6 +112,7 @@ def test_ascend(lib, test_cases): ...@@ -111,6 +112,7 @@ def test_ascend(lib, test_cases):
destroy_handle(lib, handle) destroy_handle(lib, handle)
if __name__ == "__main__": if __name__ == "__main__":
test_cases = [ test_cases = [
# x_shape, x_stride # x_shape, x_stride
......
...@@ -41,17 +41,11 @@ infiniopConvDescriptor_t = POINTER(ConvDescriptor) ...@@ -41,17 +41,11 @@ infiniopConvDescriptor_t = POINTER(ConvDescriptor)
def conv(x, w, stride, padding, dilation): def conv(x, w, stride, padding, dilation):
match len(x.shape) - 2: match len(x.shape) - 2:
case 1: case 1:
return F.conv1d( return F.conv1d(x, w, stride=stride, padding=padding, dilation=dilation)
x, w, stride=stride, padding=padding, dilation=dilation
)
case 2: case 2:
return F.conv2d( return F.conv2d(x, w, stride=stride, padding=padding, dilation=dilation)
x, w, stride=stride, padding=padding, dilation=dilation
)
case 3: case 3:
return F.conv3d( return F.conv3d(x, w, stride=stride, padding=padding, dilation=dilation)
x, w, stride=stride, padding=padding, dilation=dilation
)
case _: case _:
print("Error: Pytorch -> Unsupported tensor dimension") print("Error: Pytorch -> Unsupported tensor dimension")
return None return None
...@@ -66,11 +60,15 @@ def inferShape( ...@@ -66,11 +60,15 @@ def inferShape(
dilations: List[int], dilations: List[int],
) -> Tuple[int, ...]: ) -> Tuple[int, ...]:
assert ( assert (
len(x_shape) == len(w_shape) == len(pads) + 2 == len(dilations) + 2 == len(strides) + 2 len(x_shape)
== len(w_shape)
== len(pads) + 2
== len(dilations) + 2
== len(strides) + 2
), "x and w should have the same length; pads, strides, and dilatinos should have the same length; the length of pads should be that of x - 2" ), "x and w should have the same length; pads, strides, and dilatinos should have the same length; the length of pads should be that of x - 2"
output_dims = [ output_dims = [
math.floor( math.floor(
(x_shape[i+2] + 2 * pads[i] - dilations[i] * (w_shape[i+2] - 1) - 1) (x_shape[i + 2] + 2 * pads[i] - dilations[i] * (w_shape[i + 2] - 1) - 1)
/ strides[i] / strides[i]
+ 1 + 1
) )
...@@ -145,7 +143,9 @@ def test( ...@@ -145,7 +143,9 @@ def test(
check_error( check_error(
lib.infiniopGetConvWorkspaceSize(descriptor, ctypes.byref(workspaceSize)) lib.infiniopGetConvWorkspaceSize(descriptor, ctypes.byref(workspaceSize))
) )
workspace = torch.zeros(int(workspaceSize.value), dtype=torch.uint8).to(torch_device) workspace = torch.zeros(int(workspaceSize.value), dtype=torch.uint8).to(
torch_device
)
workspace_ptr = ctypes.cast(workspace.data_ptr(), ctypes.POINTER(ctypes.c_uint8)) workspace_ptr = ctypes.cast(workspace.data_ptr(), ctypes.POINTER(ctypes.c_uint8))
for i in range(NUM_PRERUN if PROFILE else 1): for i in range(NUM_PRERUN if PROFILE else 1):
...@@ -177,7 +177,7 @@ def test( ...@@ -177,7 +177,7 @@ def test(
elapsed = (time.time() - start_time) / NUM_ITERATIONS elapsed = (time.time() - start_time) / NUM_ITERATIONS
print(f" lib time: {elapsed :6f}") print(f" lib time: {elapsed :6f}")
if (tensor_dtype == torch.float16): if tensor_dtype == torch.float16:
assert torch.allclose(y, ans, atol=0, rtol=1e-2) assert torch.allclose(y, ans, atol=0, rtol=1e-2)
else: else:
assert torch.allclose(y, ans, atol=0, rtol=1e-3) assert torch.allclose(y, ans, atol=0, rtol=1e-3)
...@@ -188,8 +188,10 @@ def test_cpu(lib, test_cases): ...@@ -188,8 +188,10 @@ def test_cpu(lib, test_cases):
device = DeviceEnum.DEVICE_CPU device = DeviceEnum.DEVICE_CPU
handle = create_handle(lib, device) handle = create_handle(lib, device)
for x_shape, w_shape, pads, strides, dilations, x_strides in test_cases: for x_shape, w_shape, pads, strides, dilations, x_strides in test_cases:
# fmt: off
test(lib, handle, "cpu", x_shape, w_shape, pads, strides, dilations, x_strides, tensor_dtype=torch.float16) test(lib, handle, "cpu", x_shape, w_shape, pads, strides, dilations, x_strides, tensor_dtype=torch.float16)
test(lib, handle, "cpu", x_shape, w_shape, pads, strides, dilations, x_strides, tensor_dtype=torch.float32) test(lib, handle, "cpu", x_shape, w_shape, pads, strides, dilations, x_strides, tensor_dtype=torch.float32)
# fmt: on
destroy_handle(lib, handle) destroy_handle(lib, handle)
...@@ -197,8 +199,10 @@ def test_cuda(lib, test_cases): ...@@ -197,8 +199,10 @@ def test_cuda(lib, test_cases):
device = DeviceEnum.DEVICE_CUDA device = DeviceEnum.DEVICE_CUDA
handle = create_handle(lib, device) handle = create_handle(lib, device)
for x_shape, w_shape, pads, strides, dilations, x_strides in test_cases: for x_shape, w_shape, pads, strides, dilations, x_strides in test_cases:
# fmt: off
test(lib, handle, "cuda", x_shape, w_shape, pads, strides, dilations, x_strides, tensor_dtype=torch.float16) test(lib, handle, "cuda", x_shape, w_shape, pads, strides, dilations, x_strides, tensor_dtype=torch.float16)
test(lib, handle, "cuda", x_shape, w_shape, pads, strides, dilations, x_strides, tensor_dtype=torch.float32) test(lib, handle, "cuda", x_shape, w_shape, pads, strides, dilations, x_strides, tensor_dtype=torch.float32)
# fmt: on
destroy_handle(lib, handle) destroy_handle(lib, handle)
...@@ -208,8 +212,10 @@ def test_bang(lib, test_cases): ...@@ -208,8 +212,10 @@ def test_bang(lib, test_cases):
device = DeviceEnum.DEVICE_BANG device = DeviceEnum.DEVICE_BANG
handle = create_handle(lib, device) handle = create_handle(lib, device)
for x_shape, w_shape, pads, strides, dilations, x_strides in test_cases: for x_shape, w_shape, pads, strides, dilations, x_strides in test_cases:
# fmt: off
test(lib, handle, "mlu", x_shape, w_shape, pads, strides, dilations, x_strides, tensor_dtype=torch.float16) test(lib, handle, "mlu", x_shape, w_shape, pads, strides, dilations, x_strides, tensor_dtype=torch.float16)
test(lib, handle, "mlu", x_shape, w_shape, pads, strides, dilations, x_strides, tensor_dtype=torch.float32) test(lib, handle, "mlu", x_shape, w_shape, pads, strides, dilations, x_strides, tensor_dtype=torch.float32)
# fmt: on
destroy_handle(lib, handle) destroy_handle(lib, handle)
......
...@@ -109,8 +109,10 @@ def test_cpu(lib, test_cases): ...@@ -109,8 +109,10 @@ def test_cpu(lib, test_cases):
device = DeviceEnum.DEVICE_CPU device = DeviceEnum.DEVICE_CPU
handle = create_handle(lib, device) handle = create_handle(lib, device)
for y_shape, x_shape, y_stride, x_stride in test_cases: for y_shape, x_shape, y_stride, x_stride in test_cases:
# fmt: off
test(lib, handle, "cpu", y_shape, x_shape, y_stride, x_stride, tensor_dtype=torch.float16) test(lib, handle, "cpu", y_shape, x_shape, y_stride, x_stride, tensor_dtype=torch.float16)
test(lib, handle, "cpu", y_shape, x_shape, y_stride, x_stride, tensor_dtype=torch.float32) test(lib, handle, "cpu", y_shape, x_shape, y_stride, x_stride, tensor_dtype=torch.float32)
# fmt: on
destroy_handle(lib, handle) destroy_handle(lib, handle)
...@@ -118,8 +120,10 @@ def test_cuda(lib, test_cases): ...@@ -118,8 +120,10 @@ def test_cuda(lib, test_cases):
device = DeviceEnum.DEVICE_CUDA device = DeviceEnum.DEVICE_CUDA
handle = create_handle(lib, device) handle = create_handle(lib, device)
for y_shape, x_shape, y_stride, x_stride in test_cases: for y_shape, x_shape, y_stride, x_stride in test_cases:
# fmt: off
test(lib, handle, "cuda", y_shape, x_shape, y_stride, x_stride, tensor_dtype=torch.float16) test(lib, handle, "cuda", y_shape, x_shape, y_stride, x_stride, tensor_dtype=torch.float16)
test(lib, handle, "cuda", y_shape, x_shape, y_stride, x_stride, tensor_dtype=torch.float32) test(lib, handle, "cuda", y_shape, x_shape, y_stride, x_stride, tensor_dtype=torch.float32)
# fmt: on
destroy_handle(lib, handle) destroy_handle(lib, handle)
...@@ -129,13 +133,16 @@ def test_bang(lib, test_cases): ...@@ -129,13 +133,16 @@ def test_bang(lib, test_cases):
device = DeviceEnum.DEVICE_BANG device = DeviceEnum.DEVICE_BANG
handle = create_handle(lib, device) handle = create_handle(lib, device)
for y_shape, x_shape, y_stride, x_stride in test_cases: for y_shape, x_shape, y_stride, x_stride in test_cases:
# fmt: off
test(lib, handle, "mlu", y_shape, x_shape, y_stride, x_stride, tensor_dtype=torch.float16) test(lib, handle, "mlu", y_shape, x_shape, y_stride, x_stride, tensor_dtype=torch.float16)
test(lib, handle, "mlu", y_shape, x_shape, y_stride, x_stride, tensor_dtype=torch.float32) test(lib, handle, "mlu", y_shape, x_shape, y_stride, x_stride, tensor_dtype=torch.float32)
# fmt: on
destroy_handle(lib, handle) destroy_handle(lib, handle)
if __name__ == "__main__": if __name__ == "__main__":
test_cases = [ test_cases = [
# fmt: off
# y_shape, x_shape, y_stride, x_stride # y_shape, x_shape, y_stride, x_stride
((), (), None, None), ((), (), None, None),
((3, 3), (1,), None, None), ((3, 3), (1,), None, None),
...@@ -146,6 +153,7 @@ if __name__ == "__main__": ...@@ -146,6 +153,7 @@ if __name__ == "__main__":
((2, 3, 4, 5), (5,), None, None), ((2, 3, 4, 5), (5,), None, None),
((3, 2, 4, 5), (3, 2, 1, 1), None, None), ((3, 2, 4, 5), (3, 2, 1, 1), None, None),
((32, 256, 112, 112), (32, 256, 112, 1), None, None), ((32, 256, 112, 112), (32, 256, 112, 1), None, None),
# fmt: on
] ]
args = get_args() args = get_args()
lib = open_lib() lib = open_lib()
......
...@@ -27,6 +27,7 @@ PROFILE = False ...@@ -27,6 +27,7 @@ PROFILE = False
NUM_PRERUN = 10 NUM_PRERUN = 10
NUM_ITERATIONS = 1000 NUM_ITERATIONS = 1000
class GEMMDescriptor(Structure): class GEMMDescriptor(Structure):
_fields_ = [("device", c_int32)] _fields_ = [("device", c_int32)]
...@@ -34,10 +35,15 @@ class GEMMDescriptor(Structure): ...@@ -34,10 +35,15 @@ class GEMMDescriptor(Structure):
infiniopGEMMDescriptor_t = POINTER(GEMMDescriptor) infiniopGEMMDescriptor_t = POINTER(GEMMDescriptor)
def gemm(A, B, C=None, transA=False, transB=False, alpha=1.0, beta=0.0, dtype=torch.float32): def gemm(
A, B, C=None, transA=False, transB=False, alpha=1.0, beta=0.0, dtype=torch.float32
):
A = A.T if transA else A A = A.T if transA else A
B = B.T if transB else B B = B.T if transB else B
result = alpha * torch.matmul(A if dtype != torch.float16 else A.to(torch.float32), B if dtype != torch.float16 else B.to(torch.float32)).to(dtype) result = alpha * torch.matmul(
A if dtype != torch.float16 else A.to(torch.float32),
B if dtype != torch.float16 else B.to(torch.float32),
).to(dtype)
if C is not None: if C is not None:
result += beta * C if dtype != torch.float16 else C.to(torch.float32) result += beta * C if dtype != torch.float16 else C.to(torch.float32)
if PROFILE: if PROFILE:
...@@ -121,9 +127,7 @@ def test( ...@@ -121,9 +127,7 @@ def test(
workspace_size = ctypes.c_uint64(0) workspace_size = ctypes.c_uint64(0)
check_error( check_error(
lib.infiniopGetGEMMWorkspaceSize( lib.infiniopGetGEMMWorkspaceSize(descriptor, ctypes.byref(workspace_size))
descriptor, ctypes.byref(workspace_size)
)
) )
workspace = torch.zeros(int(workspace_size.value), dtype=torch.uint8).to( workspace = torch.zeros(int(workspace_size.value), dtype=torch.uint8).to(
torch_device torch_device
...@@ -182,8 +186,10 @@ def test_cpu(lib, test_cases): ...@@ -182,8 +186,10 @@ def test_cpu(lib, test_cases):
c_stride, c_stride,
y_stride, y_stride,
) in test_cases: ) in test_cases:
# fmt: off
test(lib, handle, "cpu", alpha, beta, transA, transB, a_shape, b_shape, c_shape, y_shape, a_stride, b_stride, c_stride, y_stride, dtype=torch.float16) test(lib, handle, "cpu", alpha, beta, transA, transB, a_shape, b_shape, c_shape, y_shape, a_stride, b_stride, c_stride, y_stride, dtype=torch.float16)
test(lib, handle, "cpu", alpha, beta, transA, transB, a_shape, b_shape, c_shape, y_shape, a_stride, b_stride, c_stride, y_stride, dtype=torch.float32) test(lib, handle, "cpu", alpha, beta, transA, transB, a_shape, b_shape, c_shape, y_shape, a_stride, b_stride, c_stride, y_stride, dtype=torch.float32)
# fmt: on
destroy_handle(lib, handle) destroy_handle(lib, handle)
...@@ -204,8 +210,10 @@ def test_cuda(lib, test_cases): ...@@ -204,8 +210,10 @@ def test_cuda(lib, test_cases):
c_stride, c_stride,
y_stride, y_stride,
) in test_cases: ) in test_cases:
# fmt: off
test(lib, handle, "cuda", alpha, beta, transA, transB, a_shape, b_shape, c_shape, y_shape, a_stride, b_stride, c_stride, y_stride, dtype=torch.float16) test(lib, handle, "cuda", alpha, beta, transA, transB, a_shape, b_shape, c_shape, y_shape, a_stride, b_stride, c_stride, y_stride, dtype=torch.float16)
test(lib, handle, "cuda", alpha, beta, transA, transB, a_shape, b_shape, c_shape, y_shape, a_stride, b_stride, c_stride, y_stride, dtype=torch.float32) test(lib, handle, "cuda", alpha, beta, transA, transB, a_shape, b_shape, c_shape, y_shape, a_stride, b_stride, c_stride, y_stride, dtype=torch.float32)
# fmt: on
destroy_handle(lib, handle) destroy_handle(lib, handle)
...@@ -229,9 +237,10 @@ def test_bang(lib, test_cases): ...@@ -229,9 +237,10 @@ def test_bang(lib, test_cases):
c_stride, c_stride,
y_stride, y_stride,
) in test_cases: ) in test_cases:
# fmt: off
test(lib, handle, "mlu", alpha, beta, transA, transB, a_shape, b_shape, c_shape, y_shape, a_stride, b_stride, c_stride, y_stride, dtype=torch.float16) test(lib, handle, "mlu", alpha, beta, transA, transB, a_shape, b_shape, c_shape, y_shape, a_stride, b_stride, c_stride, y_stride, dtype=torch.float16)
test(lib, handle, "mlu", alpha, beta, transA, transB, a_shape, b_shape, c_shape, y_shape, a_stride, b_stride, c_stride, y_stride, dtype=torch.float32) test(lib, handle, "mlu", alpha, beta, transA, transB, a_shape, b_shape, c_shape, y_shape, a_stride, b_stride, c_stride, y_stride, dtype=torch.float32)
# fmt: on
destroy_handle(lib, handle) destroy_handle(lib, handle)
......
...@@ -99,7 +99,12 @@ def test( ...@@ -99,7 +99,12 @@ def test(
for i in range(NUM_PRERUN if PROFILE else 1): for i in range(NUM_PRERUN if PROFILE else 1):
check_error( check_error(
lib.infiniopGlobalAvgPool( lib.infiniopGlobalAvgPool(
descriptor, workspace_ptr, workspaceSize, y_tensor.data, x_tensor.data, None descriptor,
workspace_ptr,
workspaceSize,
y_tensor.data,
x_tensor.data,
None,
) )
) )
if PROFILE: if PROFILE:
......
import os import os
import sys import sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '.')))
from .liboperators import open_lib, CTensor, infiniopHandle_t, infiniopTensorDescriptor_t sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".")))
from .liboperators import (
open_lib,
CTensor,
infiniopHandle_t,
infiniopTensorDescriptor_t,
)
from .devices import * from .devices import *
from .utils import * from .utils import *
from .datatypes import * from .datatypes import *
...@@ -54,6 +54,7 @@ def create_workspace(size, torch_device): ...@@ -54,6 +54,7 @@ def create_workspace(size, torch_device):
if size == 0: if size == 0:
return None return None
import torch import torch
return torch.zeros(size=(size,), dtype=torch.uint8, device=torch_device) return torch.zeros(size=(size,), dtype=torch.uint8, device=torch_device)
...@@ -172,6 +173,7 @@ def get_args(): ...@@ -172,6 +173,7 @@ def get_args():
def synchronize_device(torch_device): def synchronize_device(torch_device):
import torch import torch
if torch_device == "cuda": if torch_device == "cuda":
torch.cuda.synchronize() torch.cuda.synchronize()
elif torch_device == "npu": elif torch_device == "npu":
...@@ -197,11 +199,22 @@ def debug(actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True): ...@@ -197,11 +199,22 @@ def debug(actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True):
If True, the function will print detailed information about any discrepancies between the tensors. If True, the function will print detailed information about any discrepancies between the tensors.
""" """
import numpy as np import numpy as np
print_discrepancy(actual, desired, atol, rtol, verbose) print_discrepancy(actual, desired, atol, rtol, verbose)
np.testing.assert_allclose(actual.cpu(), desired.cpu(), rtol, atol, equal_nan, verbose=True, strict=True) np.testing.assert_allclose(
actual.cpu(), desired.cpu(), rtol, atol, equal_nan, verbose=True, strict=True
)
def debug_all(actual_vals: Sequence, desired_vals: Sequence, condition: str, atol=0, rtol=1e-2, equal_nan=False, verbose=True): def debug_all(
actual_vals: Sequence,
desired_vals: Sequence,
condition: str,
atol=0,
rtol=1e-2,
equal_nan=False,
verbose=True,
):
""" """
Debugging function to compare two sequences of values (actual and desired) pair by pair, results Debugging function to compare two sequences of values (actual and desired) pair by pair, results
are linked by the given logical condition, and prints discrepancies are linked by the given logical condition, and prints discrepancies
...@@ -223,7 +236,10 @@ def debug_all(actual_vals: Sequence, desired_vals: Sequence, condition: str, ato ...@@ -223,7 +236,10 @@ def debug_all(actual_vals: Sequence, desired_vals: Sequence, condition: str, ato
- AssertionError: If the specified `condition` is not 'or' or 'and'. - AssertionError: If the specified `condition` is not 'or' or 'and'.
""" """
assert len(actual_vals) == len(desired_vals), "Invalid Length" assert len(actual_vals) == len(desired_vals), "Invalid Length"
assert condition in {"or", "and"}, "Invalid condition: should be either 'or' or 'and'" assert condition in {
"or",
"and",
}, "Invalid condition: should be either 'or' or 'and'"
import numpy as np import numpy as np
passed = False if condition == "or" else True passed = False if condition == "or" else True
...@@ -237,14 +253,22 @@ def debug_all(actual_vals: Sequence, desired_vals: Sequence, condition: str, ato ...@@ -237,14 +253,22 @@ def debug_all(actual_vals: Sequence, desired_vals: Sequence, condition: str, ato
elif condition == "and": elif condition == "and":
if passed and len(indices) != 0: if passed and len(indices) != 0:
passed = False passed = False
print(f"\033[31mThe condition has not been satisfied: Condition #{index + 1}\033[0m") print(
np.testing.assert_allclose(actual.cpu(), desired.cpu(), rtol, atol, equal_nan, verbose=True, strict=True) f"\033[31mThe condition has not been satisfied: Condition #{index + 1}\033[0m"
)
np.testing.assert_allclose(
actual.cpu(),
desired.cpu(),
rtol,
atol,
equal_nan,
verbose=True,
strict=True,
)
assert passed, "\033[31mThe condition has not been satisfied\033[0m" assert passed, "\033[31mThe condition has not been satisfied\033[0m"
def print_discrepancy( def print_discrepancy(actual, expected, atol=0, rtol=1e-3, verbose=True):
actual, expected, atol=0, rtol=1e-3, verbose=True
):
if actual.shape != expected.shape: if actual.shape != expected.shape:
raise ValueError("Tensors must have the same shape to compare.") raise ValueError("Tensors must have the same shape to compare.")
...@@ -273,7 +297,9 @@ def print_discrepancy( ...@@ -273,7 +297,9 @@ def print_discrepancy(
for idx in diff_indices: for idx in diff_indices:
index_tuple = tuple(idx.tolist()) index_tuple = tuple(idx.tolist())
actual_str = f"{actual[index_tuple]:<{col_width[1]}.{decimal_places[1]}f}" actual_str = f"{actual[index_tuple]:<{col_width[1]}.{decimal_places[1]}f}"
expected_str = f"{expected[index_tuple]:<{col_width[2]}.{decimal_places[2]}f}" expected_str = (
f"{expected[index_tuple]:<{col_width[2]}.{decimal_places[2]}f}"
)
delta_str = f"{delta[index_tuple]:<{col_width[3]}.{decimal_places[3]}f}" delta_str = f"{delta[index_tuple]:<{col_width[3]}.{decimal_places[3]}f}"
print( print(
f" > Index: {str(index_tuple):<{col_width[0]}}" f" > Index: {str(index_tuple):<{col_width[0]}}"
...@@ -287,10 +313,18 @@ def print_discrepancy( ...@@ -287,10 +313,18 @@ def print_discrepancy(
print(f" - Desired dtype: {expected.dtype}") print(f" - Desired dtype: {expected.dtype}")
print(f" - Atol: {atol}") print(f" - Atol: {atol}")
print(f" - Rtol: {rtol}") print(f" - Rtol: {rtol}")
print(f" - Mismatched elements: {len(diff_indices)} / {actual.numel()} ({len(diff_indices) / actual.numel() * 100}%)") print(
print(f" - Min(actual) : {torch.min(actual):<{col_width[1]}} | Max(actual) : {torch.max(actual):<{col_width[2]}}") f" - Mismatched elements: {len(diff_indices)} / {actual.numel()} ({len(diff_indices) / actual.numel() * 100}%)"
print(f" - Min(desired): {torch.min(expected):<{col_width[1]}} | Max(desired): {torch.max(expected):<{col_width[2]}}") )
print(f" - Min(delta) : {torch.min(delta):<{col_width[1]}} | Max(delta) : {torch.max(delta):<{col_width[2]}}") print(
f" - Min(actual) : {torch.min(actual):<{col_width[1]}} | Max(actual) : {torch.max(actual):<{col_width[2]}}"
)
print(
f" - Min(desired): {torch.min(expected):<{col_width[1]}} | Max(desired): {torch.max(expected):<{col_width[2]}}"
)
print(
f" - Min(delta) : {torch.min(delta):<{col_width[1]}} | Max(delta) : {torch.max(delta):<{col_width[2]}}"
)
print("-" * total_width + "\n") print("-" * total_width + "\n")
return diff_indices return diff_indices
...@@ -301,11 +335,14 @@ def get_tolerance(tolerance_map, tensor_dtype, default_atol=0, default_rtol=1e-3 ...@@ -301,11 +335,14 @@ def get_tolerance(tolerance_map, tensor_dtype, default_atol=0, default_rtol=1e-3
Returns the atol and rtol for a given tensor data type in the tolerance_map. Returns the atol and rtol for a given tensor data type in the tolerance_map.
If the given data type is not found, it returns the provided default tolerance values. If the given data type is not found, it returns the provided default tolerance values.
""" """
return tolerance_map.get(tensor_dtype, {'atol': default_atol, 'rtol': default_rtol}).values() return tolerance_map.get(
tensor_dtype, {"atol": default_atol, "rtol": default_rtol}
).values()
def timed_op(func, num_iterations, device): def timed_op(func, num_iterations, device):
import time import time
""" Function for timing operations with synchronization. """ """ Function for timing operations with synchronization. """
synchronize_device(device) synchronize_device(device)
start = time.time() start = time.time()
...@@ -355,7 +392,13 @@ def test_operator(lib, device, test_func, test_cases, tensor_dtypes): ...@@ -355,7 +392,13 @@ def test_operator(lib, device, test_func, test_cases, tensor_dtypes):
try: try:
for test_case in test_cases: for test_case in test_cases:
for tensor_dtype in tensor_dtypes: for tensor_dtype in tensor_dtypes:
test_func(lib, handle, infiniDeviceEnum_str_map[device], *test_case, tensor_dtype) test_func(
lib,
handle,
infiniDeviceEnum_str_map[device],
*test_case,
tensor_dtype,
)
finally: finally:
destroy_handle(lib, handle) destroy_handle(lib, handle)
...@@ -372,14 +415,18 @@ def get_test_devices(args): ...@@ -372,14 +415,18 @@ def get_test_devices(args):
""" """
devices_to_test = [] devices_to_test = []
if args.cpu: devices_to_test.append(InfiniDeviceEnum.CPU) if args.cpu:
if args.nvidia: devices_to_test.append(InfiniDeviceEnum.NVIDIA) devices_to_test.append(InfiniDeviceEnum.CPU)
if args.nvidia:
devices_to_test.append(InfiniDeviceEnum.NVIDIA)
if args.cambricon: if args.cambricon:
import torch_mlu import torch_mlu
devices_to_test.append(InfiniDeviceEnum.CAMBRICON) devices_to_test.append(InfiniDeviceEnum.CAMBRICON)
if args.ascend: if args.ascend:
import torch import torch
import torch_npu import torch_npu
torch.npu.set_device(0) # Ascend NPU needs explicit device initialization torch.npu.set_device(0) # Ascend NPU needs explicit device initialization
devices_to_test.append(InfiniDeviceEnum.ASCEND) devices_to_test.append(InfiniDeviceEnum.ASCEND)
if not devices_to_test: if not devices_to_test:
......
...@@ -2,9 +2,19 @@ import torch ...@@ -2,9 +2,19 @@ import torch
import ctypes import ctypes
from ctypes import POINTER, Structure, c_int32, c_size_t, c_uint64, c_void_p, c_float from ctypes import POINTER, Structure, c_int32, c_size_t, c_uint64, c_void_p, c_float
from libinfiniop import ( from libinfiniop import (
infiniopHandle_t, infiniopTensorDescriptor_t, open_lib, to_tensor, get_test_devices, infiniopHandle_t,
check_error, rearrange_if_needed, create_workspace, test_operator, get_args, infiniopTensorDescriptor_t,
debug, get_tolerance, profile_operation, open_lib,
to_tensor,
get_test_devices,
check_error,
rearrange_if_needed,
create_workspace,
test_operator,
get_args,
debug,
get_tolerance,
profile_operation,
) )
# ============================================================================== # ==============================================================================
...@@ -21,8 +31,8 @@ _TEST_CASES = [ ...@@ -21,8 +31,8 @@ _TEST_CASES = [
(1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), (4096, 1), (4096, 1), (4096, 1)), (1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), (4096, 1), (4096, 1), (4096, 1)),
(1.0, 1.0, (6, 2048), (2048, 2560), (6, 2560), (2048, 1), (1, 2048), (2560, 1)), (1.0, 1.0, (6, 2048), (2048, 2560), (6, 2560), (2048, 1), (1, 2048), (2560, 1)),
(1.0, 1.0, (6, 2048), (2048, 2560), (6, 2560), (2048, 1), (1, 2048), (2560, 1)), (1.0, 1.0, (6, 2048), (2048, 2560), (6, 2560), (2048, 1), (1, 2048), (2560, 1)),
(1.0/8.0, 0.0, (4, 8*6, 64), (4, 64, 6), (4, 8*6, 6), None, None, None), (1.0 / 8.0, 0.0, (4, 8 * 6, 64), (4, 64, 6), (4, 8 * 6, 6), None, None, None),
(1.0/8.0, 0.0, (4, 8*6, 64), (4, 64, 6), (4, 8*6, 6), None, None, None), (1.0 / 8.0, 0.0, (4, 8 * 6, 64), (4, 64, 6), (4, 8 * 6, 6), None, None, None),
] ]
# Data types used for testing # Data types used for testing
...@@ -30,8 +40,8 @@ _TENSOR_DTYPES = [torch.float16, torch.float32] ...@@ -30,8 +40,8 @@ _TENSOR_DTYPES = [torch.float16, torch.float32]
# Tolerance map for different data types # Tolerance map for different data types
_TOLERANCE_MAP = { _TOLERANCE_MAP = {
torch.float16: {'atol': 0, 'rtol': 1e-2}, torch.float16: {"atol": 0, "rtol": 1e-2},
torch.float32: {'atol': 0, 'rtol': 1e-3}, torch.float32: {"atol": 0, "rtol": 1e-3},
} }
DEBUG = False DEBUG = False
...@@ -39,6 +49,7 @@ PROFILE = False ...@@ -39,6 +49,7 @@ PROFILE = False
NUM_PRERUN = 10 NUM_PRERUN = 10
NUM_ITERATIONS = 1000 NUM_ITERATIONS = 1000
# ============================================================================== # ==============================================================================
# Definitions # Definitions
# ============================================================================== # ==============================================================================
...@@ -48,6 +59,7 @@ class MatmulDescriptor(Structure): ...@@ -48,6 +59,7 @@ class MatmulDescriptor(Structure):
infiniopMatmulDescriptor_t = POINTER(MatmulDescriptor) infiniopMatmulDescriptor_t = POINTER(MatmulDescriptor)
# PyTorch implementation for matrix multiplication # PyTorch implementation for matrix multiplication
def matmul(_c, beta, _a, _b, alpha): def matmul(_c, beta, _a, _b, alpha):
a, b, c = _a.clone(), _b.clone(), _c.clone() a, b, c = _a.clone(), _b.clone(), _c.clone()
...@@ -55,6 +67,7 @@ def matmul(_c, beta, _a, _b, alpha): ...@@ -55,6 +67,7 @@ def matmul(_c, beta, _a, _b, alpha):
fp32_result = torch.matmul(a.to(torch.float32), b.to(torch.float32)) fp32_result = torch.matmul(a.to(torch.float32), b.to(torch.float32))
return alpha * fp32_result.to(result_dtype) + beta * c return alpha * fp32_result.to(result_dtype) + beta * c
# The argument list should be (lib, handle, torch_device, <param list>, dtype) # The argument list should be (lib, handle, torch_device, <param list>, dtype)
# The <param list> should keep the same order as the one specified in _TEST_CASES # The <param list> should keep the same order as the one specified in _TEST_CASES
def test( def test(
...@@ -85,7 +98,10 @@ def test( ...@@ -85,7 +98,10 @@ def test(
# Compute the PyTorch reference result # Compute the PyTorch reference result
ans = matmul(c, beta, a, b, alpha) ans = matmul(c, beta, a, b, alpha)
a, b, c = [rearrange_if_needed(tensor, stride) for tensor, stride in zip([a, b, c], [a_stride, b_stride, c_stride])] a, b, c = [
rearrange_if_needed(tensor, stride)
for tensor, stride in zip([a, b, c], [a_stride, b_stride, c_stride])
]
a_tensor, b_tensor, c_tensor = [to_tensor(tensor, lib) for tensor in [a, b, c]] a_tensor, b_tensor, c_tensor = [to_tensor(tensor, lib) for tensor in [a, b, c]]
descriptor = infiniopMatmulDescriptor_t() descriptor = infiniopMatmulDescriptor_t()
...@@ -95,7 +111,7 @@ def test( ...@@ -95,7 +111,7 @@ def test(
ctypes.byref(descriptor), ctypes.byref(descriptor),
c_tensor.descriptor, c_tensor.descriptor,
a_tensor.descriptor, a_tensor.descriptor,
b_tensor.descriptor b_tensor.descriptor,
) )
) )
...@@ -105,12 +121,15 @@ def test( ...@@ -105,12 +121,15 @@ def test(
# Get workspace size and create workspace # Get workspace size and create workspace
workspace_size = c_uint64(0) workspace_size = c_uint64(0)
check_error(lib.infiniopGetMatmulWorkspaceSize(descriptor, ctypes.byref(workspace_size))) check_error(
lib.infiniopGetMatmulWorkspaceSize(descriptor, ctypes.byref(workspace_size))
)
workspace = create_workspace(workspace_size.value, a.device) workspace = create_workspace(workspace_size.value, a.device)
# Execute infiniop matmul operator # Execute infiniop matmul operator
def lib_matmul(): def lib_matmul():
check_error(lib.infiniopMatmul( check_error(
lib.infiniopMatmul(
descriptor, descriptor,
workspace.data_ptr() if workspace is not None else None, workspace.data_ptr() if workspace is not None else None,
workspace_size.value, workspace_size.value,
...@@ -120,7 +139,9 @@ def test( ...@@ -120,7 +139,9 @@ def test(
alpha, alpha,
beta, beta,
None, None,
)) )
)
lib_matmul() lib_matmul()
# Validate results # Validate results
...@@ -131,9 +152,10 @@ def test( ...@@ -131,9 +152,10 @@ def test(
# Profiling workflow # Profiling workflow
if PROFILE: if PROFILE:
# fmt: off
profile_operation("PyTorch", lambda: matmul(c, beta, a, b, alpha), torch_device, NUM_PRERUN, NUM_ITERATIONS) profile_operation("PyTorch", lambda: matmul(c, beta, a, b, alpha), torch_device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_matmul(), torch_device, NUM_PRERUN, NUM_ITERATIONS) profile_operation(" lib", lambda: lib_matmul(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on
check_error(lib.infiniopDestroyMatmulDescriptor(descriptor)) check_error(lib.infiniopDestroyMatmulDescriptor(descriptor))
...@@ -150,7 +172,7 @@ if __name__ == "__main__": ...@@ -150,7 +172,7 @@ if __name__ == "__main__":
POINTER(infiniopMatmulDescriptor_t), POINTER(infiniopMatmulDescriptor_t),
infiniopTensorDescriptor_t, infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t, infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t infiniopTensorDescriptor_t,
] ]
lib.infiniopGetMatmulWorkspaceSize.restype = c_int32 lib.infiniopGetMatmulWorkspaceSize.restype = c_int32
......
...@@ -35,7 +35,7 @@ class MaxPoolDescriptor(Structure): ...@@ -35,7 +35,7 @@ class MaxPoolDescriptor(Structure):
infiniopMaxPoolDescriptor_t = POINTER(MaxPoolDescriptor) infiniopMaxPoolDescriptor_t = POINTER(MaxPoolDescriptor)
def pool(x, k, padding, stride, dilation = 1): def pool(x, k, padding, stride, dilation=1):
pooling_layers = { pooling_layers = {
1: torch.nn.MaxPool1d, 1: torch.nn.MaxPool1d,
2: torch.nn.MaxPool2d, 2: torch.nn.MaxPool2d,
...@@ -66,12 +66,14 @@ def inferShape(x_shape, kernel_shape, padding, strides): ...@@ -66,12 +66,14 @@ def inferShape(x_shape, kernel_shape, padding, strides):
return x_shape[:2] + tuple(output_shape) return x_shape[:2] + tuple(output_shape)
# convert a python tuple to a ctype void pointer # convert a python tuple to a ctype void pointer
def tuple_to_void_p(py_tuple: Tuple): def tuple_to_void_p(py_tuple: Tuple):
array = ctypes.c_int64 * len(py_tuple) array = ctypes.c_int64 * len(py_tuple)
data_array = array(*py_tuple) data_array = array(*py_tuple)
return ctypes.cast(data_array, ctypes.c_void_p) return ctypes.cast(data_array, ctypes.c_void_p)
def test( def test(
lib, lib,
handle, handle,
...@@ -87,7 +89,9 @@ def test( ...@@ -87,7 +89,9 @@ def test(
) )
x = torch.rand(x_shape, dtype=tensor_dtype).to(torch_device) x = torch.rand(x_shape, dtype=tensor_dtype).to(torch_device)
y = torch.rand(inferShape(x_shape, k_shape, padding, strides), dtype=tensor_dtype).to(torch_device) y = torch.rand(
inferShape(x_shape, k_shape, padding, strides), dtype=tensor_dtype
).to(torch_device)
for i in range(NUM_PRERUN if PROFILE else 1): for i in range(NUM_PRERUN if PROFILE else 1):
ans = pool(x, k_shape, padding, strides) ans = pool(x, k_shape, padding, strides)
...@@ -123,7 +127,9 @@ def test( ...@@ -123,7 +127,9 @@ def test(
check_error( check_error(
lib.infiniopGetMaxPoolWorkspaceSize(descriptor, ctypes.byref(workspaceSize)) lib.infiniopGetMaxPoolWorkspaceSize(descriptor, ctypes.byref(workspaceSize))
) )
workspace = torch.zeros(int(workspaceSize.value), dtype=torch.uint8).to(torch_device) workspace = torch.zeros(int(workspaceSize.value), dtype=torch.uint8).to(
torch_device
)
workspace_ptr = ctypes.cast(workspace.data_ptr(), ctypes.POINTER(ctypes.c_uint8)) workspace_ptr = ctypes.cast(workspace.data_ptr(), ctypes.POINTER(ctypes.c_uint8))
for i in range(NUM_PRERUN if PROFILE else 1): for i in range(NUM_PRERUN if PROFILE else 1):
...@@ -161,8 +167,10 @@ def test_cpu(lib, test_cases): ...@@ -161,8 +167,10 @@ def test_cpu(lib, test_cases):
device = DeviceEnum.DEVICE_CPU device = DeviceEnum.DEVICE_CPU
handle = create_handle(lib, device) handle = create_handle(lib, device)
for x_shape, kernel_shape, padding, strides in test_cases: for x_shape, kernel_shape, padding, strides in test_cases:
# fmt: off
test(lib, handle, "cpu", x_shape, kernel_shape, padding, strides, tensor_dtype=torch.float16) test(lib, handle, "cpu", x_shape, kernel_shape, padding, strides, tensor_dtype=torch.float16)
test(lib, handle, "cpu", x_shape, kernel_shape, padding, strides, tensor_dtype=torch.float32) test(lib, handle, "cpu", x_shape, kernel_shape, padding, strides, tensor_dtype=torch.float32)
# fmt: on
destroy_handle(lib, handle) destroy_handle(lib, handle)
...@@ -170,8 +178,10 @@ def test_cuda(lib, test_cases): ...@@ -170,8 +178,10 @@ def test_cuda(lib, test_cases):
device = DeviceEnum.DEVICE_CUDA device = DeviceEnum.DEVICE_CUDA
handle = create_handle(lib, device) handle = create_handle(lib, device)
for x_shape, kernel_shape, padding, strides in test_cases: for x_shape, kernel_shape, padding, strides in test_cases:
# fmt: off
test(lib, handle, "cuda", x_shape, kernel_shape, padding, strides, tensor_dtype=torch.float16) test(lib, handle, "cuda", x_shape, kernel_shape, padding, strides, tensor_dtype=torch.float16)
test(lib, handle, "cuda", x_shape, kernel_shape, padding, strides, tensor_dtype=torch.float32) test(lib, handle, "cuda", x_shape, kernel_shape, padding, strides, tensor_dtype=torch.float32)
# fmt: on
destroy_handle(lib, handle) destroy_handle(lib, handle)
...@@ -181,8 +191,10 @@ def test_bang(lib, test_cases): ...@@ -181,8 +191,10 @@ def test_bang(lib, test_cases):
device = DeviceEnum.DEVICE_BANG device = DeviceEnum.DEVICE_BANG
handle = create_handle(lib, device) handle = create_handle(lib, device)
for x_shape, kernel_shape, padding, strides in test_cases: for x_shape, kernel_shape, padding, strides in test_cases:
# fmt: off
test(lib, handle, "mlu", x_shape, kernel_shape, padding, strides, tensor_dtype=torch.float16) test(lib, handle, "mlu", x_shape, kernel_shape, padding, strides, tensor_dtype=torch.float16)
test(lib, handle, "mlu", x_shape, kernel_shape, padding, strides, tensor_dtype=torch.float32) test(lib, handle, "mlu", x_shape, kernel_shape, padding, strides, tensor_dtype=torch.float32)
# fmt: on
destroy_handle(lib, handle) destroy_handle(lib, handle)
......
...@@ -30,13 +30,13 @@ infiniopRandomSampleDescriptor_t = POINTER(RandomSampleDescriptor) ...@@ -30,13 +30,13 @@ infiniopRandomSampleDescriptor_t = POINTER(RandomSampleDescriptor)
def random_sample(data, random_val, topp, topk, voc, temperature, torch_device): def random_sample(data, random_val, topp, topk, voc, temperature, torch_device):
indices = torch.zeros([topk], dtype = torch.int64) indices = torch.zeros([topk], dtype=torch.int64)
dataNp = data.clone().detach() dataNp = data.clone().detach()
sorted_indices = torch.arange(voc) sorted_indices = torch.arange(voc)
for i in range(topk): for i in range(topk):
for j in range(i + 1, voc): for j in range(i + 1, voc):
if(dataNp[i] < dataNp[j]): if dataNp[i] < dataNp[j]:
tmp = dataNp[i].clone().detach() tmp = dataNp[i].clone().detach()
dataNp[i] = dataNp[j].clone().detach() dataNp[i] = dataNp[j].clone().detach()
dataNp[j] = tmp dataNp[j] = tmp
...@@ -45,20 +45,20 @@ def random_sample(data, random_val, topp, topk, voc, temperature, torch_device): ...@@ -45,20 +45,20 @@ def random_sample(data, random_val, topp, topk, voc, temperature, torch_device):
sorted_indices[i] = sorted_indices[j].clone().detach() sorted_indices[i] = sorted_indices[j].clone().detach()
sorted_indices[j] = tmpInd sorted_indices[j] = tmpInd
#sorted_indices = torch.argsort(dataNp, descending=True) # sorted_indices = torch.argsort(dataNp, descending=True)
indices = sorted_indices[:topk] indices = sorted_indices[:topk]
dataNp = dataNp[sorted_indices] dataNp = dataNp[sorted_indices]
globalM = dataNp[0] globalM = dataNp[0]
dataNp = (dataNp - globalM) / temperature dataNp = (dataNp - globalM) / temperature
dataNp = torch.softmax(dataNp.float(), dim = 0) dataNp = torch.softmax(dataNp.float(), dim=0)
sum_s = 0 sum_s = 0
for end in range(topk): for end in range(topk):
sum_s += dataNp[end] sum_s += dataNp[end]
if(sum_s >= topp): if sum_s >= topp:
break break
if(end < topk - 1): if end < topk - 1:
end += 1 end += 1
else: else:
end = topk end = topk
...@@ -71,21 +71,33 @@ def random_sample(data, random_val, topp, topk, voc, temperature, torch_device): ...@@ -71,21 +71,33 @@ def random_sample(data, random_val, topp, topk, voc, temperature, torch_device):
sum_s = 0 sum_s = 0
for i in range(end): for i in range(end):
sum_s += dataNp[i] sum_s += dataNp[i]
if(random_val < sum_s): if random_val < sum_s:
return indices[i] return indices[i]
def random_sample_0(data): def random_sample_0(data):
return torch.argmax(data) return torch.argmax(data)
def test(lib, handle, torch_device, voc, random_val, topp, topk, temperature, x_dtype=torch.float16):
print( def test(
f"Testing RandomSample on {torch_device} with voc:{voc} dtype:{x_dtype}" lib,
) handle,
torch_device,
voc,
random_val,
topp,
topk,
temperature,
x_dtype=torch.float16,
):
print(f"Testing RandomSample on {torch_device} with voc:{voc} dtype:{x_dtype}")
data = torch.arange(voc).float() * 0.0001 data = torch.arange(voc).float() * 0.0001
_perm = torch.randperm(voc) _perm = torch.randperm(voc)
data = data[_perm].to(x_dtype).to(torch_device) data = data[_perm].to(x_dtype).to(torch_device)
if(topp > 0 and topk > 1): if topp > 0 and topk > 1:
ans = random_sample(data.to("cpu"), random_val, topp, topk, voc, temperature, "cpu") ans = random_sample(
data.to("cpu"), random_val, topp, topk, voc, temperature, "cpu"
)
else: else:
ans = random_sample_0(data) ans = random_sample_0(data)
indices = torch.zeros([1], dtype=torch.int64).to(torch_device) indices = torch.zeros([1], dtype=torch.int64).to(torch_device)
...@@ -96,7 +108,10 @@ def test(lib, handle, torch_device, voc, random_val, topp, topk, temperature, x_ ...@@ -96,7 +108,10 @@ def test(lib, handle, torch_device, voc, random_val, topp, topk, temperature, x_
descriptor = infiniopRandomSampleDescriptor_t() descriptor = infiniopRandomSampleDescriptor_t()
check_error( check_error(
lib.infiniopCreateRandomSampleDescriptor( lib.infiniopCreateRandomSampleDescriptor(
handle, ctypes.byref(descriptor), indices_tensor.descriptor, x_tensor.descriptor handle,
ctypes.byref(descriptor),
indices_tensor.descriptor,
x_tensor.descriptor,
) )
) )
...@@ -131,10 +146,11 @@ def test(lib, handle, torch_device, voc, random_val, topp, topk, temperature, x_ ...@@ -131,10 +146,11 @@ def test(lib, handle, torch_device, voc, random_val, topp, topk, temperature, x_
assert indices[0].type(ans.dtype) == ans or data[ans] == data[indices[0]] assert indices[0].type(ans.dtype) == ans or data[ans] == data[indices[0]]
check_error(lib.infiniopDestroyRandomSampleDescriptor(descriptor)) check_error(lib.infiniopDestroyRandomSampleDescriptor(descriptor))
def test_cpu(lib, test_cases): def test_cpu(lib, test_cases):
device = DeviceEnum.DEVICE_CPU device = DeviceEnum.DEVICE_CPU
handle = create_handle(lib, device) handle = create_handle(lib, device)
for (voc, random_val, topp, topk, temperature) in test_cases: for voc, random_val, topp, topk, temperature in test_cases:
test(lib, handle, "cpu", voc, random_val, topp, topk, temperature) test(lib, handle, "cpu", voc, random_val, topp, topk, temperature)
destroy_handle(lib, handle) destroy_handle(lib, handle)
...@@ -142,7 +158,7 @@ def test_cpu(lib, test_cases): ...@@ -142,7 +158,7 @@ def test_cpu(lib, test_cases):
def test_cuda(lib, test_cases): def test_cuda(lib, test_cases):
device = DeviceEnum.DEVICE_CUDA device = DeviceEnum.DEVICE_CUDA
handle = create_handle(lib, device) handle = create_handle(lib, device)
for (voc, random_val, topp, topk, temperature) in test_cases: for voc, random_val, topp, topk, temperature in test_cases:
test(lib, handle, "cuda", voc, random_val, topp, topk, temperature) test(lib, handle, "cuda", voc, random_val, topp, topk, temperature)
destroy_handle(lib, handle) destroy_handle(lib, handle)
...@@ -152,16 +168,17 @@ def test_bang(lib, test_cases): ...@@ -152,16 +168,17 @@ def test_bang(lib, test_cases):
device = DeviceEnum.DEVICE_BANG device = DeviceEnum.DEVICE_BANG
handle = create_handle(lib, device) handle = create_handle(lib, device)
for (voc, random_val, topp, topk, temperature) in test_cases: for voc, random_val, topp, topk, temperature in test_cases:
test(lib, handle, "mlu", voc, random_val, topp, topk, temperature) test(lib, handle, "mlu", voc, random_val, topp, topk, temperature)
destroy_handle(lib, handle) destroy_handle(lib, handle)
def test_ascend(lib, test_cases): def test_ascend(lib, test_cases):
import torch_npu import torch_npu
device = DeviceEnum.DEVICE_ASCEND device = DeviceEnum.DEVICE_ASCEND
handle = create_handle(lib, device) handle = create_handle(lib, device)
for (voc, random_val, topp, topk, temperature) in test_cases: for voc, random_val, topp, topk, temperature in test_cases:
test(lib, handle, "npu", voc, random_val, topp, topk, temperature) test(lib, handle, "npu", voc, random_val, topp, topk, temperature)
destroy_handle(lib, handle) destroy_handle(lib, handle)
......
...@@ -61,9 +61,7 @@ def test( ...@@ -61,9 +61,7 @@ def test(
x_tensor.descriptor.contents.invalidate() x_tensor.descriptor.contents.invalidate()
y_tensor.descriptor.contents.invalidate() y_tensor.descriptor.contents.invalidate()
check_error( check_error(lib.infiniopRearrange(descriptor, y_tensor.data, x_tensor.data, None))
lib.infiniopRearrange(descriptor, y_tensor.data, x_tensor.data, None)
)
assert torch.allclose(x, y, atol=0, rtol=1e-3) assert torch.allclose(x, y, atol=0, rtol=1e-3)
check_error(lib.infiniopDestroyRearrangeDescriptor(descriptor)) check_error(lib.infiniopDestroyRearrangeDescriptor(descriptor))
...@@ -87,8 +85,10 @@ def test_cuda(lib, test_cases): ...@@ -87,8 +85,10 @@ def test_cuda(lib, test_cases):
test(lib, handle, "cuda", x_shape, x_stride, y_shape, y_stride) test(lib, handle, "cuda", x_shape, x_stride, y_shape, y_stride)
destroy_handle(lib, handle) destroy_handle(lib, handle)
def test_bang(lib, test_cases): def test_bang(lib, test_cases):
import torch_mlu import torch_mlu
device = DeviceEnum.DEVICE_BANG device = DeviceEnum.DEVICE_BANG
handle = create_handle(lib, device) handle = create_handle(lib, device)
for test_case in test_cases: for test_case in test_cases:
...@@ -97,6 +97,7 @@ def test_bang(lib, test_cases): ...@@ -97,6 +97,7 @@ def test_bang(lib, test_cases):
test(lib, handle, "mlu", x_shape, x_stride, y_shape, y_stride) test(lib, handle, "mlu", x_shape, x_stride, y_shape, y_stride)
destroy_handle(lib, handle) destroy_handle(lib, handle)
def test_ascend(lib, test_cases): def test_ascend(lib, test_cases):
import torch_npu import torch_npu
...@@ -108,6 +109,7 @@ def test_ascend(lib, test_cases): ...@@ -108,6 +109,7 @@ def test_ascend(lib, test_cases):
test(lib, handle, "npu", x_shape, x_stride, y_shape, y_stride) test(lib, handle, "npu", x_shape, x_stride, y_shape, y_stride)
destroy_handle(lib, handle) destroy_handle(lib, handle)
if __name__ == "__main__": if __name__ == "__main__":
args = get_args() args = get_args()
test_cases = [ test_cases = [
......
...@@ -61,7 +61,11 @@ def test( ...@@ -61,7 +61,11 @@ def test(
) )
x = torch.rand(tensor_shape, dtype=tensor_dtype).to(torch_device) * 2 - 1 x = torch.rand(tensor_shape, dtype=tensor_dtype).to(torch_device) * 2 - 1
y = torch.rand(tensor_shape, dtype=tensor_dtype).to(torch_device) if inplace == Inplace.OUT_OF_PLACE else x y = (
torch.rand(tensor_shape, dtype=tensor_dtype).to(torch_device)
if inplace == Inplace.OUT_OF_PLACE
else x
)
for i in range(NUM_PRERUN if PROFILE else 1): for i in range(NUM_PRERUN if PROFILE else 1):
ans = relu(x) ans = relu(x)
...@@ -108,17 +112,22 @@ def test_cpu(lib, test_cases): ...@@ -108,17 +112,22 @@ def test_cpu(lib, test_cases):
device = DeviceEnum.DEVICE_CPU device = DeviceEnum.DEVICE_CPU
handle = create_handle(lib, device) handle = create_handle(lib, device)
for tensor_shape, inplace in test_cases: for tensor_shape, inplace in test_cases:
# fmt: off
test(lib, handle, "cpu", tensor_shape, tensor_dtype=torch.float16, inplace=inplace) test(lib, handle, "cpu", tensor_shape, tensor_dtype=torch.float16, inplace=inplace)
test(lib, handle, "cpu", tensor_shape, tensor_dtype=torch.float32, inplace=inplace) test(lib, handle, "cpu", tensor_shape, tensor_dtype=torch.float32, inplace=inplace)
# fmt: on
destroy_handle(lib, handle) destroy_handle(lib, handle)
def test_cuda(lib, test_cases): def test_cuda(lib, test_cases):
device = DeviceEnum.DEVICE_CUDA device = DeviceEnum.DEVICE_CUDA
handle = create_handle(lib, device) handle = create_handle(lib, device)
for tensor_shape, inplace in test_cases: for tensor_shape, inplace in test_cases:
# fmt: off
test(lib, handle, "cuda", tensor_shape, tensor_dtype=torch.float16, inplace=inplace) test(lib, handle, "cuda", tensor_shape, tensor_dtype=torch.float16, inplace=inplace)
test(lib, handle, "cuda", tensor_shape, tensor_dtype=torch.float32, inplace=inplace) test(lib, handle, "cuda", tensor_shape, tensor_dtype=torch.float32, inplace=inplace)
# fmt: on
destroy_handle(lib, handle) destroy_handle(lib, handle)
...@@ -128,8 +137,10 @@ def test_bang(lib, test_cases): ...@@ -128,8 +137,10 @@ def test_bang(lib, test_cases):
device = DeviceEnum.DEVICE_BANG device = DeviceEnum.DEVICE_BANG
handle = create_handle(lib, device) handle = create_handle(lib, device)
for tensor_shape, inplace in test_cases: for tensor_shape, inplace in test_cases:
# fmt: off
test(lib, handle, "mlu", tensor_shape, tensor_dtype=torch.float16, inplace=inplace) test(lib, handle, "mlu", tensor_shape, tensor_dtype=torch.float16, inplace=inplace)
test(lib, handle, "mlu", tensor_shape, tensor_dtype=torch.float32, inplace=inplace) test(lib, handle, "mlu", tensor_shape, tensor_dtype=torch.float32, inplace=inplace)
# fmt: on
destroy_handle(lib, handle) destroy_handle(lib, handle)
......
...@@ -20,12 +20,14 @@ from operatorspy import ( ...@@ -20,12 +20,14 @@ from operatorspy import (
from operatorspy.tests.test_utils import get_args from operatorspy.tests.test_utils import get_args
import torch import torch
class RMSNormDescriptor(Structure): class RMSNormDescriptor(Structure):
_fields_ = [("device", c_int32)] _fields_ = [("device", c_int32)]
infiniopRMSNormDescriptor_t = POINTER(RMSNormDescriptor) infiniopRMSNormDescriptor_t = POINTER(RMSNormDescriptor)
def rms_norm(x, w, eps): def rms_norm(x, w, eps):
input_dtype = x.dtype input_dtype = x.dtype
hidden_states = x.to(torch.float32) hidden_states = x.to(torch.float32)
...@@ -34,9 +36,20 @@ def rms_norm(x, w, eps): ...@@ -34,9 +36,20 @@ def rms_norm(x, w, eps):
return w * hidden_states.to(input_dtype) return w * hidden_states.to(input_dtype)
def test(lib, handle, torch_device, y_shape, x_shape, w_shape, dtype=torch.float16, w_dtype=torch.float16): def test(
print(f"Testing RMS_Norm on {torch_device} with y_shape:{y_shape} x_shape:{x_shape} w_shape:{w_shape}" lib,
f" dtype:{dtype} w_dtype:{w_dtype}") handle,
torch_device,
y_shape,
x_shape,
w_shape,
dtype=torch.float16,
w_dtype=torch.float16,
):
print(
f"Testing RMS_Norm on {torch_device} with y_shape:{y_shape} x_shape:{x_shape} w_shape:{w_shape}"
f" dtype:{dtype} w_dtype:{w_dtype}"
)
y = torch.zeros(y_shape, dtype=dtype).to(torch_device) y = torch.zeros(y_shape, dtype=dtype).to(torch_device)
x = torch.rand(x_shape, dtype=dtype).to(torch_device) x = torch.rand(x_shape, dtype=dtype).to(torch_device)
...@@ -50,12 +63,16 @@ def test(lib, handle, torch_device, y_shape, x_shape, w_shape, dtype=torch.float ...@@ -50,12 +63,16 @@ def test(lib, handle, torch_device, y_shape, x_shape, w_shape, dtype=torch.float
w_tensor = to_tensor(w, lib) w_tensor = to_tensor(w, lib)
descriptor = infiniopRMSNormDescriptor_t() descriptor = infiniopRMSNormDescriptor_t()
w_dataType = 0 if w_dtype==torch.float16 else 1 w_dataType = 0 if w_dtype == torch.float16 else 1
check_error( check_error(
lib.infiniopCreateRMSNormDescriptor( lib.infiniopCreateRMSNormDescriptor(
handle, ctypes.byref(descriptor), y_tensor.descriptor, x_tensor.descriptor, handle,
w_tensor.descriptor, eps ctypes.byref(descriptor),
y_tensor.descriptor,
x_tensor.descriptor,
w_tensor.descriptor,
eps,
) )
) )
...@@ -66,9 +83,7 @@ def test(lib, handle, torch_device, y_shape, x_shape, w_shape, dtype=torch.float ...@@ -66,9 +83,7 @@ def test(lib, handle, torch_device, y_shape, x_shape, w_shape, dtype=torch.float
workspace_size = c_uint64(0) workspace_size = c_uint64(0)
check_error( check_error(
lib.infiniopGetRMSNormWorkspaceSize( lib.infiniopGetRMSNormWorkspaceSize(descriptor, ctypes.byref(workspace_size))
descriptor, ctypes.byref(workspace_size)
)
) )
workspace = create_workspace(workspace_size.value, y.device) workspace = create_workspace(workspace_size.value, y.device)
check_error( check_error(
...@@ -86,37 +101,44 @@ def test(lib, handle, torch_device, y_shape, x_shape, w_shape, dtype=torch.float ...@@ -86,37 +101,44 @@ def test(lib, handle, torch_device, y_shape, x_shape, w_shape, dtype=torch.float
assert torch.allclose(y.to(dtype), ans.to(dtype), atol=1e-3, rtol=1e-3) assert torch.allclose(y.to(dtype), ans.to(dtype), atol=1e-3, rtol=1e-3)
check_error(lib.infiniopDestroyRMSNormDescriptor(descriptor)) check_error(lib.infiniopDestroyRMSNormDescriptor(descriptor))
def test_cpu(lib, test_cases): def test_cpu(lib, test_cases):
device = DeviceEnum.DEVICE_CPU device = DeviceEnum.DEVICE_CPU
handle = create_handle(lib, device) handle = create_handle(lib, device)
for (y_shape, x_shape, w_shape, dtype, w_dtype) in test_cases: for y_shape, x_shape, w_shape, dtype, w_dtype in test_cases:
test(lib, handle, "cpu", y_shape, x_shape, w_shape, dtype, w_dtype) test(lib, handle, "cpu", y_shape, x_shape, w_shape, dtype, w_dtype)
destroy_handle(lib, handle) destroy_handle(lib, handle)
def test_cuda(lib, test_cases): def test_cuda(lib, test_cases):
device = DeviceEnum.DEVICE_CUDA device = DeviceEnum.DEVICE_CUDA
handle = create_handle(lib, device) handle = create_handle(lib, device)
for (y_shape, x_shape, w_shape, dtype, w_dtype) in test_cases: for y_shape, x_shape, w_shape, dtype, w_dtype in test_cases:
test(lib, handle, "cuda", y_shape, x_shape, w_shape, dtype, w_dtype) test(lib, handle, "cuda", y_shape, x_shape, w_shape, dtype, w_dtype)
destroy_handle(lib, handle) destroy_handle(lib, handle)
def test_bang(lib, test_cases): def test_bang(lib, test_cases):
import torch_mlu import torch_mlu
device = DeviceEnum.DEVICE_BANG device = DeviceEnum.DEVICE_BANG
handle = create_handle(lib, device) handle = create_handle(lib, device)
for (y_shape, x_shape, w_shape, dtype, w_dtype) in test_cases: for y_shape, x_shape, w_shape, dtype, w_dtype in test_cases:
test(lib, handle, "mlu", y_shape, x_shape, w_shape, dtype, w_dtype) test(lib, handle, "mlu", y_shape, x_shape, w_shape, dtype, w_dtype)
destroy_handle(lib, handle) destroy_handle(lib, handle)
def test_ascend(lib, test_cases): def test_ascend(lib, test_cases):
import torch_npu import torch_npu
device = DeviceEnum.DEVICE_ASCEND device = DeviceEnum.DEVICE_ASCEND
handle = create_handle(lib, device) handle = create_handle(lib, device)
for (y_shape, x_shape, w_shape, dtype, w_dtype) in test_cases: for y_shape, x_shape, w_shape, dtype, w_dtype in test_cases:
test(lib, handle, "npu", y_shape, x_shape, w_shape, dtype, w_dtype) test(lib, handle, "npu", y_shape, x_shape, w_shape, dtype, w_dtype)
destroy_handle(lib, handle) destroy_handle(lib, handle)
if __name__ == "__main__": if __name__ == "__main__":
test_cases = [ test_cases = [
# y_shape, x_shape, w_shape, dtype, w_dtype # y_shape, x_shape, w_shape, dtype, w_dtype
......
...@@ -51,6 +51,7 @@ def rotary_embedding(t, pos, theta, torch_device): ...@@ -51,6 +51,7 @@ def rotary_embedding(t, pos, theta, torch_device):
t_out = torch.view_as_real(t_ * freqs_cis).flatten(2).to(t.dtype) t_out = torch.view_as_real(t_ * freqs_cis).flatten(2).to(t.dtype)
return t_out return t_out
def sin_cos_table(max_seq_len, dim, torch_device, theta): def sin_cos_table(max_seq_len, dim, torch_device, theta):
pos = torch.arange( pos = torch.arange(
0, max_seq_len, dtype=torch.float32, device=torch.device(torch_device) 0, max_seq_len, dtype=torch.float32, device=torch.device(torch_device)
...@@ -73,12 +74,12 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16): ...@@ -73,12 +74,12 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16):
if strides is not None: if strides is not None:
t = rearrange_tensor(t, strides) t = rearrange_tensor(t, strides)
posTmp = torch.arange(0, t.shape[0]) posTmp = torch.arange(0, t.shape[0])
pos = torch.zeros(2 * posTmp.shape[0], dtype = torch.int32) pos = torch.zeros(2 * posTmp.shape[0], dtype=torch.int32)
for i in range(posTmp.shape[0]): for i in range(posTmp.shape[0]):
pos[2 * i] = posTmp[i] pos[2 * i] = posTmp[i]
pos[2 * i + 1] = 0 pos[2 * i + 1] = 0
theta = 1e4 theta = 1e4
if torch_device == 'mlu' or torch_device == 'npu': if torch_device == "mlu" or torch_device == "npu":
ans = rotary_embedding(t, posTmp, theta, "cpu").to(torch_device) ans = rotary_embedding(t, posTmp, theta, "cpu").to(torch_device)
pos = pos.to(torch_device) pos = pos.to(torch_device)
t = t.to(torch_device) t = t.to(torch_device)
...@@ -156,6 +157,7 @@ def test_cuda(lib, test_cases): ...@@ -156,6 +157,7 @@ def test_cuda(lib, test_cases):
def test_bang(lib, test_cases): def test_bang(lib, test_cases):
import torch_mlu import torch_mlu
device = DeviceEnum.DEVICE_BANG device = DeviceEnum.DEVICE_BANG
handle = create_handle(lib, device) handle = create_handle(lib, device)
for shape, strides, dtype in test_cases: for shape, strides, dtype in test_cases:
...@@ -163,7 +165,7 @@ def test_bang(lib, test_cases): ...@@ -163,7 +165,7 @@ def test_bang(lib, test_cases):
destroy_handle(lib, handle) destroy_handle(lib, handle)
def test_ascend(lib, test_cases) : def test_ascend(lib, test_cases):
import torch_npu import torch_npu
device = DeviceEnum.DEVICE_ASCEND device = DeviceEnum.DEVICE_ASCEND
...@@ -172,6 +174,7 @@ def test_ascend(lib, test_cases) : ...@@ -172,6 +174,7 @@ def test_ascend(lib, test_cases) :
test(lib, handle, "npu", shape, strides, dtype) test(lib, handle, "npu", shape, strides, dtype)
destroy_handle(lib, handle) destroy_handle(lib, handle)
if __name__ == "__main__": if __name__ == "__main__":
test_cases = [ test_cases = [
((1, 32, 128), None, torch.float16), ((1, 32, 128), None, torch.float16),
...@@ -180,7 +183,6 @@ if __name__ == "__main__": ...@@ -180,7 +183,6 @@ if __name__ == "__main__":
# 接口 GatherMask 的内部实现相关,目前 48 64 128 都可以支持 # 接口 GatherMask 的内部实现相关,目前 48 64 128 都可以支持
((4, 1, 32), None, torch.float16), ((4, 1, 32), None, torch.float16),
((1, 32, 128), None, torch.float16), ((1, 32, 128), None, torch.float16),
((3, 32, 128), (8000, 200, 1), torch.float16), ((3, 32, 128), (8000, 200, 1), torch.float16),
] ]
args = get_args() args = get_args()
......
...@@ -32,6 +32,7 @@ def swiglu(a, b): ...@@ -32,6 +32,7 @@ def swiglu(a, b):
return a * b / (1 + torch.exp(-b.float()).to(b.dtype)) return a * b / (1 + torch.exp(-b.float()).to(b.dtype))
def test_out_of_place( def test_out_of_place(
lib, lib,
handle, handle,
...@@ -223,6 +224,7 @@ def test_cuda(lib, test_cases): ...@@ -223,6 +224,7 @@ def test_cuda(lib, test_cases):
def test_bang(lib, test_cases): def test_bang(lib, test_cases):
import torch_mlu import torch_mlu
device = DeviceEnum.DEVICE_BANG device = DeviceEnum.DEVICE_BANG
handle = create_handle(lib, device) handle = create_handle(lib, device)
...@@ -238,15 +240,28 @@ def test_bang(lib, test_cases): ...@@ -238,15 +240,28 @@ def test_bang(lib, test_cases):
def test_ascend(lib, test_cases): def test_ascend(lib, test_cases):
import torch_npu import torch_npu
device = DeviceEnum.DEVICE_ASCEND device = DeviceEnum.DEVICE_ASCEND
handle = create_handle(lib, device) handle = create_handle(lib, device)
for shape, a_stride, b_stride, c_stride, dtype in test_cases: for shape, a_stride, b_stride, c_stride, dtype in test_cases:
test_out_of_place( test_out_of_place(
lib, handle, "npu", shape, a_stride, b_stride, c_stride, dtype, torch.npu.synchronize lib,
handle,
"npu",
shape,
a_stride,
b_stride,
c_stride,
dtype,
torch.npu.synchronize,
)
test_in_place1(
lib, handle, "npu", shape, a_stride, b_stride, dtype, torch.npu.synchronize
)
test_in_place2(
lib, handle, "npu", shape, a_stride, b_stride, dtype, torch.npu.synchronize
) )
test_in_place1(lib, handle, "npu", shape, a_stride, b_stride, dtype, torch.npu.synchronize)
test_in_place2(lib, handle, "npu", shape, a_stride, b_stride, dtype, torch.npu.synchronize)
destroy_handle(lib, handle) destroy_handle(lib, handle)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment