Commit 842d3768 authored by Minjie Wang's avatar Minjie Wang
Browse files

python api

parent 14d88497
......@@ -103,7 +103,7 @@ class vector_view {
*/
ValueType& operator[](size_t i) {
CHECK(!is_view_);
return data_[i];
return (*data_)[i];
}
/*!
......@@ -113,9 +113,9 @@ class vector_view {
*/
const ValueType& operator[](size_t i) const {
if (is_view_) {
return data_[index_[i]];
return (*data_)[index_[i]];
} else {
return data_[i];
return (*data_)[i];
}
}
......
......@@ -118,9 +118,10 @@ def _make_tvm_args(args, temp_args):
elif isinstance(arg, string_types):
values[i].v_str = c_str(arg)
type_codes[i] = TypeCode.STR
elif isinstance(arg, _CLASS_MODULE):
values[i].v_handle = arg.handle
type_codes[i] = TypeCode.MODULE_HANDLE
# NOTE(minjie): module is not used in DGL
#elif isinstance(arg, _CLASS_MODULE):
# values[i].v_handle = arg.handle
# type_codes[i] = TypeCode.MODULE_HANDLE
elif isinstance(arg, FunctionBase):
values[i].v_handle = arg.handle
type_codes[i] = TypeCode.FUNC_HANDLE
......
......@@ -29,7 +29,7 @@ except IMPORT_EXCEPT:
FunctionHandle = ctypes.c_void_p
class Function(_FunctionBase):
"""The PackedFunc object used in TVM.
"""The PackedFunc object.
Function plays an key role to bridge front and backend in TVM.
Function provide a type-erased interface, you can call function with positional arguments.
......@@ -275,7 +275,7 @@ def _init_api(namespace, target_module_name=None):
"""
target_module_name = (
target_module_name if target_module_name else namespace)
if namespace.startswith("tvm."):
if namespace.startswith("dgl."):
_init_api_prefix(target_module_name, namespace[4:])
else:
_init_api_prefix(target_module_name, namespace)
......@@ -288,7 +288,7 @@ def _init_api_prefix(module_name, prefix):
if prefix == "api":
fname = name
if name.startswith("_"):
target_module = sys.modules["tvm._api_internal"]
target_module = sys.modules["dgl._api_internal"]
else:
target_module = module
else:
......@@ -302,7 +302,7 @@ def _init_api_prefix(module_name, prefix):
f = get_global_func(name)
ff = _get_api(f)
ff.__name__ = fname
ff.__doc__ = ("TVM PackedFunc %s. " % fname)
ff.__doc__ = ("DGL PackedFunc %s. " % fname)
setattr(target_module, ff.__name__, ff)
_set_class_function(Function)
from __future__ import absolute_import
import torch as th
import scipy.sparse
import dgl.context as context
from .._ffi.runtime_ctypes import TVMType, TVMContext, TVMArray
from .._ffi.runtime_ctypes import TypeCode, tvm_shape_index_t
from ..context as cpu, gpu
# Tensor types
Tensor = th.Tensor
......@@ -78,6 +80,18 @@ def to_context(x, ctx):
def get_context(x):
if x.device.type == 'cpu':
return context.cpu()
return cpu()
else:
return context.gpu(x.device.index)
return gpu(x.device.index)
def asdglarray(arr):
assert arr.is_contiguous()
rst = TVMArray()
rst.data = arr.data_ptr()
rst.shape = c_array(tvm_shape_index_t, arr.shape)
rst.strides = None
# TODO: dtype
rst.dtype = TVMType(arr.dtype)
rst.ndim = arr.ndimension()
# TODO: ctx
return rst
from __future__ import absolute_import
from ._ffi.function import _init_api
import .backend as F
class DGLGraph(object):
def __init__(self):
self._handle = _CAPI_DGLGraphCreate()
def __del__(self):
_CAPI_DGLGraphFree(self._handle)
def add_nodes(self, num):
_CAPI_DGLGraphAddVertices(self._handle, num);
def add_edge(self, u, v):
_CAPI_DGLGraphAddEdge(self._handle, u, v);
def add_edges(self, u, v):
pass
def number_of_nodes(self):
return _CAPI_DGLGraphNumVertices(self._handle)
def number_of_edges(self):
return _CAPI_DGLGraphNumEdges(self._handle)
_init_api("dgl.cgraph")
......@@ -58,7 +58,7 @@ BoolArray Graph::HasVertices(IdArray vids) const {
BoolArray rst = BoolArray::Empty({len}, vids->dtype, vids->ctx);
const int64_t* vid_data = static_cast<int64_t*>(vids->data);
int64_t* rst_data = static_cast<int64_t*>(rst->data);
const uint64_t nverts = NumVertices();
const int64_t nverts = NumVertices();
for (int64_t i = 0; i < len; ++i) {
rst_data[i] = (vid_data[i] < nverts)? 1 : 0;
}
......
#include <dgl/runtime/packed_func.h>
#include <dgl/runtime/registry.h>
#include <dgl/graph.h>
using tvm::runtime::TVMArgs;
using tvm::runtime::TVMArgValue;
using tvm::runtime::TVMRetValue;
using tvm::runtime::PackedFunc;
namespace dgl {
typedef void* GraphHandle;
void DGLGraphCreate(TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = new Graph();
*rv = ghandle;
}
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphCreate")
.set_body(DGLGraphCreate);
void DGLGraphFree(TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
Graph* gptr = static_cast<Graph*>(ghandle);
delete gptr;
}
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphFree")
.set_body(DGLGraphFree);
void DGLGraphAddVertices(TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
Graph* gptr = static_cast<Graph*>(ghandle);
uint64_t num_vertices = args[1];
gptr->AddVertices(num_vertices);
}
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphAddVertices")
.set_body(DGLGraphAddVertices);
void DGLGraphAddEdge(TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t src = args[1];
const dgl_id_t dst = args[2];
gptr->AddEdge(src, dst);
}
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphAddEdge")
.set_body(DGLGraphAddEdge);
void DGLGraphAddEdges(TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray src = args[1];
const IdArray dst = args[2];
gptr->AddEdges(src, dst);
}
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphAddEdges")
.set_body(DGLGraphAddEdges);
void DGLGraphNumVertices(TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
*rv = static_cast<int64_t>(gptr->NumVertices());
}
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphNumVertices")
.set_body(DGLGraphNumVertices);
void DGLGraphNumEdges(TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
*rv = static_cast<int64_t>(gptr->NumEdges());
}
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphNumEdges")
.set_body(DGLGraphNumEdges);
} // namespace dgl
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment