Commit 2694b127 authored by Minjie Wang's avatar Minjie Wang
Browse files

import ffi solution from TVM

parent 61fa3c6c
# pylint: disable=invalid-name
"""Runtime NDArray api"""
from __future__ import absolute_import
import ctypes
from ..base import _LIB, check_call, c_str
from ..runtime_ctypes import TVMArrayHandle
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _return_handle
TVMPyCapsuleDestructor = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
_c_str_dltensor = c_str('dltensor')
_c_str_used_dltensor = c_str('used_dltensor')
# used for PyCapsule manipulation
if hasattr(ctypes, 'pythonapi'):
ctypes.pythonapi.PyCapsule_GetName.restype = ctypes.c_char_p
ctypes.pythonapi.PyCapsule_GetPointer.restype = ctypes.c_void_p
ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object
def _from_dlpack(dltensor):
dltensor = ctypes.py_object(dltensor)
if ctypes.pythonapi.PyCapsule_IsValid(dltensor, _c_str_dltensor):
ptr = ctypes.pythonapi.PyCapsule_GetPointer(dltensor, _c_str_dltensor)
handle = TVMArrayHandle()
check_call(_LIB.TVMArrayFromDLPack(ptr, ctypes.byref(handle)))
ctypes.pythonapi.PyCapsule_SetName(dltensor, _c_str_used_dltensor)
ctypes.pythonapi.PyCapsule_SetDestructor(dltensor, TVMPyCapsuleDestructor(0))
return _make_array(handle, False)
raise ValueError("Expect a dltensor field, PyCapsule can only be consumed once")
def _dlpack_deleter(pycapsule):
pycapsule = ctypes.cast(pycapsule, ctypes.py_object)
if ctypes.pythonapi.PyCapsule_IsValid(pycapsule, _c_str_dltensor):
ptr = ctypes.pythonapi.PyCapsule_GetPointer(pycapsule, _c_str_dltensor)
_LIB.TVMDLManagedTensorCallDeleter(ptr)
ctypes.pythonapi.PyCapsule_SetDestructor(dltensor, TVMPyCapsuleDestructor(0))
_c_dlpack_deleter = TVMPyCapsuleDestructor(_dlpack_deleter)
class NDArrayBase(object):
"""A simple Device/CPU Array object in runtime."""
__slots__ = ["handle", "is_view"]
# pylint: disable=no-member
def __init__(self, handle, is_view=False):
"""Initialize the function with handle
Parameters
----------
handle : TVMArrayHandle
the handle to the underlying C++ TVMArray
"""
self.handle = handle
self.is_view = is_view
def __del__(self):
if not self.is_view and _LIB:
check_call(_LIB.TVMArrayFree(self.handle))
@property
def _tvm_handle(self):
return ctypes.cast(self.handle, ctypes.c_void_p).value
def to_dlpack(self):
"""Produce an array from a DLPack Tensor without copying memory
Returns
-------
dlpack : DLPack tensor view of the array data
"""
handle = ctypes.c_void_p()
check_call(_LIB.TVMArrayToDLPack(self.handle, ctypes.byref(handle)))
return ctypes.pythonapi.PyCapsule_New(handle, _c_str_dltensor, _c_dlpack_deleter)
def _make_array(handle, is_view):
handle = ctypes.cast(handle, TVMArrayHandle)
return _CLASS_NDARRAY(handle, is_view)
_TVM_COMPATS = ()
def _reg_extension(cls, fcreate):
global _TVM_COMPATS
_TVM_COMPATS += (cls,)
if fcreate:
fret = lambda x: fcreate(_return_handle(x))
RETURN_SWITCH[cls._tvm_tcode] = fret
C_TO_PY_ARG_SWITCH[cls._tvm_tcode] = _wrap_arg_func(fret, cls._tvm_tcode)
_CLASS_NDARRAY = None
def _set_class_ndarray(cls):
global _CLASS_NDARRAY
_CLASS_NDARRAY = cls
"""The C Types used in API."""
# pylint: disable=invalid-name
from __future__ import absolute_import as _abs
import ctypes
from ..base import py_str, check_call, _LIB
from ..runtime_ctypes import TVMByteArray, TypeCode
class TVMValue(ctypes.Union):
"""TVMValue in C API"""
_fields_ = [("v_int64", ctypes.c_int64),
("v_float64", ctypes.c_double),
("v_handle", ctypes.c_void_p),
("v_str", ctypes.c_char_p)]
TVMPackedCFunc = ctypes.CFUNCTYPE(
ctypes.c_int,
ctypes.POINTER(TVMValue),
ctypes.POINTER(ctypes.c_int),
ctypes.c_int,
ctypes.c_void_p,
ctypes.c_void_p)
TVMCFuncFinalizer = ctypes.CFUNCTYPE(
None,
ctypes.c_void_p)
def _return_handle(x):
"""return handle"""
handle = x.v_handle
if not isinstance(handle, ctypes.c_void_p):
handle = ctypes.c_void_p(handle)
return handle
def _return_bytes(x):
"""return handle"""
handle = x.v_handle
if not isinstance(handle, ctypes.c_void_p):
handle = ctypes.c_void_p(handle)
arr = ctypes.cast(handle, ctypes.POINTER(TVMByteArray))[0]
size = arr.size
res = bytearray(size)
rptr = (ctypes.c_byte * size).from_buffer(res)
if not ctypes.memmove(rptr, arr.data, size):
raise RuntimeError('memmove failed')
return res
def _wrap_arg_func(return_f, type_code):
tcode = ctypes.c_int(type_code)
def _wrap_func(x):
check_call(_LIB.TVMCbArgToReturn(ctypes.byref(x), tcode))
return return_f(x)
return _wrap_func
RETURN_SWITCH = {
TypeCode.INT: lambda x: x.v_int64,
TypeCode.FLOAT: lambda x: x.v_float64,
TypeCode.HANDLE: _return_handle,
TypeCode.NULL: lambda x: None,
TypeCode.STR: lambda x: py_str(x.v_str),
TypeCode.BYTES: _return_bytes
}
C_TO_PY_ARG_SWITCH = {
TypeCode.INT: lambda x: x.v_int64,
TypeCode.FLOAT: lambda x: x.v_float64,
TypeCode.HANDLE: _return_handle,
TypeCode.NULL: lambda x: None,
TypeCode.STR: lambda x: py_str(x.v_str),
TypeCode.BYTES: _return_bytes
}
"""cython2 namespace"""
"""cython3 namespace"""
from ..base import DGLError
from libcpp.vector cimport vector
from cpython.version cimport PY_MAJOR_VERSION
from cpython cimport pycapsule
from libc.stdint cimport int64_t, uint64_t, uint8_t, uint16_t
import ctypes
cdef enum TVMTypeCode:
kInt = 0
kUInt = 1
kFloat = 2
kHandle = 3
kNull = 4
kTVMType = 5
kTVMContext = 6
kArrayHandle = 7
kNodeHandle = 8
kModuleHandle = 9
kFuncHandle = 10
kStr = 11
kBytes = 12
kNDArrayContainer = 13
kExtBegin = 15
cdef extern from "tvm/runtime/c_runtime_api.h":
ctypedef struct DLDataType:
uint8_t code
uint8_t bits
uint16_t lanes
ctypedef struct DLContext:
int device_type
int device_id
ctypedef struct DLTensor:
void* data
DLContext ctx
int ndim
DLDataType dtype
int64_t* shape
int64_t* strides
uint64_t byte_offset
ctypedef struct DLManagedTensor:
DLTensor dl_tensor
void* manager_ctx
void (*deleter)(DLManagedTensor* self)
ctypedef struct TVMValue:
int64_t v_int64
double v_float64
void* v_handle
const char* v_str
DLDataType v_type
DLContext v_ctx
ctypedef int64_t tvm_index_t
ctypedef DLTensor* DLTensorHandle
ctypedef void* TVMStreamHandle
ctypedef void* TVMRetValueHandle
ctypedef void* TVMFunctionHandle
ctypedef void* NodeHandle
ctypedef int (*TVMPackedCFunc)(
TVMValue* args,
int* type_codes,
int num_args,
TVMRetValueHandle ret,
void* resource_handle)
ctypedef void (*TVMPackedCFuncFinalizer)(void* resource_handle)
cdef extern from "tvm/runtime/c_runtime_api.h":
void TVMAPISetLastError(const char* msg)
const char *TVMGetLastError()
int TVMFuncCall(TVMFunctionHandle func,
TVMValue* arg_values,
int* type_codes,
int num_args,
TVMValue* ret_val,
int* ret_type_code)
int TVMFuncFree(TVMFunctionHandle func)
int TVMCFuncSetReturn(TVMRetValueHandle ret,
TVMValue* value,
int* type_code,
int num_ret)
int TVMFuncCreateFromCFunc(TVMPackedCFunc func,
void* resource_handle,
TVMPackedCFuncFinalizer fin,
TVMFunctionHandle *out)
int TVMCbArgToReturn(TVMValue* value, int code)
int TVMArrayAlloc(tvm_index_t* shape,
tvm_index_t ndim,
DLDataType dtype,
DLContext ctx,
DLTensorHandle* out)
int TVMArrayFree(DLTensorHandle handle)
int TVMArrayCopyFromTo(DLTensorHandle src,
DLTensorHandle to,
TVMStreamHandle stream)
int TVMArrayFromDLPack(DLManagedTensor* arr_from,
DLTensorHandle* out)
int TVMArrayToDLPack(DLTensorHandle arr_from,
DLManagedTensor** out)
void TVMDLManagedTensorCallDeleter(DLManagedTensor* dltensor)
cdef extern from "tvm/c_dsl_api.h":
int TVMNodeFree(NodeHandle handle)
int TVMNodeTypeKey2Index(const char* type_key,
int* out_index)
int TVMNodeGetTypeIndex(NodeHandle handle,
int* out_index)
int TVMNodeGetAttr(NodeHandle handle,
const char* key,
TVMValue* out_value,
int* out_type_code,
int* out_success)
cdef inline py_str(const char* x):
if PY_MAJOR_VERSION < 3:
return x
else:
return x.decode("utf-8")
cdef inline c_str(pystr):
"""Create ctypes char * from a python string
Parameters
----------
string : string type
python string
Returns
-------
str : c_char_p
A char pointer that can be passed to C API
"""
return pystr.encode("utf-8")
cdef inline CALL(int ret):
if ret != 0:
raise DGLError(py_str(TVMGetLastError()))
cdef inline object ctypes_handle(void* chandle):
"""Cast C handle to ctypes handle."""
return ctypes.cast(<unsigned long long>chandle, ctypes.c_void_p)
cdef inline void* c_handle(object handle):
"""Cast C types handle to c handle."""
cdef unsigned long long v_ptr
v_ptr = handle.value
return <void*>(v_ptr)
include "./base.pxi"
include "./node.pxi"
include "./function.pxi"
include "./ndarray.pxi"
import ctypes
import traceback
from cpython cimport Py_INCREF, Py_DECREF
from numbers import Number, Integral
from ..base import string_types
from ..node_generic import convert_to_node, NodeGeneric
from ..runtime_ctypes import TVMType, TVMContext, TVMByteArray
cdef void tvm_callback_finalize(void* fhandle):
local_pyfunc = <object>(fhandle)
Py_DECREF(local_pyfunc)
cdef int tvm_callback(TVMValue* args,
int* type_codes,
int num_args,
TVMRetValueHandle ret,
void* fhandle) with gil:
cdef list pyargs
cdef TVMValue value
cdef int tcode
local_pyfunc = <object>(fhandle)
pyargs = []
for i in range(num_args):
value = args[i]
tcode = type_codes[i]
if (tcode == kNodeHandle or
tcode == kFuncHandle or
tcode == kModuleHandle or
tcode > kExtBegin):
CALL(TVMCbArgToReturn(&value, tcode))
if tcode != kArrayHandle:
pyargs.append(make_ret(value, tcode))
else:
pyargs.append(c_make_array(value.v_handle, True))
try:
rv = local_pyfunc(*pyargs)
except Exception:
msg = traceback.format_exc()
TVMAPISetLastError(c_str(msg))
return -1
if rv is not None:
if isinstance(rv, tuple):
raise ValueError("PackedFunction can only support one return value")
temp_args = []
make_arg(rv, &value, &tcode, temp_args)
CALL(TVMCFuncSetReturn(ret, &value, &tcode, 1))
return 0
def convert_to_tvm_func(object pyfunc):
"""Convert a python function to TVM function
Parameters
----------
pyfunc : python function
The python function to be converted.
Returns
-------
tvmfunc: tvm.Function
The converted tvm function.
"""
cdef TVMFunctionHandle chandle
Py_INCREF(pyfunc)
CALL(TVMFuncCreateFromCFunc(tvm_callback,
<void*>(pyfunc),
tvm_callback_finalize,
&chandle))
ret = _CLASS_FUNCTION(None, False)
(<FunctionBase>ret).chandle = chandle
return ret
cdef inline int make_arg(object arg,
TVMValue* value,
int* tcode,
list temp_args) except -1:
"""Pack arguments into c args tvm call accept"""
cdef unsigned long long ptr
if isinstance(arg, NodeBase):
value[0].v_handle = (<NodeBase>arg).chandle
tcode[0] = kNodeHandle
elif isinstance(arg, NDArrayBase):
value[0].v_handle = (<NDArrayBase>arg).chandle
tcode[0] = (kNDArrayContainer if
not (<NDArrayBase>arg).c_is_view else kArrayHandle)
elif isinstance(arg, _TVM_COMPATS):
ptr = arg._tvm_handle
value[0].v_handle = (<void*>ptr)
tcode[0] = arg.__class__._tvm_tcode
elif isinstance(arg, (int, long)):
value[0].v_int64 = arg
tcode[0] = kInt
elif isinstance(arg, float):
value[0].v_float64 = arg
tcode[0] = kFloat
elif isinstance(arg, str):
tstr = c_str(arg)
value[0].v_str = tstr
tcode[0] = kStr
temp_args.append(tstr)
elif arg is None:
value[0].v_handle = NULL
tcode[0] = kNull
elif isinstance(arg, Number):
value[0].v_float64 = arg
tcode[0] = kFloat
elif isinstance(arg, TVMType):
tstr = c_str(str(arg))
value[0].v_str = tstr
tcode[0] = kStr
temp_args.append(tstr)
elif isinstance(arg, TVMContext):
value[0].v_ctx = (<DLContext*>(
<unsigned long long>ctypes.addressof(arg)))[0]
tcode[0] = kTVMContext
elif isinstance(arg, bytearray):
arr = TVMByteArray()
arr.data = ctypes.cast(
(ctypes.c_byte * len(arg)).from_buffer(arg),
ctypes.POINTER(ctypes.c_byte))
arr.size = len(arg)
value[0].v_handle = <void*>(
<unsigned long long>ctypes.addressof(arr))
tcode[0] = kBytes
temp_args.append(arr)
elif isinstance(arg, string_types):
tstr = c_str(arg)
value[0].v_str = tstr
tcode[0] = kStr
temp_args.append(tstr)
elif isinstance(arg, (list, tuple, dict, NodeGeneric)):
arg = convert_to_node(arg)
value[0].v_handle = (<NodeBase>arg).chandle
tcode[0] = kNodeHandle
temp_args.append(arg)
elif isinstance(arg, _CLASS_MODULE):
value[0].v_handle = c_handle(arg.handle)
tcode[0] = kModuleHandle
elif isinstance(arg, FunctionBase):
value[0].v_handle = (<FunctionBase>arg).chandle
tcode[0] = kFuncHandle
elif isinstance(arg, ctypes.c_void_p):
value[0].v_handle = c_handle(arg)
tcode[0] = kHandle
elif callable(arg):
arg = convert_to_tvm_func(arg)
value[0].v_handle = (<FunctionBase>arg).chandle
tcode[0] = kFuncHandle
temp_args.append(arg)
else:
raise TypeError("Don't know how to handle type %s" % type(arg))
return 0
cdef inline bytearray make_ret_bytes(void* chandle):
handle = ctypes_handle(chandle)
arr = ctypes.cast(handle, ctypes.POINTER(TVMByteArray))[0]
size = arr.size
res = bytearray(size)
rptr = (ctypes.c_byte * size).from_buffer(res)
if not ctypes.memmove(rptr, arr.data, size):
raise RuntimeError('memmove failed')
return res
cdef inline object make_ret(TVMValue value, int tcode):
"""convert result to return value."""
if tcode == kNodeHandle:
return make_ret_node(value.v_handle)
elif tcode == kNull:
return None
elif tcode == kInt:
return value.v_int64
elif tcode == kFloat:
return value.v_float64
elif tcode == kNDArrayContainer:
return c_make_array(value.v_handle, False)
elif tcode == kStr:
return py_str(value.v_str)
elif tcode == kBytes:
return make_ret_bytes(value.v_handle)
elif tcode == kHandle:
return ctypes_handle(value.v_handle)
elif tcode == kTVMContext:
return TVMContext(value.v_ctx.device_type, value.v_ctx.device_id)
elif tcode == kModuleHandle:
return _CLASS_MODULE(ctypes_handle(value.v_handle))
elif tcode == kFuncHandle:
fobj = _CLASS_FUNCTION(None, False)
(<FunctionBase>fobj).chandle = value.v_handle
return fobj
elif tcode in _TVM_EXT_RET:
return _TVM_EXT_RET[tcode](ctypes_handle(value.v_handle))
raise ValueError("Unhandled type code %d" % tcode)
cdef inline int FuncCall3(void* chandle,
tuple args,
int nargs,
TVMValue* ret_val,
int* ret_tcode) except -1:
cdef TVMValue[3] values
cdef int[3] tcodes
nargs = len(args)
temp_args = []
for i in range(nargs):
make_arg(args[i], &values[i], &tcodes[i], temp_args)
CALL(TVMFuncCall(chandle, &values[0], &tcodes[0],
nargs, ret_val, ret_tcode))
return 0
cdef inline int FuncCall(void* chandle,
tuple args,
TVMValue* ret_val,
int* ret_tcode) except -1:
cdef int nargs
nargs = len(args)
if nargs <= 3:
FuncCall3(chandle, args, nargs, ret_val, ret_tcode)
return 0
cdef vector[TVMValue] values
cdef vector[int] tcodes
values.resize(max(nargs, 1))
tcodes.resize(max(nargs, 1))
temp_args = []
for i in range(nargs):
make_arg(args[i], &values[i], &tcodes[i], temp_args)
CALL(TVMFuncCall(chandle, &values[0], &tcodes[0],
nargs, ret_val, ret_tcode))
return 0
cdef inline int ConstructorCall(void* constructor_handle,
int type_code,
tuple args,
void** handle) except -1:
"""Call contructor of a handle function"""
cdef TVMValue ret_val
cdef int ret_tcode
FuncCall(constructor_handle, args, &ret_val, &ret_tcode)
assert ret_tcode == type_code
handle[0] = ret_val.v_handle
return 0
cdef class FunctionBase:
cdef TVMFunctionHandle chandle
cdef int is_global
cdef inline _set_handle(self, handle):
if handle is None:
self.chandle = NULL
else:
self.chandle = c_handle(handle)
property is_global:
def __get__(self):
return self.c_is_global != 0
def __set__(self, value):
self.c_is_global = value
property handle:
def __get__(self):
if self.chandle == NULL:
return None
else:
return ctypes.cast(<unsigned long long>self.chandle, ctypes.c_void_p)
def __set__(self, value):
self._set_handle(value)
def __init__(self, handle, is_global):
self._set_handle(handle)
self.c_is_global = is_global
def __dealloc__(self):
if self.is_global == 0:
CALL(TVMFuncFree(self.chandle))
def __call__(self, *args):
cdef TVMValue ret_val
cdef int ret_tcode
FuncCall(self.chandle, args, &ret_val, &ret_tcode)
return make_ret(ret_val, ret_tcode)
_CLASS_FUNCTION = None
_CLASS_MODULE = None
def _set_class_module(module_class):
"""Initialize the module."""
global _CLASS_MODULE
_CLASS_MODULE = module_class
def _set_class_function(func_class):
global _CLASS_FUNCTION
_CLASS_FUNCTION = func_class
from ..runtime_ctypes import TVMArrayHandle
cdef const char* _c_str_dltensor = "dltensor"
cdef const char* _c_str_used_dltensor = "used_dltensor"
cdef void _c_dlpack_deleter(object pycaps):
cdef DLManagedTensor* dltensor
if pycapsule.PyCapsule_IsValid(pycaps, _c_str_dltensor):
dltensor = <DLManagedTensor*>pycapsule.PyCapsule_GetPointer(pycaps, _c_str_dltensor)
TVMDLManagedTensorCallDeleter(dltensor)
def _from_dlpack(object dltensor):
cdef DLManagedTensor* ptr
cdef DLTensorHandle chandle
if pycapsule.PyCapsule_IsValid(dltensor, _c_str_dltensor):
ptr = <DLManagedTensor*>pycapsule.PyCapsule_GetPointer(dltensor, _c_str_dltensor)
CALL(TVMArrayFromDLPack(ptr, &chandle))
# set name and destructor to be empty
pycapsule.PyCapsule_SetDestructor(dltensor, NULL)
pycapsule.PyCapsule_SetName(dltensor, _c_str_used_dltensor)
return c_make_array(chandle, 0)
raise ValueError("Expect a dltensor field, pycapsule.PyCapsule can only be consumed once")
cdef class NDArrayBase:
cdef DLTensor* chandle
cdef int c_is_view
cdef inline _set_handle(self, handle):
cdef unsigned long long ptr
if handle is None:
self.chandle = NULL
else:
ptr = ctypes.cast(handle, ctypes.c_void_p).value
self.chandle = <DLTensor*>(ptr)
property _tvm_handle:
def __get__(self):
return <unsigned long long>self.chandle
property handle:
def __get__(self):
if self.chandle == NULL:
return None
else:
return ctypes.cast(
<unsigned long long>self.chandle, TVMArrayHandle)
def __set__(self, value):
self._set_handle(value)
def __init__(self, handle, is_view):
self._set_handle(handle)
self.c_is_view = is_view
def __dealloc__(self):
if self.c_is_view == 0:
CALL(TVMArrayFree(self.chandle))
def to_dlpack(self):
"""Produce an array from a DLPack Tensor without copying memory
Returns
-------
dlpack : DLPack tensor view of the array data
"""
cdef DLManagedTensor* dltensor
if self.c_is_view != 0:
raise ValueError("to_dlpack do not work with memory views")
CALL(TVMArrayToDLPack(self.chandle, &dltensor))
return pycapsule.PyCapsule_New(dltensor, _c_str_dltensor, _c_dlpack_deleter)
cdef c_make_array(void* chandle, is_view):
ret = _CLASS_NDARRAY(None, is_view)
(<NDArrayBase>ret).chandle = <DLTensor*>chandle
return ret
cdef _TVM_COMPATS = ()
cdef _TVM_EXT_RET = {}
def _reg_extension(cls, fcreate):
global _TVM_COMPATS
_TVM_COMPATS += (cls,)
if fcreate:
_TVM_EXT_RET[cls._tvm_tcode] = fcreate
def _make_array(handle, is_view):
cdef unsigned long long ptr
ptr = ctypes.cast(handle, ctypes.c_void_p).value
return c_make_array(<void*>ptr, is_view)
cdef object _CLASS_NDARRAY = None
def _set_class_ndarray(cls):
global _CLASS_NDARRAY
_CLASS_NDARRAY = cls
from ... import _api_internal
from ..base import string_types
from ..node_generic import _set_class_node_base
"""Maps node type to its constructor"""
NODE_TYPE = []
def _register_node(int index, object cls):
"""register node class"""
while len(NODE_TYPE) <= index:
NODE_TYPE.append(None)
NODE_TYPE[index] = cls
cdef inline object make_ret_node(void* chandle):
global NODE_TYPE
cdef int tindex
cdef list node_type
cdef object cls
node_type = NODE_TYPE
CALL(TVMNodeGetTypeIndex(chandle, &tindex))
if tindex < len(node_type):
cls = node_type[tindex]
if cls is not None:
obj = cls.__new__(cls)
else:
obj = NodeBase.__new__(NodeBase)
else:
obj = NodeBase.__new__(NodeBase)
(<NodeBase>obj).chandle = chandle
return obj
cdef class NodeBase:
cdef void* chandle
cdef _set_handle(self, handle):
cdef unsigned long long ptr
if handle is None:
self.chandle = NULL
else:
ptr = handle.value
self.chandle = <void*>(ptr)
property handle:
def __get__(self):
if self.chandle == NULL:
return None
else:
return ctypes_handle(self.chandle)
def __set__(self, value):
self._set_handle(value)
def __dealloc__(self):
CALL(TVMNodeFree(self.chandle))
def __getattr__(self, name):
cdef TVMValue ret_val
cdef int ret_type_code, ret_succ
CALL(TVMNodeGetAttr(self.chandle, c_str(name),
&ret_val, &ret_type_code, &ret_succ))
if ret_succ == 0:
raise AttributeError(
"'%s' object has no attribute '%s'" % (type(self), name))
return make_ret(ret_val, ret_type_code)
def __init_handle_by_constructor__(self, fconstructor, *args):
"""Initialize the handle by calling constructor function.
Parameters
----------
fconstructor : Function
Constructor function.
args: list of objects
The arguments to the constructor
Note
----
We have a special calling convention to call constructor functions.
So the return handle is directly set into the Node object
instead of creating a new Node.
"""
cdef void* chandle
ConstructorCall(
(<FunctionBase>fconstructor).chandle,
kNodeHandle, args, &chandle)
self.chandle = chandle
_set_class_node_base(NodeBase)
# coding: utf-8
# pylint: disable=invalid-name
"""ctypes library and helper functions """
from __future__ import absolute_import
import sys
import os
import ctypes
import numpy as np
from . import libinfo
#----------------------------
# library loading
#----------------------------
if sys.version_info[0] == 3:
string_types = (str,)
numeric_types = (float, int, np.float32, np.int32)
# this function is needed for python3
# to convert ctypes.char_p .value back to python str
py_str = lambda x: x.decode('utf-8')
else:
string_types = (basestring,)
numeric_types = (float, int, long, np.float32, np.int32)
py_str = lambda x: x
class DGLError(Exception):
"""Error thrown by DGL function"""
pass
def _load_lib():
"""Load libary by searching possible path."""
lib_path = libinfo.find_lib_path()
lib = ctypes.CDLL(lib_path[0], ctypes.RTLD_GLOBAL)
# DMatrix functions
lib.TVMGetLastError.restype = ctypes.c_char_p
return lib, os.path.basename(lib_path[0])
# version number
__version__ = libinfo.__version__
# library instance of nnvm
_LIB, _LIB_NAME = _load_lib()
# The FFI mode of DGL
_FFI_MODE = os.environ.get("DGL_FFI", "auto")
#----------------------------
# helper function in ctypes.
#----------------------------
def check_call(ret):
"""Check the return value of C API call
This function will raise exception when error occurs.
Wrap every API call with this function
Parameters
----------
ret : int
return value from API calls
"""
if ret != 0:
raise DGLError(py_str(_LIB.TVMGetLastError()))
def c_str(string):
"""Create ctypes char * from a python string
Parameters
----------
string : string type
python string
Returns
-------
str : c_char_p
A char pointer that can be passed to C API
"""
return ctypes.c_char_p(string.encode('utf-8'))
def c_array(ctype, values):
"""Create ctypes array from a python array
Parameters
----------
ctype : ctypes data type
data type of the array we want to convert to
values : tuple or list
data content
Returns
-------
out : ctypes array
Created ctypes array
"""
return (ctype * len(values))(*values)
def decorate(func, fwrapped):
"""A wrapper call of decorator package, differs to call time
Parameters
----------
func : function
The original function
fwrapped : function
The wrapped function
"""
import decorator
return decorator.decorate(func, fwrapped)
# pylint: disable=invalid-name, unused-import
"""Function namespace."""
from __future__ import absolute_import
import sys
import ctypes
from .base import _LIB, check_call, py_str, c_str, string_types, _FFI_MODE
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
try:
# pylint: disable=wrong-import-position
if _FFI_MODE == "ctypes":
raise ImportError()
if sys.version_info >= (3, 0):
from ._cy3.core import _set_class_function, _set_class_module
from ._cy3.core import FunctionBase as _FunctionBase
from ._cy3.core import convert_to_tvm_func
else:
from ._cy2.core import _set_class_function, _set_class_module
from ._cy2.core import FunctionBase as _FunctionBase
from ._cy2.core import convert_to_tvm_func
except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position
from ._ctypes.function import _set_class_function, _set_class_module
from ._ctypes.function import FunctionBase as _FunctionBase
from ._ctypes.function import convert_to_tvm_func
FunctionHandle = ctypes.c_void_p
class Function(_FunctionBase):
"""The PackedFunc object used in TVM.
Function plays an key role to bridge front and backend in TVM.
Function provide a type-erased interface, you can call function with positional arguments.
The compiled module returns Function.
TVM backend also registers and exposes its API as Functions.
For example, the developer function exposed in tvm.ir_pass are actually
C++ functions that are registered as PackedFunc
The following are list of common usage scenario of tvm.Function.
- Automatic exposure of C++ API into python
- To call PackedFunc from python side
- To call python callbacks to inspect results in generated code
- Bring python hook into C++ backend
See Also
--------
tvm.register_func: How to register global function.
tvm.get_global_func: How to get global function.
"""
pass
class ModuleBase(object):
"""Base class for module"""
__slots__ = ["handle", "_entry", "entry_name"]
def __init__(self, handle):
self.handle = handle
self._entry = None
self.entry_name = "__tvm_main__"
def __del__(self):
check_call(_LIB.TVMModFree(self.handle))
@property
def entry_func(self):
"""Get the entry function
Returns
-------
f : Function
The entry function if exist
"""
if self._entry:
return self._entry
self._entry = self.get_function(self.entry_name)
return self._entry
def get_function(self, name, query_imports=False):
"""Get function from the module.
Parameters
----------
name : str
The name of the function
query_imports : bool
Whether also query modules imported by this module.
Returns
-------
f : Function
The result function.
"""
ret_handle = FunctionHandle()
check_call(_LIB.TVMModGetFunction(
self.handle, c_str(name),
ctypes.c_int(query_imports),
ctypes.byref(ret_handle)))
if not ret_handle.value:
raise AttributeError(
"Module has no function '%s'" % name)
return Function(ret_handle, False)
def import_module(self, module):
"""Add module to the import list of current one.
Parameters
----------
module : Module
The other module.
"""
check_call(_LIB.TVMModImport(self.handle, module.handle))
def __getitem__(self, name):
if not isinstance(name, string_types):
raise ValueError("Can only take string as function name")
return self.get_function(name)
def __call__(self, *args):
if self._entry:
return self._entry(*args)
f = self.entry_func
return f(*args)
def register_func(func_name, f=None, override=False):
"""Register global function
Parameters
----------
func_name : str or function
The function name
f : function, optional
The function to be registered.
override: boolean optional
Whether override existing entry.
Returns
-------
fregister : function
Register function if f is not specified.
Examples
--------
The following code registers my_packed_func as global function.
Note that we simply get it back from global function table to invoke
it from python side. However, we can also invoke the same function
from C++ backend, or in the compiled TVM code.
.. code-block:: python
targs = (10, 10.0, "hello")
@tvm.register_func
def my_packed_func(*args):
assert(tuple(args) == targs)
return 10
# Get it out from global function table
f = tvm.get_global_func("my_packed_func")
assert isinstance(f, tvm.nd.Function)
y = f(*targs)
assert y == 10
"""
if callable(func_name):
f = func_name
func_name = f.__name__
if not isinstance(func_name, str):
raise ValueError("expect string function name")
ioverride = ctypes.c_int(override)
def register(myf):
"""internal register function"""
if not isinstance(myf, Function):
myf = convert_to_tvm_func(myf)
check_call(_LIB.TVMFuncRegisterGlobal(
c_str(func_name), myf.handle, ioverride))
return myf
if f:
return register(f)
return register
def get_global_func(name, allow_missing=False):
"""Get a global function by name
Parameters
----------
name : str
The name of the global function
allow_missing : bool
Whether allow missing function or raise an error.
Returns
-------
func : tvm.Function
The function to be returned, None if function is missing.
"""
handle = FunctionHandle()
check_call(_LIB.TVMFuncGetGlobal(c_str(name), ctypes.byref(handle)))
if handle.value:
return Function(handle, False)
else:
if allow_missing:
return None
else:
raise ValueError("Cannot find global function %s" % name)
def list_global_func_names():
"""Get list of global functions registered.
Returns
-------
names : list
List of global functions names.
"""
plist = ctypes.POINTER(ctypes.c_char_p)()
size = ctypes.c_uint()
check_call(_LIB.TVMFuncListGlobalNames(ctypes.byref(size),
ctypes.byref(plist)))
fnames = []
for i in range(size.value):
fnames.append(py_str(plist[i]))
return fnames
def extract_ext_funcs(finit):
"""
Extract the extension PackedFuncs from a C module.
Parameters
----------
finit : ctypes function
a ctypes that takes signature of TVMExtensionDeclarer
Returns
-------
fdict : dict of str to Function
The extracted functions
"""
fdict = {}
def _list(name, func):
fdict[name] = func
myf = convert_to_tvm_func(_list)
ret = finit(myf.handle)
_ = myf
if ret != 0:
raise RuntimeError("cannot initialize with %s" % finit)
return fdict
def _get_api(f):
flocal = f
flocal.is_global = True
return flocal
def _init_api(namespace, target_module_name=None):
"""Initialize api for a given module name
namespace : str
The namespace of the source registry
target_module_name : str
The target module name if different from namespace
"""
target_module_name = (
target_module_name if target_module_name else namespace)
if namespace.startswith("tvm."):
_init_api_prefix(target_module_name, namespace[4:])
else:
_init_api_prefix(target_module_name, namespace)
def _init_api_prefix(module_name, prefix):
module = sys.modules[module_name]
for name in list_global_func_names():
if prefix == "api":
fname = name
if name.startswith("_"):
target_module = sys.modules["tvm._api_internal"]
else:
target_module = module
else:
if not name.startswith(prefix):
continue
fname = name[len(prefix)+1:]
target_module = module
if fname.find(".") != -1:
continue
f = get_global_func(name)
ff = _get_api(f)
ff.__name__ = fname
ff.__doc__ = ("TVM PackedFunc %s. " % fname)
setattr(target_module, ff.__name__, ff)
_set_class_function(Function)
"""Library information."""
from __future__ import absolute_import
import sys
import os
def find_lib_path(name=None, search_path=None, optional=False):
"""Find dynamic library files.
Parameters
----------
name : list of str
List of names to be found.
Returns
-------
lib_path : list(string)
List of all found path to the libraries
"""
# See https://github.com/dmlc/tvm/issues/281 for some background.
# NB: This will either be the source directory (if TVM is run
# inplace) or the install directory (if TVM is installed).
# An installed TVM's curr_path will look something like:
# $PREFIX/lib/python3.6/site-packages/tvm/_ffi
ffi_dir = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
source_dir = os.path.join(ffi_dir, "..", "..", "..")
install_lib_dir = os.path.join(ffi_dir, "..", "..", "..", "..")
dll_path = []
if os.environ.get('DGL_LIBRARY_PATH', None):
dll_path.append(os.environ['DGL_LIBRARY_PATH'])
if sys.platform.startswith('linux') and os.environ.get('LD_LIBRARY_PATH', None):
dll_path.extend([p.strip() for p in os.environ['LD_LIBRARY_PATH'].split(":")])
elif sys.platform.startswith('darwin') and os.environ.get('DYLD_LIBRARY_PATH', None):
dll_path.extend([p.strip() for p in os.environ['DYLD_LIBRARY_PATH'].split(":")])
# Pip lib directory
dll_path.append(os.path.join(ffi_dir, ".."))
# Default cmake build directory
dll_path.append(os.path.join(source_dir, "build"))
dll_path.append(os.path.join(source_dir, "build", "Release"))
# Default make build directory
dll_path.append(os.path.join(source_dir, "lib"))
dll_path.append(install_lib_dir)
dll_path = [os.path.abspath(x) for x in dll_path]
if search_path is not None:
if search_path is list:
dll_path = dll_path + search_path
else:
dll_path.append(search_path)
if name is not None:
if isinstance(name, list):
lib_dll_path = []
for n in name:
lib_dll_path += [os.path.join(p, n) for p in dll_path]
else:
lib_dll_path = [os.path.join(p, name) for p in dll_path]
else:
if sys.platform.startswith('win32'):
lib_dll_path = [os.path.join(p, 'libdgl.dll') for p in dll_path] +\
[os.path.join(p, 'dgl.dll') for p in dll_path]
elif sys.platform.startswith('darwin'):
lib_dll_path = [os.path.join(p, 'libdgl.dylib') for p in dll_path]
else:
lib_dll_path = [os.path.join(p, 'libdgl.so') for p in dll_path]
# try to find lib_dll_path
lib_found = [p for p in lib_dll_path if os.path.exists(p) and os.path.isfile(p)]
if not lib_found:
message = ('Cannot find the files.\n' +
'List of candidates:\n' +
str('\n'.join(lib_dll_path + runtime_dll_path)))
if not optional:
raise RuntimeError(message)
return None
return lib_found
# current version
# We use the version of the incoming release for code
# that is under development.
# The following line is set by tvm/python/update_version.py
__version__ = "0.5.dev"
This diff is collapsed.
This diff is collapsed.
# C API and runtime
Borrowed and adapted from TVM project.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment