"gallery/vscode:/vscode.git/clone" did not exist on "11e49de410ec84ec669293a91dfaa13a53c9bc47"
Unverified Commit cded5b80 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Feature] Bump DLPack to v0.7 and decouple DLPack from the core library (#4454)

* rename `DLContext` to `DGLContext`

* rename `kDLGPU` to `kDLCUDA`

* replace DLTensor with DGLArray

* fix linting

* Unify DGLType and DLDataType to DGLDataType

* Fix FFI

* rename DLDeviceType to DGLDeviceType

* decouple dlpack from the core library

* fix bug

* fix lint

* fix merge

* fix build

* address comments

* rename dl_converter to dlpack_convert

* remove redundant comments
parent f1689ad0
......@@ -3,16 +3,16 @@ from libcpp.vector cimport vector
from libcpp cimport bool
from cpython.version cimport PY_MAJOR_VERSION
from cpython cimport pycapsule
from libc.stdint cimport int64_t, uint64_t, uint8_t, uint16_t
from libc.stdint cimport int32_t, int64_t, uint64_t, uint8_t, uint16_t
import ctypes
cdef enum DGLTypeCode:
cdef enum DGLObjectTypeCode:
kInt = 0
kUInt = 1
kFloat = 2
kHandle = 3
kNull = 4
kDGLType = 5
kDGLDataType = 5
kDGLContext = 6
kArrayHandle = 7
kObjectHandle = 8
......@@ -24,26 +24,26 @@ cdef enum DGLTypeCode:
kExtBegin = 15
cdef extern from "dgl/runtime/c_runtime_api.h":
ctypedef struct DLDataType:
ctypedef struct DGLDataType:
uint8_t code
uint8_t bits
uint16_t lanes
ctypedef struct DLContext:
int device_type
int device_id
ctypedef struct DGLContext:
int32_t device_type
int32_t device_id
ctypedef struct DLTensor:
ctypedef struct DGLArray:
void* data
DLContext ctx
int ndim
DLDataType dtype
DGLContext ctx
int32_t ndim
DGLDataType dtype
int64_t* shape
int64_t* strides
uint64_t byte_offset
ctypedef struct DLManagedTensor:
DLTensor dl_tensor
DGLArray dl_tensor
void* manager_ctx
void (*deleter)(DLManagedTensor* self)
......@@ -52,13 +52,11 @@ cdef extern from "dgl/runtime/c_runtime_api.h":
double v_float64
void* v_handle
const char* v_str
DLDataType v_type
DLContext v_ctx
DGLDataType v_type
DGLContext v_ctx
ctypedef int64_t dgl_index_t
ctypedef DLTensor* DLTensorHandle
ctypedef DLTensor DGLArray
ctypedef DGLArray* CDGLArrayHandle
ctypedef DGLArray* DGLArrayHandle
ctypedef void* DGLStreamHandle
ctypedef void* DGLRetValueHandle
ctypedef void* DGLFunctionHandle
......@@ -94,9 +92,9 @@ cdef extern from "dgl/runtime/c_runtime_api.h":
int DGLCbArgToReturn(DGLValue* value, int code)
int DGLArrayAlloc(dgl_index_t* shape,
dgl_index_t ndim,
DLDataType dtype,
DLContext ctx,
DLTensorHandle* out)
DGLDataType dtype,
DGLContext ctx,
DGLArrayHandle* out)
int DGLArrayAllocSharedMem(const char *mem_name,
const dgl_index_t *shape,
int ndim,
......@@ -104,16 +102,10 @@ cdef extern from "dgl/runtime/c_runtime_api.h":
int dtype_bits,
int dtype_lanes,
bool is_create,
CDGLArrayHandle* out)
int DGLArrayFree(DLTensorHandle handle)
int DGLArrayCopyFromTo(DLTensorHandle src,
DLTensorHandle to)
int DGLArrayFromDLPack(DLManagedTensor* arr_from,
DLTensorHandle* out)
int DGLArrayToDLPack(DLTensorHandle arr_from,
DLManagedTensor** out,
int alignment)
void DGLDLManagedTensorCallDeleter(DLManagedTensor* dltensor)
DGLArrayHandle* out)
int DGLArrayFree(DGLArrayHandle handle)
int DGLArrayCopyFromTo(DGLArrayHandle src,
DGLArrayHandle to)
cdef extern from "dgl/runtime/c_object_api.h":
int DGLObjectFree(ObjectHandle handle)
......@@ -127,6 +119,14 @@ cdef extern from "dgl/runtime/c_object_api.h":
int* out_type_code,
int* out_success)
cdef extern from "dgl/runtime/dlpack_convert.h":
int DGLArrayFromDLPack(DLManagedTensor* arr_from,
DGLArrayHandle* out)
int DGLArrayToDLPack(DGLArrayHandle arr_from,
DLManagedTensor** out,
int alignment)
void DGLDLManagedTensorCallDeleter(DLManagedTensor* dltensor)
cdef inline py_str(const char* x):
if PY_MAJOR_VERSION < 3:
return x
......
......@@ -4,7 +4,9 @@ from cpython cimport Py_INCREF, Py_DECREF
from numbers import Number, Integral
from ..base import string_types
from ..object_generic import convert_to_object, ObjectGeneric
from ..runtime_ctypes import DGLType, DGLContext, DGLByteArray
from ..runtime_ctypes import DGLDataType as CTypesDGLDataType, \
DGLContext as CTypesDGLContext, \
DGLByteArray
cdef void dgl_callback_finalize(void* fhandle):
......@@ -107,13 +109,13 @@ cdef inline int make_arg(object arg,
elif isinstance(arg, Number):
value[0].v_float64 = arg
tcode[0] = kFloat
elif isinstance(arg, DGLType):
elif isinstance(arg, CTypesDGLDataType):
tstr = c_str(str(arg))
value[0].v_str = tstr
tcode[0] = kStr
temp_args.append(tstr)
elif isinstance(arg, DGLContext):
value[0].v_ctx = (<DLContext*>(
elif isinstance(arg, CTypesDGLContext):
value[0].v_ctx = (<DGLContext*>(
<unsigned long long>ctypes.addressof(arg)))[0]
tcode[0] = kDGLContext
elif isinstance(arg, bytearray):
......@@ -183,7 +185,7 @@ cdef inline object make_ret(DGLValue value, int tcode):
elif tcode == kHandle:
return ctypes_handle(value.v_handle)
elif tcode == kDGLContext:
return DGLContext(value.v_ctx.device_type, value.v_ctx.device_id)
return CTypesDGLContext(value.v_ctx.device_type, value.v_ctx.device_id)
# (minjie): class module are not used in DGL.
#elif tcode == kModuleHandle:
# return _CLASS_MODULE(ctypes_handle(value.v_handle))
......
from ..runtime_ctypes import DGLArrayHandle
from ..runtime_ctypes import DGLArrayHandle as PyDGLArrayHandle
cdef const char* _c_str_dltensor = "dltensor"
cdef const char* _c_str_used_dltensor = "used_dltensor"
......@@ -13,7 +13,7 @@ cdef void _c_dlpack_deleter(object pycaps):
def _from_dlpack(object dltensor):
cdef DLManagedTensor* ptr
cdef DLTensorHandle chandle
cdef DGLArrayHandle chandle
if pycapsule.PyCapsule_IsValid(dltensor, _c_str_dltensor):
ptr = <DLManagedTensor*>pycapsule.PyCapsule_GetPointer(dltensor, _c_str_dltensor)
CALL(DGLArrayFromDLPack(ptr, &chandle))
......@@ -25,7 +25,7 @@ def _from_dlpack(object dltensor):
cdef class NDArrayBase:
cdef DLTensor* chandle
cdef DGLArray* chandle
cdef int c_is_view
cdef inline _set_handle(self, handle):
......@@ -34,7 +34,7 @@ cdef class NDArrayBase:
self.chandle = NULL
else:
ptr = ctypes.cast(handle, ctypes.c_void_p).value
self.chandle = <DLTensor*>(ptr)
self.chandle = <DGLArray*>(ptr)
property _dgl_handle:
def __get__(self):
......@@ -46,7 +46,7 @@ cdef class NDArrayBase:
return None
else:
return ctypes.cast(
<unsigned long long>self.chandle, DGLArrayHandle)
<unsigned long long>self.chandle, PyDGLArrayHandle)
def __set__(self, value):
self._set_handle(value)
......@@ -82,7 +82,7 @@ cdef class NDArrayBase:
cdef c_make_array(void* chandle, is_view):
ret = _CLASS_NDARRAY(None, is_view)
(<NDArrayBase>ret).chandle = <DLTensor*>chandle
(<NDArrayBase>ret).chandle = <DGLArray*>chandle
return ret
......
......@@ -6,7 +6,7 @@ import sys
import ctypes
import numpy as np
from .base import _LIB, check_call, c_array, string_types, _FFI_MODE, c_str
from .runtime_ctypes import DGLType, DGLContext, DGLArray, DGLArrayHandle
from .runtime_ctypes import DGLDataType, DGLContext, DGLArray, DGLArrayHandle
from .runtime_ctypes import TypeCode, dgl_shape_index_t
......@@ -72,7 +72,7 @@ def numpyasarray(np_data):
arr.data = data.ctypes.data_as(ctypes.c_void_p)
arr.shape = shape
arr.strides = None
arr.dtype = DGLType(np.dtype(data.dtype).name)
arr.dtype = DGLDataType(np.dtype(data.dtype).name)
arr.ndim = data.ndim
# CPU device
arr.ctx = context(1, 0)
......@@ -101,7 +101,7 @@ def empty(shape, dtype="float32", ctx=context(1, 0)):
shape = c_array(dgl_shape_index_t, shape)
ndim = ctypes.c_int(len(shape))
handle = DGLArrayHandle()
dtype = DGLType(dtype)
dtype = DGLDataType(dtype)
check_call(_LIB.DGLArrayAlloc(
shape, ndim,
ctypes.c_int(dtype.type_code),
......@@ -139,7 +139,7 @@ def empty_shared_mem(name, is_create, shape, dtype="float32"):
shape = c_array(dgl_shape_index_t, shape)
ndim = ctypes.c_int(len(shape))
handle = DGLArrayHandle()
dtype = DGLType(dtype)
dtype = DGLDataType(dtype)
check_call(_LIB.DGLArrayAllocSharedMem(
name, shape, ndim,
ctypes.c_int(dtype.type_code),
......@@ -254,7 +254,7 @@ class NDArrayBase(_NDArrayBase):
except:
raise TypeError('array must be an array_like data,' +
'type %s is not supported' % str(type(source_array)))
t = DGLType(self.dtype)
t = DGLDataType(self.dtype)
shape, dtype = self.shape, self.dtype
if t.lanes > 1:
shape = shape + (t.lanes,)
......@@ -286,7 +286,7 @@ class NDArrayBase(_NDArrayBase):
np_arr : numpy.ndarray
The corresponding numpy array.
"""
t = DGLType(self.dtype)
t = DGLDataType(self.dtype)
shape, dtype = self.shape, self.dtype
if t.lanes > 1:
shape = shape + (t.lanes,)
......
......@@ -17,7 +17,7 @@ class TypeCode(object):
FLOAT = 2
HANDLE = 3
NULL = 4
DGL_TYPE = 5
DGL_DATA_TYPE = 5
DGL_CONTEXT = 6
ARRAY_HANDLE = 7
OBJECT_HANDLE = 8
......@@ -33,7 +33,7 @@ class DGLByteArray(ctypes.Structure):
_fields_ = [("data", ctypes.POINTER(ctypes.c_byte)),
("size", ctypes.c_size_t)]
class DGLType(ctypes.Structure):
class DGLDataType(ctypes.Structure):
"""DGL datatype structure"""
_fields_ = [("type_code", ctypes.c_uint8),
("bits", ctypes.c_uint8),
......@@ -50,7 +50,7 @@ class DGLType(ctypes.Structure):
if type_str in cls._cache:
return cls._cache[type_str]
inst = super(DGLType, cls).__new__(DGLType)
inst = super(DGLDataType, cls).__new__(DGLDataType)
if isinstance(type_str, np.dtype):
type_str = str(type_str)
......@@ -84,7 +84,7 @@ class DGLType(ctypes.Structure):
pass
def __repr__(self):
x = "%s%d" % (DGLType.CODE2STR[self.type_code], self.bits)
x = "%s%d" % (DGLDataType.CODE2STR[self.type_code], self.bits)
if self.lanes != 1:
x += "x%d" % self.lanes
return x
......@@ -250,7 +250,7 @@ class DGLArray(ctypes.Structure):
_fields_ = [("data", ctypes.c_void_p),
("ctx", DGLContext),
("ndim", ctypes.c_int),
("dtype", DGLType),
("dtype", DGLDataType),
("shape", ctypes.POINTER(dgl_shape_index_t)),
("strides", ctypes.POINTER(dgl_shape_index_t)),
("byte_offset", ctypes.c_uint64)]
......
......@@ -13,7 +13,7 @@ import numpy as _np
from ._ffi.object import register_object, ObjectBase
from ._ffi.function import _init_api
from ._ffi.ndarray import DGLContext, DGLType, NDArrayBase
from ._ffi.ndarray import DGLContext, DGLDataType, NDArrayBase
from ._ffi.ndarray import context, empty, empty_shared_mem, from_dlpack, numpyasarray
from ._ffi.ndarray import _set_class_ndarray
from . import backend as F
......
......@@ -19,8 +19,8 @@ using namespace dgl::runtime;
namespace dgl {
namespace aten {
IdArray NewIdArray(int64_t length, DLContext ctx, uint8_t nbits) {
return IdArray::Empty({length}, DLDataType{kDLInt, nbits, 1}, ctx);
IdArray NewIdArray(int64_t length, DGLContext ctx, uint8_t nbits) {
return IdArray::Empty({length}, DGLDataType{kDGLInt, nbits, 1}, ctx);
}
IdArray Clone(IdArray arr) {
......@@ -29,7 +29,7 @@ IdArray Clone(IdArray arr) {
return ret;
}
IdArray Range(int64_t low, int64_t high, uint8_t nbits, DLContext ctx) {
IdArray Range(int64_t low, int64_t high, uint8_t nbits, DGLContext ctx) {
IdArray ret;
ATEN_XPU_SWITCH_CUDA(ctx.device_type, XPU, "Range", {
if (nbits == 32) {
......@@ -43,7 +43,7 @@ IdArray Range(int64_t low, int64_t high, uint8_t nbits, DLContext ctx) {
return ret;
}
IdArray Full(int64_t val, int64_t length, uint8_t nbits, DLContext ctx) {
IdArray Full(int64_t val, int64_t length, uint8_t nbits, DGLContext ctx) {
IdArray ret;
ATEN_XPU_SWITCH_CUDA(ctx.device_type, XPU, "Full", {
if (nbits == 32) {
......@@ -58,7 +58,7 @@ IdArray Full(int64_t val, int64_t length, uint8_t nbits, DLContext ctx) {
}
template <typename DType>
NDArray Full(DType val, int64_t length, DLContext ctx) {
NDArray Full(DType val, int64_t length, DGLContext ctx) {
NDArray ret;
ATEN_XPU_SWITCH_CUDA(ctx.device_type, XPU, "Full", {
ret = impl::Full<XPU, DType>(val, length, ctx);
......@@ -66,10 +66,10 @@ NDArray Full(DType val, int64_t length, DLContext ctx) {
return ret;
}
template NDArray Full<int32_t>(int32_t val, int64_t length, DLContext ctx);
template NDArray Full<int64_t>(int64_t val, int64_t length, DLContext ctx);
template NDArray Full<float>(float val, int64_t length, DLContext ctx);
template NDArray Full<double>(double val, int64_t length, DLContext ctx);
template NDArray Full<int32_t>(int32_t val, int64_t length, DGLContext ctx);
template NDArray Full<int64_t>(int64_t val, int64_t length, DGLContext ctx);
template NDArray Full<float>(float val, int64_t length, DGLContext ctx);
template NDArray Full<double>(double val, int64_t length, DGLContext ctx);
IdArray AsNumBits(IdArray arr, uint8_t bits) {
CHECK(bits == 32 || bits == 64)
......@@ -315,7 +315,7 @@ std::pair<IdArray, IdArray> Sort(IdArray array, const int num_bits) {
std::string ToDebugString(NDArray array) {
std::ostringstream oss;
NDArray a = array.CopyTo(DLContext{kDLCPU, 0});
NDArray a = array.CopyTo(DGLContext{kDGLCPU, 0});
oss << "array([";
ATEN_DTYPE_SWITCH(a->dtype, DType, "array", {
for (int64_t i = 0; i < std::min<int64_t>(a.NumElements(), 10L); ++i) {
......@@ -1132,10 +1132,10 @@ DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLExistSharedMemArray")
DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLArrayCastToSigned")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
NDArray array = args[0];
CHECK_EQ(array->dtype.code, kDLUInt);
CHECK_EQ(array->dtype.code, kDGLUInt);
std::vector<int64_t> shape(array->shape, array->shape + array->ndim);
DLDataType dtype = array->dtype;
dtype.code = kDLInt;
DGLDataType dtype = array->dtype;
dtype.code = kDGLInt;
*rv = array.CreateView(shape, dtype, 0);
});
......
......@@ -16,176 +16,176 @@ namespace dgl {
namespace aten {
namespace impl {
template <DLDeviceType XPU, typename IdType>
IdArray Full(IdType val, int64_t length, DLContext ctx);
template <DGLDeviceType XPU, typename IdType>
IdArray Full(IdType val, int64_t length, DGLContext ctx);
template <DLDeviceType XPU, typename IdType>
IdArray Range(IdType low, IdType high, DLContext ctx);
template <DGLDeviceType XPU, typename IdType>
IdArray Range(IdType low, IdType high, DGLContext ctx);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
IdArray AsNumBits(IdArray arr, uint8_t bits);
template <DLDeviceType XPU, typename IdType, typename Op>
template <DGLDeviceType XPU, typename IdType, typename Op>
IdArray BinaryElewise(IdArray lhs, IdArray rhs);
template <DLDeviceType XPU, typename IdType, typename Op>
template <DGLDeviceType XPU, typename IdType, typename Op>
IdArray BinaryElewise(IdArray lhs, IdType rhs);
template <DLDeviceType XPU, typename IdType, typename Op>
template <DGLDeviceType XPU, typename IdType, typename Op>
IdArray BinaryElewise(IdType lhs, IdArray rhs);
template <DLDeviceType XPU, typename IdType, typename Op>
template <DGLDeviceType XPU, typename IdType, typename Op>
IdArray UnaryElewise(IdArray array);
template <DLDeviceType XPU, typename DType, typename IdType>
template <DGLDeviceType XPU, typename DType, typename IdType>
NDArray IndexSelect(NDArray array, IdArray index);
template <DLDeviceType XPU, typename DType>
template <DGLDeviceType XPU, typename DType>
DType IndexSelect(NDArray array, int64_t index);
template <DLDeviceType XPU, typename DType>
template <DGLDeviceType XPU, typename DType>
IdArray NonZero(BoolArray bool_arr);
template <DLDeviceType XPU, typename DType>
template <DGLDeviceType XPU, typename DType>
std::pair<IdArray, IdArray> Sort(IdArray array, int num_bits);
template <DLDeviceType XPU, typename DType, typename IdType>
template <DGLDeviceType XPU, typename DType, typename IdType>
NDArray Scatter(NDArray array, IdArray indices);
template <DLDeviceType XPU, typename DType, typename IdType>
template <DGLDeviceType XPU, typename DType, typename IdType>
void Scatter_(IdArray index, NDArray value, NDArray out);
template <DLDeviceType XPU, typename DType, typename IdType>
template <DGLDeviceType XPU, typename DType, typename IdType>
NDArray Repeat(NDArray array, IdArray repeats);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
IdArray Relabel_(const std::vector<IdArray>& arrays);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
NDArray Concat(const std::vector<IdArray>& arrays);
template <DLDeviceType XPU, typename DType>
template <DGLDeviceType XPU, typename DType>
std::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, DType pad_value);
template <DLDeviceType XPU, typename DType, typename IdType>
template <DGLDeviceType XPU, typename DType, typename IdType>
std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
IdArray CumSum(IdArray array, bool prepend_zero);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
IdArray NonZero(NDArray array);
// sparse arrays
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
runtime::NDArray CSRIsNonZero(CSRMatrix csr, runtime::NDArray row, runtime::NDArray col);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
bool CSRHasDuplicate(CSRMatrix csr);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
int64_t CSRGetRowNNZ(CSRMatrix csr, int64_t row);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
runtime::NDArray CSRGetRowNNZ(CSRMatrix csr, runtime::NDArray row);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
runtime::NDArray CSRGetRowColumnIndices(CSRMatrix csr, int64_t row);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
runtime::NDArray CSRGetRowData(CSRMatrix csr, int64_t row);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
bool CSRIsSorted(CSRMatrix csr);
template <DLDeviceType XPU, typename IdType, typename DType>
template <DGLDeviceType XPU, typename IdType, typename DType>
runtime::NDArray CSRGetData(
CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols, bool return_eids,
runtime::NDArray weights, DType filler);
template <DLDeviceType XPU, typename IdType, typename DType>
template <DGLDeviceType XPU, typename IdType, typename DType>
runtime::NDArray CSRGetData(
CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols,
runtime::NDArray weights, DType filler) {
return CSRGetData<XPU, IdType, DType>(csr, rows, cols, false, weights, filler);
}
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
NDArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols) {
return CSRGetData<XPU, IdType, IdType>(csr, rows, cols, true, NullArray(rows->dtype), -1);
}
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
std::vector<runtime::NDArray> CSRGetDataAndIndices(
CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRTranspose(CSRMatrix csr);
// Convert CSR to COO
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
COOMatrix CSRToCOO(CSRMatrix csr);
// Convert CSR to COO using data array as order
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
COOMatrix CSRToCOODataAsOrder(CSRMatrix csr);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRSliceRows(CSRMatrix csr, runtime::NDArray rows);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
void CSRSort_(CSRMatrix* csr);
template <DLDeviceType XPU, typename IdType, typename TagType>
template <DGLDeviceType XPU, typename IdType, typename TagType>
std::pair<CSRMatrix, NDArray> CSRSortByTag(
const CSRMatrix &csr, IdArray tag_array, int64_t num_tags);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRReorder(CSRMatrix csr, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
COOMatrix COOReorder(COOMatrix coo, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries);
// FloatType is the type of probability data.
template <DLDeviceType XPU, typename IdType, typename FloatType>
template <DGLDeviceType XPU, typename IdType, typename FloatType>
COOMatrix CSRRowWiseSampling(
CSRMatrix mat, IdArray rows, int64_t num_samples, FloatArray prob, bool replace);
// FloatType is the type of probability data.
template <DLDeviceType XPU, typename IdType, typename FloatType>
template <DGLDeviceType XPU, typename IdType, typename FloatType>
COOMatrix CSRRowWisePerEtypeSampling(
CSRMatrix mat, IdArray rows, IdArray etypes,
const std::vector<int64_t>& num_samples, FloatArray prob, bool replace,
bool etype_sorted);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
COOMatrix CSRRowWiseSamplingUniform(
CSRMatrix mat, IdArray rows, int64_t num_samples, bool replace);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
COOMatrix CSRRowWisePerEtypeSamplingUniform(
CSRMatrix mat, IdArray rows, IdArray etypes, const std::vector<int64_t>& num_samples,
bool replace, bool etype_sorted);
// FloatType is the type of weight data.
template <DLDeviceType XPU, typename IdType, typename DType>
template <DGLDeviceType XPU, typename IdType, typename DType>
COOMatrix CSRRowWiseTopk(
CSRMatrix mat, IdArray rows, int64_t k, NDArray weight, bool ascending);
template <DLDeviceType XPU, typename IdType, typename FloatType>
template <DGLDeviceType XPU, typename IdType, typename FloatType>
COOMatrix CSRRowWiseSamplingBiased(
CSRMatrix mat,
IdArray rows,
......@@ -194,7 +194,7 @@ COOMatrix CSRRowWiseSamplingBiased(
FloatArray bias,
bool replace);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
const CSRMatrix& csr,
int64_t num_samples,
......@@ -204,117 +204,117 @@ std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
double redundancy);
// Union CSRMatrixes
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
CSRMatrix UnionCsr(const std::vector<CSRMatrix>& csrs);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple(CSRMatrix csr);
///////////////////////////////////////////////////////////////////////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
bool COOIsNonZero(COOMatrix coo, int64_t row, int64_t col);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
runtime::NDArray COOIsNonZero(COOMatrix coo, runtime::NDArray row, runtime::NDArray col);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
bool COOHasDuplicate(COOMatrix coo);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
int64_t COOGetRowNNZ(COOMatrix coo, int64_t row);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
runtime::NDArray COOGetRowNNZ(COOMatrix coo, runtime::NDArray row);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
std::pair<runtime::NDArray, runtime::NDArray>
COOGetRowDataAndIndices(COOMatrix coo, int64_t row);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
std::vector<runtime::NDArray> COOGetDataAndIndices(
COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
runtime::NDArray COOGetData(COOMatrix mat, runtime::NDArray rows, runtime::NDArray cols);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
COOMatrix COOTranspose(COOMatrix coo);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
CSRMatrix COOToCSR(COOMatrix coo);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
COOMatrix COOSliceRows(COOMatrix coo, int64_t start, int64_t end);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
COOMatrix COOSliceRows(COOMatrix coo, runtime::NDArray rows);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
COOMatrix COOSliceMatrix(COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
void COOSort_(COOMatrix* mat, bool sort_column);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
std::pair<bool, bool> COOIsSorted(COOMatrix coo);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
COOMatrix COORemove(COOMatrix coo, IdArray entries);
// FloatType is the type of probability data.
template <DLDeviceType XPU, typename IdType, typename FloatType>
template <DGLDeviceType XPU, typename IdType, typename FloatType>
COOMatrix COORowWiseSampling(
COOMatrix mat, IdArray rows, int64_t num_samples, FloatArray prob, bool replace);
// FloatType is the type of probability data.
template <DLDeviceType XPU, typename IdType, typename FloatType>
template <DGLDeviceType XPU, typename IdType, typename FloatType>
COOMatrix COORowWisePerEtypeSampling(
COOMatrix mat, IdArray rows, IdArray etypes,
const std::vector<int64_t>& num_samples, FloatArray prob, bool replace, bool etype_sorted);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
COOMatrix COORowWiseSamplingUniform(
COOMatrix mat, IdArray rows, int64_t num_samples, bool replace);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
COOMatrix COORowWisePerEtypeSamplingUniform(
COOMatrix mat, IdArray rows, IdArray etypes, const std::vector<int64_t>& num_samples,
bool replace, bool etype_sorted);
// FloatType is the type of weight data.
template <DLDeviceType XPU, typename IdType, typename FloatType>
template <DGLDeviceType XPU, typename IdType, typename FloatType>
COOMatrix COORowWiseTopk(
COOMatrix mat, IdArray rows, int64_t k, FloatArray weight, bool ascending);
///////////////////////// Graph Traverse routines //////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
Frontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
Frontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
Frontiers TopologicalNodesFrontiers(const CSRMatrix& csr);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
Frontiers DGLDFSLabeledEdges(const CSRMatrix& csr,
IdArray source,
const bool has_reverse_edge,
const bool has_nontree_edge,
const bool return_labels);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
COOMatrix COOLineGraph(const COOMatrix &coo, bool backtracking);
} // namespace impl
......
......@@ -16,7 +16,7 @@ namespace aten {
// Check whether the given arguments have the same context.
inline void CheckCtx(
const DLContext& ctx,
const DGLContext& ctx,
const std::vector<NDArray>& arrays,
const std::vector<std::string>& names) {
for (size_t i = 0; i < arrays.size(); ++i) {
......
......@@ -10,7 +10,7 @@ using runtime::NDArray;
namespace aten {
namespace impl {
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
IdArray CumSum(IdArray array, bool prepend_zero) {
const int64_t len = array.NumElements();
if (len == 0)
......@@ -34,8 +34,8 @@ IdArray CumSum(IdArray array, bool prepend_zero) {
}
}
template IdArray CumSum<kDLCPU, int32_t>(IdArray, bool);
template IdArray CumSum<kDLCPU, int64_t>(IdArray, bool);
template IdArray CumSum<kDGLCPU, int32_t>(IdArray, bool);
template IdArray CumSum<kDGLCPU, int64_t>(IdArray, bool);
} // namespace impl
} // namespace aten
......
......@@ -10,7 +10,7 @@ using runtime::NDArray;
namespace aten {
namespace impl {
template<DLDeviceType XPU, typename DType, typename IdType>
template<DGLDeviceType XPU, typename DType, typename IdType>
NDArray IndexSelect(NDArray array, IdArray index) {
CHECK_EQ(array->shape[0], array.NumElements()) << "Only support tensor"
<< " whose first dimension equals number of elements, e.g. (5,), (5, 1)";
......@@ -28,25 +28,25 @@ NDArray IndexSelect(NDArray array, IdArray index) {
return ret;
}
template NDArray IndexSelect<kDLCPU, int32_t, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLCPU, int32_t, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLCPU, int64_t, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLCPU, int64_t, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLCPU, float, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLCPU, float, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLCPU, double, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLCPU, double, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCPU, int32_t, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCPU, int32_t, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCPU, int64_t, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCPU, int64_t, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCPU, float, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCPU, float, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCPU, double, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCPU, double, int64_t>(NDArray, IdArray);
template <DLDeviceType XPU, typename DType>
template <DGLDeviceType XPU, typename DType>
DType IndexSelect(NDArray array, int64_t index) {
const DType* data = static_cast<DType*>(array->data);
return data[index];
}
template int32_t IndexSelect<kDLCPU, int32_t>(NDArray array, int64_t index);
template int64_t IndexSelect<kDLCPU, int64_t>(NDArray array, int64_t index);
template float IndexSelect<kDLCPU, float>(NDArray array, int64_t index);
template double IndexSelect<kDLCPU, double>(NDArray array, int64_t index);
template int32_t IndexSelect<kDGLCPU, int32_t>(NDArray array, int64_t index);
template int64_t IndexSelect<kDGLCPU, int64_t>(NDArray array, int64_t index);
template float IndexSelect<kDGLCPU, float>(NDArray array, int64_t index);
template double IndexSelect<kDGLCPU, double>(NDArray array, int64_t index);
} // namespace impl
} // namespace aten
......
......@@ -10,7 +10,7 @@ using runtime::NDArray;
namespace aten {
namespace impl {
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
IdArray NonZero(IdArray array) {
std::vector<int64_t> ret;
const IdType* data = array.Ptr<IdType>();
......@@ -20,8 +20,8 @@ IdArray NonZero(IdArray array) {
return NDArray::FromVector(ret, array->ctx);
}
template IdArray NonZero<kDLCPU, int32_t>(IdArray);
template IdArray NonZero<kDLCPU, int64_t>(IdArray);
template IdArray NonZero<kDGLCPU, int32_t>(IdArray);
template IdArray NonZero<kDGLCPU, int64_t>(IdArray);
} // namespace impl
} // namespace aten
......
......@@ -17,7 +17,7 @@ namespace impl {
///////////////////////////// AsNumBits /////////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
IdArray AsNumBits(IdArray arr, uint8_t bits) {
CHECK(bits == 32 || bits == 64) << "invalid number of integer bits";
if (sizeof(IdType) * 8 == bits) {
......@@ -40,12 +40,12 @@ IdArray AsNumBits(IdArray arr, uint8_t bits) {
return ret;
}
template IdArray AsNumBits<kDLCPU, int32_t>(IdArray arr, uint8_t bits);
template IdArray AsNumBits<kDLCPU, int64_t>(IdArray arr, uint8_t bits);
template IdArray AsNumBits<kDGLCPU, int32_t>(IdArray arr, uint8_t bits);
template IdArray AsNumBits<kDGLCPU, int64_t>(IdArray arr, uint8_t bits);
///////////////////////////// BinaryElewise /////////////////////////////
template <DLDeviceType XPU, typename IdType, typename Op>
template <DGLDeviceType XPU, typename IdType, typename Op>
IdArray BinaryElewise(IdArray lhs, IdArray rhs) {
IdArray ret = NewIdArray(lhs->shape[0], lhs->ctx, lhs->dtype.bits);
const IdType* lhs_data = static_cast<IdType*>(lhs->data);
......@@ -59,30 +59,30 @@ IdArray BinaryElewise(IdArray lhs, IdArray rhs) {
return ret;
}
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Add>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Sub>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Mul>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Div>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Mod>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::GT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::LT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::GE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::LE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::EQ>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::NE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Add>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Sub>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Mul>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Div>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Mod>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::GT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::LT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::GE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::LE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::EQ>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::NE>(IdArray lhs, IdArray rhs);
template <DLDeviceType XPU, typename IdType, typename Op>
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Add>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Sub>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Mul>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Div>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Mod>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::GT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::LT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::GE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::LE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::EQ>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::NE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Add>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Sub>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Mul>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Div>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Mod>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::GT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::LT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::GE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::LE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::EQ>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::NE>(IdArray lhs, IdArray rhs);
template <DGLDeviceType XPU, typename IdType, typename Op>
IdArray BinaryElewise(IdArray lhs, IdType rhs) {
IdArray ret = NewIdArray(lhs->shape[0], lhs->ctx, lhs->dtype.bits);
const IdType* lhs_data = static_cast<IdType*>(lhs->data);
......@@ -95,30 +95,30 @@ IdArray BinaryElewise(IdArray lhs, IdType rhs) {
return ret;
}
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Add>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Sub>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Mul>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Div>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Mod>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::GT>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::LT>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::GE>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::LE>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::EQ>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::NE>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Add>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Sub>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Mul>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Div>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Mod>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::GT>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::LT>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::GE>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::LE>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::EQ>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::NE>(IdArray lhs, int64_t rhs);
template <DLDeviceType XPU, typename IdType, typename Op>
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Add>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Sub>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Mul>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Div>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Mod>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::GT>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::LT>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::GE>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::LE>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::EQ>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::NE>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Add>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Sub>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Mul>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Div>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Mod>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::GT>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::LT>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::GE>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::LE>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::EQ>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::NE>(IdArray lhs, int64_t rhs);
template <DGLDeviceType XPU, typename IdType, typename Op>
IdArray BinaryElewise(IdType lhs, IdArray rhs) {
IdArray ret = NewIdArray(rhs->shape[0], rhs->ctx, rhs->dtype.bits);
const IdType* rhs_data = static_cast<IdType*>(rhs->data);
......@@ -131,30 +131,30 @@ IdArray BinaryElewise(IdType lhs, IdArray rhs) {
return ret;
}
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Add>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Sub>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Mul>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Div>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Mod>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::GT>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::LT>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::GE>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::LE>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::EQ>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::NE>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Add>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Sub>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Mul>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Div>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Mod>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::GT>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::LT>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::GE>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::LE>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::EQ>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::NE>(int64_t lhs, IdArray rhs);
template <DLDeviceType XPU, typename IdType, typename Op>
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Add>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Sub>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Mul>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Div>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Mod>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::GT>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::LT>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::GE>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::LE>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::EQ>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::NE>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Add>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Sub>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Mul>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Div>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Mod>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::GT>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::LT>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::GE>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::LE>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::EQ>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::NE>(int64_t lhs, IdArray rhs);
template <DGLDeviceType XPU, typename IdType, typename Op>
IdArray UnaryElewise(IdArray lhs) {
IdArray ret = NewIdArray(lhs->shape[0], lhs->ctx, lhs->dtype.bits);
const IdType* lhs_data = static_cast<IdType*>(lhs->data);
......@@ -167,28 +167,28 @@ IdArray UnaryElewise(IdArray lhs) {
return ret;
}
template IdArray UnaryElewise<kDLCPU, int32_t, arith::Neg>(IdArray lhs);
template IdArray UnaryElewise<kDLCPU, int64_t, arith::Neg>(IdArray lhs);
template IdArray UnaryElewise<kDGLCPU, int32_t, arith::Neg>(IdArray lhs);
template IdArray UnaryElewise<kDGLCPU, int64_t, arith::Neg>(IdArray lhs);
///////////////////////////// Full /////////////////////////////
template <DLDeviceType XPU, typename DType>
NDArray Full(DType val, int64_t length, DLContext ctx) {
NDArray ret = NDArray::Empty({length}, DLDataTypeTraits<DType>::dtype, ctx);
template <DGLDeviceType XPU, typename DType>
NDArray Full(DType val, int64_t length, DGLContext ctx) {
NDArray ret = NDArray::Empty({length}, DGLDataTypeTraits<DType>::dtype, ctx);
DType* ret_data = static_cast<DType*>(ret->data);
std::fill(ret_data, ret_data + length, val);
return ret;
}
template NDArray Full<kDLCPU, int32_t>(int32_t val, int64_t length, DLContext ctx);
template NDArray Full<kDLCPU, int64_t>(int64_t val, int64_t length, DLContext ctx);
template NDArray Full<kDLCPU, float>(float val, int64_t length, DLContext ctx);
template NDArray Full<kDLCPU, double>(double val, int64_t length, DLContext ctx);
template NDArray Full<kDGLCPU, int32_t>(int32_t val, int64_t length, DGLContext ctx);
template NDArray Full<kDGLCPU, int64_t>(int64_t val, int64_t length, DGLContext ctx);
template NDArray Full<kDGLCPU, float>(float val, int64_t length, DGLContext ctx);
template NDArray Full<kDGLCPU, double>(double val, int64_t length, DGLContext ctx);
///////////////////////////// Range /////////////////////////////
template <DLDeviceType XPU, typename IdType>
IdArray Range(IdType low, IdType high, DLContext ctx) {
template <DGLDeviceType XPU, typename IdType>
IdArray Range(IdType low, IdType high, DGLContext ctx) {
CHECK(high >= low) << "high must be bigger than low";
IdArray ret = NewIdArray(high - low, ctx, sizeof(IdType) * 8);
IdType* ret_data = static_cast<IdType*>(ret->data);
......@@ -196,12 +196,12 @@ IdArray Range(IdType low, IdType high, DLContext ctx) {
return ret;
}
template IdArray Range<kDLCPU, int32_t>(int32_t, int32_t, DLContext);
template IdArray Range<kDLCPU, int64_t>(int64_t, int64_t, DLContext);
template IdArray Range<kDGLCPU, int32_t>(int32_t, int32_t, DGLContext);
template IdArray Range<kDGLCPU, int64_t>(int64_t, int64_t, DGLContext);
///////////////////////////// Relabel_ /////////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
IdArray Relabel_(const std::vector<IdArray>& arrays) {
// build map & relabel
IdType newid = 0;
......@@ -216,7 +216,7 @@ IdArray Relabel_(const std::vector<IdArray>& arrays) {
}
}
// map array
IdArray maparr = NewIdArray(newid, DLContext{kDLCPU, 0}, sizeof(IdType) * 8);
IdArray maparr = NewIdArray(newid, DGLContext{kDGLCPU, 0}, sizeof(IdType) * 8);
IdType* maparr_data = static_cast<IdType*>(maparr->data);
for (const auto& kv : oldv2newv) {
maparr_data[kv.second] = kv.first;
......@@ -224,8 +224,8 @@ IdArray Relabel_(const std::vector<IdArray>& arrays) {
return maparr;
}
template IdArray Relabel_<kDLCPU, int32_t>(const std::vector<IdArray>& arrays);
template IdArray Relabel_<kDLCPU, int64_t>(const std::vector<IdArray>& arrays);
template IdArray Relabel_<kDGLCPU, int32_t>(const std::vector<IdArray>& arrays);
template IdArray Relabel_<kDGLCPU, int64_t>(const std::vector<IdArray>& arrays);
} // namespace impl
} // namespace aten
......
......@@ -14,7 +14,7 @@ using runtime::parallel_for;
namespace aten {
namespace impl {
template<DLDeviceType XPU, typename DType, typename IdType>
template<DGLDeviceType XPU, typename DType, typename IdType>
std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths) {
const int64_t rows = lengths->shape[0];
const int64_t cols = (array->ndim == 1 ? array->shape[0] : array->shape[1]);
......@@ -41,16 +41,16 @@ std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths) {
return std::make_pair(concat, offsets);
}
template std::pair<NDArray, IdArray> ConcatSlices<kDLCPU, int32_t, int32_t>(NDArray, IdArray);
template std::pair<NDArray, IdArray> ConcatSlices<kDLCPU, int64_t, int32_t>(NDArray, IdArray);
template std::pair<NDArray, IdArray> ConcatSlices<kDLCPU, float, int32_t>(NDArray, IdArray);
template std::pair<NDArray, IdArray> ConcatSlices<kDLCPU, double, int32_t>(NDArray, IdArray);
template std::pair<NDArray, IdArray> ConcatSlices<kDLCPU, int32_t, int64_t>(NDArray, IdArray);
template std::pair<NDArray, IdArray> ConcatSlices<kDLCPU, int64_t, int64_t>(NDArray, IdArray);
template std::pair<NDArray, IdArray> ConcatSlices<kDLCPU, float, int64_t>(NDArray, IdArray);
template std::pair<NDArray, IdArray> ConcatSlices<kDLCPU, double, int64_t>(NDArray, IdArray);
template std::pair<NDArray, IdArray> ConcatSlices<kDGLCPU, int32_t, int32_t>(NDArray, IdArray);
template std::pair<NDArray, IdArray> ConcatSlices<kDGLCPU, int64_t, int32_t>(NDArray, IdArray);
template std::pair<NDArray, IdArray> ConcatSlices<kDGLCPU, float, int32_t>(NDArray, IdArray);
template std::pair<NDArray, IdArray> ConcatSlices<kDGLCPU, double, int32_t>(NDArray, IdArray);
template std::pair<NDArray, IdArray> ConcatSlices<kDGLCPU, int32_t, int64_t>(NDArray, IdArray);
template std::pair<NDArray, IdArray> ConcatSlices<kDGLCPU, int64_t, int64_t>(NDArray, IdArray);
template std::pair<NDArray, IdArray> ConcatSlices<kDGLCPU, float, int64_t>(NDArray, IdArray);
template std::pair<NDArray, IdArray> ConcatSlices<kDGLCPU, double, int64_t>(NDArray, IdArray);
template<DLDeviceType XPU, typename DType>
template<DGLDeviceType XPU, typename DType>
std::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, DType pad_value) {
CHECK_NDIM(array, 2, "array");
const DType *array_data = static_cast<DType *>(array->data);
......@@ -75,10 +75,10 @@ std::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, DType pad_value) {
return std::make_tuple(ret.first, length, ret.second);
}
template std::tuple<NDArray, IdArray, IdArray> Pack<kDLCPU, int32_t>(NDArray, int32_t);
template std::tuple<NDArray, IdArray, IdArray> Pack<kDLCPU, int64_t>(NDArray, int64_t);
template std::tuple<NDArray, IdArray, IdArray> Pack<kDLCPU, float>(NDArray, float);
template std::tuple<NDArray, IdArray, IdArray> Pack<kDLCPU, double>(NDArray, double);
template std::tuple<NDArray, IdArray, IdArray> Pack<kDGLCPU, int32_t>(NDArray, int32_t);
template std::tuple<NDArray, IdArray, IdArray> Pack<kDGLCPU, int64_t>(NDArray, int64_t);
template std::tuple<NDArray, IdArray, IdArray> Pack<kDGLCPU, float>(NDArray, float);
template std::tuple<NDArray, IdArray, IdArray> Pack<kDGLCPU, double>(NDArray, double);
} // namespace impl
} // namespace aten
......
......@@ -11,7 +11,7 @@ using runtime::NDArray;
namespace aten {
namespace impl {
template <DLDeviceType XPU, typename DType, typename IdType>
template <DGLDeviceType XPU, typename DType, typename IdType>
NDArray Repeat(NDArray array, IdArray repeats) {
CHECK(array->shape[0] == repeats->shape[0]) << "shape of array and repeats mismatch";
......@@ -34,14 +34,14 @@ NDArray Repeat(NDArray array, IdArray repeats) {
return result;
}
template NDArray Repeat<kDLCPU, int32_t, int32_t>(NDArray, IdArray);
template NDArray Repeat<kDLCPU, int64_t, int32_t>(NDArray, IdArray);
template NDArray Repeat<kDLCPU, float, int32_t>(NDArray, IdArray);
template NDArray Repeat<kDLCPU, double, int32_t>(NDArray, IdArray);
template NDArray Repeat<kDLCPU, int32_t, int64_t>(NDArray, IdArray);
template NDArray Repeat<kDLCPU, int64_t, int64_t>(NDArray, IdArray);
template NDArray Repeat<kDLCPU, float, int64_t>(NDArray, IdArray);
template NDArray Repeat<kDLCPU, double, int64_t>(NDArray, IdArray);
template NDArray Repeat<kDGLCPU, int32_t, int32_t>(NDArray, IdArray);
template NDArray Repeat<kDGLCPU, int64_t, int32_t>(NDArray, IdArray);
template NDArray Repeat<kDGLCPU, float, int32_t>(NDArray, IdArray);
template NDArray Repeat<kDGLCPU, double, int32_t>(NDArray, IdArray);
template NDArray Repeat<kDGLCPU, int32_t, int64_t>(NDArray, IdArray);
template NDArray Repeat<kDGLCPU, int64_t, int64_t>(NDArray, IdArray);
template NDArray Repeat<kDGLCPU, float, int64_t>(NDArray, IdArray);
template NDArray Repeat<kDGLCPU, double, int64_t>(NDArray, IdArray);
}; // namespace impl
}; // namespace aten
......
......@@ -11,7 +11,7 @@ using runtime::NDArray;
namespace aten {
namespace impl {
template <DLDeviceType XPU, typename DType, typename IdType>
template <DGLDeviceType XPU, typename DType, typename IdType>
NDArray Scatter(NDArray array, IdArray indices) {
NDArray result = NDArray::Empty({indices->shape[0]}, array->dtype, array->ctx);
......@@ -25,16 +25,16 @@ NDArray Scatter(NDArray array, IdArray indices) {
return result;
}
template NDArray Scatter<kDLCPU, int32_t, int32_t>(NDArray, IdArray);
template NDArray Scatter<kDLCPU, int64_t, int32_t>(NDArray, IdArray);
template NDArray Scatter<kDLCPU, float, int32_t>(NDArray, IdArray);
template NDArray Scatter<kDLCPU, double, int32_t>(NDArray, IdArray);
template NDArray Scatter<kDLCPU, int32_t, int64_t>(NDArray, IdArray);
template NDArray Scatter<kDLCPU, int64_t, int64_t>(NDArray, IdArray);
template NDArray Scatter<kDLCPU, float, int64_t>(NDArray, IdArray);
template NDArray Scatter<kDLCPU, double, int64_t>(NDArray, IdArray);
template NDArray Scatter<kDGLCPU, int32_t, int32_t>(NDArray, IdArray);
template NDArray Scatter<kDGLCPU, int64_t, int32_t>(NDArray, IdArray);
template NDArray Scatter<kDGLCPU, float, int32_t>(NDArray, IdArray);
template NDArray Scatter<kDGLCPU, double, int32_t>(NDArray, IdArray);
template NDArray Scatter<kDGLCPU, int32_t, int64_t>(NDArray, IdArray);
template NDArray Scatter<kDGLCPU, int64_t, int64_t>(NDArray, IdArray);
template NDArray Scatter<kDGLCPU, float, int64_t>(NDArray, IdArray);
template NDArray Scatter<kDGLCPU, double, int64_t>(NDArray, IdArray);
template <DLDeviceType XPU, typename DType, typename IdType>
template <DGLDeviceType XPU, typename DType, typename IdType>
void Scatter_(IdArray index, NDArray value, NDArray out) {
const int64_t len = index->shape[0];
const IdType* idx = index.Ptr<IdType>();
......@@ -47,14 +47,14 @@ void Scatter_(IdArray index, NDArray value, NDArray out) {
});
}
template void Scatter_<kDLCPU, int32_t, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLCPU, int64_t, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLCPU, float, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLCPU, double, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLCPU, int32_t, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLCPU, int64_t, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLCPU, float, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLCPU, double, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCPU, int32_t, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCPU, int64_t, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCPU, float, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCPU, double, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCPU, int32_t, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCPU, int64_t, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCPU, float, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCPU, double, int64_t>(IdArray, NDArray, NDArray);
}; // namespace impl
}; // namespace aten
......
......@@ -160,7 +160,7 @@ using runtime::NDArray;
namespace aten {
namespace impl {
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
std::pair<IdArray, IdArray> Sort(IdArray array, int /* num_bits */) {
const int64_t nitem = array->shape[0];
IdArray val = array.Clone();
......@@ -181,8 +181,8 @@ std::pair<IdArray, IdArray> Sort(IdArray array, int /* num_bits */) {
return std::make_pair(val, idx);
}
template std::pair<IdArray, IdArray> Sort<kDLCPU, int32_t>(IdArray, int num_bits);
template std::pair<IdArray, IdArray> Sort<kDLCPU, int64_t>(IdArray, int num_bits);
template std::pair<IdArray, IdArray> Sort<kDGLCPU, int32_t>(IdArray, int num_bits);
template std::pair<IdArray, IdArray> Sort<kDGLCPU, int64_t>(IdArray, int num_bits);
} // namespace impl
} // namespace aten
......
......@@ -88,7 +88,7 @@ class IdHashMap {
// Return all the old ids collected so far, ordered by new id.
IdArray Values() const {
IdArray values = NewIdArray(oldv2newv_.size(), DLContext{kDLCPU, 0}, sizeof(IdType) * 8);
IdArray values = NewIdArray(oldv2newv_.size(), DGLContext{kDGLCPU, 0}, sizeof(IdType) * 8);
IdType* values_data = static_cast<IdType*>(values->data);
for (auto pair : oldv2newv_)
values_data[pair.second] = pair.first;
......
......@@ -13,7 +13,7 @@ namespace aten {
namespace impl {
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo) {
const int64_t nnz = coo.row->shape[0];
const IdType* coo_row_data = static_cast<IdType*>(coo.row->data);
......@@ -44,8 +44,8 @@ std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo) {
return std::make_pair(coo_result, NDArray::FromVector(count));
}
template std::pair<COOMatrix, IdArray> COOCoalesce<kDLCPU, int32_t>(COOMatrix);
template std::pair<COOMatrix, IdArray> COOCoalesce<kDLCPU, int64_t>(COOMatrix);
template std::pair<COOMatrix, IdArray> COOCoalesce<kDGLCPU, int32_t>(COOMatrix);
template std::pair<COOMatrix, IdArray> COOCoalesce<kDGLCPU, int64_t>(COOMatrix);
}; // namespace impl
}; // namespace aten
......
......@@ -14,7 +14,7 @@ namespace dgl {
namespace aten {
namespace impl {
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
COOMatrix COOLineGraph(const COOMatrix &coo, bool backtracking) {
const int64_t nnz = coo.row->shape[0];
IdType* coo_row = coo.row.Ptr<IdType>();
......@@ -50,8 +50,8 @@ COOMatrix COOLineGraph(const COOMatrix &coo, bool backtracking) {
}
template COOMatrix COOLineGraph<kDLCPU, int32_t>(const COOMatrix &coo, bool backtracking);
template COOMatrix COOLineGraph<kDLCPU, int64_t>(const COOMatrix &coo, bool backtracking);
template COOMatrix COOLineGraph<kDGLCPU, int32_t>(const COOMatrix &coo, bool backtracking);
template COOMatrix COOLineGraph<kDGLCPU, int64_t>(const COOMatrix &coo, bool backtracking);
} // namespace impl
} // namespace aten
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment