"git@developer.sourcefind.cn:OpenDAS/bitsandbytes.git" did not exist on "0bf71989566c63f4b301e5bdbf2cd73b5683a8e9"
Unverified Commit 8f0df39e authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] clang-format auto fix. (#4810)



* [Misc] clang-format auto fix.

* manual

* manual
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 401e1278
...@@ -13,8 +13,8 @@ namespace dgl { ...@@ -13,8 +13,8 @@ namespace dgl {
using dgl::runtime::NDArray; using dgl::runtime::NDArray;
NDArray CreateNDArrayFromRawData(std::vector<int64_t> shape, DGLDataType dtype, NDArray CreateNDArrayFromRawData(
DGLContext ctx, void* raw) { std::vector<int64_t> shape, DGLDataType dtype, DGLContext ctx, void* raw) {
return NDArray::CreateFromRaw(shape, dtype, ctx, raw, true); return NDArray::CreateFromRaw(shape, dtype, ctx, raw, true);
} }
...@@ -25,9 +25,9 @@ void StreamWithBuffer::PushNDArray(const NDArray& tensor) { ...@@ -25,9 +25,9 @@ void StreamWithBuffer::PushNDArray(const NDArray& tensor) {
int ndim = tensor->ndim; int ndim = tensor->ndim;
this->WriteArray(tensor->shape, ndim); this->WriteArray(tensor->shape, ndim);
CHECK(tensor.IsContiguous()) CHECK(tensor.IsContiguous())
<< "StreamWithBuffer only supports contiguous tensor"; << "StreamWithBuffer only supports contiguous tensor";
CHECK_EQ(tensor->byte_offset, 0) CHECK_EQ(tensor->byte_offset, 0)
<< "StreamWithBuffer only supports zero byte offset tensor"; << "StreamWithBuffer only supports zero byte offset tensor";
int type_bytes = tensor->dtype.bits / 8; int type_bytes = tensor->dtype.bits / 8;
int64_t num_elems = 1; int64_t num_elems = 1;
for (int i = 0; i < ndim; ++i) { for (int i = 0; i < ndim; ++i) {
...@@ -40,7 +40,8 @@ void StreamWithBuffer::PushNDArray(const NDArray& tensor) { ...@@ -40,7 +40,8 @@ void StreamWithBuffer::PushNDArray(const NDArray& tensor) {
// If the stream is for remote communication or the data is not stored in // If the stream is for remote communication or the data is not stored in
// shared memory, serialize the data content as a buffer. // shared memory, serialize the data content as a buffer.
this->Write<bool>(false); this->Write<bool>(false);
// If this is a null ndarray, we will not push it into the underlying buffer_list // If this is a null ndarray, we will not push it into the underlying
// buffer_list
if (data_byte_size != 0) { if (data_byte_size != 0) {
buffer_list_.emplace_back(tensor, tensor->data, data_byte_size); buffer_list_.emplace_back(tensor, tensor->data, data_byte_size);
} }
...@@ -90,8 +91,8 @@ NDArray StreamWithBuffer::PopNDArray() { ...@@ -90,8 +91,8 @@ NDArray StreamWithBuffer::PopNDArray() {
// Mean this is a null ndarray // Mean this is a null ndarray
ret = CreateNDArrayFromRawData(shape, dtype, cpu_ctx, nullptr); ret = CreateNDArrayFromRawData(shape, dtype, cpu_ctx, nullptr);
} else { } else {
ret = CreateNDArrayFromRawData(shape, dtype, cpu_ctx, ret = CreateNDArrayFromRawData(
buffer_list_.front().data); shape, dtype, cpu_ctx, buffer_list_.front().data);
buffer_list_.pop_front(); buffer_list_.pop_front();
} }
return ret; return ret;
......
...@@ -31,8 +31,8 @@ using namespace dgl::aten; ...@@ -31,8 +31,8 @@ using namespace dgl::aten;
namespace dgl { namespace dgl {
template <> template <>
NDArray SharedMemManager::CopyToSharedMem<NDArray>(const NDArray &data, NDArray SharedMemManager::CopyToSharedMem<NDArray>(
std::string name) { const NDArray &data, std::string name) {
DGLContext ctx = {kDGLCPU, 0}; DGLContext ctx = {kDGLCPU, 0};
std::vector<int64_t> shape(data->shape, data->shape + data->ndim); std::vector<int64_t> shape(data->shape, data->shape + data->ndim);
strm_->Write(data->ndim); strm_->Write(data->ndim);
...@@ -46,28 +46,29 @@ NDArray SharedMemManager::CopyToSharedMem<NDArray>(const NDArray &data, ...@@ -46,28 +46,29 @@ NDArray SharedMemManager::CopyToSharedMem<NDArray>(const NDArray &data,
return data; return data;
} else { } else {
auto nd = auto nd =
NDArray::EmptyShared(graph_name_ + name, shape, data->dtype, ctx, true); NDArray::EmptyShared(graph_name_ + name, shape, data->dtype, ctx, true);
nd.CopyFrom(data); nd.CopyFrom(data);
return nd; return nd;
} }
} }
template <> template <>
CSRMatrix SharedMemManager::CopyToSharedMem<CSRMatrix>(const CSRMatrix &csr, CSRMatrix SharedMemManager::CopyToSharedMem<CSRMatrix>(
std::string name) { const CSRMatrix &csr, std::string name) {
auto indptr_shared_mem = CopyToSharedMem(csr.indptr, name + "_indptr"); auto indptr_shared_mem = CopyToSharedMem(csr.indptr, name + "_indptr");
auto indices_shared_mem = CopyToSharedMem(csr.indices, name + "_indices"); auto indices_shared_mem = CopyToSharedMem(csr.indices, name + "_indices");
auto data_shared_mem = CopyToSharedMem(csr.data, name + "_data"); auto data_shared_mem = CopyToSharedMem(csr.data, name + "_data");
strm_->Write(csr.num_rows); strm_->Write(csr.num_rows);
strm_->Write(csr.num_cols); strm_->Write(csr.num_cols);
strm_->Write(csr.sorted); strm_->Write(csr.sorted);
return CSRMatrix(csr.num_rows, csr.num_cols, indptr_shared_mem, return CSRMatrix(
indices_shared_mem, data_shared_mem, csr.sorted); csr.num_rows, csr.num_cols, indptr_shared_mem, indices_shared_mem,
data_shared_mem, csr.sorted);
} }
template <> template <>
COOMatrix SharedMemManager::CopyToSharedMem<COOMatrix>(const COOMatrix &coo, COOMatrix SharedMemManager::CopyToSharedMem<COOMatrix>(
std::string name) { const COOMatrix &coo, std::string name) {
auto row_shared_mem = CopyToSharedMem(coo.row, name + "_row"); auto row_shared_mem = CopyToSharedMem(coo.row, name + "_row");
auto col_shared_mem = CopyToSharedMem(coo.col, name + "_col"); auto col_shared_mem = CopyToSharedMem(coo.col, name + "_col");
auto data_shared_mem = CopyToSharedMem(coo.data, name + "_data"); auto data_shared_mem = CopyToSharedMem(coo.data, name + "_data");
...@@ -75,13 +76,14 @@ COOMatrix SharedMemManager::CopyToSharedMem<COOMatrix>(const COOMatrix &coo, ...@@ -75,13 +76,14 @@ COOMatrix SharedMemManager::CopyToSharedMem<COOMatrix>(const COOMatrix &coo,
strm_->Write(coo.num_cols); strm_->Write(coo.num_cols);
strm_->Write(coo.row_sorted); strm_->Write(coo.row_sorted);
strm_->Write(coo.col_sorted); strm_->Write(coo.col_sorted);
return COOMatrix(coo.num_rows, coo.num_cols, row_shared_mem, col_shared_mem, return COOMatrix(
data_shared_mem, coo.row_sorted, coo.col_sorted); coo.num_rows, coo.num_cols, row_shared_mem, col_shared_mem,
data_shared_mem, coo.row_sorted, coo.col_sorted);
} }
template <> template <>
bool SharedMemManager::CreateFromSharedMem<NDArray>(NDArray *nd, bool SharedMemManager::CreateFromSharedMem<NDArray>(
std::string name) { NDArray *nd, std::string name) {
int ndim; int ndim;
DGLContext ctx = {kDGLCPU, 0}; DGLContext ctx = {kDGLCPU, 0};
DGLDataType dtype; DGLDataType dtype;
...@@ -98,15 +100,14 @@ bool SharedMemManager::CreateFromSharedMem<NDArray>(NDArray *nd, ...@@ -98,15 +100,14 @@ bool SharedMemManager::CreateFromSharedMem<NDArray>(NDArray *nd,
if (is_null) { if (is_null) {
*nd = NDArray::Empty(shape, dtype, ctx); *nd = NDArray::Empty(shape, dtype, ctx);
} else { } else {
*nd = *nd = NDArray::EmptyShared(graph_name_ + name, shape, dtype, ctx, false);
NDArray::EmptyShared(graph_name_ + name, shape, dtype, ctx, false);
} }
return true; return true;
} }
template <> template <>
bool SharedMemManager::CreateFromSharedMem<COOMatrix>(COOMatrix *coo, bool SharedMemManager::CreateFromSharedMem<COOMatrix>(
std::string name) { COOMatrix *coo, std::string name) {
CreateFromSharedMem(&coo->row, name + "_row"); CreateFromSharedMem(&coo->row, name + "_row");
CreateFromSharedMem(&coo->col, name + "_col"); CreateFromSharedMem(&coo->col, name + "_col");
CreateFromSharedMem(&coo->data, name + "_data"); CreateFromSharedMem(&coo->data, name + "_data");
...@@ -118,8 +119,8 @@ bool SharedMemManager::CreateFromSharedMem<COOMatrix>(COOMatrix *coo, ...@@ -118,8 +119,8 @@ bool SharedMemManager::CreateFromSharedMem<COOMatrix>(COOMatrix *coo,
} }
template <> template <>
bool SharedMemManager::CreateFromSharedMem<CSRMatrix>(CSRMatrix *csr, bool SharedMemManager::CreateFromSharedMem<CSRMatrix>(
std::string name) { CSRMatrix *csr, std::string name) {
CreateFromSharedMem(&csr->indptr, name + "_indptr"); CreateFromSharedMem(&csr->indptr, name + "_indptr");
CreateFromSharedMem(&csr->indices, name + "_indices"); CreateFromSharedMem(&csr->indices, name + "_indices");
CreateFromSharedMem(&csr->data, name + "_data"); CreateFromSharedMem(&csr->data, name + "_data");
......
...@@ -29,8 +29,7 @@ const size_t SHARED_MEM_METAINFO_SIZE_MAX = 1024 * 32; ...@@ -29,8 +29,7 @@ const size_t SHARED_MEM_METAINFO_SIZE_MAX = 1024 * 32;
class SharedMemManager : public dmlc::Stream { class SharedMemManager : public dmlc::Stream {
public: public:
explicit SharedMemManager(std::string graph_name, dmlc::Stream* strm) explicit SharedMemManager(std::string graph_name, dmlc::Stream* strm)
: graph_name_(graph_name), : graph_name_(graph_name), strm_(strm) {}
strm_(strm) {}
template <typename T> template <typename T>
T CopyToSharedMem(const T& data, std::string name); T CopyToSharedMem(const T& data, std::string name);
......
...@@ -11,7 +11,8 @@ namespace dgl { ...@@ -11,7 +11,8 @@ namespace dgl {
HeteroSubgraph InEdgeGraphRelabelNodes( HeteroSubgraph InEdgeGraphRelabelNodes(
const HeteroGraphPtr graph, const std::vector<IdArray>& vids) { const HeteroGraphPtr graph, const std::vector<IdArray>& vids) {
CHECK_EQ(vids.size(), graph->NumVertexTypes()) CHECK_EQ(vids.size(), graph->NumVertexTypes())
<< "Invalid input: the input list size must be the same as the number of vertex types."; << "Invalid input: the input list size must be the same as the number of "
"vertex types.";
std::vector<IdArray> eids(graph->NumEdgeTypes()); std::vector<IdArray> eids(graph->NumEdgeTypes());
DGLContext ctx = aten::GetContextOf(vids); DGLContext ctx = aten::GetContextOf(vids);
for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) { for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
...@@ -29,9 +30,11 @@ HeteroSubgraph InEdgeGraphRelabelNodes( ...@@ -29,9 +30,11 @@ HeteroSubgraph InEdgeGraphRelabelNodes(
HeteroSubgraph InEdgeGraphNoRelabelNodes( HeteroSubgraph InEdgeGraphNoRelabelNodes(
const HeteroGraphPtr graph, const std::vector<IdArray>& vids) { const HeteroGraphPtr graph, const std::vector<IdArray>& vids) {
// TODO(mufei): This should also use EdgeSubgraph once it is supported for CSR graphs // TODO(mufei): This should also use EdgeSubgraph once it is supported for CSR
// graphs
CHECK_EQ(vids.size(), graph->NumVertexTypes()) CHECK_EQ(vids.size(), graph->NumVertexTypes())
<< "Invalid input: the input list size must be the same as the number of vertex types."; << "Invalid input: the input list size must be the same as the number of "
"vertex types.";
std::vector<HeteroGraphPtr> subrels(graph->NumEdgeTypes()); std::vector<HeteroGraphPtr> subrels(graph->NumEdgeTypes());
std::vector<IdArray> induced_edges(graph->NumEdgeTypes()); std::vector<IdArray> induced_edges(graph->NumEdgeTypes());
DGLContext ctx = aten::GetContextOf(vids); DGLContext ctx = aten::GetContextOf(vids);
...@@ -43,30 +46,28 @@ HeteroSubgraph InEdgeGraphNoRelabelNodes( ...@@ -43,30 +46,28 @@ HeteroSubgraph InEdgeGraphNoRelabelNodes(
if (aten::IsNullArray(vids[dst_vtype])) { if (aten::IsNullArray(vids[dst_vtype])) {
// create a placeholder graph // create a placeholder graph
subrels[etype] = UnitGraph::Empty( subrels[etype] = UnitGraph::Empty(
relgraph->NumVertexTypes(), relgraph->NumVertexTypes(), graph->NumVertices(src_vtype),
graph->NumVertices(src_vtype), graph->NumVertices(dst_vtype), graph->DataType(), ctx);
graph->NumVertices(dst_vtype), induced_edges[etype] =
graph->DataType(), ctx); IdArray::Empty({0}, graph->DataType(), graph->Context());
induced_edges[etype] = IdArray::Empty({0}, graph->DataType(), graph->Context());
} else { } else {
const auto& earr = graph->InEdges(etype, {vids[dst_vtype]}); const auto& earr = graph->InEdges(etype, {vids[dst_vtype]});
subrels[etype] = UnitGraph::CreateFromCOO( subrels[etype] = UnitGraph::CreateFromCOO(
relgraph->NumVertexTypes(), relgraph->NumVertexTypes(), graph->NumVertices(src_vtype),
graph->NumVertices(src_vtype), graph->NumVertices(dst_vtype), earr.src, earr.dst);
graph->NumVertices(dst_vtype),
earr.src,
earr.dst);
induced_edges[etype] = earr.id; induced_edges[etype] = earr.id;
} }
} }
HeteroSubgraph ret; HeteroSubgraph ret;
ret.graph = CreateHeteroGraph(graph->meta_graph(), subrels, graph->NumVerticesPerType()); ret.graph = CreateHeteroGraph(
graph->meta_graph(), subrels, graph->NumVerticesPerType());
ret.induced_edges = std::move(induced_edges); ret.induced_edges = std::move(induced_edges);
return ret; return ret;
} }
HeteroSubgraph InEdgeGraph( HeteroSubgraph InEdgeGraph(
const HeteroGraphPtr graph, const std::vector<IdArray>& vids, bool relabel_nodes) { const HeteroGraphPtr graph, const std::vector<IdArray>& vids,
bool relabel_nodes) {
if (relabel_nodes) { if (relabel_nodes) {
return InEdgeGraphRelabelNodes(graph, vids); return InEdgeGraphRelabelNodes(graph, vids);
} else { } else {
...@@ -77,7 +78,8 @@ HeteroSubgraph InEdgeGraph( ...@@ -77,7 +78,8 @@ HeteroSubgraph InEdgeGraph(
HeteroSubgraph OutEdgeGraphRelabelNodes( HeteroSubgraph OutEdgeGraphRelabelNodes(
const HeteroGraphPtr graph, const std::vector<IdArray>& vids) { const HeteroGraphPtr graph, const std::vector<IdArray>& vids) {
CHECK_EQ(vids.size(), graph->NumVertexTypes()) CHECK_EQ(vids.size(), graph->NumVertexTypes())
<< "Invalid input: the input list size must be the same as the number of vertex types."; << "Invalid input: the input list size must be the same as the number of "
"vertex types.";
std::vector<IdArray> eids(graph->NumEdgeTypes()); std::vector<IdArray> eids(graph->NumEdgeTypes());
DGLContext ctx = aten::GetContextOf(vids); DGLContext ctx = aten::GetContextOf(vids);
for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) { for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
...@@ -95,9 +97,11 @@ HeteroSubgraph OutEdgeGraphRelabelNodes( ...@@ -95,9 +97,11 @@ HeteroSubgraph OutEdgeGraphRelabelNodes(
HeteroSubgraph OutEdgeGraphNoRelabelNodes( HeteroSubgraph OutEdgeGraphNoRelabelNodes(
const HeteroGraphPtr graph, const std::vector<IdArray>& vids) { const HeteroGraphPtr graph, const std::vector<IdArray>& vids) {
// TODO(mufei): This should also use EdgeSubgraph once it is supported for CSR graphs // TODO(mufei): This should also use EdgeSubgraph once it is supported for CSR
// graphs
CHECK_EQ(vids.size(), graph->NumVertexTypes()) CHECK_EQ(vids.size(), graph->NumVertexTypes())
<< "Invalid input: the input list size must be the same as the number of vertex types."; << "Invalid input: the input list size must be the same as the number of "
"vertex types.";
std::vector<HeteroGraphPtr> subrels(graph->NumEdgeTypes()); std::vector<HeteroGraphPtr> subrels(graph->NumEdgeTypes());
std::vector<IdArray> induced_edges(graph->NumEdgeTypes()); std::vector<IdArray> induced_edges(graph->NumEdgeTypes());
DGLContext ctx = aten::GetContextOf(vids); DGLContext ctx = aten::GetContextOf(vids);
...@@ -109,30 +113,28 @@ HeteroSubgraph OutEdgeGraphNoRelabelNodes( ...@@ -109,30 +113,28 @@ HeteroSubgraph OutEdgeGraphNoRelabelNodes(
if (aten::IsNullArray(vids[src_vtype])) { if (aten::IsNullArray(vids[src_vtype])) {
// create a placeholder graph // create a placeholder graph
subrels[etype] = UnitGraph::Empty( subrels[etype] = UnitGraph::Empty(
relgraph->NumVertexTypes(), relgraph->NumVertexTypes(), graph->NumVertices(src_vtype),
graph->NumVertices(src_vtype), graph->NumVertices(dst_vtype), graph->DataType(), ctx);
graph->NumVertices(dst_vtype), induced_edges[etype] =
graph->DataType(), ctx); IdArray::Empty({0}, graph->DataType(), graph->Context());
induced_edges[etype] = IdArray::Empty({0}, graph->DataType(), graph->Context());
} else { } else {
const auto& earr = graph->OutEdges(etype, {vids[src_vtype]}); const auto& earr = graph->OutEdges(etype, {vids[src_vtype]});
subrels[etype] = UnitGraph::CreateFromCOO( subrels[etype] = UnitGraph::CreateFromCOO(
relgraph->NumVertexTypes(), relgraph->NumVertexTypes(), graph->NumVertices(src_vtype),
graph->NumVertices(src_vtype), graph->NumVertices(dst_vtype), earr.src, earr.dst);
graph->NumVertices(dst_vtype),
earr.src,
earr.dst);
induced_edges[etype] = earr.id; induced_edges[etype] = earr.id;
} }
} }
HeteroSubgraph ret; HeteroSubgraph ret;
ret.graph = CreateHeteroGraph(graph->meta_graph(), subrels, graph->NumVerticesPerType()); ret.graph = CreateHeteroGraph(
graph->meta_graph(), subrels, graph->NumVerticesPerType());
ret.induced_edges = std::move(induced_edges); ret.induced_edges = std::move(induced_edges);
return ret; return ret;
} }
HeteroSubgraph OutEdgeGraph( HeteroSubgraph OutEdgeGraph(
const HeteroGraphPtr graph, const std::vector<IdArray>& vids, bool relabel_nodes) { const HeteroGraphPtr graph, const std::vector<IdArray>& vids,
bool relabel_nodes) {
if (relabel_nodes) { if (relabel_nodes) {
return OutEdgeGraphRelabelNodes(graph, vids); return OutEdgeGraphRelabelNodes(graph, vids);
} else { } else {
......
...@@ -19,18 +19,20 @@ ...@@ -19,18 +19,20 @@
#include "compact.h" #include "compact.h"
#include <dgl/base_heterograph.h>
#include <dgl/transform.h>
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/base_heterograph.h>
#include <dgl/packed_func_ext.h> #include <dgl/packed_func_ext.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/container.h> #include <dgl/runtime/container.h>
#include <vector> #include <dgl/runtime/registry.h>
#include <dgl/transform.h>
#include <utility> #include <utility>
#include <vector>
#include "../../c_api_common.h" #include "../../c_api_common.h"
#include "../unit_graph.h" #include "../unit_graph.h"
// TODO(BarclayII): currently CompactGraphs depend on IdHashMap implementation which // TODO(BarclayII): currently CompactGraphs depend on IdHashMap implementation
// only works on CPU. Should fix later to make it device agnostic. // which only works on CPU. Should fix later to make it device agnostic.
#include "../../array/cpu/array_utils.h" #include "../../array/cpu/array_utils.h"
namespace dgl { namespace dgl {
...@@ -42,16 +44,16 @@ namespace transform { ...@@ -42,16 +44,16 @@ namespace transform {
namespace { namespace {
template<typename IdType> template <typename IdType>
std::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>> std::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>> CompactGraphsCPU(
CompactGraphsCPU(
const std::vector<HeteroGraphPtr> &graphs, const std::vector<HeteroGraphPtr> &graphs,
const std::vector<IdArray> &always_preserve) { const std::vector<IdArray> &always_preserve) {
// TODO(BarclayII): check whether the node space and metagraph of each graph is the same. // TODO(BarclayII): check whether the node space and metagraph of each graph
// Step 1: Collect the nodes that has connections for each type. // is the same. Step 1: Collect the nodes that has connections for each type.
const int64_t num_ntypes = graphs[0]->NumVertexTypes(); const int64_t num_ntypes = graphs[0]->NumVertexTypes();
std::vector<aten::IdHashMap<IdType>> hashmaps(num_ntypes); std::vector<aten::IdHashMap<IdType>> hashmaps(num_ntypes);
std::vector<std::vector<EdgeArray>> all_edges(graphs.size()); // all_edges[i][etype] std::vector<std::vector<EdgeArray>> all_edges(
graphs.size()); // all_edges[i][etype]
std::vector<int64_t> max_vertex_cnt(num_ntypes, 0); std::vector<int64_t> max_vertex_cnt(num_ntypes, 0);
for (size_t i = 0; i < graphs.size(); ++i) { for (size_t i = 0; i < graphs.size(); ++i) {
...@@ -98,7 +100,8 @@ CompactGraphsCPU( ...@@ -98,7 +100,8 @@ CompactGraphsCPU(
} }
} }
// Step 2: Relabel the nodes for each type to a smaller ID space and save the mapping. // Step 2: Relabel the nodes for each type to a smaller ID space and save the
// mapping.
std::vector<IdArray> induced_nodes(num_ntypes); std::vector<IdArray> induced_nodes(num_ntypes);
std::vector<int64_t> num_induced_nodes(num_ntypes); std::vector<int64_t> num_induced_nodes(num_ntypes);
for (int64_t i = 0; i < num_ntypes; ++i) { for (int64_t i = 0; i < num_ntypes; ++i) {
...@@ -123,14 +126,12 @@ CompactGraphsCPU( ...@@ -123,14 +126,12 @@ CompactGraphsCPU(
const IdArray mapped_cols = hashmaps[dsttype].Map(edges.dst, -1); const IdArray mapped_cols = hashmaps[dsttype].Map(edges.dst, -1);
rel_graphs.push_back(UnitGraph::CreateFromCOO( rel_graphs.push_back(UnitGraph::CreateFromCOO(
srctype == dsttype ? 1 : 2, srctype == dsttype ? 1 : 2, induced_nodes[srctype]->shape[0],
induced_nodes[srctype]->shape[0], induced_nodes[dsttype]->shape[0], mapped_rows, mapped_cols));
induced_nodes[dsttype]->shape[0],
mapped_rows,
mapped_cols));
} }
new_graphs.push_back(CreateHeteroGraph(meta_graph, rel_graphs, num_induced_nodes)); new_graphs.push_back(
CreateHeteroGraph(meta_graph, rel_graphs, num_induced_nodes));
} }
return std::make_pair(new_graphs, induced_nodes); return std::make_pair(new_graphs, induced_nodes);
...@@ -138,7 +139,7 @@ CompactGraphsCPU( ...@@ -138,7 +139,7 @@ CompactGraphsCPU(
}; // namespace }; // namespace
template<> template <>
std::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>> std::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>>
CompactGraphs<kDGLCPU, int32_t>( CompactGraphs<kDGLCPU, int32_t>(
const std::vector<HeteroGraphPtr> &graphs, const std::vector<HeteroGraphPtr> &graphs,
...@@ -146,7 +147,7 @@ CompactGraphs<kDGLCPU, int32_t>( ...@@ -146,7 +147,7 @@ CompactGraphs<kDGLCPU, int32_t>(
return CompactGraphsCPU<int32_t>(graphs, always_preserve); return CompactGraphsCPU<int32_t>(graphs, always_preserve);
} }
template<> template <>
std::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>> std::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>>
CompactGraphs<kDGLCPU, int64_t>( CompactGraphs<kDGLCPU, int64_t>(
const std::vector<HeteroGraphPtr> &graphs, const std::vector<HeteroGraphPtr> &graphs,
...@@ -155,44 +156,44 @@ CompactGraphs<kDGLCPU, int64_t>( ...@@ -155,44 +156,44 @@ CompactGraphs<kDGLCPU, int64_t>(
} }
DGL_REGISTER_GLOBAL("transform._CAPI_DGLCompactGraphs") DGL_REGISTER_GLOBAL("transform._CAPI_DGLCompactGraphs")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
List<HeteroGraphRef> graph_refs = args[0]; List<HeteroGraphRef> graph_refs = args[0];
List<Value> always_preserve_refs = args[1]; List<Value> always_preserve_refs = args[1];
std::vector<HeteroGraphPtr> graphs; std::vector<HeteroGraphPtr> graphs;
std::vector<IdArray> always_preserve; std::vector<IdArray> always_preserve;
for (HeteroGraphRef gref : graph_refs) for (HeteroGraphRef gref : graph_refs) graphs.push_back(gref.sptr());
graphs.push_back(gref.sptr()); for (Value array : always_preserve_refs)
for (Value array : always_preserve_refs) always_preserve.push_back(array->data);
always_preserve.push_back(array->data);
// TODO(BarclayII): check for all IdArrays
// TODO(BarclayII): check for all IdArrays CHECK(graphs[0]->DataType() == always_preserve[0]->dtype)
CHECK(graphs[0]->DataType() == always_preserve[0]->dtype) << "data type mismatch."; << "data type mismatch.";
std::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>> result_pair; std::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>> result_pair;
ATEN_XPU_SWITCH_CUDA(graphs[0]->Context().device_type, XPU, "CompactGraphs", { ATEN_XPU_SWITCH_CUDA(
ATEN_ID_TYPE_SWITCH(graphs[0]->DataType(), IdType, { graphs[0]->Context().device_type, XPU, "CompactGraphs", {
result_pair = CompactGraphs<XPU, IdType>( ATEN_ID_TYPE_SWITCH(graphs[0]->DataType(), IdType, {
graphs, always_preserve); result_pair = CompactGraphs<XPU, IdType>(graphs, always_preserve);
}); });
});
List<HeteroGraphRef> compacted_graph_refs;
List<Value> induced_nodes;
for (const HeteroGraphPtr g : result_pair.first)
compacted_graph_refs.push_back(HeteroGraphRef(g));
for (const IdArray &ids : result_pair.second)
induced_nodes.push_back(Value(MakeValue(ids)));
List<ObjectRef> result;
result.push_back(compacted_graph_refs);
result.push_back(induced_nodes);
*rv = result;
}); });
List<HeteroGraphRef> compacted_graph_refs;
List<Value> induced_nodes;
for (const HeteroGraphPtr g : result_pair.first)
compacted_graph_refs.push_back(HeteroGraphRef(g));
for (const IdArray &ids : result_pair.second)
induced_nodes.push_back(Value(MakeValue(ids)));
List<ObjectRef> result;
result.push_back(compacted_graph_refs);
result.push_back(induced_nodes);
*rv = result;
});
}; // namespace transform }; // namespace transform
}; // namespace dgl }; // namespace dgl
...@@ -24,8 +24,8 @@ ...@@ -24,8 +24,8 @@
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/base_heterograph.h> #include <dgl/base_heterograph.h>
#include <vector>
#include <utility> #include <utility>
#include <vector>
namespace dgl { namespace dgl {
namespace transform { namespace transform {
...@@ -41,9 +41,8 @@ namespace transform { ...@@ -41,9 +41,8 @@ namespace transform {
* *
* @return The vector of compacted graphs and the vector of induced nodes. * @return The vector of compacted graphs and the vector of induced nodes.
*/ */
template<DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
std::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>> std::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>> CompactGraphs(
CompactGraphs(
const std::vector<HeteroGraphPtr> &graphs, const std::vector<HeteroGraphPtr> &graphs,
const std::vector<IdArray> &always_preserve); const std::vector<IdArray> &always_preserve);
......
...@@ -9,7 +9,9 @@ ...@@ -9,7 +9,9 @@
#include <dgl/array.h> #include <dgl/array.h>
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <nanoflann.hpp> #include <nanoflann.hpp>
#include "../../../c_api_common.h" #include "../../../c_api_common.h"
namespace dgl { namespace dgl {
...@@ -17,78 +19,75 @@ namespace transform { ...@@ -17,78 +19,75 @@ namespace transform {
namespace knn_utils { namespace knn_utils {
/*! /*!
* \brief A simple 2D NDArray adapter for nanoflann, without duplicating the storage. * \brief A simple 2D NDArray adapter for nanoflann, without duplicating the
* storage.
* *
* \tparam FloatType: The type of the point coordinates (typically, double or float). * \tparam FloatType: The type of the point coordinates (typically, double or
* \tparam IdType: The type for indices in the KD-tree index (typically, size_t of int) * float).
* \tparam FeatureDim: If set to > 0, it specifies a compile-time fixed dimensionality * \tparam IdType: The type for indices in the KD-tree index (typically,
* for the points in the data set, allowing more compiler optimizations. * size_t of int)
* \tparam Dist: The distance metric to use: nanoflann::metric_L1, nanoflann::metric_L2, * \tparam FeatureDim: If set to > 0, it specifies a compile-time fixed
* nanoflann::metric_L2_Simple, etc. * dimensionality for the points in the data set, allowing more compiler
* \note The spelling of dgl's adapter ("adapter") is different from naneflann ("adaptor") * optimizations.
* \tparam Dist: The distance metric to use: nanoflann::metric_L1,
nanoflann::metric_L2, nanoflann::metric_L2_Simple, etc.
* \note The spelling of dgl's adapter ("adapter") is different from naneflann
* ("adaptor")
*/ */
template <typename FloatType, template <
typename IdType, typename FloatType, typename IdType, int FeatureDim = -1,
int FeatureDim = -1, typename Dist = nanoflann::metric_L2>
typename Dist = nanoflann::metric_L2>
class KDTreeNDArrayAdapter { class KDTreeNDArrayAdapter {
public: public:
using self_type = KDTreeNDArrayAdapter<FloatType, IdType, FeatureDim, Dist>; using self_type = KDTreeNDArrayAdapter<FloatType, IdType, FeatureDim, Dist>;
using metric_type = typename Dist::template traits<FloatType, self_type>::distance_t; using metric_type =
typename Dist::template traits<FloatType, self_type>::distance_t;
using index_type = nanoflann::KDTreeSingleIndexAdaptor< using index_type = nanoflann::KDTreeSingleIndexAdaptor<
metric_type, self_type, FeatureDim, IdType>; metric_type, self_type, FeatureDim, IdType>;
KDTreeNDArrayAdapter(const size_t /* dims */, KDTreeNDArrayAdapter(
const NDArray data_points, const size_t /* dims */, const NDArray data_points,
const int leaf_max_size = 10) const int leaf_max_size = 10)
: data_(data_points) { : data_(data_points) {
CHECK(data_points->shape[0] != 0 && data_points->shape[1] != 0) CHECK(data_points->shape[0] != 0 && data_points->shape[1] != 0)
<< "Tensor containing input data point set must be 2D."; << "Tensor containing input data point set must be 2D.";
const size_t dims = data_points->shape[1]; const size_t dims = data_points->shape[1];
CHECK(!(FeatureDim > 0 && static_cast<int>(dims) != FeatureDim)) CHECK(!(FeatureDim > 0 && static_cast<int>(dims) != FeatureDim))
<< "Data set feature dimension does not match the 'FeatureDim' " << "Data set feature dimension does not match the 'FeatureDim' "
<< "template argument."; << "template argument.";
index_ = new index_type( index_ = new index_type(
static_cast<int>(dims), *this, nanoflann::KDTreeSingleIndexAdaptorParams(leaf_max_size)); static_cast<int>(dims), *this,
nanoflann::KDTreeSingleIndexAdaptorParams(leaf_max_size));
index_->buildIndex(); index_->buildIndex();
} }
~KDTreeNDArrayAdapter() { ~KDTreeNDArrayAdapter() { delete index_; }
delete index_;
}
index_type* GetIndex() { index_type* GetIndex() { return index_; }
return index_;
}
/*! /*!
* \brief Query for the \a num_closest points to a given point * \brief Query for the \a num_closest points to a given point
* Note that this is a short-cut method for GetIndex()->findNeighbors(). * Note that this is a short-cut method for GetIndex()->findNeighbors().
*/ */
void query(const FloatType* query_pt, const size_t num_closest, void query(
IdType* out_idxs, FloatType* out_dists) const { const FloatType* query_pt, const size_t num_closest, IdType* out_idxs,
FloatType* out_dists) const {
nanoflann::KNNResultSet<FloatType, IdType> resultSet(num_closest); nanoflann::KNNResultSet<FloatType, IdType> resultSet(num_closest);
resultSet.init(out_idxs, out_dists); resultSet.init(out_idxs, out_dists);
index_->findNeighbors(resultSet, query_pt, nanoflann::SearchParams()); index_->findNeighbors(resultSet, query_pt, nanoflann::SearchParams());
} }
/*! \brief Interface expected by KDTreeSingleIndexAdaptor */ /*! \brief Interface expected by KDTreeSingleIndexAdaptor */
const self_type& derived() const { const self_type& derived() const { return *this; }
return *this;
}
/*! \brief Interface expected by KDTreeSingleIndexAdaptor */ /*! \brief Interface expected by KDTreeSingleIndexAdaptor */
self_type& derived() { self_type& derived() { return *this; }
return *this;
}
/*! /*!
* \brief Interface expected by KDTreeSingleIndexAdaptor, * \brief Interface expected by KDTreeSingleIndexAdaptor,
* return the number of data points * return the number of data points
*/ */
size_t kdtree_get_point_count() const { size_t kdtree_get_point_count() const { return data_->shape[0]; }
return data_->shape[0];
}
/*! /*!
* \brief Interface expected by KDTreeSingleIndexAdaptor, * \brief Interface expected by KDTreeSingleIndexAdaptor,
...@@ -110,7 +109,7 @@ class KDTreeNDArrayAdapter { ...@@ -110,7 +109,7 @@ class KDTreeNDArrayAdapter {
} }
private: private:
index_type* index_; // The kd tree index index_type* index_; // The kd tree index
const NDArray data_; // data points const NDArray data_; // data points
}; };
......
This diff is collapsed.
...@@ -18,13 +18,13 @@ ...@@ -18,13 +18,13 @@
* all given graphs with the same set of nodes. * all given graphs with the same set of nodes.
*/ */
#include <dgl/runtime/device_api.h>
#include <dgl/immutable_graph.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <utility> #include <dgl/immutable_graph.h>
#include <dgl/runtime/device_api.h>
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include <utility>
#include "../../../runtime/cuda/cuda_common.h" #include "../../../runtime/cuda/cuda_common.h"
#include "../../heterograph.h" #include "../../heterograph.h"
...@@ -41,54 +41,45 @@ namespace transform { ...@@ -41,54 +41,45 @@ namespace transform {
namespace { namespace {
/** /**
* \brief This function builds node maps for each node type, preserving the * \brief This function builds node maps for each node type, preserving the
* order of the input nodes. Here it is assumed the nodes are not unique, * order of the input nodes. Here it is assumed the nodes are not unique,
* and thus a unique list is generated. * and thus a unique list is generated.
* *
* \param input_nodes The set of input nodes. * \param input_nodes The set of input nodes.
* \param node_maps The node maps to be constructed. * \param node_maps The node maps to be constructed.
* \param count_unique_device The number of unique nodes (on the GPU). * \param count_unique_device The number of unique nodes (on the GPU).
* \param unique_nodes_device The unique nodes (on the GPU). * \param unique_nodes_device The unique nodes (on the GPU).
* \param stream The stream to operate on. * \param stream The stream to operate on.
*/ */
template<typename IdType> template <typename IdType>
void BuildNodeMaps( void BuildNodeMaps(
const std::vector<IdArray>& input_nodes, const std::vector<IdArray> &input_nodes,
DeviceNodeMap<IdType> * const node_maps, DeviceNodeMap<IdType> *const node_maps, int64_t *const count_unique_device,
int64_t * const count_unique_device, std::vector<IdArray> *const unique_nodes_device, cudaStream_t stream) {
std::vector<IdArray>* const unique_nodes_device,
cudaStream_t stream) {
const int64_t num_ntypes = static_cast<int64_t>(input_nodes.size()); const int64_t num_ntypes = static_cast<int64_t>(input_nodes.size());
CUDA_CALL(cudaMemsetAsync( CUDA_CALL(cudaMemsetAsync(
count_unique_device, count_unique_device, 0, num_ntypes * sizeof(*count_unique_device),
0, stream));
num_ntypes*sizeof(*count_unique_device),
stream));
// possibly duplicated nodes // possibly duplicated nodes
for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) { for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) {
const IdArray& nodes = input_nodes[ntype]; const IdArray &nodes = input_nodes[ntype];
if (nodes->shape[0] > 0) { if (nodes->shape[0] > 0) {
CHECK_EQ(nodes->ctx.device_type, kDGLCUDA); CHECK_EQ(nodes->ctx.device_type, kDGLCUDA);
node_maps->LhsHashTable(ntype).FillWithDuplicates( node_maps->LhsHashTable(ntype).FillWithDuplicates(
nodes.Ptr<IdType>(), nodes.Ptr<IdType>(), nodes->shape[0],
nodes->shape[0],
(*unique_nodes_device)[ntype].Ptr<IdType>(), (*unique_nodes_device)[ntype].Ptr<IdType>(),
count_unique_device+ntype, count_unique_device + ntype, stream);
stream);
} }
} }
} }
template <typename IdType>
template<typename IdType> std::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>> CompactGraphsGPU(
std::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>>
CompactGraphsGPU(
const std::vector<HeteroGraphPtr> &graphs, const std::vector<HeteroGraphPtr> &graphs,
const std::vector<IdArray> &always_preserve) { const std::vector<IdArray> &always_preserve) {
const auto &ctx = graphs[0]->Context();
const auto& ctx = graphs[0]->Context();
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
...@@ -96,7 +87,8 @@ CompactGraphsGPU( ...@@ -96,7 +87,8 @@ CompactGraphsGPU(
// Step 1: Collect the nodes that has connections for each type. // Step 1: Collect the nodes that has connections for each type.
const uint64_t num_ntypes = graphs[0]->NumVertexTypes(); const uint64_t num_ntypes = graphs[0]->NumVertexTypes();
std::vector<std::vector<EdgeArray>> all_edges(graphs.size()); // all_edges[i][etype] std::vector<std::vector<EdgeArray>> all_edges(
graphs.size()); // all_edges[i][etype]
// count the number of nodes per type // count the number of nodes per type
std::vector<int64_t> max_vertex_cnt(num_ntypes, 0); std::vector<int64_t> max_vertex_cnt(num_ntypes, 0);
...@@ -123,19 +115,18 @@ CompactGraphsGPU( ...@@ -123,19 +115,18 @@ CompactGraphsGPU(
std::vector<int64_t> node_offsets(num_ntypes, 0); std::vector<int64_t> node_offsets(num_ntypes, 0);
for (uint64_t ntype = 0; ntype < num_ntypes; ++ntype) { for (uint64_t ntype = 0; ntype < num_ntypes; ++ntype) {
all_nodes[ntype] = NewIdArray(max_vertex_cnt[ntype], ctx, all_nodes[ntype] =
sizeof(IdType)*8); NewIdArray(max_vertex_cnt[ntype], ctx, sizeof(IdType) * 8);
// copy the nodes in always_preserve // copy the nodes in always_preserve
if (ntype < always_preserve.size() && always_preserve[ntype]->shape[0] > 0) { if (ntype < always_preserve.size() &&
always_preserve[ntype]->shape[0] > 0) {
device->CopyDataFromTo( device->CopyDataFromTo(
always_preserve[ntype].Ptr<IdType>(), 0, always_preserve[ntype].Ptr<IdType>(), 0,
all_nodes[ntype].Ptr<IdType>(), all_nodes[ntype].Ptr<IdType>(), node_offsets[ntype],
node_offsets[ntype], sizeof(IdType) * always_preserve[ntype]->shape[0],
sizeof(IdType)*always_preserve[ntype]->shape[0], always_preserve[ntype]->ctx, all_nodes[ntype]->ctx,
always_preserve[ntype]->ctx,
all_nodes[ntype]->ctx,
always_preserve[ntype]->dtype); always_preserve[ntype]->dtype);
node_offsets[ntype] += sizeof(IdType)*always_preserve[ntype]->shape[0]; node_offsets[ntype] += sizeof(IdType) * always_preserve[ntype]->shape[0];
} }
} }
...@@ -152,25 +143,17 @@ CompactGraphsGPU( ...@@ -152,25 +143,17 @@ CompactGraphsGPU(
if (edges.src.defined()) { if (edges.src.defined()) {
device->CopyDataFromTo( device->CopyDataFromTo(
edges.src.Ptr<IdType>(), 0, edges.src.Ptr<IdType>(), 0, all_nodes[srctype].Ptr<IdType>(),
all_nodes[srctype].Ptr<IdType>(), node_offsets[srctype], sizeof(IdType) * edges.src->shape[0],
node_offsets[srctype], edges.src->ctx, all_nodes[srctype]->ctx, edges.src->dtype);
sizeof(IdType)*edges.src->shape[0], node_offsets[srctype] += sizeof(IdType) * edges.src->shape[0];
edges.src->ctx,
all_nodes[srctype]->ctx,
edges.src->dtype);
node_offsets[srctype] += sizeof(IdType)*edges.src->shape[0];
} }
if (edges.dst.defined()) { if (edges.dst.defined()) {
device->CopyDataFromTo( device->CopyDataFromTo(
edges.dst.Ptr<IdType>(), 0, edges.dst.Ptr<IdType>(), 0, all_nodes[dsttype].Ptr<IdType>(),
all_nodes[dsttype].Ptr<IdType>(), node_offsets[dsttype], sizeof(IdType) * edges.dst->shape[0],
node_offsets[dsttype], edges.dst->ctx, all_nodes[dsttype]->ctx, edges.dst->dtype);
sizeof(IdType)*edges.dst->shape[0], node_offsets[dsttype] += sizeof(IdType) * edges.dst->shape[0];
edges.dst->ctx,
all_nodes[dsttype]->ctx,
edges.dst->dtype);
node_offsets[dsttype] += sizeof(IdType)*edges.dst->shape[0];
} }
all_edges[i].push_back(edges); all_edges[i].push_back(edges);
} }
...@@ -185,29 +168,22 @@ CompactGraphsGPU( ...@@ -185,29 +168,22 @@ CompactGraphsGPU(
// number of unique nodes per type on CPU // number of unique nodes per type on CPU
std::vector<int64_t> num_induced_nodes(num_ntypes); std::vector<int64_t> num_induced_nodes(num_ntypes);
// number of unique nodes per type on GPU // number of unique nodes per type on GPU
int64_t * count_unique_device = static_cast<int64_t*>( int64_t *count_unique_device = static_cast<int64_t *>(
device->AllocWorkspace(ctx, sizeof(int64_t)*num_ntypes)); device->AllocWorkspace(ctx, sizeof(int64_t) * num_ntypes));
// the set of unique nodes per type // the set of unique nodes per type
std::vector<IdArray> induced_nodes(num_ntypes); std::vector<IdArray> induced_nodes(num_ntypes);
for (uint64_t ntype = 0; ntype < num_ntypes; ++ntype) { for (uint64_t ntype = 0; ntype < num_ntypes; ++ntype) {
induced_nodes[ntype] = NewIdArray(max_vertex_cnt[ntype], ctx, induced_nodes[ntype] =
sizeof(IdType)*8); NewIdArray(max_vertex_cnt[ntype], ctx, sizeof(IdType) * 8);
} }
BuildNodeMaps( BuildNodeMaps(
all_nodes, all_nodes, &node_maps, count_unique_device, &induced_nodes, stream);
&node_maps,
count_unique_device,
&induced_nodes,
stream);
device->CopyDataFromTo( device->CopyDataFromTo(
count_unique_device, 0, count_unique_device, 0, num_induced_nodes.data(), 0,
num_induced_nodes.data(), 0, sizeof(*num_induced_nodes.data()) * num_ntypes, ctx,
sizeof(*num_induced_nodes.data())*num_ntypes, DGLContext{kDGLCPU, 0}, DGLDataType{kDGLInt, 64, 1});
ctx,
DGLContext{kDGLCPU, 0},
DGLDataType{kDGLInt, 64, 1});
device->StreamSync(ctx, stream); device->StreamSync(ctx, stream);
// wait for the node counts to finish transferring // wait for the node counts to finish transferring
...@@ -230,22 +206,20 @@ CompactGraphsGPU( ...@@ -230,22 +206,20 @@ CompactGraphsGPU(
std::vector<IdArray> new_src; std::vector<IdArray> new_src;
std::vector<IdArray> new_dst; std::vector<IdArray> new_dst;
std::tie(new_src, new_dst) = MapEdges( std::tie(new_src, new_dst) =
curr_graph, all_edges[i], node_maps, stream); MapEdges(curr_graph, all_edges[i], node_maps, stream);
for (IdType etype = 0; etype < num_etypes; ++etype) { for (IdType etype = 0; etype < num_etypes; ++etype) {
IdType srctype, dsttype; IdType srctype, dsttype;
std::tie(srctype, dsttype) = curr_graph->GetEndpointTypes(etype); std::tie(srctype, dsttype) = curr_graph->GetEndpointTypes(etype);
rel_graphs.push_back(UnitGraph::CreateFromCOO( rel_graphs.push_back(UnitGraph::CreateFromCOO(
srctype == dsttype ? 1 : 2, srctype == dsttype ? 1 : 2, induced_nodes[srctype]->shape[0],
induced_nodes[srctype]->shape[0], induced_nodes[dsttype]->shape[0], new_src[etype], new_dst[etype]));
induced_nodes[dsttype]->shape[0],
new_src[etype],
new_dst[etype]));
} }
new_graphs.push_back(CreateHeteroGraph(meta_graph, rel_graphs, num_induced_nodes)); new_graphs.push_back(
CreateHeteroGraph(meta_graph, rel_graphs, num_induced_nodes));
} }
return std::make_pair(new_graphs, induced_nodes); return std::make_pair(new_graphs, induced_nodes);
...@@ -253,7 +227,7 @@ CompactGraphsGPU( ...@@ -253,7 +227,7 @@ CompactGraphsGPU(
} // namespace } // namespace
template<> template <>
std::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>> std::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>>
CompactGraphs<kDGLCUDA, int32_t>( CompactGraphs<kDGLCUDA, int32_t>(
const std::vector<HeteroGraphPtr> &graphs, const std::vector<HeteroGraphPtr> &graphs,
...@@ -261,7 +235,7 @@ CompactGraphs<kDGLCUDA, int32_t>( ...@@ -261,7 +235,7 @@ CompactGraphs<kDGLCUDA, int32_t>(
return CompactGraphsGPU<int32_t>(graphs, always_preserve); return CompactGraphsGPU<int32_t>(graphs, always_preserve);
} }
template<> template <>
std::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>> std::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>>
CompactGraphs<kDGLCUDA, int64_t>( CompactGraphs<kDGLCUDA, int64_t>(
const std::vector<HeteroGraphPtr> &graphs, const std::vector<HeteroGraphPtr> &graphs,
......
...@@ -20,13 +20,14 @@ ...@@ -20,13 +20,14 @@
#ifndef DGL_GRAPH_TRANSFORM_CUDA_CUDA_MAP_EDGES_CUH_ #ifndef DGL_GRAPH_TRANSFORM_CUDA_CUDA_MAP_EDGES_CUH_
#define DGL_GRAPH_TRANSFORM_CUDA_CUDA_MAP_EDGES_CUH_ #define DGL_GRAPH_TRANSFORM_CUDA_CUDA_MAP_EDGES_CUH_
#include <dgl/runtime/c_runtime_api.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <dgl/runtime/c_runtime_api.h>
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include <tuple> #include <tuple>
#include <vector>
#include <utility> #include <utility>
#include <vector>
#include "../../../runtime/cuda/cuda_common.h" #include "../../../runtime/cuda/cuda_common.h"
#include "../../../runtime/cuda/cuda_hashtable.cuh" #include "../../../runtime/cuda/cuda_hashtable.cuh"
...@@ -39,48 +40,46 @@ namespace transform { ...@@ -39,48 +40,46 @@ namespace transform {
namespace cuda { namespace cuda {
template<typename IdType, int BLOCK_SIZE, IdType TILE_SIZE> template <typename IdType, int BLOCK_SIZE, IdType TILE_SIZE>
__device__ void map_vertex_ids( __device__ void map_vertex_ids(
const IdType * const global, const IdType* const global, IdType* const new_global,
IdType * const new_global, const IdType num_vertices, const DeviceOrderedHashTable<IdType>& table) {
const IdType num_vertices,
const DeviceOrderedHashTable<IdType>& table) {
assert(BLOCK_SIZE == blockDim.x); assert(BLOCK_SIZE == blockDim.x);
using Mapping = typename OrderedHashTable<IdType>::Mapping; using Mapping = typename OrderedHashTable<IdType>::Mapping;
const IdType tile_start = TILE_SIZE*blockIdx.x; const IdType tile_start = TILE_SIZE * blockIdx.x;
const IdType tile_end = min(TILE_SIZE*(blockIdx.x+1), num_vertices); const IdType tile_end = min(TILE_SIZE * (blockIdx.x + 1), num_vertices);
for (IdType idx = threadIdx.x+tile_start; idx < tile_end; idx+=BLOCK_SIZE) { for (IdType idx = threadIdx.x + tile_start; idx < tile_end;
idx += BLOCK_SIZE) {
const Mapping& mapping = *table.Search(global[idx]); const Mapping& mapping = *table.Search(global[idx]);
new_global[idx] = mapping.local; new_global[idx] = mapping.local;
} }
} }
/** /**
* \brief Generate mapped edge endpoint ids. * \brief Generate mapped edge endpoint ids.
* *
* \tparam IdType The type of id. * \tparam IdType The type of id.
* \tparam BLOCK_SIZE The size of each thread block. * \tparam BLOCK_SIZE The size of each thread block.
* \tparam TILE_SIZE The number of edges to process per thread block. * \tparam TILE_SIZE The number of edges to process per thread block.
* \param global_srcs_device The source ids to map. * \param global_srcs_device The source ids to map.
* \param new_global_srcs_device The mapped source ids (output). * \param new_global_srcs_device The mapped source ids (output).
* \param global_dsts_device The destination ids to map. * \param global_dsts_device The destination ids to map.
* \param new_global_dsts_device The mapped destination ids (output). * \param new_global_dsts_device The mapped destination ids (output).
* \param num_edges The number of edges to map. * \param num_edges The number of edges to map.
* \param src_mapping The mapping of sources ids. * \param src_mapping The mapping of sources ids.
* \param src_hash_size The the size of source id hash table/mapping. * \param src_hash_size The the size of source id hash table/mapping.
* \param dst_mapping The mapping of destination ids. * \param dst_mapping The mapping of destination ids.
* \param dst_hash_size The the size of destination id hash table/mapping. * \param dst_hash_size The the size of destination id hash table/mapping.
*/ */
template<typename IdType, int BLOCK_SIZE, IdType TILE_SIZE> template <typename IdType, int BLOCK_SIZE, IdType TILE_SIZE>
__global__ void map_edge_ids( __global__ void map_edge_ids(
const IdType * const global_srcs_device, const IdType* const global_srcs_device,
IdType * const new_global_srcs_device, IdType* const new_global_srcs_device,
const IdType * const global_dsts_device, const IdType* const global_dsts_device,
IdType * const new_global_dsts_device, IdType* const new_global_dsts_device, const IdType num_edges,
const IdType num_edges,
DeviceOrderedHashTable<IdType> src_mapping, DeviceOrderedHashTable<IdType> src_mapping,
DeviceOrderedHashTable<IdType> dst_mapping) { DeviceOrderedHashTable<IdType> dst_mapping) {
assert(BLOCK_SIZE == blockDim.x); assert(BLOCK_SIZE == blockDim.x);
...@@ -88,87 +87,67 @@ __global__ void map_edge_ids( ...@@ -88,87 +87,67 @@ __global__ void map_edge_ids(
if (blockIdx.y == 0) { if (blockIdx.y == 0) {
map_vertex_ids<IdType, BLOCK_SIZE, TILE_SIZE>( map_vertex_ids<IdType, BLOCK_SIZE, TILE_SIZE>(
global_srcs_device, global_srcs_device, new_global_srcs_device, num_edges, src_mapping);
new_global_srcs_device,
num_edges,
src_mapping);
} else { } else {
map_vertex_ids<IdType, BLOCK_SIZE, TILE_SIZE>( map_vertex_ids<IdType, BLOCK_SIZE, TILE_SIZE>(
global_dsts_device, global_dsts_device, new_global_dsts_device, num_edges, dst_mapping);
new_global_dsts_device,
num_edges,
dst_mapping);
} }
} }
/** /**
* \brief Device level node maps for each node type. * \brief Device level node maps for each node type.
* *
* \param num_nodes Number of nodes per type. * \param num_nodes Number of nodes per type.
* \param offset When offset is set to 0, LhsHashTable is identical to RhsHashTable. * \param offset When offset is set to 0, LhsHashTable is identical to
* Or set to num_nodes.size()/2 to use seperated LhsHashTable and RhsHashTable. * RhsHashTable. Or set to num_nodes.size()/2 to use seperated
* \param ctx The DGL context. * LhsHashTable and RhsHashTable.
* \param stream The stream to operate on. * \param ctx The DGL context.
*/ * \param stream The stream to operate on.
template<typename IdType> */
template <typename IdType>
class DeviceNodeMap { class DeviceNodeMap {
public: public:
using Mapping = typename OrderedHashTable<IdType>::Mapping; using Mapping = typename OrderedHashTable<IdType>::Mapping;
DeviceNodeMap( DeviceNodeMap(
const std::vector<int64_t>& num_nodes, const std::vector<int64_t>& num_nodes, const int64_t offset,
const int64_t offset, DGLContext ctx, cudaStream_t stream)
DGLContext ctx, : num_types_(num_nodes.size()),
cudaStream_t stream) : rhs_offset_(offset),
num_types_(num_nodes.size()), hash_tables_(),
rhs_offset_(offset), ctx_(ctx) {
hash_tables_(),
ctx_(ctx) {
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
hash_tables_.reserve(num_types_); hash_tables_.reserve(num_types_);
for (int64_t i = 0; i < num_types_; ++i) { for (int64_t i = 0; i < num_types_; ++i) {
hash_tables_.emplace_back( hash_tables_.emplace_back(
new OrderedHashTable<IdType>( new OrderedHashTable<IdType>(num_nodes[i], ctx_, stream));
num_nodes[i],
ctx_,
stream));
} }
} }
OrderedHashTable<IdType>& LhsHashTable( OrderedHashTable<IdType>& LhsHashTable(const size_t index) {
const size_t index) {
return HashData(index); return HashData(index);
} }
OrderedHashTable<IdType>& RhsHashTable( OrderedHashTable<IdType>& RhsHashTable(const size_t index) {
const size_t index) { return HashData(index + rhs_offset_);
return HashData(index+rhs_offset_);
} }
const OrderedHashTable<IdType>& LhsHashTable( const OrderedHashTable<IdType>& LhsHashTable(const size_t index) const {
const size_t index) const {
return HashData(index); return HashData(index);
} }
const OrderedHashTable<IdType>& RhsHashTable( const OrderedHashTable<IdType>& RhsHashTable(const size_t index) const {
const size_t index) const { return HashData(index + rhs_offset_);
return HashData(index+rhs_offset_);
} }
IdType LhsHashSize( IdType LhsHashSize(const size_t index) const { return HashSize(index); }
const size_t index) const {
return HashSize(index);
}
IdType RhsHashSize( IdType RhsHashSize(const size_t index) const {
const size_t index) const { return HashSize(rhs_offset_ + index);
return HashSize(rhs_offset_+index);
} }
size_t Size() const { size_t Size() const { return hash_tables_.size(); }
return hash_tables_.size();
}
private: private:
int64_t num_types_; int64_t num_types_;
...@@ -176,45 +155,35 @@ class DeviceNodeMap { ...@@ -176,45 +155,35 @@ class DeviceNodeMap {
std::vector<std::unique_ptr<OrderedHashTable<IdType>>> hash_tables_; std::vector<std::unique_ptr<OrderedHashTable<IdType>>> hash_tables_;
DGLContext ctx_; DGLContext ctx_;
inline OrderedHashTable<IdType>& HashData( inline OrderedHashTable<IdType>& HashData(const size_t index) {
const size_t index) {
CHECK_LT(index, hash_tables_.size()); CHECK_LT(index, hash_tables_.size());
return *hash_tables_[index]; return *hash_tables_[index];
} }
inline const OrderedHashTable<IdType>& HashData( inline const OrderedHashTable<IdType>& HashData(const size_t index) const {
const size_t index) const {
CHECK_LT(index, hash_tables_.size()); CHECK_LT(index, hash_tables_.size());
return *hash_tables_[index]; return *hash_tables_[index];
} }
inline IdType HashSize( inline IdType HashSize(const size_t index) const {
const size_t index) const {
return HashData(index).size(); return HashData(index).size();
} }
}; };
template<typename IdType> template <typename IdType>
inline size_t RoundUpDiv( inline size_t RoundUpDiv(const IdType num, const size_t divisor) {
const IdType num, return static_cast<IdType>(num / divisor) + (num % divisor == 0 ? 0 : 1);
const size_t divisor) {
return static_cast<IdType>(num/divisor) + (num % divisor == 0 ? 0 : 1);
} }
template<typename IdType> template <typename IdType>
inline IdType RoundUp( inline IdType RoundUp(const IdType num, const size_t unit) {
const IdType num, return RoundUpDiv(num, unit) * unit;
const size_t unit) {
return RoundUpDiv(num, unit)*unit;
} }
template<typename IdType> template <typename IdType>
std::tuple<std::vector<IdArray>, std::vector<IdArray>> std::tuple<std::vector<IdArray>, std::vector<IdArray>> MapEdges(
MapEdges( HeteroGraphPtr graph, const std::vector<EdgeArray>& edge_sets,
HeteroGraphPtr graph, const DeviceNodeMap<IdType>& node_map, cudaStream_t stream) {
const std::vector<EdgeArray>& edge_sets,
const DeviceNodeMap<IdType>& node_map,
cudaStream_t stream) {
constexpr const int BLOCK_SIZE = 128; constexpr const int BLOCK_SIZE = 128;
constexpr const size_t TILE_SIZE = 1024; constexpr const size_t TILE_SIZE = 1024;
...@@ -233,8 +202,8 @@ MapEdges( ...@@ -233,8 +202,8 @@ MapEdges(
if (edges.id.defined() && edges.src->shape[0] > 0) { if (edges.id.defined() && edges.src->shape[0] > 0) {
const int64_t num_edges = edges.src->shape[0]; const int64_t num_edges = edges.src->shape[0];
new_lhs.emplace_back(NewIdArray(num_edges, ctx, sizeof(IdType)*8)); new_lhs.emplace_back(NewIdArray(num_edges, ctx, sizeof(IdType) * 8));
new_rhs.emplace_back(NewIdArray(num_edges, ctx, sizeof(IdType)*8)); new_rhs.emplace_back(NewIdArray(num_edges, ctx, sizeof(IdType) * 8));
const auto src_dst_types = graph->GetEndpointTypes(etype); const auto src_dst_types = graph->GetEndpointTypes(etype);
const int src_type = src_dst_types.first; const int src_type = src_dst_types.first;
...@@ -244,20 +213,17 @@ MapEdges( ...@@ -244,20 +213,17 @@ MapEdges(
const dim3 block(BLOCK_SIZE); const dim3 block(BLOCK_SIZE);
// map the srcs // map the srcs
CUDA_KERNEL_CALL((map_edge_ids<IdType, BLOCK_SIZE, TILE_SIZE>), CUDA_KERNEL_CALL(
grid, block, 0, stream, (map_edge_ids<IdType, BLOCK_SIZE, TILE_SIZE>), grid, block, 0, stream,
edges.src.Ptr<IdType>(), edges.src.Ptr<IdType>(), new_lhs.back().Ptr<IdType>(),
new_lhs.back().Ptr<IdType>(), edges.dst.Ptr<IdType>(), new_rhs.back().Ptr<IdType>(), num_edges,
edges.dst.Ptr<IdType>(), node_map.LhsHashTable(src_type).DeviceHandle(),
new_rhs.back().Ptr<IdType>(), node_map.RhsHashTable(dst_type).DeviceHandle());
num_edges,
node_map.LhsHashTable(src_type).DeviceHandle(),
node_map.RhsHashTable(dst_type).DeviceHandle());
} else { } else {
new_lhs.emplace_back( new_lhs.emplace_back(
aten::NullArray(DGLDataType{kDGLInt, sizeof(IdType)*8, 1}, ctx)); aten::NullArray(DGLDataType{kDGLInt, sizeof(IdType) * 8, 1}, ctx));
new_rhs.emplace_back( new_rhs.emplace_back(
aten::NullArray(DGLDataType{kDGLInt, sizeof(IdType)*8, 1}, ctx)); aten::NullArray(DGLDataType{kDGLInt, sizeof(IdType) * 8, 1}, ctx));
} }
} }
...@@ -265,7 +231,6 @@ MapEdges( ...@@ -265,7 +231,6 @@ MapEdges(
std::move(new_lhs), std::move(new_rhs)); std::move(new_lhs), std::move(new_rhs));
} }
} // namespace cuda } // namespace cuda
} // namespace transform } // namespace transform
} // namespace dgl } // namespace dgl
......
...@@ -18,13 +18,13 @@ ...@@ -18,13 +18,13 @@
* ids. * ids.
*/ */
#include <dgl/runtime/device_api.h>
#include <dgl/immutable_graph.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <utility> #include <dgl/immutable_graph.h>
#include <dgl/runtime/device_api.h>
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include <utility>
#include "../../../runtime/cuda/cuda_common.h" #include "../../../runtime/cuda/cuda_common.h"
#include "../../heterograph.h" #include "../../heterograph.h"
...@@ -40,42 +40,36 @@ namespace transform { ...@@ -40,42 +40,36 @@ namespace transform {
namespace { namespace {
template<typename IdType> template <typename IdType>
class DeviceNodeMapMaker { class DeviceNodeMapMaker {
public: public:
explicit DeviceNodeMapMaker( explicit DeviceNodeMapMaker(const std::vector<int64_t>& maxNodesPerType)
const std::vector<int64_t>& maxNodesPerType) : : max_num_nodes_(0) {
max_num_nodes_(0) { max_num_nodes_ =
max_num_nodes_ = *std::max_element(maxNodesPerType.begin(), *std::max_element(maxNodesPerType.begin(), maxNodesPerType.end());
maxNodesPerType.end());
} }
/** /**
* \brief This function builds node maps for each node type, preserving the * \brief This function builds node maps for each node type, preserving the
* order of the input nodes. Here it is assumed the lhs_nodes are not unique, * order of the input nodes. Here it is assumed the lhs_nodes are not unique,
* and thus a unique list is generated. * and thus a unique list is generated.
* *
* \param lhs_nodes The set of source input nodes. * \param lhs_nodes The set of source input nodes.
* \param rhs_nodes The set of destination input nodes. * \param rhs_nodes The set of destination input nodes.
* \param node_maps The node maps to be constructed. * \param node_maps The node maps to be constructed.
* \param count_lhs_device The number of unique source nodes (on the GPU). * \param count_lhs_device The number of unique source nodes (on the GPU).
* \param lhs_device The unique source nodes (on the GPU). * \param lhs_device The unique source nodes (on the GPU).
* \param stream The stream to operate on. * \param stream The stream to operate on.
*/ */
void Make( void Make(
const std::vector<IdArray>& lhs_nodes, const std::vector<IdArray>& lhs_nodes,
const std::vector<IdArray>& rhs_nodes, const std::vector<IdArray>& rhs_nodes,
DeviceNodeMap<IdType> * const node_maps, DeviceNodeMap<IdType>* const node_maps, int64_t* const count_lhs_device,
int64_t * const count_lhs_device, std::vector<IdArray>* const lhs_device, cudaStream_t stream) {
std::vector<IdArray>* const lhs_device,
cudaStream_t stream) {
const int64_t num_ntypes = lhs_nodes.size() + rhs_nodes.size(); const int64_t num_ntypes = lhs_nodes.size() + rhs_nodes.size();
CUDA_CALL(cudaMemsetAsync( CUDA_CALL(cudaMemsetAsync(
count_lhs_device, count_lhs_device, 0, num_ntypes * sizeof(*count_lhs_device), stream));
0,
num_ntypes*sizeof(*count_lhs_device),
stream));
// possibly dublicate lhs nodes // possibly dublicate lhs nodes
const int64_t lhs_num_ntypes = static_cast<int64_t>(lhs_nodes.size()); const int64_t lhs_num_ntypes = static_cast<int64_t>(lhs_nodes.size());
...@@ -84,10 +78,8 @@ class DeviceNodeMapMaker { ...@@ -84,10 +78,8 @@ class DeviceNodeMapMaker {
if (nodes->shape[0] > 0) { if (nodes->shape[0] > 0) {
CHECK_EQ(nodes->ctx.device_type, kDGLCUDA); CHECK_EQ(nodes->ctx.device_type, kDGLCUDA);
node_maps->LhsHashTable(ntype).FillWithDuplicates( node_maps->LhsHashTable(ntype).FillWithDuplicates(
nodes.Ptr<IdType>(), nodes.Ptr<IdType>(), nodes->shape[0],
nodes->shape[0], (*lhs_device)[ntype].Ptr<IdType>(), count_lhs_device + ntype,
(*lhs_device)[ntype].Ptr<IdType>(),
count_lhs_device+ntype,
stream); stream);
} }
} }
...@@ -98,28 +90,25 @@ class DeviceNodeMapMaker { ...@@ -98,28 +90,25 @@ class DeviceNodeMapMaker {
const IdArray& nodes = rhs_nodes[ntype]; const IdArray& nodes = rhs_nodes[ntype];
if (nodes->shape[0] > 0) { if (nodes->shape[0] > 0) {
node_maps->RhsHashTable(ntype).FillWithUnique( node_maps->RhsHashTable(ntype).FillWithUnique(
nodes.Ptr<IdType>(), nodes.Ptr<IdType>(), nodes->shape[0], stream);
nodes->shape[0],
stream);
} }
} }
} }
/** /**
* \brief This function builds node maps for each node type, preserving the * \brief This function builds node maps for each node type, preserving the
* order of the input nodes. Here it is assumed both lhs_nodes and rhs_nodes * order of the input nodes. Here it is assumed both lhs_nodes and rhs_nodes
* are unique. * are unique.
* *
* \param lhs_nodes The set of source input nodes. * \param lhs_nodes The set of source input nodes.
* \param rhs_nodes The set of destination input nodes. * \param rhs_nodes The set of destination input nodes.
* \param node_maps The node maps to be constructed. * \param node_maps The node maps to be constructed.
* \param stream The stream to operate on. * \param stream The stream to operate on.
*/ */
void Make( void Make(
const std::vector<IdArray>& lhs_nodes, const std::vector<IdArray>& lhs_nodes,
const std::vector<IdArray>& rhs_nodes, const std::vector<IdArray>& rhs_nodes,
DeviceNodeMap<IdType> * const node_maps, DeviceNodeMap<IdType>* const node_maps, cudaStream_t stream) {
cudaStream_t stream) {
const int64_t num_ntypes = lhs_nodes.size() + rhs_nodes.size(); const int64_t num_ntypes = lhs_nodes.size() + rhs_nodes.size();
// unique lhs nodes // unique lhs nodes
...@@ -129,9 +118,7 @@ class DeviceNodeMapMaker { ...@@ -129,9 +118,7 @@ class DeviceNodeMapMaker {
if (nodes->shape[0] > 0) { if (nodes->shape[0] > 0) {
CHECK_EQ(nodes->ctx.device_type, kDGLCUDA); CHECK_EQ(nodes->ctx.device_type, kDGLCUDA);
node_maps->LhsHashTable(ntype).FillWithUnique( node_maps->LhsHashTable(ntype).FillWithUnique(
nodes.Ptr<IdType>(), nodes.Ptr<IdType>(), nodes->shape[0], stream);
nodes->shape[0],
stream);
} }
} }
...@@ -141,9 +128,7 @@ class DeviceNodeMapMaker { ...@@ -141,9 +128,7 @@ class DeviceNodeMapMaker {
const IdArray& nodes = rhs_nodes[ntype]; const IdArray& nodes = rhs_nodes[ntype];
if (nodes->shape[0] > 0) { if (nodes->shape[0] > 0) {
node_maps->RhsHashTable(ntype).FillWithUnique( node_maps->RhsHashTable(ntype).FillWithUnique(
nodes.Ptr<IdType>(), nodes.Ptr<IdType>(), nodes->shape[0], stream);
nodes->shape[0],
stream);
} }
} }
} }
...@@ -152,20 +137,15 @@ class DeviceNodeMapMaker { ...@@ -152,20 +137,15 @@ class DeviceNodeMapMaker {
IdType max_num_nodes_; IdType max_num_nodes_;
}; };
// Since partial specialization is not allowed for functions, use this as an // Since partial specialization is not allowed for functions, use this as an
// intermediate for ToBlock where XPU = kDGLCUDA. // intermediate for ToBlock where XPU = kDGLCUDA.
template<typename IdType> template <typename IdType>
std::tuple<HeteroGraphPtr, std::vector<IdArray>> std::tuple<HeteroGraphPtr, std::vector<IdArray>> ToBlockGPU(
ToBlockGPU( HeteroGraphPtr graph, const std::vector<IdArray>& rhs_nodes,
HeteroGraphPtr graph, const bool include_rhs_in_lhs, std::vector<IdArray>* const lhs_nodes_ptr) {
const std::vector<IdArray> &rhs_nodes,
const bool include_rhs_in_lhs,
std::vector<IdArray>* const lhs_nodes_ptr) {
std::vector<IdArray>& lhs_nodes = *lhs_nodes_ptr; std::vector<IdArray>& lhs_nodes = *lhs_nodes_ptr;
const bool generate_lhs_nodes = lhs_nodes.empty(); const bool generate_lhs_nodes = lhs_nodes.empty();
const auto& ctx = graph->Context(); const auto& ctx = graph->Context();
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
...@@ -176,16 +156,17 @@ ToBlockGPU( ...@@ -176,16 +156,17 @@ ToBlockGPU(
} }
// Since DST nodes are included in SRC nodes, a common requirement is to fetch // Since DST nodes are included in SRC nodes, a common requirement is to fetch
// the DST node features from the SRC nodes features. To avoid expensive sparse lookup, // the DST node features from the SRC nodes features. To avoid expensive
// the function assures that the DST nodes in both SRC and DST sets have the same ids. // sparse lookup, the function assures that the DST nodes in both SRC and DST
// As a result, given the node feature tensor ``X`` of type ``utype``, // sets have the same ids. As a result, given the node feature tensor ``X`` of
// the following code finds the corresponding DST node features of type ``vtype``: // type ``utype``, the following code finds the corresponding DST node
// features of type ``vtype``:
const int64_t num_etypes = graph->NumEdgeTypes(); const int64_t num_etypes = graph->NumEdgeTypes();
const int64_t num_ntypes = graph->NumVertexTypes(); const int64_t num_ntypes = graph->NumVertexTypes();
CHECK(rhs_nodes.size() == static_cast<size_t>(num_ntypes)) CHECK(rhs_nodes.size() == static_cast<size_t>(num_ntypes))
<< "rhs_nodes not given for every node type"; << "rhs_nodes not given for every node type";
std::vector<EdgeArray> edge_arrays(num_etypes); std::vector<EdgeArray> edge_arrays(num_etypes);
for (int64_t etype = 0; etype < num_etypes; ++etype) { for (int64_t etype = 0; etype < num_etypes; ++etype) {
...@@ -197,9 +178,9 @@ ToBlockGPU( ...@@ -197,9 +178,9 @@ ToBlockGPU(
} }
// count lhs and rhs nodes // count lhs and rhs nodes
std::vector<int64_t> maxNodesPerType(num_ntypes*2, 0); std::vector<int64_t> maxNodesPerType(num_ntypes * 2, 0);
for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) { for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) {
maxNodesPerType[ntype+num_ntypes] += rhs_nodes[ntype]->shape[0]; maxNodesPerType[ntype + num_ntypes] += rhs_nodes[ntype]->shape[0];
if (generate_lhs_nodes) { if (generate_lhs_nodes) {
if (include_rhs_in_lhs) { if (include_rhs_in_lhs) {
...@@ -226,16 +207,16 @@ ToBlockGPU( ...@@ -226,16 +207,16 @@ ToBlockGPU(
if (generate_lhs_nodes) { if (generate_lhs_nodes) {
std::vector<int64_t> src_node_offsets(num_ntypes, 0); std::vector<int64_t> src_node_offsets(num_ntypes, 0);
for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) { for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) {
src_nodes[ntype] = NewIdArray(maxNodesPerType[ntype], ctx, src_nodes[ntype] =
sizeof(IdType)*8); NewIdArray(maxNodesPerType[ntype], ctx, sizeof(IdType) * 8);
if (include_rhs_in_lhs) { if (include_rhs_in_lhs) {
// place rhs nodes first // place rhs nodes first
device->CopyDataFromTo(rhs_nodes[ntype].Ptr<IdType>(), 0, device->CopyDataFromTo(
src_nodes[ntype].Ptr<IdType>(), src_node_offsets[ntype], rhs_nodes[ntype].Ptr<IdType>(), 0, src_nodes[ntype].Ptr<IdType>(),
sizeof(IdType)*rhs_nodes[ntype]->shape[0], src_node_offsets[ntype],
rhs_nodes[ntype]->ctx, src_nodes[ntype]->ctx, sizeof(IdType) * rhs_nodes[ntype]->shape[0], rhs_nodes[ntype]->ctx,
rhs_nodes[ntype]->dtype); src_nodes[ntype]->ctx, rhs_nodes[ntype]->dtype);
src_node_offsets[ntype] += sizeof(IdType)*rhs_nodes[ntype]->shape[0]; src_node_offsets[ntype] += sizeof(IdType) * rhs_nodes[ntype]->shape[0];
} }
} }
for (int64_t etype = 0; etype < num_etypes; ++etype) { for (int64_t etype = 0; etype < num_etypes; ++etype) {
...@@ -244,14 +225,13 @@ ToBlockGPU( ...@@ -244,14 +225,13 @@ ToBlockGPU(
if (edge_arrays[etype].src.defined()) { if (edge_arrays[etype].src.defined()) {
device->CopyDataFromTo( device->CopyDataFromTo(
edge_arrays[etype].src.Ptr<IdType>(), 0, edge_arrays[etype].src.Ptr<IdType>(), 0,
src_nodes[srctype].Ptr<IdType>(), src_nodes[srctype].Ptr<IdType>(), src_node_offsets[srctype],
src_node_offsets[srctype], sizeof(IdType) * edge_arrays[etype].src->shape[0],
sizeof(IdType)*edge_arrays[etype].src->shape[0], rhs_nodes[srctype]->ctx, src_nodes[srctype]->ctx,
rhs_nodes[srctype]->ctx,
src_nodes[srctype]->ctx,
rhs_nodes[srctype]->dtype); rhs_nodes[srctype]->dtype);
src_node_offsets[srctype] += sizeof(IdType)*edge_arrays[etype].src->shape[0]; src_node_offsets[srctype] +=
sizeof(IdType) * edge_arrays[etype].src->shape[0];
} }
} }
} else { } else {
...@@ -267,47 +247,35 @@ ToBlockGPU( ...@@ -267,47 +247,35 @@ ToBlockGPU(
if (generate_lhs_nodes) { if (generate_lhs_nodes) {
lhs_nodes.reserve(num_ntypes); lhs_nodes.reserve(num_ntypes);
for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) { for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) {
lhs_nodes.emplace_back(NewIdArray( lhs_nodes.emplace_back(
maxNodesPerType[ntype], ctx, sizeof(IdType)*8)); NewIdArray(maxNodesPerType[ntype], ctx, sizeof(IdType) * 8));
} }
} }
std::vector<int64_t> num_nodes_per_type(num_ntypes*2); std::vector<int64_t> num_nodes_per_type(num_ntypes * 2);
// populate RHS nodes from what we already know // populate RHS nodes from what we already know
for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) { for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) {
num_nodes_per_type[num_ntypes+ntype] = rhs_nodes[ntype]->shape[0]; num_nodes_per_type[num_ntypes + ntype] = rhs_nodes[ntype]->shape[0];
} }
// populate the mappings // populate the mappings
if (generate_lhs_nodes) { if (generate_lhs_nodes) {
int64_t * count_lhs_device = static_cast<int64_t*>( int64_t* count_lhs_device = static_cast<int64_t*>(
device->AllocWorkspace(ctx, sizeof(int64_t)*num_ntypes*2)); device->AllocWorkspace(ctx, sizeof(int64_t) * num_ntypes * 2));
maker.Make( maker.Make(
src_nodes, src_nodes, rhs_nodes, &node_maps, count_lhs_device, &lhs_nodes, stream);
rhs_nodes,
&node_maps,
count_lhs_device,
&lhs_nodes,
stream);
device->CopyDataFromTo( device->CopyDataFromTo(
count_lhs_device, 0, count_lhs_device, 0, num_nodes_per_type.data(), 0,
num_nodes_per_type.data(), 0, sizeof(*num_nodes_per_type.data()) * num_ntypes, ctx,
sizeof(*num_nodes_per_type.data())*num_ntypes, DGLContext{kDGLCPU, 0}, DGLDataType{kDGLInt, 64, 1});
ctx,
DGLContext{kDGLCPU, 0},
DGLDataType{kDGLInt, 64, 1});
device->StreamSync(ctx, stream); device->StreamSync(ctx, stream);
// wait for the node counts to finish transferring // wait for the node counts to finish transferring
device->FreeWorkspace(ctx, count_lhs_device); device->FreeWorkspace(ctx, count_lhs_device);
} else { } else {
maker.Make( maker.Make(lhs_nodes, rhs_nodes, &node_maps, stream);
lhs_nodes,
rhs_nodes,
&node_maps,
stream);
for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) { for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) {
num_nodes_per_type[ntype] = lhs_nodes[ntype]->shape[0]; num_nodes_per_type[ntype] = lhs_nodes[ntype]->shape[0];
...@@ -321,7 +289,7 @@ ToBlockGPU( ...@@ -321,7 +289,7 @@ ToBlockGPU(
induced_edges.push_back(edge_arrays[etype].id); induced_edges.push_back(edge_arrays[etype].id);
} else { } else {
induced_edges.push_back( induced_edges.push_back(
aten::NullArray(DGLDataType{kDGLInt, sizeof(IdType)*8, 1}, ctx)); aten::NullArray(DGLDataType{kDGLInt, sizeof(IdType) * 8, 1}, ctx));
} }
} }
...@@ -329,8 +297,8 @@ ToBlockGPU( ...@@ -329,8 +297,8 @@ ToBlockGPU(
const auto meta_graph = graph->meta_graph(); const auto meta_graph = graph->meta_graph();
const EdgeArray etypes = meta_graph->Edges("eid"); const EdgeArray etypes = meta_graph->Edges("eid");
const IdArray new_dst = Add(etypes.dst, num_ntypes); const IdArray new_dst = Add(etypes.dst, num_ntypes);
const auto new_meta_graph = ImmutableGraph::CreateFromCOO( const auto new_meta_graph =
num_ntypes * 2, etypes.src, new_dst); ImmutableGraph::CreateFromCOO(num_ntypes * 2, etypes.src, new_dst);
// allocate vector for graph relations while GPU is busy // allocate vector for graph relations while GPU is busy
std::vector<HeteroGraphPtr> rel_graphs; std::vector<HeteroGraphPtr> rel_graphs;
...@@ -358,20 +326,17 @@ ToBlockGPU( ...@@ -358,20 +326,17 @@ ToBlockGPU(
// No rhs nodes are given for this edge type. Create an empty graph. // No rhs nodes are given for this edge type. Create an empty graph.
rel_graphs.push_back(CreateFromCOO( rel_graphs.push_back(CreateFromCOO(
2, lhs_nodes[srctype]->shape[0], rhs_nodes[dsttype]->shape[0], 2, lhs_nodes[srctype]->shape[0], rhs_nodes[dsttype]->shape[0],
aten::NullArray(DGLDataType{kDGLInt, sizeof(IdType)*8, 1}, ctx), aten::NullArray(DGLDataType{kDGLInt, sizeof(IdType) * 8, 1}, ctx),
aten::NullArray(DGLDataType{kDGLInt, sizeof(IdType)*8, 1}, ctx))); aten::NullArray(DGLDataType{kDGLInt, sizeof(IdType) * 8, 1}, ctx)));
} else { } else {
rel_graphs.push_back(CreateFromCOO( rel_graphs.push_back(CreateFromCOO(
2, 2, lhs_nodes[srctype]->shape[0], rhs_nodes[dsttype]->shape[0],
lhs_nodes[srctype]->shape[0], new_lhs[etype], new_rhs[etype]));
rhs_nodes[dsttype]->shape[0],
new_lhs[etype],
new_rhs[etype]));
} }
} }
HeteroGraphPtr new_graph = CreateHeteroGraph( HeteroGraphPtr new_graph =
new_meta_graph, rel_graphs, num_nodes_per_type); CreateHeteroGraph(new_meta_graph, rel_graphs, num_nodes_per_type);
// return the new graph, the new src nodes, and new edges // return the new graph, the new src nodes, and new edges
return std::make_tuple(new_graph, induced_edges); return std::make_tuple(new_graph, induced_edges);
...@@ -379,26 +344,22 @@ ToBlockGPU( ...@@ -379,26 +344,22 @@ ToBlockGPU(
} // namespace } // namespace
// Use explicit names to get around MSVC's broken mangling that thinks the following two // Use explicit names to get around MSVC's broken mangling that thinks the
// functions are the same. // following two functions are the same. Using template<> fails to export the
// Using template<> fails to export the symbols. // symbols.
std::tuple<HeteroGraphPtr, std::vector<IdArray>> std::tuple<HeteroGraphPtr, std::vector<IdArray>>
// ToBlock<kDGLCUDA, int32_t> // ToBlock<kDGLCUDA, int32_t>
ToBlockGPU32( ToBlockGPU32(
HeteroGraphPtr graph, HeteroGraphPtr graph, const std::vector<IdArray>& rhs_nodes,
const std::vector<IdArray> &rhs_nodes, bool include_rhs_in_lhs, std::vector<IdArray>* const lhs_nodes) {
bool include_rhs_in_lhs,
std::vector<IdArray>* const lhs_nodes) {
return ToBlockGPU<int32_t>(graph, rhs_nodes, include_rhs_in_lhs, lhs_nodes); return ToBlockGPU<int32_t>(graph, rhs_nodes, include_rhs_in_lhs, lhs_nodes);
} }
std::tuple<HeteroGraphPtr, std::vector<IdArray>> std::tuple<HeteroGraphPtr, std::vector<IdArray>>
// ToBlock<kDGLCUDA, int64_t> // ToBlock<kDGLCUDA, int64_t>
ToBlockGPU64( ToBlockGPU64(
HeteroGraphPtr graph, HeteroGraphPtr graph, const std::vector<IdArray>& rhs_nodes,
const std::vector<IdArray> &rhs_nodes, bool include_rhs_in_lhs, std::vector<IdArray>* const lhs_nodes) {
bool include_rhs_in_lhs,
std::vector<IdArray>* const lhs_nodes) {
return ToBlockGPU<int64_t>(graph, rhs_nodes, include_rhs_in_lhs, lhs_nodes); return ToBlockGPU<int64_t>(graph, rhs_nodes, include_rhs_in_lhs, lhs_nodes);
} }
......
This diff is collapsed.
...@@ -4,9 +4,11 @@ ...@@ -4,9 +4,11 @@
* \brief k-nearest-neighbor (KNN) interface * \brief k-nearest-neighbor (KNN) interface
*/ */
#include <dgl/runtime/registry.h>
#include <dgl/runtime/packed_func.h>
#include "knn.h" #include "knn.h"
#include <dgl/runtime/packed_func.h>
#include <dgl/runtime/registry.h>
#include "../../array/check.h" #include "../../array/check.h"
using namespace dgl::runtime; using namespace dgl::runtime;
...@@ -14,57 +16,59 @@ namespace dgl { ...@@ -14,57 +16,59 @@ namespace dgl {
namespace transform { namespace transform {
DGL_REGISTER_GLOBAL("transform._CAPI_DGLKNN") DGL_REGISTER_GLOBAL("transform._CAPI_DGLKNN")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
const NDArray data_points = args[0]; const NDArray data_points = args[0];
const IdArray data_offsets = args[1]; const IdArray data_offsets = args[1];
const NDArray query_points = args[2]; const NDArray query_points = args[2];
const IdArray query_offsets = args[3]; const IdArray query_offsets = args[3];
const int k = args[4]; const int k = args[4];
IdArray result = args[5]; IdArray result = args[5];
const std::string algorithm = args[6]; const std::string algorithm = args[6];
aten::CheckContiguous( aten::CheckContiguous(
{data_points, data_offsets, query_points, query_offsets, result}, {data_points, data_offsets, query_points, query_offsets, result},
{"data_points", "data_offsets", "query_points", "query_offsets", "result"}); {"data_points", "data_offsets", "query_points", "query_offsets",
aten::CheckCtx( "result"});
data_points->ctx, {data_offsets, query_points, query_offsets, result}, aten::CheckCtx(
{"data_offsets", "query_points", "query_offsets", "result"}); data_points->ctx, {data_offsets, query_points, query_offsets, result},
{"data_offsets", "query_points", "query_offsets", "result"});
ATEN_XPU_SWITCH_CUDA(data_points->ctx.device_type, XPU, "KNN", { ATEN_XPU_SWITCH_CUDA(data_points->ctx.device_type, XPU, "KNN", {
ATEN_FLOAT_TYPE_SWITCH(data_points->dtype, FloatType, "data_points", { ATEN_FLOAT_TYPE_SWITCH(data_points->dtype, FloatType, "data_points", {
ATEN_ID_TYPE_SWITCH(result->dtype, IdType, { ATEN_ID_TYPE_SWITCH(result->dtype, IdType, {
KNN<XPU, FloatType, IdType>( KNN<XPU, FloatType, IdType>(
data_points, data_offsets, query_points, data_points, data_offsets, query_points, query_offsets, k,
query_offsets, k, result, algorithm); result, algorithm);
});
}); });
}); });
}); });
});
DGL_REGISTER_GLOBAL("transform._CAPI_DGLNNDescent") DGL_REGISTER_GLOBAL("transform._CAPI_DGLNNDescent")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
const NDArray points = args[0]; const NDArray points = args[0];
const IdArray offsets = args[1]; const IdArray offsets = args[1];
const IdArray result = args[2]; const IdArray result = args[2];
const int k = args[3]; const int k = args[3];
const int num_iters = args[4]; const int num_iters = args[4];
const int num_candidates = args[5]; const int num_candidates = args[5];
const double delta = args[6]; const double delta = args[6];
aten::CheckContiguous( aten::CheckContiguous(
{points, offsets, result}, {"points", "offsets", "result"}); {points, offsets, result}, {"points", "offsets", "result"});
aten::CheckCtx( aten::CheckCtx(
points->ctx, {points, offsets, result}, {"points", "offsets", "result"}); points->ctx, {points, offsets, result},
{"points", "offsets", "result"});
ATEN_XPU_SWITCH_CUDA(points->ctx.device_type, XPU, "NNDescent", { ATEN_XPU_SWITCH_CUDA(points->ctx.device_type, XPU, "NNDescent", {
ATEN_FLOAT_TYPE_SWITCH(points->dtype, FloatType, "points", { ATEN_FLOAT_TYPE_SWITCH(points->dtype, FloatType, "points", {
ATEN_ID_TYPE_SWITCH(result->dtype, IdType, { ATEN_ID_TYPE_SWITCH(result->dtype, IdType, {
NNDescent<XPU, FloatType, IdType>( NNDescent<XPU, FloatType, IdType>(
points, offsets, result, k, num_iters, num_candidates, delta); points, offsets, result, k, num_iters, num_candidates, delta);
});
}); });
}); });
}); });
});
} // namespace transform } // namespace transform
} // namespace dgl } // namespace dgl
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#define DGL_GRAPH_TRANSFORM_KNN_H_ #define DGL_GRAPH_TRANSFORM_KNN_H_
#include <dgl/array.h> #include <dgl/array.h>
#include <string> #include <string>
namespace dgl { namespace dgl {
...@@ -15,42 +16,45 @@ namespace transform { ...@@ -15,42 +16,45 @@ namespace transform {
/*! /*!
* \brief For each point in each segment in \a query_points, find \a k nearest * \brief For each point in each segment in \a query_points, find \a k nearest
* points in the same segment in \a data_points. \a data_offsets and \a query_offsets * points in the same segment in \a data_points. \a data_offsets and \a
* determine the start index of each segment in \a data_points and \a query_points. * query_offsets determine the start index of each segment in \a
* data_points and \a query_points.
* *
* \param data_points dataset points. * \param data_points dataset points.
* \param data_offsets offsets of point index in \a data_points. * \param data_offsets offsets of point index in \a data_points.
* \param query_points query points. * \param query_points query points.
* \param query_offsets offsets of point index in \a query_points. * \param query_offsets offsets of point index in \a query_points.
* \param k the number of nearest points. * \param k the number of nearest points.
* \param result output array. A 2D tensor indicating the index * \param result output array. A 2D tensor indicating the index relation
* relation between \a query_points and \a data_points. * between \a query_points and \a data_points.
* \param algorithm algorithm used to compute the k-nearest neighbors. * \param algorithm algorithm used to compute the k-nearest neighbors.
*/ */
template <DGLDeviceType XPU, typename FloatType, typename IdType> template <DGLDeviceType XPU, typename FloatType, typename IdType>
void KNN(const NDArray& data_points, const IdArray& data_offsets, void KNN(
const NDArray& query_points, const IdArray& query_offsets, const NDArray& data_points, const IdArray& data_offsets,
const int k, IdArray result, const std::string& algorithm); const NDArray& query_points, const IdArray& query_offsets, const int k,
IdArray result, const std::string& algorithm);
/*! /*!
* \brief For each input point, find \a k approximate nearest points in the same * \brief For each input point, find \a k approximate nearest points in the same
* segment using NN-descent algorithm. * segment using NN-descent algorithm.
* *
* \param points input points. * \param points input points.
* \param offsets offsets of point index. * \param offsets offsets of point index.
* \param result output array. A 2D tensor indicating the index relation between points. * \param result output array. A 2D tensor indicating the index relation between
* points.
* \param k the number of nearest points. * \param k the number of nearest points.
* \param num_iters The maximum number of NN-descent iterations to perform. * \param num_iters The maximum number of NN-descent iterations to perform.
* \param num_candidates The maximum number of candidates to be considered during one iteration. * \param num_candidates The maximum number of candidates to be considered
* during one iteration.
* \param delta A value controls the early abort. * \param delta A value controls the early abort.
*/ */
template <DGLDeviceType XPU, typename FloatType, typename IdType> template <DGLDeviceType XPU, typename FloatType, typename IdType>
void NNDescent(const NDArray& points, const IdArray& offsets, void NNDescent(
IdArray result, const int k, const int num_iters, const NDArray& points, const IdArray& offsets, IdArray result, const int k,
const int num_candidates, const double delta); const int num_iters, const int num_candidates, const double delta);
} // namespace transform } // namespace transform
} // namespace dgl } // namespace dgl
#endif // DGL_GRAPH_TRANSFORM_KNN_H_ #endif // DGL_GRAPH_TRANSFORM_KNN_H_
...@@ -4,12 +4,14 @@ ...@@ -4,12 +4,14 @@
* \brief Line graph implementation * \brief Line graph implementation
*/ */
#include <dgl/base_heterograph.h>
#include <dgl/transform.h>
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/base_heterograph.h>
#include <dgl/packed_func_ext.h> #include <dgl/packed_func_ext.h>
#include <vector> #include <dgl/transform.h>
#include <utility> #include <utility>
#include <vector>
#include "../../c_api_common.h" #include "../../c_api_common.h"
#include "../heterograph.h" #include "../heterograph.h"
...@@ -21,26 +23,25 @@ using namespace dgl::aten; ...@@ -21,26 +23,25 @@ using namespace dgl::aten;
namespace transform { namespace transform {
/*! /*!
* \brief Create Line Graph * \brief Create Line Graph.
* \param hg Graph * \param hg Graph.
* \param backtracking whether the pair of (v, u) (u, v) edges are treated as linked * \param backtracking whether the pair of (v, u) (u, v) edges are treated as
* \return The Line Graph * linked.
* \return The Line Graph.
*/ */
HeteroGraphPtr CreateLineGraph( HeteroGraphPtr CreateLineGraph(HeteroGraphPtr hg, bool backtracking) {
HeteroGraphPtr hg,
bool backtracking) {
const auto hgp = std::dynamic_pointer_cast<HeteroGraph>(hg); const auto hgp = std::dynamic_pointer_cast<HeteroGraph>(hg);
return hgp->LineGraph(backtracking); return hgp->LineGraph(backtracking);
} }
DGL_REGISTER_GLOBAL("transform._CAPI_DGLHeteroLineGraph") DGL_REGISTER_GLOBAL("transform._CAPI_DGLHeteroLineGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
bool backtracking = args[1]; bool backtracking = args[1];
auto hgptr = CreateLineGraph(hg.sptr(), backtracking); auto hgptr = CreateLineGraph(hg.sptr(), backtracking);
*rv = HeteroGraphRef(hgptr); *rv = HeteroGraphRef(hgptr);
}); });
}; // namespace transform }; // namespace transform
}; // namespace dgl }; // namespace dgl
...@@ -19,14 +19,15 @@ namespace transform { ...@@ -19,14 +19,15 @@ namespace transform {
#if !defined(_WIN32) #if !defined(_WIN32)
IdArray MetisPartition(UnitGraphPtr g, int k, NDArray vwgt_arr, IdArray MetisPartition(
const std::string &mode, bool obj_cut) { UnitGraphPtr g, int k, NDArray vwgt_arr, const std::string &mode,
bool obj_cut) {
// Mode can only be "k-way" or "recursive" // Mode can only be "k-way" or "recursive"
CHECK(mode == "k-way" || mode == "recursive") CHECK(mode == "k-way" || mode == "recursive")
<< "mode can only be \"k-way\" or \"recursive\""; << "mode can only be \"k-way\" or \"recursive\"";
// The index type of Metis needs to be compatible with DGL index type. // The index type of Metis needs to be compatible with DGL index type.
CHECK_EQ(sizeof(idx_t), sizeof(int64_t)) CHECK_EQ(sizeof(idx_t), sizeof(int64_t))
<< "Metis only supports int64 graph for now"; << "Metis only supports int64 graph for now";
// This is a symmetric graph, so in-csr and out-csr are the same. // This is a symmetric graph, so in-csr and out-csr are the same.
const auto mat = g->GetCSCMatrix(0); const auto mat = g->GetCSCMatrix(0);
// const auto mat = g->GetInCSR()->ToCSRMatrix(); // const auto mat = g->GetInCSR()->ToCSRMatrix();
...@@ -42,16 +43,17 @@ IdArray MetisPartition(UnitGraphPtr g, int k, NDArray vwgt_arr, ...@@ -42,16 +43,17 @@ IdArray MetisPartition(UnitGraphPtr g, int k, NDArray vwgt_arr,
int64_t vwgt_len = vwgt_arr->shape[0]; int64_t vwgt_len = vwgt_arr->shape[0];
CHECK_EQ(sizeof(idx_t), vwgt_arr->dtype.bits / 8) CHECK_EQ(sizeof(idx_t), vwgt_arr->dtype.bits / 8)
<< "The vertex weight array doesn't have right type"; << "The vertex weight array doesn't have right type";
CHECK(vwgt_len % g->NumVertices(0) == 0) CHECK(vwgt_len % g->NumVertices(0) == 0)
<< "The vertex weight array doesn't have right number of elements"; << "The vertex weight array doesn't have right number of elements";
idx_t *vwgt = NULL; idx_t *vwgt = NULL;
if (vwgt_len > 0) { if (vwgt_len > 0) {
ncon = vwgt_len / g->NumVertices(0); ncon = vwgt_len / g->NumVertices(0);
vwgt = static_cast<idx_t *>(vwgt_arr->data); vwgt = static_cast<idx_t *>(vwgt_arr->data);
} }
auto partition_func = (mode == "k-way") ? METIS_PartGraphKway : METIS_PartGraphRecursive; auto partition_func =
(mode == "k-way") ? METIS_PartGraphKway : METIS_PartGraphRecursive;
idx_t options[METIS_NOPTIONS]; idx_t options[METIS_NOPTIONS];
METIS_SetDefaultOptions(options); METIS_SetDefaultOptions(options);
...@@ -67,21 +69,21 @@ IdArray MetisPartition(UnitGraphPtr g, int k, NDArray vwgt_arr, ...@@ -67,21 +69,21 @@ IdArray MetisPartition(UnitGraphPtr g, int k, NDArray vwgt_arr,
} }
int ret = partition_func( int ret = partition_func(
&nvtxs, // The number of vertices &nvtxs, // The number of vertices
&ncon, // The number of balancing constraints. &ncon, // The number of balancing constraints.
xadj, // indptr xadj, // indptr
adjncy, // indices adjncy, // indices
vwgt, // the weights of the vertices vwgt, // the weights of the vertices
NULL, // The size of the vertices for computing NULL, // The size of the vertices for computing
// the total communication volume // the total communication volume
NULL, // The weights of the edges NULL, // The weights of the edges
&nparts, // The number of partitions. &nparts, // The number of partitions.
NULL, // the desired weight for each partition and constraint NULL, // the desired weight for each partition and constraint
NULL, // the allowed load imbalance tolerance NULL, // the allowed load imbalance tolerance
options, // the array of options options, // the array of options
&objval, // the edge-cut or the total communication volume of &objval, // the edge-cut or the total communication volume of
// the partitioning solution // the partitioning solution
part); part);
if (obj_cut) { if (obj_cut) {
LOG(INFO) << "Partition a graph with " << g->NumVertices(0) << " nodes and " LOG(INFO) << "Partition a graph with " << g->NumVertices(0) << " nodes and "
...@@ -110,22 +112,22 @@ IdArray MetisPartition(UnitGraphPtr g, int k, NDArray vwgt_arr, ...@@ -110,22 +112,22 @@ IdArray MetisPartition(UnitGraphPtr g, int k, NDArray vwgt_arr,
#endif // !defined(_WIN32) #endif // !defined(_WIN32)
DGL_REGISTER_GLOBAL("partition._CAPI_DGLMetisPartition_Hetero") DGL_REGISTER_GLOBAL("partition._CAPI_DGLMetisPartition_Hetero")
.set_body([](DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
HeteroGraphRef g = args[0]; HeteroGraphRef g = args[0];
auto hgptr = std::dynamic_pointer_cast<HeteroGraph>(g.sptr()); auto hgptr = std::dynamic_pointer_cast<HeteroGraph>(g.sptr());
CHECK(hgptr) << "Invalid HeteroGraph object"; CHECK(hgptr) << "Invalid HeteroGraph object";
CHECK_EQ(hgptr->relation_graphs().size(), 1) CHECK_EQ(hgptr->relation_graphs().size(), 1)
<< "Metis partition only supports HomoGraph"; << "Metis partition only supports HomoGraph";
auto ugptr = hgptr->relation_graphs()[0]; auto ugptr = hgptr->relation_graphs()[0];
int k = args[1]; int k = args[1];
NDArray vwgt = args[2]; NDArray vwgt = args[2];
std::string mode = args[3]; std::string mode = args[3];
bool obj_cut = args[4]; bool obj_cut = args[4];
#if !defined(_WIN32) #if !defined(_WIN32)
*rv = MetisPartition(ugptr, k, vwgt, mode, obj_cut); *rv = MetisPartition(ugptr, k, vwgt, mode, obj_cut);
#else #else
LOG(FATAL) << "Metis partition does not support Windows."; LOG(FATAL) << "Metis partition does not support Windows.";
#endif // !defined(_WIN32) #endif // !defined(_WIN32)
}); });
} // namespace transform } // namespace transform
} // namespace dgl } // namespace dgl
...@@ -37,25 +37,28 @@ HeteroGraphPtr ReorderUnitGraph(UnitGraphPtr ug, IdArray new_order) { ...@@ -37,25 +37,28 @@ HeteroGraphPtr ReorderUnitGraph(UnitGraphPtr ug, IdArray new_order) {
if (format & CSC_CODE) { if (format & CSC_CODE) {
auto cscmat = ug->GetCSCMatrix(0); auto cscmat = ug->GetCSCMatrix(0);
auto new_cscmat = aten::CSRReorder(cscmat, new_order, new_order); auto new_cscmat = aten::CSRReorder(cscmat, new_order, new_order);
return UnitGraph::CreateFromCSC(ug->NumVertexTypes(), new_cscmat, ug->GetAllowedFormats()); return UnitGraph::CreateFromCSC(
ug->NumVertexTypes(), new_cscmat, ug->GetAllowedFormats());
} else if (format & CSR_CODE) { } else if (format & CSR_CODE) {
auto csrmat = ug->GetCSRMatrix(0); auto csrmat = ug->GetCSRMatrix(0);
auto new_csrmat = aten::CSRReorder(csrmat, new_order, new_order); auto new_csrmat = aten::CSRReorder(csrmat, new_order, new_order);
return UnitGraph::CreateFromCSR(ug->NumVertexTypes(), new_csrmat, ug->GetAllowedFormats()); return UnitGraph::CreateFromCSR(
ug->NumVertexTypes(), new_csrmat, ug->GetAllowedFormats());
} else { } else {
auto coomat = ug->GetCOOMatrix(0); auto coomat = ug->GetCOOMatrix(0);
auto new_coomat = aten::COOReorder(coomat, new_order, new_order); auto new_coomat = aten::COOReorder(coomat, new_order, new_order);
return UnitGraph::CreateFromCOO(ug->NumVertexTypes(), new_coomat, ug->GetAllowedFormats()); return UnitGraph::CreateFromCOO(
ug->NumVertexTypes(), new_coomat, ug->GetAllowedFormats());
} }
} }
HaloHeteroSubgraph GetSubgraphWithHalo(std::shared_ptr<HeteroGraph> hg, HaloHeteroSubgraph GetSubgraphWithHalo(
IdArray nodes, int num_hops) { std::shared_ptr<HeteroGraph> hg, IdArray nodes, int num_hops) {
CHECK_EQ(hg->NumBits(), 64) << "halo subgraph only supports 64bits graph"; CHECK_EQ(hg->NumBits(), 64) << "halo subgraph only supports 64bits graph";
CHECK_EQ(hg->relation_graphs().size(), 1) CHECK_EQ(hg->relation_graphs().size(), 1)
<< "halo subgraph only supports homogeneous graph"; << "halo subgraph only supports homogeneous graph";
CHECK_EQ(nodes->dtype.bits, 64) CHECK_EQ(nodes->dtype.bits, 64)
<< "halo subgraph only supports 64bits nodes tensor"; << "halo subgraph only supports 64bits nodes tensor";
const dgl_id_t *nid = static_cast<dgl_id_t *>(nodes->data); const dgl_id_t *nid = static_cast<dgl_id_t *>(nodes->data);
const auto id_len = nodes->shape[0]; const auto id_len = nodes->shape[0];
// A map contains all nodes in the subgraph. // A map contains all nodes in the subgraph.
...@@ -113,8 +116,8 @@ HaloHeteroSubgraph GetSubgraphWithHalo(std::shared_ptr<HeteroGraph> hg, ...@@ -113,8 +116,8 @@ HaloHeteroSubgraph GetSubgraphWithHalo(std::shared_ptr<HeteroGraph> hg,
const dgl_id_t *eid_data = static_cast<dgl_id_t *>(eid->data); const dgl_id_t *eid_data = static_cast<dgl_id_t *>(eid->data);
for (int64_t i = 0; i < num_edges; i++) { for (int64_t i = 0; i < num_edges; i++) {
auto it1 = orig_nodes.find(src_data[i]); auto it1 = orig_nodes.find(src_data[i]);
// If the source node is in the partition, we have got this edge when we iterate over // If the source node is in the partition, we have got this edge when we
// the out-edges above. // iterate over the out-edges above.
if (it1 == orig_nodes.end()) { if (it1 == orig_nodes.end()) {
edge_src.push_back(src_data[i]); edge_src.push_back(src_data[i]);
edge_dst.push_back(dst_data[i]); edge_dst.push_back(dst_data[i]);
...@@ -164,10 +167,10 @@ HaloHeteroSubgraph GetSubgraphWithHalo(std::shared_ptr<HeteroGraph> hg, ...@@ -164,10 +167,10 @@ HaloHeteroSubgraph GetSubgraphWithHalo(std::shared_ptr<HeteroGraph> hg,
} }
num_edges = edge_src.size(); num_edges = edge_src.size();
IdArray new_src = IdArray::Empty({num_edges}, DGLDataType{kDGLInt, 64, 1}, IdArray new_src = IdArray::Empty(
DGLContext{kDGLCPU, 0}); {num_edges}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
IdArray new_dst = IdArray::Empty({num_edges}, DGLDataType{kDGLInt, 64, 1}, IdArray new_dst = IdArray::Empty(
DGLContext{kDGLCPU, 0}); {num_edges}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
dgl_id_t *new_src_data = static_cast<dgl_id_t *>(new_src->data); dgl_id_t *new_src_data = static_cast<dgl_id_t *>(new_src->data);
dgl_id_t *new_dst_data = static_cast<dgl_id_t *>(new_dst->data); dgl_id_t *new_dst_data = static_cast<dgl_id_t *>(new_dst->data);
for (size_t i = 0; i < edge_src.size(); i++) { for (size_t i = 0; i < edge_src.size(); i++) {
...@@ -180,8 +183,8 @@ HaloHeteroSubgraph GetSubgraphWithHalo(std::shared_ptr<HeteroGraph> hg, ...@@ -180,8 +183,8 @@ HaloHeteroSubgraph GetSubgraphWithHalo(std::shared_ptr<HeteroGraph> hg,
dgl_id_t old_nid = old_node_ids[i]; dgl_id_t old_nid = old_node_ids[i];
inner_nodes[i] = all_nodes[old_nid]; inner_nodes[i] = all_nodes[old_nid];
} }
aten::COOMatrix coo(old_node_ids.size(), old_node_ids.size(), new_src, aten::COOMatrix coo(
new_dst); old_node_ids.size(), old_node_ids.size(), new_src, new_dst);
HeteroGraphPtr ugptr = UnitGraph::CreateFromCOO(1, coo); HeteroGraphPtr ugptr = UnitGraph::CreateFromCOO(1, coo);
HeteroGraphPtr subg = CreateHeteroGraph(hg->meta_graph(), {ugptr}); HeteroGraphPtr subg = CreateHeteroGraph(hg->meta_graph(), {ugptr});
HaloHeteroSubgraph halo_subg; HaloHeteroSubgraph halo_subg;
...@@ -194,83 +197,83 @@ HaloHeteroSubgraph GetSubgraphWithHalo(std::shared_ptr<HeteroGraph> hg, ...@@ -194,83 +197,83 @@ HaloHeteroSubgraph GetSubgraphWithHalo(std::shared_ptr<HeteroGraph> hg,
} }
DGL_REGISTER_GLOBAL("partition._CAPI_DGLReorderGraph_Hetero") DGL_REGISTER_GLOBAL("partition._CAPI_DGLReorderGraph_Hetero")
.set_body([](DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
HeteroGraphRef g = args[0]; HeteroGraphRef g = args[0];
auto hgptr = std::dynamic_pointer_cast<HeteroGraph>(g.sptr()); auto hgptr = std::dynamic_pointer_cast<HeteroGraph>(g.sptr());
CHECK(hgptr) << "Invalid HeteroGraph object"; CHECK(hgptr) << "Invalid HeteroGraph object";
CHECK_EQ(hgptr->relation_graphs().size(), 1) CHECK_EQ(hgptr->relation_graphs().size(), 1)
<< "Reorder only supports HomoGraph"; << "Reorder only supports HomoGraph";
auto ugptr = hgptr->relation_graphs()[0]; auto ugptr = hgptr->relation_graphs()[0];
const IdArray new_order = args[1]; const IdArray new_order = args[1];
auto reorder_ugptr = ReorderUnitGraph(ugptr, new_order); auto reorder_ugptr = ReorderUnitGraph(ugptr, new_order);
std::vector<HeteroGraphPtr> rel_graphs = {reorder_ugptr}; std::vector<HeteroGraphPtr> rel_graphs = {reorder_ugptr};
*rv = HeteroGraphRef(std::make_shared<HeteroGraph>( *rv = HeteroGraphRef(std::make_shared<HeteroGraph>(
hgptr->meta_graph(), rel_graphs, hgptr->NumVerticesPerType())); hgptr->meta_graph(), rel_graphs, hgptr->NumVerticesPerType()));
}); });
DGL_REGISTER_GLOBAL("partition._CAPI_DGLPartitionWithHalo_Hetero") DGL_REGISTER_GLOBAL("partition._CAPI_DGLPartitionWithHalo_Hetero")
.set_body([](DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
HeteroGraphRef g = args[0]; HeteroGraphRef g = args[0];
auto hgptr = std::dynamic_pointer_cast<HeteroGraph>(g.sptr()); auto hgptr = std::dynamic_pointer_cast<HeteroGraph>(g.sptr());
CHECK(hgptr) << "Invalid HeteroGraph object"; CHECK(hgptr) << "Invalid HeteroGraph object";
CHECK_EQ(hgptr->relation_graphs().size(), 1) CHECK_EQ(hgptr->relation_graphs().size(), 1)
<< "Metis partition only supports HomoGraph"; << "Metis partition only supports HomoGraph";
auto ugptr = hgptr->relation_graphs()[0]; auto ugptr = hgptr->relation_graphs()[0];
IdArray node_parts = args[1]; IdArray node_parts = args[1];
int num_hops = args[2]; int num_hops = args[2];
CHECK_EQ(node_parts->dtype.bits, 64) CHECK_EQ(node_parts->dtype.bits, 64)
<< "Only supports 64bits tensor for now"; << "Only supports 64bits tensor for now";
const int64_t *part_data = static_cast<int64_t *>(node_parts->data); const int64_t *part_data = static_cast<int64_t *>(node_parts->data);
int64_t num_nodes = node_parts->shape[0]; int64_t num_nodes = node_parts->shape[0];
std::unordered_map<int, std::vector<int64_t>> part_map; std::unordered_map<int, std::vector<int64_t>> part_map;
for (int64_t i = 0; i < num_nodes; i++) { for (int64_t i = 0; i < num_nodes; i++) {
dgl_id_t part_id = part_data[i]; dgl_id_t part_id = part_data[i];
auto it = part_map.find(part_id); auto it = part_map.find(part_id);
if (it == part_map.end()) { if (it == part_map.end()) {
std::vector<int64_t> vec; std::vector<int64_t> vec;
vec.push_back(i); vec.push_back(i);
part_map[part_id] = vec; part_map[part_id] = vec;
} else { } else {
it->second.push_back(i); it->second.push_back(i);
}
} }
} std::vector<int> part_ids;
std::vector<int> part_ids; std::vector<std::vector<int64_t>> part_nodes;
std::vector<std::vector<int64_t>> part_nodes; int max_part_id = 0;
int max_part_id = 0; for (auto it = part_map.begin(); it != part_map.end(); it++) {
for (auto it = part_map.begin(); it != part_map.end(); it++) { max_part_id = std::max(it->first, max_part_id);
max_part_id = std::max(it->first, max_part_id); part_ids.push_back(it->first);
part_ids.push_back(it->first); part_nodes.push_back(it->second);
part_nodes.push_back(it->second); }
} // When we construct subgraphs, we need to access both in-edges and
// When we construct subgraphs, we need to access both in-edges and out-edges. // out-edges. We need to make sure the in-CSR and out-CSR exist.
// We need to make sure the in-CSR and out-CSR exist. Otherwise, we'll // Otherwise, we'll try to construct in-CSR and out-CSR in openmp for
// try to construct in-CSR and out-CSR in openmp for loop, which will lead // loop, which will lead to some unexpected results.
// to some unexpected results. ugptr->GetInCSR();
ugptr->GetInCSR(); ugptr->GetOutCSR();
ugptr->GetOutCSR(); std::vector<std::shared_ptr<HaloHeteroSubgraph>> subgs(max_part_id + 1);
std::vector<std::shared_ptr<HaloHeteroSubgraph>> subgs(max_part_id + 1); int num_partitions = part_nodes.size();
int num_partitions = part_nodes.size(); runtime::parallel_for(0, num_partitions, [&](int b, int e) {
runtime::parallel_for(0, num_partitions, [&](int b, int e) { for (auto i = b; i < e; i++) {
for (auto i = b; i < e; i++) { auto nodes = aten::VecToIdArray(part_nodes[i]);
auto nodes = aten::VecToIdArray(part_nodes[i]); HaloHeteroSubgraph subg = GetSubgraphWithHalo(hgptr, nodes, num_hops);
HaloHeteroSubgraph subg = GetSubgraphWithHalo(hgptr, nodes, num_hops); std::shared_ptr<HaloHeteroSubgraph> subg_ptr(
std::shared_ptr<HaloHeteroSubgraph> subg_ptr( new HaloHeteroSubgraph(subg));
new HaloHeteroSubgraph(subg)); int part_id = part_ids[i];
int part_id = part_ids[i]; subgs[part_id] = subg_ptr;
subgs[part_id] = subg_ptr; }
});
List<HeteroSubgraphRef> ret_list;
for (size_t i = 0; i < subgs.size(); i++) {
ret_list.push_back(HeteroSubgraphRef(subgs[i]));
} }
*rv = ret_list;
}); });
List<HeteroSubgraphRef> ret_list;
for (size_t i = 0; i < subgs.size(); i++) {
ret_list.push_back(HeteroSubgraphRef(subgs[i]));
}
*rv = ret_list;
});
template<class IdType> template <class IdType>
struct EdgeProperty { struct EdgeProperty {
IdType eid; IdType eid;
int64_t idx; int64_t idx;
...@@ -280,98 +283,101 @@ struct EdgeProperty { ...@@ -280,98 +283,101 @@ struct EdgeProperty {
// Reassign edge IDs so that all edges in a partition have contiguous edge IDs. // Reassign edge IDs so that all edges in a partition have contiguous edge IDs.
// The original edge IDs are returned. // The original edge IDs are returned.
DGL_REGISTER_GLOBAL("partition._CAPI_DGLReassignEdges_Hetero") DGL_REGISTER_GLOBAL("partition._CAPI_DGLReassignEdges_Hetero")
.set_body([](DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
HeteroGraphRef g = args[0]; HeteroGraphRef g = args[0];
auto hgptr = std::dynamic_pointer_cast<HeteroGraph>(g.sptr()); auto hgptr = std::dynamic_pointer_cast<HeteroGraph>(g.sptr());
CHECK(hgptr) << "Invalid HeteroGraph object"; CHECK(hgptr) << "Invalid HeteroGraph object";
CHECK_EQ(hgptr->relation_graphs().size(), 1) CHECK_EQ(hgptr->relation_graphs().size(), 1)
<< "Reorder only supports HomoGraph"; << "Reorder only supports HomoGraph";
auto ugptr = hgptr->relation_graphs()[0]; auto ugptr = hgptr->relation_graphs()[0];
IdArray etype = args[1]; IdArray etype = args[1];
IdArray part_id = args[2]; IdArray part_id = args[2];
bool is_incsr = args[3]; bool is_incsr = args[3];
auto csrmat = is_incsr ? ugptr->GetCSCMatrix(0) : ugptr->GetCSRMatrix(0); auto csrmat = is_incsr ? ugptr->GetCSCMatrix(0) : ugptr->GetCSRMatrix(0);
int64_t num_edges = csrmat.data->shape[0]; int64_t num_edges = csrmat.data->shape[0];
int64_t num_rows = csrmat.indptr->shape[0] - 1; int64_t num_rows = csrmat.indptr->shape[0] - 1;
IdArray new_data = IdArray new_data =
IdArray::Empty({num_edges}, csrmat.data->dtype, csrmat.data->ctx); IdArray::Empty({num_edges}, csrmat.data->dtype, csrmat.data->ctx);
// Return the original edge Ids. // Return the original edge Ids.
*rv = new_data; *rv = new_data;
// Generate new edge Ids. // Generate new edge Ids.
ATEN_ID_TYPE_SWITCH(new_data->dtype, IdType, { ATEN_ID_TYPE_SWITCH(new_data->dtype, IdType, {
CHECK(etype->dtype.bits == sizeof(IdType) * 8); CHECK(etype->dtype.bits == sizeof(IdType) * 8);
CHECK(part_id->dtype.bits == sizeof(IdType) * 8); CHECK(part_id->dtype.bits == sizeof(IdType) * 8);
const IdType *part_id_data = static_cast<IdType *>(part_id->data); const IdType *part_id_data = static_cast<IdType *>(part_id->data);
const IdType *etype_data = static_cast<IdType *>(etype->data); const IdType *etype_data = static_cast<IdType *>(etype->data);
const IdType *indptr_data = static_cast<IdType *>(csrmat.indptr->data); const IdType *indptr_data = static_cast<IdType *>(csrmat.indptr->data);
IdType *typed_data = static_cast<IdType *>(csrmat.data->data); IdType *typed_data = static_cast<IdType *>(csrmat.data->data);
IdType *typed_new_data = static_cast<IdType *>(new_data->data); IdType *typed_new_data = static_cast<IdType *>(new_data->data);
std::vector<EdgeProperty<IdType>> indexed_eids(num_edges); std::vector<EdgeProperty<IdType>> indexed_eids(num_edges);
for (int64_t i = 0; i < num_rows; i++) { for (int64_t i = 0; i < num_rows; i++) {
for (int64_t j = indptr_data[i]; j < indptr_data[i + 1]; j++) { for (int64_t j = indptr_data[i]; j < indptr_data[i + 1]; j++) {
indexed_eids[j].eid = typed_data[j]; indexed_eids[j].eid = typed_data[j];
indexed_eids[j].idx = j; indexed_eids[j].idx = j;
indexed_eids[j].part_id = part_id_data[i]; indexed_eids[j].part_id = part_id_data[i];
}
} }
} auto comp = [etype_data](
auto comp = [etype_data](const EdgeProperty<IdType> &a, const EdgeProperty<IdType> &b) { const EdgeProperty<IdType> &a,
if (a.part_id == b.part_id) { const EdgeProperty<IdType> &b) {
return etype_data[a.eid] < etype_data[b.eid]; if (a.part_id == b.part_id) {
} else { return etype_data[a.eid] < etype_data[b.eid];
return a.part_id < b.part_id; } else {
return a.part_id < b.part_id;
}
};
// We only need to sort the edges if the input graph has multiple
// relations. If it's a homogeneous grap, we'll just assign edge Ids
// based on its previous order.
if (etype->shape[0] > 0) {
std::sort(indexed_eids.begin(), indexed_eids.end(), comp);
} }
}; for (int64_t new_eid = 0; new_eid < num_edges; new_eid++) {
// We only need to sort the edges if the input graph has multiple relations. int64_t orig_idx = indexed_eids[new_eid].idx;
// If it's a homogeneous grap, we'll just assign edge Ids based on its previous order. typed_new_data[new_eid] = typed_data[orig_idx];
if (etype->shape[0] > 0) { typed_data[orig_idx] = new_eid;
std::sort(indexed_eids.begin(), indexed_eids.end(), comp); }
} });
for (int64_t new_eid = 0; new_eid < num_edges; new_eid++) { ugptr->InvalidateCSR();
int64_t orig_idx = indexed_eids[new_eid].idx; ugptr->InvalidateCOO();
typed_new_data[new_eid] = typed_data[orig_idx];
typed_data[orig_idx] = new_eid;
}
}); });
ugptr->InvalidateCSR();
ugptr->InvalidateCOO();
});
DGL_REGISTER_GLOBAL("partition._CAPI_GetHaloSubgraphInnerNodes_Hetero") DGL_REGISTER_GLOBAL("partition._CAPI_GetHaloSubgraphInnerNodes_Hetero")
.set_body([](DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
HeteroSubgraphRef g = args[0]; HeteroSubgraphRef g = args[0];
auto gptr = std::dynamic_pointer_cast<HaloHeteroSubgraph>(g.sptr()); auto gptr = std::dynamic_pointer_cast<HaloHeteroSubgraph>(g.sptr());
CHECK(gptr) << "The input graph has to be HaloHeteroSubgraph"; CHECK(gptr) << "The input graph has to be HaloHeteroSubgraph";
*rv = gptr->inner_nodes[0]; *rv = gptr->inner_nodes[0];
}); });
DGL_REGISTER_GLOBAL("partition._CAPI_DGLMakeSymmetric_Hetero") DGL_REGISTER_GLOBAL("partition._CAPI_DGLMakeSymmetric_Hetero")
.set_body([](DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
HeteroGraphRef g = args[0]; HeteroGraphRef g = args[0];
auto hgptr = std::dynamic_pointer_cast<HeteroGraph>(g.sptr()); auto hgptr = std::dynamic_pointer_cast<HeteroGraph>(g.sptr());
CHECK(hgptr) << "Invalid HeteroGraph object"; CHECK(hgptr) << "Invalid HeteroGraph object";
CHECK_EQ(hgptr->relation_graphs().size(), 1) CHECK_EQ(hgptr->relation_graphs().size(), 1)
<< "Metis partition only supports homogeneous graph"; << "Metis partition only supports homogeneous graph";
auto ugptr = hgptr->relation_graphs()[0]; auto ugptr = hgptr->relation_graphs()[0];
#if !defined(_WIN32) #if !defined(_WIN32)
// TODO(zhengda) should we get whatever CSR exists in the graph. // TODO(zhengda) should we get whatever CSR exists in the graph.
gk_csr_t *gk_csr = Convert2GKCsr(ugptr->GetCSCMatrix(0), true); gk_csr_t *gk_csr = Convert2GKCsr(ugptr->GetCSCMatrix(0), true);
gk_csr_t *sym_gk_csr = gk_csr_MakeSymmetric(gk_csr, GK_CSR_SYM_SUM); gk_csr_t *sym_gk_csr = gk_csr_MakeSymmetric(gk_csr, GK_CSR_SYM_SUM);
auto mat = Convert2DGLCsr(sym_gk_csr, true); auto mat = Convert2DGLCsr(sym_gk_csr, true);
gk_csr_Free(&gk_csr); gk_csr_Free(&gk_csr);
gk_csr_Free(&sym_gk_csr); gk_csr_Free(&sym_gk_csr);
auto new_ugptr = UnitGraph::CreateFromCSC(ugptr->NumVertexTypes(), mat, auto new_ugptr = UnitGraph::CreateFromCSC(
ugptr->GetAllowedFormats()); ugptr->NumVertexTypes(), mat, ugptr->GetAllowedFormats());
std::vector<HeteroGraphPtr> rel_graphs = {new_ugptr}; std::vector<HeteroGraphPtr> rel_graphs = {new_ugptr};
*rv = HeteroGraphRef(std::make_shared<HeteroGraph>( *rv = HeteroGraphRef(std::make_shared<HeteroGraph>(
hgptr->meta_graph(), rel_graphs, hgptr->NumVerticesPerType())); hgptr->meta_graph(), rel_graphs, hgptr->NumVerticesPerType()));
#else #else
LOG(FATAL) << "The fast version of making symmetric graph is not supported in Windows."; LOG(FATAL) << "The fast version of making symmetric graph is not "
"supported in Windows.";
#endif // !defined(_WIN32) #endif // !defined(_WIN32)
}); });
} // namespace transform } // namespace transform
} // namespace dgl } // namespace dgl
...@@ -4,15 +4,16 @@ ...@@ -4,15 +4,16 @@
* \brief Remove edges. * \brief Remove edges.
*/ */
#include <dgl/base_heterograph.h>
#include <dgl/transform.h>
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/base_heterograph.h>
#include <dgl/packed_func_ext.h> #include <dgl/packed_func_ext.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/container.h> #include <dgl/runtime/container.h>
#include <vector> #include <dgl/runtime/registry.h>
#include <utility> #include <dgl/transform.h>
#include <tuple> #include <tuple>
#include <utility>
#include <vector>
namespace dgl { namespace dgl {
...@@ -21,8 +22,8 @@ using namespace dgl::aten; ...@@ -21,8 +22,8 @@ using namespace dgl::aten;
namespace transform { namespace transform {
std::pair<HeteroGraphPtr, std::vector<IdArray>> std::pair<HeteroGraphPtr, std::vector<IdArray>> RemoveEdges(
RemoveEdges(const HeteroGraphPtr graph, const std::vector<IdArray> &eids) { const HeteroGraphPtr graph, const std::vector<IdArray> &eids) {
std::vector<IdArray> induced_eids; std::vector<IdArray> induced_eids;
std::vector<HeteroGraphPtr> rel_graphs; std::vector<HeteroGraphPtr> rel_graphs;
const int64_t num_etypes = graph->NumEdgeTypes(); const int64_t num_etypes = graph->NumEdgeTypes();
...@@ -40,23 +41,30 @@ RemoveEdges(const HeteroGraphPtr graph, const std::vector<IdArray> &eids) { ...@@ -40,23 +41,30 @@ RemoveEdges(const HeteroGraphPtr graph, const std::vector<IdArray> &eids) {
const COOMatrix &coo = graph->GetCOOMatrix(etype); const COOMatrix &coo = graph->GetCOOMatrix(etype);
const COOMatrix &result = COORemove(coo, eids[etype]); const COOMatrix &result = COORemove(coo, eids[etype]);
new_rel_graph = CreateFromCOO( new_rel_graph = CreateFromCOO(
num_ntypes_rel, result.num_rows, result.num_cols, result.row, result.col); num_ntypes_rel, result.num_rows, result.num_cols, result.row,
result.col);
induced_eids_rel = result.data; induced_eids_rel = result.data;
} else if (fmt == SparseFormat::kCSR) { } else if (fmt == SparseFormat::kCSR) {
const CSRMatrix &csr = graph->GetCSRMatrix(etype); const CSRMatrix &csr = graph->GetCSRMatrix(etype);
const CSRMatrix &result = CSRRemove(csr, eids[etype]); const CSRMatrix &result = CSRRemove(csr, eids[etype]);
new_rel_graph = CreateFromCSR( new_rel_graph = CreateFromCSR(
num_ntypes_rel, result.num_rows, result.num_cols, result.indptr, result.indices, num_ntypes_rel, result.num_rows, result.num_cols, result.indptr,
result.indices,
// TODO(BarclayII): make CSR support null eid array // TODO(BarclayII): make CSR support null eid array
Range(0, result.indices->shape[0], result.indices->dtype.bits, result.indices->ctx)); Range(
0, result.indices->shape[0], result.indices->dtype.bits,
result.indices->ctx));
induced_eids_rel = result.data; induced_eids_rel = result.data;
} else if (fmt == SparseFormat::kCSC) { } else if (fmt == SparseFormat::kCSC) {
const CSRMatrix &csc = graph->GetCSCMatrix(etype); const CSRMatrix &csc = graph->GetCSCMatrix(etype);
const CSRMatrix &result = CSRRemove(csc, eids[etype]); const CSRMatrix &result = CSRRemove(csc, eids[etype]);
new_rel_graph = CreateFromCSC( new_rel_graph = CreateFromCSC(
num_ntypes_rel, result.num_rows, result.num_cols, result.indptr, result.indices, num_ntypes_rel, result.num_rows, result.num_cols, result.indptr,
result.indices,
// TODO(BarclayII): make CSR support null eid array // TODO(BarclayII): make CSR support null eid array
Range(0, result.indices->shape[0], result.indices->dtype.bits, result.indices->ctx)); Range(
0, result.indices->shape[0], result.indices->dtype.bits,
result.indices->ctx));
induced_eids_rel = result.data; induced_eids_rel = result.data;
} }
...@@ -70,24 +78,24 @@ RemoveEdges(const HeteroGraphPtr graph, const std::vector<IdArray> &eids) { ...@@ -70,24 +78,24 @@ RemoveEdges(const HeteroGraphPtr graph, const std::vector<IdArray> &eids) {
} }
DGL_REGISTER_GLOBAL("transform._CAPI_DGLRemoveEdges") DGL_REGISTER_GLOBAL("transform._CAPI_DGLRemoveEdges")
.set_body([] (DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
const HeteroGraphRef graph_ref = args[0]; const HeteroGraphRef graph_ref = args[0];
const std::vector<IdArray> &eids = ListValueToVector<IdArray>(args[1]); const std::vector<IdArray> &eids = ListValueToVector<IdArray>(args[1]);
HeteroGraphPtr new_graph; HeteroGraphPtr new_graph;
std::vector<IdArray> induced_eids; std::vector<IdArray> induced_eids;
std::tie(new_graph, induced_eids) = RemoveEdges(graph_ref.sptr(), eids); std::tie(new_graph, induced_eids) = RemoveEdges(graph_ref.sptr(), eids);
List<Value> induced_eids_ref; List<Value> induced_eids_ref;
for (IdArray &array : induced_eids) for (IdArray &array : induced_eids)
induced_eids_ref.push_back(Value(MakeValue(array))); induced_eids_ref.push_back(Value(MakeValue(array)));
List<ObjectRef> ret; List<ObjectRef> ret;
ret.push_back(HeteroGraphRef(new_graph)); ret.push_back(HeteroGraphRef(new_graph));
ret.push_back(induced_eids_ref); ret.push_back(induced_eids_ref);
*rv = ret; *rv = ret;
}); });
}; // namespace transform }; // namespace transform
......
...@@ -19,16 +19,18 @@ ...@@ -19,16 +19,18 @@
#include "to_bipartite.h" #include "to_bipartite.h"
#include <dgl/base_heterograph.h>
#include <dgl/transform.h>
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/packed_func_ext.h> #include <dgl/base_heterograph.h>
#include <dgl/immutable_graph.h> #include <dgl/immutable_graph.h>
#include <dgl/runtime/registry.h> #include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h> #include <dgl/runtime/container.h>
#include <vector> #include <dgl/runtime/registry.h>
#include <dgl/transform.h>
#include <tuple> #include <tuple>
#include <utility> #include <utility>
#include <vector>
#include "../../array/cpu/array_utils.h" #include "../../array/cpu/array_utils.h"
namespace dgl { namespace dgl {
...@@ -42,11 +44,11 @@ namespace { ...@@ -42,11 +44,11 @@ namespace {
// Since partial specialization is not allowed for functions, use this as an // Since partial specialization is not allowed for functions, use this as an
// intermediate for ToBlock where XPU = kDGLCPU. // intermediate for ToBlock where XPU = kDGLCPU.
template<typename IdType> template <typename IdType>
std::tuple<HeteroGraphPtr, std::vector<IdArray>> std::tuple<HeteroGraphPtr, std::vector<IdArray>> ToBlockCPU(
ToBlockCPU(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes, HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes,
bool include_rhs_in_lhs, std::vector<IdArray>* const lhs_nodes_ptr) { bool include_rhs_in_lhs, std::vector<IdArray> *const lhs_nodes_ptr) {
std::vector<IdArray>& lhs_nodes = *lhs_nodes_ptr; std::vector<IdArray> &lhs_nodes = *lhs_nodes_ptr;
const bool generate_lhs_nodes = lhs_nodes.empty(); const bool generate_lhs_nodes = lhs_nodes.empty();
const int64_t num_etypes = graph->NumEdgeTypes(); const int64_t num_etypes = graph->NumEdgeTypes();
...@@ -54,28 +56,29 @@ ToBlockCPU(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes, ...@@ -54,28 +56,29 @@ ToBlockCPU(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes,
std::vector<EdgeArray> edge_arrays(num_etypes); std::vector<EdgeArray> edge_arrays(num_etypes);
CHECK(rhs_nodes.size() == static_cast<size_t>(num_ntypes)) CHECK(rhs_nodes.size() == static_cast<size_t>(num_ntypes))
<< "rhs_nodes not given for every node type"; << "rhs_nodes not given for every node type";
const std::vector<IdHashMap<IdType>> rhs_node_mappings(rhs_nodes.begin(), rhs_nodes.end()); const std::vector<IdHashMap<IdType>> rhs_node_mappings(
rhs_nodes.begin(), rhs_nodes.end());
std::vector<IdHashMap<IdType>> lhs_node_mappings; std::vector<IdHashMap<IdType>> lhs_node_mappings;
if (generate_lhs_nodes) { if (generate_lhs_nodes) {
// build lhs_node_mappings -- if we don't have them already // build lhs_node_mappings -- if we don't have them already
if (include_rhs_in_lhs) if (include_rhs_in_lhs)
lhs_node_mappings = rhs_node_mappings; // copy lhs_node_mappings = rhs_node_mappings; // copy
else else
lhs_node_mappings.resize(num_ntypes); lhs_node_mappings.resize(num_ntypes);
} else { } else {
lhs_node_mappings = std::vector<IdHashMap<IdType>>(lhs_nodes.begin(), lhs_nodes.end()); lhs_node_mappings =
std::vector<IdHashMap<IdType>>(lhs_nodes.begin(), lhs_nodes.end());
} }
for (int64_t etype = 0; etype < num_etypes; ++etype) { for (int64_t etype = 0; etype < num_etypes; ++etype) {
const auto src_dst_types = graph->GetEndpointTypes(etype); const auto src_dst_types = graph->GetEndpointTypes(etype);
const dgl_type_t srctype = src_dst_types.first; const dgl_type_t srctype = src_dst_types.first;
const dgl_type_t dsttype = src_dst_types.second; const dgl_type_t dsttype = src_dst_types.second;
if (!aten::IsNullArray(rhs_nodes[dsttype])) { if (!aten::IsNullArray(rhs_nodes[dsttype])) {
const EdgeArray& edges = graph->Edges(etype); const EdgeArray &edges = graph->Edges(etype);
if (generate_lhs_nodes) { if (generate_lhs_nodes) {
lhs_node_mappings[srctype].Update(edges.src); lhs_node_mappings[srctype].Update(edges.src);
} }
...@@ -89,8 +92,8 @@ ToBlockCPU(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes, ...@@ -89,8 +92,8 @@ ToBlockCPU(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes,
const auto meta_graph = graph->meta_graph(); const auto meta_graph = graph->meta_graph();
const EdgeArray etypes = meta_graph->Edges("eid"); const EdgeArray etypes = meta_graph->Edges("eid");
const IdArray new_dst = Add(etypes.dst, num_ntypes); const IdArray new_dst = Add(etypes.dst, num_ntypes);
const auto new_meta_graph = ImmutableGraph::CreateFromCOO( const auto new_meta_graph =
num_ntypes * 2, etypes.src, new_dst); ImmutableGraph::CreateFromCOO(num_ntypes * 2, etypes.src, new_dst);
for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) for (int64_t ntype = 0; ntype < num_ntypes; ++ntype)
num_nodes_per_type.push_back(lhs_node_mappings[ntype].Size()); num_nodes_per_type.push_back(lhs_node_mappings[ntype].Size());
...@@ -108,8 +111,8 @@ ToBlockCPU(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes, ...@@ -108,8 +111,8 @@ ToBlockCPU(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes,
if (rhs_map.Size() == 0) { if (rhs_map.Size() == 0) {
// No rhs nodes are given for this edge type. Create an empty graph. // No rhs nodes are given for this edge type. Create an empty graph.
rel_graphs.push_back(CreateFromCOO( rel_graphs.push_back(CreateFromCOO(
2, lhs_map.Size(), rhs_map.Size(), 2, lhs_map.Size(), rhs_map.Size(), aten::NullArray(),
aten::NullArray(), aten::NullArray())); aten::NullArray()));
induced_edges.push_back(aten::NullArray()); induced_edges.push_back(aten::NullArray());
} else { } else {
IdArray new_src = lhs_map.Map(edge_arrays[etype].src, -1); IdArray new_src = lhs_map.Map(edge_arrays[etype].src, -1);
...@@ -117,22 +120,22 @@ ToBlockCPU(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes, ...@@ -117,22 +120,22 @@ ToBlockCPU(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes,
// Check whether there are unmapped IDs and raise error. // Check whether there are unmapped IDs and raise error.
for (int64_t i = 0; i < new_dst->shape[0]; ++i) for (int64_t i = 0; i < new_dst->shape[0]; ++i)
CHECK_NE(new_dst.Ptr<IdType>()[i], -1) CHECK_NE(new_dst.Ptr<IdType>()[i], -1)
<< "Node " << edge_arrays[etype].dst.Ptr<IdType>()[i] << " does not exist" << "Node " << edge_arrays[etype].dst.Ptr<IdType>()[i]
<< " in `rhs_nodes`. Argument `rhs_nodes` must contain all the edge" << " does not exist"
<< " destination nodes."; << " in `rhs_nodes`. Argument `rhs_nodes` must contain all the edge"
rel_graphs.push_back(CreateFromCOO( << " destination nodes.";
2, lhs_map.Size(), rhs_map.Size(), rel_graphs.push_back(
new_src, new_dst)); CreateFromCOO(2, lhs_map.Size(), rhs_map.Size(), new_src, new_dst));
induced_edges.push_back(edge_arrays[etype].id); induced_edges.push_back(edge_arrays[etype].id);
} }
} }
const HeteroGraphPtr new_graph = CreateHeteroGraph( const HeteroGraphPtr new_graph =
new_meta_graph, rel_graphs, num_nodes_per_type); CreateHeteroGraph(new_meta_graph, rel_graphs, num_nodes_per_type);
if (generate_lhs_nodes) { if (generate_lhs_nodes) {
CHECK_EQ(lhs_nodes.size(), 0) << "InteralError: lhs_nodes should be empty " CHECK_EQ(lhs_nodes.size(), 0) << "InteralError: lhs_nodes should be empty "
"when generating it."; "when generating it.";
for (const IdHashMap<IdType> &lhs_map : lhs_node_mappings) for (const IdHashMap<IdType> &lhs_map : lhs_node_mappings)
lhs_nodes.push_back(lhs_map.Values()); lhs_nodes.push_back(lhs_map.Values());
} }
...@@ -141,87 +144,83 @@ ToBlockCPU(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes, ...@@ -141,87 +144,83 @@ ToBlockCPU(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes,
} // namespace } // namespace
template<> template <>
std::tuple<HeteroGraphPtr, std::vector<IdArray>> std::tuple<HeteroGraphPtr, std::vector<IdArray>> ToBlock<kDGLCPU, int32_t>(
ToBlock<kDGLCPU, int32_t>(HeteroGraphPtr graph, HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes,
const std::vector<IdArray> &rhs_nodes, bool include_rhs_in_lhs, std::vector<IdArray> *const lhs_nodes) {
bool include_rhs_in_lhs,
std::vector<IdArray>* const lhs_nodes) {
return ToBlockCPU<int32_t>(graph, rhs_nodes, include_rhs_in_lhs, lhs_nodes); return ToBlockCPU<int32_t>(graph, rhs_nodes, include_rhs_in_lhs, lhs_nodes);
} }
template<> template <>
std::tuple<HeteroGraphPtr, std::vector<IdArray>> std::tuple<HeteroGraphPtr, std::vector<IdArray>> ToBlock<kDGLCPU, int64_t>(
ToBlock<kDGLCPU, int64_t>(HeteroGraphPtr graph, HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes,
const std::vector<IdArray> &rhs_nodes, bool include_rhs_in_lhs, std::vector<IdArray> *const lhs_nodes) {
bool include_rhs_in_lhs,
std::vector<IdArray>* const lhs_nodes) {
return ToBlockCPU<int64_t>(graph, rhs_nodes, include_rhs_in_lhs, lhs_nodes); return ToBlockCPU<int64_t>(graph, rhs_nodes, include_rhs_in_lhs, lhs_nodes);
} }
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
// Forward declaration of GPU ToBlock implementations - actual implementation is in // Forward declaration of GPU ToBlock implementations - actual implementation is
// in
// ./cuda/cuda_to_block.cu // ./cuda/cuda_to_block.cu
// This is to get around the broken name mangling in VS2019 CL 16.5.5 + CUDA 11.3 // This is to get around the broken name mangling in VS2019 CL 16.5.5 +
// which complains that the two template specializations have the same signature. // CUDA 11.3 which complains that the two template specializations have the same
std::tuple<HeteroGraphPtr, std::vector<IdArray>> // signature.
ToBlockGPU32(HeteroGraphPtr, const std::vector<IdArray>&, bool, std::vector<IdArray>* const); std::tuple<HeteroGraphPtr, std::vector<IdArray>> ToBlockGPU32(
std::tuple<HeteroGraphPtr, std::vector<IdArray>> HeteroGraphPtr, const std::vector<IdArray> &, bool,
ToBlockGPU64(HeteroGraphPtr, const std::vector<IdArray>&, bool, std::vector<IdArray>* const); std::vector<IdArray> *const);
std::tuple<HeteroGraphPtr, std::vector<IdArray>> ToBlockGPU64(
template<> HeteroGraphPtr, const std::vector<IdArray> &, bool,
std::tuple<HeteroGraphPtr, std::vector<IdArray>> std::vector<IdArray> *const);
ToBlock<kDGLCUDA, int32_t>(HeteroGraphPtr graph,
const std::vector<IdArray> &rhs_nodes, template <>
bool include_rhs_in_lhs, std::tuple<HeteroGraphPtr, std::vector<IdArray>> ToBlock<kDGLCUDA, int32_t>(
std::vector<IdArray>* const lhs_nodes) { HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes,
bool include_rhs_in_lhs, std::vector<IdArray> *const lhs_nodes) {
return ToBlockGPU32(graph, rhs_nodes, include_rhs_in_lhs, lhs_nodes); return ToBlockGPU32(graph, rhs_nodes, include_rhs_in_lhs, lhs_nodes);
} }
template<> template <>
std::tuple<HeteroGraphPtr, std::vector<IdArray>> std::tuple<HeteroGraphPtr, std::vector<IdArray>> ToBlock<kDGLCUDA, int64_t>(
ToBlock<kDGLCUDA, int64_t>(HeteroGraphPtr graph, HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes,
const std::vector<IdArray> &rhs_nodes, bool include_rhs_in_lhs, std::vector<IdArray> *const lhs_nodes) {
bool include_rhs_in_lhs,
std::vector<IdArray>* const lhs_nodes) {
return ToBlockGPU64(graph, rhs_nodes, include_rhs_in_lhs, lhs_nodes); return ToBlockGPU64(graph, rhs_nodes, include_rhs_in_lhs, lhs_nodes);
} }
#endif // DGL_USE_CUDA #endif // DGL_USE_CUDA
DGL_REGISTER_GLOBAL("transform._CAPI_DGLToBlock") DGL_REGISTER_GLOBAL("transform._CAPI_DGLToBlock")
.set_body([] (DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
const HeteroGraphRef graph_ref = args[0]; const HeteroGraphRef graph_ref = args[0];
const std::vector<IdArray> &rhs_nodes = ListValueToVector<IdArray>(args[1]); const std::vector<IdArray> &rhs_nodes =
const bool include_rhs_in_lhs = args[2]; ListValueToVector<IdArray>(args[1]);
std::vector<IdArray> lhs_nodes = ListValueToVector<IdArray>(args[3]); const bool include_rhs_in_lhs = args[2];
std::vector<IdArray> lhs_nodes = ListValueToVector<IdArray>(args[3]);
HeteroGraphPtr new_graph;
std::vector<IdArray> induced_edges; HeteroGraphPtr new_graph;
std::vector<IdArray> induced_edges;
ATEN_XPU_SWITCH_CUDA(graph_ref->Context().device_type, XPU, "ToBlock", {
ATEN_ID_TYPE_SWITCH(graph_ref->DataType(), IdType, { ATEN_XPU_SWITCH_CUDA(graph_ref->Context().device_type, XPU, "ToBlock", {
std::tie(new_graph, induced_edges) = ToBlock<XPU, IdType>( ATEN_ID_TYPE_SWITCH(graph_ref->DataType(), IdType, {
graph_ref.sptr(), rhs_nodes, include_rhs_in_lhs, std::tie(new_graph, induced_edges) = ToBlock<XPU, IdType>(
&lhs_nodes); graph_ref.sptr(), rhs_nodes, include_rhs_in_lhs, &lhs_nodes);
});
}); });
});
List<Value> lhs_nodes_ref; List<Value> lhs_nodes_ref;
for (IdArray &array : lhs_nodes) for (IdArray &array : lhs_nodes)
lhs_nodes_ref.push_back(Value(MakeValue(array))); lhs_nodes_ref.push_back(Value(MakeValue(array)));
List<Value> induced_edges_ref; List<Value> induced_edges_ref;
for (IdArray &array : induced_edges) for (IdArray &array : induced_edges)
induced_edges_ref.push_back(Value(MakeValue(array))); induced_edges_ref.push_back(Value(MakeValue(array)));
List<ObjectRef> ret; List<ObjectRef> ret;
ret.push_back(HeteroGraphRef(new_graph)); ret.push_back(HeteroGraphRef(new_graph));
ret.push_back(lhs_nodes_ref); ret.push_back(lhs_nodes_ref);
ret.push_back(induced_edges_ref); ret.push_back(induced_edges_ref);
*rv = ret; *rv = ret;
}); });
}; // namespace transform }; // namespace transform
......
...@@ -44,10 +44,10 @@ namespace transform { ...@@ -44,10 +44,10 @@ namespace transform {
* *
* @return The block and the induced edges. * @return The block and the induced edges.
*/ */
template<DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
std::tuple<HeteroGraphPtr, std::vector<IdArray>> std::tuple<HeteroGraphPtr, std::vector<IdArray>> ToBlock(
ToBlock(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes, HeteroGraphPtr graph, const std::vector<IdArray>& rhs_nodes,
bool include_rhs_in_lhs, std::vector<IdArray>* lhs_nodes); bool include_rhs_in_lhs, std::vector<IdArray>* lhs_nodes);
} // namespace transform } // namespace transform
} // namespace dgl } // namespace dgl
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment