Unverified Commit cded5b80 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Feature] Bump DLPack to v0.7 and decouple DLPack from the core library (#4454)

* rename `DLContext` to `DGLContext`

* rename `kDLGPU` to `kDLCUDA`

* replace DLTensor with DGLArray

* fix linting

* Unify DGLType and DLDataType to DGLDataType

* Fix FFI

* rename DLDeviceType to DGLDeviceType

* decouple dlpack from the core library

* fix bug

* fix lint

* fix merge

* fix build

* address comments

* rename dl_converter to dlpack_convert

* remove redundant comments
parent f1689ad0
...@@ -44,7 +44,7 @@ namespace transform { ...@@ -44,7 +44,7 @@ namespace transform {
* *
* @return The block and the induced edges. * @return The block and the induced edges.
*/ */
template<DLDeviceType XPU, typename IdType> template<DGLDeviceType XPU, typename IdType>
std::tuple<HeteroGraphPtr, std::vector<IdArray>> std::tuple<HeteroGraphPtr, std::vector<IdArray>>
ToBlock(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes, ToBlock(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes,
bool include_rhs_in_lhs, std::vector<IdArray>* lhs_nodes); bool include_rhs_in_lhs, std::vector<IdArray>* lhs_nodes);
......
...@@ -54,7 +54,7 @@ IdArray MergeMultipleTraversals( ...@@ -54,7 +54,7 @@ IdArray MergeMultipleTraversals(
max_len = std::max(max_len, tracelen); max_len = std::max(max_len, tracelen);
total_len += traces[i].size(); total_len += traces[i].size();
} }
IdArray ret = IdArray::Empty({total_len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0}); IdArray ret = IdArray::Empty({total_len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
int64_t* ret_data = static_cast<int64_t*>(ret->data); int64_t* ret_data = static_cast<int64_t*>(ret->data);
for (int64_t i = 0; i < max_len; ++i) { for (int64_t i = 0; i < max_len; ++i) {
for (size_t j = 0; j < traces.size(); ++j) { for (size_t j = 0; j < traces.size(); ++j) {
...@@ -78,7 +78,7 @@ IdArray ComputeMergedSections( ...@@ -78,7 +78,7 @@ IdArray ComputeMergedSections(
const int64_t tracelen = traces[i].size(); const int64_t tracelen = traces[i].size();
max_len = std::max(max_len, tracelen); max_len = std::max(max_len, tracelen);
} }
IdArray ret = IdArray::Empty({max_len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0}); IdArray ret = IdArray::Empty({max_len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
int64_t* ret_data = static_cast<int64_t*>(ret->data); int64_t* ret_data = static_cast<int64_t*>(ret->data);
for (int64_t i = 0; i < max_len; ++i) { for (int64_t i = 0; i < max_len; ++i) {
int64_t sec_len = 0; int64_t sec_len = 0;
......
...@@ -125,11 +125,11 @@ class UnitGraph::COO : public BaseHeteroGraph { ...@@ -125,11 +125,11 @@ class UnitGraph::COO : public BaseHeteroGraph {
LOG(FATAL) << "UnitGraph graph is not mutable."; LOG(FATAL) << "UnitGraph graph is not mutable.";
} }
DLDataType DataType() const override { DGLDataType DataType() const override {
return adj_.row->dtype; return adj_.row->dtype;
} }
DLContext Context() const override { DGLContext Context() const override {
return adj_.row->ctx; return adj_.row->ctx;
} }
...@@ -153,7 +153,7 @@ class UnitGraph::COO : public BaseHeteroGraph { ...@@ -153,7 +153,7 @@ class UnitGraph::COO : public BaseHeteroGraph {
return ret; return ret;
} }
COO CopyTo(const DLContext &ctx) const { COO CopyTo(const DGLContext &ctx) const {
if (Context() == ctx) if (Context() == ctx)
return *this; return *this;
return COO(meta_graph_, adj_.CopyTo(ctx)); return COO(meta_graph_, adj_.CopyTo(ctx));
...@@ -385,7 +385,7 @@ class UnitGraph::COO : public BaseHeteroGraph { ...@@ -385,7 +385,7 @@ class UnitGraph::COO : public BaseHeteroGraph {
CHECK(aten::IsValidIdArray(dstvids)) << "Invalid vertex id array."; CHECK(aten::IsValidIdArray(dstvids)) << "Invalid vertex id array.";
HeteroSubgraph subg; HeteroSubgraph subg;
const auto& submat = aten::COOSliceMatrix(adj_, srcvids, dstvids); const auto& submat = aten::COOSliceMatrix(adj_, srcvids, dstvids);
DLContext ctx = aten::GetContextOf(vids); DGLContext ctx = aten::GetContextOf(vids);
IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), ctx); IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), ctx);
subg.graph = std::make_shared<COO>(meta_graph(), submat.num_rows, submat.num_cols, subg.graph = std::make_shared<COO>(meta_graph(), submat.num_rows, submat.num_cols,
submat.row, submat.col); submat.row, submat.col);
...@@ -412,9 +412,9 @@ class UnitGraph::COO : public BaseHeteroGraph { ...@@ -412,9 +412,9 @@ class UnitGraph::COO : public BaseHeteroGraph {
IdArray new_src = aten::IndexSelect(adj_.row, eids[0]); IdArray new_src = aten::IndexSelect(adj_.row, eids[0]);
IdArray new_dst = aten::IndexSelect(adj_.col, eids[0]); IdArray new_dst = aten::IndexSelect(adj_.col, eids[0]);
subg.induced_vertices.emplace_back( subg.induced_vertices.emplace_back(
aten::NullArray(DLDataType{kDLInt, NumBits(), 1}, Context())); aten::NullArray(DGLDataType{kDGLInt, NumBits(), 1}, Context()));
subg.induced_vertices.emplace_back( subg.induced_vertices.emplace_back(
aten::NullArray(DLDataType{kDLInt, NumBits(), 1}, Context())); aten::NullArray(DGLDataType{kDGLInt, NumBits(), 1}, Context()));
subg.graph = std::make_shared<COO>( subg.graph = std::make_shared<COO>(
meta_graph(), NumVertices(SrcType()), NumVertices(DstType()), new_src, new_dst); meta_graph(), NumVertices(SrcType()), NumVertices(DstType()), new_src, new_dst);
subg.induced_edges = eids; subg.induced_edges = eids;
...@@ -532,11 +532,11 @@ class UnitGraph::CSR : public BaseHeteroGraph { ...@@ -532,11 +532,11 @@ class UnitGraph::CSR : public BaseHeteroGraph {
LOG(FATAL) << "UnitGraph graph is not mutable."; LOG(FATAL) << "UnitGraph graph is not mutable.";
} }
DLDataType DataType() const override { DGLDataType DataType() const override {
return adj_.indices->dtype; return adj_.indices->dtype;
} }
DLContext Context() const override { DGLContext Context() const override {
return adj_.indices->ctx; return adj_.indices->ctx;
} }
...@@ -562,7 +562,7 @@ class UnitGraph::CSR : public BaseHeteroGraph { ...@@ -562,7 +562,7 @@ class UnitGraph::CSR : public BaseHeteroGraph {
} }
} }
CSR CopyTo(const DLContext &ctx) const { CSR CopyTo(const DGLContext &ctx) const {
if (Context() == ctx) { if (Context() == ctx) {
return *this; return *this;
} else { } else {
...@@ -810,7 +810,7 @@ class UnitGraph::CSR : public BaseHeteroGraph { ...@@ -810,7 +810,7 @@ class UnitGraph::CSR : public BaseHeteroGraph {
CHECK(aten::IsValidIdArray(dstvids)) << "Invalid vertex id array."; CHECK(aten::IsValidIdArray(dstvids)) << "Invalid vertex id array.";
HeteroSubgraph subg; HeteroSubgraph subg;
const auto& submat = aten::CSRSliceMatrix(adj_, srcvids, dstvids); const auto& submat = aten::CSRSliceMatrix(adj_, srcvids, dstvids);
DLContext ctx = aten::GetContextOf(vids); DGLContext ctx = aten::GetContextOf(vids);
IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), ctx); IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), ctx);
subg.graph = std::make_shared<CSR>(meta_graph(), submat.num_rows, submat.num_cols, subg.graph = std::make_shared<CSR>(meta_graph(), submat.num_rows, submat.num_cols,
submat.indptr, submat.indices, sub_eids); submat.indptr, submat.indices, sub_eids);
...@@ -860,11 +860,11 @@ class UnitGraph::CSR : public BaseHeteroGraph { ...@@ -860,11 +860,11 @@ class UnitGraph::CSR : public BaseHeteroGraph {
// //
////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////
DLDataType UnitGraph::DataType() const { DGLDataType UnitGraph::DataType() const {
return GetAny()->DataType(); return GetAny()->DataType();
} }
DLContext UnitGraph::Context() const { DGLContext UnitGraph::Context() const {
return GetAny()->Context(); return GetAny()->Context();
} }
...@@ -1285,7 +1285,7 @@ HeteroGraphPtr UnitGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) { ...@@ -1285,7 +1285,7 @@ HeteroGraphPtr UnitGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) {
} }
} }
HeteroGraphPtr UnitGraph::CopyTo(HeteroGraphPtr g, const DLContext &ctx) { HeteroGraphPtr UnitGraph::CopyTo(HeteroGraphPtr g, const DGLContext &ctx) {
if (ctx == g->Context()) { if (ctx == g->Context()) {
return g; return g;
} else { } else {
......
...@@ -79,9 +79,9 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -79,9 +79,9 @@ class UnitGraph : public BaseHeteroGraph {
LOG(FATAL) << "UnitGraph graph is not mutable."; LOG(FATAL) << "UnitGraph graph is not mutable.";
} }
DLDataType DataType() const override; DGLDataType DataType() const override;
DLContext Context() const override; DGLContext Context() const override;
bool IsPinned() const override; bool IsPinned() const override;
...@@ -167,7 +167,7 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -167,7 +167,7 @@ class UnitGraph : public BaseHeteroGraph {
/*! \brief Create a graph with no edges */ /*! \brief Create a graph with no edges */
static HeteroGraphPtr Empty( static HeteroGraphPtr Empty(
int64_t num_vtypes, int64_t num_src, int64_t num_dst, int64_t num_vtypes, int64_t num_src, int64_t num_dst,
DLDataType dtype, DLContext ctx) { DGLDataType dtype, DGLContext ctx) {
IdArray row = IdArray::Empty({0}, dtype, ctx); IdArray row = IdArray::Empty({0}, dtype, ctx);
IdArray col = IdArray::Empty({0}, dtype, ctx); IdArray col = IdArray::Empty({0}, dtype, ctx);
return CreateFromCOO(num_vtypes, num_src, num_dst, row, col); return CreateFromCOO(num_vtypes, num_src, num_dst, row, col);
...@@ -207,14 +207,14 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -207,14 +207,14 @@ class UnitGraph : public BaseHeteroGraph {
static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits); static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits);
/*! \brief Copy the data to another context */ /*! \brief Copy the data to another context */
static HeteroGraphPtr CopyTo(HeteroGraphPtr g, const DLContext &ctx); static HeteroGraphPtr CopyTo(HeteroGraphPtr g, const DGLContext &ctx);
/*! /*!
* \brief Pin the in_csr_, out_scr_ and coo_ of the current graph. * \brief Pin the in_csr_, out_scr_ and coo_ of the current graph.
* \note The graph will be pinned inplace. Behavior depends on the current context, * \note The graph will be pinned inplace. Behavior depends on the current context,
* kDLCPU: will be pinned; * kDGLCPU: will be pinned;
* IsPinned: directly return; * IsPinned: directly return;
* kDLGPU: invalid, will throw an error. * kDGLCUDA: invalid, will throw an error.
* The context check is deferred to pinning the NDArray. * The context check is deferred to pinning the NDArray.
*/ */
void PinMemory_() override; void PinMemory_() override;
......
...@@ -251,7 +251,7 @@ __global__ void _MapGlobalIndexByRangeKernel( ...@@ -251,7 +251,7 @@ __global__ void _MapGlobalIndexByRangeKernel(
// Remainder Based Partition Operations // Remainder Based Partition Operations
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
std::pair<IdArray, NDArray> std::pair<IdArray, NDArray>
GeneratePermutationFromRemainder( GeneratePermutationFromRemainder(
int64_t array_size, int64_t array_size,
...@@ -376,18 +376,18 @@ GeneratePermutationFromRemainder( ...@@ -376,18 +376,18 @@ GeneratePermutationFromRemainder(
template std::pair<IdArray, IdArray> template std::pair<IdArray, IdArray>
GeneratePermutationFromRemainder<kDLGPU, int32_t>( GeneratePermutationFromRemainder<kDGLCUDA, int32_t>(
int64_t array_size, int64_t array_size,
int num_parts, int num_parts,
IdArray in_idx); IdArray in_idx);
template std::pair<IdArray, IdArray> template std::pair<IdArray, IdArray>
GeneratePermutationFromRemainder<kDLGPU, int64_t>( GeneratePermutationFromRemainder<kDGLCUDA, int64_t>(
int64_t array_size, int64_t array_size,
int num_parts, int num_parts,
IdArray in_idx); IdArray in_idx);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
IdArray MapToLocalFromRemainder( IdArray MapToLocalFromRemainder(
const int num_parts, const int num_parts,
IdArray global_idx) { IdArray global_idx) {
...@@ -420,15 +420,15 @@ IdArray MapToLocalFromRemainder( ...@@ -420,15 +420,15 @@ IdArray MapToLocalFromRemainder(
} }
template IdArray template IdArray
MapToLocalFromRemainder<kDLGPU, int32_t>( MapToLocalFromRemainder<kDGLCUDA, int32_t>(
int num_parts, int num_parts,
IdArray in_idx); IdArray in_idx);
template IdArray template IdArray
MapToLocalFromRemainder<kDLGPU, int64_t>( MapToLocalFromRemainder<kDGLCUDA, int64_t>(
int num_parts, int num_parts,
IdArray in_idx); IdArray in_idx);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
IdArray MapToGlobalFromRemainder( IdArray MapToGlobalFromRemainder(
const int num_parts, const int num_parts,
IdArray local_idx, IdArray local_idx,
...@@ -468,12 +468,12 @@ IdArray MapToGlobalFromRemainder( ...@@ -468,12 +468,12 @@ IdArray MapToGlobalFromRemainder(
} }
template IdArray template IdArray
MapToGlobalFromRemainder<kDLGPU, int32_t>( MapToGlobalFromRemainder<kDGLCUDA, int32_t>(
int num_parts, int num_parts,
IdArray in_idx, IdArray in_idx,
int part_id); int part_id);
template IdArray template IdArray
MapToGlobalFromRemainder<kDLGPU, int64_t>( MapToGlobalFromRemainder<kDGLCUDA, int64_t>(
int num_parts, int num_parts,
IdArray in_idx, IdArray in_idx,
int part_id); int part_id);
...@@ -481,7 +481,7 @@ MapToGlobalFromRemainder<kDLGPU, int64_t>( ...@@ -481,7 +481,7 @@ MapToGlobalFromRemainder<kDLGPU, int64_t>(
// Range Based Partition Operations // Range Based Partition Operations
template <DLDeviceType XPU, typename IdType, typename RangeType> template <DGLDeviceType XPU, typename IdType, typename RangeType>
std::pair<IdArray, NDArray> std::pair<IdArray, NDArray>
GeneratePermutationFromRange( GeneratePermutationFromRange(
int64_t array_size, int64_t array_size,
...@@ -598,31 +598,31 @@ GeneratePermutationFromRange( ...@@ -598,31 +598,31 @@ GeneratePermutationFromRange(
template std::pair<IdArray, IdArray> template std::pair<IdArray, IdArray>
GeneratePermutationFromRange<kDLGPU, int32_t, int32_t>( GeneratePermutationFromRange<kDGLCUDA, int32_t, int32_t>(
int64_t array_size, int64_t array_size,
int num_parts, int num_parts,
IdArray range, IdArray range,
IdArray in_idx); IdArray in_idx);
template std::pair<IdArray, IdArray> template std::pair<IdArray, IdArray>
GeneratePermutationFromRange<kDLGPU, int64_t, int32_t>( GeneratePermutationFromRange<kDGLCUDA, int64_t, int32_t>(
int64_t array_size, int64_t array_size,
int num_parts, int num_parts,
IdArray range, IdArray range,
IdArray in_idx); IdArray in_idx);
template std::pair<IdArray, IdArray> template std::pair<IdArray, IdArray>
GeneratePermutationFromRange<kDLGPU, int32_t, int64_t>( GeneratePermutationFromRange<kDGLCUDA, int32_t, int64_t>(
int64_t array_size, int64_t array_size,
int num_parts, int num_parts,
IdArray range, IdArray range,
IdArray in_idx); IdArray in_idx);
template std::pair<IdArray, IdArray> template std::pair<IdArray, IdArray>
GeneratePermutationFromRange<kDLGPU, int64_t, int64_t>( GeneratePermutationFromRange<kDGLCUDA, int64_t, int64_t>(
int64_t array_size, int64_t array_size,
int num_parts, int num_parts,
IdArray range, IdArray range,
IdArray in_idx); IdArray in_idx);
template <DLDeviceType XPU, typename IdType, typename RangeType> template <DGLDeviceType XPU, typename IdType, typename RangeType>
IdArray MapToLocalFromRange( IdArray MapToLocalFromRange(
const int num_parts, const int num_parts,
IdArray range, IdArray range,
...@@ -657,28 +657,28 @@ IdArray MapToLocalFromRange( ...@@ -657,28 +657,28 @@ IdArray MapToLocalFromRange(
} }
template IdArray template IdArray
MapToLocalFromRange<kDLGPU, int32_t, int32_t>( MapToLocalFromRange<kDGLCUDA, int32_t, int32_t>(
int num_parts, int num_parts,
IdArray range, IdArray range,
IdArray in_idx); IdArray in_idx);
template IdArray template IdArray
MapToLocalFromRange<kDLGPU, int64_t, int32_t>( MapToLocalFromRange<kDGLCUDA, int64_t, int32_t>(
int num_parts, int num_parts,
IdArray range, IdArray range,
IdArray in_idx); IdArray in_idx);
template IdArray template IdArray
MapToLocalFromRange<kDLGPU, int32_t, int64_t>( MapToLocalFromRange<kDGLCUDA, int32_t, int64_t>(
int num_parts, int num_parts,
IdArray range, IdArray range,
IdArray in_idx); IdArray in_idx);
template IdArray template IdArray
MapToLocalFromRange<kDLGPU, int64_t, int64_t>( MapToLocalFromRange<kDGLCUDA, int64_t, int64_t>(
int num_parts, int num_parts,
IdArray range, IdArray range,
IdArray in_idx); IdArray in_idx);
template <DLDeviceType XPU, typename IdType, typename RangeType> template <DGLDeviceType XPU, typename IdType, typename RangeType>
IdArray MapToGlobalFromRange( IdArray MapToGlobalFromRange(
const int num_parts, const int num_parts,
IdArray range, IdArray range,
...@@ -720,25 +720,25 @@ IdArray MapToGlobalFromRange( ...@@ -720,25 +720,25 @@ IdArray MapToGlobalFromRange(
} }
template IdArray template IdArray
MapToGlobalFromRange<kDLGPU, int32_t, int32_t>( MapToGlobalFromRange<kDGLCUDA, int32_t, int32_t>(
int num_parts, int num_parts,
IdArray range, IdArray range,
IdArray in_idx, IdArray in_idx,
int part_id); int part_id);
template IdArray template IdArray
MapToGlobalFromRange<kDLGPU, int64_t, int32_t>( MapToGlobalFromRange<kDGLCUDA, int64_t, int32_t>(
int num_parts, int num_parts,
IdArray range, IdArray range,
IdArray in_idx, IdArray in_idx,
int part_id); int part_id);
template IdArray template IdArray
MapToGlobalFromRange<kDLGPU, int32_t, int64_t>( MapToGlobalFromRange<kDGLCUDA, int32_t, int64_t>(
int num_parts, int num_parts,
IdArray range, IdArray range,
IdArray in_idx, IdArray in_idx,
int part_id); int part_id);
template IdArray template IdArray
MapToGlobalFromRange<kDLGPU, int64_t, int64_t>( MapToGlobalFromRange<kDGLCUDA, int64_t, int64_t>(
int num_parts, int num_parts,
IdArray range, IdArray range,
IdArray in_idx, IdArray in_idx,
......
...@@ -46,9 +46,9 @@ class RemainderPartition : public NDArrayPartition { ...@@ -46,9 +46,9 @@ class RemainderPartition : public NDArrayPartition {
IdArray in_idx) const override { IdArray in_idx) const override {
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
auto ctx = in_idx->ctx; auto ctx = in_idx->ctx;
if (ctx.device_type == kDLGPU) { if (ctx.device_type == kDGLCUDA) {
ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, { ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, {
return impl::GeneratePermutationFromRemainder<kDLGPU, IdType>( return impl::GeneratePermutationFromRemainder<kDGLCUDA, IdType>(
ArraySize(), NumParts(), in_idx); ArraySize(), NumParts(), in_idx);
}); });
} }
...@@ -64,9 +64,9 @@ class RemainderPartition : public NDArrayPartition { ...@@ -64,9 +64,9 @@ class RemainderPartition : public NDArrayPartition {
IdArray in_idx) const override { IdArray in_idx) const override {
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
auto ctx = in_idx->ctx; auto ctx = in_idx->ctx;
if (ctx.device_type == kDLGPU) { if (ctx.device_type == kDGLCUDA) {
ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, { ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, {
return impl::MapToLocalFromRemainder<kDLGPU, IdType>( return impl::MapToLocalFromRemainder<kDGLCUDA, IdType>(
NumParts(), in_idx); NumParts(), in_idx);
}); });
} }
...@@ -83,9 +83,9 @@ class RemainderPartition : public NDArrayPartition { ...@@ -83,9 +83,9 @@ class RemainderPartition : public NDArrayPartition {
const int part_id) const override { const int part_id) const override {
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
auto ctx = in_idx->ctx; auto ctx = in_idx->ctx;
if (ctx.device_type == kDLGPU) { if (ctx.device_type == kDGLCUDA) {
ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, { ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, {
return impl::MapToGlobalFromRemainder<kDLGPU, IdType>( return impl::MapToGlobalFromRemainder<kDGLCUDA, IdType>(
NumParts(), in_idx, part_id); NumParts(), in_idx, part_id);
}); });
} }
...@@ -116,9 +116,9 @@ class RangePartition : public NDArrayPartition { ...@@ -116,9 +116,9 @@ class RangePartition : public NDArrayPartition {
// sizes. We require the input range on the GPU, as if we have multiple // sizes. We require the input range on the GPU, as if we have multiple
// GPUs, we can't know which is the proper one to copy the array to, but we // GPUs, we can't know which is the proper one to copy the array to, but we
// have only one CPU context, and can safely copy the array to that. // have only one CPU context, and can safely copy the array to that.
range_cpu_(range.CopyTo(DGLContext{kDLCPU, 0})) { range_cpu_(range.CopyTo(DGLContext{kDGLCPU, 0})) {
auto ctx = range->ctx; auto ctx = range->ctx;
if (ctx.device_type != kDLGPU) { if (ctx.device_type != kDGLCUDA) {
LOG(FATAL) << "The range for an NDArrayPartition is only supported " LOG(FATAL) << "The range for an NDArrayPartition is only supported "
" on GPUs. Transfer the range to the target device before " " on GPUs. Transfer the range to the target device before "
"creating the partition."; "creating the partition.";
...@@ -130,7 +130,7 @@ class RangePartition : public NDArrayPartition { ...@@ -130,7 +130,7 @@ class RangePartition : public NDArrayPartition {
IdArray in_idx) const override { IdArray in_idx) const override {
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
auto ctx = in_idx->ctx; auto ctx = in_idx->ctx;
if (ctx.device_type == kDLGPU) { if (ctx.device_type == kDGLCUDA) {
if (ctx.device_type != range_->ctx.device_type || if (ctx.device_type != range_->ctx.device_type ||
ctx.device_id != range_->ctx.device_id) { ctx.device_id != range_->ctx.device_id) {
LOG(FATAL) << "The range for the NDArrayPartition and the input " LOG(FATAL) << "The range for the NDArrayPartition and the input "
...@@ -138,7 +138,7 @@ class RangePartition : public NDArrayPartition { ...@@ -138,7 +138,7 @@ class RangePartition : public NDArrayPartition {
} }
ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, { ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, {
ATEN_ID_TYPE_SWITCH(range_->dtype, RangeType, { ATEN_ID_TYPE_SWITCH(range_->dtype, RangeType, {
return impl::GeneratePermutationFromRange<kDLGPU, IdType, RangeType>( return impl::GeneratePermutationFromRange<kDGLCUDA, IdType, RangeType>(
ArraySize(), NumParts(), range_, in_idx); ArraySize(), NumParts(), range_, in_idx);
}); });
}); });
...@@ -155,10 +155,10 @@ class RangePartition : public NDArrayPartition { ...@@ -155,10 +155,10 @@ class RangePartition : public NDArrayPartition {
IdArray in_idx) const override { IdArray in_idx) const override {
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
auto ctx = in_idx->ctx; auto ctx = in_idx->ctx;
if (ctx.device_type == kDLGPU) { if (ctx.device_type == kDGLCUDA) {
ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, { ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, {
ATEN_ID_TYPE_SWITCH(range_->dtype, RangeType, { ATEN_ID_TYPE_SWITCH(range_->dtype, RangeType, {
return impl::MapToLocalFromRange<kDLGPU, IdType, RangeType>( return impl::MapToLocalFromRange<kDGLCUDA, IdType, RangeType>(
NumParts(), range_, in_idx); NumParts(), range_, in_idx);
}); });
}); });
...@@ -176,10 +176,10 @@ class RangePartition : public NDArrayPartition { ...@@ -176,10 +176,10 @@ class RangePartition : public NDArrayPartition {
const int part_id) const override { const int part_id) const override {
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
auto ctx = in_idx->ctx; auto ctx = in_idx->ctx;
if (ctx.device_type == kDLGPU) { if (ctx.device_type == kDGLCUDA) {
ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, { ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, {
ATEN_ID_TYPE_SWITCH(range_->dtype, RangeType, { ATEN_ID_TYPE_SWITCH(range_->dtype, RangeType, {
return impl::MapToGlobalFromRange<kDLGPU, IdType, RangeType>( return impl::MapToGlobalFromRange<kDGLCUDA, IdType, RangeType>(
NumParts(), range_, in_idx, part_id); NumParts(), range_, in_idx, part_id);
}); });
}); });
......
...@@ -32,7 +32,7 @@ namespace impl { ...@@ -32,7 +32,7 @@ namespace impl {
* @return The permutation to group the indices by part id, and the number of * @return The permutation to group the indices by part id, and the number of
* indices in each part. * indices in each part.
*/ */
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
std::pair<IdArray, IdArray> std::pair<IdArray, IdArray>
GeneratePermutationFromRemainder( GeneratePermutationFromRemainder(
int64_t array_size, int64_t array_size,
...@@ -51,7 +51,7 @@ GeneratePermutationFromRemainder( ...@@ -51,7 +51,7 @@ GeneratePermutationFromRemainder(
* *
* @return The array of local indices. * @return The array of local indices.
*/ */
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
IdArray MapToLocalFromRemainder( IdArray MapToLocalFromRemainder(
int num_parts, int num_parts,
IdArray global_idx); IdArray global_idx);
...@@ -69,7 +69,7 @@ IdArray MapToLocalFromRemainder( ...@@ -69,7 +69,7 @@ IdArray MapToLocalFromRemainder(
* *
* @return The array of global indices. * @return The array of global indices.
*/ */
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
IdArray MapToGlobalFromRemainder( IdArray MapToGlobalFromRemainder(
int num_parts, int num_parts,
IdArray local_idx, IdArray local_idx,
...@@ -95,7 +95,7 @@ IdArray MapToGlobalFromRemainder( ...@@ -95,7 +95,7 @@ IdArray MapToGlobalFromRemainder(
* @return The permutation to group the indices by part id, and the number of * @return The permutation to group the indices by part id, and the number of
* indices in each part. * indices in each part.
*/ */
template <DLDeviceType XPU, typename IdType, typename RangeType> template <DGLDeviceType XPU, typename IdType, typename RangeType>
std::pair<IdArray, IdArray> std::pair<IdArray, IdArray>
GeneratePermutationFromRange( GeneratePermutationFromRange(
int64_t array_size, int64_t array_size,
...@@ -118,7 +118,7 @@ GeneratePermutationFromRange( ...@@ -118,7 +118,7 @@ GeneratePermutationFromRange(
* *
* @return The array of local indices. * @return The array of local indices.
*/ */
template <DLDeviceType XPU, typename IdType, typename RangeType> template <DGLDeviceType XPU, typename IdType, typename RangeType>
IdArray MapToLocalFromRange( IdArray MapToLocalFromRange(
int num_parts, int num_parts,
IdArray range, IdArray range,
...@@ -140,7 +140,7 @@ IdArray MapToLocalFromRange( ...@@ -140,7 +140,7 @@ IdArray MapToLocalFromRange(
* *
* @return The array of global indices. * @return The array of global indices.
*/ */
template <DLDeviceType XPU, typename IdType, typename RangeType> template <DGLDeviceType XPU, typename IdType, typename RangeType>
IdArray MapToGlobalFromRange( IdArray MapToGlobalFromRange(
int num_parts, int num_parts,
IdArray range, IdArray range,
......
...@@ -29,7 +29,7 @@ DGL_REGISTER_GLOBAL("rng._CAPI_SetSeed") ...@@ -29,7 +29,7 @@ DGL_REGISTER_GLOBAL("rng._CAPI_SetSeed")
} }
}); });
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
if (DeviceAPI::Get(kDLGPU)->IsAvailable()) { if (DeviceAPI::Get(kDGLCUDA)->IsAvailable()) {
auto* thr_entry = CUDAThreadEntry::ThreadLocal(); auto* thr_entry = CUDAThreadEntry::ThreadLocal();
if (!thr_entry->curand_gen) { if (!thr_entry->curand_gen) {
CURAND_CALL(curandCreateGenerator(&thr_entry->curand_gen, CURAND_RNG_PSEUDO_DEFAULT)); CURAND_CALL(curandCreateGenerator(&thr_entry->curand_gen, CURAND_RNG_PSEUDO_DEFAULT));
......
...@@ -517,7 +517,7 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFastPull") ...@@ -517,7 +517,7 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFastPull")
local_data_shape[0] = ID_size; local_data_shape[0] = ID_size;
NDArray res_tensor = NDArray::Empty(local_data_shape, NDArray res_tensor = NDArray::Empty(local_data_shape,
local_data->dtype, local_data->dtype,
DLContext{kDLCPU, 0}); DGLContext{kDGLCPU, 0});
char* return_data = static_cast<char*>(res_tensor->data); char* return_data = static_cast<char*>(res_tensor->data);
// Copy local data // Copy local data
parallel_for(0, local_ids.size(), [&](size_t b, size_t e) { parallel_for(0, local_ids.size(), [&](size_t b, size_t e) {
......
...@@ -137,7 +137,7 @@ int DGLObjectGetAttr(ObjectHandle handle, ...@@ -137,7 +137,7 @@ int DGLObjectGetAttr(ObjectHandle handle,
(*tobject)->VisitAttrs(&getter); (*tobject)->VisitAttrs(&getter);
*ret_success = getter.found_object_ref || rv.type_code() != kNull; *ret_success = getter.found_object_ref || rv.type_code() != kNull;
if (rv.type_code() == kStr || if (rv.type_code() == kStr ||
rv.type_code() == kDGLType) { rv.type_code() == kDGLDataType) {
DGLAPIThreadLocalEntry *e = DGLAPIThreadLocalStore::Get(); DGLAPIThreadLocalEntry *e = DGLAPIThreadLocalStore::Get();
e->ret_str = rv.operator std::string(); e->ret_str = rv.operator std::string();
*ret_type_code = kStr; *ret_type_code = kStr;
......
/*! /*!
* Copyright (c) 2016 by Contributors * Copyright (c) 2016-2022 by Contributors
* \file c_runtime_api.cc * \file c_runtime_api.cc
* \brief Runtime API implementation * \brief Runtime API implementation
*/ */
...@@ -26,17 +26,9 @@ namespace runtime { ...@@ -26,17 +26,9 @@ namespace runtime {
*/ */
inline std::string DeviceName(int type) { inline std::string DeviceName(int type) {
switch (type) { switch (type) {
case kDLCPU: return "cpu"; case kDGLCPU: return "cpu";
case kDLGPU: return "gpu"; case kDGLCUDA: return "cuda";
case kDLOpenCL: return "opencl"; // add more device here once supported
case kDLSDAccel: return "sdaccel";
case kDLAOCL: return "aocl";
case kDLVulkan: return "vulkan";
case kDLMetal: return "metal";
case kDLVPI: return "vpi";
case kDLROCM: return "rocm";
case kOpenGL: return "opengl";
case kExtDev: return "ext_dev";
default: LOG(FATAL) << "unknown type =" << type; return "Unknown"; default: LOG(FATAL) << "unknown type =" << type; return "Unknown";
} }
} }
...@@ -99,13 +91,13 @@ DeviceAPI* DeviceAPI::Get(DGLContext ctx, bool allow_missing) { ...@@ -99,13 +91,13 @@ DeviceAPI* DeviceAPI::Get(DGLContext ctx, bool allow_missing) {
static_cast<int>(ctx.device_type), allow_missing); static_cast<int>(ctx.device_type), allow_missing);
} }
DeviceAPI* DeviceAPI::Get(DLDeviceType dev_type, bool allow_missing) { DeviceAPI* DeviceAPI::Get(DGLDeviceType dev_type, bool allow_missing) {
return DeviceAPIManager::Get(static_cast<int>(dev_type), allow_missing); return DeviceAPIManager::Get(static_cast<int>(dev_type), allow_missing);
} }
void* DeviceAPI::AllocWorkspace(DGLContext ctx, void* DeviceAPI::AllocWorkspace(DGLContext ctx,
size_t size, size_t size,
DGLType type_hint) { DGLDataType type_hint) {
return AllocDataSpace(ctx, size, kTempAllocaAlignment, type_hint); return AllocDataSpace(ctx, size, kTempAllocaAlignment, type_hint);
} }
...@@ -213,10 +205,10 @@ void* DGLBackendAllocWorkspace(int device_type, ...@@ -213,10 +205,10 @@ void* DGLBackendAllocWorkspace(int device_type,
int dtype_code_hint, int dtype_code_hint,
int dtype_bits_hint) { int dtype_bits_hint) {
DGLContext ctx; DGLContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type); ctx.device_type = static_cast<DGLDeviceType>(device_type);
ctx.device_id = device_id; ctx.device_id = device_id;
DGLType type_hint; DGLDataType type_hint;
type_hint.code = static_cast<decltype(type_hint.code)>(dtype_code_hint); type_hint.code = static_cast<decltype(type_hint.code)>(dtype_code_hint);
type_hint.bits = static_cast<decltype(type_hint.bits)>(dtype_bits_hint); type_hint.bits = static_cast<decltype(type_hint.bits)>(dtype_bits_hint);
type_hint.lanes = 1; type_hint.lanes = 1;
...@@ -230,7 +222,7 @@ int DGLBackendFreeWorkspace(int device_type, ...@@ -230,7 +222,7 @@ int DGLBackendFreeWorkspace(int device_type,
int device_id, int device_id,
void* ptr) { void* ptr) {
DGLContext ctx; DGLContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type); ctx.device_type = static_cast<DGLDeviceType>(device_type);
ctx.device_id = device_id; ctx.device_id = device_id;
DeviceAPIManager::Get(ctx)->FreeWorkspace(ctx, ptr); DeviceAPIManager::Get(ctx)->FreeWorkspace(ctx, ptr);
return 0; return 0;
...@@ -265,10 +257,10 @@ int DGLFuncCall(DGLFunctionHandle func, ...@@ -265,10 +257,10 @@ int DGLFuncCall(DGLFunctionHandle func,
DGLArgs(args, arg_type_codes, num_args), &rv); DGLArgs(args, arg_type_codes, num_args), &rv);
// handle return string. // handle return string.
if (rv.type_code() == kStr || if (rv.type_code() == kStr ||
rv.type_code() == kDGLType || rv.type_code() == kDGLDataType ||
rv.type_code() == kBytes) { rv.type_code() == kBytes) {
DGLRuntimeEntry* e = DGLAPIRuntimeStore::Get(); DGLRuntimeEntry* e = DGLAPIRuntimeStore::Get();
if (rv.type_code() != kDGLType) { if (rv.type_code() != kDGLDataType) {
e->ret_str = *rv.ptr<std::string>(); e->ret_str = *rv.ptr<std::string>();
} else { } else {
e->ret_str = rv.operator std::string(); e->ret_str = rv.operator std::string();
...@@ -336,7 +328,7 @@ int DGLFuncCreateFromCFunc(DGLPackedCFunc func, ...@@ -336,7 +328,7 @@ int DGLFuncCreateFromCFunc(DGLPackedCFunc func,
int DGLStreamCreate(int device_type, int device_id, DGLStreamHandle* out) { int DGLStreamCreate(int device_type, int device_id, DGLStreamHandle* out) {
API_BEGIN(); API_BEGIN();
DGLContext ctx; DGLContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type); ctx.device_type = static_cast<DGLDeviceType>(device_type);
ctx.device_id = device_id; ctx.device_id = device_id;
*out = DeviceAPIManager::Get(ctx)->CreateStream(ctx); *out = DeviceAPIManager::Get(ctx)->CreateStream(ctx);
API_END(); API_END();
...@@ -345,7 +337,7 @@ int DGLStreamCreate(int device_type, int device_id, DGLStreamHandle* out) { ...@@ -345,7 +337,7 @@ int DGLStreamCreate(int device_type, int device_id, DGLStreamHandle* out) {
int DGLStreamFree(int device_type, int device_id, DGLStreamHandle stream) { int DGLStreamFree(int device_type, int device_id, DGLStreamHandle stream) {
API_BEGIN(); API_BEGIN();
DGLContext ctx; DGLContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type); ctx.device_type = static_cast<DGLDeviceType>(device_type);
ctx.device_id = device_id; ctx.device_id = device_id;
DeviceAPIManager::Get(ctx)->FreeStream(ctx, stream); DeviceAPIManager::Get(ctx)->FreeStream(ctx, stream);
API_END(); API_END();
...@@ -354,7 +346,7 @@ int DGLStreamFree(int device_type, int device_id, DGLStreamHandle stream) { ...@@ -354,7 +346,7 @@ int DGLStreamFree(int device_type, int device_id, DGLStreamHandle stream) {
int DGLSetStream(int device_type, int device_id, DGLStreamHandle stream) { int DGLSetStream(int device_type, int device_id, DGLStreamHandle stream) {
API_BEGIN(); API_BEGIN();
DGLContext ctx; DGLContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type); ctx.device_type = static_cast<DGLDeviceType>(device_type);
ctx.device_id = device_id; ctx.device_id = device_id;
DeviceAPIManager::Get(ctx)->SetStream(ctx, stream); DeviceAPIManager::Get(ctx)->SetStream(ctx, stream);
API_END(); API_END();
...@@ -363,7 +355,7 @@ int DGLSetStream(int device_type, int device_id, DGLStreamHandle stream) { ...@@ -363,7 +355,7 @@ int DGLSetStream(int device_type, int device_id, DGLStreamHandle stream) {
int DGLGetStream(int device_type, int device_id, DGLStreamHandle* stream) { int DGLGetStream(int device_type, int device_id, DGLStreamHandle* stream) {
API_BEGIN(); API_BEGIN();
DGLContext ctx; DGLContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type); ctx.device_type = static_cast<DGLDeviceType>(device_type);
ctx.device_id = device_id; ctx.device_id = device_id;
*stream = DeviceAPIManager::Get(ctx)->GetStream(); *stream = DeviceAPIManager::Get(ctx)->GetStream();
API_END(); API_END();
...@@ -372,7 +364,7 @@ int DGLGetStream(int device_type, int device_id, DGLStreamHandle* stream) { ...@@ -372,7 +364,7 @@ int DGLGetStream(int device_type, int device_id, DGLStreamHandle* stream) {
int DGLSynchronize(int device_type, int device_id, DGLStreamHandle stream) { int DGLSynchronize(int device_type, int device_id, DGLStreamHandle stream) {
API_BEGIN(); API_BEGIN();
DGLContext ctx; DGLContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type); ctx.device_type = static_cast<DGLDeviceType>(device_type);
ctx.device_id = device_id; ctx.device_id = device_id;
DeviceAPIManager::Get(ctx)->StreamSync(ctx, stream); DeviceAPIManager::Get(ctx)->StreamSync(ctx, stream);
API_END(); API_END();
...@@ -384,7 +376,7 @@ int DGLStreamStreamSynchronize(int device_type, ...@@ -384,7 +376,7 @@ int DGLStreamStreamSynchronize(int device_type,
DGLStreamHandle dst) { DGLStreamHandle dst) {
API_BEGIN(); API_BEGIN();
DGLContext ctx; DGLContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type); ctx.device_type = static_cast<DGLDeviceType>(device_type);
ctx.device_id = device_id; ctx.device_id = device_id;
DeviceAPIManager::Get(ctx)->SyncStreamFromTo(ctx, src, dst); DeviceAPIManager::Get(ctx)->SyncStreamFromTo(ctx, src, dst);
API_END(); API_END();
...@@ -408,7 +400,7 @@ int DGLLoadTensorAdapter(const char *path) { ...@@ -408,7 +400,7 @@ int DGLLoadTensorAdapter(const char *path) {
DGL_REGISTER_GLOBAL(dgl::runtime::symbol::dgl_set_device) DGL_REGISTER_GLOBAL(dgl::runtime::symbol::dgl_set_device)
.set_body([](DGLArgs args, DGLRetValue *ret) { .set_body([](DGLArgs args, DGLRetValue *ret) {
DGLContext ctx; DGLContext ctx;
ctx.device_type = static_cast<DLDeviceType>(args[0].operator int()); ctx.device_type = static_cast<DGLDeviceType>(args[0].operator int());
ctx.device_id = args[1]; ctx.device_id = args[1];
DeviceAPIManager::Get(ctx)->SetDevice(ctx); DeviceAPIManager::Get(ctx)->SetDevice(ctx);
}); });
...@@ -417,7 +409,7 @@ DGL_REGISTER_GLOBAL(dgl::runtime::symbol::dgl_set_device) ...@@ -417,7 +409,7 @@ DGL_REGISTER_GLOBAL(dgl::runtime::symbol::dgl_set_device)
DGL_REGISTER_GLOBAL("_GetDeviceAttr") DGL_REGISTER_GLOBAL("_GetDeviceAttr")
.set_body([](DGLArgs args, DGLRetValue *ret) { .set_body([](DGLArgs args, DGLRetValue *ret) {
DGLContext ctx; DGLContext ctx;
ctx.device_type = static_cast<DLDeviceType>(args[0].operator int()); ctx.device_type = static_cast<DGLDeviceType>(args[0].operator int());
ctx.device_id = args[1]; ctx.device_id = args[1];
DeviceAttrKind kind = static_cast<DeviceAttrKind>(args[2].operator int()); DeviceAttrKind kind = static_cast<DeviceAttrKind>(args[2].operator int());
......
...@@ -24,7 +24,7 @@ class CPUDeviceAPI final : public DeviceAPI { ...@@ -24,7 +24,7 @@ class CPUDeviceAPI final : public DeviceAPI {
void* AllocDataSpace(DGLContext ctx, void* AllocDataSpace(DGLContext ctx,
size_t nbytes, size_t nbytes,
size_t alignment, size_t alignment,
DGLType type_hint) final { DGLDataType type_hint) final {
TensorDispatcher* td = TensorDispatcher::Global(); TensorDispatcher* td = TensorDispatcher::Global();
if (td->IsAvailable()) if (td->IsAvailable())
return td->CPUAllocWorkspace(nbytes); return td->CPUAllocWorkspace(nbytes);
...@@ -62,7 +62,7 @@ class CPUDeviceAPI final : public DeviceAPI { ...@@ -62,7 +62,7 @@ class CPUDeviceAPI final : public DeviceAPI {
size_t size, size_t size,
DGLContext ctx_from, DGLContext ctx_from,
DGLContext ctx_to, DGLContext ctx_to,
DGLType type_hint) final { DGLDataType type_hint) final {
memcpy(static_cast<char*>(to) + to_offset, memcpy(static_cast<char*>(to) + to_offset,
static_cast<const char*>(from) + from_offset, static_cast<const char*>(from) + from_offset,
size); size);
...@@ -73,7 +73,7 @@ class CPUDeviceAPI final : public DeviceAPI { ...@@ -73,7 +73,7 @@ class CPUDeviceAPI final : public DeviceAPI {
void StreamSync(DGLContext ctx, DGLStreamHandle stream) final { void StreamSync(DGLContext ctx, DGLStreamHandle stream) final {
} }
void* AllocWorkspace(DGLContext ctx, size_t size, DGLType type_hint) final; void* AllocWorkspace(DGLContext ctx, size_t size, DGLDataType type_hint) final;
void FreeWorkspace(DGLContext ctx, void* data) final; void FreeWorkspace(DGLContext ctx, void* data) final;
static const std::shared_ptr<CPUDeviceAPI>& Global() { static const std::shared_ptr<CPUDeviceAPI>& Global() {
...@@ -85,12 +85,12 @@ class CPUDeviceAPI final : public DeviceAPI { ...@@ -85,12 +85,12 @@ class CPUDeviceAPI final : public DeviceAPI {
struct CPUWorkspacePool : public WorkspacePool { struct CPUWorkspacePool : public WorkspacePool {
CPUWorkspacePool() : CPUWorkspacePool() :
WorkspacePool(kDLCPU, CPUDeviceAPI::Global()) {} WorkspacePool(kDGLCPU, CPUDeviceAPI::Global()) {}
}; };
void* CPUDeviceAPI::AllocWorkspace(DGLContext ctx, void* CPUDeviceAPI::AllocWorkspace(DGLContext ctx,
size_t size, size_t size,
DGLType type_hint) { DGLDataType type_hint) {
TensorDispatcher* td = TensorDispatcher::Global(); TensorDispatcher* td = TensorDispatcher::Global();
if (td->IsAvailable()) if (td->IsAvailable())
return td->CPUAllocWorkspace(size); return td->CPUAllocWorkspace(size);
......
...@@ -106,7 +106,7 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -106,7 +106,7 @@ class CUDADeviceAPI final : public DeviceAPI {
void* AllocDataSpace(DGLContext ctx, void* AllocDataSpace(DGLContext ctx,
size_t nbytes, size_t nbytes,
size_t alignment, size_t alignment,
DGLType type_hint) final { DGLDataType type_hint) final {
SetDevice(ctx); SetDevice(ctx);
// Redirect to PyTorch's allocator when available. // Redirect to PyTorch's allocator when available.
TensorDispatcher* td = TensorDispatcher::Global(); TensorDispatcher* td = TensorDispatcher::Global();
...@@ -136,12 +136,12 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -136,12 +136,12 @@ class CUDADeviceAPI final : public DeviceAPI {
size_t size, size_t size,
DGLContext ctx_from, DGLContext ctx_from,
DGLContext ctx_to, DGLContext ctx_to,
DGLType type_hint, DGLDataType type_hint,
DGLStreamHandle stream) { DGLStreamHandle stream) {
cudaStream_t cu_stream = static_cast<cudaStream_t>(stream); cudaStream_t cu_stream = static_cast<cudaStream_t>(stream);
from = static_cast<const char*>(from) + from_offset; from = static_cast<const char*>(from) + from_offset;
to = static_cast<char*>(to) + to_offset; to = static_cast<char*>(to) + to_offset;
if (ctx_from.device_type == kDLGPU && ctx_to.device_type == kDLGPU) { if (ctx_from.device_type == kDGLCUDA && ctx_to.device_type == kDGLCUDA) {
CUDA_CALL(cudaSetDevice(ctx_from.device_id)); CUDA_CALL(cudaSetDevice(ctx_from.device_id));
if (ctx_from.device_id == ctx_to.device_id) { if (ctx_from.device_id == ctx_to.device_id) {
GPUCopy(from, to, size, cudaMemcpyDeviceToDevice, cu_stream); GPUCopy(from, to, size, cudaMemcpyDeviceToDevice, cu_stream);
...@@ -150,10 +150,10 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -150,10 +150,10 @@ class CUDADeviceAPI final : public DeviceAPI {
from, ctx_from.device_id, from, ctx_from.device_id,
size, cu_stream)); size, cu_stream));
} }
} else if (ctx_from.device_type == kDLGPU && ctx_to.device_type == kDLCPU) { } else if (ctx_from.device_type == kDGLCUDA && ctx_to.device_type == kDGLCPU) {
CUDA_CALL(cudaSetDevice(ctx_from.device_id)); CUDA_CALL(cudaSetDevice(ctx_from.device_id));
GPUCopy(from, to, size, cudaMemcpyDeviceToHost, cu_stream); GPUCopy(from, to, size, cudaMemcpyDeviceToHost, cu_stream);
} else if (ctx_from.device_type == kDLCPU && ctx_to.device_type == kDLGPU) { } else if (ctx_from.device_type == kDGLCPU && ctx_to.device_type == kDGLCUDA) {
CUDA_CALL(cudaSetDevice(ctx_to.device_id)); CUDA_CALL(cudaSetDevice(ctx_to.device_id));
GPUCopy(from, to, size, cudaMemcpyHostToDevice, cu_stream); GPUCopy(from, to, size, cudaMemcpyHostToDevice, cu_stream);
} else { } else {
...@@ -168,7 +168,7 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -168,7 +168,7 @@ class CUDADeviceAPI final : public DeviceAPI {
size_t size, size_t size,
DGLContext ctx_from, DGLContext ctx_from,
DGLContext ctx_to, DGLContext ctx_to,
DGLType type_hint) final { DGLDataType type_hint) final {
auto stream = GetStream(); auto stream = GetStream();
CopyDataFromTo(from, from_offset, to, to_offset, size, ctx_from, ctx_to, type_hint, stream); CopyDataFromTo(from, from_offset, to, to_offset, size, ctx_from, ctx_to, type_hint, stream);
} }
...@@ -269,7 +269,7 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -269,7 +269,7 @@ class CUDADeviceAPI final : public DeviceAPI {
return result; return result;
} }
void* AllocWorkspace(DGLContext ctx, size_t size, DGLType type_hint) final { void* AllocWorkspace(DGLContext ctx, size_t size, DGLDataType type_hint) final {
SetDevice(ctx); SetDevice(ctx);
// Redirect to PyTorch's allocator when available. // Redirect to PyTorch's allocator when available.
TensorDispatcher* td = TensorDispatcher::Global(); TensorDispatcher* td = TensorDispatcher::Global();
...@@ -313,7 +313,7 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -313,7 +313,7 @@ class CUDADeviceAPI final : public DeviceAPI {
typedef dmlc::ThreadLocalStore<CUDAThreadEntry> CUDAThreadStore; typedef dmlc::ThreadLocalStore<CUDAThreadEntry> CUDAThreadStore;
CUDAThreadEntry::CUDAThreadEntry() CUDAThreadEntry::CUDAThreadEntry()
: pool(kDLGPU, CUDADeviceAPI::Global()) { : pool(kDGLCUDA, CUDADeviceAPI::Global()) {
} }
CUDAThreadEntry* CUDAThreadEntry::ThreadLocal() { CUDAThreadEntry* CUDAThreadEntry::ThreadLocal() {
...@@ -328,7 +328,7 @@ cudaStream_t getCurrentCUDAStream() { ...@@ -328,7 +328,7 @@ cudaStream_t getCurrentCUDAStream() {
return nullptr; return nullptr;
} }
DGL_REGISTER_GLOBAL("device_api.gpu") DGL_REGISTER_GLOBAL("device_api.cuda")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
DeviceAPI* ptr = CUDADeviceAPI::Global().get(); DeviceAPI* ptr = CUDADeviceAPI::Global().get();
*rv = static_cast<void*>(ptr); *rv = static_cast<void*>(ptr);
......
...@@ -222,8 +222,8 @@ std::pair<IdArray, NDArray> SparsePush( ...@@ -222,8 +222,8 @@ std::pair<IdArray, NDArray> SparsePush(
0, 0,
send_prefix_host.size()*sizeof(*send_prefix.get()), send_prefix_host.size()*sizeof(*send_prefix.get()),
ctx, ctx,
DGLContext{kDLCPU, 0}, DGLContext{kDGLCPU, 0},
DGLType{kDLInt, sizeof(*send_prefix.get())*8, 1}); DGLDataType{kDGLInt, sizeof(*send_prefix.get())*8, 1});
send_prefix.free(); send_prefix.free();
CHECK_EQ(send_prefix_host.back(), num_in) << "Internal Error: " CHECK_EQ(send_prefix_host.back(), num_in) << "Internal Error: "
...@@ -260,8 +260,8 @@ std::pair<IdArray, NDArray> SparsePush( ...@@ -260,8 +260,8 @@ std::pair<IdArray, NDArray> SparsePush(
0, 0,
recv_prefix_host.size()*sizeof(*recv_prefix.get()), recv_prefix_host.size()*sizeof(*recv_prefix.get()),
ctx, ctx,
DGLContext{kDLCPU, 0}, DGLContext{kDGLCPU, 0},
DGLType{kDLInt, sizeof(*recv_prefix.get())*8, 1}); DGLDataType{kDGLInt, sizeof(*recv_prefix.get())*8, 1});
recv_prefix.free(); recv_prefix.free();
// use an event to track when copying is done // use an event to track when copying is done
...@@ -376,8 +376,8 @@ NDArray SparsePull( ...@@ -376,8 +376,8 @@ NDArray SparsePull(
0, 0,
request_prefix_host.size()*sizeof(*request_prefix.get()), request_prefix_host.size()*sizeof(*request_prefix.get()),
ctx, ctx,
DGLContext{kDLCPU, 0}, DGLContext{kDGLCPU, 0},
DGLType{kDLInt, sizeof(*request_prefix.get())*8, 1}); DGLDataType{kDGLInt, sizeof(*request_prefix.get())*8, 1});
request_prefix.free(); request_prefix.free();
CHECK_EQ(request_prefix_host.back(), num_in) << "Internal Error: " CHECK_EQ(request_prefix_host.back(), num_in) << "Internal Error: "
"request_prefix_host.back() = " << request_prefix_host.back() << "request_prefix_host.back() = " << request_prefix_host.back() <<
...@@ -411,8 +411,8 @@ NDArray SparsePull( ...@@ -411,8 +411,8 @@ NDArray SparsePull(
0, 0,
response_prefix_host.size()*sizeof(*response_prefix.get()), response_prefix_host.size()*sizeof(*response_prefix.get()),
ctx, ctx,
DGLContext{kDLCPU, 0}, DGLContext{kDGLCPU, 0},
DGLType{kDLInt, sizeof(*response_prefix.get())*8, 1}); DGLDataType{kDGLInt, sizeof(*response_prefix.get())*8, 1});
response_prefix.free(); response_prefix.free();
// use an event to track when copying is done // use an event to track when copying is done
...@@ -617,10 +617,10 @@ void NCCLCommunicator::AllToAllV( ...@@ -617,10 +617,10 @@ void NCCLCommunicator::AllToAllV(
int dev_id; int dev_id;
CUDA_CALL(cudaGetDevice(&dev_id)); CUDA_CALL(cudaGetDevice(&dev_id));
DGLContext ctx{kDLGPU, dev_id}; DGLContext ctx{kDGLCUDA, dev_id};
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
auto dtype = DLDataTypeTraits<DType>::dtype; auto dtype = DGLDataTypeTraits<DType>::dtype;
// copy using the same stream (local current stream), no need to sync // copy using the same stream (local current stream), no need to sync
device->CopyDataFromTo(send, send_prefix[0], device->CopyDataFromTo(send, send_prefix[0],
...@@ -679,10 +679,10 @@ void NCCLCommunicator::AllToAll( ...@@ -679,10 +679,10 @@ void NCCLCommunicator::AllToAll(
#else #else
int dev_id; int dev_id;
CUDA_CALL(cudaGetDevice(&dev_id)); CUDA_CALL(cudaGetDevice(&dev_id));
DGLContext ctx{kDLGPU, dev_id}; DGLContext ctx{kDGLCUDA, dev_id};
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
auto dtype = DLDataTypeTraits<IdType>::dtype; auto dtype = DGLDataTypeTraits<IdType>::dtype;
// copy using the same stream (local current stream), no need to sync // copy using the same stream (local current stream), no need to sync
device->CopyDataFromTo(send, 0, recv, 0, count, ctx, ctx, dtype); device->CopyDataFromTo(send, 0, recv, 0, count, ctx, ctx, dtype);
......
/*!
* Copyright (c) 2022 by Contributors
* \file src/runtime/dlpack_convert.cc
* \brief Conversion between NDArray and DLPack.
*/
#include <dgl/runtime/dlpack_convert.h>
#include <dlpack/dlpack.h>
#include <dgl/runtime/ndarray.h>
#include <dgl/runtime/c_runtime_api.h>
#include <dgl/runtime/device_api.h>
#include "runtime_base.h"
// deleter for arrays used by DLPack exporter
extern "C" void NDArrayDLPackDeleter(DLManagedTensor* tensor);
namespace dgl {
namespace runtime {
void NDArrayDLPackDeleter(DLManagedTensor* tensor) {
static_cast<NDArray::Container*>(tensor->manager_ctx)->DecRef();
delete tensor;
}
inline DGLContext ToDGLContext(const DLDevice& device) {
DGLContext ctx;
ctx.device_type = static_cast<DGLDeviceType>(device.device_type);
ctx.device_id = device.device_id;
return ctx;
}
inline DGLDataType ToDGLDataType(const DLDataType& src) {
DGLDataType ret;
ret.code = src.code;
ret.bits = src.bits;
ret.lanes = src.lanes;
return ret;
}
inline DLDevice ToDLDevice(const DGLContext& ctx) {
DLDevice device;
device.device_type = static_cast<DLDeviceType>(ctx.device_type);
device.device_id = ctx.device_id;
return device;
}
inline DLDataType ToDLDataType(const DGLDataType& src) {
DLDataType ret;
ret.code = src.code;
ret.bits = src.bits;
ret.lanes = src.lanes;
return ret;
}
NDArray DLPackConvert::FromDLPack(DLManagedTensor* tensor) {
NDArray::Container* data = new NDArray::Container();
data->deleter = DLPackConvert::DLPackDeleter;
data->manager_ctx = tensor;
data->dl_tensor.data = tensor->dl_tensor.data;
data->dl_tensor.ctx = ToDGLContext(tensor->dl_tensor.device);
data->dl_tensor.ndim = tensor->dl_tensor.ndim;
data->dl_tensor.dtype = ToDGLDataType(tensor->dl_tensor.dtype);
data->dl_tensor.shape = tensor->dl_tensor.shape;
data->dl_tensor.strides = tensor->dl_tensor.strides;
data->dl_tensor.byte_offset = tensor->dl_tensor.byte_offset;
return NDArray(data);
}
void DLPackConvert::DLPackDeleter(NDArray::Container* ptr) {
// if the array is pinned by dgl, unpin it before freeing
if (ptr->pinned_by_dgl_)
NDArray::UnpinContainer(ptr);
DLManagedTensor* tensor = static_cast<DLManagedTensor*>(ptr->manager_ctx);
if (tensor->deleter != nullptr) {
(*tensor->deleter)(tensor);
}
delete ptr;
}
DLManagedTensor* ContainerToDLPack(NDArray::Container* from) {
CHECK(from != nullptr);
DLManagedTensor* ret = new DLManagedTensor();
ret->dl_tensor.data = from->dl_tensor.data;
ret->dl_tensor.device = ToDLDevice(from->dl_tensor.ctx);
ret->dl_tensor.ndim = from->dl_tensor.ndim;
ret->dl_tensor.dtype = ToDLDataType(from->dl_tensor.dtype);
ret->dl_tensor.shape = from->dl_tensor.shape;
ret->dl_tensor.strides = from->dl_tensor.strides;
ret->dl_tensor.byte_offset = from->dl_tensor.byte_offset;
ret->manager_ctx = from;
from->IncRef();
ret->deleter = NDArrayDLPackDeleter;
return ret;
}
DLManagedTensor* DLPackConvert::ToDLPack(const NDArray &from) {
return ContainerToDLPack(from.data_);
}
} // namespace runtime
} // namespace dgl
using namespace dgl::runtime;
void DGLDLManagedTensorCallDeleter(DLManagedTensor* dltensor) {
(*(dltensor->deleter))(dltensor);
}
inline bool IsAligned(const void* ptr, std::uintptr_t alignment) noexcept {
auto iptr = reinterpret_cast<std::uintptr_t>(ptr);
return !(iptr % alignment);
}
int DGLArrayFromDLPack(DLManagedTensor* from,
DGLArrayHandle* out) {
API_BEGIN();
*out = NDArray::Internal::MoveAsDGLArray(DLPackConvert::FromDLPack(from));
API_END();
}
int DGLArrayToDLPack(DGLArrayHandle from, DLManagedTensor** out,
int alignment) {
API_BEGIN();
auto* nd_container = reinterpret_cast<NDArray::Container*>(from);
DGLArray* nd = &(nd_container->dl_tensor);
// If the source DGLArray is not aligned, we should create a new aligned one
if (alignment != 0 && !IsAligned(nd->data, alignment)) {
std::vector<int64_t> shape_vec(nd->shape, nd->shape + nd->ndim);
NDArray copy_ndarray = NDArray::Empty(shape_vec, nd->dtype, nd->ctx);
copy_ndarray.CopyFrom(nd);
*out = DLPackConvert::ToDLPack(copy_ndarray);
} else {
*out = ContainerToDLPack(nd_container);
}
API_END();
}
...@@ -17,7 +17,7 @@ namespace runtime { ...@@ -17,7 +17,7 @@ namespace runtime {
void FunctionInfo::Save(dmlc::JSONWriter* writer) const { void FunctionInfo::Save(dmlc::JSONWriter* writer) const {
std::vector<std::string> sarg_types(arg_types.size()); std::vector<std::string> sarg_types(arg_types.size());
for (size_t i = 0; i < arg_types.size(); ++i) { for (size_t i = 0; i < arg_types.size(); ++i) {
sarg_types[i] = DGLType2String(arg_types[i]); sarg_types[i] = DGLDataType2String(arg_types[i]);
} }
writer->BeginObject(); writer->BeginObject();
writer->WriteObjectKeyValue("name", name); writer->WriteObjectKeyValue("name", name);
...@@ -35,7 +35,7 @@ void FunctionInfo::Load(dmlc::JSONReader* reader) { ...@@ -35,7 +35,7 @@ void FunctionInfo::Load(dmlc::JSONReader* reader) {
helper.ReadAllFields(reader); helper.ReadAllFields(reader);
arg_types.resize(sarg_types.size()); arg_types.resize(sarg_types.size());
for (size_t i = 0; i < arg_types.size(); ++i) { for (size_t i = 0; i < arg_types.size(); ++i) {
arg_types[i] = String2DGLType(sarg_types[i]); arg_types[i] = String2DGLDataType(sarg_types[i]);
} }
} }
......
...@@ -19,7 +19,7 @@ namespace runtime { ...@@ -19,7 +19,7 @@ namespace runtime {
/*! \brief function information needed by device */ /*! \brief function information needed by device */
struct FunctionInfo { struct FunctionInfo {
std::string name; std::string name;
std::vector<DGLType> arg_types; std::vector<DGLDataType> arg_types;
std::vector<std::string> thread_axis_tags; std::vector<std::string> thread_axis_tags;
void Save(dmlc::JSONWriter *writer) const; void Save(dmlc::JSONWriter *writer) const;
......
...@@ -105,7 +105,7 @@ bool RuntimeEnabled(const std::string& target) { ...@@ -105,7 +105,7 @@ bool RuntimeEnabled(const std::string& target) {
if (target == "cpu") { if (target == "cpu") {
return true; return true;
} else if (target == "cuda" || target == "gpu") { } else if (target == "cuda" || target == "gpu") {
f_name = "device_api.gpu"; f_name = "device_api.cuda";
} else if (target == "cl" || target == "opencl" || target == "sdaccel") { } else if (target == "cl" || target == "opencl" || target == "sdaccel") {
f_name = "device_api.opencl"; f_name = "device_api.opencl";
} else if (target == "gl" || target == "opengl") { } else if (target == "gl" || target == "opengl") {
...@@ -121,7 +121,7 @@ bool RuntimeEnabled(const std::string& target) { ...@@ -121,7 +121,7 @@ bool RuntimeEnabled(const std::string& target) {
} else if (target == "vpi" || target == "verilog") { } else if (target == "vpi" || target == "verilog") {
f_name = "device_api.vpi"; f_name = "device_api.vpi";
} else if (target.length() >= 5 && target.substr(0, 5) == "nvptx") { } else if (target.length() >= 5 && target.substr(0, 5) == "nvptx") {
f_name = "device_api.gpu"; f_name = "device_api.cuda";
} else if (target.length() >= 4 && target.substr(0, 4) == "rocm") { } else if (target.length() >= 4 && target.substr(0, 4) == "rocm") {
f_name = "device_api.rocm"; f_name = "device_api.rocm";
} else if (target.length() >= 4 && target.substr(0, 4) == "llvm") { } else if (target.length() >= 4 && target.substr(0, 4) == "llvm") {
......
...@@ -13,28 +13,25 @@ ...@@ -13,28 +13,25 @@
#include <dgl/runtime/tensordispatch.h> #include <dgl/runtime/tensordispatch.h>
#include "runtime_base.h" #include "runtime_base.h"
// deleter for arrays used by DLPack exporter
extern "C" void NDArrayDLPackDeleter(DLManagedTensor* tensor);
namespace dgl { namespace dgl {
constexpr DLDataType DLDataTypeTraits<int8_t>::dtype; constexpr DGLDataType DGLDataTypeTraits<int8_t>::dtype;
constexpr DLDataType DLDataTypeTraits<int16_t>::dtype; constexpr DGLDataType DGLDataTypeTraits<int16_t>::dtype;
constexpr DLDataType DLDataTypeTraits<int32_t>::dtype; constexpr DGLDataType DGLDataTypeTraits<int32_t>::dtype;
constexpr DLDataType DLDataTypeTraits<int64_t>::dtype; constexpr DGLDataType DGLDataTypeTraits<int64_t>::dtype;
constexpr DLDataType DLDataTypeTraits<uint32_t>::dtype; constexpr DGLDataType DGLDataTypeTraits<uint32_t>::dtype;
constexpr DLDataType DLDataTypeTraits<uint64_t>::dtype; constexpr DGLDataType DGLDataTypeTraits<uint64_t>::dtype;
#ifdef USE_FP16 #ifdef USE_FP16
constexpr DLDataType DLDataTypeTraits<__half>::dtype; constexpr DGLDataType DGLDataTypeTraits<__half>::dtype;
#endif #endif
constexpr DLDataType DLDataTypeTraits<float>::dtype; constexpr DGLDataType DGLDataTypeTraits<float>::dtype;
constexpr DLDataType DLDataTypeTraits<double>::dtype; constexpr DGLDataType DGLDataTypeTraits<double>::dtype;
namespace runtime { namespace runtime {
inline void VerifyDataType(DLDataType dtype) { inline void VerifyDataType(DGLDataType dtype) {
CHECK_GE(dtype.lanes, 1); CHECK_GE(dtype.lanes, 1);
if (dtype.code == kDLFloat) { if (dtype.code == kDGLFloat) {
CHECK_EQ(dtype.bits % 8, 0); CHECK_EQ(dtype.bits % 8, 0);
} else { } else {
CHECK_EQ(dtype.bits % 8, 0); CHECK_EQ(dtype.bits % 8, 0);
...@@ -42,7 +39,7 @@ inline void VerifyDataType(DLDataType dtype) { ...@@ -42,7 +39,7 @@ inline void VerifyDataType(DLDataType dtype) {
CHECK_EQ(dtype.bits & (dtype.bits - 1), 0); CHECK_EQ(dtype.bits & (dtype.bits - 1), 0);
} }
inline size_t GetDataSize(const DLTensor& arr) { inline size_t GetDataSize(const DGLArray& arr) {
size_t size = 1; size_t size = 1;
for (dgl_index_t i = 0; i < arr.ndim; ++i) { for (dgl_index_t i = 0; i < arr.ndim; ++i) {
size *= arr.shape[i]; size *= arr.shape[i];
...@@ -51,91 +48,61 @@ inline size_t GetDataSize(const DLTensor& arr) { ...@@ -51,91 +48,61 @@ inline size_t GetDataSize(const DLTensor& arr) {
return size; return size;
} }
inline size_t GetDataAlignment(const DLTensor& arr) { inline size_t GetDataAlignment(const DGLArray& arr) {
size_t align = (arr.dtype.bits / 8) * arr.dtype.lanes; size_t align = (arr.dtype.bits / 8) * arr.dtype.lanes;
if (align < kAllocAlignment) return kAllocAlignment; if (align < kAllocAlignment) return kAllocAlignment;
return align; return align;
} }
struct NDArray::Internal { void NDArray::Internal::DefaultDeleter(NDArray::Container* ptr) {
// Default deleter for the container using dgl::runtime::NDArray;
static void DefaultDeleter(NDArray::Container* ptr) { if (ptr->manager_ctx != nullptr) {
using dgl::runtime::NDArray; static_cast<NDArray::Container*>(ptr->manager_ctx)->DecRef();
if (ptr->manager_ctx != nullptr) { } else if (ptr->mem) {
static_cast<NDArray::Container*>(ptr->manager_ctx)->DecRef(); ptr->mem = nullptr;
} else if (ptr->mem) { } else if (ptr->dl_tensor.data != nullptr) {
ptr->mem = nullptr; // if the array is still pinned before freeing, unpin it.
} else if (ptr->dl_tensor.data != nullptr) {
// if the array is still pinned before freeing, unpin it.
if (ptr->pinned_by_dgl_)
UnpinContainer(ptr);
dgl::runtime::DeviceAPI::Get(ptr->dl_tensor.ctx)->FreeDataSpace(
ptr->dl_tensor.ctx, ptr->dl_tensor.data);
}
delete ptr;
}
// Deleter for NDArray converted from DLPack
// This is used from data which is passed from external DLPack(DLManagedTensor)
// that are not allocated inside of DGL.
// This enables us to create NDArray from memory allocated by other
// frameworks that are DLPack compatible
static void DLPackDeleter(NDArray::Container* ptr) {
// if the array is pinned by dgl, unpin it before freeing
if (ptr->pinned_by_dgl_) if (ptr->pinned_by_dgl_)
UnpinContainer(ptr); UnpinContainer(ptr);
DLManagedTensor* tensor = static_cast<DLManagedTensor*>(ptr->manager_ctx); dgl::runtime::DeviceAPI::Get(ptr->dl_tensor.ctx)->FreeDataSpace(
if (tensor->deleter != nullptr) { ptr->dl_tensor.ctx, ptr->dl_tensor.data);
(*tensor->deleter)(tensor);
}
delete ptr;
}
// Local create function which allocates tensor metadata
// but does not allocate space for the data.
static NDArray Create(std::vector<int64_t> shape,
DLDataType dtype,
DLContext ctx) {
VerifyDataType(dtype);
// critical zone
NDArray::Container* data = new NDArray::Container();
data->deleter = DefaultDeleter;
NDArray ret(data);
ret.data_ = data;
// RAII now in effect
// setup shape
data->shape_ = std::move(shape);
data->dl_tensor.shape = dmlc::BeginPtr(data->shape_);
data->dl_tensor.ndim = static_cast<int>(data->shape_.size());
// setup stride (this should be optional, but some framework
// does not support NULL stride and thus will crash the program).
data->stride_.resize(data->dl_tensor.ndim, 1);
for (int i = data->dl_tensor.ndim - 2; i >= 0; --i) {
data->stride_[i] = data->shape_[i+1] * data->stride_[i+1];
}
data->dl_tensor.strides = dmlc::BeginPtr(data->stride_);
// setup dtype
data->dl_tensor.dtype = dtype;
// setup ctx
data->dl_tensor.ctx = ctx;
return ret;
}
// Implementation of API function
static DLTensor* MoveAsDLTensor(NDArray arr) {
DLTensor* tensor = reinterpret_cast<DLTensor*>(arr.data_);
CHECK(tensor == const_cast<DLTensor*>(arr.operator->()));
arr.data_ = nullptr;
return tensor;
} }
// Container to DLManagedTensor delete ptr;
static DLManagedTensor* ToDLPack(NDArray::Container* from) { }
CHECK(from != nullptr);
DLManagedTensor* ret = new DLManagedTensor(); NDArray NDArray::Internal::Create(std::vector<int64_t> shape,
ret->dl_tensor = from->dl_tensor; DGLDataType dtype, DGLContext ctx) {
ret->manager_ctx = from; VerifyDataType(dtype);
from->IncRef(); // critical zone
ret->deleter = NDArrayDLPackDeleter; NDArray::Container* data = new NDArray::Container();
return ret; data->deleter = DefaultDeleter;
NDArray ret(data);
ret.data_ = data;
// RAII now in effect
// setup shape
data->shape_ = std::move(shape);
data->dl_tensor.shape = dmlc::BeginPtr(data->shape_);
data->dl_tensor.ndim = static_cast<int>(data->shape_.size());
// setup stride (this should be optional, but some framework
// does not support NULL stride and thus will crash the program).
data->stride_.resize(data->dl_tensor.ndim, 1);
for (int i = data->dl_tensor.ndim - 2; i >= 0; --i) {
data->stride_[i] = data->shape_[i+1] * data->stride_[i+1];
} }
}; data->dl_tensor.strides = dmlc::BeginPtr(data->stride_);
// setup dtype
data->dl_tensor.dtype = dtype;
// setup ctx
data->dl_tensor.ctx = ctx;
return ret;
}
DGLArray* NDArray::Internal::MoveAsDGLArray(NDArray arr) {
DGLArray* tensor = reinterpret_cast<DGLArray*>(arr.data_);
CHECK(tensor == const_cast<DGLArray*>(arr.operator->()));
arr.data_ = nullptr;
return tensor;
}
size_t NDArray::GetSize() const { size_t NDArray::GetSize() const {
return GetDataSize(data_->dl_tensor); return GetDataSize(data_->dl_tensor);
...@@ -170,7 +137,7 @@ bool NDArray::IsContiguous() const { ...@@ -170,7 +137,7 @@ bool NDArray::IsContiguous() const {
} }
NDArray NDArray::CreateView(std::vector<int64_t> shape, NDArray NDArray::CreateView(std::vector<int64_t> shape,
DLDataType dtype, DGLDataType dtype,
int64_t offset) { int64_t offset) {
CHECK(data_ != nullptr); CHECK(data_ != nullptr);
CHECK(IsContiguous()) << "Can only create view for compact tensor"; CHECK(IsContiguous()) << "Can only create view for compact tensor";
...@@ -189,14 +156,10 @@ NDArray NDArray::CreateView(std::vector<int64_t> shape, ...@@ -189,14 +156,10 @@ NDArray NDArray::CreateView(std::vector<int64_t> shape,
return ret; return ret;
} }
DLManagedTensor* NDArray::ToDLPack() const {
return Internal::ToDLPack(data_);
}
NDArray NDArray::EmptyShared(const std::string &name, NDArray NDArray::EmptyShared(const std::string &name,
std::vector<int64_t> shape, std::vector<int64_t> shape,
DLDataType dtype, DGLDataType dtype,
DLContext ctx, bool is_create) { DGLContext ctx, bool is_create) {
NDArray ret = Internal::Create(shape, dtype, ctx); NDArray ret = Internal::Create(shape, dtype, ctx);
// setup memory content // setup memory content
size_t size = GetDataSize(ret.data_->dl_tensor); size_t size = GetDataSize(ret.data_->dl_tensor);
...@@ -212,8 +175,8 @@ NDArray NDArray::EmptyShared(const std::string &name, ...@@ -212,8 +175,8 @@ NDArray NDArray::EmptyShared(const std::string &name,
} }
NDArray NDArray::Empty(std::vector<int64_t> shape, NDArray NDArray::Empty(std::vector<int64_t> shape,
DLDataType dtype, DGLDataType dtype,
DLContext ctx) { DGLContext ctx) {
NDArray ret = Internal::Create(shape, dtype, ctx); NDArray ret = Internal::Create(shape, dtype, ctx);
// setup memory content // setup memory content
size_t size = GetDataSize(ret.data_->dl_tensor); size_t size = GetDataSize(ret.data_->dl_tensor);
...@@ -225,30 +188,21 @@ NDArray NDArray::Empty(std::vector<int64_t> shape, ...@@ -225,30 +188,21 @@ NDArray NDArray::Empty(std::vector<int64_t> shape,
return ret; return ret;
} }
NDArray NDArray::FromDLPack(DLManagedTensor* tensor) { void NDArray::CopyFromTo(DGLArray* from,
NDArray::Container* data = new NDArray::Container(); DGLArray* to) {
data->deleter = Internal::DLPackDeleter;
data->manager_ctx = tensor;
data->dl_tensor = tensor->dl_tensor;
return NDArray(data);
}
void NDArray::CopyFromTo(DLTensor* from,
DLTensor* to) {
size_t from_size = GetDataSize(*from); size_t from_size = GetDataSize(*from);
size_t to_size = GetDataSize(*to); size_t to_size = GetDataSize(*to);
CHECK_EQ(from_size, to_size) CHECK_EQ(from_size, to_size)
<< "DGLArrayCopyFromTo: The size must exactly match"; << "DGLArrayCopyFromTo: The size must exactly match";
CHECK(from->ctx.device_type == to->ctx.device_type CHECK(from->ctx.device_type == to->ctx.device_type
|| from->ctx.device_type == kDLCPU || from->ctx.device_type == kDGLCPU
|| to->ctx.device_type == kDLCPU) || to->ctx.device_type == kDGLCPU)
<< "Can not copy across different ctx types directly"; << "Can not copy across different ctx types directly";
// Use the context that is *not* a cpu context to get the correct device // Use the context that is *not* a cpu context to get the correct device
// api manager. // api manager.
DGLContext ctx = from->ctx.device_type != kDLCPU ? from->ctx : to->ctx; DGLContext ctx = from->ctx.device_type != kDGLCPU ? from->ctx : to->ctx;
// default: local current cuda stream // default: local current cuda stream
DeviceAPI::Get(ctx)->CopyDataFromTo( DeviceAPI::Get(ctx)->CopyDataFromTo(
...@@ -260,9 +214,9 @@ void NDArray::CopyFromTo(DLTensor* from, ...@@ -260,9 +214,9 @@ void NDArray::CopyFromTo(DLTensor* from,
void NDArray::PinContainer(NDArray::Container* ptr) { void NDArray::PinContainer(NDArray::Container* ptr) {
if (IsContainerPinned(ptr)) return; if (IsContainerPinned(ptr)) return;
auto* tensor = &(ptr->dl_tensor); auto* tensor = &(ptr->dl_tensor);
CHECK_EQ(tensor->ctx.device_type, kDLCPU) CHECK_EQ(tensor->ctx.device_type, kDGLCPU)
<< "Only NDArray on CPU can be pinned"; << "Only NDArray on CPU can be pinned";
DeviceAPI::Get(kDLGPU)->PinData(tensor->data, GetDataSize(*tensor)); DeviceAPI::Get(kDGLCUDA)->PinData(tensor->data, GetDataSize(*tensor));
ptr->pinned_by_dgl_ = true; ptr->pinned_by_dgl_ = true;
} }
...@@ -275,22 +229,22 @@ void NDArray::UnpinContainer(NDArray::Container* ptr) { ...@@ -275,22 +229,22 @@ void NDArray::UnpinContainer(NDArray::Container* ptr) {
// 1. not pinned, do nothing // 1. not pinned, do nothing
if (!container_is_pinned) return; if (!container_is_pinned) return;
// 2. pinned by DGL, unpin it // 2. pinned by DGL, unpin it
DeviceAPI::Get(kDLGPU)->UnpinData(ptr->dl_tensor.data); DeviceAPI::Get(kDGLCUDA)->UnpinData(ptr->dl_tensor.data);
ptr->pinned_by_dgl_ = false; ptr->pinned_by_dgl_ = false;
} }
void NDArray::RecordStream(DGLArray* tensor, DGLStreamHandle stream) { void NDArray::RecordStream(DGLArray* tensor, DGLStreamHandle stream) {
TensorDispatcher* td = TensorDispatcher::Global(); TensorDispatcher* td = TensorDispatcher::Global();
CHECK(td->IsAvailable()) << "RecordStream only works when TensorAdaptor is available."; CHECK(td->IsAvailable()) << "RecordStream only works when TensorAdaptor is available.";
CHECK_EQ(tensor->ctx.device_type, kDLGPU) CHECK_EQ(tensor->ctx.device_type, kDGLCUDA)
<< "RecordStream only works with GPU tensors."; << "RecordStream only works with GPU tensors.";
td->RecordStream(tensor->data, stream, tensor->ctx.device_id); td->RecordStream(tensor->data, stream, tensor->ctx.device_id);
} }
template<typename T> template<typename T>
NDArray NDArray::FromVector(const std::vector<T>& vec, DLContext ctx) { NDArray NDArray::FromVector(const std::vector<T>& vec, DGLContext ctx) {
const DLDataType dtype = DLDataTypeTraits<T>::dtype; const DGLDataType dtype = DGLDataTypeTraits<T>::dtype;
int64_t size = static_cast<int64_t>(vec.size()); int64_t size = static_cast<int64_t>(vec.size());
NDArray ret = NDArray::Empty({size}, dtype, ctx); NDArray ret = NDArray::Empty({size}, dtype, ctx);
DeviceAPI::Get(ctx)->CopyDataFromTo( DeviceAPI::Get(ctx)->CopyDataFromTo(
...@@ -299,29 +253,38 @@ NDArray NDArray::FromVector(const std::vector<T>& vec, DLContext ctx) { ...@@ -299,29 +253,38 @@ NDArray NDArray::FromVector(const std::vector<T>& vec, DLContext ctx) {
static_cast<T*>(ret->data), static_cast<T*>(ret->data),
0, 0,
size * sizeof(T), size * sizeof(T),
DLContext{kDLCPU, 0}, DGLContext{kDGLCPU, 0},
ctx, ctx,
dtype); dtype);
return ret; return ret;
} }
NDArray NDArray::CreateFromRaw(const std::vector<int64_t>& shape,
DGLDataType dtype, DGLContext ctx, void* raw, bool auto_free) {
NDArray ret = Internal::Create(shape, dtype, ctx);
ret.data_->dl_tensor.data = raw;
if (!auto_free)
ret.data_->deleter = nullptr;
return ret;
}
// export specializations // export specializations
template NDArray NDArray::FromVector<int32_t>(const std::vector<int32_t>&, DLContext); template NDArray NDArray::FromVector<int32_t>(const std::vector<int32_t>&, DGLContext);
template NDArray NDArray::FromVector<int64_t>(const std::vector<int64_t>&, DLContext); template NDArray NDArray::FromVector<int64_t>(const std::vector<int64_t>&, DGLContext);
template NDArray NDArray::FromVector<uint32_t>(const std::vector<uint32_t>&, DLContext); template NDArray NDArray::FromVector<uint32_t>(const std::vector<uint32_t>&, DGLContext);
template NDArray NDArray::FromVector<uint64_t>(const std::vector<uint64_t>&, DLContext); template NDArray NDArray::FromVector<uint64_t>(const std::vector<uint64_t>&, DGLContext);
template NDArray NDArray::FromVector<float>(const std::vector<float>&, DLContext); template NDArray NDArray::FromVector<float>(const std::vector<float>&, DGLContext);
template NDArray NDArray::FromVector<double>(const std::vector<double>&, DLContext); template NDArray NDArray::FromVector<double>(const std::vector<double>&, DGLContext);
template<typename T> template<typename T>
std::vector<T> NDArray::ToVector() const { std::vector<T> NDArray::ToVector() const {
const DLDataType dtype = DLDataTypeTraits<T>::dtype; const DGLDataType dtype = DGLDataTypeTraits<T>::dtype;
CHECK(data_->dl_tensor.ndim == 1) << "ToVector() only supported for 1D arrays"; CHECK(data_->dl_tensor.ndim == 1) << "ToVector() only supported for 1D arrays";
CHECK(data_->dl_tensor.dtype == dtype) << "dtype mismatch"; CHECK(data_->dl_tensor.dtype == dtype) << "dtype mismatch";
int64_t size = data_->dl_tensor.shape[0]; int64_t size = data_->dl_tensor.shape[0];
std::vector<T> vec(size); std::vector<T> vec(size);
const DLContext &ctx = data_->dl_tensor.ctx; const DGLContext &ctx = data_->dl_tensor.ctx;
DeviceAPI::Get(ctx)->CopyDataFromTo( DeviceAPI::Get(ctx)->CopyDataFromTo(
static_cast<T*>(data_->dl_tensor.data), static_cast<T*>(data_->dl_tensor.data),
0, 0,
...@@ -329,7 +292,7 @@ std::vector<T> NDArray::ToVector() const { ...@@ -329,7 +292,7 @@ std::vector<T> NDArray::ToVector() const {
0, 0,
size * sizeof(T), size * sizeof(T),
ctx, ctx,
DLContext{kDLCPU, 0}, DGLContext{kDGLCPU, 0},
dtype); dtype);
return vec; return vec;
} }
...@@ -350,10 +313,10 @@ bool NDArray::IsContainerPinned(NDArray::Container* ptr) { ...@@ -350,10 +313,10 @@ bool NDArray::IsContainerPinned(NDArray::Container* ptr) {
return true; return true;
auto* tensor = &(ptr->dl_tensor); auto* tensor = &(ptr->dl_tensor);
// Can only be pinned if on CPU... // Can only be pinned if on CPU...
if (tensor->ctx.device_type != kDLCPU) if (tensor->ctx.device_type != kDGLCPU)
return false; return false;
// ... and CUDA device API is enabled, and the tensor is indeed in pinned memory. // ... and CUDA device API is enabled, and the tensor is indeed in pinned memory.
auto device = DeviceAPI::Get(kDLGPU, true); auto device = DeviceAPI::Get(kDGLCUDA, true);
return device && device->IsPinned(tensor->data); return device && device->IsPinned(tensor->data);
} }
...@@ -363,7 +326,7 @@ void NDArray::Save(dmlc::Stream* strm) const { ...@@ -363,7 +326,7 @@ void NDArray::Save(dmlc::Stream* strm) const {
zc_strm->PushNDArray(*this); zc_strm->PushNDArray(*this);
return; return;
} }
SaveDLTensor(strm, const_cast<DLTensor*>(operator->())); SaveDGLArray(strm, const_cast<DGLArray*>(operator->()));
} }
bool NDArray::Load(dmlc::Stream* strm) { bool NDArray::Load(dmlc::Stream* strm) {
...@@ -374,26 +337,26 @@ bool NDArray::Load(dmlc::Stream* strm) { ...@@ -374,26 +337,26 @@ bool NDArray::Load(dmlc::Stream* strm) {
} }
uint64_t header, reserved; uint64_t header, reserved;
CHECK(strm->Read(&header)) CHECK(strm->Read(&header))
<< "Invalid DLTensor file format"; << "Invalid DGLArray file format";
CHECK(strm->Read(&reserved)) CHECK(strm->Read(&reserved))
<< "Invalid DLTensor file format"; << "Invalid DGLArray file format";
CHECK(header == kDGLNDArrayMagic) CHECK(header == kDGLNDArrayMagic)
<< "Invalid DLTensor file format"; << "Invalid DGLArray file format";
DLContext ctx; DGLContext ctx;
int ndim; int ndim;
DLDataType dtype; DGLDataType dtype;
CHECK(strm->Read(&ctx)) CHECK(strm->Read(&ctx))
<< "Invalid DLTensor file format"; << "Invalid DGLArray file format";
CHECK(strm->Read(&ndim)) CHECK(strm->Read(&ndim))
<< "Invalid DLTensor file format"; << "Invalid DGLArray file format";
CHECK(strm->Read(&dtype)) CHECK(strm->Read(&dtype))
<< "Invalid DLTensor file format"; << "Invalid DGLArray file format";
CHECK_EQ(ctx.device_type, kDLCPU) CHECK_EQ(ctx.device_type, kDGLCPU)
<< "Invalid DLTensor context: can only save as CPU tensor"; << "Invalid DGLArray context: can only save as CPU tensor";
std::vector<int64_t> shape(ndim); std::vector<int64_t> shape(ndim);
if (ndim != 0) { if (ndim != 0) {
CHECK(strm->ReadArray(&shape[0], ndim)) CHECK(strm->ReadArray(&shape[0], ndim))
<< "Invalid DLTensor file format"; << "Invalid DGLArray file format";
} }
NDArray ret = NDArray::Empty(shape, dtype, ctx); NDArray ret = NDArray::Empty(shape, dtype, ctx);
int64_t num_elems = 1; int64_t num_elems = 1;
...@@ -403,14 +366,14 @@ bool NDArray::Load(dmlc::Stream* strm) { ...@@ -403,14 +366,14 @@ bool NDArray::Load(dmlc::Stream* strm) {
} }
int64_t data_byte_size; int64_t data_byte_size;
CHECK(strm->Read(&data_byte_size)) CHECK(strm->Read(&data_byte_size))
<< "Invalid DLTensor file format"; << "Invalid DGLArray file format";
CHECK(data_byte_size == num_elems * elem_bytes) CHECK(data_byte_size == num_elems * elem_bytes)
<< "Invalid DLTensor file format"; << "Invalid DGLArray file format";
if (data_byte_size != 0) { if (data_byte_size != 0) {
// strm->Read will return the total number of elements successfully read. // strm->Read will return the total number of elements successfully read.
// Therefore if data_byte_size is zero, the CHECK below would fail. // Therefore if data_byte_size is zero, the CHECK below would fail.
CHECK(strm->Read(ret->data, data_byte_size)) CHECK(strm->Read(ret->data, data_byte_size))
<< "Invalid DLTensor file format"; << "Invalid DGLArray file format";
} }
if (!DMLC_IO_NO_ENDIAN_SWAP) { if (!DMLC_IO_NO_ENDIAN_SWAP) {
dmlc::ByteSwap(ret->data, elem_bytes, num_elems); dmlc::ByteSwap(ret->data, elem_bytes, num_elems);
...@@ -425,11 +388,6 @@ bool NDArray::Load(dmlc::Stream* strm) { ...@@ -425,11 +388,6 @@ bool NDArray::Load(dmlc::Stream* strm) {
using namespace dgl::runtime; using namespace dgl::runtime;
void NDArrayDLPackDeleter(DLManagedTensor* tensor) {
static_cast<NDArray::Container*>(tensor->manager_ctx)->DecRef();
delete tensor;
}
int DGLArrayAlloc(const dgl_index_t* shape, int DGLArrayAlloc(const dgl_index_t* shape,
int ndim, int ndim,
int dtype_code, int dtype_code,
...@@ -439,14 +397,14 @@ int DGLArrayAlloc(const dgl_index_t* shape, ...@@ -439,14 +397,14 @@ int DGLArrayAlloc(const dgl_index_t* shape,
int device_id, int device_id,
DGLArrayHandle* out) { DGLArrayHandle* out) {
API_BEGIN(); API_BEGIN();
DLDataType dtype; DGLDataType dtype;
dtype.code = static_cast<uint8_t>(dtype_code); dtype.code = static_cast<uint8_t>(dtype_code);
dtype.bits = static_cast<uint8_t>(dtype_bits); dtype.bits = static_cast<uint8_t>(dtype_bits);
dtype.lanes = static_cast<uint16_t>(dtype_lanes); dtype.lanes = static_cast<uint16_t>(dtype_lanes);
DLContext ctx; DGLContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type); ctx.device_type = static_cast<DGLDeviceType>(device_type);
ctx.device_id = device_id; ctx.device_id = device_id;
*out = NDArray::Internal::MoveAsDLTensor( *out = NDArray::Internal::MoveAsDGLArray(
NDArray::Empty(std::vector<int64_t>(shape, shape + ndim), dtype, ctx)); NDArray::Empty(std::vector<int64_t>(shape, shape + ndim), dtype, ctx));
API_END(); API_END();
} }
...@@ -460,14 +418,14 @@ int DGLArrayAllocSharedMem(const char *mem_name, ...@@ -460,14 +418,14 @@ int DGLArrayAllocSharedMem(const char *mem_name,
bool is_create, bool is_create,
DGLArrayHandle* out) { DGLArrayHandle* out) {
API_BEGIN(); API_BEGIN();
DLDataType dtype; DGLDataType dtype;
dtype.code = static_cast<uint8_t>(dtype_code); dtype.code = static_cast<uint8_t>(dtype_code);
dtype.bits = static_cast<uint8_t>(dtype_bits); dtype.bits = static_cast<uint8_t>(dtype_bits);
dtype.lanes = static_cast<uint16_t>(dtype_lanes); dtype.lanes = static_cast<uint16_t>(dtype_lanes);
std::vector<int64_t> shape_vec(shape, shape + ndim); std::vector<int64_t> shape_vec(shape, shape + ndim);
NDArray arr = NDArray::EmptyShared(mem_name, shape_vec, dtype, NDArray arr = NDArray::EmptyShared(mem_name, shape_vec, dtype,
DLContext{kDLCPU, 0}, is_create); DGLContext{kDGLCPU, 0}, is_create);
*out = NDArray::Internal::MoveAsDLTensor(arr); *out = NDArray::Internal::MoveAsDGLArray(arr);
API_END(); API_END();
} }
...@@ -484,44 +442,12 @@ int DGLArrayCopyFromTo(DGLArrayHandle from, ...@@ -484,44 +442,12 @@ int DGLArrayCopyFromTo(DGLArrayHandle from,
API_END(); API_END();
} }
int DGLArrayFromDLPack(DLManagedTensor* from,
DGLArrayHandle* out) {
API_BEGIN();
*out = NDArray::Internal::MoveAsDLTensor(NDArray::FromDLPack(from));
API_END();
}
inline bool is_aligned(const void* ptr, std::uintptr_t alignment) noexcept {
auto iptr = reinterpret_cast<std::uintptr_t>(ptr);
return !(iptr % alignment);
}
int DGLArrayToDLPack(DGLArrayHandle from, DLManagedTensor** out,
int alignment) {
API_BEGIN();
auto* nd_container = reinterpret_cast<NDArray::Container*>(from);
DLTensor* nd = &(nd_container->dl_tensor);
if (alignment != 0 && !is_aligned(nd->data, alignment)) {
std::vector<int64_t> shape_vec(nd->shape, nd->shape + nd->ndim);
NDArray copy_ndarray = NDArray::Empty(shape_vec, nd->dtype, nd->ctx);
copy_ndarray.CopyFrom(nd);
*out = copy_ndarray.ToDLPack();
} else {
*out = NDArray::Internal::ToDLPack(nd_container);
}
API_END();
}
void DGLDLManagedTensorCallDeleter(DLManagedTensor* dltensor) {
(*(dltensor->deleter))(dltensor);
}
int DGLArrayCopyFromBytes(DGLArrayHandle handle, int DGLArrayCopyFromBytes(DGLArrayHandle handle,
void* data, void* data,
size_t nbytes) { size_t nbytes) {
API_BEGIN(); API_BEGIN();
DGLContext cpu_ctx; DGLContext cpu_ctx;
cpu_ctx.device_type = kDLCPU; cpu_ctx.device_type = kDGLCPU;
cpu_ctx.device_id = 0; cpu_ctx.device_id = 0;
size_t arr_size = GetDataSize(*handle); size_t arr_size = GetDataSize(*handle);
CHECK_EQ(arr_size, nbytes) CHECK_EQ(arr_size, nbytes)
...@@ -538,7 +464,7 @@ int DGLArrayCopyToBytes(DGLArrayHandle handle, ...@@ -538,7 +464,7 @@ int DGLArrayCopyToBytes(DGLArrayHandle handle,
size_t nbytes) { size_t nbytes) {
API_BEGIN(); API_BEGIN();
DGLContext cpu_ctx; DGLContext cpu_ctx;
cpu_ctx.device_type = kDLCPU; cpu_ctx.device_type = kDGLCPU;
cpu_ctx.device_id = 0; cpu_ctx.device_id = 0;
size_t arr_size = GetDataSize(*handle); size_t arr_size = GetDataSize(*handle);
CHECK_EQ(arr_size, nbytes) CHECK_EQ(arr_size, nbytes)
...@@ -551,7 +477,7 @@ int DGLArrayCopyToBytes(DGLArrayHandle handle, ...@@ -551,7 +477,7 @@ int DGLArrayCopyToBytes(DGLArrayHandle handle,
} }
int DGLArrayPinData(DGLArrayHandle handle, int DGLArrayPinData(DGLArrayHandle handle,
DLContext ctx) { DGLContext ctx) {
API_BEGIN(); API_BEGIN();
auto* nd_container = reinterpret_cast<NDArray::Container*>(handle); auto* nd_container = reinterpret_cast<NDArray::Container*>(handle);
NDArray::PinContainer(nd_container); NDArray::PinContainer(nd_container);
...@@ -559,7 +485,7 @@ int DGLArrayPinData(DGLArrayHandle handle, ...@@ -559,7 +485,7 @@ int DGLArrayPinData(DGLArrayHandle handle,
} }
int DGLArrayUnpinData(DGLArrayHandle handle, int DGLArrayUnpinData(DGLArrayHandle handle,
DLContext ctx) { DGLContext ctx) {
API_BEGIN(); API_BEGIN();
auto* nd_container = reinterpret_cast<NDArray::Container*>(handle); auto* nd_container = reinterpret_cast<NDArray::Container*>(handle);
NDArray::UnpinContainer(nd_container); NDArray::UnpinContainer(nd_container);
......
...@@ -39,7 +39,7 @@ union ArgUnion { ...@@ -39,7 +39,7 @@ union ArgUnion {
* \return The wrapped packed function. * \return The wrapped packed function.
*/ */
template<typename F> template<typename F>
inline PackedFunc PackFuncVoidAddr(F f, const std::vector<DGLType>& arg_types); inline PackedFunc PackFuncVoidAddr(F f, const std::vector<DGLDataType>& arg_types);
/*! /*!
* \brief Create a packed function that from function only packs buffer arguments. * \brief Create a packed function that from function only packs buffer arguments.
* *
...@@ -50,7 +50,7 @@ inline PackedFunc PackFuncVoidAddr(F f, const std::vector<DGLType>& arg_types); ...@@ -50,7 +50,7 @@ inline PackedFunc PackFuncVoidAddr(F f, const std::vector<DGLType>& arg_types);
* \return The wrapped packed function. * \return The wrapped packed function.
*/ */
template<typename F> template<typename F>
inline PackedFunc PackFuncNonBufferArg(F f, const std::vector<DGLType>& arg_types); inline PackedFunc PackFuncNonBufferArg(F f, const std::vector<DGLDataType>& arg_types);
/*! /*!
* \brief Create a packed function that from function that takes a packed arguments. * \brief Create a packed function that from function that takes a packed arguments.
* *
...@@ -61,13 +61,13 @@ inline PackedFunc PackFuncNonBufferArg(F f, const std::vector<DGLType>& arg_type ...@@ -61,13 +61,13 @@ inline PackedFunc PackFuncNonBufferArg(F f, const std::vector<DGLType>& arg_type
* \return The wrapped packed function. * \return The wrapped packed function.
*/ */
template<typename F> template<typename F>
inline PackedFunc PackFuncPackedArg(F f, const std::vector<DGLType>& arg_types); inline PackedFunc PackFuncPackedArg(F f, const std::vector<DGLDataType>& arg_types);
/*! /*!
* \brief Extract number of buffer argument from the argument types. * \brief Extract number of buffer argument from the argument types.
* \param arg_types The argument types. * \param arg_types The argument types.
* \return number of buffer arguments * \return number of buffer arguments
*/ */
inline size_t NumBufferArgs(const std::vector<DGLType>& arg_types); inline size_t NumBufferArgs(const std::vector<DGLDataType>& arg_types);
// implementations details // implementations details
namespace detail { namespace detail {
...@@ -102,15 +102,15 @@ enum ArgConvertCode { ...@@ -102,15 +102,15 @@ enum ArgConvertCode {
HANDLE_TO_HANDLE HANDLE_TO_HANDLE
}; };
inline ArgConvertCode GetArgConvertCode(DGLType t) { inline ArgConvertCode GetArgConvertCode(DGLDataType t) {
CHECK_EQ(t.lanes, 1U) CHECK_EQ(t.lanes, 1U)
<< "Cannot pass vector type argument to devic function for now"; << "Cannot pass vector type argument to devic function for now";
if (t.code == kDLInt) { if (t.code == kDGLInt) {
if (t.bits == 64U) return INT64_TO_INT64; if (t.bits == 64U) return INT64_TO_INT64;
if (t.bits == 32U) return INT64_TO_INT32; if (t.bits == 32U) return INT64_TO_INT32;
} else if (t.code == kDLUInt) { } else if (t.code == kDGLUInt) {
if (t.bits == 32U) return INT64_TO_UINT32; if (t.bits == 32U) return INT64_TO_UINT32;
} else if (t.code == kDLFloat) { } else if (t.code == kDGLFloat) {
if (t.bits == 64U) return FLOAT64_TO_FLOAT64; if (t.bits == 64U) return FLOAT64_TO_FLOAT64;
if (t.bits == 32U) return FLOAT64_TO_FLOAT32; if (t.bits == 32U) return FLOAT64_TO_FLOAT32;
} else if (t.code == kHandle) { } else if (t.code == kHandle) {
...@@ -245,7 +245,7 @@ inline PackedFunc PackFuncPackedArg_( ...@@ -245,7 +245,7 @@ inline PackedFunc PackFuncPackedArg_(
} // namespace detail } // namespace detail
template<typename F> template<typename F>
inline PackedFunc PackFuncVoidAddr(F f, const std::vector<DGLType>& arg_types) { inline PackedFunc PackFuncVoidAddr(F f, const std::vector<DGLDataType>& arg_types) {
std::vector<detail::ArgConvertCode> codes(arg_types.size()); std::vector<detail::ArgConvertCode> codes(arg_types.size());
for (size_t i = 0; i < arg_types.size(); ++i) { for (size_t i = 0; i < arg_types.size(); ++i) {
codes[i] = detail::GetArgConvertCode(arg_types[i]); codes[i] = detail::GetArgConvertCode(arg_types[i]);
...@@ -261,7 +261,7 @@ inline PackedFunc PackFuncVoidAddr(F f, const std::vector<DGLType>& arg_types) { ...@@ -261,7 +261,7 @@ inline PackedFunc PackFuncVoidAddr(F f, const std::vector<DGLType>& arg_types) {
} }
} }
inline size_t NumBufferArgs(const std::vector<DGLType>& arg_types) { inline size_t NumBufferArgs(const std::vector<DGLDataType>& arg_types) {
size_t base = arg_types.size(); size_t base = arg_types.size();
for (size_t i = 0; i < arg_types.size(); ++i) { for (size_t i = 0; i < arg_types.size(); ++i) {
if (arg_types[i].code != kHandle) { if (arg_types[i].code != kHandle) {
...@@ -276,7 +276,7 @@ inline size_t NumBufferArgs(const std::vector<DGLType>& arg_types) { ...@@ -276,7 +276,7 @@ inline size_t NumBufferArgs(const std::vector<DGLType>& arg_types) {
} }
template<typename F> template<typename F>
inline PackedFunc PackFuncNonBufferArg(F f, const std::vector<DGLType>& arg_types) { inline PackedFunc PackFuncNonBufferArg(F f, const std::vector<DGLDataType>& arg_types) {
size_t num_buffer = NumBufferArgs(arg_types); size_t num_buffer = NumBufferArgs(arg_types);
std::vector<detail::ArgConvertCode> codes; std::vector<detail::ArgConvertCode> codes;
for (size_t i = num_buffer; i < arg_types.size(); ++i) { for (size_t i = num_buffer; i < arg_types.size(); ++i) {
...@@ -293,7 +293,7 @@ inline PackedFunc PackFuncNonBufferArg(F f, const std::vector<DGLType>& arg_type ...@@ -293,7 +293,7 @@ inline PackedFunc PackFuncNonBufferArg(F f, const std::vector<DGLType>& arg_type
} }
template<typename F> template<typename F>
inline PackedFunc PackFuncPackedArg(F f, const std::vector<DGLType>& arg_types) { inline PackedFunc PackFuncPackedArg(F f, const std::vector<DGLDataType>& arg_types) {
std::vector<detail::ArgConvertCode> codes; std::vector<detail::ArgConvertCode> codes;
for (size_t i = 0; i < arg_types.size(); ++i) { for (size_t i = 0; i < arg_types.size(); ++i) {
codes.push_back(detail::GetArgConvertCode(arg_types[i])); codes.push_back(detail::GetArgConvertCode(arg_types[i]));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment