Commit 758cb16e authored by Minjie Wang's avatar Minjie Wang
Browse files

ndarray argument

parent c086d454
from __future__ import absolute_import
import ctypes
import torch as th
from .._ffi.base import _LIB, check_call, c_array
from .._ffi.runtime_ctypes import TVMType, TVMContext, TVMArray
from .._ffi.runtime_ctypes import TypeCode, tvm_shape_index_t
from .. import ndarray as nd
# Tensor types
Tensor = th.Tensor
......@@ -84,14 +87,31 @@ def get_context(arr):
return TVMContext(
TVMContext.STR2MASK[arr.device.type], arr.device.index)
def _typestr(arr_dtype):
if arr_dtype in (th.float16, th.half):
return 'float16'
elif arr_dtype in (th.float32, th.float):
return 'float32'
elif arr_dtype in (th.float64, th.double):
return 'float64'
elif arr_dtype in (th.int16, th.short):
return 'int16'
elif arr_dtype in (th.int32, th.int):
return 'int32'
elif arr_dtype in (th.int64, th.long):
return 'int64'
elif arr_dtype == th.int8:
return 'int8'
elif arr_dtype == th.uint8:
return 'uint8'
else:
raise RuntimeError('Unsupported data type:', arr_dtype)
def asdglarray(arr):
"""The data is copied to the new array."""
assert arr.is_contiguous()
rst = TVMArray()
rst.data = arr.data_ptr()
rst.shape = c_array(tvm_shape_index_t, arr.shape)
rst.strides = None
# TODO: dtype
rst.dtype = TVMType(arr.dtype)
rst.ndim = arr.ndimension()
# TODO: ctx
rst = nd.empty(tuple(arr.shape), _typestr(arr.dtype), get_context(arr))
data = ctypes.cast(arr.data_ptr(), ctypes.c_void_p)
nbytes = ctypes.c_size_t(arr.numel() * arr.element_size())
check_call(_LIB.TVMArrayCopyFromBytes(rst.handle, data, nbytes))
return rst
......@@ -2,6 +2,7 @@ from __future__ import absolute_import
from ._ffi.function import _init_api
from . import backend as F
from . import utils
class DGLGraph(object):
def __init__(self):
......@@ -17,7 +18,14 @@ class DGLGraph(object):
_CAPI_DGLGraphAddEdge(self._handle, u, v);
def add_edges(self, u, v):
pass
u = utils.Index(u)
v = utils.Index(v)
u_array = F.asdglarray(u.totensor())
v_array = F.asdglarray(v.totensor())
_CAPI_DGLGraphAddEdges(
self._handle,
u_array,
v_array)
def number_of_nodes(self):
return _CAPI_DGLGraphNumVertices(self._handle)
......
"""DGL Runtime NDArray API.
dgl.ndarray provides a minimum runtime array API to unify
different array libraries used as backend.
"""
# pylint: disable=invalid-name,unused-import
from __future__ import absolute_import as _abs
import numpy as _np
from ._ffi.ndarray import TVMContext, TVMType, NDArrayBase
from ._ffi.ndarray import context, empty, from_dlpack
from ._ffi.ndarray import _set_class_ndarray
class NDArray(NDArrayBase):
"""Lightweight NDArray class for DGL framework."""
pass
def cpu(dev_id=0):
"""Construct a CPU device
Parameters
----------
dev_id : int, optional
The integer device id
Returns
-------
ctx : TVMContext
The created context
"""
return TVMContext(1, dev_id)
def gpu(dev_id=0):
"""Construct a CPU device
Parameters
----------
dev_id : int, optional
The integer device id
Returns
-------
ctx : TVMContext
The created context
"""
return TVMContext(2, dev_id)
def array(arr, ctx=cpu(0)):
"""Create an array from source arr.
Parameters
----------
arr : numpy.ndarray
The array to be copied from
ctx : TVMContext, optional
The device context to create the array
Returns
-------
ret : NDArray
The created array
"""
if not isinstance(arr, (_np.ndarray, NDArray)):
arr = _np.array(arr)
return empty(arr.shape, arr.dtype, ctx).copyfrom(arr)
_set_class_ndarray(NDArray)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment