Unverified Commit cded5b80 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Feature] Bump DLPack to v0.7 and decouple DLPack from the core library (#4454)

* rename `DLContext` to `DGLContext`

* rename `kDLGPU` to `kDLCUDA`

* replace DLTensor with DGLArray

* fix linting

* Unify DGLType and DLDataType to DGLDataType

* Fix FFI

* rename DLDeviceType to DGLDeviceType

* decouple dlpack from the core library

* fix bug

* fix lint

* fix merge

* fix build

* address comments

* rename dl_converter to dlpack_convert

* remove redundant comments
parent f1689ad0
...@@ -3,16 +3,16 @@ from libcpp.vector cimport vector ...@@ -3,16 +3,16 @@ from libcpp.vector cimport vector
from libcpp cimport bool from libcpp cimport bool
from cpython.version cimport PY_MAJOR_VERSION from cpython.version cimport PY_MAJOR_VERSION
from cpython cimport pycapsule from cpython cimport pycapsule
from libc.stdint cimport int64_t, uint64_t, uint8_t, uint16_t from libc.stdint cimport int32_t, int64_t, uint64_t, uint8_t, uint16_t
import ctypes import ctypes
cdef enum DGLTypeCode: cdef enum DGLObjectTypeCode:
kInt = 0 kInt = 0
kUInt = 1 kUInt = 1
kFloat = 2 kFloat = 2
kHandle = 3 kHandle = 3
kNull = 4 kNull = 4
kDGLType = 5 kDGLDataType = 5
kDGLContext = 6 kDGLContext = 6
kArrayHandle = 7 kArrayHandle = 7
kObjectHandle = 8 kObjectHandle = 8
...@@ -24,26 +24,26 @@ cdef enum DGLTypeCode: ...@@ -24,26 +24,26 @@ cdef enum DGLTypeCode:
kExtBegin = 15 kExtBegin = 15
cdef extern from "dgl/runtime/c_runtime_api.h": cdef extern from "dgl/runtime/c_runtime_api.h":
ctypedef struct DLDataType: ctypedef struct DGLDataType:
uint8_t code uint8_t code
uint8_t bits uint8_t bits
uint16_t lanes uint16_t lanes
ctypedef struct DLContext: ctypedef struct DGLContext:
int device_type int32_t device_type
int device_id int32_t device_id
ctypedef struct DLTensor: ctypedef struct DGLArray:
void* data void* data
DLContext ctx DGLContext ctx
int ndim int32_t ndim
DLDataType dtype DGLDataType dtype
int64_t* shape int64_t* shape
int64_t* strides int64_t* strides
uint64_t byte_offset uint64_t byte_offset
ctypedef struct DLManagedTensor: ctypedef struct DLManagedTensor:
DLTensor dl_tensor DGLArray dl_tensor
void* manager_ctx void* manager_ctx
void (*deleter)(DLManagedTensor* self) void (*deleter)(DLManagedTensor* self)
...@@ -52,13 +52,11 @@ cdef extern from "dgl/runtime/c_runtime_api.h": ...@@ -52,13 +52,11 @@ cdef extern from "dgl/runtime/c_runtime_api.h":
double v_float64 double v_float64
void* v_handle void* v_handle
const char* v_str const char* v_str
DLDataType v_type DGLDataType v_type
DLContext v_ctx DGLContext v_ctx
ctypedef int64_t dgl_index_t ctypedef int64_t dgl_index_t
ctypedef DLTensor* DLTensorHandle ctypedef DGLArray* DGLArrayHandle
ctypedef DLTensor DGLArray
ctypedef DGLArray* CDGLArrayHandle
ctypedef void* DGLStreamHandle ctypedef void* DGLStreamHandle
ctypedef void* DGLRetValueHandle ctypedef void* DGLRetValueHandle
ctypedef void* DGLFunctionHandle ctypedef void* DGLFunctionHandle
...@@ -94,9 +92,9 @@ cdef extern from "dgl/runtime/c_runtime_api.h": ...@@ -94,9 +92,9 @@ cdef extern from "dgl/runtime/c_runtime_api.h":
int DGLCbArgToReturn(DGLValue* value, int code) int DGLCbArgToReturn(DGLValue* value, int code)
int DGLArrayAlloc(dgl_index_t* shape, int DGLArrayAlloc(dgl_index_t* shape,
dgl_index_t ndim, dgl_index_t ndim,
DLDataType dtype, DGLDataType dtype,
DLContext ctx, DGLContext ctx,
DLTensorHandle* out) DGLArrayHandle* out)
int DGLArrayAllocSharedMem(const char *mem_name, int DGLArrayAllocSharedMem(const char *mem_name,
const dgl_index_t *shape, const dgl_index_t *shape,
int ndim, int ndim,
...@@ -104,16 +102,10 @@ cdef extern from "dgl/runtime/c_runtime_api.h": ...@@ -104,16 +102,10 @@ cdef extern from "dgl/runtime/c_runtime_api.h":
int dtype_bits, int dtype_bits,
int dtype_lanes, int dtype_lanes,
bool is_create, bool is_create,
CDGLArrayHandle* out) DGLArrayHandle* out)
int DGLArrayFree(DLTensorHandle handle) int DGLArrayFree(DGLArrayHandle handle)
int DGLArrayCopyFromTo(DLTensorHandle src, int DGLArrayCopyFromTo(DGLArrayHandle src,
DLTensorHandle to) DGLArrayHandle to)
int DGLArrayFromDLPack(DLManagedTensor* arr_from,
DLTensorHandle* out)
int DGLArrayToDLPack(DLTensorHandle arr_from,
DLManagedTensor** out,
int alignment)
void DGLDLManagedTensorCallDeleter(DLManagedTensor* dltensor)
cdef extern from "dgl/runtime/c_object_api.h": cdef extern from "dgl/runtime/c_object_api.h":
int DGLObjectFree(ObjectHandle handle) int DGLObjectFree(ObjectHandle handle)
...@@ -127,6 +119,14 @@ cdef extern from "dgl/runtime/c_object_api.h": ...@@ -127,6 +119,14 @@ cdef extern from "dgl/runtime/c_object_api.h":
int* out_type_code, int* out_type_code,
int* out_success) int* out_success)
cdef extern from "dgl/runtime/dlpack_convert.h":
int DGLArrayFromDLPack(DLManagedTensor* arr_from,
DGLArrayHandle* out)
int DGLArrayToDLPack(DGLArrayHandle arr_from,
DLManagedTensor** out,
int alignment)
void DGLDLManagedTensorCallDeleter(DLManagedTensor* dltensor)
cdef inline py_str(const char* x): cdef inline py_str(const char* x):
if PY_MAJOR_VERSION < 3: if PY_MAJOR_VERSION < 3:
return x return x
......
...@@ -4,7 +4,9 @@ from cpython cimport Py_INCREF, Py_DECREF ...@@ -4,7 +4,9 @@ from cpython cimport Py_INCREF, Py_DECREF
from numbers import Number, Integral from numbers import Number, Integral
from ..base import string_types from ..base import string_types
from ..object_generic import convert_to_object, ObjectGeneric from ..object_generic import convert_to_object, ObjectGeneric
from ..runtime_ctypes import DGLType, DGLContext, DGLByteArray from ..runtime_ctypes import DGLDataType as CTypesDGLDataType, \
DGLContext as CTypesDGLContext, \
DGLByteArray
cdef void dgl_callback_finalize(void* fhandle): cdef void dgl_callback_finalize(void* fhandle):
...@@ -107,13 +109,13 @@ cdef inline int make_arg(object arg, ...@@ -107,13 +109,13 @@ cdef inline int make_arg(object arg,
elif isinstance(arg, Number): elif isinstance(arg, Number):
value[0].v_float64 = arg value[0].v_float64 = arg
tcode[0] = kFloat tcode[0] = kFloat
elif isinstance(arg, DGLType): elif isinstance(arg, CTypesDGLDataType):
tstr = c_str(str(arg)) tstr = c_str(str(arg))
value[0].v_str = tstr value[0].v_str = tstr
tcode[0] = kStr tcode[0] = kStr
temp_args.append(tstr) temp_args.append(tstr)
elif isinstance(arg, DGLContext): elif isinstance(arg, CTypesDGLContext):
value[0].v_ctx = (<DLContext*>( value[0].v_ctx = (<DGLContext*>(
<unsigned long long>ctypes.addressof(arg)))[0] <unsigned long long>ctypes.addressof(arg)))[0]
tcode[0] = kDGLContext tcode[0] = kDGLContext
elif isinstance(arg, bytearray): elif isinstance(arg, bytearray):
...@@ -183,7 +185,7 @@ cdef inline object make_ret(DGLValue value, int tcode): ...@@ -183,7 +185,7 @@ cdef inline object make_ret(DGLValue value, int tcode):
elif tcode == kHandle: elif tcode == kHandle:
return ctypes_handle(value.v_handle) return ctypes_handle(value.v_handle)
elif tcode == kDGLContext: elif tcode == kDGLContext:
return DGLContext(value.v_ctx.device_type, value.v_ctx.device_id) return CTypesDGLContext(value.v_ctx.device_type, value.v_ctx.device_id)
# (minjie): class module are not used in DGL. # (minjie): class module are not used in DGL.
#elif tcode == kModuleHandle: #elif tcode == kModuleHandle:
# return _CLASS_MODULE(ctypes_handle(value.v_handle)) # return _CLASS_MODULE(ctypes_handle(value.v_handle))
......
from ..runtime_ctypes import DGLArrayHandle from ..runtime_ctypes import DGLArrayHandle as PyDGLArrayHandle
cdef const char* _c_str_dltensor = "dltensor" cdef const char* _c_str_dltensor = "dltensor"
cdef const char* _c_str_used_dltensor = "used_dltensor" cdef const char* _c_str_used_dltensor = "used_dltensor"
...@@ -13,7 +13,7 @@ cdef void _c_dlpack_deleter(object pycaps): ...@@ -13,7 +13,7 @@ cdef void _c_dlpack_deleter(object pycaps):
def _from_dlpack(object dltensor): def _from_dlpack(object dltensor):
cdef DLManagedTensor* ptr cdef DLManagedTensor* ptr
cdef DLTensorHandle chandle cdef DGLArrayHandle chandle
if pycapsule.PyCapsule_IsValid(dltensor, _c_str_dltensor): if pycapsule.PyCapsule_IsValid(dltensor, _c_str_dltensor):
ptr = <DLManagedTensor*>pycapsule.PyCapsule_GetPointer(dltensor, _c_str_dltensor) ptr = <DLManagedTensor*>pycapsule.PyCapsule_GetPointer(dltensor, _c_str_dltensor)
CALL(DGLArrayFromDLPack(ptr, &chandle)) CALL(DGLArrayFromDLPack(ptr, &chandle))
...@@ -25,7 +25,7 @@ def _from_dlpack(object dltensor): ...@@ -25,7 +25,7 @@ def _from_dlpack(object dltensor):
cdef class NDArrayBase: cdef class NDArrayBase:
cdef DLTensor* chandle cdef DGLArray* chandle
cdef int c_is_view cdef int c_is_view
cdef inline _set_handle(self, handle): cdef inline _set_handle(self, handle):
...@@ -34,7 +34,7 @@ cdef class NDArrayBase: ...@@ -34,7 +34,7 @@ cdef class NDArrayBase:
self.chandle = NULL self.chandle = NULL
else: else:
ptr = ctypes.cast(handle, ctypes.c_void_p).value ptr = ctypes.cast(handle, ctypes.c_void_p).value
self.chandle = <DLTensor*>(ptr) self.chandle = <DGLArray*>(ptr)
property _dgl_handle: property _dgl_handle:
def __get__(self): def __get__(self):
...@@ -46,7 +46,7 @@ cdef class NDArrayBase: ...@@ -46,7 +46,7 @@ cdef class NDArrayBase:
return None return None
else: else:
return ctypes.cast( return ctypes.cast(
<unsigned long long>self.chandle, DGLArrayHandle) <unsigned long long>self.chandle, PyDGLArrayHandle)
def __set__(self, value): def __set__(self, value):
self._set_handle(value) self._set_handle(value)
...@@ -82,7 +82,7 @@ cdef class NDArrayBase: ...@@ -82,7 +82,7 @@ cdef class NDArrayBase:
cdef c_make_array(void* chandle, is_view): cdef c_make_array(void* chandle, is_view):
ret = _CLASS_NDARRAY(None, is_view) ret = _CLASS_NDARRAY(None, is_view)
(<NDArrayBase>ret).chandle = <DLTensor*>chandle (<NDArrayBase>ret).chandle = <DGLArray*>chandle
return ret return ret
......
...@@ -6,7 +6,7 @@ import sys ...@@ -6,7 +6,7 @@ import sys
import ctypes import ctypes
import numpy as np import numpy as np
from .base import _LIB, check_call, c_array, string_types, _FFI_MODE, c_str from .base import _LIB, check_call, c_array, string_types, _FFI_MODE, c_str
from .runtime_ctypes import DGLType, DGLContext, DGLArray, DGLArrayHandle from .runtime_ctypes import DGLDataType, DGLContext, DGLArray, DGLArrayHandle
from .runtime_ctypes import TypeCode, dgl_shape_index_t from .runtime_ctypes import TypeCode, dgl_shape_index_t
...@@ -72,7 +72,7 @@ def numpyasarray(np_data): ...@@ -72,7 +72,7 @@ def numpyasarray(np_data):
arr.data = data.ctypes.data_as(ctypes.c_void_p) arr.data = data.ctypes.data_as(ctypes.c_void_p)
arr.shape = shape arr.shape = shape
arr.strides = None arr.strides = None
arr.dtype = DGLType(np.dtype(data.dtype).name) arr.dtype = DGLDataType(np.dtype(data.dtype).name)
arr.ndim = data.ndim arr.ndim = data.ndim
# CPU device # CPU device
arr.ctx = context(1, 0) arr.ctx = context(1, 0)
...@@ -101,7 +101,7 @@ def empty(shape, dtype="float32", ctx=context(1, 0)): ...@@ -101,7 +101,7 @@ def empty(shape, dtype="float32", ctx=context(1, 0)):
shape = c_array(dgl_shape_index_t, shape) shape = c_array(dgl_shape_index_t, shape)
ndim = ctypes.c_int(len(shape)) ndim = ctypes.c_int(len(shape))
handle = DGLArrayHandle() handle = DGLArrayHandle()
dtype = DGLType(dtype) dtype = DGLDataType(dtype)
check_call(_LIB.DGLArrayAlloc( check_call(_LIB.DGLArrayAlloc(
shape, ndim, shape, ndim,
ctypes.c_int(dtype.type_code), ctypes.c_int(dtype.type_code),
...@@ -139,7 +139,7 @@ def empty_shared_mem(name, is_create, shape, dtype="float32"): ...@@ -139,7 +139,7 @@ def empty_shared_mem(name, is_create, shape, dtype="float32"):
shape = c_array(dgl_shape_index_t, shape) shape = c_array(dgl_shape_index_t, shape)
ndim = ctypes.c_int(len(shape)) ndim = ctypes.c_int(len(shape))
handle = DGLArrayHandle() handle = DGLArrayHandle()
dtype = DGLType(dtype) dtype = DGLDataType(dtype)
check_call(_LIB.DGLArrayAllocSharedMem( check_call(_LIB.DGLArrayAllocSharedMem(
name, shape, ndim, name, shape, ndim,
ctypes.c_int(dtype.type_code), ctypes.c_int(dtype.type_code),
...@@ -254,7 +254,7 @@ class NDArrayBase(_NDArrayBase): ...@@ -254,7 +254,7 @@ class NDArrayBase(_NDArrayBase):
except: except:
raise TypeError('array must be an array_like data,' + raise TypeError('array must be an array_like data,' +
'type %s is not supported' % str(type(source_array))) 'type %s is not supported' % str(type(source_array)))
t = DGLType(self.dtype) t = DGLDataType(self.dtype)
shape, dtype = self.shape, self.dtype shape, dtype = self.shape, self.dtype
if t.lanes > 1: if t.lanes > 1:
shape = shape + (t.lanes,) shape = shape + (t.lanes,)
...@@ -286,7 +286,7 @@ class NDArrayBase(_NDArrayBase): ...@@ -286,7 +286,7 @@ class NDArrayBase(_NDArrayBase):
np_arr : numpy.ndarray np_arr : numpy.ndarray
The corresponding numpy array. The corresponding numpy array.
""" """
t = DGLType(self.dtype) t = DGLDataType(self.dtype)
shape, dtype = self.shape, self.dtype shape, dtype = self.shape, self.dtype
if t.lanes > 1: if t.lanes > 1:
shape = shape + (t.lanes,) shape = shape + (t.lanes,)
......
...@@ -17,7 +17,7 @@ class TypeCode(object): ...@@ -17,7 +17,7 @@ class TypeCode(object):
FLOAT = 2 FLOAT = 2
HANDLE = 3 HANDLE = 3
NULL = 4 NULL = 4
DGL_TYPE = 5 DGL_DATA_TYPE = 5
DGL_CONTEXT = 6 DGL_CONTEXT = 6
ARRAY_HANDLE = 7 ARRAY_HANDLE = 7
OBJECT_HANDLE = 8 OBJECT_HANDLE = 8
...@@ -33,7 +33,7 @@ class DGLByteArray(ctypes.Structure): ...@@ -33,7 +33,7 @@ class DGLByteArray(ctypes.Structure):
_fields_ = [("data", ctypes.POINTER(ctypes.c_byte)), _fields_ = [("data", ctypes.POINTER(ctypes.c_byte)),
("size", ctypes.c_size_t)] ("size", ctypes.c_size_t)]
class DGLType(ctypes.Structure): class DGLDataType(ctypes.Structure):
"""DGL datatype structure""" """DGL datatype structure"""
_fields_ = [("type_code", ctypes.c_uint8), _fields_ = [("type_code", ctypes.c_uint8),
("bits", ctypes.c_uint8), ("bits", ctypes.c_uint8),
...@@ -50,7 +50,7 @@ class DGLType(ctypes.Structure): ...@@ -50,7 +50,7 @@ class DGLType(ctypes.Structure):
if type_str in cls._cache: if type_str in cls._cache:
return cls._cache[type_str] return cls._cache[type_str]
inst = super(DGLType, cls).__new__(DGLType) inst = super(DGLDataType, cls).__new__(DGLDataType)
if isinstance(type_str, np.dtype): if isinstance(type_str, np.dtype):
type_str = str(type_str) type_str = str(type_str)
...@@ -84,7 +84,7 @@ class DGLType(ctypes.Structure): ...@@ -84,7 +84,7 @@ class DGLType(ctypes.Structure):
pass pass
def __repr__(self): def __repr__(self):
x = "%s%d" % (DGLType.CODE2STR[self.type_code], self.bits) x = "%s%d" % (DGLDataType.CODE2STR[self.type_code], self.bits)
if self.lanes != 1: if self.lanes != 1:
x += "x%d" % self.lanes x += "x%d" % self.lanes
return x return x
...@@ -250,7 +250,7 @@ class DGLArray(ctypes.Structure): ...@@ -250,7 +250,7 @@ class DGLArray(ctypes.Structure):
_fields_ = [("data", ctypes.c_void_p), _fields_ = [("data", ctypes.c_void_p),
("ctx", DGLContext), ("ctx", DGLContext),
("ndim", ctypes.c_int), ("ndim", ctypes.c_int),
("dtype", DGLType), ("dtype", DGLDataType),
("shape", ctypes.POINTER(dgl_shape_index_t)), ("shape", ctypes.POINTER(dgl_shape_index_t)),
("strides", ctypes.POINTER(dgl_shape_index_t)), ("strides", ctypes.POINTER(dgl_shape_index_t)),
("byte_offset", ctypes.c_uint64)] ("byte_offset", ctypes.c_uint64)]
......
...@@ -13,7 +13,7 @@ import numpy as _np ...@@ -13,7 +13,7 @@ import numpy as _np
from ._ffi.object import register_object, ObjectBase from ._ffi.object import register_object, ObjectBase
from ._ffi.function import _init_api from ._ffi.function import _init_api
from ._ffi.ndarray import DGLContext, DGLType, NDArrayBase from ._ffi.ndarray import DGLContext, DGLDataType, NDArrayBase
from ._ffi.ndarray import context, empty, empty_shared_mem, from_dlpack, numpyasarray from ._ffi.ndarray import context, empty, empty_shared_mem, from_dlpack, numpyasarray
from ._ffi.ndarray import _set_class_ndarray from ._ffi.ndarray import _set_class_ndarray
from . import backend as F from . import backend as F
......
...@@ -19,8 +19,8 @@ using namespace dgl::runtime; ...@@ -19,8 +19,8 @@ using namespace dgl::runtime;
namespace dgl { namespace dgl {
namespace aten { namespace aten {
IdArray NewIdArray(int64_t length, DLContext ctx, uint8_t nbits) { IdArray NewIdArray(int64_t length, DGLContext ctx, uint8_t nbits) {
return IdArray::Empty({length}, DLDataType{kDLInt, nbits, 1}, ctx); return IdArray::Empty({length}, DGLDataType{kDGLInt, nbits, 1}, ctx);
} }
IdArray Clone(IdArray arr) { IdArray Clone(IdArray arr) {
...@@ -29,7 +29,7 @@ IdArray Clone(IdArray arr) { ...@@ -29,7 +29,7 @@ IdArray Clone(IdArray arr) {
return ret; return ret;
} }
IdArray Range(int64_t low, int64_t high, uint8_t nbits, DLContext ctx) { IdArray Range(int64_t low, int64_t high, uint8_t nbits, DGLContext ctx) {
IdArray ret; IdArray ret;
ATEN_XPU_SWITCH_CUDA(ctx.device_type, XPU, "Range", { ATEN_XPU_SWITCH_CUDA(ctx.device_type, XPU, "Range", {
if (nbits == 32) { if (nbits == 32) {
...@@ -43,7 +43,7 @@ IdArray Range(int64_t low, int64_t high, uint8_t nbits, DLContext ctx) { ...@@ -43,7 +43,7 @@ IdArray Range(int64_t low, int64_t high, uint8_t nbits, DLContext ctx) {
return ret; return ret;
} }
IdArray Full(int64_t val, int64_t length, uint8_t nbits, DLContext ctx) { IdArray Full(int64_t val, int64_t length, uint8_t nbits, DGLContext ctx) {
IdArray ret; IdArray ret;
ATEN_XPU_SWITCH_CUDA(ctx.device_type, XPU, "Full", { ATEN_XPU_SWITCH_CUDA(ctx.device_type, XPU, "Full", {
if (nbits == 32) { if (nbits == 32) {
...@@ -58,7 +58,7 @@ IdArray Full(int64_t val, int64_t length, uint8_t nbits, DLContext ctx) { ...@@ -58,7 +58,7 @@ IdArray Full(int64_t val, int64_t length, uint8_t nbits, DLContext ctx) {
} }
template <typename DType> template <typename DType>
NDArray Full(DType val, int64_t length, DLContext ctx) { NDArray Full(DType val, int64_t length, DGLContext ctx) {
NDArray ret; NDArray ret;
ATEN_XPU_SWITCH_CUDA(ctx.device_type, XPU, "Full", { ATEN_XPU_SWITCH_CUDA(ctx.device_type, XPU, "Full", {
ret = impl::Full<XPU, DType>(val, length, ctx); ret = impl::Full<XPU, DType>(val, length, ctx);
...@@ -66,10 +66,10 @@ NDArray Full(DType val, int64_t length, DLContext ctx) { ...@@ -66,10 +66,10 @@ NDArray Full(DType val, int64_t length, DLContext ctx) {
return ret; return ret;
} }
template NDArray Full<int32_t>(int32_t val, int64_t length, DLContext ctx); template NDArray Full<int32_t>(int32_t val, int64_t length, DGLContext ctx);
template NDArray Full<int64_t>(int64_t val, int64_t length, DLContext ctx); template NDArray Full<int64_t>(int64_t val, int64_t length, DGLContext ctx);
template NDArray Full<float>(float val, int64_t length, DLContext ctx); template NDArray Full<float>(float val, int64_t length, DGLContext ctx);
template NDArray Full<double>(double val, int64_t length, DLContext ctx); template NDArray Full<double>(double val, int64_t length, DGLContext ctx);
IdArray AsNumBits(IdArray arr, uint8_t bits) { IdArray AsNumBits(IdArray arr, uint8_t bits) {
CHECK(bits == 32 || bits == 64) CHECK(bits == 32 || bits == 64)
...@@ -315,7 +315,7 @@ std::pair<IdArray, IdArray> Sort(IdArray array, const int num_bits) { ...@@ -315,7 +315,7 @@ std::pair<IdArray, IdArray> Sort(IdArray array, const int num_bits) {
std::string ToDebugString(NDArray array) { std::string ToDebugString(NDArray array) {
std::ostringstream oss; std::ostringstream oss;
NDArray a = array.CopyTo(DLContext{kDLCPU, 0}); NDArray a = array.CopyTo(DGLContext{kDGLCPU, 0});
oss << "array(["; oss << "array([";
ATEN_DTYPE_SWITCH(a->dtype, DType, "array", { ATEN_DTYPE_SWITCH(a->dtype, DType, "array", {
for (int64_t i = 0; i < std::min<int64_t>(a.NumElements(), 10L); ++i) { for (int64_t i = 0; i < std::min<int64_t>(a.NumElements(), 10L); ++i) {
...@@ -1132,10 +1132,10 @@ DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLExistSharedMemArray") ...@@ -1132,10 +1132,10 @@ DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLExistSharedMemArray")
DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLArrayCastToSigned") DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLArrayCastToSigned")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
NDArray array = args[0]; NDArray array = args[0];
CHECK_EQ(array->dtype.code, kDLUInt); CHECK_EQ(array->dtype.code, kDGLUInt);
std::vector<int64_t> shape(array->shape, array->shape + array->ndim); std::vector<int64_t> shape(array->shape, array->shape + array->ndim);
DLDataType dtype = array->dtype; DGLDataType dtype = array->dtype;
dtype.code = kDLInt; dtype.code = kDGLInt;
*rv = array.CreateView(shape, dtype, 0); *rv = array.CreateView(shape, dtype, 0);
}); });
......
...@@ -16,176 +16,176 @@ namespace dgl { ...@@ -16,176 +16,176 @@ namespace dgl {
namespace aten { namespace aten {
namespace impl { namespace impl {
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
IdArray Full(IdType val, int64_t length, DLContext ctx); IdArray Full(IdType val, int64_t length, DGLContext ctx);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
IdArray Range(IdType low, IdType high, DLContext ctx); IdArray Range(IdType low, IdType high, DGLContext ctx);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
IdArray AsNumBits(IdArray arr, uint8_t bits); IdArray AsNumBits(IdArray arr, uint8_t bits);
template <DLDeviceType XPU, typename IdType, typename Op> template <DGLDeviceType XPU, typename IdType, typename Op>
IdArray BinaryElewise(IdArray lhs, IdArray rhs); IdArray BinaryElewise(IdArray lhs, IdArray rhs);
template <DLDeviceType XPU, typename IdType, typename Op> template <DGLDeviceType XPU, typename IdType, typename Op>
IdArray BinaryElewise(IdArray lhs, IdType rhs); IdArray BinaryElewise(IdArray lhs, IdType rhs);
template <DLDeviceType XPU, typename IdType, typename Op> template <DGLDeviceType XPU, typename IdType, typename Op>
IdArray BinaryElewise(IdType lhs, IdArray rhs); IdArray BinaryElewise(IdType lhs, IdArray rhs);
template <DLDeviceType XPU, typename IdType, typename Op> template <DGLDeviceType XPU, typename IdType, typename Op>
IdArray UnaryElewise(IdArray array); IdArray UnaryElewise(IdArray array);
template <DLDeviceType XPU, typename DType, typename IdType> template <DGLDeviceType XPU, typename DType, typename IdType>
NDArray IndexSelect(NDArray array, IdArray index); NDArray IndexSelect(NDArray array, IdArray index);
template <DLDeviceType XPU, typename DType> template <DGLDeviceType XPU, typename DType>
DType IndexSelect(NDArray array, int64_t index); DType IndexSelect(NDArray array, int64_t index);
template <DLDeviceType XPU, typename DType> template <DGLDeviceType XPU, typename DType>
IdArray NonZero(BoolArray bool_arr); IdArray NonZero(BoolArray bool_arr);
template <DLDeviceType XPU, typename DType> template <DGLDeviceType XPU, typename DType>
std::pair<IdArray, IdArray> Sort(IdArray array, int num_bits); std::pair<IdArray, IdArray> Sort(IdArray array, int num_bits);
template <DLDeviceType XPU, typename DType, typename IdType> template <DGLDeviceType XPU, typename DType, typename IdType>
NDArray Scatter(NDArray array, IdArray indices); NDArray Scatter(NDArray array, IdArray indices);
template <DLDeviceType XPU, typename DType, typename IdType> template <DGLDeviceType XPU, typename DType, typename IdType>
void Scatter_(IdArray index, NDArray value, NDArray out); void Scatter_(IdArray index, NDArray value, NDArray out);
template <DLDeviceType XPU, typename DType, typename IdType> template <DGLDeviceType XPU, typename DType, typename IdType>
NDArray Repeat(NDArray array, IdArray repeats); NDArray Repeat(NDArray array, IdArray repeats);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
IdArray Relabel_(const std::vector<IdArray>& arrays); IdArray Relabel_(const std::vector<IdArray>& arrays);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
NDArray Concat(const std::vector<IdArray>& arrays); NDArray Concat(const std::vector<IdArray>& arrays);
template <DLDeviceType XPU, typename DType> template <DGLDeviceType XPU, typename DType>
std::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, DType pad_value); std::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, DType pad_value);
template <DLDeviceType XPU, typename DType, typename IdType> template <DGLDeviceType XPU, typename DType, typename IdType>
std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths); std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
IdArray CumSum(IdArray array, bool prepend_zero); IdArray CumSum(IdArray array, bool prepend_zero);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
IdArray NonZero(NDArray array); IdArray NonZero(NDArray array);
// sparse arrays // sparse arrays
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col); bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
runtime::NDArray CSRIsNonZero(CSRMatrix csr, runtime::NDArray row, runtime::NDArray col); runtime::NDArray CSRIsNonZero(CSRMatrix csr, runtime::NDArray row, runtime::NDArray col);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
bool CSRHasDuplicate(CSRMatrix csr); bool CSRHasDuplicate(CSRMatrix csr);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
int64_t CSRGetRowNNZ(CSRMatrix csr, int64_t row); int64_t CSRGetRowNNZ(CSRMatrix csr, int64_t row);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
runtime::NDArray CSRGetRowNNZ(CSRMatrix csr, runtime::NDArray row); runtime::NDArray CSRGetRowNNZ(CSRMatrix csr, runtime::NDArray row);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
runtime::NDArray CSRGetRowColumnIndices(CSRMatrix csr, int64_t row); runtime::NDArray CSRGetRowColumnIndices(CSRMatrix csr, int64_t row);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
runtime::NDArray CSRGetRowData(CSRMatrix csr, int64_t row); runtime::NDArray CSRGetRowData(CSRMatrix csr, int64_t row);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
bool CSRIsSorted(CSRMatrix csr); bool CSRIsSorted(CSRMatrix csr);
template <DLDeviceType XPU, typename IdType, typename DType> template <DGLDeviceType XPU, typename IdType, typename DType>
runtime::NDArray CSRGetData( runtime::NDArray CSRGetData(
CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols, bool return_eids, CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols, bool return_eids,
runtime::NDArray weights, DType filler); runtime::NDArray weights, DType filler);
template <DLDeviceType XPU, typename IdType, typename DType> template <DGLDeviceType XPU, typename IdType, typename DType>
runtime::NDArray CSRGetData( runtime::NDArray CSRGetData(
CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols, CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols,
runtime::NDArray weights, DType filler) { runtime::NDArray weights, DType filler) {
return CSRGetData<XPU, IdType, DType>(csr, rows, cols, false, weights, filler); return CSRGetData<XPU, IdType, DType>(csr, rows, cols, false, weights, filler);
} }
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
NDArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols) { NDArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols) {
return CSRGetData<XPU, IdType, IdType>(csr, rows, cols, true, NullArray(rows->dtype), -1); return CSRGetData<XPU, IdType, IdType>(csr, rows, cols, true, NullArray(rows->dtype), -1);
} }
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
std::vector<runtime::NDArray> CSRGetDataAndIndices( std::vector<runtime::NDArray> CSRGetDataAndIndices(
CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols); CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRTranspose(CSRMatrix csr); CSRMatrix CSRTranspose(CSRMatrix csr);
// Convert CSR to COO // Convert CSR to COO
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
COOMatrix CSRToCOO(CSRMatrix csr); COOMatrix CSRToCOO(CSRMatrix csr);
// Convert CSR to COO using data array as order // Convert CSR to COO using data array as order
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
COOMatrix CSRToCOODataAsOrder(CSRMatrix csr); COOMatrix CSRToCOODataAsOrder(CSRMatrix csr);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end); CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRSliceRows(CSRMatrix csr, runtime::NDArray rows); CSRMatrix CSRSliceRows(CSRMatrix csr, runtime::NDArray rows);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols); CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
void CSRSort_(CSRMatrix* csr); void CSRSort_(CSRMatrix* csr);
template <DLDeviceType XPU, typename IdType, typename TagType> template <DGLDeviceType XPU, typename IdType, typename TagType>
std::pair<CSRMatrix, NDArray> CSRSortByTag( std::pair<CSRMatrix, NDArray> CSRSortByTag(
const CSRMatrix &csr, IdArray tag_array, int64_t num_tags); const CSRMatrix &csr, IdArray tag_array, int64_t num_tags);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRReorder(CSRMatrix csr, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids); CSRMatrix CSRReorder(CSRMatrix csr, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
COOMatrix COOReorder(COOMatrix coo, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids); COOMatrix COOReorder(COOMatrix coo, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries); CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries);
// FloatType is the type of probability data. // FloatType is the type of probability data.
template <DLDeviceType XPU, typename IdType, typename FloatType> template <DGLDeviceType XPU, typename IdType, typename FloatType>
COOMatrix CSRRowWiseSampling( COOMatrix CSRRowWiseSampling(
CSRMatrix mat, IdArray rows, int64_t num_samples, FloatArray prob, bool replace); CSRMatrix mat, IdArray rows, int64_t num_samples, FloatArray prob, bool replace);
// FloatType is the type of probability data. // FloatType is the type of probability data.
template <DLDeviceType XPU, typename IdType, typename FloatType> template <DGLDeviceType XPU, typename IdType, typename FloatType>
COOMatrix CSRRowWisePerEtypeSampling( COOMatrix CSRRowWisePerEtypeSampling(
CSRMatrix mat, IdArray rows, IdArray etypes, CSRMatrix mat, IdArray rows, IdArray etypes,
const std::vector<int64_t>& num_samples, FloatArray prob, bool replace, const std::vector<int64_t>& num_samples, FloatArray prob, bool replace,
bool etype_sorted); bool etype_sorted);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
COOMatrix CSRRowWiseSamplingUniform( COOMatrix CSRRowWiseSamplingUniform(
CSRMatrix mat, IdArray rows, int64_t num_samples, bool replace); CSRMatrix mat, IdArray rows, int64_t num_samples, bool replace);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
COOMatrix CSRRowWisePerEtypeSamplingUniform( COOMatrix CSRRowWisePerEtypeSamplingUniform(
CSRMatrix mat, IdArray rows, IdArray etypes, const std::vector<int64_t>& num_samples, CSRMatrix mat, IdArray rows, IdArray etypes, const std::vector<int64_t>& num_samples,
bool replace, bool etype_sorted); bool replace, bool etype_sorted);
// FloatType is the type of weight data. // FloatType is the type of weight data.
template <DLDeviceType XPU, typename IdType, typename DType> template <DGLDeviceType XPU, typename IdType, typename DType>
COOMatrix CSRRowWiseTopk( COOMatrix CSRRowWiseTopk(
CSRMatrix mat, IdArray rows, int64_t k, NDArray weight, bool ascending); CSRMatrix mat, IdArray rows, int64_t k, NDArray weight, bool ascending);
template <DLDeviceType XPU, typename IdType, typename FloatType> template <DGLDeviceType XPU, typename IdType, typename FloatType>
COOMatrix CSRRowWiseSamplingBiased( COOMatrix CSRRowWiseSamplingBiased(
CSRMatrix mat, CSRMatrix mat,
IdArray rows, IdArray rows,
...@@ -194,7 +194,7 @@ COOMatrix CSRRowWiseSamplingBiased( ...@@ -194,7 +194,7 @@ COOMatrix CSRRowWiseSamplingBiased(
FloatArray bias, FloatArray bias,
bool replace); bool replace);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling( std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
const CSRMatrix& csr, const CSRMatrix& csr,
int64_t num_samples, int64_t num_samples,
...@@ -204,117 +204,117 @@ std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling( ...@@ -204,117 +204,117 @@ std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
double redundancy); double redundancy);
// Union CSRMatrixes // Union CSRMatrixes
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
CSRMatrix UnionCsr(const std::vector<CSRMatrix>& csrs); CSRMatrix UnionCsr(const std::vector<CSRMatrix>& csrs);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple(CSRMatrix csr); std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple(CSRMatrix csr);
/////////////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////////////
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
bool COOIsNonZero(COOMatrix coo, int64_t row, int64_t col); bool COOIsNonZero(COOMatrix coo, int64_t row, int64_t col);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
runtime::NDArray COOIsNonZero(COOMatrix coo, runtime::NDArray row, runtime::NDArray col); runtime::NDArray COOIsNonZero(COOMatrix coo, runtime::NDArray row, runtime::NDArray col);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
bool COOHasDuplicate(COOMatrix coo); bool COOHasDuplicate(COOMatrix coo);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
int64_t COOGetRowNNZ(COOMatrix coo, int64_t row); int64_t COOGetRowNNZ(COOMatrix coo, int64_t row);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
runtime::NDArray COOGetRowNNZ(COOMatrix coo, runtime::NDArray row); runtime::NDArray COOGetRowNNZ(COOMatrix coo, runtime::NDArray row);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
std::pair<runtime::NDArray, runtime::NDArray> std::pair<runtime::NDArray, runtime::NDArray>
COOGetRowDataAndIndices(COOMatrix coo, int64_t row); COOGetRowDataAndIndices(COOMatrix coo, int64_t row);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
std::vector<runtime::NDArray> COOGetDataAndIndices( std::vector<runtime::NDArray> COOGetDataAndIndices(
COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols); COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
runtime::NDArray COOGetData(COOMatrix mat, runtime::NDArray rows, runtime::NDArray cols); runtime::NDArray COOGetData(COOMatrix mat, runtime::NDArray rows, runtime::NDArray cols);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
COOMatrix COOTranspose(COOMatrix coo); COOMatrix COOTranspose(COOMatrix coo);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
CSRMatrix COOToCSR(COOMatrix coo); CSRMatrix COOToCSR(COOMatrix coo);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
COOMatrix COOSliceRows(COOMatrix coo, int64_t start, int64_t end); COOMatrix COOSliceRows(COOMatrix coo, int64_t start, int64_t end);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
COOMatrix COOSliceRows(COOMatrix coo, runtime::NDArray rows); COOMatrix COOSliceRows(COOMatrix coo, runtime::NDArray rows);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
COOMatrix COOSliceMatrix(COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols); COOMatrix COOSliceMatrix(COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo); std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos); COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
void COOSort_(COOMatrix* mat, bool sort_column); void COOSort_(COOMatrix* mat, bool sort_column);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
std::pair<bool, bool> COOIsSorted(COOMatrix coo); std::pair<bool, bool> COOIsSorted(COOMatrix coo);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
COOMatrix COORemove(COOMatrix coo, IdArray entries); COOMatrix COORemove(COOMatrix coo, IdArray entries);
// FloatType is the type of probability data. // FloatType is the type of probability data.
template <DLDeviceType XPU, typename IdType, typename FloatType> template <DGLDeviceType XPU, typename IdType, typename FloatType>
COOMatrix COORowWiseSampling( COOMatrix COORowWiseSampling(
COOMatrix mat, IdArray rows, int64_t num_samples, FloatArray prob, bool replace); COOMatrix mat, IdArray rows, int64_t num_samples, FloatArray prob, bool replace);
// FloatType is the type of probability data. // FloatType is the type of probability data.
template <DLDeviceType XPU, typename IdType, typename FloatType> template <DGLDeviceType XPU, typename IdType, typename FloatType>
COOMatrix COORowWisePerEtypeSampling( COOMatrix COORowWisePerEtypeSampling(
COOMatrix mat, IdArray rows, IdArray etypes, COOMatrix mat, IdArray rows, IdArray etypes,
const std::vector<int64_t>& num_samples, FloatArray prob, bool replace, bool etype_sorted); const std::vector<int64_t>& num_samples, FloatArray prob, bool replace, bool etype_sorted);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
COOMatrix COORowWiseSamplingUniform( COOMatrix COORowWiseSamplingUniform(
COOMatrix mat, IdArray rows, int64_t num_samples, bool replace); COOMatrix mat, IdArray rows, int64_t num_samples, bool replace);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
COOMatrix COORowWisePerEtypeSamplingUniform( COOMatrix COORowWisePerEtypeSamplingUniform(
COOMatrix mat, IdArray rows, IdArray etypes, const std::vector<int64_t>& num_samples, COOMatrix mat, IdArray rows, IdArray etypes, const std::vector<int64_t>& num_samples,
bool replace, bool etype_sorted); bool replace, bool etype_sorted);
// FloatType is the type of weight data. // FloatType is the type of weight data.
template <DLDeviceType XPU, typename IdType, typename FloatType> template <DGLDeviceType XPU, typename IdType, typename FloatType>
COOMatrix COORowWiseTopk( COOMatrix COORowWiseTopk(
COOMatrix mat, IdArray rows, int64_t k, FloatArray weight, bool ascending); COOMatrix mat, IdArray rows, int64_t k, FloatArray weight, bool ascending);
///////////////////////// Graph Traverse routines ////////////////////////// ///////////////////////// Graph Traverse routines //////////////////////////
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
Frontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source); Frontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
Frontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source); Frontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
Frontiers TopologicalNodesFrontiers(const CSRMatrix& csr); Frontiers TopologicalNodesFrontiers(const CSRMatrix& csr);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source); Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
Frontiers DGLDFSLabeledEdges(const CSRMatrix& csr, Frontiers DGLDFSLabeledEdges(const CSRMatrix& csr,
IdArray source, IdArray source,
const bool has_reverse_edge, const bool has_reverse_edge,
const bool has_nontree_edge, const bool has_nontree_edge,
const bool return_labels); const bool return_labels);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
COOMatrix COOLineGraph(const COOMatrix &coo, bool backtracking); COOMatrix COOLineGraph(const COOMatrix &coo, bool backtracking);
} // namespace impl } // namespace impl
......
...@@ -16,7 +16,7 @@ namespace aten { ...@@ -16,7 +16,7 @@ namespace aten {
// Check whether the given arguments have the same context. // Check whether the given arguments have the same context.
inline void CheckCtx( inline void CheckCtx(
const DLContext& ctx, const DGLContext& ctx,
const std::vector<NDArray>& arrays, const std::vector<NDArray>& arrays,
const std::vector<std::string>& names) { const std::vector<std::string>& names) {
for (size_t i = 0; i < arrays.size(); ++i) { for (size_t i = 0; i < arrays.size(); ++i) {
......
...@@ -10,7 +10,7 @@ using runtime::NDArray; ...@@ -10,7 +10,7 @@ using runtime::NDArray;
namespace aten { namespace aten {
namespace impl { namespace impl {
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
IdArray CumSum(IdArray array, bool prepend_zero) { IdArray CumSum(IdArray array, bool prepend_zero) {
const int64_t len = array.NumElements(); const int64_t len = array.NumElements();
if (len == 0) if (len == 0)
...@@ -34,8 +34,8 @@ IdArray CumSum(IdArray array, bool prepend_zero) { ...@@ -34,8 +34,8 @@ IdArray CumSum(IdArray array, bool prepend_zero) {
} }
} }
template IdArray CumSum<kDLCPU, int32_t>(IdArray, bool); template IdArray CumSum<kDGLCPU, int32_t>(IdArray, bool);
template IdArray CumSum<kDLCPU, int64_t>(IdArray, bool); template IdArray CumSum<kDGLCPU, int64_t>(IdArray, bool);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
...@@ -10,7 +10,7 @@ using runtime::NDArray; ...@@ -10,7 +10,7 @@ using runtime::NDArray;
namespace aten { namespace aten {
namespace impl { namespace impl {
template<DLDeviceType XPU, typename DType, typename IdType> template<DGLDeviceType XPU, typename DType, typename IdType>
NDArray IndexSelect(NDArray array, IdArray index) { NDArray IndexSelect(NDArray array, IdArray index) {
CHECK_EQ(array->shape[0], array.NumElements()) << "Only support tensor" CHECK_EQ(array->shape[0], array.NumElements()) << "Only support tensor"
<< " whose first dimension equals number of elements, e.g. (5,), (5, 1)"; << " whose first dimension equals number of elements, e.g. (5,), (5, 1)";
...@@ -28,25 +28,25 @@ NDArray IndexSelect(NDArray array, IdArray index) { ...@@ -28,25 +28,25 @@ NDArray IndexSelect(NDArray array, IdArray index) {
return ret; return ret;
} }
template NDArray IndexSelect<kDLCPU, int32_t, int32_t>(NDArray, IdArray); template NDArray IndexSelect<kDGLCPU, int32_t, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLCPU, int32_t, int64_t>(NDArray, IdArray); template NDArray IndexSelect<kDGLCPU, int32_t, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLCPU, int64_t, int32_t>(NDArray, IdArray); template NDArray IndexSelect<kDGLCPU, int64_t, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLCPU, int64_t, int64_t>(NDArray, IdArray); template NDArray IndexSelect<kDGLCPU, int64_t, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLCPU, float, int32_t>(NDArray, IdArray); template NDArray IndexSelect<kDGLCPU, float, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLCPU, float, int64_t>(NDArray, IdArray); template NDArray IndexSelect<kDGLCPU, float, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLCPU, double, int32_t>(NDArray, IdArray); template NDArray IndexSelect<kDGLCPU, double, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLCPU, double, int64_t>(NDArray, IdArray); template NDArray IndexSelect<kDGLCPU, double, int64_t>(NDArray, IdArray);
template <DLDeviceType XPU, typename DType> template <DGLDeviceType XPU, typename DType>
DType IndexSelect(NDArray array, int64_t index) { DType IndexSelect(NDArray array, int64_t index) {
const DType* data = static_cast<DType*>(array->data); const DType* data = static_cast<DType*>(array->data);
return data[index]; return data[index];
} }
template int32_t IndexSelect<kDLCPU, int32_t>(NDArray array, int64_t index); template int32_t IndexSelect<kDGLCPU, int32_t>(NDArray array, int64_t index);
template int64_t IndexSelect<kDLCPU, int64_t>(NDArray array, int64_t index); template int64_t IndexSelect<kDGLCPU, int64_t>(NDArray array, int64_t index);
template float IndexSelect<kDLCPU, float>(NDArray array, int64_t index); template float IndexSelect<kDGLCPU, float>(NDArray array, int64_t index);
template double IndexSelect<kDLCPU, double>(NDArray array, int64_t index); template double IndexSelect<kDGLCPU, double>(NDArray array, int64_t index);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
...@@ -10,7 +10,7 @@ using runtime::NDArray; ...@@ -10,7 +10,7 @@ using runtime::NDArray;
namespace aten { namespace aten {
namespace impl { namespace impl {
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
IdArray NonZero(IdArray array) { IdArray NonZero(IdArray array) {
std::vector<int64_t> ret; std::vector<int64_t> ret;
const IdType* data = array.Ptr<IdType>(); const IdType* data = array.Ptr<IdType>();
...@@ -20,8 +20,8 @@ IdArray NonZero(IdArray array) { ...@@ -20,8 +20,8 @@ IdArray NonZero(IdArray array) {
return NDArray::FromVector(ret, array->ctx); return NDArray::FromVector(ret, array->ctx);
} }
template IdArray NonZero<kDLCPU, int32_t>(IdArray); template IdArray NonZero<kDGLCPU, int32_t>(IdArray);
template IdArray NonZero<kDLCPU, int64_t>(IdArray); template IdArray NonZero<kDGLCPU, int64_t>(IdArray);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
...@@ -17,7 +17,7 @@ namespace impl { ...@@ -17,7 +17,7 @@ namespace impl {
///////////////////////////// AsNumBits ///////////////////////////// ///////////////////////////// AsNumBits /////////////////////////////
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
IdArray AsNumBits(IdArray arr, uint8_t bits) { IdArray AsNumBits(IdArray arr, uint8_t bits) {
CHECK(bits == 32 || bits == 64) << "invalid number of integer bits"; CHECK(bits == 32 || bits == 64) << "invalid number of integer bits";
if (sizeof(IdType) * 8 == bits) { if (sizeof(IdType) * 8 == bits) {
...@@ -40,12 +40,12 @@ IdArray AsNumBits(IdArray arr, uint8_t bits) { ...@@ -40,12 +40,12 @@ IdArray AsNumBits(IdArray arr, uint8_t bits) {
return ret; return ret;
} }
template IdArray AsNumBits<kDLCPU, int32_t>(IdArray arr, uint8_t bits); template IdArray AsNumBits<kDGLCPU, int32_t>(IdArray arr, uint8_t bits);
template IdArray AsNumBits<kDLCPU, int64_t>(IdArray arr, uint8_t bits); template IdArray AsNumBits<kDGLCPU, int64_t>(IdArray arr, uint8_t bits);
///////////////////////////// BinaryElewise ///////////////////////////// ///////////////////////////// BinaryElewise /////////////////////////////
template <DLDeviceType XPU, typename IdType, typename Op> template <DGLDeviceType XPU, typename IdType, typename Op>
IdArray BinaryElewise(IdArray lhs, IdArray rhs) { IdArray BinaryElewise(IdArray lhs, IdArray rhs) {
IdArray ret = NewIdArray(lhs->shape[0], lhs->ctx, lhs->dtype.bits); IdArray ret = NewIdArray(lhs->shape[0], lhs->ctx, lhs->dtype.bits);
const IdType* lhs_data = static_cast<IdType*>(lhs->data); const IdType* lhs_data = static_cast<IdType*>(lhs->data);
...@@ -59,30 +59,30 @@ IdArray BinaryElewise(IdArray lhs, IdArray rhs) { ...@@ -59,30 +59,30 @@ IdArray BinaryElewise(IdArray lhs, IdArray rhs) {
return ret; return ret;
} }
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Add>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Add>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Sub>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Sub>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Mul>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Mul>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Div>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Div>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Mod>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Mod>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::GT>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::GT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::LT>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::LT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::GE>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::GE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::LE>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::LE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::EQ>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::EQ>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::NE>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::NE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Add>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Add>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Sub>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Sub>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Mul>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Mul>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Div>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Div>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Mod>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Mod>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::GT>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int64_t, arith::GT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::LT>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int64_t, arith::LT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::GE>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int64_t, arith::GE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::LE>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int64_t, arith::LE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::EQ>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int64_t, arith::EQ>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::NE>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int64_t, arith::NE>(IdArray lhs, IdArray rhs);
template <DLDeviceType XPU, typename IdType, typename Op> template <DGLDeviceType XPU, typename IdType, typename Op>
IdArray BinaryElewise(IdArray lhs, IdType rhs) { IdArray BinaryElewise(IdArray lhs, IdType rhs) {
IdArray ret = NewIdArray(lhs->shape[0], lhs->ctx, lhs->dtype.bits); IdArray ret = NewIdArray(lhs->shape[0], lhs->ctx, lhs->dtype.bits);
const IdType* lhs_data = static_cast<IdType*>(lhs->data); const IdType* lhs_data = static_cast<IdType*>(lhs->data);
...@@ -95,30 +95,30 @@ IdArray BinaryElewise(IdArray lhs, IdType rhs) { ...@@ -95,30 +95,30 @@ IdArray BinaryElewise(IdArray lhs, IdType rhs) {
return ret; return ret;
} }
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Add>(IdArray lhs, int32_t rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Add>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Sub>(IdArray lhs, int32_t rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Sub>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Mul>(IdArray lhs, int32_t rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Mul>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Div>(IdArray lhs, int32_t rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Div>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Mod>(IdArray lhs, int32_t rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Mod>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::GT>(IdArray lhs, int32_t rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::GT>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::LT>(IdArray lhs, int32_t rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::LT>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::GE>(IdArray lhs, int32_t rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::GE>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::LE>(IdArray lhs, int32_t rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::LE>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::EQ>(IdArray lhs, int32_t rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::EQ>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::NE>(IdArray lhs, int32_t rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::NE>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Add>(IdArray lhs, int64_t rhs); template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Add>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Sub>(IdArray lhs, int64_t rhs); template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Sub>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Mul>(IdArray lhs, int64_t rhs); template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Mul>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Div>(IdArray lhs, int64_t rhs); template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Div>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Mod>(IdArray lhs, int64_t rhs); template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Mod>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::GT>(IdArray lhs, int64_t rhs); template IdArray BinaryElewise<kDGLCPU, int64_t, arith::GT>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::LT>(IdArray lhs, int64_t rhs); template IdArray BinaryElewise<kDGLCPU, int64_t, arith::LT>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::GE>(IdArray lhs, int64_t rhs); template IdArray BinaryElewise<kDGLCPU, int64_t, arith::GE>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::LE>(IdArray lhs, int64_t rhs); template IdArray BinaryElewise<kDGLCPU, int64_t, arith::LE>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::EQ>(IdArray lhs, int64_t rhs); template IdArray BinaryElewise<kDGLCPU, int64_t, arith::EQ>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::NE>(IdArray lhs, int64_t rhs); template IdArray BinaryElewise<kDGLCPU, int64_t, arith::NE>(IdArray lhs, int64_t rhs);
template <DLDeviceType XPU, typename IdType, typename Op> template <DGLDeviceType XPU, typename IdType, typename Op>
IdArray BinaryElewise(IdType lhs, IdArray rhs) { IdArray BinaryElewise(IdType lhs, IdArray rhs) {
IdArray ret = NewIdArray(rhs->shape[0], rhs->ctx, rhs->dtype.bits); IdArray ret = NewIdArray(rhs->shape[0], rhs->ctx, rhs->dtype.bits);
const IdType* rhs_data = static_cast<IdType*>(rhs->data); const IdType* rhs_data = static_cast<IdType*>(rhs->data);
...@@ -131,30 +131,30 @@ IdArray BinaryElewise(IdType lhs, IdArray rhs) { ...@@ -131,30 +131,30 @@ IdArray BinaryElewise(IdType lhs, IdArray rhs) {
return ret; return ret;
} }
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Add>(int32_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Add>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Sub>(int32_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Sub>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Mul>(int32_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Mul>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Div>(int32_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Div>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Mod>(int32_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Mod>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::GT>(int32_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::GT>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::LT>(int32_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::LT>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::GE>(int32_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::GE>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::LE>(int32_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::LE>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::EQ>(int32_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::EQ>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::NE>(int32_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int32_t, arith::NE>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Add>(int64_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Add>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Sub>(int64_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Sub>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Mul>(int64_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Mul>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Div>(int64_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Div>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Mod>(int64_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Mod>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::GT>(int64_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int64_t, arith::GT>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::LT>(int64_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int64_t, arith::LT>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::GE>(int64_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int64_t, arith::GE>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::LE>(int64_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int64_t, arith::LE>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::EQ>(int64_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int64_t, arith::EQ>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::NE>(int64_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCPU, int64_t, arith::NE>(int64_t lhs, IdArray rhs);
template <DLDeviceType XPU, typename IdType, typename Op> template <DGLDeviceType XPU, typename IdType, typename Op>
IdArray UnaryElewise(IdArray lhs) { IdArray UnaryElewise(IdArray lhs) {
IdArray ret = NewIdArray(lhs->shape[0], lhs->ctx, lhs->dtype.bits); IdArray ret = NewIdArray(lhs->shape[0], lhs->ctx, lhs->dtype.bits);
const IdType* lhs_data = static_cast<IdType*>(lhs->data); const IdType* lhs_data = static_cast<IdType*>(lhs->data);
...@@ -167,28 +167,28 @@ IdArray UnaryElewise(IdArray lhs) { ...@@ -167,28 +167,28 @@ IdArray UnaryElewise(IdArray lhs) {
return ret; return ret;
} }
template IdArray UnaryElewise<kDLCPU, int32_t, arith::Neg>(IdArray lhs); template IdArray UnaryElewise<kDGLCPU, int32_t, arith::Neg>(IdArray lhs);
template IdArray UnaryElewise<kDLCPU, int64_t, arith::Neg>(IdArray lhs); template IdArray UnaryElewise<kDGLCPU, int64_t, arith::Neg>(IdArray lhs);
///////////////////////////// Full ///////////////////////////// ///////////////////////////// Full /////////////////////////////
template <DLDeviceType XPU, typename DType> template <DGLDeviceType XPU, typename DType>
NDArray Full(DType val, int64_t length, DLContext ctx) { NDArray Full(DType val, int64_t length, DGLContext ctx) {
NDArray ret = NDArray::Empty({length}, DLDataTypeTraits<DType>::dtype, ctx); NDArray ret = NDArray::Empty({length}, DGLDataTypeTraits<DType>::dtype, ctx);
DType* ret_data = static_cast<DType*>(ret->data); DType* ret_data = static_cast<DType*>(ret->data);
std::fill(ret_data, ret_data + length, val); std::fill(ret_data, ret_data + length, val);
return ret; return ret;
} }
template NDArray Full<kDLCPU, int32_t>(int32_t val, int64_t length, DLContext ctx); template NDArray Full<kDGLCPU, int32_t>(int32_t val, int64_t length, DGLContext ctx);
template NDArray Full<kDLCPU, int64_t>(int64_t val, int64_t length, DLContext ctx); template NDArray Full<kDGLCPU, int64_t>(int64_t val, int64_t length, DGLContext ctx);
template NDArray Full<kDLCPU, float>(float val, int64_t length, DLContext ctx); template NDArray Full<kDGLCPU, float>(float val, int64_t length, DGLContext ctx);
template NDArray Full<kDLCPU, double>(double val, int64_t length, DLContext ctx); template NDArray Full<kDGLCPU, double>(double val, int64_t length, DGLContext ctx);
///////////////////////////// Range ///////////////////////////// ///////////////////////////// Range /////////////////////////////
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
IdArray Range(IdType low, IdType high, DLContext ctx) { IdArray Range(IdType low, IdType high, DGLContext ctx) {
CHECK(high >= low) << "high must be bigger than low"; CHECK(high >= low) << "high must be bigger than low";
IdArray ret = NewIdArray(high - low, ctx, sizeof(IdType) * 8); IdArray ret = NewIdArray(high - low, ctx, sizeof(IdType) * 8);
IdType* ret_data = static_cast<IdType*>(ret->data); IdType* ret_data = static_cast<IdType*>(ret->data);
...@@ -196,12 +196,12 @@ IdArray Range(IdType low, IdType high, DLContext ctx) { ...@@ -196,12 +196,12 @@ IdArray Range(IdType low, IdType high, DLContext ctx) {
return ret; return ret;
} }
template IdArray Range<kDLCPU, int32_t>(int32_t, int32_t, DLContext); template IdArray Range<kDGLCPU, int32_t>(int32_t, int32_t, DGLContext);
template IdArray Range<kDLCPU, int64_t>(int64_t, int64_t, DLContext); template IdArray Range<kDGLCPU, int64_t>(int64_t, int64_t, DGLContext);
///////////////////////////// Relabel_ ///////////////////////////// ///////////////////////////// Relabel_ /////////////////////////////
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
IdArray Relabel_(const std::vector<IdArray>& arrays) { IdArray Relabel_(const std::vector<IdArray>& arrays) {
// build map & relabel // build map & relabel
IdType newid = 0; IdType newid = 0;
...@@ -216,7 +216,7 @@ IdArray Relabel_(const std::vector<IdArray>& arrays) { ...@@ -216,7 +216,7 @@ IdArray Relabel_(const std::vector<IdArray>& arrays) {
} }
} }
// map array // map array
IdArray maparr = NewIdArray(newid, DLContext{kDLCPU, 0}, sizeof(IdType) * 8); IdArray maparr = NewIdArray(newid, DGLContext{kDGLCPU, 0}, sizeof(IdType) * 8);
IdType* maparr_data = static_cast<IdType*>(maparr->data); IdType* maparr_data = static_cast<IdType*>(maparr->data);
for (const auto& kv : oldv2newv) { for (const auto& kv : oldv2newv) {
maparr_data[kv.second] = kv.first; maparr_data[kv.second] = kv.first;
...@@ -224,8 +224,8 @@ IdArray Relabel_(const std::vector<IdArray>& arrays) { ...@@ -224,8 +224,8 @@ IdArray Relabel_(const std::vector<IdArray>& arrays) {
return maparr; return maparr;
} }
template IdArray Relabel_<kDLCPU, int32_t>(const std::vector<IdArray>& arrays); template IdArray Relabel_<kDGLCPU, int32_t>(const std::vector<IdArray>& arrays);
template IdArray Relabel_<kDLCPU, int64_t>(const std::vector<IdArray>& arrays); template IdArray Relabel_<kDGLCPU, int64_t>(const std::vector<IdArray>& arrays);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
...@@ -14,7 +14,7 @@ using runtime::parallel_for; ...@@ -14,7 +14,7 @@ using runtime::parallel_for;
namespace aten { namespace aten {
namespace impl { namespace impl {
template<DLDeviceType XPU, typename DType, typename IdType> template<DGLDeviceType XPU, typename DType, typename IdType>
std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths) { std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths) {
const int64_t rows = lengths->shape[0]; const int64_t rows = lengths->shape[0];
const int64_t cols = (array->ndim == 1 ? array->shape[0] : array->shape[1]); const int64_t cols = (array->ndim == 1 ? array->shape[0] : array->shape[1]);
...@@ -41,16 +41,16 @@ std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths) { ...@@ -41,16 +41,16 @@ std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths) {
return std::make_pair(concat, offsets); return std::make_pair(concat, offsets);
} }
template std::pair<NDArray, IdArray> ConcatSlices<kDLCPU, int32_t, int32_t>(NDArray, IdArray); template std::pair<NDArray, IdArray> ConcatSlices<kDGLCPU, int32_t, int32_t>(NDArray, IdArray);
template std::pair<NDArray, IdArray> ConcatSlices<kDLCPU, int64_t, int32_t>(NDArray, IdArray); template std::pair<NDArray, IdArray> ConcatSlices<kDGLCPU, int64_t, int32_t>(NDArray, IdArray);
template std::pair<NDArray, IdArray> ConcatSlices<kDLCPU, float, int32_t>(NDArray, IdArray); template std::pair<NDArray, IdArray> ConcatSlices<kDGLCPU, float, int32_t>(NDArray, IdArray);
template std::pair<NDArray, IdArray> ConcatSlices<kDLCPU, double, int32_t>(NDArray, IdArray); template std::pair<NDArray, IdArray> ConcatSlices<kDGLCPU, double, int32_t>(NDArray, IdArray);
template std::pair<NDArray, IdArray> ConcatSlices<kDLCPU, int32_t, int64_t>(NDArray, IdArray); template std::pair<NDArray, IdArray> ConcatSlices<kDGLCPU, int32_t, int64_t>(NDArray, IdArray);
template std::pair<NDArray, IdArray> ConcatSlices<kDLCPU, int64_t, int64_t>(NDArray, IdArray); template std::pair<NDArray, IdArray> ConcatSlices<kDGLCPU, int64_t, int64_t>(NDArray, IdArray);
template std::pair<NDArray, IdArray> ConcatSlices<kDLCPU, float, int64_t>(NDArray, IdArray); template std::pair<NDArray, IdArray> ConcatSlices<kDGLCPU, float, int64_t>(NDArray, IdArray);
template std::pair<NDArray, IdArray> ConcatSlices<kDLCPU, double, int64_t>(NDArray, IdArray); template std::pair<NDArray, IdArray> ConcatSlices<kDGLCPU, double, int64_t>(NDArray, IdArray);
template<DLDeviceType XPU, typename DType> template<DGLDeviceType XPU, typename DType>
std::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, DType pad_value) { std::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, DType pad_value) {
CHECK_NDIM(array, 2, "array"); CHECK_NDIM(array, 2, "array");
const DType *array_data = static_cast<DType *>(array->data); const DType *array_data = static_cast<DType *>(array->data);
...@@ -75,10 +75,10 @@ std::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, DType pad_value) { ...@@ -75,10 +75,10 @@ std::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, DType pad_value) {
return std::make_tuple(ret.first, length, ret.second); return std::make_tuple(ret.first, length, ret.second);
} }
template std::tuple<NDArray, IdArray, IdArray> Pack<kDLCPU, int32_t>(NDArray, int32_t); template std::tuple<NDArray, IdArray, IdArray> Pack<kDGLCPU, int32_t>(NDArray, int32_t);
template std::tuple<NDArray, IdArray, IdArray> Pack<kDLCPU, int64_t>(NDArray, int64_t); template std::tuple<NDArray, IdArray, IdArray> Pack<kDGLCPU, int64_t>(NDArray, int64_t);
template std::tuple<NDArray, IdArray, IdArray> Pack<kDLCPU, float>(NDArray, float); template std::tuple<NDArray, IdArray, IdArray> Pack<kDGLCPU, float>(NDArray, float);
template std::tuple<NDArray, IdArray, IdArray> Pack<kDLCPU, double>(NDArray, double); template std::tuple<NDArray, IdArray, IdArray> Pack<kDGLCPU, double>(NDArray, double);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
...@@ -11,7 +11,7 @@ using runtime::NDArray; ...@@ -11,7 +11,7 @@ using runtime::NDArray;
namespace aten { namespace aten {
namespace impl { namespace impl {
template <DLDeviceType XPU, typename DType, typename IdType> template <DGLDeviceType XPU, typename DType, typename IdType>
NDArray Repeat(NDArray array, IdArray repeats) { NDArray Repeat(NDArray array, IdArray repeats) {
CHECK(array->shape[0] == repeats->shape[0]) << "shape of array and repeats mismatch"; CHECK(array->shape[0] == repeats->shape[0]) << "shape of array and repeats mismatch";
...@@ -34,14 +34,14 @@ NDArray Repeat(NDArray array, IdArray repeats) { ...@@ -34,14 +34,14 @@ NDArray Repeat(NDArray array, IdArray repeats) {
return result; return result;
} }
template NDArray Repeat<kDLCPU, int32_t, int32_t>(NDArray, IdArray); template NDArray Repeat<kDGLCPU, int32_t, int32_t>(NDArray, IdArray);
template NDArray Repeat<kDLCPU, int64_t, int32_t>(NDArray, IdArray); template NDArray Repeat<kDGLCPU, int64_t, int32_t>(NDArray, IdArray);
template NDArray Repeat<kDLCPU, float, int32_t>(NDArray, IdArray); template NDArray Repeat<kDGLCPU, float, int32_t>(NDArray, IdArray);
template NDArray Repeat<kDLCPU, double, int32_t>(NDArray, IdArray); template NDArray Repeat<kDGLCPU, double, int32_t>(NDArray, IdArray);
template NDArray Repeat<kDLCPU, int32_t, int64_t>(NDArray, IdArray); template NDArray Repeat<kDGLCPU, int32_t, int64_t>(NDArray, IdArray);
template NDArray Repeat<kDLCPU, int64_t, int64_t>(NDArray, IdArray); template NDArray Repeat<kDGLCPU, int64_t, int64_t>(NDArray, IdArray);
template NDArray Repeat<kDLCPU, float, int64_t>(NDArray, IdArray); template NDArray Repeat<kDGLCPU, float, int64_t>(NDArray, IdArray);
template NDArray Repeat<kDLCPU, double, int64_t>(NDArray, IdArray); template NDArray Repeat<kDGLCPU, double, int64_t>(NDArray, IdArray);
}; // namespace impl }; // namespace impl
}; // namespace aten }; // namespace aten
......
...@@ -11,7 +11,7 @@ using runtime::NDArray; ...@@ -11,7 +11,7 @@ using runtime::NDArray;
namespace aten { namespace aten {
namespace impl { namespace impl {
template <DLDeviceType XPU, typename DType, typename IdType> template <DGLDeviceType XPU, typename DType, typename IdType>
NDArray Scatter(NDArray array, IdArray indices) { NDArray Scatter(NDArray array, IdArray indices) {
NDArray result = NDArray::Empty({indices->shape[0]}, array->dtype, array->ctx); NDArray result = NDArray::Empty({indices->shape[0]}, array->dtype, array->ctx);
...@@ -25,16 +25,16 @@ NDArray Scatter(NDArray array, IdArray indices) { ...@@ -25,16 +25,16 @@ NDArray Scatter(NDArray array, IdArray indices) {
return result; return result;
} }
template NDArray Scatter<kDLCPU, int32_t, int32_t>(NDArray, IdArray); template NDArray Scatter<kDGLCPU, int32_t, int32_t>(NDArray, IdArray);
template NDArray Scatter<kDLCPU, int64_t, int32_t>(NDArray, IdArray); template NDArray Scatter<kDGLCPU, int64_t, int32_t>(NDArray, IdArray);
template NDArray Scatter<kDLCPU, float, int32_t>(NDArray, IdArray); template NDArray Scatter<kDGLCPU, float, int32_t>(NDArray, IdArray);
template NDArray Scatter<kDLCPU, double, int32_t>(NDArray, IdArray); template NDArray Scatter<kDGLCPU, double, int32_t>(NDArray, IdArray);
template NDArray Scatter<kDLCPU, int32_t, int64_t>(NDArray, IdArray); template NDArray Scatter<kDGLCPU, int32_t, int64_t>(NDArray, IdArray);
template NDArray Scatter<kDLCPU, int64_t, int64_t>(NDArray, IdArray); template NDArray Scatter<kDGLCPU, int64_t, int64_t>(NDArray, IdArray);
template NDArray Scatter<kDLCPU, float, int64_t>(NDArray, IdArray); template NDArray Scatter<kDGLCPU, float, int64_t>(NDArray, IdArray);
template NDArray Scatter<kDLCPU, double, int64_t>(NDArray, IdArray); template NDArray Scatter<kDGLCPU, double, int64_t>(NDArray, IdArray);
template <DLDeviceType XPU, typename DType, typename IdType> template <DGLDeviceType XPU, typename DType, typename IdType>
void Scatter_(IdArray index, NDArray value, NDArray out) { void Scatter_(IdArray index, NDArray value, NDArray out) {
const int64_t len = index->shape[0]; const int64_t len = index->shape[0];
const IdType* idx = index.Ptr<IdType>(); const IdType* idx = index.Ptr<IdType>();
...@@ -47,14 +47,14 @@ void Scatter_(IdArray index, NDArray value, NDArray out) { ...@@ -47,14 +47,14 @@ void Scatter_(IdArray index, NDArray value, NDArray out) {
}); });
} }
template void Scatter_<kDLCPU, int32_t, int32_t>(IdArray, NDArray, NDArray); template void Scatter_<kDGLCPU, int32_t, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLCPU, int64_t, int32_t>(IdArray, NDArray, NDArray); template void Scatter_<kDGLCPU, int64_t, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLCPU, float, int32_t>(IdArray, NDArray, NDArray); template void Scatter_<kDGLCPU, float, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLCPU, double, int32_t>(IdArray, NDArray, NDArray); template void Scatter_<kDGLCPU, double, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLCPU, int32_t, int64_t>(IdArray, NDArray, NDArray); template void Scatter_<kDGLCPU, int32_t, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLCPU, int64_t, int64_t>(IdArray, NDArray, NDArray); template void Scatter_<kDGLCPU, int64_t, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLCPU, float, int64_t>(IdArray, NDArray, NDArray); template void Scatter_<kDGLCPU, float, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLCPU, double, int64_t>(IdArray, NDArray, NDArray); template void Scatter_<kDGLCPU, double, int64_t>(IdArray, NDArray, NDArray);
}; // namespace impl }; // namespace impl
}; // namespace aten }; // namespace aten
......
...@@ -160,7 +160,7 @@ using runtime::NDArray; ...@@ -160,7 +160,7 @@ using runtime::NDArray;
namespace aten { namespace aten {
namespace impl { namespace impl {
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
std::pair<IdArray, IdArray> Sort(IdArray array, int /* num_bits */) { std::pair<IdArray, IdArray> Sort(IdArray array, int /* num_bits */) {
const int64_t nitem = array->shape[0]; const int64_t nitem = array->shape[0];
IdArray val = array.Clone(); IdArray val = array.Clone();
...@@ -181,8 +181,8 @@ std::pair<IdArray, IdArray> Sort(IdArray array, int /* num_bits */) { ...@@ -181,8 +181,8 @@ std::pair<IdArray, IdArray> Sort(IdArray array, int /* num_bits */) {
return std::make_pair(val, idx); return std::make_pair(val, idx);
} }
template std::pair<IdArray, IdArray> Sort<kDLCPU, int32_t>(IdArray, int num_bits); template std::pair<IdArray, IdArray> Sort<kDGLCPU, int32_t>(IdArray, int num_bits);
template std::pair<IdArray, IdArray> Sort<kDLCPU, int64_t>(IdArray, int num_bits); template std::pair<IdArray, IdArray> Sort<kDGLCPU, int64_t>(IdArray, int num_bits);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
...@@ -88,7 +88,7 @@ class IdHashMap { ...@@ -88,7 +88,7 @@ class IdHashMap {
// Return all the old ids collected so far, ordered by new id. // Return all the old ids collected so far, ordered by new id.
IdArray Values() const { IdArray Values() const {
IdArray values = NewIdArray(oldv2newv_.size(), DLContext{kDLCPU, 0}, sizeof(IdType) * 8); IdArray values = NewIdArray(oldv2newv_.size(), DGLContext{kDGLCPU, 0}, sizeof(IdType) * 8);
IdType* values_data = static_cast<IdType*>(values->data); IdType* values_data = static_cast<IdType*>(values->data);
for (auto pair : oldv2newv_) for (auto pair : oldv2newv_)
values_data[pair.second] = pair.first; values_data[pair.second] = pair.first;
......
...@@ -13,7 +13,7 @@ namespace aten { ...@@ -13,7 +13,7 @@ namespace aten {
namespace impl { namespace impl {
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo) { std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo) {
const int64_t nnz = coo.row->shape[0]; const int64_t nnz = coo.row->shape[0];
const IdType* coo_row_data = static_cast<IdType*>(coo.row->data); const IdType* coo_row_data = static_cast<IdType*>(coo.row->data);
...@@ -44,8 +44,8 @@ std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo) { ...@@ -44,8 +44,8 @@ std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo) {
return std::make_pair(coo_result, NDArray::FromVector(count)); return std::make_pair(coo_result, NDArray::FromVector(count));
} }
template std::pair<COOMatrix, IdArray> COOCoalesce<kDLCPU, int32_t>(COOMatrix); template std::pair<COOMatrix, IdArray> COOCoalesce<kDGLCPU, int32_t>(COOMatrix);
template std::pair<COOMatrix, IdArray> COOCoalesce<kDLCPU, int64_t>(COOMatrix); template std::pair<COOMatrix, IdArray> COOCoalesce<kDGLCPU, int64_t>(COOMatrix);
}; // namespace impl }; // namespace impl
}; // namespace aten }; // namespace aten
......
...@@ -14,7 +14,7 @@ namespace dgl { ...@@ -14,7 +14,7 @@ namespace dgl {
namespace aten { namespace aten {
namespace impl { namespace impl {
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
COOMatrix COOLineGraph(const COOMatrix &coo, bool backtracking) { COOMatrix COOLineGraph(const COOMatrix &coo, bool backtracking) {
const int64_t nnz = coo.row->shape[0]; const int64_t nnz = coo.row->shape[0];
IdType* coo_row = coo.row.Ptr<IdType>(); IdType* coo_row = coo.row.Ptr<IdType>();
...@@ -50,8 +50,8 @@ COOMatrix COOLineGraph(const COOMatrix &coo, bool backtracking) { ...@@ -50,8 +50,8 @@ COOMatrix COOLineGraph(const COOMatrix &coo, bool backtracking) {
} }
template COOMatrix COOLineGraph<kDLCPU, int32_t>(const COOMatrix &coo, bool backtracking); template COOMatrix COOLineGraph<kDGLCPU, int32_t>(const COOMatrix &coo, bool backtracking);
template COOMatrix COOLineGraph<kDLCPU, int64_t>(const COOMatrix &coo, bool backtracking); template COOMatrix COOLineGraph<kDGLCPU, int64_t>(const COOMatrix &coo, bool backtracking);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment