from ctypes import POINTER, Structure, c_int32, c_uint64, c_void_p, c_float import ctypes import sys import os sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) from operatorspy import ( open_lib, to_tensor, DeviceEnum, infiniopHandle_t, infiniopTensorDescriptor_t, create_handle, destroy_handle, check_error, rearrange_tensor, create_workspace, U64, ) from operatorspy.tests.test_utils import get_args import torch class RandomSampleDescriptor(Structure): _fields_ = [("device", c_int32)] infiniopRandomSampleDescriptor_t = POINTER(RandomSampleDescriptor) def random_sample(data, random_val, topp, topk, voc, temperature, torch_device): indices = torch.zeros([topk], dtype = torch.int64) dataNp = data.clone().detach() sorted_indices = torch.arange(voc) for i in range(topk): for j in range(i + 1, voc): if(dataNp[i] < dataNp[j]): tmp = dataNp[i].clone().detach() dataNp[i] = dataNp[j].clone().detach() dataNp[j] = tmp tmpInd = sorted_indices[i].clone().detach() sorted_indices[i] = sorted_indices[j].clone().detach() sorted_indices[j] = tmpInd #sorted_indices = torch.argsort(dataNp, descending=True) indices = sorted_indices[:topk] dataNp = dataNp[sorted_indices] globalM = dataNp[0] dataNp = (dataNp - globalM) / temperature dataNp = torch.softmax(dataNp.float(), dim = 0) sum_s = 0 for end in range(topk): sum_s += dataNp[end] if(sum_s >= topp): break if(end < topk - 1): end += 1 else: end = topk sum_s = 0 for i in range(end): sum_s += dataNp[i] random_val *= sum_s sum_s = 0 for i in range(end): sum_s += dataNp[i] if(random_val < sum_s): return indices[i] def random_sample_0(data): return torch.argmax(data) def test(lib, handle, torch_device, voc, random_val, topp, topk, temperature, x_dtype=torch.float16): print( f"Testing RandomSample on {torch_device} with voc:{voc} dtype:{x_dtype}" ) data = torch.arange(voc).float() * 0.0001 _perm = torch.randperm(voc) data = data[_perm].to(x_dtype).to(torch_device) if(topp > 0 and topk > 1): ans = random_sample(data.to("cpu"), random_val, topp, topk, voc, temperature, "cpu") else: ans = random_sample_0(data) indices = torch.zeros([1], dtype=torch.int64).to(torch_device) x_tensor = to_tensor(data, lib) indices_tensor = to_tensor(indices, lib) indices_tensor.descriptor.contents.dt = U64 # treat int64 as uint64 descriptor = infiniopRandomSampleDescriptor_t() check_error( lib.infiniopCreateRandomSampleDescriptor( handle, ctypes.byref(descriptor), indices_tensor.descriptor, x_tensor.descriptor ) ) # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel x_tensor.descriptor.contents.invalidate() indices_tensor.descriptor.contents.invalidate() workspace_size = c_uint64(0) check_error( lib.infiniopGetRandomSampleWorkspaceSize( descriptor, ctypes.byref(workspace_size) ) ) workspace = create_workspace(workspace_size.value, torch_device) check_error( lib.infiniopRandomSample( descriptor, workspace.data_ptr() if workspace is not None else None, workspace_size.value, indices_tensor.data, x_tensor.data, random_val, topp, topk, temperature, None, ) ) if torch_device == "npu": torch.npu.synchronize() assert indices[0].type(ans.dtype) == ans or data[ans] == data[indices[0]] check_error(lib.infiniopDestroyRandomSampleDescriptor(descriptor)) def test_cpu(lib, test_cases): device = DeviceEnum.DEVICE_CPU handle = create_handle(lib, device) for (voc, random_val, topp, topk, temperature) in test_cases: test(lib, handle, "cpu", voc, random_val, topp, topk, temperature) destroy_handle(lib, handle) def test_cuda(lib, test_cases): device = DeviceEnum.DEVICE_CUDA handle = create_handle(lib, device) for (voc, random_val, topp, topk, temperature) in test_cases: test(lib, handle, "cuda", voc, random_val, topp, topk, temperature) destroy_handle(lib, handle) def test_bang(lib, test_cases): import torch_mlu device = DeviceEnum.DEVICE_BANG handle = create_handle(lib, device) for (voc, random_val, topp, topk, temperature) in test_cases: test(lib, handle, "mlu", voc, random_val, topp, topk, temperature) destroy_handle(lib, handle) def test_ascend(lib, test_cases): import torch_npu device = DeviceEnum.DEVICE_ASCEND handle = create_handle(lib, device) for (voc, random_val, topp, topk, temperature) in test_cases: test(lib, handle, "npu", voc, random_val, topp, topk, temperature) destroy_handle(lib, handle) if __name__ == "__main__": test_cases = [ # voc, random_val, topp, topk, temperature (512, 0.8, 0.8, 3, 0.5), (4096, 0.05, 0.9, 5, 1.0), (16384, 0.15, 0.85, 10, 2.0), (512, 0.08, 0, 3, 0.5), (4096, 0.5, 0.9, 1, 1.0), (16384, 0.15, 0, 1, 2.0), (16384, 0.15, 0, 1, 2.0), (32000, 0.08, 0.8, 50, 1.0), (32000, 0.08, 1.0, 25, 1.0), # (119696, 0.01, 1.0, 100, 1.0), ] args = get_args() lib = open_lib() lib.infiniopCreateRandomSampleDescriptor.restype = c_int32 lib.infiniopCreateRandomSampleDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopRandomSampleDescriptor_t), infiniopTensorDescriptor_t, ] lib.infiniopGetRandomSampleWorkspaceSize.restype = c_int32 lib.infiniopGetRandomSampleWorkspaceSize.argtypes = [ infiniopRandomSampleDescriptor_t, POINTER(c_uint64), ] lib.infiniopRandomSample.restype = c_int32 lib.infiniopRandomSample.argtypes = [ infiniopRandomSampleDescriptor_t, c_void_p, c_uint64, c_uint64, c_void_p, c_float, c_float, c_int32, c_float, c_void_p, ] lib.infiniopDestroyRandomSampleDescriptor.restype = c_int32 lib.infiniopDestroyRandomSampleDescriptor.argtypes = [ infiniopRandomSampleDescriptor_t, ] if args.cpu: test_cpu(lib, test_cases) if args.cuda: test_cuda(lib, test_cases) if args.bang: test_bang(lib, test_cases) if args.ascend: test_ascend(lib, test_cases) if not (args.cpu or args.cuda or args.bang or args.ascend): test_cpu(lib, test_cases) print("\033[92mTest passed!\033[0m")