# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions # are met: # * Redistributions of source code must retain the above copyright # notice, this list of conditions and the following disclaimer. # * Redistributions in binary form must reproduce the above copyright # notice, this list of conditions and the following disclaimer in the # documentation and/or other materials provided with the distribution. # * Neither the name of NVIDIA CORPORATION nor the names of its # contributors may be used to endorse or promote products derived # from this software without specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ################################################################################ # This file contains the DLPack API wrapped in Python style (see # 'dlpack.h' for detail) and the utilities for Triton client to interact # with DLPack # # Ref: # https://github.com/dmlc/dlpack/blob/main/include/dlpack/dlpack.h # https://github.com/dmlc/dlpack/blob/main/apps/numpy_dlpack/dlpack/from_numpy.py ################################################################################ import ctypes # Need to explicit set the res / arg types for pythonapi functions to # work properly ctypes.pythonapi.PyMem_RawMalloc.restype = ctypes.c_void_p ctypes.pythonapi.PyMem_RawFree.argtypes = [ctypes.c_void_p] ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object ctypes.pythonapi.PyCapsule_New.argtypes = [ ctypes.c_void_p, ctypes.c_char_p, ctypes.c_void_p, ] ctypes.pythonapi.PyCapsule_GetPointer.restype = ctypes.c_void_p ctypes.pythonapi.PyCapsule_GetPointer.argtypes = [ctypes.py_object, ctypes.c_char_p] c_str_dltensor = b"dltensor" class DLDeviceType(ctypes.c_int): kDLCPU = 1 kDLCUDA = 2 kDLCUDAHost = 3 kDLOpenCL = 4 kDLVulkan = 7 kDLMetal = 8 kDLVPI = 9 kDLROCM = 10 kDLROCMHost = 11 kDLExtDev = 12 kDLCUDAManaged = 13 kDLOneAPI = 14 kDLWebGPU = 15 kDLHexagon = 16 class DLDevice(ctypes.Structure): _fields_ = [ ("device_type", DLDeviceType), ("device_id", ctypes.c_int), ] class DLDataTypeCode(ctypes.c_uint8): kDLInt = 0 kDLUInt = 1 kDLFloat = 2 kDLOpaquePointer = 3 kDLBfloat = 4 kDLComplex = 5 kDLBool = 6 class DLDataType(ctypes.Structure): _fields_ = [ ("type_code", DLDataTypeCode), ("bits", ctypes.c_uint8), ("lanes", ctypes.c_uint16), ] class DLTensor(ctypes.Structure): _fields_ = [ ("data", ctypes.c_void_p), ("device", DLDevice), ("ndim", ctypes.c_int), ("dtype", DLDataType), ("shape", ctypes.POINTER(ctypes.c_int64)), ("strides", ctypes.POINTER(ctypes.c_int64)), ("byte_offset", ctypes.c_uint64), ] class DLManagedTensor(ctypes.Structure): _fields_ = [ ("dl_tensor", DLTensor), ("manager_ctx", ctypes.c_void_p), ("deleter", ctypes.CFUNCTYPE(None, ctypes.c_void_p)), ] # Utilities def _raise_error(msg): """ Raise error with the provided message """ raise Exception(msg=msg) from None # Use as managed context in DLPack that doesn't hold ownership of the # data content. class DataViewContext: def __init__(self, shape) -> None: # Convert the Python object to ctypes objects expected by # DLPack self._shape = (ctypes.c_int64 * len(shape))(*shape) # No strides: compact and row-major self._strides = ctypes.POINTER(ctypes.c_int64)() def as_manager_ctx(self) -> ctypes.c_void_p: py_obj = ctypes.py_object(self) py_obj_ptr = ctypes.pointer(py_obj) ctypes.pythonapi.Py_IncRef(py_obj) ctypes.pythonapi.Py_IncRef(ctypes.py_object(py_obj_ptr)) return ctypes.cast(py_obj_ptr, ctypes.c_void_p) @ctypes.CFUNCTYPE(None, ctypes.c_void_p) def managed_tensor_deleter(handle: ctypes.c_void_p) -> None: dl_managed_tensor = DLManagedTensor.from_address(handle) py_obj_ptr = ctypes.cast( dl_managed_tensor.manager_ctx, ctypes.POINTER(ctypes.py_object) ) py_obj = py_obj_ptr.contents ctypes.pythonapi.Py_DecRef(py_obj) ctypes.pythonapi.Py_DecRef(ctypes.py_object(py_obj_ptr)) ctypes.pythonapi.PyMem_RawFree(handle) @ctypes.CFUNCTYPE(None, ctypes.c_void_p) def pycapsule_deleter(handle: ctypes.c_void_p) -> None: pycapsule: ctypes.py_object = ctypes.cast(handle, ctypes.py_object) if ctypes.pythonapi.PyCapsule_IsValid(pycapsule, c_str_dltensor): dl_managed_tensor = ctypes.pythonapi.PyCapsule_GetPointer( pycapsule, c_str_dltensor ) managed_tensor_deleter(dl_managed_tensor) ctypes.pythonapi.PyCapsule_SetDestructor(pycapsule, None) def triton_to_dlpack_dtype(dtype): if dtype == "BOOL": type_code = DLDataTypeCode.kDLBool bits = 1 elif dtype == "INT8": type_code = DLDataTypeCode.kDLInt bits = 8 elif dtype == "INT16": type_code = DLDataTypeCode.kDLInt bits = 16 elif dtype == "INT32": type_code = DLDataTypeCode.kDLInt bits = 32 elif dtype == "INT64": type_code = DLDataTypeCode.kDLInt bits = 64 elif dtype == "UINT8": type_code = DLDataTypeCode.kDLUInt bits = 8 elif dtype == "UINT16": type_code = DLDataTypeCode.kDLUInt bits = 16 elif dtype == "UINT32": type_code = DLDataTypeCode.kDLUInt bits = 32 elif dtype == "UINT64": type_code = DLDataTypeCode.kDLUInt bits = 64 elif dtype == "FP16": type_code = DLDataTypeCode.kDLFloat bits = 16 elif dtype == "FP32": type_code = DLDataTypeCode.kDLFloat bits = 32 elif dtype == "FP64": type_code = DLDataTypeCode.kDLFloat bits = 64 elif dtype == "BF16": type_code = DLDataTypeCode.kDLBfloat bits = 16 elif dtype == "BYTES": _raise_error("DLPack currently doesn't suppose BYTES type") else: _raise_error( "Can not covert unknown data type '{}' to DLPack data type".format(dtype) ) return DLDataType(type_code, bits, 1) def is_contiguous_data( ndim: ctypes.c_int, shape: ctypes.POINTER(ctypes.c_int64), stride: ctypes.POINTER(ctypes.c_int64), ): # If 'stride' doesn't capture valid value if (stride is None) or (not bool(stride)): return True calculated_stride = 1 # iterate stride in reverse order [ndim-1, -1) for i in reversed(range(ndim)): if stride[i] != calculated_stride: return False calculated_stride *= shape[i] return True def get_byte_size( dtype: DLDataType, ndim: ctypes.c_int, shape: ctypes.POINTER(ctypes.c_int64) ): element_byte_size = dtype.bits * dtype.lanes // 8 # Assume 8 bits in a byte for i in range(ndim): element_byte_size *= shape[i] return element_byte_size def get_dlpack_capsule(dlpack_obj, stream=None): # Extract PyCapsule of the DLPack object if hasattr(dlpack_obj, "__dlpack__"): if not hasattr(dlpack_obj, "__dlpack_device__"): _raise_error( "DLPack expects '__dlpack_device__' if '__dlpack__' has been defined" ) device = dlpack_obj.__dlpack_device__() # Have to condition on the device type as, using numpy as example, # some DLPack implementation doesn't accept 'stream' as arguments if device != DLDeviceType.kDLCUDA: return dlpack_obj.__dlpack__() else: return dlpack_obj.__dlpack__(stream) else: # Old interface where PyCapsule object is passed directly return dlpack_obj def get_dlpack_device(dlpack_obj): if hasattr(dlpack_obj, "__dlpack_device__"): return dlpack_obj.__dlpack_device__() return None def get_managed_tensor(dlcapsule): ptr = ctypes.pythonapi.PyCapsule_GetPointer(dlcapsule, c_str_dltensor) return DLManagedTensor.from_address(ptr)