Unverified Commit a1d50f0f authored by Lingfan Yu's avatar Lingfan Yu Committed by GitHub
Browse files

[Refactor] Rename before release (#261)

* include/dgl/runtime

* include

* src/runtime

* src/graph

* src/scheduler

* src

* clean up CMakeLists

* further clean up in cmake

* install commands

* python/dgl/_ffi/_cython

* python/dgl/_ffi/_ctypes

* python/dgl/_ffi

* python/dgl

* some fix

* copy right
parent aabba9d4
...@@ -4,20 +4,20 @@ from cpython cimport Py_INCREF, Py_DECREF ...@@ -4,20 +4,20 @@ from cpython cimport Py_INCREF, Py_DECREF
from numbers import Number, Integral from numbers import Number, Integral
from ..base import string_types from ..base import string_types
from ..node_generic import convert_to_node, NodeGeneric from ..node_generic import convert_to_node, NodeGeneric
from ..runtime_ctypes import TVMType, TVMContext, TVMByteArray from ..runtime_ctypes import DGLType, DGLContext, DGLByteArray
cdef void tvm_callback_finalize(void* fhandle): cdef void dgl_callback_finalize(void* fhandle):
local_pyfunc = <object>(fhandle) local_pyfunc = <object>(fhandle)
Py_DECREF(local_pyfunc) Py_DECREF(local_pyfunc)
cdef int tvm_callback(TVMValue* args, cdef int dgl_callback(DGLValue* args,
int* type_codes, int* type_codes,
int num_args, int num_args,
TVMRetValueHandle ret, DGLRetValueHandle ret,
void* fhandle) with gil: void* fhandle) with gil:
cdef list pyargs cdef list pyargs
cdef TVMValue value cdef DGLValue value
cdef int tcode cdef int tcode
local_pyfunc = <object>(fhandle) local_pyfunc = <object>(fhandle)
pyargs = [] pyargs = []
...@@ -28,7 +28,7 @@ cdef int tvm_callback(TVMValue* args, ...@@ -28,7 +28,7 @@ cdef int tvm_callback(TVMValue* args,
tcode == kFuncHandle or tcode == kFuncHandle or
tcode == kModuleHandle or tcode == kModuleHandle or
tcode > kExtBegin): tcode > kExtBegin):
CALL(TVMCbArgToReturn(&value, tcode)) CALL(DGLCbArgToReturn(&value, tcode))
if tcode != kArrayHandle: if tcode != kArrayHandle:
pyargs.append(make_ret(value, tcode)) pyargs.append(make_ret(value, tcode))
...@@ -38,19 +38,19 @@ cdef int tvm_callback(TVMValue* args, ...@@ -38,19 +38,19 @@ cdef int tvm_callback(TVMValue* args,
rv = local_pyfunc(*pyargs) rv = local_pyfunc(*pyargs)
except Exception: except Exception:
msg = traceback.format_exc() msg = traceback.format_exc()
TVMAPISetLastError(c_str(msg)) DGLAPISetLastError(c_str(msg))
return -1 return -1
if rv is not None: if rv is not None:
if isinstance(rv, tuple): if isinstance(rv, tuple):
raise ValueError("PackedFunction can only support one return value") raise ValueError("PackedFunction can only support one return value")
temp_args = [] temp_args = []
make_arg(rv, &value, &tcode, temp_args) make_arg(rv, &value, &tcode, temp_args)
CALL(TVMCFuncSetReturn(ret, &value, &tcode, 1)) CALL(DGLCFuncSetReturn(ret, &value, &tcode, 1))
return 0 return 0
def convert_to_tvm_func(object pyfunc): def convert_to_dgl_func(object pyfunc):
"""Convert a python function to TVM function """Convert a python function to DGL function
Parameters Parameters
---------- ----------
...@@ -59,14 +59,14 @@ def convert_to_tvm_func(object pyfunc): ...@@ -59,14 +59,14 @@ def convert_to_tvm_func(object pyfunc):
Returns Returns
------- -------
tvmfunc: tvm.Function dglfunc: dgl.Function
The converted tvm function. The converted dgl function.
""" """
cdef TVMFunctionHandle chandle cdef DGLFunctionHandle chandle
Py_INCREF(pyfunc) Py_INCREF(pyfunc)
CALL(TVMFuncCreateFromCFunc(tvm_callback, CALL(DGLFuncCreateFromCFunc(dgl_callback,
<void*>(pyfunc), <void*>(pyfunc),
tvm_callback_finalize, dgl_callback_finalize,
&chandle)) &chandle))
ret = _CLASS_FUNCTION(None, False) ret = _CLASS_FUNCTION(None, False)
(<FunctionBase>ret).chandle = chandle (<FunctionBase>ret).chandle = chandle
...@@ -74,10 +74,10 @@ def convert_to_tvm_func(object pyfunc): ...@@ -74,10 +74,10 @@ def convert_to_tvm_func(object pyfunc):
cdef inline int make_arg(object arg, cdef inline int make_arg(object arg,
TVMValue* value, DGLValue* value,
int* tcode, int* tcode,
list temp_args) except -1: list temp_args) except -1:
"""Pack arguments into c args tvm call accept""" """Pack arguments into c args dgl call accept"""
cdef unsigned long long ptr cdef unsigned long long ptr
if isinstance(arg, NodeBase): if isinstance(arg, NodeBase):
value[0].v_handle = (<NodeBase>arg).chandle value[0].v_handle = (<NodeBase>arg).chandle
...@@ -86,10 +86,10 @@ cdef inline int make_arg(object arg, ...@@ -86,10 +86,10 @@ cdef inline int make_arg(object arg,
value[0].v_handle = (<NDArrayBase>arg).chandle value[0].v_handle = (<NDArrayBase>arg).chandle
tcode[0] = (kNDArrayContainer if tcode[0] = (kNDArrayContainer if
not (<NDArrayBase>arg).c_is_view else kArrayHandle) not (<NDArrayBase>arg).c_is_view else kArrayHandle)
elif isinstance(arg, _TVM_COMPATS): elif isinstance(arg, _DGL_COMPATS):
ptr = arg._tvm_handle ptr = arg._dgl_handle
value[0].v_handle = (<void*>ptr) value[0].v_handle = (<void*>ptr)
tcode[0] = arg.__class__._tvm_tcode tcode[0] = arg.__class__._dgl_tcode
elif isinstance(arg, (int, long)): elif isinstance(arg, (int, long)):
value[0].v_int64 = arg value[0].v_int64 = arg
tcode[0] = kInt tcode[0] = kInt
...@@ -107,17 +107,17 @@ cdef inline int make_arg(object arg, ...@@ -107,17 +107,17 @@ cdef inline int make_arg(object arg,
elif isinstance(arg, Number): elif isinstance(arg, Number):
value[0].v_float64 = arg value[0].v_float64 = arg
tcode[0] = kFloat tcode[0] = kFloat
elif isinstance(arg, TVMType): elif isinstance(arg, DGLType):
tstr = c_str(str(arg)) tstr = c_str(str(arg))
value[0].v_str = tstr value[0].v_str = tstr
tcode[0] = kStr tcode[0] = kStr
temp_args.append(tstr) temp_args.append(tstr)
elif isinstance(arg, TVMContext): elif isinstance(arg, DGLContext):
value[0].v_ctx = (<DLContext*>( value[0].v_ctx = (<DLContext*>(
<unsigned long long>ctypes.addressof(arg)))[0] <unsigned long long>ctypes.addressof(arg)))[0]
tcode[0] = kTVMContext tcode[0] = kDGLContext
elif isinstance(arg, bytearray): elif isinstance(arg, bytearray):
arr = TVMByteArray() arr = DGLByteArray()
arr.data = ctypes.cast( arr.data = ctypes.cast(
(ctypes.c_byte * len(arg)).from_buffer(arg), (ctypes.c_byte * len(arg)).from_buffer(arg),
ctypes.POINTER(ctypes.c_byte)) ctypes.POINTER(ctypes.c_byte))
...@@ -146,7 +146,7 @@ cdef inline int make_arg(object arg, ...@@ -146,7 +146,7 @@ cdef inline int make_arg(object arg,
value[0].v_handle = c_handle(arg) value[0].v_handle = c_handle(arg)
tcode[0] = kHandle tcode[0] = kHandle
elif callable(arg): elif callable(arg):
arg = convert_to_tvm_func(arg) arg = convert_to_dgl_func(arg)
value[0].v_handle = (<FunctionBase>arg).chandle value[0].v_handle = (<FunctionBase>arg).chandle
tcode[0] = kFuncHandle tcode[0] = kFuncHandle
temp_args.append(arg) temp_args.append(arg)
...@@ -156,7 +156,7 @@ cdef inline int make_arg(object arg, ...@@ -156,7 +156,7 @@ cdef inline int make_arg(object arg,
cdef inline bytearray make_ret_bytes(void* chandle): cdef inline bytearray make_ret_bytes(void* chandle):
handle = ctypes_handle(chandle) handle = ctypes_handle(chandle)
arr = ctypes.cast(handle, ctypes.POINTER(TVMByteArray))[0] arr = ctypes.cast(handle, ctypes.POINTER(DGLByteArray))[0]
size = arr.size size = arr.size
res = bytearray(size) res = bytearray(size)
rptr = (ctypes.c_byte * size).from_buffer(res) rptr = (ctypes.c_byte * size).from_buffer(res)
...@@ -164,7 +164,7 @@ cdef inline bytearray make_ret_bytes(void* chandle): ...@@ -164,7 +164,7 @@ cdef inline bytearray make_ret_bytes(void* chandle):
raise RuntimeError('memmove failed') raise RuntimeError('memmove failed')
return res return res
cdef inline object make_ret(TVMValue value, int tcode): cdef inline object make_ret(DGLValue value, int tcode):
"""convert result to return value.""" """convert result to return value."""
if tcode == kNodeHandle: if tcode == kNodeHandle:
return make_ret_node(value.v_handle) return make_ret_node(value.v_handle)
...@@ -182,16 +182,16 @@ cdef inline object make_ret(TVMValue value, int tcode): ...@@ -182,16 +182,16 @@ cdef inline object make_ret(TVMValue value, int tcode):
return make_ret_bytes(value.v_handle) return make_ret_bytes(value.v_handle)
elif tcode == kHandle: elif tcode == kHandle:
return ctypes_handle(value.v_handle) return ctypes_handle(value.v_handle)
elif tcode == kTVMContext: elif tcode == kDGLContext:
return TVMContext(value.v_ctx.device_type, value.v_ctx.device_id) return DGLContext(value.v_ctx.device_type, value.v_ctx.device_id)
elif tcode == kModuleHandle: elif tcode == kModuleHandle:
return _CLASS_MODULE(ctypes_handle(value.v_handle)) return _CLASS_MODULE(ctypes_handle(value.v_handle))
elif tcode == kFuncHandle: elif tcode == kFuncHandle:
fobj = _CLASS_FUNCTION(None, False) fobj = _CLASS_FUNCTION(None, False)
(<FunctionBase>fobj).chandle = value.v_handle (<FunctionBase>fobj).chandle = value.v_handle
return fobj return fobj
elif tcode in _TVM_EXT_RET: elif tcode in _DGL_EXT_RET:
return _TVM_EXT_RET[tcode](ctypes_handle(value.v_handle)) return _DGL_EXT_RET[tcode](ctypes_handle(value.v_handle))
raise ValueError("Unhandled type code %d" % tcode) raise ValueError("Unhandled type code %d" % tcode)
...@@ -199,21 +199,21 @@ cdef inline object make_ret(TVMValue value, int tcode): ...@@ -199,21 +199,21 @@ cdef inline object make_ret(TVMValue value, int tcode):
cdef inline int FuncCall3(void* chandle, cdef inline int FuncCall3(void* chandle,
tuple args, tuple args,
int nargs, int nargs,
TVMValue* ret_val, DGLValue* ret_val,
int* ret_tcode) except -1: int* ret_tcode) except -1:
cdef TVMValue[3] values cdef DGLValue[3] values
cdef int[3] tcodes cdef int[3] tcodes
nargs = len(args) nargs = len(args)
temp_args = [] temp_args = []
for i in range(nargs): for i in range(nargs):
make_arg(args[i], &values[i], &tcodes[i], temp_args) make_arg(args[i], &values[i], &tcodes[i], temp_args)
CALL(TVMFuncCall(chandle, &values[0], &tcodes[0], CALL(DGLFuncCall(chandle, &values[0], &tcodes[0],
nargs, ret_val, ret_tcode)) nargs, ret_val, ret_tcode))
return 0 return 0
cdef inline int FuncCall(void* chandle, cdef inline int FuncCall(void* chandle,
tuple args, tuple args,
TVMValue* ret_val, DGLValue* ret_val,
int* ret_tcode) except -1: int* ret_tcode) except -1:
cdef int nargs cdef int nargs
nargs = len(args) nargs = len(args)
...@@ -221,14 +221,14 @@ cdef inline int FuncCall(void* chandle, ...@@ -221,14 +221,14 @@ cdef inline int FuncCall(void* chandle,
FuncCall3(chandle, args, nargs, ret_val, ret_tcode) FuncCall3(chandle, args, nargs, ret_val, ret_tcode)
return 0 return 0
cdef vector[TVMValue] values cdef vector[DGLValue] values
cdef vector[int] tcodes cdef vector[int] tcodes
values.resize(max(nargs, 1)) values.resize(max(nargs, 1))
tcodes.resize(max(nargs, 1)) tcodes.resize(max(nargs, 1))
temp_args = [] temp_args = []
for i in range(nargs): for i in range(nargs):
make_arg(args[i], &values[i], &tcodes[i], temp_args) make_arg(args[i], &values[i], &tcodes[i], temp_args)
CALL(TVMFuncCall(chandle, &values[0], &tcodes[0], CALL(DGLFuncCall(chandle, &values[0], &tcodes[0],
nargs, ret_val, ret_tcode)) nargs, ret_val, ret_tcode))
return 0 return 0
...@@ -238,7 +238,7 @@ cdef inline int ConstructorCall(void* constructor_handle, ...@@ -238,7 +238,7 @@ cdef inline int ConstructorCall(void* constructor_handle,
tuple args, tuple args,
void** handle) except -1: void** handle) except -1:
"""Call contructor of a handle function""" """Call contructor of a handle function"""
cdef TVMValue ret_val cdef DGLValue ret_val
cdef int ret_tcode cdef int ret_tcode
FuncCall(constructor_handle, args, &ret_val, &ret_tcode) FuncCall(constructor_handle, args, &ret_val, &ret_tcode)
assert ret_tcode == type_code assert ret_tcode == type_code
...@@ -247,7 +247,7 @@ cdef inline int ConstructorCall(void* constructor_handle, ...@@ -247,7 +247,7 @@ cdef inline int ConstructorCall(void* constructor_handle,
cdef class FunctionBase: cdef class FunctionBase:
cdef TVMFunctionHandle chandle cdef DGLFunctionHandle chandle
cdef int is_global cdef int is_global
cdef inline _set_handle(self, handle): cdef inline _set_handle(self, handle):
...@@ -278,10 +278,10 @@ cdef class FunctionBase: ...@@ -278,10 +278,10 @@ cdef class FunctionBase:
def __dealloc__(self): def __dealloc__(self):
if self.is_global == 0: if self.is_global == 0:
CALL(TVMFuncFree(self.chandle)) CALL(DGLFuncFree(self.chandle))
def __call__(self, *args): def __call__(self, *args):
cdef TVMValue ret_val cdef DGLValue ret_val
cdef int ret_tcode cdef int ret_tcode
FuncCall(self.chandle, args, &ret_val, &ret_tcode) FuncCall(self.chandle, args, &ret_val, &ret_tcode)
return make_ret(ret_val, ret_tcode) return make_ret(ret_val, ret_tcode)
......
from ..runtime_ctypes import TVMArrayHandle from ..runtime_ctypes import DGLArrayHandle
cdef const char* _c_str_dltensor = "dltensor" cdef const char* _c_str_dltensor = "dltensor"
cdef const char* _c_str_used_dltensor = "used_dltensor" cdef const char* _c_str_used_dltensor = "used_dltensor"
...@@ -8,7 +8,7 @@ cdef void _c_dlpack_deleter(object pycaps): ...@@ -8,7 +8,7 @@ cdef void _c_dlpack_deleter(object pycaps):
cdef DLManagedTensor* dltensor cdef DLManagedTensor* dltensor
if pycapsule.PyCapsule_IsValid(pycaps, _c_str_dltensor): if pycapsule.PyCapsule_IsValid(pycaps, _c_str_dltensor):
dltensor = <DLManagedTensor*>pycapsule.PyCapsule_GetPointer(pycaps, _c_str_dltensor) dltensor = <DLManagedTensor*>pycapsule.PyCapsule_GetPointer(pycaps, _c_str_dltensor)
TVMDLManagedTensorCallDeleter(dltensor) DGLDLManagedTensorCallDeleter(dltensor)
def _from_dlpack(object dltensor): def _from_dlpack(object dltensor):
...@@ -16,7 +16,7 @@ def _from_dlpack(object dltensor): ...@@ -16,7 +16,7 @@ def _from_dlpack(object dltensor):
cdef DLTensorHandle chandle cdef DLTensorHandle chandle
if pycapsule.PyCapsule_IsValid(dltensor, _c_str_dltensor): if pycapsule.PyCapsule_IsValid(dltensor, _c_str_dltensor):
ptr = <DLManagedTensor*>pycapsule.PyCapsule_GetPointer(dltensor, _c_str_dltensor) ptr = <DLManagedTensor*>pycapsule.PyCapsule_GetPointer(dltensor, _c_str_dltensor)
CALL(TVMArrayFromDLPack(ptr, &chandle)) CALL(DGLArrayFromDLPack(ptr, &chandle))
# set name and destructor to be empty # set name and destructor to be empty
pycapsule.PyCapsule_SetDestructor(dltensor, NULL) pycapsule.PyCapsule_SetDestructor(dltensor, NULL)
pycapsule.PyCapsule_SetName(dltensor, _c_str_used_dltensor) pycapsule.PyCapsule_SetName(dltensor, _c_str_used_dltensor)
...@@ -36,7 +36,7 @@ cdef class NDArrayBase: ...@@ -36,7 +36,7 @@ cdef class NDArrayBase:
ptr = ctypes.cast(handle, ctypes.c_void_p).value ptr = ctypes.cast(handle, ctypes.c_void_p).value
self.chandle = <DLTensor*>(ptr) self.chandle = <DLTensor*>(ptr)
property _tvm_handle: property _dgl_handle:
def __get__(self): def __get__(self):
return <unsigned long long>self.chandle return <unsigned long long>self.chandle
...@@ -46,7 +46,7 @@ cdef class NDArrayBase: ...@@ -46,7 +46,7 @@ cdef class NDArrayBase:
return None return None
else: else:
return ctypes.cast( return ctypes.cast(
<unsigned long long>self.chandle, TVMArrayHandle) <unsigned long long>self.chandle, DGLArrayHandle)
def __set__(self, value): def __set__(self, value):
self._set_handle(value) self._set_handle(value)
...@@ -57,7 +57,7 @@ cdef class NDArrayBase: ...@@ -57,7 +57,7 @@ cdef class NDArrayBase:
def __dealloc__(self): def __dealloc__(self):
if self.c_is_view == 0: if self.c_is_view == 0:
CALL(TVMArrayFree(self.chandle)) CALL(DGLArrayFree(self.chandle))
def to_dlpack(self): def to_dlpack(self):
"""Produce an array from a DLPack Tensor without copying memory """Produce an array from a DLPack Tensor without copying memory
...@@ -69,7 +69,7 @@ cdef class NDArrayBase: ...@@ -69,7 +69,7 @@ cdef class NDArrayBase:
cdef DLManagedTensor* dltensor cdef DLManagedTensor* dltensor
if self.c_is_view != 0: if self.c_is_view != 0:
raise ValueError("to_dlpack do not work with memory views") raise ValueError("to_dlpack do not work with memory views")
CALL(TVMArrayToDLPack(self.chandle, &dltensor)) CALL(DGLArrayToDLPack(self.chandle, &dltensor))
return pycapsule.PyCapsule_New(dltensor, _c_str_dltensor, _c_dlpack_deleter) return pycapsule.PyCapsule_New(dltensor, _c_str_dltensor, _c_dlpack_deleter)
...@@ -79,15 +79,15 @@ cdef c_make_array(void* chandle, is_view): ...@@ -79,15 +79,15 @@ cdef c_make_array(void* chandle, is_view):
return ret return ret
cdef _TVM_COMPATS = () cdef _DGL_COMPATS = ()
cdef _TVM_EXT_RET = {} cdef _DGL_EXT_RET = {}
def _reg_extension(cls, fcreate): def _reg_extension(cls, fcreate):
global _TVM_COMPATS global _DGL_COMPATS
_TVM_COMPATS += (cls,) _DGL_COMPATS += (cls,)
if fcreate: if fcreate:
_TVM_EXT_RET[cls._tvm_tcode] = fcreate _DGL_EXT_RET[cls._dgl_tcode] = fcreate
def _make_array(handle, is_view): def _make_array(handle, is_view):
......
...@@ -18,7 +18,7 @@ cdef inline object make_ret_node(void* chandle): ...@@ -18,7 +18,7 @@ cdef inline object make_ret_node(void* chandle):
cdef list node_type cdef list node_type
cdef object cls cdef object cls
node_type = NODE_TYPE node_type = NODE_TYPE
CALL(TVMNodeGetTypeIndex(chandle, &tindex)) CALL(DGLNodeGetTypeIndex(chandle, &tindex))
if tindex < len(node_type): if tindex < len(node_type):
cls = node_type[tindex] cls = node_type[tindex]
if cls is not None: if cls is not None:
...@@ -53,12 +53,12 @@ cdef class NodeBase: ...@@ -53,12 +53,12 @@ cdef class NodeBase:
self._set_handle(value) self._set_handle(value)
def __dealloc__(self): def __dealloc__(self):
CALL(TVMNodeFree(self.chandle)) CALL(DGLNodeFree(self.chandle))
def __getattr__(self, name): def __getattr__(self, name):
cdef TVMValue ret_val cdef DGLValue ret_val
cdef int ret_type_code, ret_succ cdef int ret_type_code, ret_succ
CALL(TVMNodeGetAttr(self.chandle, c_str(name), CALL(DGLNodeGetAttr(self.chandle, c_str(name),
&ret_val, &ret_type_code, &ret_succ)) &ret_val, &ret_type_code, &ret_succ))
if ret_succ == 0: if ret_succ == 0:
raise AttributeError( raise AttributeError(
......
...@@ -34,7 +34,7 @@ def _load_lib(): ...@@ -34,7 +34,7 @@ def _load_lib():
lib_path = libinfo.find_lib_path() lib_path = libinfo.find_lib_path()
lib = ctypes.CDLL(lib_path[0], ctypes.RTLD_GLOBAL) lib = ctypes.CDLL(lib_path[0], ctypes.RTLD_GLOBAL)
# DMatrix functions # DMatrix functions
lib.TVMGetLastError.restype = ctypes.c_char_p lib.DGLGetLastError.restype = ctypes.c_char_p
return lib, os.path.basename(lib_path[0]) return lib, os.path.basename(lib_path[0])
# version number # version number
...@@ -60,7 +60,7 @@ def check_call(ret): ...@@ -60,7 +60,7 @@ def check_call(ret):
return value from API calls return value from API calls
""" """
if ret != 0: if ret != 0:
raise DGLError(py_str(_LIB.TVMGetLastError())) raise DGLError(py_str(_LIB.DGLGetLastError()))
def c_str(string): def c_str(string):
......
...@@ -15,31 +15,31 @@ try: ...@@ -15,31 +15,31 @@ try:
if sys.version_info >= (3, 0): if sys.version_info >= (3, 0):
from ._cy3.core import _set_class_function, _set_class_module from ._cy3.core import _set_class_function, _set_class_module
from ._cy3.core import FunctionBase as _FunctionBase from ._cy3.core import FunctionBase as _FunctionBase
from ._cy3.core import convert_to_tvm_func from ._cy3.core import convert_to_dgl_func
else: else:
from ._cy2.core import _set_class_function, _set_class_module from ._cy2.core import _set_class_function, _set_class_module
from ._cy2.core import FunctionBase as _FunctionBase from ._cy2.core import FunctionBase as _FunctionBase
from ._cy2.core import convert_to_tvm_func from ._cy2.core import convert_to_dgl_func
except IMPORT_EXCEPT: except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position # pylint: disable=wrong-import-position
from ._ctypes.function import _set_class_function, _set_class_module from ._ctypes.function import _set_class_function, _set_class_module
from ._ctypes.function import FunctionBase as _FunctionBase from ._ctypes.function import FunctionBase as _FunctionBase
from ._ctypes.function import convert_to_tvm_func from ._ctypes.function import convert_to_dgl_func
FunctionHandle = ctypes.c_void_p FunctionHandle = ctypes.c_void_p
class Function(_FunctionBase): class Function(_FunctionBase):
"""The PackedFunc object. """The PackedFunc object.
Function plays an key role to bridge front and backend in TVM. Function plays an key role to bridge front and backend in DGL.
Function provide a type-erased interface, you can call function with positional arguments. Function provide a type-erased interface, you can call function with positional arguments.
The compiled module returns Function. The compiled module returns Function.
TVM backend also registers and exposes its API as Functions. DGL backend also registers and exposes its API as Functions.
For example, the developer function exposed in tvm.ir_pass are actually For example, the developer function exposed in dgl.ir_pass are actually
C++ functions that are registered as PackedFunc C++ functions that are registered as PackedFunc
The following are list of common usage scenario of tvm.Function. The following are list of common usage scenario of dgl.Function.
- Automatic exposure of C++ API into python - Automatic exposure of C++ API into python
- To call PackedFunc from python side - To call PackedFunc from python side
...@@ -48,8 +48,8 @@ class Function(_FunctionBase): ...@@ -48,8 +48,8 @@ class Function(_FunctionBase):
See Also See Also
-------- --------
tvm.register_func: How to register global function. dgl.register_func: How to register global function.
tvm.get_global_func: How to get global function. dgl.get_global_func: How to get global function.
""" """
pass pass
...@@ -61,10 +61,10 @@ class ModuleBase(object): ...@@ -61,10 +61,10 @@ class ModuleBase(object):
def __init__(self, handle): def __init__(self, handle):
self.handle = handle self.handle = handle
self._entry = None self._entry = None
self.entry_name = "__tvm_main__" self.entry_name = "__dgl_main__"
def __del__(self): def __del__(self):
check_call(_LIB.TVMModFree(self.handle)) check_call(_LIB.DGLModFree(self.handle))
@property @property
def entry_func(self): def entry_func(self):
...@@ -97,7 +97,7 @@ class ModuleBase(object): ...@@ -97,7 +97,7 @@ class ModuleBase(object):
The result function. The result function.
""" """
ret_handle = FunctionHandle() ret_handle = FunctionHandle()
check_call(_LIB.TVMModGetFunction( check_call(_LIB.DGLModGetFunction(
self.handle, c_str(name), self.handle, c_str(name),
ctypes.c_int(query_imports), ctypes.c_int(query_imports),
ctypes.byref(ret_handle))) ctypes.byref(ret_handle)))
...@@ -114,7 +114,7 @@ class ModuleBase(object): ...@@ -114,7 +114,7 @@ class ModuleBase(object):
module : Module module : Module
The other module. The other module.
""" """
check_call(_LIB.TVMModImport(self.handle, module.handle)) check_call(_LIB.DGLModImport(self.handle, module.handle))
def __getitem__(self, name): def __getitem__(self, name):
if not isinstance(name, string_types): if not isinstance(name, string_types):
...@@ -152,18 +152,18 @@ def register_func(func_name, f=None, override=False): ...@@ -152,18 +152,18 @@ def register_func(func_name, f=None, override=False):
The following code registers my_packed_func as global function. The following code registers my_packed_func as global function.
Note that we simply get it back from global function table to invoke Note that we simply get it back from global function table to invoke
it from python side. However, we can also invoke the same function it from python side. However, we can also invoke the same function
from C++ backend, or in the compiled TVM code. from C++ backend, or in the compiled DGL code.
.. code-block:: python .. code-block:: python
targs = (10, 10.0, "hello") targs = (10, 10.0, "hello")
@tvm.register_func @dgl.register_func
def my_packed_func(*args): def my_packed_func(*args):
assert(tuple(args) == targs) assert(tuple(args) == targs)
return 10 return 10
# Get it out from global function table # Get it out from global function table
f = tvm.get_global_func("my_packed_func") f = dgl.get_global_func("my_packed_func")
assert isinstance(f, tvm.nd.Function) assert isinstance(f, dgl.nd.Function)
y = f(*targs) y = f(*targs)
assert y == 10 assert y == 10
""" """
...@@ -178,8 +178,8 @@ def register_func(func_name, f=None, override=False): ...@@ -178,8 +178,8 @@ def register_func(func_name, f=None, override=False):
def register(myf): def register(myf):
"""internal register function""" """internal register function"""
if not isinstance(myf, Function): if not isinstance(myf, Function):
myf = convert_to_tvm_func(myf) myf = convert_to_dgl_func(myf)
check_call(_LIB.TVMFuncRegisterGlobal( check_call(_LIB.DGLFuncRegisterGlobal(
c_str(func_name), myf.handle, ioverride)) c_str(func_name), myf.handle, ioverride))
return myf return myf
if f: if f:
...@@ -200,11 +200,11 @@ def get_global_func(name, allow_missing=False): ...@@ -200,11 +200,11 @@ def get_global_func(name, allow_missing=False):
Returns Returns
------- -------
func : tvm.Function func : dgl.Function
The function to be returned, None if function is missing. The function to be returned, None if function is missing.
""" """
handle = FunctionHandle() handle = FunctionHandle()
check_call(_LIB.TVMFuncGetGlobal(c_str(name), ctypes.byref(handle))) check_call(_LIB.DGLFuncGetGlobal(c_str(name), ctypes.byref(handle)))
if handle.value: if handle.value:
return Function(handle, False) return Function(handle, False)
else: else:
...@@ -226,7 +226,7 @@ def list_global_func_names(): ...@@ -226,7 +226,7 @@ def list_global_func_names():
plist = ctypes.POINTER(ctypes.c_char_p)() plist = ctypes.POINTER(ctypes.c_char_p)()
size = ctypes.c_uint() size = ctypes.c_uint()
check_call(_LIB.TVMFuncListGlobalNames(ctypes.byref(size), check_call(_LIB.DGLFuncListGlobalNames(ctypes.byref(size),
ctypes.byref(plist))) ctypes.byref(plist)))
fnames = [] fnames = []
for i in range(size.value): for i in range(size.value):
...@@ -241,7 +241,7 @@ def extract_ext_funcs(finit): ...@@ -241,7 +241,7 @@ def extract_ext_funcs(finit):
Parameters Parameters
---------- ----------
finit : ctypes function finit : ctypes function
a ctypes that takes signature of TVMExtensionDeclarer a ctypes that takes signature of DGLExtensionDeclarer
Returns Returns
------- -------
...@@ -251,7 +251,7 @@ def extract_ext_funcs(finit): ...@@ -251,7 +251,7 @@ def extract_ext_funcs(finit):
fdict = {} fdict = {}
def _list(name, func): def _list(name, func):
fdict[name] = func fdict[name] = func
myf = convert_to_tvm_func(_list) myf = convert_to_dgl_func(_list)
ret = finit(myf.handle) ret = finit(myf.handle)
_ = myf _ = myf
if ret != 0: if ret != 0:
......
...@@ -19,10 +19,10 @@ def find_lib_path(name=None, search_path=None, optional=False): ...@@ -19,10 +19,10 @@ def find_lib_path(name=None, search_path=None, optional=False):
""" """
# See https://github.com/dmlc/tvm/issues/281 for some background. # See https://github.com/dmlc/tvm/issues/281 for some background.
# NB: This will either be the source directory (if TVM is run # NB: This will either be the source directory (if DGL is run
# inplace) or the install directory (if TVM is installed). # inplace) or the install directory (if DGL is installed).
# An installed TVM's curr_path will look something like: # An installed DGL's curr_path will look something like:
# $PREFIX/lib/python3.6/site-packages/tvm/_ffi # $PREFIX/lib/python3.6/site-packages/dgl/_ffi
ffi_dir = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) ffi_dir = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
source_dir = os.path.join(ffi_dir, "..", "..", "..") source_dir = os.path.join(ffi_dir, "..", "..", "..")
install_lib_dir = os.path.join(ffi_dir, "..", "..", "..", "..") install_lib_dir = os.path.join(ffi_dir, "..", "..", "..", "..")
...@@ -71,7 +71,7 @@ def find_lib_path(name=None, search_path=None, optional=False): ...@@ -71,7 +71,7 @@ def find_lib_path(name=None, search_path=None, optional=False):
# try to find lib_dll_path # try to find lib_dll_path
lib_found = [p for p in lib_dll_path if os.path.exists(p) and os.path.isfile(p)] lib_found = [p for p in lib_dll_path if os.path.exists(p) and os.path.isfile(p)]
if not lib_found: if not lib_found:
message = ('Cannot find the files.\n' + message = ('Cannot find the files.\n' +
'List of candidates:\n' + 'List of candidates:\n' +
...@@ -86,5 +86,5 @@ def find_lib_path(name=None, search_path=None, optional=False): ...@@ -86,5 +86,5 @@ def find_lib_path(name=None, search_path=None, optional=False):
# current version # current version
# We use the version of the incoming release for code # We use the version of the incoming release for code
# that is under development. # that is under development.
# The following line is set by tvm/python/update_version.py # The following line is set by dgl/python/update_version.py
__version__ = "0.0.1" __version__ = "0.0.1"
...@@ -6,8 +6,8 @@ import sys ...@@ -6,8 +6,8 @@ import sys
import ctypes import ctypes
import numpy as np import numpy as np
from .base import _LIB, check_call, c_array, string_types, _FFI_MODE, c_str from .base import _LIB, check_call, c_array, string_types, _FFI_MODE, c_str
from .runtime_ctypes import TVMType, TVMContext, TVMArray, TVMArrayHandle from .runtime_ctypes import DGLType, DGLContext, DGLArray, DGLArrayHandle
from .runtime_ctypes import TypeCode, tvm_shape_index_t from .runtime_ctypes import TypeCode, dgl_shape_index_t
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
...@@ -28,7 +28,7 @@ except IMPORT_EXCEPT: ...@@ -28,7 +28,7 @@ except IMPORT_EXCEPT:
from ._ctypes.ndarray import NDArrayBase as _NDArrayBase from ._ctypes.ndarray import NDArrayBase as _NDArrayBase
def context(dev_type, dev_id=0): def context(dev_type, dev_id=0):
"""Construct a TVM context with given device type and id. """Construct a DGL context with given device type and id.
Parameters Parameters
---------- ----------
...@@ -40,7 +40,7 @@ def context(dev_type, dev_id=0): ...@@ -40,7 +40,7 @@ def context(dev_type, dev_id=0):
Returns Returns
------- -------
ctx: TVMContext ctx: DGLContext
The corresponding context. The corresponding context.
Examples Examples
...@@ -50,29 +50,29 @@ def context(dev_type, dev_id=0): ...@@ -50,29 +50,29 @@ def context(dev_type, dev_id=0):
.. code-block:: python .. code-block:: python
assert tvm.context("cpu", 1) == tvm.cpu(1) assert dgl.context("cpu", 1) == dgl.cpu(1)
assert tvm.context("gpu", 0) == tvm.gpu(0) assert dgl.context("gpu", 0) == dgl.gpu(0)
assert tvm.context("cuda", 0) == tvm.gpu(0) assert dgl.context("cuda", 0) == dgl.gpu(0)
""" """
if isinstance(dev_type, string_types): if isinstance(dev_type, string_types):
dev_type = dev_type.split()[0] dev_type = dev_type.split()[0]
if dev_type not in TVMContext.STR2MASK: if dev_type not in DGLContext.STR2MASK:
raise ValueError("Unknown device type %s" % dev_type) raise ValueError("Unknown device type %s" % dev_type)
dev_type = TVMContext.STR2MASK[dev_type] dev_type = DGLContext.STR2MASK[dev_type]
return TVMContext(dev_type, dev_id) return DGLContext(dev_type, dev_id)
def numpyasarray(np_data): def numpyasarray(np_data):
"""Return a TVMArray representation of a numpy array. """Return a DGLArray representation of a numpy array.
""" """
data = np_data data = np_data
assert data.flags['C_CONTIGUOUS'] assert data.flags['C_CONTIGUOUS']
arr = TVMArray() arr = DGLArray()
shape = c_array(tvm_shape_index_t, data.shape) shape = c_array(dgl_shape_index_t, data.shape)
arr.data = data.ctypes.data_as(ctypes.c_void_p) arr.data = data.ctypes.data_as(ctypes.c_void_p)
arr.shape = shape arr.shape = shape
arr.strides = None arr.strides = None
arr.dtype = TVMType(np.dtype(data.dtype).name) arr.dtype = DGLType(np.dtype(data.dtype).name)
arr.ndim = data.ndim arr.ndim = data.ndim
# CPU device # CPU device
arr.ctx = context(1, 0) arr.ctx = context(1, 0)
...@@ -90,19 +90,19 @@ def empty(shape, dtype="float32", ctx=context(1, 0)): ...@@ -90,19 +90,19 @@ def empty(shape, dtype="float32", ctx=context(1, 0)):
dtype : type or str dtype : type or str
The data type of the array. The data type of the array.
ctx : TVMContext ctx : DGLContext
The context of the array The context of the array
Returns Returns
------- -------
arr : tvm.nd.NDArray arr : dgl.nd.NDArray
The array tvm supported. The array dgl supported.
""" """
shape = c_array(tvm_shape_index_t, shape) shape = c_array(dgl_shape_index_t, shape)
ndim = ctypes.c_int(len(shape)) ndim = ctypes.c_int(len(shape))
handle = TVMArrayHandle() handle = DGLArrayHandle()
dtype = TVMType(dtype) dtype = DGLType(dtype)
check_call(_LIB.TVMArrayAlloc( check_call(_LIB.DGLArrayAlloc(
shape, ndim, shape, ndim,
ctypes.c_int(dtype.type_code), ctypes.c_int(dtype.type_code),
ctypes.c_int(dtype.bits), ctypes.c_int(dtype.bits),
...@@ -126,7 +126,7 @@ def from_dlpack(dltensor): ...@@ -126,7 +126,7 @@ def from_dlpack(dltensor):
Returns Returns
------- -------
arr: tvm.nd.NDArray arr: dgl.nd.NDArray
The array view of the tensor data. The array view of the tensor data.
""" """
return _from_dlpack(dltensor) return _from_dlpack(dltensor)
...@@ -217,7 +217,7 @@ class NDArrayBase(_NDArrayBase): ...@@ -217,7 +217,7 @@ class NDArrayBase(_NDArrayBase):
except: except:
raise TypeError('array must be an array_like data,' + raise TypeError('array must be an array_like data,' +
'type %s is not supported' % str(type(source_array))) 'type %s is not supported' % str(type(source_array)))
t = TVMType(self.dtype) t = DGLType(self.dtype)
shape, dtype = self.shape, self.dtype shape, dtype = self.shape, self.dtype
if t.lanes > 1: if t.lanes > 1:
shape = shape + (t.lanes,) shape = shape + (t.lanes,)
...@@ -231,11 +231,11 @@ class NDArrayBase(_NDArrayBase): ...@@ -231,11 +231,11 @@ class NDArrayBase(_NDArrayBase):
assert source_array.flags['C_CONTIGUOUS'] assert source_array.flags['C_CONTIGUOUS']
data = source_array.ctypes.data_as(ctypes.c_void_p) data = source_array.ctypes.data_as(ctypes.c_void_p)
nbytes = ctypes.c_size_t(source_array.size * source_array.dtype.itemsize) nbytes = ctypes.c_size_t(source_array.size * source_array.dtype.itemsize)
check_call(_LIB.TVMArrayCopyFromBytes(self.handle, data, nbytes)) check_call(_LIB.DGLArrayCopyFromBytes(self.handle, data, nbytes))
return self return self
def __repr__(self): def __repr__(self):
res = "<tvm.NDArray shape={0}, {1}>\n".format(self.shape, self.context) res = "<dgl.NDArray shape={0}, {1}>\n".format(self.shape, self.context)
res += self.asnumpy().__repr__() res += self.asnumpy().__repr__()
return res return res
...@@ -250,7 +250,7 @@ class NDArrayBase(_NDArrayBase): ...@@ -250,7 +250,7 @@ class NDArrayBase(_NDArrayBase):
np_arr : numpy.ndarray np_arr : numpy.ndarray
The corresponding numpy array. The corresponding numpy array.
""" """
t = TVMType(self.dtype) t = DGLType(self.dtype)
shape, dtype = self.shape, self.dtype shape, dtype = self.shape, self.dtype
if t.lanes > 1: if t.lanes > 1:
shape = shape + (t.lanes,) shape = shape + (t.lanes,)
...@@ -260,7 +260,7 @@ class NDArrayBase(_NDArrayBase): ...@@ -260,7 +260,7 @@ class NDArrayBase(_NDArrayBase):
assert np_arr.flags['C_CONTIGUOUS'] assert np_arr.flags['C_CONTIGUOUS']
data = np_arr.ctypes.data_as(ctypes.c_void_p) data = np_arr.ctypes.data_as(ctypes.c_void_p)
nbytes = ctypes.c_size_t(np_arr.size * np_arr.dtype.itemsize) nbytes = ctypes.c_size_t(np_arr.size * np_arr.dtype.itemsize)
check_call(_LIB.TVMArrayCopyToBytes(self.handle, data, nbytes)) check_call(_LIB.DGLArrayCopyToBytes(self.handle, data, nbytes))
return np_arr return np_arr
def copyto(self, target): def copyto(self, target):
...@@ -271,10 +271,10 @@ class NDArrayBase(_NDArrayBase): ...@@ -271,10 +271,10 @@ class NDArrayBase(_NDArrayBase):
target : NDArray target : NDArray
The target array to be copied, must have same shape as this array. The target array to be copied, must have same shape as this array.
""" """
if isinstance(target, TVMContext): if isinstance(target, DGLContext):
target = empty(self.shape, self.dtype, target) target = empty(self.shape, self.dtype, target)
if isinstance(target, NDArrayBase): if isinstance(target, NDArrayBase):
check_call(_LIB.TVMArrayCopyFromTo( check_call(_LIB.DGLArrayCopyFromTo(
self.handle, target.handle, None)) self.handle, target.handle, None))
else: else:
raise ValueError("Unsupported target type %s" % str(type(target))) raise ValueError("Unsupported target type %s" % str(type(target)))
...@@ -292,13 +292,13 @@ def free_extension_handle(handle, type_code): ...@@ -292,13 +292,13 @@ def free_extension_handle(handle, type_code):
type_code : int type_code : int
The tyoe code The tyoe code
""" """
check_call(_LIB.TVMExtTypeFree(handle, ctypes.c_int(type_code))) check_call(_LIB.DGLExtTypeFree(handle, ctypes.c_int(type_code)))
def register_extension(cls, fcreate=None): def register_extension(cls, fcreate=None):
"""Register a extension class to TVM. """Register a extension class to DGL.
After the class is registered, the class will be able After the class is registered, the class will be able
to directly pass as Function argument generated by TVM. to directly pass as Function argument generated by DGL.
Parameters Parameters
---------- ----------
...@@ -307,10 +307,10 @@ def register_extension(cls, fcreate=None): ...@@ -307,10 +307,10 @@ def register_extension(cls, fcreate=None):
Note Note
---- ----
The registered class is requires one property: _tvm_handle and a class attribute _tvm_tcode. The registered class is requires one property: _dgl_handle and a class attribute _dgl_tcode.
- ```_tvm_handle``` returns integer represents the address of the handle. - ```_dgl_handle``` returns integer represents the address of the handle.
- ```_tvm_tcode``` gives integer represents type code of the class. - ```_dgl_tcode``` gives integer represents type code of the class.
Returns Returns
------- -------
...@@ -327,18 +327,18 @@ def register_extension(cls, fcreate=None): ...@@ -327,18 +327,18 @@ def register_extension(cls, fcreate=None):
.. code-block:: python .. code-block:: python
@tvm.register_extension @dgl.register_extension
class MyTensor(object): class MyTensor(object):
_tvm_tcode = tvm.TypeCode.ARRAY_HANDLE _dgl_tcode = dgl.TypeCode.ARRAY_HANDLE
def __init__(self): def __init__(self):
self.handle = _LIB.NewDLTensor() self.handle = _LIB.NewDLTensor()
@property @property
def _tvm_handle(self): def _dgl_handle(self):
return self.handle.value return self.handle.value
""" """
if fcreate and cls._tvm_tcode < TypeCode.EXT_BEGIN: if fcreate and cls._dgl_tcode < TypeCode.EXT_BEGIN:
raise ValueError("Cannot register create when extension tcode is same as buildin") raise ValueError("Cannot register create when extension tcode is same as buildin")
_reg_extension(cls, fcreate) _reg_extension(cls, fcreate)
return cls return cls
...@@ -8,7 +8,7 @@ import numpy as np ...@@ -8,7 +8,7 @@ import numpy as np
from .base import _LIB, check_call from .base import _LIB, check_call
from .. import _api_internal from .. import _api_internal
tvm_shape_index_t = ctypes.c_int64 dgl_shape_index_t = ctypes.c_int64
class TypeCode(object): class TypeCode(object):
"""Type code used in API calls""" """Type code used in API calls"""
...@@ -17,8 +17,8 @@ class TypeCode(object): ...@@ -17,8 +17,8 @@ class TypeCode(object):
FLOAT = 2 FLOAT = 2
HANDLE = 3 HANDLE = 3
NULL = 4 NULL = 4
TVM_TYPE = 5 DGL_TYPE = 5
TVM_CONTEXT = 6 DGL_CONTEXT = 6
ARRAY_HANDLE = 7 ARRAY_HANDLE = 7
NODE_HANDLE = 8 NODE_HANDLE = 8
MODULE_HANDLE = 9 MODULE_HANDLE = 9
...@@ -28,13 +28,13 @@ class TypeCode(object): ...@@ -28,13 +28,13 @@ class TypeCode(object):
NDARRAY_CONTAINER = 13 NDARRAY_CONTAINER = 13
EXT_BEGIN = 15 EXT_BEGIN = 15
class TVMByteArray(ctypes.Structure): class DGLByteArray(ctypes.Structure):
"""Temp data structure for byte array.""" """Temp data structure for byte array."""
_fields_ = [("data", ctypes.POINTER(ctypes.c_byte)), _fields_ = [("data", ctypes.POINTER(ctypes.c_byte)),
("size", ctypes.c_size_t)] ("size", ctypes.c_size_t)]
class TVMType(ctypes.Structure): class DGLType(ctypes.Structure):
"""TVM datatype structure""" """DGL datatype structure"""
_fields_ = [("type_code", ctypes.c_uint8), _fields_ = [("type_code", ctypes.c_uint8),
("bits", ctypes.c_uint8), ("bits", ctypes.c_uint8),
("lanes", ctypes.c_uint16)] ("lanes", ctypes.c_uint16)]
...@@ -50,7 +50,7 @@ class TVMType(ctypes.Structure): ...@@ -50,7 +50,7 @@ class TVMType(ctypes.Structure):
if type_str in cls._cache: if type_str in cls._cache:
return cls._cache[type_str] return cls._cache[type_str]
inst = super(TVMType, cls).__new__(TVMType) inst = super(DGLType, cls).__new__(DGLType)
if isinstance(type_str, np.dtype): if isinstance(type_str, np.dtype):
type_str = str(type_str) type_str = str(type_str)
...@@ -84,7 +84,7 @@ class TVMType(ctypes.Structure): ...@@ -84,7 +84,7 @@ class TVMType(ctypes.Structure):
pass pass
def __repr__(self): def __repr__(self):
x = "%s%d" % (TVMType.CODE2STR[self.type_code], self.bits) x = "%s%d" % (DGLType.CODE2STR[self.type_code], self.bits)
if self.lanes != 1: if self.lanes != 1:
x += "x%d" % self.lanes x += "x%d" % self.lanes
return x return x
...@@ -99,8 +99,8 @@ class TVMType(ctypes.Structure): ...@@ -99,8 +99,8 @@ class TVMType(ctypes.Structure):
RPC_SESS_MASK = 128 RPC_SESS_MASK = 128
class TVMContext(ctypes.Structure): class DGLContext(ctypes.Structure):
"""TVM context strucure.""" """DGL context strucure."""
_fields_ = [("device_type", ctypes.c_int), _fields_ = [("device_type", ctypes.c_int),
("device_id", ctypes.c_int)] ("device_id", ctypes.c_int)]
MASK2STR = { MASK2STR = {
...@@ -141,7 +141,7 @@ class TVMContext(ctypes.Structure): ...@@ -141,7 +141,7 @@ class TVMContext(ctypes.Structure):
if (device_type, device_id) in cls._cache: if (device_type, device_id) in cls._cache:
return cls._cache[(device_type, device_id)] return cls._cache[(device_type, device_id)]
inst = super(TVMContext, cls).__new__(TVMContext) inst = super(DGLContext, cls).__new__(DGLContext)
inst.device_type = device_type inst.device_type = device_type
inst.device_id = device_id inst.device_id = device_id
...@@ -222,10 +222,10 @@ class TVMContext(ctypes.Structure): ...@@ -222,10 +222,10 @@ class TVMContext(ctypes.Structure):
def sync(self): def sync(self):
"""Synchronize until jobs finished at the context.""" """Synchronize until jobs finished at the context."""
check_call(_LIB.TVMSynchronize(self.device_type, self.device_id, None)) check_call(_LIB.DGLSynchronize(self.device_type, self.device_id, None))
def __eq__(self, other): def __eq__(self, other):
return (isinstance(other, TVMContext) and return (isinstance(other, DGLContext) and
self.device_id == other.device_id and self.device_id == other.device_id and
self.device_type == other.device_type) self.device_type == other.device_type)
...@@ -237,22 +237,22 @@ class TVMContext(ctypes.Structure): ...@@ -237,22 +237,22 @@ class TVMContext(ctypes.Structure):
tbl_id = self.device_type / RPC_SESS_MASK - 1 tbl_id = self.device_type / RPC_SESS_MASK - 1
dev_type = self.device_type % RPC_SESS_MASK dev_type = self.device_type % RPC_SESS_MASK
return "remote[%d]:%s(%d)" % ( return "remote[%d]:%s(%d)" % (
tbl_id, TVMContext.MASK2STR[dev_type], self.device_id) tbl_id, DGLContext.MASK2STR[dev_type], self.device_id)
return "%s(%d)" % ( return "%s(%d)" % (
TVMContext.MASK2STR[self.device_type], self.device_id) DGLContext.MASK2STR[self.device_type], self.device_id)
def __hash__(self): def __hash__(self):
return hash((self.device_type, self.device_id)) return hash((self.device_type, self.device_id))
class TVMArray(ctypes.Structure): class DGLArray(ctypes.Structure):
"""TVMValue in C API""" """DGLValue in C API"""
_fields_ = [("data", ctypes.c_void_p), _fields_ = [("data", ctypes.c_void_p),
("ctx", TVMContext), ("ctx", DGLContext),
("ndim", ctypes.c_int), ("ndim", ctypes.c_int),
("dtype", TVMType), ("dtype", DGLType),
("shape", ctypes.POINTER(tvm_shape_index_t)), ("shape", ctypes.POINTER(dgl_shape_index_t)),
("strides", ctypes.POINTER(tvm_shape_index_t)), ("strides", ctypes.POINTER(dgl_shape_index_t)),
("byte_offset", ctypes.c_uint64)] ("byte_offset", ctypes.c_uint64)]
TVMArrayHandle = ctypes.POINTER(TVMArray) DGLArrayHandle = ctypes.POINTER(DGLArray)
...@@ -306,7 +306,7 @@ class Frame(MutableMapping): ...@@ -306,7 +306,7 @@ class Frame(MutableMapping):
The column name. The column name.
scheme : Scheme scheme : Scheme
The column scheme. The column scheme.
ctx : TVMContext ctx : DGLContext
The column context. The column context.
""" """
if name in self: if name in self:
......
...@@ -11,7 +11,7 @@ import functools ...@@ -11,7 +11,7 @@ import functools
import operator import operator
import numpy as _np import numpy as _np
from ._ffi.ndarray import TVMContext, TVMType, NDArrayBase from ._ffi.ndarray import DGLContext, DGLType, NDArrayBase
from ._ffi.ndarray import context, empty, from_dlpack, numpyasarray from ._ffi.ndarray import context, empty, from_dlpack, numpyasarray
from ._ffi.ndarray import _set_class_ndarray from ._ffi.ndarray import _set_class_ndarray
from . import backend as F from . import backend as F
...@@ -31,10 +31,10 @@ def cpu(dev_id=0): ...@@ -31,10 +31,10 @@ def cpu(dev_id=0):
Returns Returns
------- -------
ctx : TVMContext ctx : DGLContext
The created context The created context
""" """
return TVMContext(1, dev_id) return DGLContext(1, dev_id)
def gpu(dev_id=0): def gpu(dev_id=0):
"""Construct a CPU device """Construct a CPU device
...@@ -46,10 +46,10 @@ def gpu(dev_id=0): ...@@ -46,10 +46,10 @@ def gpu(dev_id=0):
Returns Returns
------- -------
ctx : TVMContext ctx : DGLContext
The created context The created context
""" """
return TVMContext(2, dev_id) return DGLContext(2, dev_id)
def array(arr, ctx=cpu(0)): def array(arr, ctx=cpu(0)):
"""Create an array from source arr. """Create an array from source arr.
...@@ -59,7 +59,7 @@ def array(arr, ctx=cpu(0)): ...@@ -59,7 +59,7 @@ def array(arr, ctx=cpu(0)):
arr : numpy.ndarray arr : numpy.ndarray
The array to be copied from The array to be copied from
ctx : TVMContext, optional ctx : DGLContext, optional
The device context to create the array The device context to create the array
Returns Returns
......
...@@ -5,15 +5,15 @@ ...@@ -5,15 +5,15 @@
*/ */
#include "c_api_common.h" #include "c_api_common.h"
using tvm::runtime::TVMArgs; using dgl::runtime::DGLArgs;
using tvm::runtime::TVMArgValue; using dgl::runtime::DGLArgValue;
using tvm::runtime::TVMRetValue; using dgl::runtime::DGLRetValue;
using tvm::runtime::PackedFunc; using dgl::runtime::PackedFunc;
using tvm::runtime::NDArray; using dgl::runtime::NDArray;
namespace dgl { namespace dgl {
DLManagedTensor* CreateTmpDLManagedTensor(const TVMArgValue& arg) { DLManagedTensor* CreateTmpDLManagedTensor(const DGLArgValue& arg) {
const DLTensor* dl_tensor = arg; const DLTensor* dl_tensor = arg;
DLManagedTensor* ret = new DLManagedTensor(); DLManagedTensor* ret = new DLManagedTensor();
ret->deleter = [] (DLManagedTensor* self) { delete self; }; ret->deleter = [] (DLManagedTensor* self) { delete self; };
...@@ -23,7 +23,7 @@ DLManagedTensor* CreateTmpDLManagedTensor(const TVMArgValue& arg) { ...@@ -23,7 +23,7 @@ DLManagedTensor* CreateTmpDLManagedTensor(const TVMArgValue& arg) {
} }
PackedFunc ConvertNDArrayVectorToPackedFunc(const std::vector<NDArray>& vec) { PackedFunc ConvertNDArrayVectorToPackedFunc(const std::vector<NDArray>& vec) {
auto body = [vec](TVMArgs args, TVMRetValue* rv) { auto body = [vec](DGLArgs args, DGLRetValue* rv) {
const int which = args[0]; const int which = args[0];
if (which >= vec.size()) { if (which >= vec.size()) {
LOG(FATAL) << "invalid choice"; LOG(FATAL) << "invalid choice";
......
...@@ -23,16 +23,16 @@ typedef void* GraphHandle; ...@@ -23,16 +23,16 @@ typedef void* GraphHandle;
* Return a temporary DLManagedTensor that does not own memory. * Return a temporary DLManagedTensor that does not own memory.
*/ */
DLManagedTensor* CreateTmpDLManagedTensor( DLManagedTensor* CreateTmpDLManagedTensor(
const tvm::runtime::TVMArgValue& arg); const dgl::runtime::DGLArgValue& arg);
/*! /*!
* \brief Convert a vector of NDArray to PackedFunc. * \brief Convert a vector of NDArray to PackedFunc.
*/ */
tvm::runtime::PackedFunc ConvertNDArrayVectorToPackedFunc( dgl::runtime::PackedFunc ConvertNDArrayVectorToPackedFunc(
const std::vector<tvm::runtime::NDArray>& vec); const std::vector<dgl::runtime::NDArray>& vec);
/*!\brief Return whether the array is a valid 1D int array*/ /*!\brief Return whether the array is a valid 1D int array*/
inline bool IsValidIdArray(const tvm::runtime::NDArray& arr) { inline bool IsValidIdArray(const dgl::runtime::NDArray& arr) {
return arr->ctx.device_type == kDLCPU && arr->ndim == 1 return arr->ctx.device_type == kDLCPU && arr->ndim == 1
&& arr->dtype.code == kDLInt && arr->dtype.bits == 64; && arr->dtype.code == kDLInt && arr->dtype.bits == 64;
} }
...@@ -43,9 +43,9 @@ inline bool IsValidIdArray(const tvm::runtime::NDArray& arr) { ...@@ -43,9 +43,9 @@ inline bool IsValidIdArray(const tvm::runtime::NDArray& arr) {
* The element type of the vector must be convertible to int64_t. * The element type of the vector must be convertible to int64_t.
*/ */
template<typename DType> template<typename DType>
tvm::runtime::NDArray CopyVectorToNDArray( dgl::runtime::NDArray CopyVectorToNDArray(
const std::vector<DType>& vec) { const std::vector<DType>& vec) {
using tvm::runtime::NDArray; using dgl::runtime::NDArray;
const int64_t len = vec.size(); const int64_t len = vec.size();
NDArray a = NDArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0}); NDArray a = NDArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
std::copy(vec.begin(), vec.end(), static_cast<int64_t*>(a->data)); std::copy(vec.begin(), vec.end(), static_cast<int64_t*>(a->data));
......
...@@ -7,18 +7,18 @@ ...@@ -7,18 +7,18 @@
#include <dgl/graph_op.h> #include <dgl/graph_op.h>
#include "../c_api_common.h" #include "../c_api_common.h"
using tvm::runtime::TVMArgs; using dgl::runtime::DGLArgs;
using tvm::runtime::TVMArgValue; using dgl::runtime::DGLArgValue;
using tvm::runtime::TVMRetValue; using dgl::runtime::DGLRetValue;
using tvm::runtime::PackedFunc; using dgl::runtime::PackedFunc;
using tvm::runtime::NDArray; using dgl::runtime::NDArray;
namespace dgl { namespace dgl {
namespace { namespace {
// Convert EdgeArray structure to PackedFunc. // Convert EdgeArray structure to PackedFunc.
PackedFunc ConvertEdgeArrayToPackedFunc(const Graph::EdgeArray& ea) { PackedFunc ConvertEdgeArrayToPackedFunc(const Graph::EdgeArray& ea) {
auto body = [ea] (TVMArgs args, TVMRetValue* rv) { auto body = [ea] (DGLArgs args, DGLRetValue* rv) {
const int which = args[0]; const int which = args[0];
if (which == 0) { if (which == 0) {
*rv = std::move(ea.src); *rv = std::move(ea.src);
...@@ -35,7 +35,7 @@ PackedFunc ConvertEdgeArrayToPackedFunc(const Graph::EdgeArray& ea) { ...@@ -35,7 +35,7 @@ PackedFunc ConvertEdgeArrayToPackedFunc(const Graph::EdgeArray& ea) {
// Convert Subgraph structure to PackedFunc. // Convert Subgraph structure to PackedFunc.
PackedFunc ConvertSubgraphToPackedFunc(const Subgraph& sg) { PackedFunc ConvertSubgraphToPackedFunc(const Subgraph& sg) {
auto body = [sg] (TVMArgs args, TVMRetValue* rv) { auto body = [sg] (DGLArgs args, DGLRetValue* rv) {
const int which = args[0]; const int which = args[0];
if (which == 0) { if (which == 0) {
Graph* gptr = new Graph(); Graph* gptr = new Graph();
...@@ -55,30 +55,30 @@ PackedFunc ConvertSubgraphToPackedFunc(const Subgraph& sg) { ...@@ -55,30 +55,30 @@ PackedFunc ConvertSubgraphToPackedFunc(const Subgraph& sg) {
} // namespace } // namespace
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCreate") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCreate")
.set_body([] (TVMArgs args, TVMRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
bool multigraph = static_cast<bool>(args[0]); bool multigraph = static_cast<bool>(args[0]);
GraphHandle ghandle = new Graph(multigraph); GraphHandle ghandle = new Graph(multigraph);
*rv = ghandle; *rv = ghandle;
}); });
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphFree") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphFree")
.set_body([] (TVMArgs args, TVMRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0]; GraphHandle ghandle = args[0];
Graph* gptr = static_cast<Graph*>(ghandle); Graph* gptr = static_cast<Graph*>(ghandle);
delete gptr; delete gptr;
}); });
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphAddVertices") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphAddVertices")
.set_body([] (TVMArgs args, TVMRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0]; GraphHandle ghandle = args[0];
Graph* gptr = static_cast<Graph*>(ghandle); Graph* gptr = static_cast<Graph*>(ghandle);
uint64_t num_vertices = args[1]; uint64_t num_vertices = args[1];
gptr->AddVertices(num_vertices); gptr->AddVertices(num_vertices);
}); });
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphAddEdge") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphAddEdge")
.set_body([] (TVMArgs args, TVMRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0]; GraphHandle ghandle = args[0];
Graph* gptr = static_cast<Graph*>(ghandle); Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t src = args[1]; const dgl_id_t src = args[1];
...@@ -86,8 +86,8 @@ TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphAddEdge") ...@@ -86,8 +86,8 @@ TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphAddEdge")
gptr->AddEdge(src, dst); gptr->AddEdge(src, dst);
}); });
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphAddEdges") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphAddEdges")
.set_body([] (TVMArgs args, TVMRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0]; GraphHandle ghandle = args[0];
Graph* gptr = static_cast<Graph*>(ghandle); Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray src = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1])); const IdArray src = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
...@@ -95,67 +95,67 @@ TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphAddEdges") ...@@ -95,67 +95,67 @@ TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphAddEdges")
gptr->AddEdges(src, dst); gptr->AddEdges(src, dst);
}); });
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphClear") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphClear")
.set_body([] (TVMArgs args, TVMRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0]; GraphHandle ghandle = args[0];
Graph* gptr = static_cast<Graph*>(ghandle); Graph* gptr = static_cast<Graph*>(ghandle);
gptr->Clear(); gptr->Clear();
}); });
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphIsMultigraph") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphIsMultigraph")
.set_body([] (TVMArgs args, TVMRetValue *rv) { .set_body([] (DGLArgs args, DGLRetValue *rv) {
GraphHandle ghandle = args[0]; GraphHandle ghandle = args[0];
// NOTE: not const since we have caches // NOTE: not const since we have caches
const Graph* gptr = static_cast<Graph*>(ghandle); const Graph* gptr = static_cast<Graph*>(ghandle);
*rv = gptr->IsMultigraph(); *rv = gptr->IsMultigraph();
}); });
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphNumVertices") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphNumVertices")
.set_body([] (TVMArgs args, TVMRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0]; GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle); const Graph* gptr = static_cast<Graph*>(ghandle);
*rv = static_cast<int64_t>(gptr->NumVertices()); *rv = static_cast<int64_t>(gptr->NumVertices());
}); });
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphNumEdges") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphNumEdges")
.set_body([] (TVMArgs args, TVMRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0]; GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle); const Graph* gptr = static_cast<Graph*>(ghandle);
*rv = static_cast<int64_t>(gptr->NumEdges()); *rv = static_cast<int64_t>(gptr->NumEdges());
}); });
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasVertex") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasVertex")
.set_body([] (TVMArgs args, TVMRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0]; GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle); const Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t vid = args[1]; const dgl_id_t vid = args[1];
*rv = gptr->HasVertex(vid); *rv = gptr->HasVertex(vid);
}); });
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasVertices") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasVertices")
.set_body([] (TVMArgs args, TVMRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0]; GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle); const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray vids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1])); const IdArray vids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
*rv = gptr->HasVertices(vids); *rv = gptr->HasVertices(vids);
}); });
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLMapSubgraphNID") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLMapSubgraphNID")
.set_body([] (TVMArgs args, TVMRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
const IdArray parent_vids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[0])); const IdArray parent_vids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[0]));
const IdArray query = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1])); const IdArray query = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
*rv = GraphOp::MapParentIdToSubgraphId(parent_vids, query); *rv = GraphOp::MapParentIdToSubgraphId(parent_vids, query);
}); });
TVM_REGISTER_GLOBAL("immutable_graph_index._CAPI_DGLExpandIds") DGL_REGISTER_GLOBAL("immutable_graph_index._CAPI_DGLExpandIds")
.set_body([] (TVMArgs args, TVMRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
const IdArray ids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[0])); const IdArray ids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[0]));
const IdArray offsets = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1])); const IdArray offsets = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
*rv = GraphOp::ExpandIds(ids, offsets); *rv = GraphOp::ExpandIds(ids, offsets);
}); });
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasEdgeBetween") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasEdgeBetween")
.set_body([] (TVMArgs args, TVMRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0]; GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle); const Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t src = args[1]; const dgl_id_t src = args[1];
...@@ -163,8 +163,8 @@ TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasEdgeBetween") ...@@ -163,8 +163,8 @@ TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasEdgeBetween")
*rv = gptr->HasEdgeBetween(src, dst); *rv = gptr->HasEdgeBetween(src, dst);
}); });
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasEdgesBetween") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasEdgesBetween")
.set_body([] (TVMArgs args, TVMRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0]; GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle); const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray src = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1])); const IdArray src = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
...@@ -172,8 +172,8 @@ TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasEdgesBetween") ...@@ -172,8 +172,8 @@ TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasEdgesBetween")
*rv = gptr->HasEdgesBetween(src, dst); *rv = gptr->HasEdgesBetween(src, dst);
}); });
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphPredecessors") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphPredecessors")
.set_body([] (TVMArgs args, TVMRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0]; GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle); const Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t vid = args[1]; const dgl_id_t vid = args[1];
...@@ -181,8 +181,8 @@ TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphPredecessors") ...@@ -181,8 +181,8 @@ TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphPredecessors")
*rv = gptr->Predecessors(vid, radius); *rv = gptr->Predecessors(vid, radius);
}); });
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphSuccessors") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphSuccessors")
.set_body([] (TVMArgs args, TVMRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0]; GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle); const Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t vid = args[1]; const dgl_id_t vid = args[1];
...@@ -190,8 +190,8 @@ TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphSuccessors") ...@@ -190,8 +190,8 @@ TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphSuccessors")
*rv = gptr->Successors(vid, radius); *rv = gptr->Successors(vid, radius);
}); });
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeId") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeId")
.set_body([] (TVMArgs args, TVMRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0]; GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle); const Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t src = args[1]; const dgl_id_t src = args[1];
...@@ -199,8 +199,8 @@ TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeId") ...@@ -199,8 +199,8 @@ TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeId")
*rv = gptr->EdgeId(src, dst); *rv = gptr->EdgeId(src, dst);
}); });
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeIds") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeIds")
.set_body([] (TVMArgs args, TVMRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0]; GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle); const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray src = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1])); const IdArray src = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
...@@ -208,104 +208,104 @@ TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeIds") ...@@ -208,104 +208,104 @@ TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeIds")
*rv = ConvertEdgeArrayToPackedFunc(gptr->EdgeIds(src, dst)); *rv = ConvertEdgeArrayToPackedFunc(gptr->EdgeIds(src, dst));
}); });
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphFindEdges") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphFindEdges")
.set_body([] (TVMArgs args, TVMRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0]; GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle); const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray eids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1])); const IdArray eids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
*rv = ConvertEdgeArrayToPackedFunc(gptr->FindEdges(eids)); *rv = ConvertEdgeArrayToPackedFunc(gptr->FindEdges(eids));
}); });
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphInEdges_1") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphInEdges_1")
.set_body([] (TVMArgs args, TVMRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0]; GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle); const Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t vid = args[1]; const dgl_id_t vid = args[1];
*rv = ConvertEdgeArrayToPackedFunc(gptr->InEdges(vid)); *rv = ConvertEdgeArrayToPackedFunc(gptr->InEdges(vid));
}); });
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphInEdges_2") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphInEdges_2")
.set_body([] (TVMArgs args, TVMRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0]; GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle); const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray vids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1])); const IdArray vids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
*rv = ConvertEdgeArrayToPackedFunc(gptr->InEdges(vids)); *rv = ConvertEdgeArrayToPackedFunc(gptr->InEdges(vids));
}); });
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphOutEdges_1") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphOutEdges_1")
.set_body([] (TVMArgs args, TVMRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0]; GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle); const Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t vid = args[1]; const dgl_id_t vid = args[1];
*rv = ConvertEdgeArrayToPackedFunc(gptr->OutEdges(vid)); *rv = ConvertEdgeArrayToPackedFunc(gptr->OutEdges(vid));
}); });
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphOutEdges_2") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphOutEdges_2")
.set_body([] (TVMArgs args, TVMRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0]; GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle); const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray vids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1])); const IdArray vids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
*rv = ConvertEdgeArrayToPackedFunc(gptr->OutEdges(vids)); *rv = ConvertEdgeArrayToPackedFunc(gptr->OutEdges(vids));
}); });
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdges") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdges")
.set_body([] (TVMArgs args, TVMRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0]; GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle); const Graph* gptr = static_cast<Graph*>(ghandle);
const bool sorted = args[1]; const bool sorted = args[1];
*rv = ConvertEdgeArrayToPackedFunc(gptr->Edges(sorted)); *rv = ConvertEdgeArrayToPackedFunc(gptr->Edges(sorted));
}); });
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphInDegree") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphInDegree")
.set_body([] (TVMArgs args, TVMRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0]; GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle); const Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t vid = args[1]; const dgl_id_t vid = args[1];
*rv = static_cast<int64_t>(gptr->InDegree(vid)); *rv = static_cast<int64_t>(gptr->InDegree(vid));
}); });
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphInDegrees") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphInDegrees")
.set_body([] (TVMArgs args, TVMRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0]; GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle); const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray vids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1])); const IdArray vids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
*rv = gptr->InDegrees(vids); *rv = gptr->InDegrees(vids);
}); });
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphOutDegree") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphOutDegree")
.set_body([] (TVMArgs args, TVMRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0]; GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle); const Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t vid = args[1]; const dgl_id_t vid = args[1];
*rv = static_cast<int64_t>(gptr->OutDegree(vid)); *rv = static_cast<int64_t>(gptr->OutDegree(vid));
}); });
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphOutDegrees") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphOutDegrees")
.set_body([] (TVMArgs args, TVMRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0]; GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle); const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray vids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1])); const IdArray vids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
*rv = gptr->OutDegrees(vids); *rv = gptr->OutDegrees(vids);
}); });
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphVertexSubgraph") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphVertexSubgraph")
.set_body([] (TVMArgs args, TVMRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0]; GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle); const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray vids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1])); const IdArray vids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
*rv = ConvertSubgraphToPackedFunc(gptr->VertexSubgraph(vids)); *rv = ConvertSubgraphToPackedFunc(gptr->VertexSubgraph(vids));
}); });
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeSubgraph") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeSubgraph")
.set_body([] (TVMArgs args, TVMRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0]; GraphHandle ghandle = args[0];
const Graph *gptr = static_cast<Graph*>(ghandle); const Graph *gptr = static_cast<Graph*>(ghandle);
const IdArray eids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1])); const IdArray eids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
*rv = ConvertSubgraphToPackedFunc(gptr->EdgeSubgraph(eids)); *rv = ConvertSubgraphToPackedFunc(gptr->EdgeSubgraph(eids));
}); });
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointUnion") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointUnion")
.set_body([] (TVMArgs args, TVMRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
void* list = args[0]; void* list = args[0];
GraphHandle* inhandles = static_cast<GraphHandle*>(list); GraphHandle* inhandles = static_cast<GraphHandle*>(list);
int list_size = args[1]; int list_size = args[1];
...@@ -320,8 +320,8 @@ TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointUnion") ...@@ -320,8 +320,8 @@ TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointUnion")
*rv = ghandle; *rv = ghandle;
}); });
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointPartitionByNum") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointPartitionByNum")
.set_body([] (TVMArgs args, TVMRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0]; GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle); const Graph* gptr = static_cast<Graph*>(ghandle);
int64_t num = args[1]; int64_t num = args[1];
...@@ -338,8 +338,8 @@ TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointPartitionByNum") ...@@ -338,8 +338,8 @@ TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointPartitionByNum")
*rv = ptr_array; *rv = ptr_array;
}); });
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointPartitionBySizes") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointPartitionBySizes")
.set_body([] (TVMArgs args, TVMRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0]; GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle); const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray sizes = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1])); const IdArray sizes = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
...@@ -356,8 +356,8 @@ TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointPartitionBySizes") ...@@ -356,8 +356,8 @@ TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointPartitionBySizes")
*rv = ptr_array; *rv = ptr_array;
}); });
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphLineGraph") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphLineGraph")
.set_body([] (TVMArgs args, TVMRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0]; GraphHandle ghandle = args[0];
bool backtracking = args[1]; bool backtracking = args[1];
const Graph* gptr = static_cast<Graph*>(ghandle); const Graph* gptr = static_cast<Graph*>(ghandle);
......
...@@ -8,11 +8,11 @@ ...@@ -8,11 +8,11 @@
#include "./traversal.h" #include "./traversal.h"
#include "../c_api_common.h" #include "../c_api_common.h"
using tvm::runtime::TVMArgs; using dgl::runtime::DGLArgs;
using tvm::runtime::TVMArgValue; using dgl::runtime::DGLArgValue;
using tvm::runtime::TVMRetValue; using dgl::runtime::DGLRetValue;
using tvm::runtime::PackedFunc; using dgl::runtime::PackedFunc;
using tvm::runtime::NDArray; using dgl::runtime::NDArray;
namespace dgl { namespace dgl {
namespace traverse { namespace traverse {
...@@ -129,8 +129,8 @@ Frontiers BFSNodesFrontiers(const Graph& graph, IdArray source, bool reversed) { ...@@ -129,8 +129,8 @@ Frontiers BFSNodesFrontiers(const Graph& graph, IdArray source, bool reversed) {
return front; return front;
} }
TVM_REGISTER_GLOBAL("traversal._CAPI_DGLBFSNodes") DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSNodes")
.set_body([] (TVMArgs args, TVMRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0]; GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle); const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray src = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1])); const IdArray src = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
...@@ -160,8 +160,8 @@ Frontiers BFSEdgesFrontiers(const Graph& graph, IdArray source, bool reversed) { ...@@ -160,8 +160,8 @@ Frontiers BFSEdgesFrontiers(const Graph& graph, IdArray source, bool reversed) {
return front; return front;
} }
TVM_REGISTER_GLOBAL("traversal._CAPI_DGLBFSEdges") DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSEdges")
.set_body([] (TVMArgs args, TVMRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0]; GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle); const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray src = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1])); const IdArray src = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
...@@ -186,8 +186,8 @@ Frontiers TopologicalNodesFrontiers(const Graph& graph, bool reversed) { ...@@ -186,8 +186,8 @@ Frontiers TopologicalNodesFrontiers(const Graph& graph, bool reversed) {
return front; return front;
} }
TVM_REGISTER_GLOBAL("traversal._CAPI_DGLTopologicalNodes") DGL_REGISTER_GLOBAL("traversal._CAPI_DGLTopologicalNodes")
.set_body([] (TVMArgs args, TVMRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0]; GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle); const Graph* gptr = static_cast<Graph*>(ghandle);
bool reversed = args[1]; bool reversed = args[1];
...@@ -198,8 +198,8 @@ TVM_REGISTER_GLOBAL("traversal._CAPI_DGLTopologicalNodes") ...@@ -198,8 +198,8 @@ TVM_REGISTER_GLOBAL("traversal._CAPI_DGLTopologicalNodes")
}); });
TVM_REGISTER_GLOBAL("traversal._CAPI_DGLDFSEdges") DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSEdges")
.set_body([] (TVMArgs args, TVMRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0]; GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle); const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray source = args[1]; const IdArray source = args[1];
...@@ -217,8 +217,8 @@ TVM_REGISTER_GLOBAL("traversal._CAPI_DGLDFSEdges") ...@@ -217,8 +217,8 @@ TVM_REGISTER_GLOBAL("traversal._CAPI_DGLDFSEdges")
*rv = ConvertNDArrayVectorToPackedFunc({ids, sections}); *rv = ConvertNDArrayVectorToPackedFunc({ids, sections});
}); });
TVM_REGISTER_GLOBAL("traversal._CAPI_DGLDFSLabeledEdges") DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSLabeledEdges")
.set_body([] (TVMArgs args, TVMRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0]; GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle); const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray source = args[1]; const IdArray source = args[1];
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#include <cstdlib> #include <cstdlib>
#include "runtime_base.h" #include "runtime_base.h"
namespace tvm { namespace dgl {
namespace runtime { namespace runtime {
/*! /*!
...@@ -44,7 +44,7 @@ class DeviceAPIManager { ...@@ -44,7 +44,7 @@ class DeviceAPIManager {
public: public:
static const int kMaxDeviceAPI = 32; static const int kMaxDeviceAPI = 32;
// Get API // Get API
static DeviceAPI* Get(const TVMContext& ctx) { static DeviceAPI* Get(const DGLContext& ctx) {
return Get(ctx.device_type); return Get(ctx.device_type);
} }
static DeviceAPI* Get(int dev_type, bool allow_missing = false) { static DeviceAPI* Get(int dev_type, bool allow_missing = false) {
...@@ -93,81 +93,81 @@ class DeviceAPIManager { ...@@ -93,81 +93,81 @@ class DeviceAPIManager {
} }
}; };
DeviceAPI* DeviceAPI::Get(TVMContext ctx, bool allow_missing) { DeviceAPI* DeviceAPI::Get(DGLContext ctx, bool allow_missing) {
return DeviceAPIManager::Get( return DeviceAPIManager::Get(
static_cast<int>(ctx.device_type), allow_missing); static_cast<int>(ctx.device_type), allow_missing);
} }
void* DeviceAPI::AllocWorkspace(TVMContext ctx, void* DeviceAPI::AllocWorkspace(DGLContext ctx,
size_t size, size_t size,
TVMType type_hint) { DGLType type_hint) {
return AllocDataSpace(ctx, size, kTempAllocaAlignment, type_hint); return AllocDataSpace(ctx, size, kTempAllocaAlignment, type_hint);
} }
void DeviceAPI::FreeWorkspace(TVMContext ctx, void* ptr) { void DeviceAPI::FreeWorkspace(DGLContext ctx, void* ptr) {
FreeDataSpace(ctx, ptr); FreeDataSpace(ctx, ptr);
} }
TVMStreamHandle DeviceAPI::CreateStream(TVMContext ctx) { DGLStreamHandle DeviceAPI::CreateStream(DGLContext ctx) {
LOG(FATAL) << "Device does not support stream api."; LOG(FATAL) << "Device does not support stream api.";
return 0; return 0;
} }
void DeviceAPI::FreeStream(TVMContext ctx, TVMStreamHandle stream) { void DeviceAPI::FreeStream(DGLContext ctx, DGLStreamHandle stream) {
LOG(FATAL) << "Device does not support stream api."; LOG(FATAL) << "Device does not support stream api.";
} }
void DeviceAPI::SyncStreamFromTo(TVMContext ctx, void DeviceAPI::SyncStreamFromTo(DGLContext ctx,
TVMStreamHandle event_src, DGLStreamHandle event_src,
TVMStreamHandle event_dst) { DGLStreamHandle event_dst) {
LOG(FATAL) << "Device does not support stream api."; LOG(FATAL) << "Device does not support stream api.";
} }
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace dgl
using namespace tvm::runtime; using namespace dgl::runtime;
struct TVMRuntimeEntry { struct DGLRuntimeEntry {
std::string ret_str; std::string ret_str;
std::string last_error; std::string last_error;
TVMByteArray ret_bytes; DGLByteArray ret_bytes;
}; };
typedef dmlc::ThreadLocalStore<TVMRuntimeEntry> TVMAPIRuntimeStore; typedef dmlc::ThreadLocalStore<DGLRuntimeEntry> DGLAPIRuntimeStore;
const char *TVMGetLastError() { const char *DGLGetLastError() {
return TVMAPIRuntimeStore::Get()->last_error.c_str(); return DGLAPIRuntimeStore::Get()->last_error.c_str();
} }
void TVMAPISetLastError(const char* msg) { void DGLAPISetLastError(const char* msg) {
#ifndef _LIBCPP_SGX_CONFIG #ifndef _LIBCPP_SGX_CONFIG
TVMAPIRuntimeStore::Get()->last_error = msg; DGLAPIRuntimeStore::Get()->last_error = msg;
#else #else
sgx::OCallPackedFunc("__sgx_set_last_error__", msg); sgx::OCallPackedFunc("__sgx_set_last_error__", msg);
#endif #endif
} }
int TVMModLoadFromFile(const char* file_name, int DGLModLoadFromFile(const char* file_name,
const char* format, const char* format,
TVMModuleHandle* out) { DGLModuleHandle* out) {
API_BEGIN(); API_BEGIN();
Module m = Module::LoadFromFile(file_name, format); Module m = Module::LoadFromFile(file_name, format);
*out = new Module(m); *out = new Module(m);
API_END(); API_END();
} }
int TVMModImport(TVMModuleHandle mod, int DGLModImport(DGLModuleHandle mod,
TVMModuleHandle dep) { DGLModuleHandle dep) {
API_BEGIN(); API_BEGIN();
static_cast<Module*>(mod)->Import( static_cast<Module*>(mod)->Import(
*static_cast<Module*>(dep)); *static_cast<Module*>(dep));
API_END(); API_END();
} }
int TVMModGetFunction(TVMModuleHandle mod, int DGLModGetFunction(DGLModuleHandle mod,
const char* func_name, const char* func_name,
int query_imports, int query_imports,
TVMFunctionHandle *func) { DGLFunctionHandle *func) {
API_BEGIN(); API_BEGIN();
PackedFunc pf = static_cast<Module*>(mod)->GetFunction( PackedFunc pf = static_cast<Module*>(mod)->GetFunction(
func_name, query_imports != 0); func_name, query_imports != 0);
...@@ -179,31 +179,31 @@ int TVMModGetFunction(TVMModuleHandle mod, ...@@ -179,31 +179,31 @@ int TVMModGetFunction(TVMModuleHandle mod,
API_END(); API_END();
} }
int TVMModFree(TVMModuleHandle mod) { int DGLModFree(DGLModuleHandle mod) {
API_BEGIN(); API_BEGIN();
delete static_cast<Module*>(mod); delete static_cast<Module*>(mod);
API_END(); API_END();
} }
int TVMBackendGetFuncFromEnv(void* mod_node, int DGLBackendGetFuncFromEnv(void* mod_node,
const char* func_name, const char* func_name,
TVMFunctionHandle *func) { DGLFunctionHandle *func) {
API_BEGIN(); API_BEGIN();
*func = (TVMFunctionHandle)( *func = (DGLFunctionHandle)(
static_cast<ModuleNode*>(mod_node)->GetFuncFromEnv(func_name)); static_cast<ModuleNode*>(mod_node)->GetFuncFromEnv(func_name));
API_END(); API_END();
} }
void* TVMBackendAllocWorkspace(int device_type, void* DGLBackendAllocWorkspace(int device_type,
int device_id, int device_id,
uint64_t size, uint64_t size,
int dtype_code_hint, int dtype_code_hint,
int dtype_bits_hint) { int dtype_bits_hint) {
TVMContext ctx; DGLContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type); ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id; ctx.device_id = device_id;
TVMType type_hint; DGLType type_hint;
type_hint.code = static_cast<decltype(type_hint.code)>(dtype_code_hint); type_hint.code = static_cast<decltype(type_hint.code)>(dtype_code_hint);
type_hint.bits = static_cast<decltype(type_hint.bits)>(dtype_bits_hint); type_hint.bits = static_cast<decltype(type_hint.bits)>(dtype_bits_hint);
type_hint.lanes = 1; type_hint.lanes = 1;
...@@ -213,17 +213,17 @@ void* TVMBackendAllocWorkspace(int device_type, ...@@ -213,17 +213,17 @@ void* TVMBackendAllocWorkspace(int device_type,
type_hint); type_hint);
} }
int TVMBackendFreeWorkspace(int device_type, int DGLBackendFreeWorkspace(int device_type,
int device_id, int device_id,
void* ptr) { void* ptr) {
TVMContext ctx; DGLContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type); ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id; ctx.device_id = device_id;
DeviceAPIManager::Get(ctx)->FreeWorkspace(ctx, ptr); DeviceAPIManager::Get(ctx)->FreeWorkspace(ctx, ptr);
return 0; return 0;
} }
int TVMBackendRunOnce(void** handle, int DGLBackendRunOnce(void** handle,
int (*f)(void*), int (*f)(void*),
void* cdata, void* cdata,
int nbytes) { int nbytes) {
...@@ -234,28 +234,28 @@ int TVMBackendRunOnce(void** handle, ...@@ -234,28 +234,28 @@ int TVMBackendRunOnce(void** handle,
return 0; return 0;
} }
int TVMFuncFree(TVMFunctionHandle func) { int DGLFuncFree(DGLFunctionHandle func) {
API_BEGIN(); API_BEGIN();
delete static_cast<PackedFunc*>(func); delete static_cast<PackedFunc*>(func);
API_END(); API_END();
} }
int TVMFuncCall(TVMFunctionHandle func, int DGLFuncCall(DGLFunctionHandle func,
TVMValue* args, DGLValue* args,
int* arg_type_codes, int* arg_type_codes,
int num_args, int num_args,
TVMValue* ret_val, DGLValue* ret_val,
int* ret_type_code) { int* ret_type_code) {
API_BEGIN(); API_BEGIN();
TVMRetValue rv; DGLRetValue rv;
(*static_cast<const PackedFunc*>(func)).CallPacked( (*static_cast<const PackedFunc*>(func)).CallPacked(
TVMArgs(args, arg_type_codes, num_args), &rv); DGLArgs(args, arg_type_codes, num_args), &rv);
// handle return string. // handle return string.
if (rv.type_code() == kStr || if (rv.type_code() == kStr ||
rv.type_code() == kTVMType || rv.type_code() == kDGLType ||
rv.type_code() == kBytes) { rv.type_code() == kBytes) {
TVMRuntimeEntry* e = TVMAPIRuntimeStore::Get(); DGLRuntimeEntry* e = DGLAPIRuntimeStore::Get();
if (rv.type_code() != kTVMType) { if (rv.type_code() != kDGLType) {
e->ret_str = *rv.ptr<std::string>(); e->ret_str = *rv.ptr<std::string>();
} else { } else {
e->ret_str = rv.operator std::string(); e->ret_str = rv.operator std::string();
...@@ -275,30 +275,30 @@ int TVMFuncCall(TVMFunctionHandle func, ...@@ -275,30 +275,30 @@ int TVMFuncCall(TVMFunctionHandle func,
API_END(); API_END();
} }
int TVMCFuncSetReturn(TVMRetValueHandle ret, int DGLCFuncSetReturn(DGLRetValueHandle ret,
TVMValue* value, DGLValue* value,
int* type_code, int* type_code,
int num_ret) { int num_ret) {
API_BEGIN(); API_BEGIN();
CHECK_EQ(num_ret, 1); CHECK_EQ(num_ret, 1);
TVMRetValue* rv = static_cast<TVMRetValue*>(ret); DGLRetValue* rv = static_cast<DGLRetValue*>(ret);
*rv = TVMArgValue(value[0], type_code[0]); *rv = DGLArgValue(value[0], type_code[0]);
API_END(); API_END();
} }
int TVMFuncCreateFromCFunc(TVMPackedCFunc func, int DGLFuncCreateFromCFunc(DGLPackedCFunc func,
void* resource_handle, void* resource_handle,
TVMPackedCFuncFinalizer fin, DGLPackedCFuncFinalizer fin,
TVMFunctionHandle *out) { DGLFunctionHandle *out) {
API_BEGIN(); API_BEGIN();
if (fin == nullptr) { if (fin == nullptr) {
*out = new PackedFunc( *out = new PackedFunc(
[func, resource_handle](TVMArgs args, TVMRetValue* rv) { [func, resource_handle](DGLArgs args, DGLRetValue* rv) {
int ret = func((TVMValue*)args.values, (int*)args.type_codes, // NOLINT(*) int ret = func((DGLValue*)args.values, (int*)args.type_codes, // NOLINT(*)
args.num_args, rv, resource_handle); args.num_args, rv, resource_handle);
if (ret != 0) { if (ret != 0) {
std::string err = "TVMCall CFunc Error:\n"; std::string err = "DGLCall CFunc Error:\n";
err += TVMGetLastError(); err += DGLGetLastError();
throw dmlc::Error(err); throw dmlc::Error(err);
} }
}); });
...@@ -307,12 +307,12 @@ int TVMFuncCreateFromCFunc(TVMPackedCFunc func, ...@@ -307,12 +307,12 @@ int TVMFuncCreateFromCFunc(TVMPackedCFunc func,
// so fin will be called when the lambda went out of scope. // so fin will be called when the lambda went out of scope.
std::shared_ptr<void> rpack(resource_handle, fin); std::shared_ptr<void> rpack(resource_handle, fin);
*out = new PackedFunc( *out = new PackedFunc(
[func, rpack](TVMArgs args, TVMRetValue* rv) { [func, rpack](DGLArgs args, DGLRetValue* rv) {
int ret = func((TVMValue*)args.values, (int*)args.type_codes, // NOLINT(*) int ret = func((DGLValue*)args.values, (int*)args.type_codes, // NOLINT(*)
args.num_args, rv, rpack.get()); args.num_args, rv, rpack.get());
if (ret != 0) { if (ret != 0) {
std::string err = "TVMCall CFunc Error:\n"; std::string err = "DGLCall CFunc Error:\n";
err += TVMGetLastError(); err += DGLGetLastError();
throw dmlc::Error(err); throw dmlc::Error(err);
} }
}); });
...@@ -320,58 +320,58 @@ int TVMFuncCreateFromCFunc(TVMPackedCFunc func, ...@@ -320,58 +320,58 @@ int TVMFuncCreateFromCFunc(TVMPackedCFunc func,
API_END(); API_END();
} }
int TVMStreamCreate(int device_type, int device_id, TVMStreamHandle* out) { int DGLStreamCreate(int device_type, int device_id, DGLStreamHandle* out) {
API_BEGIN(); API_BEGIN();
TVMContext ctx; DGLContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type); ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id; ctx.device_id = device_id;
*out = DeviceAPIManager::Get(ctx)->CreateStream(ctx); *out = DeviceAPIManager::Get(ctx)->CreateStream(ctx);
API_END(); API_END();
} }
int TVMStreamFree(int device_type, int device_id, TVMStreamHandle stream) { int DGLStreamFree(int device_type, int device_id, DGLStreamHandle stream) {
API_BEGIN(); API_BEGIN();
TVMContext ctx; DGLContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type); ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id; ctx.device_id = device_id;
DeviceAPIManager::Get(ctx)->FreeStream(ctx, stream); DeviceAPIManager::Get(ctx)->FreeStream(ctx, stream);
API_END(); API_END();
} }
int TVMSetStream(int device_type, int device_id, TVMStreamHandle stream) { int DGLSetStream(int device_type, int device_id, DGLStreamHandle stream) {
API_BEGIN(); API_BEGIN();
TVMContext ctx; DGLContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type); ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id; ctx.device_id = device_id;
DeviceAPIManager::Get(ctx)->SetStream(ctx, stream); DeviceAPIManager::Get(ctx)->SetStream(ctx, stream);
API_END(); API_END();
} }
int TVMSynchronize(int device_type, int device_id, TVMStreamHandle stream) { int DGLSynchronize(int device_type, int device_id, DGLStreamHandle stream) {
API_BEGIN(); API_BEGIN();
TVMContext ctx; DGLContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type); ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id; ctx.device_id = device_id;
DeviceAPIManager::Get(ctx)->StreamSync(ctx, stream); DeviceAPIManager::Get(ctx)->StreamSync(ctx, stream);
API_END(); API_END();
} }
int TVMStreamStreamSynchronize(int device_type, int DGLStreamStreamSynchronize(int device_type,
int device_id, int device_id,
TVMStreamHandle src, DGLStreamHandle src,
TVMStreamHandle dst) { DGLStreamHandle dst) {
API_BEGIN(); API_BEGIN();
TVMContext ctx; DGLContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type); ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id; ctx.device_id = device_id;
DeviceAPIManager::Get(ctx)->SyncStreamFromTo(ctx, src, dst); DeviceAPIManager::Get(ctx)->SyncStreamFromTo(ctx, src, dst);
API_END(); API_END();
} }
int TVMCbArgToReturn(TVMValue* value, int code) { int DGLCbArgToReturn(DGLValue* value, int code) {
API_BEGIN(); API_BEGIN();
tvm::runtime::TVMRetValue rv; dgl::runtime::DGLRetValue rv;
rv = tvm::runtime::TVMArgValue(*value, code); rv = dgl::runtime::DGLArgValue(*value, code);
int tcode; int tcode;
rv.MoveToCHost(value, &tcode); rv.MoveToCHost(value, &tcode);
CHECK_EQ(tcode, code); CHECK_EQ(tcode, code);
...@@ -379,18 +379,18 @@ int TVMCbArgToReturn(TVMValue* value, int code) { ...@@ -379,18 +379,18 @@ int TVMCbArgToReturn(TVMValue* value, int code) {
} }
// set device api // set device api
TVM_REGISTER_GLOBAL(tvm::runtime::symbol::tvm_set_device) DGL_REGISTER_GLOBAL(dgl::runtime::symbol::dgl_set_device)
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](DGLArgs args, DGLRetValue *ret) {
TVMContext ctx; DGLContext ctx;
ctx.device_type = static_cast<DLDeviceType>(args[0].operator int()); ctx.device_type = static_cast<DLDeviceType>(args[0].operator int());
ctx.device_id = args[1]; ctx.device_id = args[1];
DeviceAPIManager::Get(ctx)->SetDevice(ctx); DeviceAPIManager::Get(ctx)->SetDevice(ctx);
}); });
// set device api // set device api
TVM_REGISTER_GLOBAL("_GetDeviceAttr") DGL_REGISTER_GLOBAL("_GetDeviceAttr")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](DGLArgs args, DGLRetValue *ret) {
TVMContext ctx; DGLContext ctx;
ctx.device_type = static_cast<DLDeviceType>(args[0].operator int()); ctx.device_type = static_cast<DLDeviceType>(args[0].operator int());
ctx.device_id = args[1]; ctx.device_id = args[1];
......
...@@ -10,20 +10,20 @@ ...@@ -10,20 +10,20 @@
#include <cstring> #include <cstring>
#include "workspace_pool.h" #include "workspace_pool.h"
namespace tvm { namespace dgl {
namespace runtime { namespace runtime {
class CPUDeviceAPI final : public DeviceAPI { class CPUDeviceAPI final : public DeviceAPI {
public: public:
void SetDevice(TVMContext ctx) final {} void SetDevice(DGLContext ctx) final {}
void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final { void GetAttr(DGLContext ctx, DeviceAttrKind kind, DGLRetValue* rv) final {
if (kind == kExist) { if (kind == kExist) {
*rv = 1; *rv = 1;
} }
} }
void* AllocDataSpace(TVMContext ctx, void* AllocDataSpace(DGLContext ctx,
size_t nbytes, size_t nbytes,
size_t alignment, size_t alignment,
TVMType type_hint) final { DGLType type_hint) final {
void* ptr; void* ptr;
#if _MSC_VER || defined(__MINGW32__) #if _MSC_VER || defined(__MINGW32__)
ptr = _aligned_malloc(nbytes, alignment); ptr = _aligned_malloc(nbytes, alignment);
...@@ -38,7 +38,7 @@ class CPUDeviceAPI final : public DeviceAPI { ...@@ -38,7 +38,7 @@ class CPUDeviceAPI final : public DeviceAPI {
return ptr; return ptr;
} }
void FreeDataSpace(TVMContext ctx, void* ptr) final { void FreeDataSpace(DGLContext ctx, void* ptr) final {
#if _MSC_VER || defined(__MINGW32__) #if _MSC_VER || defined(__MINGW32__)
_aligned_free(ptr); _aligned_free(ptr);
#else #else
...@@ -51,20 +51,20 @@ class CPUDeviceAPI final : public DeviceAPI { ...@@ -51,20 +51,20 @@ class CPUDeviceAPI final : public DeviceAPI {
void* to, void* to,
size_t to_offset, size_t to_offset,
size_t size, size_t size,
TVMContext ctx_from, DGLContext ctx_from,
TVMContext ctx_to, DGLContext ctx_to,
TVMType type_hint, DGLType type_hint,
TVMStreamHandle stream) final { DGLStreamHandle stream) final {
memcpy(static_cast<char*>(to) + to_offset, memcpy(static_cast<char*>(to) + to_offset,
static_cast<const char*>(from) + from_offset, static_cast<const char*>(from) + from_offset,
size); size);
} }
void StreamSync(TVMContext ctx, TVMStreamHandle stream) final { void StreamSync(DGLContext ctx, DGLStreamHandle stream) final {
} }
void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final; void* AllocWorkspace(DGLContext ctx, size_t size, DGLType type_hint) final;
void FreeWorkspace(TVMContext ctx, void* data) final; void FreeWorkspace(DGLContext ctx, void* data) final;
static const std::shared_ptr<CPUDeviceAPI>& Global() { static const std::shared_ptr<CPUDeviceAPI>& Global() {
static std::shared_ptr<CPUDeviceAPI> inst = static std::shared_ptr<CPUDeviceAPI> inst =
...@@ -78,21 +78,21 @@ struct CPUWorkspacePool : public WorkspacePool { ...@@ -78,21 +78,21 @@ struct CPUWorkspacePool : public WorkspacePool {
WorkspacePool(kDLCPU, CPUDeviceAPI::Global()) {} WorkspacePool(kDLCPU, CPUDeviceAPI::Global()) {}
}; };
void* CPUDeviceAPI::AllocWorkspace(TVMContext ctx, void* CPUDeviceAPI::AllocWorkspace(DGLContext ctx,
size_t size, size_t size,
TVMType type_hint) { DGLType type_hint) {
return dmlc::ThreadLocalStore<CPUWorkspacePool>::Get() return dmlc::ThreadLocalStore<CPUWorkspacePool>::Get()
->AllocWorkspace(ctx, size); ->AllocWorkspace(ctx, size);
} }
void CPUDeviceAPI::FreeWorkspace(TVMContext ctx, void* data) { void CPUDeviceAPI::FreeWorkspace(DGLContext ctx, void* data) {
dmlc::ThreadLocalStore<CPUWorkspacePool>::Get()->FreeWorkspace(ctx, data); dmlc::ThreadLocalStore<CPUWorkspacePool>::Get()->FreeWorkspace(ctx, data);
} }
TVM_REGISTER_GLOBAL("device_api.cpu") DGL_REGISTER_GLOBAL("device_api.cpu")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
DeviceAPI* ptr = CPUDeviceAPI::Global().get(); DeviceAPI* ptr = CPUDeviceAPI::Global().get();
*rv = static_cast<void*>(ptr); *rv = static_cast<void*>(ptr);
}); });
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace dgl
...@@ -14,11 +14,11 @@ ...@@ -14,11 +14,11 @@
#include <dlfcn.h> #include <dlfcn.h>
#endif #endif
namespace tvm { namespace dgl {
namespace runtime { namespace runtime {
// Module to load from dynamic shared libary. // Module to load from dynamic shared libary.
// This is the default module TVM used for host-side AOT // This is the default module DGL used for host-side AOT
class DSOModuleNode final : public ModuleNode { class DSOModuleNode final : public ModuleNode {
public: public:
~DSOModuleNode() { ~DSOModuleNode() {
...@@ -33,11 +33,11 @@ class DSOModuleNode final : public ModuleNode { ...@@ -33,11 +33,11 @@ class DSOModuleNode final : public ModuleNode {
const std::string& name, const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final { const std::shared_ptr<ModuleNode>& sptr_to_self) final {
BackendPackedCFunc faddr; BackendPackedCFunc faddr;
if (name == runtime::symbol::tvm_module_main) { if (name == runtime::symbol::dgl_module_main) {
const char* entry_name = reinterpret_cast<const char*>( const char* entry_name = reinterpret_cast<const char*>(
GetSymbol(runtime::symbol::tvm_module_main)); GetSymbol(runtime::symbol::dgl_module_main));
CHECK(entry_name!= nullptr) CHECK(entry_name!= nullptr)
<< "Symbol " << runtime::symbol::tvm_module_main << " is not presented"; << "Symbol " << runtime::symbol::dgl_module_main << " is not presented";
faddr = reinterpret_cast<BackendPackedCFunc>(GetSymbol(entry_name)); faddr = reinterpret_cast<BackendPackedCFunc>(GetSymbol(entry_name));
} else { } else {
faddr = reinterpret_cast<BackendPackedCFunc>(GetSymbol(name.c_str())); faddr = reinterpret_cast<BackendPackedCFunc>(GetSymbol(name.c_str()));
...@@ -49,7 +49,7 @@ class DSOModuleNode final : public ModuleNode { ...@@ -49,7 +49,7 @@ class DSOModuleNode final : public ModuleNode {
void Init(const std::string& name) { void Init(const std::string& name) {
Load(name); Load(name);
if (auto *ctx_addr = if (auto *ctx_addr =
reinterpret_cast<void**>(GetSymbol(runtime::symbol::tvm_module_ctx))) { reinterpret_cast<void**>(GetSymbol(runtime::symbol::dgl_module_ctx))) {
*ctx_addr = this; *ctx_addr = this;
} }
InitContextFunctions([this](const char* fname) { InitContextFunctions([this](const char* fname) {
...@@ -58,7 +58,7 @@ class DSOModuleNode final : public ModuleNode { ...@@ -58,7 +58,7 @@ class DSOModuleNode final : public ModuleNode {
// Load the imported modules // Load the imported modules
const char* dev_mblob = const char* dev_mblob =
reinterpret_cast<const char*>( reinterpret_cast<const char*>(
GetSymbol(runtime::symbol::tvm_dev_mblob)); GetSymbol(runtime::symbol::dgl_dev_mblob));
if (dev_mblob != nullptr) { if (dev_mblob != nullptr) {
ImportModuleBlob(dev_mblob, &imports_); ImportModuleBlob(dev_mblob, &imports_);
} }
...@@ -103,11 +103,11 @@ class DSOModuleNode final : public ModuleNode { ...@@ -103,11 +103,11 @@ class DSOModuleNode final : public ModuleNode {
#endif #endif
}; };
TVM_REGISTER_GLOBAL("module.loadfile_so") DGL_REGISTER_GLOBAL("module.loadfile_so")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
std::shared_ptr<DSOModuleNode> n = std::make_shared<DSOModuleNode>(); std::shared_ptr<DSOModuleNode> n = std::make_shared<DSOModuleNode>();
n->Init(args[0]); n->Init(args[0]);
*rv = runtime::Module(n); *rv = runtime::Module(n);
}); });
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace dgl
...@@ -10,13 +10,13 @@ ...@@ -10,13 +10,13 @@
#include "file_util.h" #include "file_util.h"
namespace tvm { namespace dgl {
namespace runtime { namespace runtime {
void FunctionInfo::Save(dmlc::JSONWriter* writer) const { void FunctionInfo::Save(dmlc::JSONWriter* writer) const {
std::vector<std::string> sarg_types(arg_types.size()); std::vector<std::string> sarg_types(arg_types.size());
for (size_t i = 0; i < arg_types.size(); ++i) { for (size_t i = 0; i < arg_types.size(); ++i) {
sarg_types[i] = TVMType2String(arg_types[i]); sarg_types[i] = DGLType2String(arg_types[i]);
} }
writer->BeginObject(); writer->BeginObject();
writer->WriteObjectKeyValue("name", name); writer->WriteObjectKeyValue("name", name);
...@@ -34,7 +34,7 @@ void FunctionInfo::Load(dmlc::JSONReader* reader) { ...@@ -34,7 +34,7 @@ void FunctionInfo::Load(dmlc::JSONReader* reader) {
helper.ReadAllFields(reader); helper.ReadAllFields(reader);
arg_types.resize(sarg_types.size()); arg_types.resize(sarg_types.size());
for (size_t i = 0; i < arg_types.size(); ++i) { for (size_t i = 0; i < arg_types.size(); ++i) {
arg_types[i] = String2TVMType(sarg_types[i]); arg_types[i] = String2DGLType(sarg_types[i]);
} }
} }
...@@ -69,12 +69,12 @@ std::string GetFileFormat(const std::string& file_name, ...@@ -69,12 +69,12 @@ std::string GetFileFormat(const std::string& file_name,
std::string GetCacheDir() { std::string GetCacheDir() {
char* env_cache_dir; char* env_cache_dir;
if ((env_cache_dir = getenv("TVM_CACHE_DIR"))) return env_cache_dir; if ((env_cache_dir = getenv("DGL_CACHE_DIR"))) return env_cache_dir;
if ((env_cache_dir = getenv("XDG_CACHE_HOME"))) { if ((env_cache_dir = getenv("XDG_CACHE_HOME"))) {
return std::string(env_cache_dir) + "/tvm"; return std::string(env_cache_dir) + "/dgl";
} }
if ((env_cache_dir = getenv("HOME"))) { if ((env_cache_dir = getenv("HOME"))) {
return std::string(env_cache_dir) + "/.cache/tvm"; return std::string(env_cache_dir) + "/.cache/dgl";
} }
return "."; return ".";
} }
...@@ -88,9 +88,9 @@ std::string GetFileBasename(const std::string& file_name) { ...@@ -88,9 +88,9 @@ std::string GetFileBasename(const std::string& file_name) {
std::string GetMetaFilePath(const std::string& file_name) { std::string GetMetaFilePath(const std::string& file_name) {
size_t pos = file_name.find_last_of("."); size_t pos = file_name.find_last_of(".");
if (pos != std::string::npos) { if (pos != std::string::npos) {
return file_name.substr(0, pos) + ".tvm_meta.json"; return file_name.substr(0, pos) + ".dgl_meta.json";
} else { } else {
return file_name + ".tvm_meta.json"; return file_name + ".dgl_meta.json";
} }
} }
...@@ -122,7 +122,7 @@ void SaveMetaDataToFile( ...@@ -122,7 +122,7 @@ void SaveMetaDataToFile(
CHECK(!fs.fail()) << "Cannot open file " << file_name; CHECK(!fs.fail()) << "Cannot open file " << file_name;
dmlc::JSONWriter writer(&fs); dmlc::JSONWriter writer(&fs);
writer.BeginObject(); writer.BeginObject();
writer.WriteObjectKeyValue("tvm_version", version); writer.WriteObjectKeyValue("dgl_version", version);
writer.WriteObjectKeyValue("func_info", fmap); writer.WriteObjectKeyValue("func_info", fmap);
writer.EndObject(); writer.EndObject();
fs.close(); fs.close();
...@@ -136,11 +136,11 @@ void LoadMetaDataFromFile( ...@@ -136,11 +136,11 @@ void LoadMetaDataFromFile(
std::string version; std::string version;
dmlc::JSONReader reader(&fs); dmlc::JSONReader reader(&fs);
dmlc::JSONObjectReadHelper helper; dmlc::JSONObjectReadHelper helper;
helper.DeclareField("tvm_version", &version); helper.DeclareField("dgl_version", &version);
helper.DeclareField("func_info", fmap); helper.DeclareField("func_info", fmap);
helper.ReadAllFields(&reader); helper.ReadAllFields(&reader);
fs.close(); fs.close();
} }
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace dgl
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
#include <string> #include <string>
#include "meta_data.h" #include "meta_data.h"
namespace tvm { namespace dgl {
namespace runtime { namespace runtime {
/*! /*!
* \brief Get file format from given file name or format argument. * \brief Get file format from given file name or format argument.
...@@ -20,8 +20,8 @@ std::string GetFileFormat(const std::string& file_name, ...@@ -20,8 +20,8 @@ std::string GetFileFormat(const std::string& file_name,
const std::string& format); const std::string& format);
/*! /*!
* \return the directory in which TVM stores cached files. * \return the directory in which DGL stores cached files.
* May be set using TVM_CACHE_DIR; defaults to system locations. * May be set using DGL_CACHE_DIR; defaults to system locations.
*/ */
std::string GetCacheDir(); std::string GetCacheDir();
...@@ -72,5 +72,5 @@ void LoadMetaDataFromFile( ...@@ -72,5 +72,5 @@ void LoadMetaDataFromFile(
const std::string& file_name, const std::string& file_name,
std::unordered_map<std::string, FunctionInfo>* fmap); std::unordered_map<std::string, FunctionInfo>* fmap);
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace dgl
#endif // DGL_RUNTIME_FILE_UTIL_H_ #endif // DGL_RUNTIME_FILE_UTIL_H_
...@@ -13,13 +13,13 @@ ...@@ -13,13 +13,13 @@
#include <vector> #include <vector>
#include "runtime_base.h" #include "runtime_base.h"
namespace tvm { namespace dgl {
namespace runtime { namespace runtime {
/*! \brief function information needed by device */ /*! \brief function information needed by device */
struct FunctionInfo { struct FunctionInfo {
std::string name; std::string name;
std::vector<TVMType> arg_types; std::vector<DGLType> arg_types;
std::vector<std::string> thread_axis_tags; std::vector<std::string> thread_axis_tags;
void Save(dmlc::JSONWriter *writer) const; void Save(dmlc::JSONWriter *writer) const;
...@@ -28,9 +28,9 @@ struct FunctionInfo { ...@@ -28,9 +28,9 @@ struct FunctionInfo {
bool Load(dmlc::Stream *reader); bool Load(dmlc::Stream *reader);
}; };
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace dgl
namespace dmlc { namespace dmlc {
DMLC_DECLARE_TRAITS(has_saveload, ::tvm::runtime::FunctionInfo, true); DMLC_DECLARE_TRAITS(has_saveload, ::dgl::runtime::FunctionInfo, true);
} // namespace dmlc } // namespace dmlc
#endif // DGL_RUNTIME_META_DATA_H_ #endif // DGL_RUNTIME_META_DATA_H_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment