Commit 842d3768 authored by Minjie Wang's avatar Minjie Wang
Browse files

python api

parent 14d88497
...@@ -103,7 +103,7 @@ class vector_view { ...@@ -103,7 +103,7 @@ class vector_view {
*/ */
ValueType& operator[](size_t i) { ValueType& operator[](size_t i) {
CHECK(!is_view_); CHECK(!is_view_);
return data_[i]; return (*data_)[i];
} }
/*! /*!
...@@ -113,9 +113,9 @@ class vector_view { ...@@ -113,9 +113,9 @@ class vector_view {
*/ */
const ValueType& operator[](size_t i) const { const ValueType& operator[](size_t i) const {
if (is_view_) { if (is_view_) {
return data_[index_[i]]; return (*data_)[index_[i]];
} else { } else {
return data_[i]; return (*data_)[i];
} }
} }
......
...@@ -118,9 +118,10 @@ def _make_tvm_args(args, temp_args): ...@@ -118,9 +118,10 @@ def _make_tvm_args(args, temp_args):
elif isinstance(arg, string_types): elif isinstance(arg, string_types):
values[i].v_str = c_str(arg) values[i].v_str = c_str(arg)
type_codes[i] = TypeCode.STR type_codes[i] = TypeCode.STR
elif isinstance(arg, _CLASS_MODULE): # NOTE(minjie): module is not used in DGL
values[i].v_handle = arg.handle #elif isinstance(arg, _CLASS_MODULE):
type_codes[i] = TypeCode.MODULE_HANDLE # values[i].v_handle = arg.handle
# type_codes[i] = TypeCode.MODULE_HANDLE
elif isinstance(arg, FunctionBase): elif isinstance(arg, FunctionBase):
values[i].v_handle = arg.handle values[i].v_handle = arg.handle
type_codes[i] = TypeCode.FUNC_HANDLE type_codes[i] = TypeCode.FUNC_HANDLE
......
...@@ -29,7 +29,7 @@ except IMPORT_EXCEPT: ...@@ -29,7 +29,7 @@ except IMPORT_EXCEPT:
FunctionHandle = ctypes.c_void_p FunctionHandle = ctypes.c_void_p
class Function(_FunctionBase): class Function(_FunctionBase):
"""The PackedFunc object used in TVM. """The PackedFunc object.
Function plays an key role to bridge front and backend in TVM. Function plays an key role to bridge front and backend in TVM.
Function provide a type-erased interface, you can call function with positional arguments. Function provide a type-erased interface, you can call function with positional arguments.
...@@ -275,7 +275,7 @@ def _init_api(namespace, target_module_name=None): ...@@ -275,7 +275,7 @@ def _init_api(namespace, target_module_name=None):
""" """
target_module_name = ( target_module_name = (
target_module_name if target_module_name else namespace) target_module_name if target_module_name else namespace)
if namespace.startswith("tvm."): if namespace.startswith("dgl."):
_init_api_prefix(target_module_name, namespace[4:]) _init_api_prefix(target_module_name, namespace[4:])
else: else:
_init_api_prefix(target_module_name, namespace) _init_api_prefix(target_module_name, namespace)
...@@ -288,7 +288,7 @@ def _init_api_prefix(module_name, prefix): ...@@ -288,7 +288,7 @@ def _init_api_prefix(module_name, prefix):
if prefix == "api": if prefix == "api":
fname = name fname = name
if name.startswith("_"): if name.startswith("_"):
target_module = sys.modules["tvm._api_internal"] target_module = sys.modules["dgl._api_internal"]
else: else:
target_module = module target_module = module
else: else:
...@@ -302,7 +302,7 @@ def _init_api_prefix(module_name, prefix): ...@@ -302,7 +302,7 @@ def _init_api_prefix(module_name, prefix):
f = get_global_func(name) f = get_global_func(name)
ff = _get_api(f) ff = _get_api(f)
ff.__name__ = fname ff.__name__ = fname
ff.__doc__ = ("TVM PackedFunc %s. " % fname) ff.__doc__ = ("DGL PackedFunc %s. " % fname)
setattr(target_module, ff.__name__, ff) setattr(target_module, ff.__name__, ff)
_set_class_function(Function) _set_class_function(Function)
from __future__ import absolute_import from __future__ import absolute_import
import torch as th import torch as th
import scipy.sparse
import dgl.context as context from .._ffi.runtime_ctypes import TVMType, TVMContext, TVMArray
from .._ffi.runtime_ctypes import TypeCode, tvm_shape_index_t
from ..context as cpu, gpu
# Tensor types # Tensor types
Tensor = th.Tensor Tensor = th.Tensor
...@@ -78,6 +80,18 @@ def to_context(x, ctx): ...@@ -78,6 +80,18 @@ def to_context(x, ctx):
def get_context(x): def get_context(x):
if x.device.type == 'cpu': if x.device.type == 'cpu':
return context.cpu() return cpu()
else: else:
return context.gpu(x.device.index) return gpu(x.device.index)
def asdglarray(arr):
assert arr.is_contiguous()
rst = TVMArray()
rst.data = arr.data_ptr()
rst.shape = c_array(tvm_shape_index_t, arr.shape)
rst.strides = None
# TODO: dtype
rst.dtype = TVMType(arr.dtype)
rst.ndim = arr.ndimension()
# TODO: ctx
return rst
from __future__ import absolute_import
from ._ffi.function import _init_api
import .backend as F
class DGLGraph(object):
def __init__(self):
self._handle = _CAPI_DGLGraphCreate()
def __del__(self):
_CAPI_DGLGraphFree(self._handle)
def add_nodes(self, num):
_CAPI_DGLGraphAddVertices(self._handle, num);
def add_edge(self, u, v):
_CAPI_DGLGraphAddEdge(self._handle, u, v);
def add_edges(self, u, v):
pass
def number_of_nodes(self):
return _CAPI_DGLGraphNumVertices(self._handle)
def number_of_edges(self):
return _CAPI_DGLGraphNumEdges(self._handle)
_init_api("dgl.cgraph")
...@@ -58,7 +58,7 @@ BoolArray Graph::HasVertices(IdArray vids) const { ...@@ -58,7 +58,7 @@ BoolArray Graph::HasVertices(IdArray vids) const {
BoolArray rst = BoolArray::Empty({len}, vids->dtype, vids->ctx); BoolArray rst = BoolArray::Empty({len}, vids->dtype, vids->ctx);
const int64_t* vid_data = static_cast<int64_t*>(vids->data); const int64_t* vid_data = static_cast<int64_t*>(vids->data);
int64_t* rst_data = static_cast<int64_t*>(rst->data); int64_t* rst_data = static_cast<int64_t*>(rst->data);
const uint64_t nverts = NumVertices(); const int64_t nverts = NumVertices();
for (int64_t i = 0; i < len; ++i) { for (int64_t i = 0; i < len; ++i) {
rst_data[i] = (vid_data[i] < nverts)? 1 : 0; rst_data[i] = (vid_data[i] < nverts)? 1 : 0;
} }
......
#include <dgl/runtime/packed_func.h>
#include <dgl/runtime/registry.h>
#include <dgl/graph.h>
using tvm::runtime::TVMArgs;
using tvm::runtime::TVMArgValue;
using tvm::runtime::TVMRetValue;
using tvm::runtime::PackedFunc;
namespace dgl {
typedef void* GraphHandle;
void DGLGraphCreate(TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = new Graph();
*rv = ghandle;
}
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphCreate")
.set_body(DGLGraphCreate);
void DGLGraphFree(TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
Graph* gptr = static_cast<Graph*>(ghandle);
delete gptr;
}
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphFree")
.set_body(DGLGraphFree);
void DGLGraphAddVertices(TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
Graph* gptr = static_cast<Graph*>(ghandle);
uint64_t num_vertices = args[1];
gptr->AddVertices(num_vertices);
}
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphAddVertices")
.set_body(DGLGraphAddVertices);
void DGLGraphAddEdge(TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t src = args[1];
const dgl_id_t dst = args[2];
gptr->AddEdge(src, dst);
}
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphAddEdge")
.set_body(DGLGraphAddEdge);
void DGLGraphAddEdges(TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray src = args[1];
const IdArray dst = args[2];
gptr->AddEdges(src, dst);
}
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphAddEdges")
.set_body(DGLGraphAddEdges);
void DGLGraphNumVertices(TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
*rv = static_cast<int64_t>(gptr->NumVertices());
}
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphNumVertices")
.set_body(DGLGraphNumVertices);
void DGLGraphNumEdges(TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
*rv = static_cast<int64_t>(gptr->NumEdges());
}
TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphNumEdges")
.set_body(DGLGraphNumEdges);
} // namespace dgl
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment