Unverified Commit cded5b80 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Feature] Bump DLPack to v0.7 and decouple DLPack from the core library (#4454)

* rename `DLContext` to `DGLContext`

* rename `kDLGPU` to `kDLCUDA`

* replace DLTensor with DGLArray

* fix linting

* Unify DGLType and DLDataType to DGLDataType

* Fix FFI

* rename DLDeviceType to DGLDeviceType

* decouple dlpack from the core library

* fix bug

* fix lint

* fix merge

* fix build

* address comments

* rename dl_converter to dlpack_convert

* remove redundant comments
parent f1689ad0
...@@ -89,7 +89,7 @@ __global__ void fps_kernel(const FloatType *array_data, const int64_t batch_size ...@@ -89,7 +89,7 @@ __global__ void fps_kernel(const FloatType *array_data, const int64_t batch_size
} }
} }
template <DLDeviceType XPU, typename FloatType, typename IdType> template <DGLDeviceType XPU, typename FloatType, typename IdType>
void FarthestPointSampler(NDArray array, int64_t batch_size, int64_t sample_points, void FarthestPointSampler(NDArray array, int64_t batch_size, int64_t sample_points,
NDArray dist, IdArray start_idx, IdArray result) { NDArray dist, IdArray start_idx, IdArray result) {
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
...@@ -115,16 +115,16 @@ void FarthestPointSampler(NDArray array, int64_t batch_size, int64_t sample_poin ...@@ -115,16 +115,16 @@ void FarthestPointSampler(NDArray array, int64_t batch_size, int64_t sample_poin
point_in_batch, dim, start_idx_data, dist_data, ret_data); point_in_batch, dim, start_idx_data, dist_data, ret_data);
} }
template void FarthestPointSampler<kDLGPU, float, int32_t>( template void FarthestPointSampler<kDGLCUDA, float, int32_t>(
NDArray array, int64_t batch_size, int64_t sample_points, NDArray array, int64_t batch_size, int64_t sample_points,
NDArray dist, IdArray start_idx, IdArray result); NDArray dist, IdArray start_idx, IdArray result);
template void FarthestPointSampler<kDLGPU, float, int64_t>( template void FarthestPointSampler<kDGLCUDA, float, int64_t>(
NDArray array, int64_t batch_size, int64_t sample_points, NDArray array, int64_t batch_size, int64_t sample_points,
NDArray dist, IdArray start_idx, IdArray result); NDArray dist, IdArray start_idx, IdArray result);
template void FarthestPointSampler<kDLGPU, double, int32_t>( template void FarthestPointSampler<kDGLCUDA, double, int32_t>(
NDArray array, int64_t batch_size, int64_t sample_points, NDArray array, int64_t batch_size, int64_t sample_points,
NDArray dist, IdArray start_idx, IdArray result); NDArray dist, IdArray start_idx, IdArray result);
template void FarthestPointSampler<kDLGPU, double, int64_t>( template void FarthestPointSampler<kDGLCUDA, double, int64_t>(
NDArray array, int64_t batch_size, int64_t sample_points, NDArray array, int64_t batch_size, int64_t sample_points,
NDArray dist, IdArray start_idx, IdArray result); NDArray dist, IdArray start_idx, IdArray result);
......
...@@ -12,7 +12,7 @@ namespace dgl { ...@@ -12,7 +12,7 @@ namespace dgl {
namespace geometry { namespace geometry {
namespace impl { namespace impl {
template <DLDeviceType XPU, typename FloatType, typename IdType> template <DGLDeviceType XPU, typename FloatType, typename IdType>
void FarthestPointSampler(NDArray array, int64_t batch_size, int64_t sample_points, void FarthestPointSampler(NDArray array, int64_t batch_size, int64_t sample_points,
NDArray dist, IdArray start_idx, IdArray result); NDArray dist, IdArray start_idx, IdArray result);
...@@ -21,7 +21,7 @@ void FarthestPointSampler(NDArray array, int64_t batch_size, int64_t sample_poin ...@@ -21,7 +21,7 @@ void FarthestPointSampler(NDArray array, int64_t batch_size, int64_t sample_poin
* picking an unmarked vertex and matching it with one its unmarked neighbors * picking an unmarked vertex and matching it with one its unmarked neighbors
* (that maximizes its edge weight) until no match can be done. * (that maximizes its edge weight) until no match can be done.
*/ */
template <DLDeviceType XPU, typename FloatType, typename IdType> template <DGLDeviceType XPU, typename FloatType, typename IdType>
void WeightedNeighborMatching(const aten::CSRMatrix &csr, const NDArray weight, IdArray result); void WeightedNeighborMatching(const aten::CSRMatrix &csr, const NDArray weight, IdArray result);
/*! \brief Implementation of neighbor matching process of edge coarsening used /*! \brief Implementation of neighbor matching process of edge coarsening used
...@@ -29,7 +29,7 @@ void WeightedNeighborMatching(const aten::CSRMatrix &csr, const NDArray weight, ...@@ -29,7 +29,7 @@ void WeightedNeighborMatching(const aten::CSRMatrix &csr, const NDArray weight,
* picking an unmarked vertex and matching it with one its unmarked neighbors * picking an unmarked vertex and matching it with one its unmarked neighbors
* (that maximizes its edge weight) until no match can be done. * (that maximizes its edge weight) until no match can be done.
*/ */
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
void NeighborMatching(const aten::CSRMatrix &csr, IdArray result); void NeighborMatching(const aten::CSRMatrix &csr, IdArray result);
} // namespace impl } // namespace impl
......
...@@ -178,7 +178,7 @@ IdArray Graph::Predecessors(dgl_id_t vid, uint64_t radius) const { ...@@ -178,7 +178,7 @@ IdArray Graph::Predecessors(dgl_id_t vid, uint64_t radius) const {
vset.insert(it); vset.insert(it);
const int64_t len = vset.size(); const int64_t len = vset.size();
IdArray rst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0}); IdArray rst = IdArray::Empty({len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
int64_t* rst_data = static_cast<int64_t*>(rst->data); int64_t* rst_data = static_cast<int64_t*>(rst->data);
std::copy(vset.begin(), vset.end(), rst_data); std::copy(vset.begin(), vset.end(), rst_data);
...@@ -195,7 +195,7 @@ IdArray Graph::Successors(dgl_id_t vid, uint64_t radius) const { ...@@ -195,7 +195,7 @@ IdArray Graph::Successors(dgl_id_t vid, uint64_t radius) const {
vset.insert(it); vset.insert(it);
const int64_t len = vset.size(); const int64_t len = vset.size();
IdArray rst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0}); IdArray rst = IdArray::Empty({len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
int64_t* rst_data = static_cast<int64_t*>(rst->data); int64_t* rst_data = static_cast<int64_t*>(rst->data);
std::copy(vset.begin(), vset.end(), rst_data); std::copy(vset.begin(), vset.end(), rst_data);
...@@ -216,7 +216,7 @@ IdArray Graph::EdgeId(dgl_id_t src, dgl_id_t dst) const { ...@@ -216,7 +216,7 @@ IdArray Graph::EdgeId(dgl_id_t src, dgl_id_t dst) const {
// FIXME: signed? Also it seems that we are using int64_t everywhere... // FIXME: signed? Also it seems that we are using int64_t everywhere...
const int64_t len = edgelist.size(); const int64_t len = edgelist.size();
IdArray rst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0}); IdArray rst = IdArray::Empty({len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
// FIXME: signed? // FIXME: signed?
int64_t* rst_data = static_cast<int64_t*>(rst->data); int64_t* rst_data = static_cast<int64_t*>(rst->data);
...@@ -301,9 +301,9 @@ EdgeArray Graph::FindEdges(IdArray eids) const { ...@@ -301,9 +301,9 @@ EdgeArray Graph::FindEdges(IdArray eids) const {
EdgeArray Graph::InEdges(dgl_id_t vid) const { EdgeArray Graph::InEdges(dgl_id_t vid) const {
CHECK(HasVertex(vid)) << "invalid vertex: " << vid; CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
const int64_t len = reverse_adjlist_[vid].succ.size(); const int64_t len = reverse_adjlist_[vid].succ.size();
IdArray src = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0}); IdArray src = IdArray::Empty({len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
IdArray dst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0}); IdArray dst = IdArray::Empty({len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
IdArray eid = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0}); IdArray eid = IdArray::Empty({len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
int64_t* src_data = static_cast<int64_t*>(src->data); int64_t* src_data = static_cast<int64_t*>(src->data);
int64_t* dst_data = static_cast<int64_t*>(dst->data); int64_t* dst_data = static_cast<int64_t*>(dst->data);
int64_t* eid_data = static_cast<int64_t*>(eid->data); int64_t* eid_data = static_cast<int64_t*>(eid->data);
...@@ -347,9 +347,9 @@ EdgeArray Graph::InEdges(IdArray vids) const { ...@@ -347,9 +347,9 @@ EdgeArray Graph::InEdges(IdArray vids) const {
EdgeArray Graph::OutEdges(dgl_id_t vid) const { EdgeArray Graph::OutEdges(dgl_id_t vid) const {
CHECK(HasVertex(vid)) << "invalid vertex: " << vid; CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
const int64_t len = adjlist_[vid].succ.size(); const int64_t len = adjlist_[vid].succ.size();
IdArray src = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0}); IdArray src = IdArray::Empty({len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
IdArray dst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0}); IdArray dst = IdArray::Empty({len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
IdArray eid = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0}); IdArray eid = IdArray::Empty({len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
int64_t* src_data = static_cast<int64_t*>(src->data); int64_t* src_data = static_cast<int64_t*>(src->data);
int64_t* dst_data = static_cast<int64_t*>(dst->data); int64_t* dst_data = static_cast<int64_t*>(dst->data);
int64_t* eid_data = static_cast<int64_t*>(eid->data); int64_t* eid_data = static_cast<int64_t*>(eid->data);
...@@ -392,9 +392,9 @@ EdgeArray Graph::OutEdges(IdArray vids) const { ...@@ -392,9 +392,9 @@ EdgeArray Graph::OutEdges(IdArray vids) const {
// O(E*log(E)) if sort is required; otherwise, O(E) // O(E*log(E)) if sort is required; otherwise, O(E)
EdgeArray Graph::Edges(const std::string &order) const { EdgeArray Graph::Edges(const std::string &order) const {
const int64_t len = num_edges_; const int64_t len = num_edges_;
IdArray src = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0}); IdArray src = IdArray::Empty({len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
IdArray dst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0}); IdArray dst = IdArray::Empty({len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
IdArray eid = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0}); IdArray eid = IdArray::Empty({len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
if (order == "srcdst") { if (order == "srcdst") {
typedef std::tuple<int64_t, int64_t, int64_t> Tuple; typedef std::tuple<int64_t, int64_t, int64_t> Tuple;
...@@ -553,8 +553,8 @@ std::vector<IdArray> Graph::GetAdj(bool transpose, const std::string &fmt) const ...@@ -553,8 +553,8 @@ std::vector<IdArray> Graph::GetAdj(bool transpose, const std::string &fmt) const
if (fmt == "coo") { if (fmt == "coo") {
IdArray idx = IdArray::Empty( IdArray idx = IdArray::Empty(
{2 * static_cast<int64_t>(num_edges)}, {2 * static_cast<int64_t>(num_edges)},
DLDataType{kDLInt, 64, 1}, DGLDataType{kDGLInt, 64, 1},
DLContext{kDLCPU, 0}); DGLContext{kDGLCPU, 0});
int64_t *idx_data = static_cast<int64_t*>(idx->data); int64_t *idx_data = static_cast<int64_t*>(idx->data);
if (transpose) { if (transpose) {
std::copy(all_edges_src_.begin(), all_edges_src_.end(), idx_data); std::copy(all_edges_src_.begin(), all_edges_src_.end(), idx_data);
...@@ -565,8 +565,8 @@ std::vector<IdArray> Graph::GetAdj(bool transpose, const std::string &fmt) const ...@@ -565,8 +565,8 @@ std::vector<IdArray> Graph::GetAdj(bool transpose, const std::string &fmt) const
} }
IdArray eid = IdArray::Empty( IdArray eid = IdArray::Empty(
{static_cast<int64_t>(num_edges)}, {static_cast<int64_t>(num_edges)},
DLDataType{kDLInt, 64, 1}, DGLDataType{kDGLInt, 64, 1},
DLContext{kDLCPU, 0}); DGLContext{kDGLCPU, 0});
int64_t *eid_data = static_cast<int64_t*>(eid->data); int64_t *eid_data = static_cast<int64_t*>(eid->data);
for (uint64_t eid = 0; eid < num_edges; ++eid) { for (uint64_t eid = 0; eid < num_edges; ++eid) {
eid_data[eid] = eid; eid_data[eid] = eid;
...@@ -575,16 +575,16 @@ std::vector<IdArray> Graph::GetAdj(bool transpose, const std::string &fmt) const ...@@ -575,16 +575,16 @@ std::vector<IdArray> Graph::GetAdj(bool transpose, const std::string &fmt) const
} else if (fmt == "csr") { } else if (fmt == "csr") {
IdArray indptr = IdArray::Empty( IdArray indptr = IdArray::Empty(
{static_cast<int64_t>(num_nodes) + 1}, {static_cast<int64_t>(num_nodes) + 1},
DLDataType{kDLInt, 64, 1}, DGLDataType{kDGLInt, 64, 1},
DLContext{kDLCPU, 0}); DGLContext{kDGLCPU, 0});
IdArray indices = IdArray::Empty( IdArray indices = IdArray::Empty(
{static_cast<int64_t>(num_edges)}, {static_cast<int64_t>(num_edges)},
DLDataType{kDLInt, 64, 1}, DGLDataType{kDGLInt, 64, 1},
DLContext{kDLCPU, 0}); DGLContext{kDGLCPU, 0});
IdArray eid = IdArray::Empty( IdArray eid = IdArray::Empty(
{static_cast<int64_t>(num_edges)}, {static_cast<int64_t>(num_edges)},
DLDataType{kDLInt, 64, 1}, DGLDataType{kDGLInt, 64, 1},
DLContext{kDLCPU, 0}); DGLContext{kDGLCPU, 0});
int64_t *indptr_data = static_cast<int64_t*>(indptr->data); int64_t *indptr_data = static_cast<int64_t*>(indptr->data);
int64_t *indices_data = static_cast<int64_t*>(indices->data); int64_t *indices_data = static_cast<int64_t*>(indices->data);
int64_t *eid_data = static_cast<int64_t*>(eid->data); int64_t *eid_data = static_cast<int64_t*>(eid->data);
......
...@@ -47,7 +47,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCSRCreate") ...@@ -47,7 +47,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCSRCreate")
const std::string edge_dir = args[2]; const std::string edge_dir = args[2];
IdArray edge_ids = IdArray::Empty({indices->shape[0]}, IdArray edge_ids = IdArray::Empty({indices->shape[0]},
DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0}); DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
int64_t *edge_data = static_cast<int64_t *>(edge_ids->data); int64_t *edge_data = static_cast<int64_t *>(edge_ids->data);
for (int64_t i = 0; i < edge_ids->shape[0]; i++) for (int64_t i = 0; i < edge_ids->shape[0]; i++)
edge_data[i] = i; edge_data[i] = i;
......
...@@ -143,7 +143,7 @@ GraphPtr GraphOp::DisjointUnion(std::vector<GraphPtr> graphs) { ...@@ -143,7 +143,7 @@ GraphPtr GraphOp::DisjointUnion(std::vector<GraphPtr> graphs) {
std::vector<GraphPtr> GraphOp::DisjointPartitionByNum(GraphPtr graph, int64_t num) { std::vector<GraphPtr> GraphOp::DisjointPartitionByNum(GraphPtr graph, int64_t num) {
CHECK(num != 0 && graph->NumVertices() % num == 0) CHECK(num != 0 && graph->NumVertices() % num == 0)
<< "Number of partitions must evenly divide the number of nodes."; << "Number of partitions must evenly divide the number of nodes.";
IdArray sizes = IdArray::Empty({num}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0}); IdArray sizes = IdArray::Empty({num}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
int64_t* sizes_data = static_cast<int64_t*>(sizes->data); int64_t* sizes_data = static_cast<int64_t*>(sizes->data);
std::fill(sizes_data, sizes_data + num, graph->NumVertices() / num); std::fill(sizes_data, sizes_data + num, graph->NumVertices() / num);
return DisjointPartitionBySizes(graph, sizes); return DisjointPartitionBySizes(graph, sizes);
...@@ -257,7 +257,7 @@ IdArray GraphOp::MapParentIdToSubgraphId(IdArray parent_vids, IdArray query) { ...@@ -257,7 +257,7 @@ IdArray GraphOp::MapParentIdToSubgraphId(IdArray parent_vids, IdArray query) {
const auto query_len = query->shape[0]; const auto query_len = query->shape[0];
const dgl_id_t* parent_data = static_cast<dgl_id_t*>(parent_vids->data); const dgl_id_t* parent_data = static_cast<dgl_id_t*>(parent_vids->data);
const dgl_id_t* query_data = static_cast<dgl_id_t*>(query->data); const dgl_id_t* query_data = static_cast<dgl_id_t*>(query->data);
IdArray rst = IdArray::Empty({query_len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0}); IdArray rst = IdArray::Empty({query_len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
dgl_id_t* rst_data = static_cast<dgl_id_t*>(rst->data); dgl_id_t* rst_data = static_cast<dgl_id_t*>(rst->data);
const bool is_sorted = std::is_sorted(parent_data, parent_data + parent_len); const bool is_sorted = std::is_sorted(parent_data, parent_data + parent_len);
...@@ -303,7 +303,7 @@ IdArray GraphOp::ExpandIds(IdArray ids, IdArray offset) { ...@@ -303,7 +303,7 @@ IdArray GraphOp::ExpandIds(IdArray ids, IdArray offset) {
const dgl_id_t *id_data = static_cast<dgl_id_t*>(ids->data); const dgl_id_t *id_data = static_cast<dgl_id_t*>(ids->data);
const dgl_id_t *off_data = static_cast<dgl_id_t*>(offset->data); const dgl_id_t *off_data = static_cast<dgl_id_t*>(offset->data);
const int64_t len = off_data[off_len - 1]; const int64_t len = off_data[off_len - 1];
IdArray rst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0}); IdArray rst = IdArray::Empty({len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
dgl_id_t *rst_data = static_cast<dgl_id_t*>(rst->data); dgl_id_t *rst_data = static_cast<dgl_id_t*>(rst->data);
for (int64_t i = 0; i < id_len; i++) { for (int64_t i = 0; i < id_len; i++) {
const int64_t local_len = off_data[i + 1] - off_data[i]; const int64_t local_len = off_data[i + 1] - off_data[i];
...@@ -482,8 +482,10 @@ HaloSubgraph GraphOp::GetSubgraphWithHalo(GraphPtr g, IdArray nodes, int num_hop ...@@ -482,8 +482,10 @@ HaloSubgraph GraphOp::GetSubgraphWithHalo(GraphPtr g, IdArray nodes, int num_hop
} }
num_edges = edge_src.size(); num_edges = edge_src.size();
IdArray new_src = IdArray::Empty({num_edges}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0}); IdArray new_src = IdArray::Empty({num_edges}, DGLDataType{kDGLInt, 64, 1},
IdArray new_dst = IdArray::Empty({num_edges}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0}); DGLContext{kDGLCPU, 0});
IdArray new_dst = IdArray::Empty({num_edges}, DGLDataType{kDGLInt, 64, 1},
DGLContext{kDGLCPU, 0});
dgl_id_t *new_src_data = static_cast<dgl_id_t *>(new_src->data); dgl_id_t *new_src_data = static_cast<dgl_id_t *>(new_src->data);
dgl_id_t *new_dst_data = static_cast<dgl_id_t *>(new_dst->data); dgl_id_t *new_dst_data = static_cast<dgl_id_t *>(new_dst->data);
for (size_t i = 0; i < edge_src.size(); i++) { for (size_t i = 0; i < edge_src.size(); i++) {
......
...@@ -252,7 +252,7 @@ HeteroGraphPtr HeteroGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) { ...@@ -252,7 +252,7 @@ HeteroGraphPtr HeteroGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) {
hgindex->num_verts_per_type_)); hgindex->num_verts_per_type_));
} }
HeteroGraphPtr HeteroGraph::CopyTo(HeteroGraphPtr g, const DLContext &ctx) { HeteroGraphPtr HeteroGraph::CopyTo(HeteroGraphPtr g, const DGLContext &ctx) {
if (ctx == g->Context()) { if (ctx == g->Context()) {
return g; return g;
} }
......
...@@ -50,11 +50,11 @@ class HeteroGraph : public BaseHeteroGraph { ...@@ -50,11 +50,11 @@ class HeteroGraph : public BaseHeteroGraph {
LOG(FATAL) << "Bipartite graph is not mutable."; LOG(FATAL) << "Bipartite graph is not mutable.";
} }
DLDataType DataType() const override { DGLDataType DataType() const override {
return relation_graphs_[0]->DataType(); return relation_graphs_[0]->DataType();
} }
DLContext Context() const override { DGLContext Context() const override {
return relation_graphs_[0]->Context(); return relation_graphs_[0]->Context();
} }
...@@ -229,15 +229,15 @@ class HeteroGraph : public BaseHeteroGraph { ...@@ -229,15 +229,15 @@ class HeteroGraph : public BaseHeteroGraph {
static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits); static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits);
/*! \brief Copy the data to another context */ /*! \brief Copy the data to another context */
static HeteroGraphPtr CopyTo(HeteroGraphPtr g, const DLContext &ctx); static HeteroGraphPtr CopyTo(HeteroGraphPtr g, const DGLContext &ctx);
/*! /*!
* \brief Pin all relation graphs of the current graph. * \brief Pin all relation graphs of the current graph.
* \note The graph will be pinned inplace. Behavior depends on the current context, * \note The graph will be pinned inplace. Behavior depends on the current context,
* kDLCPU: will be pinned; * kDGLCPU: will be pinned;
* IsPinned: directly return; * IsPinned: directly return;
* kDLGPU: invalid, will throw an error. * kDGLCUDA: invalid, will throw an error.
* The context check is deferred to pinning the NDArray. * The context check is deferred to pinning the NDArray.
*/ */
void PinMemory_() override; void PinMemory_() override;
......
...@@ -470,8 +470,8 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCopyTo") ...@@ -470,8 +470,8 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCopyTo")
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
int device_type = args[1]; int device_type = args[1];
int device_id = args[2]; int device_id = args[2];
DLContext ctx; DGLContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type); ctx.device_type = static_cast<DGLDeviceType>(device_type);
ctx.device_id = device_id; ctx.device_id = device_id;
HeteroGraphPtr hg_new = HeteroGraph::CopyTo(hg.sptr(), ctx); HeteroGraphPtr hg_new = HeteroGraph::CopyTo(hg.sptr(), ctx);
*rv = HeteroGraphRef(hg_new); *rv = HeteroGraphRef(hg_new);
...@@ -550,7 +550,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroJointUnion") ...@@ -550,7 +550,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroJointUnion")
std::vector<HeteroGraphPtr> component_ptrs; std::vector<HeteroGraphPtr> component_ptrs;
component_ptrs.reserve(component_graphs.size()); component_ptrs.reserve(component_graphs.size());
const int64_t bits = component_graphs[0]->NumBits(); const int64_t bits = component_graphs[0]->NumBits();
const DLContext ctx = component_graphs[0]->Context(); const DGLContext ctx = component_graphs[0]->Context();
for (const auto& component : component_graphs) { for (const auto& component : component_graphs) {
component_ptrs.push_back(component.sptr()); component_ptrs.push_back(component.sptr());
CHECK_EQ(component->NumBits(), bits) CHECK_EQ(component->NumBits(), bits)
...@@ -574,7 +574,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroDisjointUnion_v2") ...@@ -574,7 +574,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroDisjointUnion_v2")
std::vector<HeteroGraphPtr> component_ptrs; std::vector<HeteroGraphPtr> component_ptrs;
component_ptrs.reserve(component_graphs.size()); component_ptrs.reserve(component_graphs.size());
const int64_t bits = component_graphs[0]->NumBits(); const int64_t bits = component_graphs[0]->NumBits();
const DLContext ctx = component_graphs[0]->Context(); const DGLContext ctx = component_graphs[0]->Context();
for (const auto& component : component_graphs) { for (const auto& component : component_graphs) {
component_ptrs.push_back(component.sptr()); component_ptrs.push_back(component.sptr());
CHECK_EQ(component->NumBits(), bits) CHECK_EQ(component->NumBits(), bits)
...@@ -723,9 +723,9 @@ DGL_REGISTER_GLOBAL("transform._CAPI_DGLHeteroSortOutEdges") ...@@ -723,9 +723,9 @@ DGL_REGISTER_GLOBAL("transform._CAPI_DGLHeteroSortOutEdges")
NDArray tag = args[1]; NDArray tag = args[1];
int64_t num_tag = args[2]; int64_t num_tag = args[2];
CHECK_EQ(hg->Context().device_type, kDLCPU) << "Only support sorting by tag on cpu"; CHECK_EQ(hg->Context().device_type, kDGLCPU) << "Only support sorting by tag on cpu";
CHECK(aten::IsValidIdArray(tag)); CHECK(aten::IsValidIdArray(tag));
CHECK_EQ(tag->ctx.device_type, kDLCPU) << "Only support sorting by tag on cpu"; CHECK_EQ(tag->ctx.device_type, kDGLCPU) << "Only support sorting by tag on cpu";
const auto csr = hg->GetCSRMatrix(0); const auto csr = hg->GetCSRMatrix(0);
...@@ -745,9 +745,9 @@ DGL_REGISTER_GLOBAL("transform._CAPI_DGLHeteroSortInEdges") ...@@ -745,9 +745,9 @@ DGL_REGISTER_GLOBAL("transform._CAPI_DGLHeteroSortInEdges")
NDArray tag = args[1]; NDArray tag = args[1];
int64_t num_tag = args[2]; int64_t num_tag = args[2];
CHECK_EQ(hg->Context().device_type, kDLCPU) << "Only support sorting by tag on cpu"; CHECK_EQ(hg->Context().device_type, kDGLCPU) << "Only support sorting by tag on cpu";
CHECK(aten::IsValidIdArray(tag)); CHECK(aten::IsValidIdArray(tag));
CHECK_EQ(tag->ctx.device_type, kDLCPU) << "Only support sorting by tag on cpu"; CHECK_EQ(tag->ctx.device_type, kDGLCPU) << "Only support sorting by tag on cpu";
const auto csc = hg->GetCSCMatrix(0); const auto csc = hg->GetCSCMatrix(0);
......
...@@ -51,8 +51,8 @@ NDArray SerializeMetadata(ImmutableGraphPtr gidx, const std::string &name) { ...@@ -51,8 +51,8 @@ NDArray SerializeMetadata(ImmutableGraphPtr gidx, const std::string &name) {
meta.has_out_csr = gidx->HasOutCSR(); meta.has_out_csr = gidx->HasOutCSR();
meta.has_coo = false; meta.has_coo = false;
NDArray meta_arr = NDArray::EmptyShared(name, {sizeof(meta)}, DLDataType{kDLInt, 8, 1}, NDArray meta_arr = NDArray::EmptyShared(name, {sizeof(meta)}, DGLDataType{kDGLInt, 8, 1},
DLContext{kDLCPU, 0}, true); DGLContext{kDGLCPU, 0}, true);
memcpy(meta_arr->data, &meta, sizeof(meta)); memcpy(meta_arr->data, &meta, sizeof(meta));
return meta_arr; return meta_arr;
#else #else
...@@ -67,8 +67,8 @@ NDArray SerializeMetadata(ImmutableGraphPtr gidx, const std::string &name) { ...@@ -67,8 +67,8 @@ NDArray SerializeMetadata(ImmutableGraphPtr gidx, const std::string &name) {
GraphIndexMetadata DeserializeMetadata(const std::string &name) { GraphIndexMetadata DeserializeMetadata(const std::string &name) {
GraphIndexMetadata meta; GraphIndexMetadata meta;
#ifndef _WIN32 #ifndef _WIN32
NDArray meta_arr = NDArray::EmptyShared(name, {sizeof(meta)}, DLDataType{kDLInt, 8, 1}, NDArray meta_arr = NDArray::EmptyShared(name, {sizeof(meta)}, DGLDataType{kDGLInt, 8, 1},
DLContext{kDLCPU, 0}, false); DGLContext{kDGLCPU, 0}, false);
memcpy(&meta, meta_arr->data, sizeof(meta)); memcpy(&meta, meta_arr->data, sizeof(meta));
#else #else
LOG(FATAL) << "CSR graph doesn't support shared memory in Windows yet"; LOG(FATAL) << "CSR graph doesn't support shared memory in Windows yet";
...@@ -82,13 +82,13 @@ std::tuple<IdArray, IdArray, IdArray> MapFromSharedMemory( ...@@ -82,13 +82,13 @@ std::tuple<IdArray, IdArray, IdArray> MapFromSharedMemory(
const int64_t file_size = (num_verts + 1 + num_edges * 2) * sizeof(dgl_id_t); const int64_t file_size = (num_verts + 1 + num_edges * 2) * sizeof(dgl_id_t);
IdArray sm_array = IdArray::EmptyShared( IdArray sm_array = IdArray::EmptyShared(
shared_mem_name, {file_size}, DLDataType{kDLInt, 8, 1}, DLContext{kDLCPU, 0}, is_create); shared_mem_name, {file_size}, DGLDataType{kDGLInt, 8, 1}, DGLContext{kDGLCPU, 0}, is_create);
// Create views from the shared memory array. Note that we don't need to save // Create views from the shared memory array. Note that we don't need to save
// the sm_array because the refcount is maintained by the view arrays. // the sm_array because the refcount is maintained by the view arrays.
IdArray indptr = sm_array.CreateView({num_verts + 1}, DLDataType{kDLInt, 64, 1}); IdArray indptr = sm_array.CreateView({num_verts + 1}, DGLDataType{kDGLInt, 64, 1});
IdArray indices = sm_array.CreateView({num_edges}, DLDataType{kDLInt, 64, 1}, IdArray indices = sm_array.CreateView({num_edges}, DGLDataType{kDGLInt, 64, 1},
(num_verts + 1) * sizeof(dgl_id_t)); (num_verts + 1) * sizeof(dgl_id_t));
IdArray edge_ids = sm_array.CreateView({num_edges}, DLDataType{kDLInt, 64, 1}, IdArray edge_ids = sm_array.CreateView({num_edges}, DGLDataType{kDGLInt, 64, 1},
(num_verts + 1 + num_edges) * sizeof(dgl_id_t)); (num_verts + 1 + num_edges) * sizeof(dgl_id_t));
return std::make_tuple(indptr, indices, edge_ids); return std::make_tuple(indptr, indices, edge_ids);
#else #else
...@@ -239,7 +239,7 @@ COOPtr CSR::ToCOO() const { ...@@ -239,7 +239,7 @@ COOPtr CSR::ToCOO() const {
return COOPtr(new COO(NumVertices(), coo.row, coo.col)); return COOPtr(new COO(NumVertices(), coo.row, coo.col));
} }
CSR CSR::CopyTo(const DLContext& ctx) const { CSR CSR::CopyTo(const DGLContext& ctx) const {
if (Context() == ctx) { if (Context() == ctx) {
return *this; return *this;
} else { } else {
...@@ -370,7 +370,7 @@ CSRPtr COO::ToCSR() const { ...@@ -370,7 +370,7 @@ CSRPtr COO::ToCSR() const {
return CSRPtr(new CSR(csr.indptr, csr.indices, csr.data)); return CSRPtr(new CSR(csr.indptr, csr.indices, csr.data));
} }
COO COO::CopyTo(const DLContext& ctx) const { COO COO::CopyTo(const DGLContext& ctx) const {
if (Context() == ctx) { if (Context() == ctx) {
return *this; return *this;
} else { } else {
...@@ -556,7 +556,7 @@ ImmutableGraphPtr ImmutableGraph::ToImmutable(GraphPtr graph) { ...@@ -556,7 +556,7 @@ ImmutableGraphPtr ImmutableGraph::ToImmutable(GraphPtr graph) {
} }
} }
ImmutableGraphPtr ImmutableGraph::CopyTo(ImmutableGraphPtr g, const DLContext& ctx) { ImmutableGraphPtr ImmutableGraph::CopyTo(ImmutableGraphPtr g, const DGLContext& ctx) {
if (ctx == g->Context()) { if (ctx == g->Context()) {
return g; return g;
} }
...@@ -656,8 +656,8 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphCopyTo") ...@@ -656,8 +656,8 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphCopyTo")
GraphRef g = args[0]; GraphRef g = args[0];
const int device_type = args[1]; const int device_type = args[1];
const int device_id = args[2]; const int device_id = args[2];
DLContext ctx; DGLContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type); ctx.device_type = static_cast<DGLDeviceType>(device_type);
ctx.device_id = device_id; ctx.device_id = device_id;
ImmutableGraphPtr ig = CHECK_NOTNULL(std::dynamic_pointer_cast<ImmutableGraph>(g.sptr())); ImmutableGraphPtr ig = CHECK_NOTNULL(std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()));
*rv = ImmutableGraph::CopyTo(ig, ctx); *rv = ImmutableGraph::CopyTo(ig, ctx);
......
/*! /*!
* Copyright (c) 2018 by Contributors * Copyright (c) 2018-2022 by Contributors
* \file graph/network.cc * \file graph/network.cc
* \brief DGL networking related APIs * \brief DGL networking related APIs
*/ */
...@@ -29,40 +29,13 @@ const bool AUTO_FREE = true; ...@@ -29,40 +29,13 @@ const bool AUTO_FREE = true;
namespace dgl { namespace dgl {
namespace network { namespace network {
static void NaiveDeleter(DLManagedTensor* managed_tensor) {
delete [] managed_tensor->dl_tensor.shape;
delete [] managed_tensor->dl_tensor.strides;
free(managed_tensor->dl_tensor.data);
delete managed_tensor;
}
NDArray CreateNDArrayFromRaw(std::vector<int64_t> shape, NDArray CreateNDArrayFromRaw(std::vector<int64_t> shape,
DLDataType dtype, DGLDataType dtype,
DLContext ctx, DGLContext ctx,
void* raw, void* raw,
bool auto_free) { bool auto_free) {
DLTensor tensor; return NDArray::CreateFromRaw(shape, dtype, ctx, raw, auto_free);
tensor.ctx = ctx;
tensor.ndim = static_cast<int>(shape.size());
tensor.dtype = dtype;
tensor.shape = new int64_t[tensor.ndim];
for (int i = 0; i < tensor.ndim; ++i) {
tensor.shape[i] = shape[i];
}
tensor.strides = new int64_t[tensor.ndim];
for (int i = 0; i < tensor.ndim; ++i) {
tensor.strides[i] = 1;
}
for (int i = tensor.ndim - 2; i >= 0; --i) {
tensor.strides[i] = tensor.shape[i+1] * tensor.strides[i+1];
}
tensor.data = raw;
DLManagedTensor *managed_tensor = new DLManagedTensor();
managed_tensor->dl_tensor = tensor;
if (auto_free) {
managed_tensor->deleter = NaiveDeleter;
}
return NDArray::FromDLPack(managed_tensor);
} }
void ArrayMeta::AddArray(const NDArray& array) { void ArrayMeta::AddArray(const NDArray& array) {
...@@ -87,7 +60,7 @@ char* ArrayMeta::Serialize(int64_t* size) { ...@@ -87,7 +60,7 @@ char* ArrayMeta::Serialize(int64_t* size) {
buffer_size += sizeof(int64_t) * data_shape_.size(); buffer_size += sizeof(int64_t) * data_shape_.size();
// we don't need to write data_type_.size() // we don't need to write data_type_.size()
// because it equals to ndarray_count_ * 3 // because it equals to ndarray_count_ * 3
buffer_size += sizeof(DLDataType) * data_type_.size(); buffer_size += sizeof(DGLDataType) * data_type_.size();
} }
// In the future, we should have a better memory management as // In the future, we should have a better memory management as
// allocating a large chunk of memory can be very expensive. // allocating a large chunk of memory can be very expensive.
...@@ -102,9 +75,9 @@ char* ArrayMeta::Serialize(int64_t* size) { ...@@ -102,9 +75,9 @@ char* ArrayMeta::Serialize(int64_t* size) {
pointer += sizeof(ndarray_count_); pointer += sizeof(ndarray_count_);
// Write data type // Write data type
memcpy(pointer, memcpy(pointer,
reinterpret_cast<DLDataType*>(data_type_.data()), reinterpret_cast<DGLDataType*>(data_type_.data()),
sizeof(DLDataType) * data_type_.size()); sizeof(DGLDataType) * data_type_.size());
pointer += (sizeof(DLDataType) * data_type_.size()); pointer += (sizeof(DGLDataType) * data_type_.size());
// Write size of data_shape_ // Write size of data_shape_
*(reinterpret_cast<size_t*>(pointer)) = data_shape_.size(); *(reinterpret_cast<size_t*>(pointer)) = data_shape_.size();
pointer += sizeof(data_shape_.size()); pointer += sizeof(data_shape_.size());
...@@ -131,9 +104,9 @@ void ArrayMeta::Deserialize(char* buffer, int64_t size) { ...@@ -131,9 +104,9 @@ void ArrayMeta::Deserialize(char* buffer, int64_t size) {
// Read data type // Read data type
data_type_.resize(ndarray_count_); data_type_.resize(ndarray_count_);
memcpy(data_type_.data(), buffer, memcpy(data_type_.data(), buffer,
ndarray_count_ * sizeof(DLDataType)); ndarray_count_ * sizeof(DGLDataType));
buffer += ndarray_count_ * sizeof(DLDataType); buffer += ndarray_count_ * sizeof(DGLDataType);
data_size += ndarray_count_ * sizeof(DLDataType); data_size += ndarray_count_ * sizeof(DGLDataType);
// Read size of data_shape_ // Read size of data_shape_
size_t count = *(reinterpret_cast<size_t*>(buffer)); size_t count = *(reinterpret_cast<size_t*>(buffer));
buffer += sizeof(size_t); buffer += sizeof(size_t);
...@@ -405,8 +378,8 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvNodeFlow") ...@@ -405,8 +378,8 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvNodeFlow")
CHECK_EQ(meta.data_shape_[0], 1); CHECK_EQ(meta.data_shape_[0], 1);
nf->node_mapping = CreateNDArrayFromRaw( nf->node_mapping = CreateNDArrayFromRaw(
{meta.data_shape_[1]}, {meta.data_shape_[1]},
DLDataType{kDLInt, 64, 1}, DGLDataType{kDGLInt, 64, 1},
DLContext{kDLCPU, 0}, DGLContext{kDGLCPU, 0},
array_0.data, array_0.data,
AUTO_FREE); AUTO_FREE);
// edge_mapping // edge_mapping
...@@ -415,8 +388,8 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvNodeFlow") ...@@ -415,8 +388,8 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvNodeFlow")
CHECK_EQ(meta.data_shape_[2], 1); CHECK_EQ(meta.data_shape_[2], 1);
nf->edge_mapping = CreateNDArrayFromRaw( nf->edge_mapping = CreateNDArrayFromRaw(
{meta.data_shape_[3]}, {meta.data_shape_[3]},
DLDataType{kDLInt, 64, 1}, DGLDataType{kDGLInt, 64, 1},
DLContext{kDLCPU, 0}, DGLContext{kDGLCPU, 0},
array_1.data, array_1.data,
AUTO_FREE); AUTO_FREE);
// layer_offset // layer_offset
...@@ -425,8 +398,8 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvNodeFlow") ...@@ -425,8 +398,8 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvNodeFlow")
CHECK_EQ(meta.data_shape_[4], 1); CHECK_EQ(meta.data_shape_[4], 1);
nf->layer_offsets = CreateNDArrayFromRaw( nf->layer_offsets = CreateNDArrayFromRaw(
{meta.data_shape_[5]}, {meta.data_shape_[5]},
DLDataType{kDLInt, 64, 1}, DGLDataType{kDGLInt, 64, 1},
DLContext{kDLCPU, 0}, DGLContext{kDGLCPU, 0},
array_2.data, array_2.data,
AUTO_FREE); AUTO_FREE);
// flow_offset // flow_offset
...@@ -435,8 +408,8 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvNodeFlow") ...@@ -435,8 +408,8 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvNodeFlow")
CHECK_EQ(meta.data_shape_[6], 1); CHECK_EQ(meta.data_shape_[6], 1);
nf->flow_offsets = CreateNDArrayFromRaw( nf->flow_offsets = CreateNDArrayFromRaw(
{meta.data_shape_[7]}, {meta.data_shape_[7]},
DLDataType{kDLInt, 64, 1}, DGLDataType{kDGLInt, 64, 1},
DLContext{kDLCPU, 0}, DGLContext{kDGLCPU, 0},
array_3.data, array_3.data,
AUTO_FREE); AUTO_FREE);
// CSR indptr // CSR indptr
...@@ -445,8 +418,8 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvNodeFlow") ...@@ -445,8 +418,8 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvNodeFlow")
CHECK_EQ(meta.data_shape_[8], 1); CHECK_EQ(meta.data_shape_[8], 1);
NDArray indptr = CreateNDArrayFromRaw( NDArray indptr = CreateNDArrayFromRaw(
{meta.data_shape_[9]}, {meta.data_shape_[9]},
DLDataType{kDLInt, 64, 1}, DGLDataType{kDGLInt, 64, 1},
DLContext{kDLCPU, 0}, DGLContext{kDGLCPU, 0},
array_4.data, array_4.data,
AUTO_FREE); AUTO_FREE);
// CSR indice // CSR indice
...@@ -455,8 +428,8 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvNodeFlow") ...@@ -455,8 +428,8 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvNodeFlow")
CHECK_EQ(meta.data_shape_[10], 1); CHECK_EQ(meta.data_shape_[10], 1);
NDArray indice = CreateNDArrayFromRaw( NDArray indice = CreateNDArrayFromRaw(
{meta.data_shape_[11]}, {meta.data_shape_[11]},
DLDataType{kDLInt, 64, 1}, DGLDataType{kDGLInt, 64, 1},
DLContext{kDLCPU, 0}, DGLContext{kDGLCPU, 0},
array_5.data, array_5.data,
AUTO_FREE); AUTO_FREE);
// CSR edge_ids // CSR edge_ids
...@@ -465,8 +438,8 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvNodeFlow") ...@@ -465,8 +438,8 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvNodeFlow")
CHECK_EQ(meta.data_shape_[12], 1); CHECK_EQ(meta.data_shape_[12], 1);
NDArray edge_ids = CreateNDArrayFromRaw( NDArray edge_ids = CreateNDArrayFromRaw(
{meta.data_shape_[13]}, {meta.data_shape_[13]},
DLDataType{kDLInt, 64, 1}, DGLDataType{kDGLInt, 64, 1},
DLContext{kDLCPU, 0}, DGLContext{kDGLCPU, 0},
array_6.data, array_6.data,
AUTO_FREE); AUTO_FREE);
// Create CSR // Create CSR
...@@ -598,7 +571,7 @@ static KVStoreMsg* recv_kv_message(network::Receiver* receiver) { ...@@ -598,7 +571,7 @@ static KVStoreMsg* recv_kv_message(network::Receiver* receiver) {
kv_msg->id = CreateNDArrayFromRaw( kv_msg->id = CreateNDArrayFromRaw(
{meta.data_shape_[1]}, {meta.data_shape_[1]},
meta.data_type_[0], meta.data_type_[0],
DLContext{kDLCPU, 0}, DGLContext{kDGLCPU, 0},
recv_id_msg.data, recv_id_msg.data,
AUTO_FREE); AUTO_FREE);
} }
...@@ -617,7 +590,7 @@ static KVStoreMsg* recv_kv_message(network::Receiver* receiver) { ...@@ -617,7 +590,7 @@ static KVStoreMsg* recv_kv_message(network::Receiver* receiver) {
kv_msg->data = CreateNDArrayFromRaw( kv_msg->data = CreateNDArrayFromRaw(
vec_shape, vec_shape,
meta.data_type_[1], meta.data_type_[1],
DLContext{kDLCPU, 0}, DGLContext{kDGLCPU, 0},
recv_data_msg.data, recv_data_msg.data,
AUTO_FREE); AUTO_FREE);
} }
...@@ -636,7 +609,7 @@ static KVStoreMsg* recv_kv_message(network::Receiver* receiver) { ...@@ -636,7 +609,7 @@ static KVStoreMsg* recv_kv_message(network::Receiver* receiver) {
kv_msg->shape = CreateNDArrayFromRaw( kv_msg->shape = CreateNDArrayFromRaw(
vec_shape, vec_shape,
meta.data_type_[0], meta.data_type_[0],
DLContext{kDLCPU, 0}, DGLContext{kDGLCPU, 0},
recv_shape_msg.data, recv_shape_msg.data,
AUTO_FREE); AUTO_FREE);
} }
...@@ -814,7 +787,7 @@ DGL_REGISTER_GLOBAL("network._CAPI_FastPull") ...@@ -814,7 +787,7 @@ DGL_REGISTER_GLOBAL("network._CAPI_FastPull")
kv_msg.name = name; kv_msg.name = name;
kv_msg.id = CreateNDArrayFromRaw({static_cast<int64_t>(remote_ids[i].size())}, kv_msg.id = CreateNDArrayFromRaw({static_cast<int64_t>(remote_ids[i].size())},
ID->dtype, ID->dtype,
DLContext{kDLCPU, 0}, DGLContext{kDGLCPU, 0},
remote_ids[i].data(), remote_ids[i].data(),
!AUTO_FREE); !AUTO_FREE);
int lower = i*group_count; int lower = i*group_count;
...@@ -859,7 +832,7 @@ DGL_REGISTER_GLOBAL("network._CAPI_FastPull") ...@@ -859,7 +832,7 @@ DGL_REGISTER_GLOBAL("network._CAPI_FastPull")
NDArray res_tensor = CreateNDArrayFromRaw( NDArray res_tensor = CreateNDArrayFromRaw(
local_data_shape, local_data_shape,
local_data->dtype, local_data->dtype,
DLContext{kDLCPU, 0}, DGLContext{kDGLCPU, 0},
return_data, return_data,
AUTO_FREE); AUTO_FREE);
*rv = res_tensor; *rv = res_tensor;
......
...@@ -25,8 +25,8 @@ namespace network { ...@@ -25,8 +25,8 @@ namespace network {
* \brief Create NDArray from raw data * \brief Create NDArray from raw data
*/ */
NDArray CreateNDArrayFromRaw(std::vector<int64_t> shape, NDArray CreateNDArrayFromRaw(std::vector<int64_t> shape,
DLDataType dtype, DGLDataType dtype,
DLContext ctx, DGLContext ctx,
void* raw); void* raw);
/*! /*!
...@@ -145,7 +145,7 @@ class ArrayMeta { ...@@ -145,7 +145,7 @@ class ArrayMeta {
/*! /*!
* \brief DataType for each NDArray * \brief DataType for each NDArray
*/ */
std::vector<DLDataType> data_type_; std::vector<DGLDataType> data_type_;
/*! /*!
* \brief We first write the ndim to data_shape_ * \brief We first write the ndim to data_shape_
......
...@@ -885,7 +885,7 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_UniformSampling") ...@@ -885,7 +885,7 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_UniformSampling")
CHECK(gptr) << "sampling isn't implemented in mutable graph"; CHECK(gptr) << "sampling isn't implemented in mutable graph";
CHECK(aten::IsValidIdArray(seed_nodes)); CHECK(aten::IsValidIdArray(seed_nodes));
CHECK_EQ(seed_nodes->ctx.device_type, kDLCPU) CHECK_EQ(seed_nodes->ctx.device_type, kDGLCPU)
<< "UniformSampler only support CPU sampling"; << "UniformSampler only support CPU sampling";
std::vector<NodeFlow> nflows = NeighborSamplingImpl<float>( std::vector<NodeFlow> nflows = NeighborSamplingImpl<float>(
...@@ -913,16 +913,16 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_NeighborSampling") ...@@ -913,16 +913,16 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_NeighborSampling")
CHECK(gptr) << "sampling isn't implemented in mutable graph"; CHECK(gptr) << "sampling isn't implemented in mutable graph";
CHECK(aten::IsValidIdArray(seed_nodes)); CHECK(aten::IsValidIdArray(seed_nodes));
CHECK_EQ(seed_nodes->ctx.device_type, kDLCPU) CHECK_EQ(seed_nodes->ctx.device_type, kDGLCPU)
<< "NeighborSampler only support CPU sampling"; << "NeighborSampler only support CPU sampling";
std::vector<NodeFlow> nflows; std::vector<NodeFlow> nflows;
CHECK(probability->dtype.code == kDLFloat) CHECK(probability->dtype.code == kDGLFloat)
<< "transition probability must be float"; << "transition probability must be float";
CHECK(probability->ndim == 1) CHECK(probability->ndim == 1)
<< "transition probability must be a 1-dimensional vector"; << "transition probability must be a 1-dimensional vector";
CHECK_EQ(probability->ctx.device_type, kDLCPU) CHECK_EQ(probability->ctx.device_type, kDGLCPU)
<< "NeighborSampling only support CPU sampling"; << "NeighborSampling only support CPU sampling";
ATEN_FLOAT_TYPE_SWITCH( ATEN_FLOAT_TYPE_SWITCH(
...@@ -964,11 +964,11 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_LayerSampling") ...@@ -964,11 +964,11 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_LayerSampling")
auto gptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()); auto gptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(gptr) << "sampling isn't implemented in mutable graph"; CHECK(gptr) << "sampling isn't implemented in mutable graph";
CHECK(aten::IsValidIdArray(seed_nodes)); CHECK(aten::IsValidIdArray(seed_nodes));
CHECK_EQ(seed_nodes->ctx.device_type, kDLCPU) CHECK_EQ(seed_nodes->ctx.device_type, kDGLCPU)
<< "LayerSampler only support CPU sampling"; << "LayerSampler only support CPU sampling";
CHECK(aten::IsValidIdArray(layer_sizes)); CHECK(aten::IsValidIdArray(layer_sizes));
CHECK_EQ(layer_sizes->ctx.device_type, kDLCPU) CHECK_EQ(layer_sizes->ctx.device_type, kDGLCPU)
<< "LayerSampler only support CPU sampling"; << "LayerSampler only support CPU sampling";
const dgl_id_t* seed_nodes_data = static_cast<dgl_id_t*>(seed_nodes->data); const dgl_id_t* seed_nodes_data = static_cast<dgl_id_t*>(seed_nodes->data);
...@@ -1477,7 +1477,7 @@ public: ...@@ -1477,7 +1477,7 @@ public:
IdArray worker_seeds; IdArray worker_seeds;
if (replacement_ == false) { if (replacement_ == false) {
worker_seeds = seed_edges_.CreateView({num_edges}, DLDataType{kDLInt, 64, 1}, worker_seeds = seed_edges_.CreateView({num_edges}, DGLDataType{kDGLInt, 64, 1},
sizeof(dgl_id_t) * start); sizeof(dgl_id_t) * start);
} else { } else {
std::vector<dgl_id_t> seeds; std::vector<dgl_id_t> seeds;
...@@ -1593,12 +1593,12 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_CreateUniformEdgeSampler") ...@@ -1593,12 +1593,12 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_CreateUniformEdgeSampler")
auto gptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()); auto gptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(gptr) << "sampling isn't implemented in mutable graph"; CHECK(gptr) << "sampling isn't implemented in mutable graph";
CHECK(aten::IsValidIdArray(seed_edges)); CHECK(aten::IsValidIdArray(seed_edges));
CHECK_EQ(seed_edges->ctx.device_type, kDLCPU) CHECK_EQ(seed_edges->ctx.device_type, kDGLCPU)
<< "UniformEdgeSampler only support CPU sampling"; << "UniformEdgeSampler only support CPU sampling";
if (relations->shape[0] > 0) { if (relations->shape[0] > 0) {
CHECK(aten::IsValidIdArray(relations)); CHECK(aten::IsValidIdArray(relations));
CHECK_EQ(relations->ctx.device_type, kDLCPU) CHECK_EQ(relations->ctx.device_type, kDGLCPU)
<< "WeightedEdgeSampler only support CPU sampling"; << "WeightedEdgeSampler only support CPU sampling";
} }
BuildCoo(*gptr); BuildCoo(*gptr);
...@@ -1879,21 +1879,21 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_CreateWeightedEdgeSampler") ...@@ -1879,21 +1879,21 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_CreateWeightedEdgeSampler")
auto gptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()); auto gptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(gptr) << "sampling isn't implemented in mutable graph"; CHECK(gptr) << "sampling isn't implemented in mutable graph";
CHECK(aten::IsValidIdArray(seed_edges)); CHECK(aten::IsValidIdArray(seed_edges));
CHECK_EQ(seed_edges->ctx.device_type, kDLCPU) CHECK_EQ(seed_edges->ctx.device_type, kDGLCPU)
<< "WeightedEdgeSampler only support CPU sampling"; << "WeightedEdgeSampler only support CPU sampling";
CHECK(edge_weight->dtype.code == kDLFloat) << "edge_weight should be FloatType"; CHECK(edge_weight->dtype.code == kDGLFloat) << "edge_weight should be FloatType";
CHECK(edge_weight->dtype.bits == 32) << "WeightedEdgeSampler only support float weight"; CHECK(edge_weight->dtype.bits == 32) << "WeightedEdgeSampler only support float weight";
CHECK_EQ(edge_weight->ctx.device_type, kDLCPU) CHECK_EQ(edge_weight->ctx.device_type, kDGLCPU)
<< "WeightedEdgeSampler only support CPU sampling"; << "WeightedEdgeSampler only support CPU sampling";
if (node_weight->shape[0] > 0) { if (node_weight->shape[0] > 0) {
CHECK(node_weight->dtype.code == kDLFloat) << "node_weight should be FloatType"; CHECK(node_weight->dtype.code == kDGLFloat) << "node_weight should be FloatType";
CHECK(node_weight->dtype.bits == 32) << "WeightedEdgeSampler only support float weight"; CHECK(node_weight->dtype.bits == 32) << "WeightedEdgeSampler only support float weight";
CHECK_EQ(node_weight->ctx.device_type, kDLCPU) CHECK_EQ(node_weight->ctx.device_type, kDGLCPU)
<< "WeightedEdgeSampler only support CPU sampling"; << "WeightedEdgeSampler only support CPU sampling";
} }
if (relations->shape[0] > 0) { if (relations->shape[0] > 0) {
CHECK(aten::IsValidIdArray(relations)); CHECK(aten::IsValidIdArray(relations));
CHECK_EQ(relations->ctx.device_type, kDLCPU) CHECK_EQ(relations->ctx.device_type, kDGLCPU)
<< "WeightedEdgeSampler only support CPU sampling"; << "WeightedEdgeSampler only support CPU sampling";
} }
BuildCoo(*gptr); BuildCoo(*gptr);
......
...@@ -79,7 +79,7 @@ HeteroSubgraph SampleNeighbors( ...@@ -79,7 +79,7 @@ HeteroSubgraph SampleNeighbors(
CHECK_EQ(prob.size(), hg->NumEdgeTypes()) CHECK_EQ(prob.size(), hg->NumEdgeTypes())
<< "Number of probability tensors must match the number of edge types."; << "Number of probability tensors must match the number of edge types.";
DLContext ctx = aten::GetContextOf(nodes); DGLContext ctx = aten::GetContextOf(nodes);
std::vector<HeteroGraphPtr> subrels(hg->NumEdgeTypes()); std::vector<HeteroGraphPtr> subrels(hg->NumEdgeTypes());
std::vector<IdArray> induced_edges(hg->NumEdgeTypes()); std::vector<IdArray> induced_edges(hg->NumEdgeTypes());
......
...@@ -274,7 +274,7 @@ FrequencyHashmap<IdxType>::~FrequencyHashmap() { ...@@ -274,7 +274,7 @@ FrequencyHashmap<IdxType>::~FrequencyHashmap() {
template <typename IdxType> template <typename IdxType>
std::tuple<IdArray, IdArray, IdArray> FrequencyHashmap<IdxType>::Topk( std::tuple<IdArray, IdArray, IdArray> FrequencyHashmap<IdxType>::Topk(
const IdxType *src_data, const IdxType *dst_data, DLDataType dtype, const IdxType *src_data, const IdxType *dst_data, DGLDataType dtype,
const int64_t num_edges, const int64_t num_edges_per_node, const int64_t num_edges, const int64_t num_edges_per_node,
const int64_t num_pick) { const int64_t num_pick) {
...@@ -323,9 +323,7 @@ std::tuple<IdArray, IdArray, IdArray> FrequencyHashmap<IdxType>::Topk( ...@@ -323,9 +323,7 @@ std::tuple<IdArray, IdArray, IdArray> FrequencyHashmap<IdxType>::Topk(
device->FreeWorkspace(_ctx, d_temp_storage); device->FreeWorkspace(_ctx, d_temp_storage);
std::swap(edge_blocks_prefix, edge_blocks_prefix_alternate); std::swap(edge_blocks_prefix, edge_blocks_prefix_alternate);
device->CopyDataFromTo(&edge_blocks_prefix[num_edge_blocks], 0, &num_unique_edges, 0, device->CopyDataFromTo(&edge_blocks_prefix[num_edge_blocks], 0, &num_unique_edges, 0,
sizeof(num_unique_edges), sizeof(num_unique_edges), _ctx, DGLContext{kDGLCPU, 0}, dtype);
_ctx, DGLContext{kDLCPU, 0},
dtype);
device->StreamSync(_ctx, _stream); device->StreamSync(_ctx, _stream);
// 2.2 Allocate the data of unique edges and frequency // 2.2 Allocate the data of unique edges and frequency
// double space to use SegmentedRadixSort // double space to use SegmentedRadixSort
...@@ -408,9 +406,7 @@ std::tuple<IdArray, IdArray, IdArray> FrequencyHashmap<IdxType>::Topk( ...@@ -408,9 +406,7 @@ std::tuple<IdArray, IdArray, IdArray> FrequencyHashmap<IdxType>::Topk(
// 5. Pick the data to result // 5. Pick the data to result
IdxType num_output = 0; IdxType num_output = 0;
device->CopyDataFromTo(&unique_output_offsets[num_dst_nodes], 0, &num_output, 0, device->CopyDataFromTo(&unique_output_offsets[num_dst_nodes], 0, &num_output, 0,
sizeof(num_output), sizeof(num_output), _ctx, DGLContext{kDGLCPU, 0}, dtype);
_ctx, DGLContext{kDLCPU, 0},
dtype);
device->StreamSync(_ctx, _stream); device->StreamSync(_ctx, _stream);
IdArray res_src = IdArray::Empty({static_cast<int64_t>(num_output)}, IdArray res_src = IdArray::Empty({static_cast<int64_t>(num_output)},
......
...@@ -54,7 +54,7 @@ public: ...@@ -54,7 +54,7 @@ public:
~FrequencyHashmap(); ~FrequencyHashmap();
using EdgeItem = typename DeviceEdgeHashmap<IdxType>::EdgeItem; using EdgeItem = typename DeviceEdgeHashmap<IdxType>::EdgeItem;
std::tuple<IdArray, IdArray, IdArray> Topk( std::tuple<IdArray, IdArray, IdArray> Topk(
const IdxType *src_data, const IdxType *dst_data, DLDataType dtype, const IdxType *src_data, const IdxType *dst_data, DGLDataType dtype,
const int64_t num_edges, const int64_t num_edges_per_node, const int64_t num_edges, const int64_t num_edges_per_node,
const int64_t num_pick); const int64_t num_pick);
private: private:
......
...@@ -18,7 +18,7 @@ namespace sampling { ...@@ -18,7 +18,7 @@ namespace sampling {
namespace impl { namespace impl {
template<DLDeviceType XPU, typename IdxType> template<DGLDeviceType XPU, typename IdxType>
TypeArray GetNodeTypesFromMetapath( TypeArray GetNodeTypesFromMetapath(
const HeteroGraphPtr hg, const HeteroGraphPtr hg,
const TypeArray metapath) { const TypeArray metapath) {
...@@ -49,11 +49,11 @@ TypeArray GetNodeTypesFromMetapath( ...@@ -49,11 +49,11 @@ TypeArray GetNodeTypesFromMetapath(
} }
template template
TypeArray GetNodeTypesFromMetapath<kDLCPU, int32_t>( TypeArray GetNodeTypesFromMetapath<kDGLCPU, int32_t>(
const HeteroGraphPtr hg, const HeteroGraphPtr hg,
const TypeArray metapath); const TypeArray metapath);
template template
TypeArray GetNodeTypesFromMetapath<kDLCPU, int64_t>( TypeArray GetNodeTypesFromMetapath<kDGLCPU, int64_t>(
const HeteroGraphPtr hg, const HeteroGraphPtr hg,
const TypeArray metapath); const TypeArray metapath);
......
...@@ -20,14 +20,14 @@ namespace sampling { ...@@ -20,14 +20,14 @@ namespace sampling {
namespace impl { namespace impl {
template<DLDeviceType XPU, typename IdxType> template<DGLDeviceType XPU, typename IdxType>
TypeArray GetNodeTypesFromMetapath( TypeArray GetNodeTypesFromMetapath(
const HeteroGraphPtr hg, const HeteroGraphPtr hg,
const TypeArray metapath) { const TypeArray metapath) {
uint64_t num_etypes = metapath->shape[0]; uint64_t num_etypes = metapath->shape[0];
auto cpu_ctx = DGLContext{kDLCPU, 0}; auto cpu_ctx = DGLContext{kDGLCPU, 0};
auto metapath_ctx = metapath->ctx; auto metapath_ctx = metapath->ctx;
auto stream = DeviceAPI::Get(metapath_ctx)->GetStream(); auto stream = DeviceAPI::Get(metapath_ctx)->GetStream();
...@@ -61,11 +61,11 @@ TypeArray GetNodeTypesFromMetapath( ...@@ -61,11 +61,11 @@ TypeArray GetNodeTypesFromMetapath(
} }
template template
TypeArray GetNodeTypesFromMetapath<kDLGPU, int32_t>( TypeArray GetNodeTypesFromMetapath<kDGLCUDA, int32_t>(
const HeteroGraphPtr hg, const HeteroGraphPtr hg,
const TypeArray metapath); const TypeArray metapath);
template template
TypeArray GetNodeTypesFromMetapath<kDLGPU, int64_t>( TypeArray GetNodeTypesFromMetapath<kDGLCUDA, int64_t>(
const HeteroGraphPtr hg, const HeteroGraphPtr hg,
const TypeArray metapath); const TypeArray metapath);
......
...@@ -51,7 +51,7 @@ using TerminatePredicate = std::function<bool(IdxType *, dgl_id_t, int64_t)>; ...@@ -51,7 +51,7 @@ using TerminatePredicate = std::function<bool(IdxType *, dgl_id_t, int64_t)>;
* \return A tuple of ID of next successor (-1 if not exist), the last traversed edge * \return A tuple of ID of next successor (-1 if not exist), the last traversed edge
* ID, as well as whether to terminate. * ID, as well as whether to terminate.
*/ */
template<DLDeviceType XPU, typename IdxType> template<DGLDeviceType XPU, typename IdxType>
std::tuple<dgl_id_t, dgl_id_t, bool> MetapathRandomWalkStep( std::tuple<dgl_id_t, dgl_id_t, bool> MetapathRandomWalkStep(
IdxType *data, IdxType *data,
dgl_id_t curr, dgl_id_t curr,
...@@ -119,7 +119,7 @@ std::tuple<dgl_id_t, dgl_id_t, bool> MetapathRandomWalkStep( ...@@ -119,7 +119,7 @@ std::tuple<dgl_id_t, dgl_id_t, bool> MetapathRandomWalkStep(
* \return A pair of ID of next successor (-1 if not exist), as well as whether to terminate. * \return A pair of ID of next successor (-1 if not exist), as well as whether to terminate.
* \note This function is called only if all the probability arrays are null. * \note This function is called only if all the probability arrays are null.
*/ */
template<DLDeviceType XPU, typename IdxType> template<DGLDeviceType XPU, typename IdxType>
std::tuple<dgl_id_t, dgl_id_t, bool> MetapathRandomWalkStepUniform( std::tuple<dgl_id_t, dgl_id_t, bool> MetapathRandomWalkStepUniform(
IdxType *data, IdxType *data,
dgl_id_t curr, dgl_id_t curr,
...@@ -167,7 +167,7 @@ std::tuple<dgl_id_t, dgl_id_t, bool> MetapathRandomWalkStepUniform( ...@@ -167,7 +167,7 @@ std::tuple<dgl_id_t, dgl_id_t, bool> MetapathRandomWalkStepUniform(
* \return A 2D array of shape (len(seeds), len(metapath) + 1) with node IDs, and * \return A 2D array of shape (len(seeds), len(metapath) + 1) with node IDs, and
* A 2D array of shape (len(seeds), len(metapath)) with edge IDs. * A 2D array of shape (len(seeds), len(metapath)) with edge IDs.
*/ */
template<DLDeviceType XPU, typename IdxType> template<DGLDeviceType XPU, typename IdxType>
std::pair<IdArray, IdArray> MetapathBasedRandomWalk( std::pair<IdArray, IdArray> MetapathBasedRandomWalk(
const HeteroGraphPtr hg, const HeteroGraphPtr hg,
const IdArray seeds, const IdArray seeds,
......
...@@ -19,7 +19,7 @@ namespace sampling { ...@@ -19,7 +19,7 @@ namespace sampling {
namespace impl { namespace impl {
template <DLDeviceType XPU, typename IdxType> template <DGLDeviceType XPU, typename IdxType>
std::pair<IdArray, IdArray> Node2vec( std::pair<IdArray, IdArray> Node2vec(
const HeteroGraphPtr hg, const IdArray seeds, const double p, const HeteroGraphPtr hg, const IdArray seeds, const double p,
const double q, const int64_t walk_length, const double q, const int64_t walk_length,
...@@ -31,13 +31,13 @@ std::pair<IdArray, IdArray> Node2vec( ...@@ -31,13 +31,13 @@ std::pair<IdArray, IdArray> Node2vec(
terminate); terminate);
} }
template std::pair<IdArray, IdArray> Node2vec<kDLCPU, int32_t>( template std::pair<IdArray, IdArray> Node2vec<kDGLCPU, int32_t>(
const HeteroGraphPtr hg, const HeteroGraphPtr hg,
const IdArray seeds, const double p, const IdArray seeds, const double p,
const double q, const double q,
const int64_t walk_length, const int64_t walk_length,
const FloatArray &prob); const FloatArray &prob);
template std::pair<IdArray, IdArray> Node2vec<kDLCPU, int64_t>( template std::pair<IdArray, IdArray> Node2vec<kDGLCPU, int64_t>(
const HeteroGraphPtr hg, const HeteroGraphPtr hg,
const IdArray seeds, const double p, const IdArray seeds, const double p,
const double q, const double q,
......
...@@ -40,7 +40,7 @@ namespace impl { ...@@ -40,7 +40,7 @@ namespace impl {
* \return A 2D array of shape (len(seeds), len(walk_length) + 1) * \return A 2D array of shape (len(seeds), len(walk_length) + 1)
* with node IDs. The paths that terminated early are padded with -1. * with node IDs. The paths that terminated early are padded with -1.
*/ */
template <DLDeviceType XPU, typename IdxType> template <DGLDeviceType XPU, typename IdxType>
std::pair<IdArray, IdArray> Node2vec( std::pair<IdArray, IdArray> Node2vec(
const HeteroGraphPtr hg, const IdArray seeds, const double p, const HeteroGraphPtr hg, const IdArray seeds, const double p,
const double q, const int64_t walk_length, const double q, const int64_t walk_length,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment