from ctypes import POINTER, Structure, c_int32, c_void_p, c_uint64 import ctypes import sys import os import time sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) from operatorspy import ( open_lib, to_tensor, DeviceEnum, infiniopHandle_t, infiniopTensorDescriptor_t, create_handle, destroy_handle, check_error, ) from operatorspy.tests.test_utils import get_args import torch from typing import Tuple # constant for control whether profile the pytorch and lib functions # NOTE: need to manually add synchronization function to the lib function, # e.g., cudaDeviceSynchronize() for CUDA PROFILE = False NUM_PRERUN = 10 NUM_ITERATIONS = 1000 class MaxPoolDescriptor(Structure): _fields_ = [("device", c_int32)] infiniopMaxPoolDescriptor_t = POINTER(MaxPoolDescriptor) def pool(x, k, padding, stride, dilation = 1): pooling_layers = { 1: torch.nn.MaxPool1d, 2: torch.nn.MaxPool2d, 3: torch.nn.MaxPool3d, } ndim = len(x.shape) - 2 if ndim not in pooling_layers: print("Error: Pytorch -> Unsupported tensor dimension") return None ans = pooling_layers[ndim](k, stride=stride, padding=padding, dilation=dilation)(x) if PROFILE: torch.cuda.synchronize() return ans def inferShape(x_shape, kernel_shape, padding, strides): assert ( len(x_shape) - 2 == len(kernel_shape) == len(padding) == len(strides) ), "kernel, pads, and strides should have the same length; the length of input x should be 2 more than that of kernel" input_shape = x_shape[2:] output_shape = [] for dim, k, p, s in zip(input_shape, kernel_shape, padding, strides): output_dim = (dim + 2 * p - k) // s + 1 output_shape.append(output_dim) return x_shape[:2] + tuple(output_shape) # convert a python tuple to a ctype void pointer def tuple_to_void_p(py_tuple: Tuple): array = ctypes.c_int64 * len(py_tuple) data_array = array(*py_tuple) return ctypes.cast(data_array, ctypes.c_void_p) def test( lib, handle, torch_device, x_shape, k_shape, padding, strides, tensor_dtype=torch.float16, ): print( f"Testing MaxPool on {torch_device} with x_shape:{x_shape} kernel_shape:{k_shape} padding:{padding} strides:{strides} dtype:{tensor_dtype}" ) x = torch.rand(x_shape, dtype=tensor_dtype).to(torch_device) y = torch.rand(inferShape(x_shape, k_shape, padding, strides), dtype=tensor_dtype).to(torch_device) for i in range(NUM_PRERUN if PROFILE else 1): ans = pool(x, k_shape, padding, strides) if PROFILE: start_time = time.time() for i in range(NUM_ITERATIONS): _ = pool(x, k_shape, padding, strides) elapsed = (time.time() - start_time) / NUM_ITERATIONS print(f"pytorch time: {elapsed :6f}") x_tensor = to_tensor(x, lib) y_tensor = to_tensor(y, lib) descriptor = infiniopMaxPoolDescriptor_t() check_error( lib.infiniopCreateMaxPoolDescriptor( handle, ctypes.byref(descriptor), y_tensor.descriptor, x_tensor.descriptor, tuple_to_void_p(k_shape), tuple_to_void_p(padding), tuple_to_void_p(strides), len(k_shape), ) ) # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel x_tensor.descriptor.contents.invalidate() y_tensor.descriptor.contents.invalidate() workspaceSize = ctypes.c_uint64(0) check_error( lib.infiniopGetMaxPoolWorkspaceSize(descriptor, ctypes.byref(workspaceSize)) ) workspace = torch.zeros(int(workspaceSize.value), dtype=torch.uint8).to(torch_device) workspace_ptr = ctypes.cast(workspace.data_ptr(), ctypes.POINTER(ctypes.c_uint8)) for i in range(NUM_PRERUN if PROFILE else 1): check_error( lib.infiniopMaxPool( descriptor, workspace_ptr, workspaceSize, y_tensor.data, x_tensor.data, None, ) ) if PROFILE: start_time = time.time() for i in range(NUM_ITERATIONS): check_error( lib.infiniopMaxPool( descriptor, workspace_ptr, workspaceSize, y_tensor.data, x_tensor.data, None, ) ) elapsed = (time.time() - start_time) / NUM_ITERATIONS print(f" lib time: {elapsed :6f}") assert torch.allclose(y, ans, atol=0, rtol=1e-3) check_error(lib.infiniopDestroyMaxPoolDescriptor(descriptor)) def test_cpu(lib, test_cases): device = DeviceEnum.DEVICE_CPU handle = create_handle(lib, device) for x_shape, kernel_shape, padding, strides in test_cases: test(lib, handle, "cpu", x_shape, kernel_shape, padding, strides, tensor_dtype=torch.float16) test(lib, handle, "cpu", x_shape, kernel_shape, padding, strides, tensor_dtype=torch.float32) destroy_handle(lib, handle) def test_cuda(lib, test_cases): device = DeviceEnum.DEVICE_CUDA handle = create_handle(lib, device) for x_shape, kernel_shape, padding, strides in test_cases: test(lib, handle, "cuda", x_shape, kernel_shape, padding, strides, tensor_dtype=torch.float16) test(lib, handle, "cuda", x_shape, kernel_shape, padding, strides, tensor_dtype=torch.float32) destroy_handle(lib, handle) def test_bang(lib, test_cases): import torch_mlu device = DeviceEnum.DEVICE_BANG handle = create_handle(lib, device) for x_shape, kernel_shape, padding, strides in test_cases: test(lib, handle, "mlu", x_shape, kernel_shape, padding, strides, tensor_dtype=torch.float16) test(lib, handle, "mlu", x_shape, kernel_shape, padding, strides, tensor_dtype=torch.float32) destroy_handle(lib, handle) if __name__ == "__main__": test_cases = [ # x_shape, kernel_shape, padding, strides ((1, 1, 10), (3,), (1,), (1,)), ((32, 3, 224, 224), (3, 3), (1, 1), (2, 2)), ((1, 1, 16, 16, 16), (5, 5, 5), (2, 2, 2), (2, 2, 2)), ] args = get_args() lib = open_lib() lib.infiniopCreateMaxPoolDescriptor.restype = c_int32 lib.infiniopCreateMaxPoolDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopMaxPoolDescriptor_t), infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, c_void_p, c_void_p, c_void_p, c_uint64, ] lib.infiniopGetMaxPoolWorkspaceSize.restype = c_int32 lib.infiniopGetMaxPoolWorkspaceSize.argtypes = [ infiniopMaxPoolDescriptor_t, POINTER(c_uint64), ] lib.infiniopMaxPool.restype = c_int32 lib.infiniopMaxPool.argtypes = [ infiniopMaxPoolDescriptor_t, c_void_p, c_uint64, c_void_p, c_void_p, c_void_p, ] lib.infiniopDestroyMaxPoolDescriptor.restype = c_int32 lib.infiniopDestroyMaxPoolDescriptor.argtypes = [ infiniopMaxPoolDescriptor_t, ] if args.cpu: test_cpu(lib, test_cases) if args.cuda: test_cuda(lib, test_cases) if args.bang: test_bang(lib, test_cases) if not (args.cpu or args.cuda or args.bang): test_cpu(lib, test_cases) print("\033[92mTest passed!\033[0m")