Commit 2694b127 authored by Minjie Wang's avatar Minjie Wang
Browse files

import ffi solution from TVM

parent 61fa3c6c
# pylint: disable=invalid-name
"""Runtime NDArray api"""
from __future__ import absolute_import
import ctypes
from ..base import _LIB, check_call, c_str
from ..runtime_ctypes import TVMArrayHandle
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _return_handle
TVMPyCapsuleDestructor = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
_c_str_dltensor = c_str('dltensor')
_c_str_used_dltensor = c_str('used_dltensor')
# used for PyCapsule manipulation
if hasattr(ctypes, 'pythonapi'):
ctypes.pythonapi.PyCapsule_GetName.restype = ctypes.c_char_p
ctypes.pythonapi.PyCapsule_GetPointer.restype = ctypes.c_void_p
ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object
def _from_dlpack(dltensor):
dltensor = ctypes.py_object(dltensor)
if ctypes.pythonapi.PyCapsule_IsValid(dltensor, _c_str_dltensor):
ptr = ctypes.pythonapi.PyCapsule_GetPointer(dltensor, _c_str_dltensor)
handle = TVMArrayHandle()
check_call(_LIB.TVMArrayFromDLPack(ptr, ctypes.byref(handle)))
ctypes.pythonapi.PyCapsule_SetName(dltensor, _c_str_used_dltensor)
ctypes.pythonapi.PyCapsule_SetDestructor(dltensor, TVMPyCapsuleDestructor(0))
return _make_array(handle, False)
raise ValueError("Expect a dltensor field, PyCapsule can only be consumed once")
def _dlpack_deleter(pycapsule):
pycapsule = ctypes.cast(pycapsule, ctypes.py_object)
if ctypes.pythonapi.PyCapsule_IsValid(pycapsule, _c_str_dltensor):
ptr = ctypes.pythonapi.PyCapsule_GetPointer(pycapsule, _c_str_dltensor)
_LIB.TVMDLManagedTensorCallDeleter(ptr)
ctypes.pythonapi.PyCapsule_SetDestructor(dltensor, TVMPyCapsuleDestructor(0))
_c_dlpack_deleter = TVMPyCapsuleDestructor(_dlpack_deleter)
class NDArrayBase(object):
"""A simple Device/CPU Array object in runtime."""
__slots__ = ["handle", "is_view"]
# pylint: disable=no-member
def __init__(self, handle, is_view=False):
"""Initialize the function with handle
Parameters
----------
handle : TVMArrayHandle
the handle to the underlying C++ TVMArray
"""
self.handle = handle
self.is_view = is_view
def __del__(self):
if not self.is_view and _LIB:
check_call(_LIB.TVMArrayFree(self.handle))
@property
def _tvm_handle(self):
return ctypes.cast(self.handle, ctypes.c_void_p).value
def to_dlpack(self):
"""Produce an array from a DLPack Tensor without copying memory
Returns
-------
dlpack : DLPack tensor view of the array data
"""
handle = ctypes.c_void_p()
check_call(_LIB.TVMArrayToDLPack(self.handle, ctypes.byref(handle)))
return ctypes.pythonapi.PyCapsule_New(handle, _c_str_dltensor, _c_dlpack_deleter)
def _make_array(handle, is_view):
handle = ctypes.cast(handle, TVMArrayHandle)
return _CLASS_NDARRAY(handle, is_view)
_TVM_COMPATS = ()
def _reg_extension(cls, fcreate):
global _TVM_COMPATS
_TVM_COMPATS += (cls,)
if fcreate:
fret = lambda x: fcreate(_return_handle(x))
RETURN_SWITCH[cls._tvm_tcode] = fret
C_TO_PY_ARG_SWITCH[cls._tvm_tcode] = _wrap_arg_func(fret, cls._tvm_tcode)
_CLASS_NDARRAY = None
def _set_class_ndarray(cls):
global _CLASS_NDARRAY
_CLASS_NDARRAY = cls
"""The C Types used in API."""
# pylint: disable=invalid-name
from __future__ import absolute_import as _abs
import ctypes
from ..base import py_str, check_call, _LIB
from ..runtime_ctypes import TVMByteArray, TypeCode
class TVMValue(ctypes.Union):
"""TVMValue in C API"""
_fields_ = [("v_int64", ctypes.c_int64),
("v_float64", ctypes.c_double),
("v_handle", ctypes.c_void_p),
("v_str", ctypes.c_char_p)]
TVMPackedCFunc = ctypes.CFUNCTYPE(
ctypes.c_int,
ctypes.POINTER(TVMValue),
ctypes.POINTER(ctypes.c_int),
ctypes.c_int,
ctypes.c_void_p,
ctypes.c_void_p)
TVMCFuncFinalizer = ctypes.CFUNCTYPE(
None,
ctypes.c_void_p)
def _return_handle(x):
"""return handle"""
handle = x.v_handle
if not isinstance(handle, ctypes.c_void_p):
handle = ctypes.c_void_p(handle)
return handle
def _return_bytes(x):
"""return handle"""
handle = x.v_handle
if not isinstance(handle, ctypes.c_void_p):
handle = ctypes.c_void_p(handle)
arr = ctypes.cast(handle, ctypes.POINTER(TVMByteArray))[0]
size = arr.size
res = bytearray(size)
rptr = (ctypes.c_byte * size).from_buffer(res)
if not ctypes.memmove(rptr, arr.data, size):
raise RuntimeError('memmove failed')
return res
def _wrap_arg_func(return_f, type_code):
tcode = ctypes.c_int(type_code)
def _wrap_func(x):
check_call(_LIB.TVMCbArgToReturn(ctypes.byref(x), tcode))
return return_f(x)
return _wrap_func
RETURN_SWITCH = {
TypeCode.INT: lambda x: x.v_int64,
TypeCode.FLOAT: lambda x: x.v_float64,
TypeCode.HANDLE: _return_handle,
TypeCode.NULL: lambda x: None,
TypeCode.STR: lambda x: py_str(x.v_str),
TypeCode.BYTES: _return_bytes
}
C_TO_PY_ARG_SWITCH = {
TypeCode.INT: lambda x: x.v_int64,
TypeCode.FLOAT: lambda x: x.v_float64,
TypeCode.HANDLE: _return_handle,
TypeCode.NULL: lambda x: None,
TypeCode.STR: lambda x: py_str(x.v_str),
TypeCode.BYTES: _return_bytes
}
"""cython2 namespace"""
"""cython3 namespace"""
from ..base import DGLError
from libcpp.vector cimport vector
from cpython.version cimport PY_MAJOR_VERSION
from cpython cimport pycapsule
from libc.stdint cimport int64_t, uint64_t, uint8_t, uint16_t
import ctypes
cdef enum TVMTypeCode:
kInt = 0
kUInt = 1
kFloat = 2
kHandle = 3
kNull = 4
kTVMType = 5
kTVMContext = 6
kArrayHandle = 7
kNodeHandle = 8
kModuleHandle = 9
kFuncHandle = 10
kStr = 11
kBytes = 12
kNDArrayContainer = 13
kExtBegin = 15
cdef extern from "tvm/runtime/c_runtime_api.h":
ctypedef struct DLDataType:
uint8_t code
uint8_t bits
uint16_t lanes
ctypedef struct DLContext:
int device_type
int device_id
ctypedef struct DLTensor:
void* data
DLContext ctx
int ndim
DLDataType dtype
int64_t* shape
int64_t* strides
uint64_t byte_offset
ctypedef struct DLManagedTensor:
DLTensor dl_tensor
void* manager_ctx
void (*deleter)(DLManagedTensor* self)
ctypedef struct TVMValue:
int64_t v_int64
double v_float64
void* v_handle
const char* v_str
DLDataType v_type
DLContext v_ctx
ctypedef int64_t tvm_index_t
ctypedef DLTensor* DLTensorHandle
ctypedef void* TVMStreamHandle
ctypedef void* TVMRetValueHandle
ctypedef void* TVMFunctionHandle
ctypedef void* NodeHandle
ctypedef int (*TVMPackedCFunc)(
TVMValue* args,
int* type_codes,
int num_args,
TVMRetValueHandle ret,
void* resource_handle)
ctypedef void (*TVMPackedCFuncFinalizer)(void* resource_handle)
cdef extern from "tvm/runtime/c_runtime_api.h":
void TVMAPISetLastError(const char* msg)
const char *TVMGetLastError()
int TVMFuncCall(TVMFunctionHandle func,
TVMValue* arg_values,
int* type_codes,
int num_args,
TVMValue* ret_val,
int* ret_type_code)
int TVMFuncFree(TVMFunctionHandle func)
int TVMCFuncSetReturn(TVMRetValueHandle ret,
TVMValue* value,
int* type_code,
int num_ret)
int TVMFuncCreateFromCFunc(TVMPackedCFunc func,
void* resource_handle,
TVMPackedCFuncFinalizer fin,
TVMFunctionHandle *out)
int TVMCbArgToReturn(TVMValue* value, int code)
int TVMArrayAlloc(tvm_index_t* shape,
tvm_index_t ndim,
DLDataType dtype,
DLContext ctx,
DLTensorHandle* out)
int TVMArrayFree(DLTensorHandle handle)
int TVMArrayCopyFromTo(DLTensorHandle src,
DLTensorHandle to,
TVMStreamHandle stream)
int TVMArrayFromDLPack(DLManagedTensor* arr_from,
DLTensorHandle* out)
int TVMArrayToDLPack(DLTensorHandle arr_from,
DLManagedTensor** out)
void TVMDLManagedTensorCallDeleter(DLManagedTensor* dltensor)
cdef extern from "tvm/c_dsl_api.h":
int TVMNodeFree(NodeHandle handle)
int TVMNodeTypeKey2Index(const char* type_key,
int* out_index)
int TVMNodeGetTypeIndex(NodeHandle handle,
int* out_index)
int TVMNodeGetAttr(NodeHandle handle,
const char* key,
TVMValue* out_value,
int* out_type_code,
int* out_success)
cdef inline py_str(const char* x):
if PY_MAJOR_VERSION < 3:
return x
else:
return x.decode("utf-8")
cdef inline c_str(pystr):
"""Create ctypes char * from a python string
Parameters
----------
string : string type
python string
Returns
-------
str : c_char_p
A char pointer that can be passed to C API
"""
return pystr.encode("utf-8")
cdef inline CALL(int ret):
if ret != 0:
raise DGLError(py_str(TVMGetLastError()))
cdef inline object ctypes_handle(void* chandle):
"""Cast C handle to ctypes handle."""
return ctypes.cast(<unsigned long long>chandle, ctypes.c_void_p)
cdef inline void* c_handle(object handle):
"""Cast C types handle to c handle."""
cdef unsigned long long v_ptr
v_ptr = handle.value
return <void*>(v_ptr)
include "./base.pxi"
include "./node.pxi"
include "./function.pxi"
include "./ndarray.pxi"
import ctypes
import traceback
from cpython cimport Py_INCREF, Py_DECREF
from numbers import Number, Integral
from ..base import string_types
from ..node_generic import convert_to_node, NodeGeneric
from ..runtime_ctypes import TVMType, TVMContext, TVMByteArray
cdef void tvm_callback_finalize(void* fhandle):
local_pyfunc = <object>(fhandle)
Py_DECREF(local_pyfunc)
cdef int tvm_callback(TVMValue* args,
int* type_codes,
int num_args,
TVMRetValueHandle ret,
void* fhandle) with gil:
cdef list pyargs
cdef TVMValue value
cdef int tcode
local_pyfunc = <object>(fhandle)
pyargs = []
for i in range(num_args):
value = args[i]
tcode = type_codes[i]
if (tcode == kNodeHandle or
tcode == kFuncHandle or
tcode == kModuleHandle or
tcode > kExtBegin):
CALL(TVMCbArgToReturn(&value, tcode))
if tcode != kArrayHandle:
pyargs.append(make_ret(value, tcode))
else:
pyargs.append(c_make_array(value.v_handle, True))
try:
rv = local_pyfunc(*pyargs)
except Exception:
msg = traceback.format_exc()
TVMAPISetLastError(c_str(msg))
return -1
if rv is not None:
if isinstance(rv, tuple):
raise ValueError("PackedFunction can only support one return value")
temp_args = []
make_arg(rv, &value, &tcode, temp_args)
CALL(TVMCFuncSetReturn(ret, &value, &tcode, 1))
return 0
def convert_to_tvm_func(object pyfunc):
"""Convert a python function to TVM function
Parameters
----------
pyfunc : python function
The python function to be converted.
Returns
-------
tvmfunc: tvm.Function
The converted tvm function.
"""
cdef TVMFunctionHandle chandle
Py_INCREF(pyfunc)
CALL(TVMFuncCreateFromCFunc(tvm_callback,
<void*>(pyfunc),
tvm_callback_finalize,
&chandle))
ret = _CLASS_FUNCTION(None, False)
(<FunctionBase>ret).chandle = chandle
return ret
cdef inline int make_arg(object arg,
TVMValue* value,
int* tcode,
list temp_args) except -1:
"""Pack arguments into c args tvm call accept"""
cdef unsigned long long ptr
if isinstance(arg, NodeBase):
value[0].v_handle = (<NodeBase>arg).chandle
tcode[0] = kNodeHandle
elif isinstance(arg, NDArrayBase):
value[0].v_handle = (<NDArrayBase>arg).chandle
tcode[0] = (kNDArrayContainer if
not (<NDArrayBase>arg).c_is_view else kArrayHandle)
elif isinstance(arg, _TVM_COMPATS):
ptr = arg._tvm_handle
value[0].v_handle = (<void*>ptr)
tcode[0] = arg.__class__._tvm_tcode
elif isinstance(arg, (int, long)):
value[0].v_int64 = arg
tcode[0] = kInt
elif isinstance(arg, float):
value[0].v_float64 = arg
tcode[0] = kFloat
elif isinstance(arg, str):
tstr = c_str(arg)
value[0].v_str = tstr
tcode[0] = kStr
temp_args.append(tstr)
elif arg is None:
value[0].v_handle = NULL
tcode[0] = kNull
elif isinstance(arg, Number):
value[0].v_float64 = arg
tcode[0] = kFloat
elif isinstance(arg, TVMType):
tstr = c_str(str(arg))
value[0].v_str = tstr
tcode[0] = kStr
temp_args.append(tstr)
elif isinstance(arg, TVMContext):
value[0].v_ctx = (<DLContext*>(
<unsigned long long>ctypes.addressof(arg)))[0]
tcode[0] = kTVMContext
elif isinstance(arg, bytearray):
arr = TVMByteArray()
arr.data = ctypes.cast(
(ctypes.c_byte * len(arg)).from_buffer(arg),
ctypes.POINTER(ctypes.c_byte))
arr.size = len(arg)
value[0].v_handle = <void*>(
<unsigned long long>ctypes.addressof(arr))
tcode[0] = kBytes
temp_args.append(arr)
elif isinstance(arg, string_types):
tstr = c_str(arg)
value[0].v_str = tstr
tcode[0] = kStr
temp_args.append(tstr)
elif isinstance(arg, (list, tuple, dict, NodeGeneric)):
arg = convert_to_node(arg)
value[0].v_handle = (<NodeBase>arg).chandle
tcode[0] = kNodeHandle
temp_args.append(arg)
elif isinstance(arg, _CLASS_MODULE):
value[0].v_handle = c_handle(arg.handle)
tcode[0] = kModuleHandle
elif isinstance(arg, FunctionBase):
value[0].v_handle = (<FunctionBase>arg).chandle
tcode[0] = kFuncHandle
elif isinstance(arg, ctypes.c_void_p):
value[0].v_handle = c_handle(arg)
tcode[0] = kHandle
elif callable(arg):
arg = convert_to_tvm_func(arg)
value[0].v_handle = (<FunctionBase>arg).chandle
tcode[0] = kFuncHandle
temp_args.append(arg)
else:
raise TypeError("Don't know how to handle type %s" % type(arg))
return 0
cdef inline bytearray make_ret_bytes(void* chandle):
handle = ctypes_handle(chandle)
arr = ctypes.cast(handle, ctypes.POINTER(TVMByteArray))[0]
size = arr.size
res = bytearray(size)
rptr = (ctypes.c_byte * size).from_buffer(res)
if not ctypes.memmove(rptr, arr.data, size):
raise RuntimeError('memmove failed')
return res
cdef inline object make_ret(TVMValue value, int tcode):
"""convert result to return value."""
if tcode == kNodeHandle:
return make_ret_node(value.v_handle)
elif tcode == kNull:
return None
elif tcode == kInt:
return value.v_int64
elif tcode == kFloat:
return value.v_float64
elif tcode == kNDArrayContainer:
return c_make_array(value.v_handle, False)
elif tcode == kStr:
return py_str(value.v_str)
elif tcode == kBytes:
return make_ret_bytes(value.v_handle)
elif tcode == kHandle:
return ctypes_handle(value.v_handle)
elif tcode == kTVMContext:
return TVMContext(value.v_ctx.device_type, value.v_ctx.device_id)
elif tcode == kModuleHandle:
return _CLASS_MODULE(ctypes_handle(value.v_handle))
elif tcode == kFuncHandle:
fobj = _CLASS_FUNCTION(None, False)
(<FunctionBase>fobj).chandle = value.v_handle
return fobj
elif tcode in _TVM_EXT_RET:
return _TVM_EXT_RET[tcode](ctypes_handle(value.v_handle))
raise ValueError("Unhandled type code %d" % tcode)
cdef inline int FuncCall3(void* chandle,
tuple args,
int nargs,
TVMValue* ret_val,
int* ret_tcode) except -1:
cdef TVMValue[3] values
cdef int[3] tcodes
nargs = len(args)
temp_args = []
for i in range(nargs):
make_arg(args[i], &values[i], &tcodes[i], temp_args)
CALL(TVMFuncCall(chandle, &values[0], &tcodes[0],
nargs, ret_val, ret_tcode))
return 0
cdef inline int FuncCall(void* chandle,
tuple args,
TVMValue* ret_val,
int* ret_tcode) except -1:
cdef int nargs
nargs = len(args)
if nargs <= 3:
FuncCall3(chandle, args, nargs, ret_val, ret_tcode)
return 0
cdef vector[TVMValue] values
cdef vector[int] tcodes
values.resize(max(nargs, 1))
tcodes.resize(max(nargs, 1))
temp_args = []
for i in range(nargs):
make_arg(args[i], &values[i], &tcodes[i], temp_args)
CALL(TVMFuncCall(chandle, &values[0], &tcodes[0],
nargs, ret_val, ret_tcode))
return 0
cdef inline int ConstructorCall(void* constructor_handle,
int type_code,
tuple args,
void** handle) except -1:
"""Call contructor of a handle function"""
cdef TVMValue ret_val
cdef int ret_tcode
FuncCall(constructor_handle, args, &ret_val, &ret_tcode)
assert ret_tcode == type_code
handle[0] = ret_val.v_handle
return 0
cdef class FunctionBase:
cdef TVMFunctionHandle chandle
cdef int is_global
cdef inline _set_handle(self, handle):
if handle is None:
self.chandle = NULL
else:
self.chandle = c_handle(handle)
property is_global:
def __get__(self):
return self.c_is_global != 0
def __set__(self, value):
self.c_is_global = value
property handle:
def __get__(self):
if self.chandle == NULL:
return None
else:
return ctypes.cast(<unsigned long long>self.chandle, ctypes.c_void_p)
def __set__(self, value):
self._set_handle(value)
def __init__(self, handle, is_global):
self._set_handle(handle)
self.c_is_global = is_global
def __dealloc__(self):
if self.is_global == 0:
CALL(TVMFuncFree(self.chandle))
def __call__(self, *args):
cdef TVMValue ret_val
cdef int ret_tcode
FuncCall(self.chandle, args, &ret_val, &ret_tcode)
return make_ret(ret_val, ret_tcode)
_CLASS_FUNCTION = None
_CLASS_MODULE = None
def _set_class_module(module_class):
"""Initialize the module."""
global _CLASS_MODULE
_CLASS_MODULE = module_class
def _set_class_function(func_class):
global _CLASS_FUNCTION
_CLASS_FUNCTION = func_class
from ..runtime_ctypes import TVMArrayHandle
cdef const char* _c_str_dltensor = "dltensor"
cdef const char* _c_str_used_dltensor = "used_dltensor"
cdef void _c_dlpack_deleter(object pycaps):
cdef DLManagedTensor* dltensor
if pycapsule.PyCapsule_IsValid(pycaps, _c_str_dltensor):
dltensor = <DLManagedTensor*>pycapsule.PyCapsule_GetPointer(pycaps, _c_str_dltensor)
TVMDLManagedTensorCallDeleter(dltensor)
def _from_dlpack(object dltensor):
cdef DLManagedTensor* ptr
cdef DLTensorHandle chandle
if pycapsule.PyCapsule_IsValid(dltensor, _c_str_dltensor):
ptr = <DLManagedTensor*>pycapsule.PyCapsule_GetPointer(dltensor, _c_str_dltensor)
CALL(TVMArrayFromDLPack(ptr, &chandle))
# set name and destructor to be empty
pycapsule.PyCapsule_SetDestructor(dltensor, NULL)
pycapsule.PyCapsule_SetName(dltensor, _c_str_used_dltensor)
return c_make_array(chandle, 0)
raise ValueError("Expect a dltensor field, pycapsule.PyCapsule can only be consumed once")
cdef class NDArrayBase:
cdef DLTensor* chandle
cdef int c_is_view
cdef inline _set_handle(self, handle):
cdef unsigned long long ptr
if handle is None:
self.chandle = NULL
else:
ptr = ctypes.cast(handle, ctypes.c_void_p).value
self.chandle = <DLTensor*>(ptr)
property _tvm_handle:
def __get__(self):
return <unsigned long long>self.chandle
property handle:
def __get__(self):
if self.chandle == NULL:
return None
else:
return ctypes.cast(
<unsigned long long>self.chandle, TVMArrayHandle)
def __set__(self, value):
self._set_handle(value)
def __init__(self, handle, is_view):
self._set_handle(handle)
self.c_is_view = is_view
def __dealloc__(self):
if self.c_is_view == 0:
CALL(TVMArrayFree(self.chandle))
def to_dlpack(self):
"""Produce an array from a DLPack Tensor without copying memory
Returns
-------
dlpack : DLPack tensor view of the array data
"""
cdef DLManagedTensor* dltensor
if self.c_is_view != 0:
raise ValueError("to_dlpack do not work with memory views")
CALL(TVMArrayToDLPack(self.chandle, &dltensor))
return pycapsule.PyCapsule_New(dltensor, _c_str_dltensor, _c_dlpack_deleter)
cdef c_make_array(void* chandle, is_view):
ret = _CLASS_NDARRAY(None, is_view)
(<NDArrayBase>ret).chandle = <DLTensor*>chandle
return ret
cdef _TVM_COMPATS = ()
cdef _TVM_EXT_RET = {}
def _reg_extension(cls, fcreate):
global _TVM_COMPATS
_TVM_COMPATS += (cls,)
if fcreate:
_TVM_EXT_RET[cls._tvm_tcode] = fcreate
def _make_array(handle, is_view):
cdef unsigned long long ptr
ptr = ctypes.cast(handle, ctypes.c_void_p).value
return c_make_array(<void*>ptr, is_view)
cdef object _CLASS_NDARRAY = None
def _set_class_ndarray(cls):
global _CLASS_NDARRAY
_CLASS_NDARRAY = cls
from ... import _api_internal
from ..base import string_types
from ..node_generic import _set_class_node_base
"""Maps node type to its constructor"""
NODE_TYPE = []
def _register_node(int index, object cls):
"""register node class"""
while len(NODE_TYPE) <= index:
NODE_TYPE.append(None)
NODE_TYPE[index] = cls
cdef inline object make_ret_node(void* chandle):
global NODE_TYPE
cdef int tindex
cdef list node_type
cdef object cls
node_type = NODE_TYPE
CALL(TVMNodeGetTypeIndex(chandle, &tindex))
if tindex < len(node_type):
cls = node_type[tindex]
if cls is not None:
obj = cls.__new__(cls)
else:
obj = NodeBase.__new__(NodeBase)
else:
obj = NodeBase.__new__(NodeBase)
(<NodeBase>obj).chandle = chandle
return obj
cdef class NodeBase:
cdef void* chandle
cdef _set_handle(self, handle):
cdef unsigned long long ptr
if handle is None:
self.chandle = NULL
else:
ptr = handle.value
self.chandle = <void*>(ptr)
property handle:
def __get__(self):
if self.chandle == NULL:
return None
else:
return ctypes_handle(self.chandle)
def __set__(self, value):
self._set_handle(value)
def __dealloc__(self):
CALL(TVMNodeFree(self.chandle))
def __getattr__(self, name):
cdef TVMValue ret_val
cdef int ret_type_code, ret_succ
CALL(TVMNodeGetAttr(self.chandle, c_str(name),
&ret_val, &ret_type_code, &ret_succ))
if ret_succ == 0:
raise AttributeError(
"'%s' object has no attribute '%s'" % (type(self), name))
return make_ret(ret_val, ret_type_code)
def __init_handle_by_constructor__(self, fconstructor, *args):
"""Initialize the handle by calling constructor function.
Parameters
----------
fconstructor : Function
Constructor function.
args: list of objects
The arguments to the constructor
Note
----
We have a special calling convention to call constructor functions.
So the return handle is directly set into the Node object
instead of creating a new Node.
"""
cdef void* chandle
ConstructorCall(
(<FunctionBase>fconstructor).chandle,
kNodeHandle, args, &chandle)
self.chandle = chandle
_set_class_node_base(NodeBase)
# coding: utf-8
# pylint: disable=invalid-name
"""ctypes library and helper functions """
from __future__ import absolute_import
import sys
import os
import ctypes
import numpy as np
from . import libinfo
#----------------------------
# library loading
#----------------------------
if sys.version_info[0] == 3:
string_types = (str,)
numeric_types = (float, int, np.float32, np.int32)
# this function is needed for python3
# to convert ctypes.char_p .value back to python str
py_str = lambda x: x.decode('utf-8')
else:
string_types = (basestring,)
numeric_types = (float, int, long, np.float32, np.int32)
py_str = lambda x: x
class DGLError(Exception):
"""Error thrown by DGL function"""
pass
def _load_lib():
"""Load libary by searching possible path."""
lib_path = libinfo.find_lib_path()
lib = ctypes.CDLL(lib_path[0], ctypes.RTLD_GLOBAL)
# DMatrix functions
lib.TVMGetLastError.restype = ctypes.c_char_p
return lib, os.path.basename(lib_path[0])
# version number
__version__ = libinfo.__version__
# library instance of nnvm
_LIB, _LIB_NAME = _load_lib()
# The FFI mode of DGL
_FFI_MODE = os.environ.get("DGL_FFI", "auto")
#----------------------------
# helper function in ctypes.
#----------------------------
def check_call(ret):
"""Check the return value of C API call
This function will raise exception when error occurs.
Wrap every API call with this function
Parameters
----------
ret : int
return value from API calls
"""
if ret != 0:
raise DGLError(py_str(_LIB.TVMGetLastError()))
def c_str(string):
"""Create ctypes char * from a python string
Parameters
----------
string : string type
python string
Returns
-------
str : c_char_p
A char pointer that can be passed to C API
"""
return ctypes.c_char_p(string.encode('utf-8'))
def c_array(ctype, values):
"""Create ctypes array from a python array
Parameters
----------
ctype : ctypes data type
data type of the array we want to convert to
values : tuple or list
data content
Returns
-------
out : ctypes array
Created ctypes array
"""
return (ctype * len(values))(*values)
def decorate(func, fwrapped):
"""A wrapper call of decorator package, differs to call time
Parameters
----------
func : function
The original function
fwrapped : function
The wrapped function
"""
import decorator
return decorator.decorate(func, fwrapped)
# pylint: disable=invalid-name, unused-import
"""Function namespace."""
from __future__ import absolute_import
import sys
import ctypes
from .base import _LIB, check_call, py_str, c_str, string_types, _FFI_MODE
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
try:
# pylint: disable=wrong-import-position
if _FFI_MODE == "ctypes":
raise ImportError()
if sys.version_info >= (3, 0):
from ._cy3.core import _set_class_function, _set_class_module
from ._cy3.core import FunctionBase as _FunctionBase
from ._cy3.core import convert_to_tvm_func
else:
from ._cy2.core import _set_class_function, _set_class_module
from ._cy2.core import FunctionBase as _FunctionBase
from ._cy2.core import convert_to_tvm_func
except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position
from ._ctypes.function import _set_class_function, _set_class_module
from ._ctypes.function import FunctionBase as _FunctionBase
from ._ctypes.function import convert_to_tvm_func
FunctionHandle = ctypes.c_void_p
class Function(_FunctionBase):
"""The PackedFunc object used in TVM.
Function plays an key role to bridge front and backend in TVM.
Function provide a type-erased interface, you can call function with positional arguments.
The compiled module returns Function.
TVM backend also registers and exposes its API as Functions.
For example, the developer function exposed in tvm.ir_pass are actually
C++ functions that are registered as PackedFunc
The following are list of common usage scenario of tvm.Function.
- Automatic exposure of C++ API into python
- To call PackedFunc from python side
- To call python callbacks to inspect results in generated code
- Bring python hook into C++ backend
See Also
--------
tvm.register_func: How to register global function.
tvm.get_global_func: How to get global function.
"""
pass
class ModuleBase(object):
"""Base class for module"""
__slots__ = ["handle", "_entry", "entry_name"]
def __init__(self, handle):
self.handle = handle
self._entry = None
self.entry_name = "__tvm_main__"
def __del__(self):
check_call(_LIB.TVMModFree(self.handle))
@property
def entry_func(self):
"""Get the entry function
Returns
-------
f : Function
The entry function if exist
"""
if self._entry:
return self._entry
self._entry = self.get_function(self.entry_name)
return self._entry
def get_function(self, name, query_imports=False):
"""Get function from the module.
Parameters
----------
name : str
The name of the function
query_imports : bool
Whether also query modules imported by this module.
Returns
-------
f : Function
The result function.
"""
ret_handle = FunctionHandle()
check_call(_LIB.TVMModGetFunction(
self.handle, c_str(name),
ctypes.c_int(query_imports),
ctypes.byref(ret_handle)))
if not ret_handle.value:
raise AttributeError(
"Module has no function '%s'" % name)
return Function(ret_handle, False)
def import_module(self, module):
"""Add module to the import list of current one.
Parameters
----------
module : Module
The other module.
"""
check_call(_LIB.TVMModImport(self.handle, module.handle))
def __getitem__(self, name):
if not isinstance(name, string_types):
raise ValueError("Can only take string as function name")
return self.get_function(name)
def __call__(self, *args):
if self._entry:
return self._entry(*args)
f = self.entry_func
return f(*args)
def register_func(func_name, f=None, override=False):
"""Register global function
Parameters
----------
func_name : str or function
The function name
f : function, optional
The function to be registered.
override: boolean optional
Whether override existing entry.
Returns
-------
fregister : function
Register function if f is not specified.
Examples
--------
The following code registers my_packed_func as global function.
Note that we simply get it back from global function table to invoke
it from python side. However, we can also invoke the same function
from C++ backend, or in the compiled TVM code.
.. code-block:: python
targs = (10, 10.0, "hello")
@tvm.register_func
def my_packed_func(*args):
assert(tuple(args) == targs)
return 10
# Get it out from global function table
f = tvm.get_global_func("my_packed_func")
assert isinstance(f, tvm.nd.Function)
y = f(*targs)
assert y == 10
"""
if callable(func_name):
f = func_name
func_name = f.__name__
if not isinstance(func_name, str):
raise ValueError("expect string function name")
ioverride = ctypes.c_int(override)
def register(myf):
"""internal register function"""
if not isinstance(myf, Function):
myf = convert_to_tvm_func(myf)
check_call(_LIB.TVMFuncRegisterGlobal(
c_str(func_name), myf.handle, ioverride))
return myf
if f:
return register(f)
return register
def get_global_func(name, allow_missing=False):
"""Get a global function by name
Parameters
----------
name : str
The name of the global function
allow_missing : bool
Whether allow missing function or raise an error.
Returns
-------
func : tvm.Function
The function to be returned, None if function is missing.
"""
handle = FunctionHandle()
check_call(_LIB.TVMFuncGetGlobal(c_str(name), ctypes.byref(handle)))
if handle.value:
return Function(handle, False)
else:
if allow_missing:
return None
else:
raise ValueError("Cannot find global function %s" % name)
def list_global_func_names():
"""Get list of global functions registered.
Returns
-------
names : list
List of global functions names.
"""
plist = ctypes.POINTER(ctypes.c_char_p)()
size = ctypes.c_uint()
check_call(_LIB.TVMFuncListGlobalNames(ctypes.byref(size),
ctypes.byref(plist)))
fnames = []
for i in range(size.value):
fnames.append(py_str(plist[i]))
return fnames
def extract_ext_funcs(finit):
"""
Extract the extension PackedFuncs from a C module.
Parameters
----------
finit : ctypes function
a ctypes that takes signature of TVMExtensionDeclarer
Returns
-------
fdict : dict of str to Function
The extracted functions
"""
fdict = {}
def _list(name, func):
fdict[name] = func
myf = convert_to_tvm_func(_list)
ret = finit(myf.handle)
_ = myf
if ret != 0:
raise RuntimeError("cannot initialize with %s" % finit)
return fdict
def _get_api(f):
flocal = f
flocal.is_global = True
return flocal
def _init_api(namespace, target_module_name=None):
"""Initialize api for a given module name
namespace : str
The namespace of the source registry
target_module_name : str
The target module name if different from namespace
"""
target_module_name = (
target_module_name if target_module_name else namespace)
if namespace.startswith("tvm."):
_init_api_prefix(target_module_name, namespace[4:])
else:
_init_api_prefix(target_module_name, namespace)
def _init_api_prefix(module_name, prefix):
module = sys.modules[module_name]
for name in list_global_func_names():
if prefix == "api":
fname = name
if name.startswith("_"):
target_module = sys.modules["tvm._api_internal"]
else:
target_module = module
else:
if not name.startswith(prefix):
continue
fname = name[len(prefix)+1:]
target_module = module
if fname.find(".") != -1:
continue
f = get_global_func(name)
ff = _get_api(f)
ff.__name__ = fname
ff.__doc__ = ("TVM PackedFunc %s. " % fname)
setattr(target_module, ff.__name__, ff)
_set_class_function(Function)
"""Library information."""
from __future__ import absolute_import
import sys
import os
def find_lib_path(name=None, search_path=None, optional=False):
"""Find dynamic library files.
Parameters
----------
name : list of str
List of names to be found.
Returns
-------
lib_path : list(string)
List of all found path to the libraries
"""
# See https://github.com/dmlc/tvm/issues/281 for some background.
# NB: This will either be the source directory (if TVM is run
# inplace) or the install directory (if TVM is installed).
# An installed TVM's curr_path will look something like:
# $PREFIX/lib/python3.6/site-packages/tvm/_ffi
ffi_dir = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
source_dir = os.path.join(ffi_dir, "..", "..", "..")
install_lib_dir = os.path.join(ffi_dir, "..", "..", "..", "..")
dll_path = []
if os.environ.get('DGL_LIBRARY_PATH', None):
dll_path.append(os.environ['DGL_LIBRARY_PATH'])
if sys.platform.startswith('linux') and os.environ.get('LD_LIBRARY_PATH', None):
dll_path.extend([p.strip() for p in os.environ['LD_LIBRARY_PATH'].split(":")])
elif sys.platform.startswith('darwin') and os.environ.get('DYLD_LIBRARY_PATH', None):
dll_path.extend([p.strip() for p in os.environ['DYLD_LIBRARY_PATH'].split(":")])
# Pip lib directory
dll_path.append(os.path.join(ffi_dir, ".."))
# Default cmake build directory
dll_path.append(os.path.join(source_dir, "build"))
dll_path.append(os.path.join(source_dir, "build", "Release"))
# Default make build directory
dll_path.append(os.path.join(source_dir, "lib"))
dll_path.append(install_lib_dir)
dll_path = [os.path.abspath(x) for x in dll_path]
if search_path is not None:
if search_path is list:
dll_path = dll_path + search_path
else:
dll_path.append(search_path)
if name is not None:
if isinstance(name, list):
lib_dll_path = []
for n in name:
lib_dll_path += [os.path.join(p, n) for p in dll_path]
else:
lib_dll_path = [os.path.join(p, name) for p in dll_path]
else:
if sys.platform.startswith('win32'):
lib_dll_path = [os.path.join(p, 'libdgl.dll') for p in dll_path] +\
[os.path.join(p, 'dgl.dll') for p in dll_path]
elif sys.platform.startswith('darwin'):
lib_dll_path = [os.path.join(p, 'libdgl.dylib') for p in dll_path]
else:
lib_dll_path = [os.path.join(p, 'libdgl.so') for p in dll_path]
# try to find lib_dll_path
lib_found = [p for p in lib_dll_path if os.path.exists(p) and os.path.isfile(p)]
if not lib_found:
message = ('Cannot find the files.\n' +
'List of candidates:\n' +
str('\n'.join(lib_dll_path + runtime_dll_path)))
if not optional:
raise RuntimeError(message)
return None
return lib_found
# current version
# We use the version of the incoming release for code
# that is under development.
# The following line is set by tvm/python/update_version.py
__version__ = "0.5.dev"
# pylint: disable=invalid-name, unused-import
"""Runtime NDArray api"""
from __future__ import absolute_import
import sys
import ctypes
import numpy as np
from .base import _LIB, check_call, c_array, string_types, _FFI_MODE, c_str
from .runtime_ctypes import TVMType, TVMContext, TVMArray, TVMArrayHandle
from .runtime_ctypes import TypeCode, tvm_shape_index_t
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
try:
# pylint: disable=wrong-import-position
if _FFI_MODE == "ctypes":
raise ImportError()
if sys.version_info >= (3, 0):
from ._cy3.core import _set_class_ndarray, _reg_extension, _make_array, _from_dlpack
from ._cy3.core import NDArrayBase as _NDArrayBase
else:
from ._cy2.core import _set_class_ndarray, _reg_extension, _make_array, _from_dlpack
from ._cy2.core import NDArrayBase as _NDArrayBase
except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position
from ._ctypes.ndarray import _set_class_ndarray, _reg_extension, _make_array, _from_dlpack
from ._ctypes.ndarray import NDArrayBase as _NDArrayBase
def context(dev_type, dev_id=0):
"""Construct a TVM context with given device type and id.
Parameters
----------
dev_type: int or str
The device type mask or name of the device.
dev_id : int, optional
The integer device id
Returns
-------
ctx: TVMContext
The corresponding context.
Examples
--------
Context can be used to create reflection of context by
string representation of the device type.
.. code-block:: python
assert tvm.context("cpu", 1) == tvm.cpu(1)
assert tvm.context("gpu", 0) == tvm.gpu(0)
assert tvm.context("cuda", 0) == tvm.gpu(0)
"""
if isinstance(dev_type, string_types):
dev_type = dev_type.split()[0]
if dev_type not in TVMContext.STR2MASK:
raise ValueError("Unknown device type %s" % dev_type)
dev_type = TVMContext.STR2MASK[dev_type]
return TVMContext(dev_type, dev_id)
def numpyasarray(np_data):
"""Return a TVMArray representation of a numpy array.
"""
data = np_data
assert data.flags['C_CONTIGUOUS']
arr = TVMArray()
shape = c_array(tvm_shape_index_t, data.shape)
arr.data = data.ctypes.data_as(ctypes.c_void_p)
arr.shape = shape
arr.strides = None
arr.dtype = TVMType(np.dtype(data.dtype).name)
arr.ndim = data.ndim
# CPU device
arr.ctx = context(1, 0)
return arr, shape
def empty(shape, dtype="float32", ctx=context(1, 0)):
"""Create an empty array given shape and device
Parameters
----------
shape : tuple of int
The shape of the array
dtype : type or str
The data type of the array.
ctx : TVMContext
The context of the array
Returns
-------
arr : tvm.nd.NDArray
The array tvm supported.
"""
shape = c_array(tvm_shape_index_t, shape)
ndim = ctypes.c_int(len(shape))
handle = TVMArrayHandle()
dtype = TVMType(dtype)
check_call(_LIB.TVMArrayAlloc(
shape, ndim,
ctypes.c_int(dtype.type_code),
ctypes.c_int(dtype.bits),
ctypes.c_int(dtype.lanes),
ctx.device_type,
ctx.device_id,
ctypes.byref(handle)))
return _make_array(handle, False)
def from_dlpack(dltensor):
"""Produce an array from a DLPack tensor without memory copy.
Retreives the underlying DLPack tensor's pointer to create an array from the
data. Removes the original DLPack tensor's destructor as now the array is
responsible for destruction.
Parameters
----------
dltensor : DLPack tensor
Input DLManagedTensor, can only be consumed once.
Returns
-------
arr: tvm.nd.NDArray
The array view of the tensor data.
"""
return _from_dlpack(dltensor)
class NDArrayBase(_NDArrayBase):
"""A simple Device/CPU Array object in runtime."""
@property
def shape(self):
"""Shape of this array"""
return tuple(self.handle.contents.shape[i] for i in range(self.handle.contents.ndim))
@property
def dtype(self):
"""Type of this array"""
return str(self.handle.contents.dtype)
@property
def ctx(self):
"""context of this array"""
return self.handle.contents.ctx
@property
def context(self):
"""context of this array"""
return self.ctx
def __hash__(self):
return ctypes.cast(self.handle, ctypes.c_void_p).value
def __eq__(self, other):
return self.same_as(other)
def __ne__(self, other):
return not self.__eq__(other)
def same_as(self, other):
"""Check object identity equality
Parameters
----------
other : object
The other object to compare to
Returns
-------
same : bool
Whether other is same as self.
"""
if not isinstance(other, NDArrayBase):
return False
return self.__hash__() == other.__hash__()
def __setitem__(self, in_slice, value):
"""Set ndarray value"""
if (not isinstance(in_slice, slice) or
in_slice.start is not None
or in_slice.stop is not None):
raise ValueError('Array only support set from numpy array')
if isinstance(value, NDArrayBase):
if value.handle is not self.handle:
value.copyto(self)
elif isinstance(value, (np.ndarray, np.generic)):
self.copyfrom(value)
else:
raise TypeError('type %s not supported' % str(type(value)))
def copyfrom(self, source_array):
"""Peform an synchronize copy from the array.
Parameters
----------
source_array : array_like
The data source we should like to copy from.
Returns
-------
arr : NDArray
Reference to self.
"""
if isinstance(source_array, NDArrayBase):
source_array.copyto(self)
return self
if not isinstance(source_array, np.ndarray):
try:
source_array = np.array(source_array, dtype=self.dtype)
except:
raise TypeError('array must be an array_like data,' +
'type %s is not supported' % str(type(source_array)))
t = TVMType(self.dtype)
shape, dtype = self.shape, self.dtype
if t.lanes > 1:
shape = shape + (t.lanes,)
t.lanes = 1
dtype = str(t)
if source_array.shape != shape:
raise ValueError("array shape do not match the shape of NDArray {0} vs {1}".format(
source_array.shape, shape))
source_array = np.ascontiguousarray(source_array, dtype=dtype)
assert source_array.flags['C_CONTIGUOUS']
data = source_array.ctypes.data_as(ctypes.c_void_p)
nbytes = ctypes.c_size_t(source_array.size * source_array.dtype.itemsize)
check_call(_LIB.TVMArrayCopyFromBytes(self.handle, data, nbytes))
return self
def __repr__(self):
res = "<tvm.NDArray shape={0}, {1}>\n".format(self.shape, self.context)
res += self.asnumpy().__repr__()
return res
def __str__(self):
return str(self.asnumpy())
def asnumpy(self):
"""Convert this array to numpy array
Returns
-------
np_arr : numpy.ndarray
The corresponding numpy array.
"""
t = TVMType(self.dtype)
shape, dtype = self.shape, self.dtype
if t.lanes > 1:
shape = shape + (t.lanes,)
t.lanes = 1
dtype = str(t)
np_arr = np.empty(shape, dtype=dtype)
assert np_arr.flags['C_CONTIGUOUS']
data = np_arr.ctypes.data_as(ctypes.c_void_p)
nbytes = ctypes.c_size_t(np_arr.size * np_arr.dtype.itemsize)
check_call(_LIB.TVMArrayCopyToBytes(self.handle, data, nbytes))
return np_arr
def copyto(self, target):
"""Copy array to target
Parameters
----------
target : NDArray
The target array to be copied, must have same shape as this array.
"""
if isinstance(target, TVMContext):
target = empty(self.shape, self.dtype, target)
if isinstance(target, NDArrayBase):
check_call(_LIB.TVMArrayCopyFromTo(
self.handle, target.handle, None))
else:
raise ValueError("Unsupported target type %s" % str(type(target)))
return target
def free_extension_handle(handle, type_code):
"""Free c++ extension type handle
Parameters
----------
handle : ctypes.c_void_p
The handle to the extension type.
type_code : int
The tyoe code
"""
check_call(_LIB.TVMExtTypeFree(handle, ctypes.c_int(type_code)))
def register_extension(cls, fcreate=None):
"""Register a extension class to TVM.
After the class is registered, the class will be able
to directly pass as Function argument generated by TVM.
Parameters
----------
cls : class
The class object to be registered as extension.
Note
----
The registered class is requires one property: _tvm_handle and a class attribute _tvm_tcode.
- ```_tvm_handle``` returns integer represents the address of the handle.
- ```_tvm_tcode``` gives integer represents type code of the class.
Returns
-------
cls : class
The class being registered.
fcreate : function, optional
The creation function to create a class object given handle value.
Example
-------
The following code registers user defined class
MyTensor to be DLTensor compatible.
.. code-block:: python
@tvm.register_extension
class MyTensor(object):
_tvm_tcode = tvm.TypeCode.ARRAY_HANDLE
def __init__(self):
self.handle = _LIB.NewDLTensor()
@property
def _tvm_handle(self):
return self.handle.value
"""
if fcreate and cls._tvm_tcode < TypeCode.EXT_BEGIN:
raise ValueError("Cannot register create when extension tcode is same as buildin")
_reg_extension(cls, fcreate)
return cls
"""Common runtime ctypes."""
# pylint: disable=invalid-name
from __future__ import absolute_import
import ctypes
import json
import numpy as np
from .base import _LIB, check_call
from .. import _api_internal
tvm_shape_index_t = ctypes.c_int64
class TypeCode(object):
"""Type code used in API calls"""
INT = 0
UINT = 1
FLOAT = 2
HANDLE = 3
NULL = 4
TVM_TYPE = 5
TVM_CONTEXT = 6
ARRAY_HANDLE = 7
NODE_HANDLE = 8
MODULE_HANDLE = 9
FUNC_HANDLE = 10
STR = 11
BYTES = 12
NDARRAY_CONTAINER = 13
EXT_BEGIN = 15
class TVMByteArray(ctypes.Structure):
"""Temp data structure for byte array."""
_fields_ = [("data", ctypes.POINTER(ctypes.c_byte)),
("size", ctypes.c_size_t)]
class TVMType(ctypes.Structure):
"""TVM datatype structure"""
_fields_ = [("type_code", ctypes.c_uint8),
("bits", ctypes.c_uint8),
("lanes", ctypes.c_uint16)]
CODE2STR = {
0 : 'int',
1 : 'uint',
2 : 'float',
4 : 'handle'
}
def __init__(self, type_str):
super(TVMType, self).__init__()
if isinstance(type_str, np.dtype):
type_str = str(type_str)
arr = type_str.split("x")
head = arr[0]
self.lanes = int(arr[1]) if len(arr) > 1 else 1
bits = 32
if head.startswith("int"):
self.type_code = 0
head = head[3:]
elif head.startswith("uint"):
self.type_code = 1
head = head[4:]
elif head.startswith("float"):
self.type_code = 2
head = head[5:]
elif head.startswith("handle"):
self.type_code = 4
bits = 64
head = ""
else:
raise ValueError("Donot know how to handle type %s" % type_str)
bits = int(head) if head else bits
self.bits = bits
def __repr__(self):
x = "%s%d" % (TVMType.CODE2STR[self.type_code], self.bits)
if self.lanes != 1:
x += "x%d" % self.lanes
return x
def __eq__(self, other):
return (self.bits == other.bits and
self.type_code == other.type_code and
self.lanes == other.lanes)
def __ne__(self, other):
return not self.__eq__(other)
RPC_SESS_MASK = 128
class TVMContext(ctypes.Structure):
"""TVM context strucure."""
_fields_ = [("device_type", ctypes.c_int),
("device_id", ctypes.c_int)]
MASK2STR = {
1 : 'cpu',
2 : 'gpu',
4 : 'opencl',
5 : 'aocl',
6 : 'sdaccel',
7 : 'vulkan',
8 : 'metal',
9 : 'vpi',
10: 'rocm',
11: 'opengl',
12: 'ext_dev',
}
STR2MASK = {
'llvm': 1,
'stackvm': 1,
'cpu': 1,
'gpu': 2,
'cuda': 2,
'nvptx': 2,
'cl': 4,
'opencl': 4,
'aocl' : 5,
'aocl_sw_emu' : 5,
'sdaccel': 6,
'vulkan': 7,
'metal': 8,
'vpi': 9,
'rocm': 10,
'opengl': 11,
'ext_dev': 12,
}
def __init__(self, device_type, device_id):
super(TVMContext, self).__init__()
self.device_type = device_type
self.device_id = device_id
@property
def exist(self):
"""Whether this device exist."""
return _api_internal._GetDeviceAttr(
self.device_type, self.device_id, 0) != 0
@property
def max_threads_per_block(self):
"""Maximum number of threads on each block."""
return _api_internal._GetDeviceAttr(
self.device_type, self.device_id, 1)
@property
def warp_size(self):
"""Number of threads that executes in concurrent."""
return _api_internal._GetDeviceAttr(
self.device_type, self.device_id, 2)
@property
def max_shared_memory_per_block(self):
"""Total amount of shared memory per block in bytes."""
return _api_internal._GetDeviceAttr(
self.device_type, self.device_id, 3)
@property
def compute_version(self):
"""Get compute verison number in string.
Currently used to get compute capability of CUDA device.
Returns
-------
version : str
The version string in `major.minor` format.
"""
return _api_internal._GetDeviceAttr(
self.device_type, self.device_id, 4)
@property
def device_name(self):
"""Return the string name of device."""
return _api_internal._GetDeviceAttr(
self.device_type, self.device_id, 5)
@property
def max_clock_rate(self):
"""Return the max clock frequency of device."""
return _api_internal._GetDeviceAttr(
self.device_type, self.device_id, 6)
@property
def multi_processor_count(self):
"""Return the number of compute units of device."""
return _api_internal._GetDeviceAttr(
self.device_type, self.device_id, 7)
@property
def max_thread_dimensions(self):
"""Return the maximum size of each thread axis
Returns
-------
dims: List of int
The maximum length of threadIdx.x, threadIdx.y, threadIdx.z
"""
return json.loads(_api_internal._GetDeviceAttr(
self.device_type, self.device_id, 8))
def sync(self):
"""Synchronize until jobs finished at the context."""
check_call(_LIB.TVMSynchronize(self.device_type, self.device_id, None))
def __eq__(self, other):
return (isinstance(other, TVMContext) and
self.device_id == other.device_id and
self.device_type == other.device_type)
def __ne__(self, other):
return not self.__eq__(other)
def __repr__(self):
if self.device_type >= RPC_SESS_MASK:
tbl_id = self.device_type / RPC_SESS_MASK - 1
dev_type = self.device_type % RPC_SESS_MASK
return "remote[%d]:%s(%d)" % (
tbl_id, TVMContext.MASK2STR[dev_type], self.device_id)
return "%s(%d)" % (
TVMContext.MASK2STR[self.device_type], self.device_id)
class TVMArray(ctypes.Structure):
"""TVMValue in C API"""
_fields_ = [("data", ctypes.c_void_p),
("ctx", TVMContext),
("ndim", ctypes.c_int),
("dtype", TVMType),
("shape", ctypes.POINTER(tvm_shape_index_t)),
("strides", ctypes.POINTER(tvm_shape_index_t)),
("byte_offset", ctypes.c_uint64)]
TVMArrayHandle = ctypes.POINTER(TVMArray)
# C API and runtime
Borrowed and adapted from TVM project.
/*!
* Copyright (c) 2016 by Contributors
* \file c_runtime_api.cc
* \brief Device specific implementations
*/
#include <dmlc/thread_local.h>
#include <dgl/runtime/c_runtime_api.h>
#include <dgl/runtime/c_backend_api.h>
#include <dgl/runtime/packed_func.h>
#include <dgl/runtime/module.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/device_api.h>
#include <array>
#include <algorithm>
#include <string>
#include <cstdlib>
#include "runtime_base.h"
namespace tvm {
namespace runtime {
/*!
* \brief The name of Device API factory.
* \param type The device type.
*/
inline std::string DeviceName(int type) {
switch (type) {
case kDLCPU: return "cpu";
case kDLGPU: return "gpu";
case kDLOpenCL: return "opencl";
case kDLSDAccel: return "sdaccel";
case kDLAOCL: return "aocl";
case kDLVulkan: return "vulkan";
case kDLMetal: return "metal";
case kDLVPI: return "vpi";
case kDLROCM: return "rocm";
case kOpenGL: return "opengl";
case kExtDev: return "ext_dev";
default: LOG(FATAL) << "unknown type =" << type; return "Unknown";
}
}
class DeviceAPIManager {
public:
static const int kMaxDeviceAPI = 32;
// Get API
static DeviceAPI* Get(const TVMContext& ctx) {
return Get(ctx.device_type);
}
static DeviceAPI* Get(int dev_type, bool allow_missing = false) {
return Global()->GetAPI(dev_type, allow_missing);
}
private:
std::array<DeviceAPI*, kMaxDeviceAPI> api_;
DeviceAPI* rpc_api_{nullptr};
std::mutex mutex_;
// constructor
DeviceAPIManager() {
std::fill(api_.begin(), api_.end(), nullptr);
}
// Global static variable.
static DeviceAPIManager* Global() {
static DeviceAPIManager inst;
return &inst;
}
// Get or initialize API.
DeviceAPI* GetAPI(int type, bool allow_missing) {
if (type < kRPCSessMask) {
if (api_[type] != nullptr) return api_[type];
std::lock_guard<std::mutex> lock(mutex_);
if (api_[type] != nullptr) return api_[type];
api_[type] = GetAPI(DeviceName(type), allow_missing);
return api_[type];
} else {
if (rpc_api_ != nullptr) return rpc_api_;
std::lock_guard<std::mutex> lock(mutex_);
if (rpc_api_ != nullptr) return rpc_api_;
rpc_api_ = GetAPI("rpc", allow_missing);
return rpc_api_;
}
}
DeviceAPI* GetAPI(const std::string name, bool allow_missing) {
std::string factory = "device_api." + name;
auto* f = Registry::Get(factory);
if (f == nullptr) {
CHECK(allow_missing)
<< "Device API " << name << " is not enabled.";
return nullptr;
}
void* ptr = (*f)();
return static_cast<DeviceAPI*>(ptr);
}
};
DeviceAPI* DeviceAPI::Get(TVMContext ctx, bool allow_missing) {
return DeviceAPIManager::Get(
static_cast<int>(ctx.device_type), allow_missing);
}
void* DeviceAPI::AllocWorkspace(TVMContext ctx,
size_t size,
TVMType type_hint) {
return AllocDataSpace(ctx, size, kTempAllocaAlignment, type_hint);
}
void DeviceAPI::FreeWorkspace(TVMContext ctx, void* ptr) {
FreeDataSpace(ctx, ptr);
}
TVMStreamHandle DeviceAPI::CreateStream(TVMContext ctx) {
LOG(FATAL) << "Device does not support stream api.";
return 0;
}
void DeviceAPI::FreeStream(TVMContext ctx, TVMStreamHandle stream) {
LOG(FATAL) << "Device does not support stream api.";
}
void DeviceAPI::SyncStreamFromTo(TVMContext ctx,
TVMStreamHandle event_src,
TVMStreamHandle event_dst) {
LOG(FATAL) << "Device does not support stream api.";
}
} // namespace runtime
} // namespace tvm
using namespace tvm::runtime;
struct TVMRuntimeEntry {
std::string ret_str;
std::string last_error;
TVMByteArray ret_bytes;
};
typedef dmlc::ThreadLocalStore<TVMRuntimeEntry> TVMAPIRuntimeStore;
const char *TVMGetLastError() {
return TVMAPIRuntimeStore::Get()->last_error.c_str();
}
void TVMAPISetLastError(const char* msg) {
#ifndef _LIBCPP_SGX_CONFIG
TVMAPIRuntimeStore::Get()->last_error = msg;
#else
sgx::OCallPackedFunc("__sgx_set_last_error__", msg);
#endif
}
int TVMModLoadFromFile(const char* file_name,
const char* format,
TVMModuleHandle* out) {
API_BEGIN();
Module m = Module::LoadFromFile(file_name, format);
*out = new Module(m);
API_END();
}
int TVMModImport(TVMModuleHandle mod,
TVMModuleHandle dep) {
API_BEGIN();
static_cast<Module*>(mod)->Import(
*static_cast<Module*>(dep));
API_END();
}
int TVMModGetFunction(TVMModuleHandle mod,
const char* func_name,
int query_imports,
TVMFunctionHandle *func) {
API_BEGIN();
PackedFunc pf = static_cast<Module*>(mod)->GetFunction(
func_name, query_imports != 0);
if (pf != nullptr) {
*func = new PackedFunc(pf);
} else {
*func = nullptr;
}
API_END();
}
int TVMModFree(TVMModuleHandle mod) {
API_BEGIN();
delete static_cast<Module*>(mod);
API_END();
}
int TVMBackendGetFuncFromEnv(void* mod_node,
const char* func_name,
TVMFunctionHandle *func) {
API_BEGIN();
*func = (TVMFunctionHandle)(
static_cast<ModuleNode*>(mod_node)->GetFuncFromEnv(func_name));
API_END();
}
void* TVMBackendAllocWorkspace(int device_type,
int device_id,
uint64_t size,
int dtype_code_hint,
int dtype_bits_hint) {
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id;
TVMType type_hint;
type_hint.code = static_cast<decltype(type_hint.code)>(dtype_code_hint);
type_hint.bits = static_cast<decltype(type_hint.bits)>(dtype_bits_hint);
type_hint.lanes = 1;
return DeviceAPIManager::Get(ctx)->AllocWorkspace(ctx,
static_cast<size_t>(size),
type_hint);
}
int TVMBackendFreeWorkspace(int device_type,
int device_id,
void* ptr) {
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id;
DeviceAPIManager::Get(ctx)->FreeWorkspace(ctx, ptr);
return 0;
}
int TVMBackendRunOnce(void** handle,
int (*f)(void*),
void* cdata,
int nbytes) {
if (*handle == nullptr) {
*handle = reinterpret_cast<void*>(1);
return (*f)(cdata);
}
return 0;
}
int TVMFuncFree(TVMFunctionHandle func) {
API_BEGIN();
delete static_cast<PackedFunc*>(func);
API_END();
}
int TVMFuncCall(TVMFunctionHandle func,
TVMValue* args,
int* arg_type_codes,
int num_args,
TVMValue* ret_val,
int* ret_type_code) {
API_BEGIN();
TVMRetValue rv;
(*static_cast<const PackedFunc*>(func)).CallPacked(
TVMArgs(args, arg_type_codes, num_args), &rv);
// handle return string.
if (rv.type_code() == kStr ||
rv.type_code() == kTVMType ||
rv.type_code() == kBytes) {
TVMRuntimeEntry* e = TVMAPIRuntimeStore::Get();
if (rv.type_code() != kTVMType) {
e->ret_str = *rv.ptr<std::string>();
} else {
e->ret_str = rv.operator std::string();
}
if (rv.type_code() == kBytes) {
e->ret_bytes.data = e->ret_str.c_str();
e->ret_bytes.size = e->ret_str.length();
*ret_type_code = kBytes;
ret_val->v_handle = &(e->ret_bytes);
} else {
*ret_type_code = kStr;
ret_val->v_str = e->ret_str.c_str();
}
} else {
rv.MoveToCHost(ret_val, ret_type_code);
}
API_END();
}
int TVMCFuncSetReturn(TVMRetValueHandle ret,
TVMValue* value,
int* type_code,
int num_ret) {
API_BEGIN();
CHECK_EQ(num_ret, 1);
TVMRetValue* rv = static_cast<TVMRetValue*>(ret);
*rv = TVMArgValue(value[0], type_code[0]);
API_END();
}
int TVMFuncCreateFromCFunc(TVMPackedCFunc func,
void* resource_handle,
TVMPackedCFuncFinalizer fin,
TVMFunctionHandle *out) {
API_BEGIN();
if (fin == nullptr) {
*out = new PackedFunc(
[func, resource_handle](TVMArgs args, TVMRetValue* rv) {
int ret = func((TVMValue*)args.values, (int*)args.type_codes, // NOLINT(*)
args.num_args, rv, resource_handle);
if (ret != 0) {
std::string err = "TVMCall CFunc Error:\n";
err += TVMGetLastError();
throw dmlc::Error(err);
}
});
} else {
// wrap it in a shared_ptr, with fin as deleter.
// so fin will be called when the lambda went out of scope.
std::shared_ptr<void> rpack(resource_handle, fin);
*out = new PackedFunc(
[func, rpack](TVMArgs args, TVMRetValue* rv) {
int ret = func((TVMValue*)args.values, (int*)args.type_codes, // NOLINT(*)
args.num_args, rv, rpack.get());
if (ret != 0) {
std::string err = "TVMCall CFunc Error:\n";
err += TVMGetLastError();
throw dmlc::Error(err);
}
});
}
API_END();
}
int TVMStreamCreate(int device_type, int device_id, TVMStreamHandle* out) {
API_BEGIN();
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id;
*out = DeviceAPIManager::Get(ctx)->CreateStream(ctx);
API_END();
}
int TVMStreamFree(int device_type, int device_id, TVMStreamHandle stream) {
API_BEGIN();
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id;
DeviceAPIManager::Get(ctx)->FreeStream(ctx, stream);
API_END();
}
int TVMSetStream(int device_type, int device_id, TVMStreamHandle stream) {
API_BEGIN();
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id;
DeviceAPIManager::Get(ctx)->SetStream(ctx, stream);
API_END();
}
int TVMSynchronize(int device_type, int device_id, TVMStreamHandle stream) {
API_BEGIN();
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id;
DeviceAPIManager::Get(ctx)->StreamSync(ctx, stream);
API_END();
}
int TVMStreamStreamSynchronize(int device_type,
int device_id,
TVMStreamHandle src,
TVMStreamHandle dst) {
API_BEGIN();
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id;
DeviceAPIManager::Get(ctx)->SyncStreamFromTo(ctx, src, dst);
API_END();
}
int TVMCbArgToReturn(TVMValue* value, int code) {
API_BEGIN();
tvm::runtime::TVMRetValue rv;
rv = tvm::runtime::TVMArgValue(*value, code);
int tcode;
rv.MoveToCHost(value, &tcode);
CHECK_EQ(tcode, code);
API_END();
}
// set device api
TVM_REGISTER_GLOBAL(tvm::runtime::symbol::tvm_set_device)
.set_body([](TVMArgs args, TVMRetValue *ret) {
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(args[0].operator int());
ctx.device_id = args[1];
DeviceAPIManager::Get(ctx)->SetDevice(ctx);
});
// set device api
TVM_REGISTER_GLOBAL("_GetDeviceAttr")
.set_body([](TVMArgs args, TVMRetValue *ret) {
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(args[0].operator int());
ctx.device_id = args[1];
DeviceAttrKind kind = static_cast<DeviceAttrKind>(args[2].operator int());
if (kind == kExist) {
DeviceAPI* api = DeviceAPIManager::Get(ctx.device_type, true);
if (api != nullptr) {
api->GetAttr(ctx, kind, ret);
} else {
*ret = 0;
}
} else {
DeviceAPIManager::Get(ctx)->GetAttr(ctx, kind, ret);
}
});
/*!
* Copyright (c) 2016 by Contributors
* \file cpu_device_api.cc
*/
#include <dmlc/logging.h>
#include <dmlc/thread_local.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/device_api.h>
#include <cstdlib>
#include <cstring>
#include "workspace_pool.h"
namespace tvm {
namespace runtime {
class CPUDeviceAPI final : public DeviceAPI {
public:
void SetDevice(TVMContext ctx) final {}
void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final {
if (kind == kExist) {
*rv = 1;
}
}
void* AllocDataSpace(TVMContext ctx,
size_t nbytes,
size_t alignment,
TVMType type_hint) final {
void* ptr;
#if _MSC_VER
ptr = _aligned_malloc(nbytes, alignment);
if (ptr == nullptr) throw std::bad_alloc();
#elif defined(_LIBCPP_SGX_CONFIG)
ptr = memalign(alignment, nbytes);
if (ptr == nullptr) throw std::bad_alloc();
#else
int ret = posix_memalign(&ptr, alignment, nbytes);
if (ret != 0) throw std::bad_alloc();
#endif
return ptr;
}
void FreeDataSpace(TVMContext ctx, void* ptr) final {
#if _MSC_VER
_aligned_free(ptr);
#else
free(ptr);
#endif
}
void CopyDataFromTo(const void* from,
size_t from_offset,
void* to,
size_t to_offset,
size_t size,
TVMContext ctx_from,
TVMContext ctx_to,
TVMType type_hint,
TVMStreamHandle stream) final {
memcpy(static_cast<char*>(to) + to_offset,
static_cast<const char*>(from) + from_offset,
size);
}
void StreamSync(TVMContext ctx, TVMStreamHandle stream) final {
}
void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final;
void FreeWorkspace(TVMContext ctx, void* data) final;
static const std::shared_ptr<CPUDeviceAPI>& Global() {
static std::shared_ptr<CPUDeviceAPI> inst =
std::make_shared<CPUDeviceAPI>();
return inst;
}
};
struct CPUWorkspacePool : public WorkspacePool {
CPUWorkspacePool() :
WorkspacePool(kDLCPU, CPUDeviceAPI::Global()) {}
};
void* CPUDeviceAPI::AllocWorkspace(TVMContext ctx,
size_t size,
TVMType type_hint) {
return dmlc::ThreadLocalStore<CPUWorkspacePool>::Get()
->AllocWorkspace(ctx, size);
}
void CPUDeviceAPI::FreeWorkspace(TVMContext ctx, void* data) {
dmlc::ThreadLocalStore<CPUWorkspacePool>::Get()->FreeWorkspace(ctx, data);
}
TVM_REGISTER_GLOBAL("device_api.cpu")
.set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = CPUDeviceAPI::Global().get();
*rv = static_cast<void*>(ptr);
});
} // namespace runtime
} // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \file dso_dll_module.cc
* \brief Module to load from dynamic shared library.
*/
#include <dgl/runtime/module.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/packed_func.h>
#include "module_util.h"
#if defined(_WIN32)
#include <windows.h>
#else
#include <dlfcn.h>
#endif
namespace tvm {
namespace runtime {
// Module to load from dynamic shared libary.
// This is the default module TVM used for host-side AOT
class DSOModuleNode final : public ModuleNode {
public:
~DSOModuleNode() {
if (lib_handle_) Unload();
}
const char* type_key() const final {
return "dso";
}
PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final {
BackendPackedCFunc faddr;
if (name == runtime::symbol::tvm_module_main) {
const char* entry_name = reinterpret_cast<const char*>(
GetSymbol(runtime::symbol::tvm_module_main));
CHECK(entry_name!= nullptr)
<< "Symbol " << runtime::symbol::tvm_module_main << " is not presented";
faddr = reinterpret_cast<BackendPackedCFunc>(GetSymbol(entry_name));
} else {
faddr = reinterpret_cast<BackendPackedCFunc>(GetSymbol(name.c_str()));
}
if (faddr == nullptr) return PackedFunc();
return WrapPackedFunc(faddr, sptr_to_self);
}
void Init(const std::string& name) {
Load(name);
if (auto *ctx_addr =
reinterpret_cast<void**>(GetSymbol(runtime::symbol::tvm_module_ctx))) {
*ctx_addr = this;
}
InitContextFunctions([this](const char* fname) {
return GetSymbol(fname);
});
// Load the imported modules
const char* dev_mblob =
reinterpret_cast<const char*>(
GetSymbol(runtime::symbol::tvm_dev_mblob));
if (dev_mblob != nullptr) {
ImportModuleBlob(dev_mblob, &imports_);
}
}
private:
// Platform dependent handling.
#if defined(_WIN32)
// library handle
HMODULE lib_handle_{nullptr};
// Load the library
void Load(const std::string& name) {
// use wstring version that is needed by LLVM.
std::wstring wname(name.begin(), name.end());
lib_handle_ = LoadLibraryW(wname.c_str());
CHECK(lib_handle_ != nullptr)
<< "Failed to load dynamic shared library " << name;
}
void* GetSymbol(const char* name) {
return reinterpret_cast<void*>(
GetProcAddress(lib_handle_, (LPCSTR)name)); // NOLINT(*)
}
void Unload() {
FreeLibrary(lib_handle_);
}
#else
// Library handle
void* lib_handle_{nullptr};
// load the library
void Load(const std::string& name) {
lib_handle_ = dlopen(name.c_str(), RTLD_LAZY | RTLD_LOCAL);
CHECK(lib_handle_ != nullptr)
<< "Failed to load dynamic shared library " << name
<< " " << dlerror();
}
void* GetSymbol(const char* name) {
return dlsym(lib_handle_, name);
}
void Unload() {
dlclose(lib_handle_);
}
#endif
};
TVM_REGISTER_GLOBAL("module.loadfile_so")
.set_body([](TVMArgs args, TVMRetValue* rv) {
std::shared_ptr<DSOModuleNode> n = std::make_shared<DSOModuleNode>();
n->Init(args[0]);
*rv = runtime::Module(n);
});
} // namespace runtime
} // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \file file_util.cc
*/
#include <dmlc/json.h>
#include <dmlc/logging.h>
#include <dgl/runtime/serializer.h>
#include <fstream>
#include <vector>
#include "file_util.h"
namespace tvm {
namespace runtime {
void FunctionInfo::Save(dmlc::JSONWriter* writer) const {
std::vector<std::string> sarg_types(arg_types.size());
for (size_t i = 0; i < arg_types.size(); ++i) {
sarg_types[i] = TVMType2String(arg_types[i]);
}
writer->BeginObject();
writer->WriteObjectKeyValue("name", name);
writer->WriteObjectKeyValue("arg_types", sarg_types);
writer->WriteObjectKeyValue("thread_axis_tags", thread_axis_tags);
writer->EndObject();
}
void FunctionInfo::Load(dmlc::JSONReader* reader) {
dmlc::JSONObjectReadHelper helper;
std::vector<std::string> sarg_types;
helper.DeclareField("name", &name);
helper.DeclareField("arg_types", &sarg_types);
helper.DeclareField("thread_axis_tags", &thread_axis_tags);
helper.ReadAllFields(reader);
arg_types.resize(sarg_types.size());
for (size_t i = 0; i < arg_types.size(); ++i) {
arg_types[i] = String2TVMType(sarg_types[i]);
}
}
void FunctionInfo::Save(dmlc::Stream* writer) const {
writer->Write(name);
writer->Write(arg_types);
writer->Write(thread_axis_tags);
}
bool FunctionInfo::Load(dmlc::Stream* reader) {
if (!reader->Read(&name)) return false;
if (!reader->Read(&arg_types)) return false;
if (!reader->Read(&thread_axis_tags)) return false;
return true;
}
std::string GetFileFormat(const std::string& file_name,
const std::string& format) {
std::string fmt = format;
if (fmt.length() == 0) {
if (file_name.find(".signed.so") != std::string::npos) return "sgx";
size_t pos = file_name.find_last_of(".");
if (pos != std::string::npos) {
return file_name.substr(pos + 1, file_name.length() - pos - 1);
} else {
return "";
}
} else {
return format;
}
}
std::string GetCacheDir() {
char* env_cache_dir;
if ((env_cache_dir = getenv("TVM_CACHE_DIR"))) return env_cache_dir;
if ((env_cache_dir = getenv("XDG_CACHE_HOME"))) {
return std::string(env_cache_dir) + "/tvm";
}
if ((env_cache_dir = getenv("HOME"))) {
return std::string(env_cache_dir) + "/.cache/tvm";
}
return ".";
}
std::string GetFileBasename(const std::string& file_name) {
size_t last_slash = file_name.find_last_of("/");
if (last_slash == std::string::npos) return file_name;
return file_name.substr(last_slash + 1);
}
std::string GetMetaFilePath(const std::string& file_name) {
size_t pos = file_name.find_last_of(".");
if (pos != std::string::npos) {
return file_name.substr(0, pos) + ".tvm_meta.json";
} else {
return file_name + ".tvm_meta.json";
}
}
void LoadBinaryFromFile(const std::string& file_name,
std::string* data) {
std::ifstream fs(file_name, std::ios::in | std::ios::binary);
CHECK(!fs.fail()) << "Cannot open " << file_name;
// get its size:
fs.seekg(0, std::ios::end);
size_t size = static_cast<size_t>(fs.tellg());
fs.seekg(0, std::ios::beg);
data->resize(size);
fs.read(&(*data)[0], size);
}
void SaveBinaryToFile(
const std::string& file_name,
const std::string& data) {
std::ofstream fs(file_name, std::ios::out | std::ios::binary);
CHECK(!fs.fail()) << "Cannot open " << file_name;
fs.write(&data[0], data.length());
}
void SaveMetaDataToFile(
const std::string& file_name,
const std::unordered_map<std::string, FunctionInfo>& fmap) {
std::string version = "0.1.0";
std::ofstream fs(file_name.c_str());
CHECK(!fs.fail()) << "Cannot open file " << file_name;
dmlc::JSONWriter writer(&fs);
writer.BeginObject();
writer.WriteObjectKeyValue("tvm_version", version);
writer.WriteObjectKeyValue("func_info", fmap);
writer.EndObject();
fs.close();
}
void LoadMetaDataFromFile(
const std::string& file_name,
std::unordered_map<std::string, FunctionInfo>* fmap) {
std::ifstream fs(file_name.c_str());
CHECK(!fs.fail()) << "Cannot open file " << file_name;
std::string version;
dmlc::JSONReader reader(&fs);
dmlc::JSONObjectReadHelper helper;
helper.DeclareField("tvm_version", &version);
helper.DeclareField("func_info", fmap);
helper.ReadAllFields(&reader);
fs.close();
}
} // namespace runtime
} // namespace tvm
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment