Commit 758cb16e authored by Minjie Wang's avatar Minjie Wang
Browse files

ndarray argument

parent c086d454
from __future__ import absolute_import from __future__ import absolute_import
import ctypes
import torch as th import torch as th
from .._ffi.base import _LIB, check_call, c_array
from .._ffi.runtime_ctypes import TVMType, TVMContext, TVMArray from .._ffi.runtime_ctypes import TVMType, TVMContext, TVMArray
from .._ffi.runtime_ctypes import TypeCode, tvm_shape_index_t from .._ffi.runtime_ctypes import TypeCode, tvm_shape_index_t
from .. import ndarray as nd
# Tensor types # Tensor types
Tensor = th.Tensor Tensor = th.Tensor
...@@ -84,14 +87,31 @@ def get_context(arr): ...@@ -84,14 +87,31 @@ def get_context(arr):
return TVMContext( return TVMContext(
TVMContext.STR2MASK[arr.device.type], arr.device.index) TVMContext.STR2MASK[arr.device.type], arr.device.index)
def _typestr(arr_dtype):
if arr_dtype in (th.float16, th.half):
return 'float16'
elif arr_dtype in (th.float32, th.float):
return 'float32'
elif arr_dtype in (th.float64, th.double):
return 'float64'
elif arr_dtype in (th.int16, th.short):
return 'int16'
elif arr_dtype in (th.int32, th.int):
return 'int32'
elif arr_dtype in (th.int64, th.long):
return 'int64'
elif arr_dtype == th.int8:
return 'int8'
elif arr_dtype == th.uint8:
return 'uint8'
else:
raise RuntimeError('Unsupported data type:', arr_dtype)
def asdglarray(arr): def asdglarray(arr):
"""The data is copied to the new array."""
assert arr.is_contiguous() assert arr.is_contiguous()
rst = TVMArray() rst = nd.empty(tuple(arr.shape), _typestr(arr.dtype), get_context(arr))
rst.data = arr.data_ptr() data = ctypes.cast(arr.data_ptr(), ctypes.c_void_p)
rst.shape = c_array(tvm_shape_index_t, arr.shape) nbytes = ctypes.c_size_t(arr.numel() * arr.element_size())
rst.strides = None check_call(_LIB.TVMArrayCopyFromBytes(rst.handle, data, nbytes))
# TODO: dtype
rst.dtype = TVMType(arr.dtype)
rst.ndim = arr.ndimension()
# TODO: ctx
return rst return rst
...@@ -2,6 +2,7 @@ from __future__ import absolute_import ...@@ -2,6 +2,7 @@ from __future__ import absolute_import
from ._ffi.function import _init_api from ._ffi.function import _init_api
from . import backend as F from . import backend as F
from . import utils
class DGLGraph(object): class DGLGraph(object):
def __init__(self): def __init__(self):
...@@ -17,7 +18,14 @@ class DGLGraph(object): ...@@ -17,7 +18,14 @@ class DGLGraph(object):
_CAPI_DGLGraphAddEdge(self._handle, u, v); _CAPI_DGLGraphAddEdge(self._handle, u, v);
def add_edges(self, u, v): def add_edges(self, u, v):
pass u = utils.Index(u)
v = utils.Index(v)
u_array = F.asdglarray(u.totensor())
v_array = F.asdglarray(v.totensor())
_CAPI_DGLGraphAddEdges(
self._handle,
u_array,
v_array)
def number_of_nodes(self): def number_of_nodes(self):
return _CAPI_DGLGraphNumVertices(self._handle) return _CAPI_DGLGraphNumVertices(self._handle)
......
"""DGL Runtime NDArray API.
dgl.ndarray provides a minimum runtime array API to unify
different array libraries used as backend.
"""
# pylint: disable=invalid-name,unused-import
from __future__ import absolute_import as _abs
import numpy as _np
from ._ffi.ndarray import TVMContext, TVMType, NDArrayBase
from ._ffi.ndarray import context, empty, from_dlpack
from ._ffi.ndarray import _set_class_ndarray
class NDArray(NDArrayBase):
"""Lightweight NDArray class for DGL framework."""
pass
def cpu(dev_id=0):
"""Construct a CPU device
Parameters
----------
dev_id : int, optional
The integer device id
Returns
-------
ctx : TVMContext
The created context
"""
return TVMContext(1, dev_id)
def gpu(dev_id=0):
"""Construct a CPU device
Parameters
----------
dev_id : int, optional
The integer device id
Returns
-------
ctx : TVMContext
The created context
"""
return TVMContext(2, dev_id)
def array(arr, ctx=cpu(0)):
"""Create an array from source arr.
Parameters
----------
arr : numpy.ndarray
The array to be copied from
ctx : TVMContext, optional
The device context to create the array
Returns
-------
ret : NDArray
The created array
"""
if not isinstance(arr, (_np.ndarray, NDArray)):
arr = _np.array(arr)
return empty(arr.shape, arr.dtype, ctx).copyfrom(arr)
_set_class_ndarray(NDArray)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment