Unverified Commit 8f0df39e authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] clang-format auto fix. (#4810)



* [Misc] clang-format auto fix.

* manual

* manual
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 401e1278
...@@ -13,8 +13,8 @@ namespace dgl { ...@@ -13,8 +13,8 @@ namespace dgl {
using dgl::runtime::NDArray; using dgl::runtime::NDArray;
NDArray CreateNDArrayFromRawData(std::vector<int64_t> shape, DGLDataType dtype, NDArray CreateNDArrayFromRawData(
DGLContext ctx, void* raw) { std::vector<int64_t> shape, DGLDataType dtype, DGLContext ctx, void* raw) {
return NDArray::CreateFromRaw(shape, dtype, ctx, raw, true); return NDArray::CreateFromRaw(shape, dtype, ctx, raw, true);
} }
...@@ -25,9 +25,9 @@ void StreamWithBuffer::PushNDArray(const NDArray& tensor) { ...@@ -25,9 +25,9 @@ void StreamWithBuffer::PushNDArray(const NDArray& tensor) {
int ndim = tensor->ndim; int ndim = tensor->ndim;
this->WriteArray(tensor->shape, ndim); this->WriteArray(tensor->shape, ndim);
CHECK(tensor.IsContiguous()) CHECK(tensor.IsContiguous())
<< "StreamWithBuffer only supports contiguous tensor"; << "StreamWithBuffer only supports contiguous tensor";
CHECK_EQ(tensor->byte_offset, 0) CHECK_EQ(tensor->byte_offset, 0)
<< "StreamWithBuffer only supports zero byte offset tensor"; << "StreamWithBuffer only supports zero byte offset tensor";
int type_bytes = tensor->dtype.bits / 8; int type_bytes = tensor->dtype.bits / 8;
int64_t num_elems = 1; int64_t num_elems = 1;
for (int i = 0; i < ndim; ++i) { for (int i = 0; i < ndim; ++i) {
...@@ -40,7 +40,8 @@ void StreamWithBuffer::PushNDArray(const NDArray& tensor) { ...@@ -40,7 +40,8 @@ void StreamWithBuffer::PushNDArray(const NDArray& tensor) {
// If the stream is for remote communication or the data is not stored in // If the stream is for remote communication or the data is not stored in
// shared memory, serialize the data content as a buffer. // shared memory, serialize the data content as a buffer.
this->Write<bool>(false); this->Write<bool>(false);
// If this is a null ndarray, we will not push it into the underlying buffer_list // If this is a null ndarray, we will not push it into the underlying
// buffer_list
if (data_byte_size != 0) { if (data_byte_size != 0) {
buffer_list_.emplace_back(tensor, tensor->data, data_byte_size); buffer_list_.emplace_back(tensor, tensor->data, data_byte_size);
} }
...@@ -90,8 +91,8 @@ NDArray StreamWithBuffer::PopNDArray() { ...@@ -90,8 +91,8 @@ NDArray StreamWithBuffer::PopNDArray() {
// Mean this is a null ndarray // Mean this is a null ndarray
ret = CreateNDArrayFromRawData(shape, dtype, cpu_ctx, nullptr); ret = CreateNDArrayFromRawData(shape, dtype, cpu_ctx, nullptr);
} else { } else {
ret = CreateNDArrayFromRawData(shape, dtype, cpu_ctx, ret = CreateNDArrayFromRawData(
buffer_list_.front().data); shape, dtype, cpu_ctx, buffer_list_.front().data);
buffer_list_.pop_front(); buffer_list_.pop_front();
} }
return ret; return ret;
......
...@@ -31,8 +31,8 @@ using namespace dgl::aten; ...@@ -31,8 +31,8 @@ using namespace dgl::aten;
namespace dgl { namespace dgl {
template <> template <>
NDArray SharedMemManager::CopyToSharedMem<NDArray>(const NDArray &data, NDArray SharedMemManager::CopyToSharedMem<NDArray>(
std::string name) { const NDArray &data, std::string name) {
DGLContext ctx = {kDGLCPU, 0}; DGLContext ctx = {kDGLCPU, 0};
std::vector<int64_t> shape(data->shape, data->shape + data->ndim); std::vector<int64_t> shape(data->shape, data->shape + data->ndim);
strm_->Write(data->ndim); strm_->Write(data->ndim);
...@@ -46,28 +46,29 @@ NDArray SharedMemManager::CopyToSharedMem<NDArray>(const NDArray &data, ...@@ -46,28 +46,29 @@ NDArray SharedMemManager::CopyToSharedMem<NDArray>(const NDArray &data,
return data; return data;
} else { } else {
auto nd = auto nd =
NDArray::EmptyShared(graph_name_ + name, shape, data->dtype, ctx, true); NDArray::EmptyShared(graph_name_ + name, shape, data->dtype, ctx, true);
nd.CopyFrom(data); nd.CopyFrom(data);
return nd; return nd;
} }
} }
template <> template <>
CSRMatrix SharedMemManager::CopyToSharedMem<CSRMatrix>(const CSRMatrix &csr, CSRMatrix SharedMemManager::CopyToSharedMem<CSRMatrix>(
std::string name) { const CSRMatrix &csr, std::string name) {
auto indptr_shared_mem = CopyToSharedMem(csr.indptr, name + "_indptr"); auto indptr_shared_mem = CopyToSharedMem(csr.indptr, name + "_indptr");
auto indices_shared_mem = CopyToSharedMem(csr.indices, name + "_indices"); auto indices_shared_mem = CopyToSharedMem(csr.indices, name + "_indices");
auto data_shared_mem = CopyToSharedMem(csr.data, name + "_data"); auto data_shared_mem = CopyToSharedMem(csr.data, name + "_data");
strm_->Write(csr.num_rows); strm_->Write(csr.num_rows);
strm_->Write(csr.num_cols); strm_->Write(csr.num_cols);
strm_->Write(csr.sorted); strm_->Write(csr.sorted);
return CSRMatrix(csr.num_rows, csr.num_cols, indptr_shared_mem, return CSRMatrix(
indices_shared_mem, data_shared_mem, csr.sorted); csr.num_rows, csr.num_cols, indptr_shared_mem, indices_shared_mem,
data_shared_mem, csr.sorted);
} }
template <> template <>
COOMatrix SharedMemManager::CopyToSharedMem<COOMatrix>(const COOMatrix &coo, COOMatrix SharedMemManager::CopyToSharedMem<COOMatrix>(
std::string name) { const COOMatrix &coo, std::string name) {
auto row_shared_mem = CopyToSharedMem(coo.row, name + "_row"); auto row_shared_mem = CopyToSharedMem(coo.row, name + "_row");
auto col_shared_mem = CopyToSharedMem(coo.col, name + "_col"); auto col_shared_mem = CopyToSharedMem(coo.col, name + "_col");
auto data_shared_mem = CopyToSharedMem(coo.data, name + "_data"); auto data_shared_mem = CopyToSharedMem(coo.data, name + "_data");
...@@ -75,13 +76,14 @@ COOMatrix SharedMemManager::CopyToSharedMem<COOMatrix>(const COOMatrix &coo, ...@@ -75,13 +76,14 @@ COOMatrix SharedMemManager::CopyToSharedMem<COOMatrix>(const COOMatrix &coo,
strm_->Write(coo.num_cols); strm_->Write(coo.num_cols);
strm_->Write(coo.row_sorted); strm_->Write(coo.row_sorted);
strm_->Write(coo.col_sorted); strm_->Write(coo.col_sorted);
return COOMatrix(coo.num_rows, coo.num_cols, row_shared_mem, col_shared_mem, return COOMatrix(
data_shared_mem, coo.row_sorted, coo.col_sorted); coo.num_rows, coo.num_cols, row_shared_mem, col_shared_mem,
data_shared_mem, coo.row_sorted, coo.col_sorted);
} }
template <> template <>
bool SharedMemManager::CreateFromSharedMem<NDArray>(NDArray *nd, bool SharedMemManager::CreateFromSharedMem<NDArray>(
std::string name) { NDArray *nd, std::string name) {
int ndim; int ndim;
DGLContext ctx = {kDGLCPU, 0}; DGLContext ctx = {kDGLCPU, 0};
DGLDataType dtype; DGLDataType dtype;
...@@ -98,15 +100,14 @@ bool SharedMemManager::CreateFromSharedMem<NDArray>(NDArray *nd, ...@@ -98,15 +100,14 @@ bool SharedMemManager::CreateFromSharedMem<NDArray>(NDArray *nd,
if (is_null) { if (is_null) {
*nd = NDArray::Empty(shape, dtype, ctx); *nd = NDArray::Empty(shape, dtype, ctx);
} else { } else {
*nd = *nd = NDArray::EmptyShared(graph_name_ + name, shape, dtype, ctx, false);
NDArray::EmptyShared(graph_name_ + name, shape, dtype, ctx, false);
} }
return true; return true;
} }
template <> template <>
bool SharedMemManager::CreateFromSharedMem<COOMatrix>(COOMatrix *coo, bool SharedMemManager::CreateFromSharedMem<COOMatrix>(
std::string name) { COOMatrix *coo, std::string name) {
CreateFromSharedMem(&coo->row, name + "_row"); CreateFromSharedMem(&coo->row, name + "_row");
CreateFromSharedMem(&coo->col, name + "_col"); CreateFromSharedMem(&coo->col, name + "_col");
CreateFromSharedMem(&coo->data, name + "_data"); CreateFromSharedMem(&coo->data, name + "_data");
...@@ -118,8 +119,8 @@ bool SharedMemManager::CreateFromSharedMem<COOMatrix>(COOMatrix *coo, ...@@ -118,8 +119,8 @@ bool SharedMemManager::CreateFromSharedMem<COOMatrix>(COOMatrix *coo,
} }
template <> template <>
bool SharedMemManager::CreateFromSharedMem<CSRMatrix>(CSRMatrix *csr, bool SharedMemManager::CreateFromSharedMem<CSRMatrix>(
std::string name) { CSRMatrix *csr, std::string name) {
CreateFromSharedMem(&csr->indptr, name + "_indptr"); CreateFromSharedMem(&csr->indptr, name + "_indptr");
CreateFromSharedMem(&csr->indices, name + "_indices"); CreateFromSharedMem(&csr->indices, name + "_indices");
CreateFromSharedMem(&csr->data, name + "_data"); CreateFromSharedMem(&csr->data, name + "_data");
......
...@@ -29,8 +29,7 @@ const size_t SHARED_MEM_METAINFO_SIZE_MAX = 1024 * 32; ...@@ -29,8 +29,7 @@ const size_t SHARED_MEM_METAINFO_SIZE_MAX = 1024 * 32;
class SharedMemManager : public dmlc::Stream { class SharedMemManager : public dmlc::Stream {
public: public:
explicit SharedMemManager(std::string graph_name, dmlc::Stream* strm) explicit SharedMemManager(std::string graph_name, dmlc::Stream* strm)
: graph_name_(graph_name), : graph_name_(graph_name), strm_(strm) {}
strm_(strm) {}
template <typename T> template <typename T>
T CopyToSharedMem(const T& data, std::string name); T CopyToSharedMem(const T& data, std::string name);
......
...@@ -11,7 +11,8 @@ namespace dgl { ...@@ -11,7 +11,8 @@ namespace dgl {
HeteroSubgraph InEdgeGraphRelabelNodes( HeteroSubgraph InEdgeGraphRelabelNodes(
const HeteroGraphPtr graph, const std::vector<IdArray>& vids) { const HeteroGraphPtr graph, const std::vector<IdArray>& vids) {
CHECK_EQ(vids.size(), graph->NumVertexTypes()) CHECK_EQ(vids.size(), graph->NumVertexTypes())
<< "Invalid input: the input list size must be the same as the number of vertex types."; << "Invalid input: the input list size must be the same as the number of "
"vertex types.";
std::vector<IdArray> eids(graph->NumEdgeTypes()); std::vector<IdArray> eids(graph->NumEdgeTypes());
DGLContext ctx = aten::GetContextOf(vids); DGLContext ctx = aten::GetContextOf(vids);
for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) { for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
...@@ -29,9 +30,11 @@ HeteroSubgraph InEdgeGraphRelabelNodes( ...@@ -29,9 +30,11 @@ HeteroSubgraph InEdgeGraphRelabelNodes(
HeteroSubgraph InEdgeGraphNoRelabelNodes( HeteroSubgraph InEdgeGraphNoRelabelNodes(
const HeteroGraphPtr graph, const std::vector<IdArray>& vids) { const HeteroGraphPtr graph, const std::vector<IdArray>& vids) {
// TODO(mufei): This should also use EdgeSubgraph once it is supported for CSR graphs // TODO(mufei): This should also use EdgeSubgraph once it is supported for CSR
// graphs
CHECK_EQ(vids.size(), graph->NumVertexTypes()) CHECK_EQ(vids.size(), graph->NumVertexTypes())
<< "Invalid input: the input list size must be the same as the number of vertex types."; << "Invalid input: the input list size must be the same as the number of "
"vertex types.";
std::vector<HeteroGraphPtr> subrels(graph->NumEdgeTypes()); std::vector<HeteroGraphPtr> subrels(graph->NumEdgeTypes());
std::vector<IdArray> induced_edges(graph->NumEdgeTypes()); std::vector<IdArray> induced_edges(graph->NumEdgeTypes());
DGLContext ctx = aten::GetContextOf(vids); DGLContext ctx = aten::GetContextOf(vids);
...@@ -43,30 +46,28 @@ HeteroSubgraph InEdgeGraphNoRelabelNodes( ...@@ -43,30 +46,28 @@ HeteroSubgraph InEdgeGraphNoRelabelNodes(
if (aten::IsNullArray(vids[dst_vtype])) { if (aten::IsNullArray(vids[dst_vtype])) {
// create a placeholder graph // create a placeholder graph
subrels[etype] = UnitGraph::Empty( subrels[etype] = UnitGraph::Empty(
relgraph->NumVertexTypes(), relgraph->NumVertexTypes(), graph->NumVertices(src_vtype),
graph->NumVertices(src_vtype), graph->NumVertices(dst_vtype), graph->DataType(), ctx);
graph->NumVertices(dst_vtype), induced_edges[etype] =
graph->DataType(), ctx); IdArray::Empty({0}, graph->DataType(), graph->Context());
induced_edges[etype] = IdArray::Empty({0}, graph->DataType(), graph->Context());
} else { } else {
const auto& earr = graph->InEdges(etype, {vids[dst_vtype]}); const auto& earr = graph->InEdges(etype, {vids[dst_vtype]});
subrels[etype] = UnitGraph::CreateFromCOO( subrels[etype] = UnitGraph::CreateFromCOO(
relgraph->NumVertexTypes(), relgraph->NumVertexTypes(), graph->NumVertices(src_vtype),
graph->NumVertices(src_vtype), graph->NumVertices(dst_vtype), earr.src, earr.dst);
graph->NumVertices(dst_vtype),
earr.src,
earr.dst);
induced_edges[etype] = earr.id; induced_edges[etype] = earr.id;
} }
} }
HeteroSubgraph ret; HeteroSubgraph ret;
ret.graph = CreateHeteroGraph(graph->meta_graph(), subrels, graph->NumVerticesPerType()); ret.graph = CreateHeteroGraph(
graph->meta_graph(), subrels, graph->NumVerticesPerType());
ret.induced_edges = std::move(induced_edges); ret.induced_edges = std::move(induced_edges);
return ret; return ret;
} }
HeteroSubgraph InEdgeGraph( HeteroSubgraph InEdgeGraph(
const HeteroGraphPtr graph, const std::vector<IdArray>& vids, bool relabel_nodes) { const HeteroGraphPtr graph, const std::vector<IdArray>& vids,
bool relabel_nodes) {
if (relabel_nodes) { if (relabel_nodes) {
return InEdgeGraphRelabelNodes(graph, vids); return InEdgeGraphRelabelNodes(graph, vids);
} else { } else {
...@@ -77,7 +78,8 @@ HeteroSubgraph InEdgeGraph( ...@@ -77,7 +78,8 @@ HeteroSubgraph InEdgeGraph(
HeteroSubgraph OutEdgeGraphRelabelNodes( HeteroSubgraph OutEdgeGraphRelabelNodes(
const HeteroGraphPtr graph, const std::vector<IdArray>& vids) { const HeteroGraphPtr graph, const std::vector<IdArray>& vids) {
CHECK_EQ(vids.size(), graph->NumVertexTypes()) CHECK_EQ(vids.size(), graph->NumVertexTypes())
<< "Invalid input: the input list size must be the same as the number of vertex types."; << "Invalid input: the input list size must be the same as the number of "
"vertex types.";
std::vector<IdArray> eids(graph->NumEdgeTypes()); std::vector<IdArray> eids(graph->NumEdgeTypes());
DGLContext ctx = aten::GetContextOf(vids); DGLContext ctx = aten::GetContextOf(vids);
for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) { for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
...@@ -95,9 +97,11 @@ HeteroSubgraph OutEdgeGraphRelabelNodes( ...@@ -95,9 +97,11 @@ HeteroSubgraph OutEdgeGraphRelabelNodes(
HeteroSubgraph OutEdgeGraphNoRelabelNodes( HeteroSubgraph OutEdgeGraphNoRelabelNodes(
const HeteroGraphPtr graph, const std::vector<IdArray>& vids) { const HeteroGraphPtr graph, const std::vector<IdArray>& vids) {
// TODO(mufei): This should also use EdgeSubgraph once it is supported for CSR graphs // TODO(mufei): This should also use EdgeSubgraph once it is supported for CSR
// graphs
CHECK_EQ(vids.size(), graph->NumVertexTypes()) CHECK_EQ(vids.size(), graph->NumVertexTypes())
<< "Invalid input: the input list size must be the same as the number of vertex types."; << "Invalid input: the input list size must be the same as the number of "
"vertex types.";
std::vector<HeteroGraphPtr> subrels(graph->NumEdgeTypes()); std::vector<HeteroGraphPtr> subrels(graph->NumEdgeTypes());
std::vector<IdArray> induced_edges(graph->NumEdgeTypes()); std::vector<IdArray> induced_edges(graph->NumEdgeTypes());
DGLContext ctx = aten::GetContextOf(vids); DGLContext ctx = aten::GetContextOf(vids);
...@@ -109,30 +113,28 @@ HeteroSubgraph OutEdgeGraphNoRelabelNodes( ...@@ -109,30 +113,28 @@ HeteroSubgraph OutEdgeGraphNoRelabelNodes(
if (aten::IsNullArray(vids[src_vtype])) { if (aten::IsNullArray(vids[src_vtype])) {
// create a placeholder graph // create a placeholder graph
subrels[etype] = UnitGraph::Empty( subrels[etype] = UnitGraph::Empty(
relgraph->NumVertexTypes(), relgraph->NumVertexTypes(), graph->NumVertices(src_vtype),
graph->NumVertices(src_vtype), graph->NumVertices(dst_vtype), graph->DataType(), ctx);
graph->NumVertices(dst_vtype), induced_edges[etype] =
graph->DataType(), ctx); IdArray::Empty({0}, graph->DataType(), graph->Context());
induced_edges[etype] = IdArray::Empty({0}, graph->DataType(), graph->Context());
} else { } else {
const auto& earr = graph->OutEdges(etype, {vids[src_vtype]}); const auto& earr = graph->OutEdges(etype, {vids[src_vtype]});
subrels[etype] = UnitGraph::CreateFromCOO( subrels[etype] = UnitGraph::CreateFromCOO(
relgraph->NumVertexTypes(), relgraph->NumVertexTypes(), graph->NumVertices(src_vtype),
graph->NumVertices(src_vtype), graph->NumVertices(dst_vtype), earr.src, earr.dst);
graph->NumVertices(dst_vtype),
earr.src,
earr.dst);
induced_edges[etype] = earr.id; induced_edges[etype] = earr.id;
} }
} }
HeteroSubgraph ret; HeteroSubgraph ret;
ret.graph = CreateHeteroGraph(graph->meta_graph(), subrels, graph->NumVerticesPerType()); ret.graph = CreateHeteroGraph(
graph->meta_graph(), subrels, graph->NumVerticesPerType());
ret.induced_edges = std::move(induced_edges); ret.induced_edges = std::move(induced_edges);
return ret; return ret;
} }
HeteroSubgraph OutEdgeGraph( HeteroSubgraph OutEdgeGraph(
const HeteroGraphPtr graph, const std::vector<IdArray>& vids, bool relabel_nodes) { const HeteroGraphPtr graph, const std::vector<IdArray>& vids,
bool relabel_nodes) {
if (relabel_nodes) { if (relabel_nodes) {
return OutEdgeGraphRelabelNodes(graph, vids); return OutEdgeGraphRelabelNodes(graph, vids);
} else { } else {
......
...@@ -19,18 +19,20 @@ ...@@ -19,18 +19,20 @@
#include "compact.h" #include "compact.h"
#include <dgl/base_heterograph.h>
#include <dgl/transform.h>
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/base_heterograph.h>
#include <dgl/packed_func_ext.h> #include <dgl/packed_func_ext.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/container.h> #include <dgl/runtime/container.h>
#include <vector> #include <dgl/runtime/registry.h>
#include <dgl/transform.h>
#include <utility> #include <utility>
#include <vector>
#include "../../c_api_common.h" #include "../../c_api_common.h"
#include "../unit_graph.h" #include "../unit_graph.h"
// TODO(BarclayII): currently CompactGraphs depend on IdHashMap implementation which // TODO(BarclayII): currently CompactGraphs depend on IdHashMap implementation
// only works on CPU. Should fix later to make it device agnostic. // which only works on CPU. Should fix later to make it device agnostic.
#include "../../array/cpu/array_utils.h" #include "../../array/cpu/array_utils.h"
namespace dgl { namespace dgl {
...@@ -42,16 +44,16 @@ namespace transform { ...@@ -42,16 +44,16 @@ namespace transform {
namespace { namespace {
template<typename IdType> template <typename IdType>
std::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>> std::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>> CompactGraphsCPU(
CompactGraphsCPU(
const std::vector<HeteroGraphPtr> &graphs, const std::vector<HeteroGraphPtr> &graphs,
const std::vector<IdArray> &always_preserve) { const std::vector<IdArray> &always_preserve) {
// TODO(BarclayII): check whether the node space and metagraph of each graph is the same. // TODO(BarclayII): check whether the node space and metagraph of each graph
// Step 1: Collect the nodes that has connections for each type. // is the same. Step 1: Collect the nodes that has connections for each type.
const int64_t num_ntypes = graphs[0]->NumVertexTypes(); const int64_t num_ntypes = graphs[0]->NumVertexTypes();
std::vector<aten::IdHashMap<IdType>> hashmaps(num_ntypes); std::vector<aten::IdHashMap<IdType>> hashmaps(num_ntypes);
std::vector<std::vector<EdgeArray>> all_edges(graphs.size()); // all_edges[i][etype] std::vector<std::vector<EdgeArray>> all_edges(
graphs.size()); // all_edges[i][etype]
std::vector<int64_t> max_vertex_cnt(num_ntypes, 0); std::vector<int64_t> max_vertex_cnt(num_ntypes, 0);
for (size_t i = 0; i < graphs.size(); ++i) { for (size_t i = 0; i < graphs.size(); ++i) {
...@@ -98,7 +100,8 @@ CompactGraphsCPU( ...@@ -98,7 +100,8 @@ CompactGraphsCPU(
} }
} }
// Step 2: Relabel the nodes for each type to a smaller ID space and save the mapping. // Step 2: Relabel the nodes for each type to a smaller ID space and save the
// mapping.
std::vector<IdArray> induced_nodes(num_ntypes); std::vector<IdArray> induced_nodes(num_ntypes);
std::vector<int64_t> num_induced_nodes(num_ntypes); std::vector<int64_t> num_induced_nodes(num_ntypes);
for (int64_t i = 0; i < num_ntypes; ++i) { for (int64_t i = 0; i < num_ntypes; ++i) {
...@@ -123,14 +126,12 @@ CompactGraphsCPU( ...@@ -123,14 +126,12 @@ CompactGraphsCPU(
const IdArray mapped_cols = hashmaps[dsttype].Map(edges.dst, -1); const IdArray mapped_cols = hashmaps[dsttype].Map(edges.dst, -1);
rel_graphs.push_back(UnitGraph::CreateFromCOO( rel_graphs.push_back(UnitGraph::CreateFromCOO(
srctype == dsttype ? 1 : 2, srctype == dsttype ? 1 : 2, induced_nodes[srctype]->shape[0],
induced_nodes[srctype]->shape[0], induced_nodes[dsttype]->shape[0], mapped_rows, mapped_cols));
induced_nodes[dsttype]->shape[0],
mapped_rows,
mapped_cols));
} }
new_graphs.push_back(CreateHeteroGraph(meta_graph, rel_graphs, num_induced_nodes)); new_graphs.push_back(
CreateHeteroGraph(meta_graph, rel_graphs, num_induced_nodes));
} }
return std::make_pair(new_graphs, induced_nodes); return std::make_pair(new_graphs, induced_nodes);
...@@ -138,7 +139,7 @@ CompactGraphsCPU( ...@@ -138,7 +139,7 @@ CompactGraphsCPU(
}; // namespace }; // namespace
template<> template <>
std::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>> std::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>>
CompactGraphs<kDGLCPU, int32_t>( CompactGraphs<kDGLCPU, int32_t>(
const std::vector<HeteroGraphPtr> &graphs, const std::vector<HeteroGraphPtr> &graphs,
...@@ -146,7 +147,7 @@ CompactGraphs<kDGLCPU, int32_t>( ...@@ -146,7 +147,7 @@ CompactGraphs<kDGLCPU, int32_t>(
return CompactGraphsCPU<int32_t>(graphs, always_preserve); return CompactGraphsCPU<int32_t>(graphs, always_preserve);
} }
template<> template <>
std::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>> std::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>>
CompactGraphs<kDGLCPU, int64_t>( CompactGraphs<kDGLCPU, int64_t>(
const std::vector<HeteroGraphPtr> &graphs, const std::vector<HeteroGraphPtr> &graphs,
...@@ -155,44 +156,44 @@ CompactGraphs<kDGLCPU, int64_t>( ...@@ -155,44 +156,44 @@ CompactGraphs<kDGLCPU, int64_t>(
} }
DGL_REGISTER_GLOBAL("transform._CAPI_DGLCompactGraphs") DGL_REGISTER_GLOBAL("transform._CAPI_DGLCompactGraphs")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
List<HeteroGraphRef> graph_refs = args[0]; List<HeteroGraphRef> graph_refs = args[0];
List<Value> always_preserve_refs = args[1]; List<Value> always_preserve_refs = args[1];
std::vector<HeteroGraphPtr> graphs; std::vector<HeteroGraphPtr> graphs;
std::vector<IdArray> always_preserve; std::vector<IdArray> always_preserve;
for (HeteroGraphRef gref : graph_refs) for (HeteroGraphRef gref : graph_refs) graphs.push_back(gref.sptr());
graphs.push_back(gref.sptr()); for (Value array : always_preserve_refs)
for (Value array : always_preserve_refs) always_preserve.push_back(array->data);
always_preserve.push_back(array->data);
// TODO(BarclayII): check for all IdArrays
// TODO(BarclayII): check for all IdArrays CHECK(graphs[0]->DataType() == always_preserve[0]->dtype)
CHECK(graphs[0]->DataType() == always_preserve[0]->dtype) << "data type mismatch."; << "data type mismatch.";
std::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>> result_pair; std::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>> result_pair;
ATEN_XPU_SWITCH_CUDA(graphs[0]->Context().device_type, XPU, "CompactGraphs", { ATEN_XPU_SWITCH_CUDA(
ATEN_ID_TYPE_SWITCH(graphs[0]->DataType(), IdType, { graphs[0]->Context().device_type, XPU, "CompactGraphs", {
result_pair = CompactGraphs<XPU, IdType>( ATEN_ID_TYPE_SWITCH(graphs[0]->DataType(), IdType, {
graphs, always_preserve); result_pair = CompactGraphs<XPU, IdType>(graphs, always_preserve);
}); });
});
List<HeteroGraphRef> compacted_graph_refs;
List<Value> induced_nodes;
for (const HeteroGraphPtr g : result_pair.first)
compacted_graph_refs.push_back(HeteroGraphRef(g));
for (const IdArray &ids : result_pair.second)
induced_nodes.push_back(Value(MakeValue(ids)));
List<ObjectRef> result;
result.push_back(compacted_graph_refs);
result.push_back(induced_nodes);
*rv = result;
}); });
List<HeteroGraphRef> compacted_graph_refs;
List<Value> induced_nodes;
for (const HeteroGraphPtr g : result_pair.first)
compacted_graph_refs.push_back(HeteroGraphRef(g));
for (const IdArray &ids : result_pair.second)
induced_nodes.push_back(Value(MakeValue(ids)));
List<ObjectRef> result;
result.push_back(compacted_graph_refs);
result.push_back(induced_nodes);
*rv = result;
});
}; // namespace transform }; // namespace transform
}; // namespace dgl }; // namespace dgl
...@@ -24,8 +24,8 @@ ...@@ -24,8 +24,8 @@
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/base_heterograph.h> #include <dgl/base_heterograph.h>
#include <vector>
#include <utility> #include <utility>
#include <vector>
namespace dgl { namespace dgl {
namespace transform { namespace transform {
...@@ -41,9 +41,8 @@ namespace transform { ...@@ -41,9 +41,8 @@ namespace transform {
* *
* @return The vector of compacted graphs and the vector of induced nodes. * @return The vector of compacted graphs and the vector of induced nodes.
*/ */
template<DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
std::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>> std::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>> CompactGraphs(
CompactGraphs(
const std::vector<HeteroGraphPtr> &graphs, const std::vector<HeteroGraphPtr> &graphs,
const std::vector<IdArray> &always_preserve); const std::vector<IdArray> &always_preserve);
......
...@@ -9,7 +9,9 @@ ...@@ -9,7 +9,9 @@
#include <dgl/array.h> #include <dgl/array.h>
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <nanoflann.hpp> #include <nanoflann.hpp>
#include "../../../c_api_common.h" #include "../../../c_api_common.h"
namespace dgl { namespace dgl {
...@@ -17,78 +19,75 @@ namespace transform { ...@@ -17,78 +19,75 @@ namespace transform {
namespace knn_utils { namespace knn_utils {
/*! /*!
* \brief A simple 2D NDArray adapter for nanoflann, without duplicating the storage. * \brief A simple 2D NDArray adapter for nanoflann, without duplicating the
* storage.
* *
* \tparam FloatType: The type of the point coordinates (typically, double or float). * \tparam FloatType: The type of the point coordinates (typically, double or
* \tparam IdType: The type for indices in the KD-tree index (typically, size_t of int) * float).
* \tparam FeatureDim: If set to > 0, it specifies a compile-time fixed dimensionality * \tparam IdType: The type for indices in the KD-tree index (typically,
* for the points in the data set, allowing more compiler optimizations. * size_t of int)
* \tparam Dist: The distance metric to use: nanoflann::metric_L1, nanoflann::metric_L2, * \tparam FeatureDim: If set to > 0, it specifies a compile-time fixed
* nanoflann::metric_L2_Simple, etc. * dimensionality for the points in the data set, allowing more compiler
* \note The spelling of dgl's adapter ("adapter") is different from naneflann ("adaptor") * optimizations.
* \tparam Dist: The distance metric to use: nanoflann::metric_L1,
nanoflann::metric_L2, nanoflann::metric_L2_Simple, etc.
* \note The spelling of dgl's adapter ("adapter") is different from naneflann
* ("adaptor")
*/ */
template <typename FloatType, template <
typename IdType, typename FloatType, typename IdType, int FeatureDim = -1,
int FeatureDim = -1, typename Dist = nanoflann::metric_L2>
typename Dist = nanoflann::metric_L2>
class KDTreeNDArrayAdapter { class KDTreeNDArrayAdapter {
public: public:
using self_type = KDTreeNDArrayAdapter<FloatType, IdType, FeatureDim, Dist>; using self_type = KDTreeNDArrayAdapter<FloatType, IdType, FeatureDim, Dist>;
using metric_type = typename Dist::template traits<FloatType, self_type>::distance_t; using metric_type =
typename Dist::template traits<FloatType, self_type>::distance_t;
using index_type = nanoflann::KDTreeSingleIndexAdaptor< using index_type = nanoflann::KDTreeSingleIndexAdaptor<
metric_type, self_type, FeatureDim, IdType>; metric_type, self_type, FeatureDim, IdType>;
KDTreeNDArrayAdapter(const size_t /* dims */, KDTreeNDArrayAdapter(
const NDArray data_points, const size_t /* dims */, const NDArray data_points,
const int leaf_max_size = 10) const int leaf_max_size = 10)
: data_(data_points) { : data_(data_points) {
CHECK(data_points->shape[0] != 0 && data_points->shape[1] != 0) CHECK(data_points->shape[0] != 0 && data_points->shape[1] != 0)
<< "Tensor containing input data point set must be 2D."; << "Tensor containing input data point set must be 2D.";
const size_t dims = data_points->shape[1]; const size_t dims = data_points->shape[1];
CHECK(!(FeatureDim > 0 && static_cast<int>(dims) != FeatureDim)) CHECK(!(FeatureDim > 0 && static_cast<int>(dims) != FeatureDim))
<< "Data set feature dimension does not match the 'FeatureDim' " << "Data set feature dimension does not match the 'FeatureDim' "
<< "template argument."; << "template argument.";
index_ = new index_type( index_ = new index_type(
static_cast<int>(dims), *this, nanoflann::KDTreeSingleIndexAdaptorParams(leaf_max_size)); static_cast<int>(dims), *this,
nanoflann::KDTreeSingleIndexAdaptorParams(leaf_max_size));
index_->buildIndex(); index_->buildIndex();
} }
~KDTreeNDArrayAdapter() { ~KDTreeNDArrayAdapter() { delete index_; }
delete index_;
}
index_type* GetIndex() { index_type* GetIndex() { return index_; }
return index_;
}
/*! /*!
* \brief Query for the \a num_closest points to a given point * \brief Query for the \a num_closest points to a given point
* Note that this is a short-cut method for GetIndex()->findNeighbors(). * Note that this is a short-cut method for GetIndex()->findNeighbors().
*/ */
void query(const FloatType* query_pt, const size_t num_closest, void query(
IdType* out_idxs, FloatType* out_dists) const { const FloatType* query_pt, const size_t num_closest, IdType* out_idxs,
FloatType* out_dists) const {
nanoflann::KNNResultSet<FloatType, IdType> resultSet(num_closest); nanoflann::KNNResultSet<FloatType, IdType> resultSet(num_closest);
resultSet.init(out_idxs, out_dists); resultSet.init(out_idxs, out_dists);
index_->findNeighbors(resultSet, query_pt, nanoflann::SearchParams()); index_->findNeighbors(resultSet, query_pt, nanoflann::SearchParams());
} }
/*! \brief Interface expected by KDTreeSingleIndexAdaptor */ /*! \brief Interface expected by KDTreeSingleIndexAdaptor */
const self_type& derived() const { const self_type& derived() const { return *this; }
return *this;
}
/*! \brief Interface expected by KDTreeSingleIndexAdaptor */ /*! \brief Interface expected by KDTreeSingleIndexAdaptor */
self_type& derived() { self_type& derived() { return *this; }
return *this;
}
/*! /*!
* \brief Interface expected by KDTreeSingleIndexAdaptor, * \brief Interface expected by KDTreeSingleIndexAdaptor,
* return the number of data points * return the number of data points
*/ */
size_t kdtree_get_point_count() const { size_t kdtree_get_point_count() const { return data_->shape[0]; }
return data_->shape[0];
}
/*! /*!
* \brief Interface expected by KDTreeSingleIndexAdaptor, * \brief Interface expected by KDTreeSingleIndexAdaptor,
...@@ -110,7 +109,7 @@ class KDTreeNDArrayAdapter { ...@@ -110,7 +109,7 @@ class KDTreeNDArrayAdapter {
} }
private: private:
index_type* index_; // The kd tree index index_type* index_; // The kd tree index
const NDArray data_; // data points const NDArray data_; // data points
}; };
......
...@@ -4,16 +4,19 @@ ...@@ -4,16 +4,19 @@
* \brief k-nearest-neighbor (KNN) implementation * \brief k-nearest-neighbor (KNN) implementation
*/ */
#include <dgl/runtime/device_api.h> #include "../knn.h"
#include <dgl/random.h> #include <dgl/random.h>
#include <dgl/runtime/device_api.h>
#include <dgl/runtime/parallel_for.h> #include <dgl/runtime/parallel_for.h>
#include <dmlc/omp.h> #include <dmlc/omp.h>
#include <vector>
#include <tuple>
#include <limits>
#include <algorithm> #include <algorithm>
#include <limits>
#include <tuple>
#include <vector>
#include "kdtree_ndarray_adapter.h" #include "kdtree_ndarray_adapter.h"
#include "../knn.h"
using namespace dgl::runtime; using namespace dgl::runtime;
using namespace dgl::transform::knn_utils; using namespace dgl::transform::knn_utils;
...@@ -30,8 +33,9 @@ static constexpr int NN_DESCENT_BLOCK_SIZE = 16384; ...@@ -30,8 +33,9 @@ static constexpr int NN_DESCENT_BLOCK_SIZE = 16384;
* distance. * distance.
*/ */
template <typename FloatType, typename IdType> template <typename FloatType, typename IdType>
FloatType EuclideanDistWithCheck(const FloatType* vec1, const FloatType* vec2, int64_t dim, FloatType EuclideanDistWithCheck(
FloatType worst_dist = std::numeric_limits<FloatType>::max()) { const FloatType* vec1, const FloatType* vec2, int64_t dim,
FloatType worst_dist = std::numeric_limits<FloatType>::max()) {
FloatType dist = 0; FloatType dist = 0;
bool early_stop = false; bool early_stop = false;
...@@ -52,7 +56,8 @@ FloatType EuclideanDistWithCheck(const FloatType* vec1, const FloatType* vec2, i ...@@ -52,7 +56,8 @@ FloatType EuclideanDistWithCheck(const FloatType* vec1, const FloatType* vec2, i
/*! \brief Compute Euclidean distance between two vectors */ /*! \brief Compute Euclidean distance between two vectors */
template <typename FloatType, typename IdType> template <typename FloatType, typename IdType>
FloatType EuclideanDist(const FloatType* vec1, const FloatType* vec2, int64_t dim) { FloatType EuclideanDist(
const FloatType* vec1, const FloatType* vec2, int64_t dim) {
FloatType dist = 0; FloatType dist = 0;
for (IdType idx = 0; idx < dim; ++idx) { for (IdType idx = 0; idx < dim; ++idx) {
...@@ -64,9 +69,9 @@ FloatType EuclideanDist(const FloatType* vec1, const FloatType* vec2, int64_t di ...@@ -64,9 +69,9 @@ FloatType EuclideanDist(const FloatType* vec1, const FloatType* vec2, int64_t di
/*! \brief Insert a new element into a heap */ /*! \brief Insert a new element into a heap */
template <typename FloatType, typename IdType> template <typename FloatType, typename IdType>
void HeapInsert(IdType* out, FloatType* dist, void HeapInsert(
IdType new_id, FloatType new_dist, IdType* out, FloatType* dist, IdType new_id, FloatType new_dist, int k,
int k, bool check_repeat = false) { bool check_repeat = false) {
if (new_dist > dist[0]) return; if (new_dist > dist[0]) return;
// check if we have it // check if we have it
...@@ -99,11 +104,12 @@ void HeapInsert(IdType* out, FloatType* dist, ...@@ -99,11 +104,12 @@ void HeapInsert(IdType* out, FloatType* dist,
} }
} }
/*! \brief Insert a new element and its flag into heap, return 1 if insert successfully */ /*! \brief Insert a new element and its flag into heap, return 1 if insert
* successfully */
template <typename FloatType, typename IdType> template <typename FloatType, typename IdType>
int FlaggedHeapInsert(IdType* out, FloatType* dist, bool* flag, int FlaggedHeapInsert(
IdType new_id, FloatType new_dist, bool new_flag, IdType* out, FloatType* dist, bool* flag, IdType new_id, FloatType new_dist,
int k, bool check_repeat = false) { bool new_flag, int k, bool check_repeat = false) {
if (new_dist > dist[0]) return 0; if (new_dist > dist[0]) return 0;
if (check_repeat) { if (check_repeat) {
...@@ -170,16 +176,15 @@ void BuildHeap(IdType* index, FloatType* dist, int k) { ...@@ -170,16 +176,15 @@ void BuildHeap(IdType* index, FloatType* dist, int k) {
* distance of these two points, we update the neighborhood of that point. * distance of these two points, we update the neighborhood of that point.
*/ */
template <typename FloatType, typename IdType> template <typename FloatType, typename IdType>
int UpdateNeighbors(IdType* neighbors, FloatType* dists, const FloatType* points, int UpdateNeighbors(
bool* flags, IdType c1, IdType c2, IdType point_start, IdType* neighbors, FloatType* dists, const FloatType* points, bool* flags,
int64_t feature_size, int k) { IdType c1, IdType c2, IdType point_start, int64_t feature_size, int k) {
IdType c1_local = c1 - point_start, c2_local = c2 - point_start; IdType c1_local = c1 - point_start, c2_local = c2 - point_start;
FloatType worst_c1_dist = dists[c1_local * k]; FloatType worst_c1_dist = dists[c1_local * k];
FloatType worst_c2_dist = dists[c2_local * k]; FloatType worst_c2_dist = dists[c2_local * k];
FloatType new_dist = EuclideanDistWithCheck<FloatType, IdType>( FloatType new_dist = EuclideanDistWithCheck<FloatType, IdType>(
points + c1 * feature_size, points + c1 * feature_size, points + c2 * feature_size, feature_size,
points + c2 * feature_size, std::max(worst_c1_dist, worst_c2_dist));
feature_size, std::max(worst_c1_dist, worst_c2_dist));
int num_updates = 0; int num_updates = 0;
if (new_dist < worst_c1_dist) { if (new_dist < worst_c1_dist) {
...@@ -187,10 +192,8 @@ int UpdateNeighbors(IdType* neighbors, FloatType* dists, const FloatType* points ...@@ -187,10 +192,8 @@ int UpdateNeighbors(IdType* neighbors, FloatType* dists, const FloatType* points
#pragma omp critical #pragma omp critical
{ {
FlaggedHeapInsert<FloatType, IdType>( FlaggedHeapInsert<FloatType, IdType>(
neighbors + c1 * k, neighbors + c1 * k, dists + c1_local * k, flags + c1_local * k, c2,
dists + c1_local * k, new_dist, true, k, true);
flags + c1_local * k,
c2, new_dist, true, k, true);
} }
} }
if (new_dist < worst_c2_dist) { if (new_dist < worst_c2_dist) {
...@@ -198,10 +201,8 @@ int UpdateNeighbors(IdType* neighbors, FloatType* dists, const FloatType* points ...@@ -198,10 +201,8 @@ int UpdateNeighbors(IdType* neighbors, FloatType* dists, const FloatType* points
#pragma omp critical #pragma omp critical
{ {
FlaggedHeapInsert<FloatType, IdType>( FlaggedHeapInsert<FloatType, IdType>(
neighbors + c2 * k, neighbors + c2 * k, dists + c2_local * k, flags + c2_local * k, c1,
dists + c2_local * k, new_dist, true, k, true);
flags + c2_local * k,
c1, new_dist, true, k, true);
} }
} }
return num_updates; return num_updates;
...@@ -209,9 +210,10 @@ int UpdateNeighbors(IdType* neighbors, FloatType* dists, const FloatType* points ...@@ -209,9 +210,10 @@ int UpdateNeighbors(IdType* neighbors, FloatType* dists, const FloatType* points
/*! \brief The kd-tree implementation of K-Nearest Neighbors */ /*! \brief The kd-tree implementation of K-Nearest Neighbors */
template <typename FloatType, typename IdType> template <typename FloatType, typename IdType>
void KdTreeKNN(const NDArray& data_points, const IdArray& data_offsets, void KdTreeKNN(
const NDArray& query_points, const IdArray& query_offsets, const NDArray& data_points, const IdArray& data_offsets,
const int k, IdArray result) { const NDArray& query_points, const IdArray& query_offsets, const int k,
IdArray result) {
const int64_t batch_size = data_offsets->shape[0] - 1; const int64_t batch_size = data_offsets->shape[0] - 1;
const int64_t feature_size = data_points->shape[1]; const int64_t feature_size = data_points->shape[1];
const IdType* data_offsets_data = data_offsets.Ptr<IdType>(); const IdType* data_offsets_data = data_offsets.Ptr<IdType>();
...@@ -228,11 +230,16 @@ void KdTreeKNN(const NDArray& data_points, const IdArray& data_offsets, ...@@ -228,11 +230,16 @@ void KdTreeKNN(const NDArray& data_points, const IdArray& data_offsets,
auto out_offset = k * q_offset; auto out_offset = k * q_offset;
// create view for each segment // create view for each segment
const NDArray current_data_points = const_cast<NDArray*>(&data_points)->CreateView( const NDArray current_data_points =
{d_length, feature_size}, data_points->dtype, d_offset * feature_size * sizeof(FloatType)); const_cast<NDArray*>(&data_points)
const FloatType* current_query_pts_data = query_points_data + q_offset * feature_size; ->CreateView(
{d_length, feature_size}, data_points->dtype,
d_offset * feature_size * sizeof(FloatType));
const FloatType* current_query_pts_data =
query_points_data + q_offset * feature_size;
KDTreeNDArrayAdapter<FloatType, IdType> kdtree(feature_size, current_data_points); KDTreeNDArrayAdapter<FloatType, IdType> kdtree(
feature_size, current_data_points);
// query // query
parallel_for(0, q_length, [&](IdType b, IdType e) { parallel_for(0, q_length, [&](IdType b, IdType e) {
...@@ -256,9 +263,10 @@ void KdTreeKNN(const NDArray& data_points, const IdArray& data_offsets, ...@@ -256,9 +263,10 @@ void KdTreeKNN(const NDArray& data_points, const IdArray& data_offsets,
} }
template <typename FloatType, typename IdType> template <typename FloatType, typename IdType>
void BruteForceKNN(const NDArray& data_points, const IdArray& data_offsets, void BruteForceKNN(
const NDArray& query_points, const IdArray& query_offsets, const NDArray& data_points, const IdArray& data_offsets,
const int k, IdArray result) { const NDArray& query_points, const IdArray& query_offsets, const int k,
IdArray result) {
const int64_t batch_size = data_offsets->shape[0] - 1; const int64_t batch_size = data_offsets->shape[0] - 1;
const int64_t feature_size = data_points->shape[1]; const int64_t feature_size = data_points->shape[1];
const IdType* data_offsets_data = data_offsets.Ptr<IdType>(); const IdType* data_offsets_data = data_offsets.Ptr<IdType>();
...@@ -285,9 +293,9 @@ void BruteForceKNN(const NDArray& data_points, const IdArray& data_offsets, ...@@ -285,9 +293,9 @@ void BruteForceKNN(const NDArray& data_points, const IdArray& data_offsets,
for (IdType d_idx = d_start; d_idx < d_end; ++d_idx) { for (IdType d_idx = d_start; d_idx < d_end; ++d_idx) {
FloatType tmp_dist = EuclideanDistWithCheck<FloatType, IdType>( FloatType tmp_dist = EuclideanDistWithCheck<FloatType, IdType>(
query_points_data + q_idx * feature_size, query_points_data + q_idx * feature_size,
data_points_data + d_idx * feature_size, data_points_data + d_idx * feature_size, feature_size,
feature_size, worst_dist); worst_dist);
if (tmp_dist == std::numeric_limits<FloatType>::max()) { if (tmp_dist == std::numeric_limits<FloatType>::max()) {
continue; continue;
...@@ -295,7 +303,7 @@ void BruteForceKNN(const NDArray& data_points, const IdArray& data_offsets, ...@@ -295,7 +303,7 @@ void BruteForceKNN(const NDArray& data_points, const IdArray& data_offsets,
IdType out_offset = q_idx * k; IdType out_offset = q_idx * k;
HeapInsert<FloatType, IdType>( HeapInsert<FloatType, IdType>(
data_out + out_offset, dist_buffer.data(), d_idx, tmp_dist, k); data_out + out_offset, dist_buffer.data(), d_idx, tmp_dist, k);
worst_dist = dist_buffer[0]; worst_dist = dist_buffer[0];
} }
} }
...@@ -305,25 +313,27 @@ void BruteForceKNN(const NDArray& data_points, const IdArray& data_offsets, ...@@ -305,25 +313,27 @@ void BruteForceKNN(const NDArray& data_points, const IdArray& data_offsets,
} // namespace impl } // namespace impl
template <DGLDeviceType XPU, typename FloatType, typename IdType> template <DGLDeviceType XPU, typename FloatType, typename IdType>
void KNN(const NDArray& data_points, const IdArray& data_offsets, void KNN(
const NDArray& query_points, const IdArray& query_offsets, const NDArray& data_points, const IdArray& data_offsets,
const int k, IdArray result, const std::string& algorithm) { const NDArray& query_points, const IdArray& query_offsets, const int k,
IdArray result, const std::string& algorithm) {
if (algorithm == std::string("kd-tree")) { if (algorithm == std::string("kd-tree")) {
impl::KdTreeKNN<FloatType, IdType>( impl::KdTreeKNN<FloatType, IdType>(
data_points, data_offsets, query_points, query_offsets, k, result); data_points, data_offsets, query_points, query_offsets, k, result);
} else if (algorithm == std::string("bruteforce")) { } else if (algorithm == std::string("bruteforce")) {
impl::BruteForceKNN<FloatType, IdType>( impl::BruteForceKNN<FloatType, IdType>(
data_points, data_offsets, query_points, query_offsets, k, result); data_points, data_offsets, query_points, query_offsets, k, result);
} else { } else {
LOG(FATAL) << "Algorithm " << algorithm << " is not supported on CPU"; LOG(FATAL) << "Algorithm " << algorithm << " is not supported on CPU";
} }
} }
template <DGLDeviceType XPU, typename FloatType, typename IdType> template <DGLDeviceType XPU, typename FloatType, typename IdType>
void NNDescent(const NDArray& points, const IdArray& offsets, void NNDescent(
IdArray result, const int k, const int num_iters, const NDArray& points, const IdArray& offsets, IdArray result, const int k,
const int num_candidates, const double delta) { const int num_iters, const int num_candidates, const double delta) {
using nnd_updates_t = std::vector<std::vector<std::tuple<IdType, IdType, FloatType>>>; using nnd_updates_t =
std::vector<std::vector<std::tuple<IdType, IdType, FloatType>>>;
const auto& ctx = points->ctx; const auto& ctx = points->ctx;
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
const int64_t num_nodes = points->shape[0]; const int64_t num_nodes = points->shape[0];
...@@ -343,62 +353,69 @@ void NNDescent(const NDArray& points, const IdArray& offsets, ...@@ -343,62 +353,69 @@ void NNDescent(const NDArray& points, const IdArray& offsets,
} }
// allocate memory for candidate, sampling pool, distance and flag // allocate memory for candidate, sampling pool, distance and flag
IdType* new_candidates = static_cast<IdType*>( IdType* new_candidates = static_cast<IdType*>(device->AllocWorkspace(
device->AllocWorkspace(ctx, max_segment_size * num_candidates * sizeof(IdType))); ctx, max_segment_size * num_candidates * sizeof(IdType)));
IdType* old_candidates = static_cast<IdType*>( IdType* old_candidates = static_cast<IdType*>(device->AllocWorkspace(
device->AllocWorkspace(ctx, max_segment_size * num_candidates * sizeof(IdType))); ctx, max_segment_size * num_candidates * sizeof(IdType)));
FloatType* new_candidates_dists = static_cast<FloatType*>( FloatType* new_candidates_dists =
device->AllocWorkspace(ctx, max_segment_size * num_candidates * sizeof(FloatType))); static_cast<FloatType*>(device->AllocWorkspace(
FloatType* old_candidates_dists = static_cast<FloatType*>( ctx, max_segment_size * num_candidates * sizeof(FloatType)));
device->AllocWorkspace(ctx, max_segment_size * num_candidates * sizeof(FloatType))); FloatType* old_candidates_dists =
static_cast<FloatType*>(device->AllocWorkspace(
ctx, max_segment_size * num_candidates * sizeof(FloatType)));
FloatType* neighbors_dists = static_cast<FloatType*>( FloatType* neighbors_dists = static_cast<FloatType*>(
device->AllocWorkspace(ctx, max_segment_size * k * sizeof(FloatType))); device->AllocWorkspace(ctx, max_segment_size * k * sizeof(FloatType)));
bool* flags = static_cast<bool*>( bool* flags = static_cast<bool*>(
device->AllocWorkspace(ctx, max_segment_size * k * sizeof(bool))); device->AllocWorkspace(ctx, max_segment_size * k * sizeof(bool)));
for (IdType b = 0; b < batch_size; ++b) { for (IdType b = 0; b < batch_size; ++b) {
IdType point_idx_start = offsets_data[b], point_idx_end = offsets_data[b + 1]; IdType point_idx_start = offsets_data[b],
point_idx_end = offsets_data[b + 1];
IdType segment_size = point_idx_end - point_idx_start; IdType segment_size = point_idx_end - point_idx_start;
// random initialization // random initialization
runtime::parallel_for(point_idx_start, point_idx_end, [&](size_t b, size_t e) { runtime::parallel_for(
for (auto i = b; i < e; ++i) { point_idx_start, point_idx_end, [&](size_t b, size_t e) {
IdType local_idx = i - point_idx_start; for (auto i = b; i < e; ++i) {
IdType local_idx = i - point_idx_start;
dgl::RandomEngine::ThreadLocal()->UniformChoice<IdType>(
k, segment_size, neighbors + i * k, false); dgl::RandomEngine::ThreadLocal()->UniformChoice<IdType>(
k, segment_size, neighbors + i * k, false);
for (IdType n = 0; n < k; ++n) {
central_nodes[i * k + n] = i; for (IdType n = 0; n < k; ++n) {
neighbors[i * k + n] += point_idx_start; central_nodes[i * k + n] = i;
flags[local_idx * k + n] = true; neighbors[i * k + n] += point_idx_start;
neighbors_dists[local_idx * k + n] = impl::EuclideanDist<FloatType, IdType>( flags[local_idx * k + n] = true;
points_data + i * feature_size, neighbors_dists[local_idx * k + n] =
points_data + neighbors[i * k + n] * feature_size, impl::EuclideanDist<FloatType, IdType>(
feature_size); points_data + i * feature_size,
} points_data + neighbors[i * k + n] * feature_size,
impl::BuildHeap<FloatType, IdType>(neighbors + i * k, neighbors_dists + local_idx * k, k); feature_size);
} }
}); impl::BuildHeap<FloatType, IdType>(
neighbors + i * k, neighbors_dists + local_idx * k, k);
}
});
size_t num_updates = 0; size_t num_updates = 0;
for (int iter = 0; iter < num_iters; ++iter) { for (int iter = 0; iter < num_iters; ++iter) {
num_updates = 0; num_updates = 0;
// initialize candidates array as empty value // initialize candidates array as empty value
runtime::parallel_for(point_idx_start, point_idx_end, [&](size_t b, size_t e) { runtime::parallel_for(
for (auto i = b; i < e; ++i) { point_idx_start, point_idx_end, [&](size_t b, size_t e) {
IdType local_idx = i - point_idx_start; for (auto i = b; i < e; ++i) {
for (IdType c = 0; c < num_candidates; ++c) { IdType local_idx = i - point_idx_start;
new_candidates[local_idx * num_candidates + c] = num_nodes; for (IdType c = 0; c < num_candidates; ++c) {
old_candidates[local_idx * num_candidates + c] = num_nodes; new_candidates[local_idx * num_candidates + c] = num_nodes;
new_candidates_dists[local_idx * num_candidates + c] = old_candidates[local_idx * num_candidates + c] = num_nodes;
std::numeric_limits<FloatType>::max(); new_candidates_dists[local_idx * num_candidates + c] =
old_candidates_dists[local_idx * num_candidates + c] = std::numeric_limits<FloatType>::max();
std::numeric_limits<FloatType>::max(); old_candidates_dists[local_idx * num_candidates + c] =
} std::numeric_limits<FloatType>::max();
} }
}); }
});
// randomly select neighbors as candidates // randomly select neighbors as candidates
int num_threads = omp_get_max_threads(); int num_threads = omp_get_max_threads();
...@@ -410,33 +427,36 @@ void NNDescent(const NDArray& points, const IdArray& offsets, ...@@ -410,33 +427,36 @@ void NNDescent(const NDArray& points, const IdArray& offsets,
IdType neighbor_idx = neighbors[i * k + n]; IdType neighbor_idx = neighbors[i * k + n];
bool is_new = flags[local_idx * k + n]; bool is_new = flags[local_idx * k + n];
IdType local_neighbor_idx = neighbor_idx - point_idx_start; IdType local_neighbor_idx = neighbor_idx - point_idx_start;
FloatType random_dist = dgl::RandomEngine::ThreadLocal()->Uniform<FloatType>(); FloatType random_dist =
dgl::RandomEngine::ThreadLocal()->Uniform<FloatType>();
if (is_new) { if (is_new) {
if (local_idx % num_threads == tid) { if (local_idx % num_threads == tid) {
impl::HeapInsert<FloatType, IdType>( impl::HeapInsert<FloatType, IdType>(
new_candidates + local_idx * num_candidates, new_candidates + local_idx * num_candidates,
new_candidates_dists + local_idx * num_candidates, new_candidates_dists + local_idx * num_candidates,
neighbor_idx, random_dist, num_candidates, true); neighbor_idx, random_dist, num_candidates, true);
} }
if (local_neighbor_idx % num_threads == tid) { if (local_neighbor_idx % num_threads == tid) {
impl::HeapInsert<FloatType, IdType>( impl::HeapInsert<FloatType, IdType>(
new_candidates + local_neighbor_idx * num_candidates, new_candidates + local_neighbor_idx * num_candidates,
new_candidates_dists + local_neighbor_idx * num_candidates, new_candidates_dists +
i, random_dist, num_candidates, true); local_neighbor_idx * num_candidates,
i, random_dist, num_candidates, true);
} }
} else { } else {
if (local_idx % num_threads == tid) { if (local_idx % num_threads == tid) {
impl::HeapInsert<FloatType, IdType>( impl::HeapInsert<FloatType, IdType>(
old_candidates + local_idx * num_candidates, old_candidates + local_idx * num_candidates,
old_candidates_dists + local_idx * num_candidates, old_candidates_dists + local_idx * num_candidates,
neighbor_idx, random_dist, num_candidates, true); neighbor_idx, random_dist, num_candidates, true);
} }
if (local_neighbor_idx % num_threads == tid) { if (local_neighbor_idx % num_threads == tid) {
impl::HeapInsert<FloatType, IdType>( impl::HeapInsert<FloatType, IdType>(
old_candidates + local_neighbor_idx * num_candidates, old_candidates + local_neighbor_idx * num_candidates,
old_candidates_dists + local_neighbor_idx * num_candidates, old_candidates_dists +
i, random_dist, num_candidates, true); local_neighbor_idx * num_candidates,
i, random_dist, num_candidates, true);
} }
} }
} }
...@@ -445,27 +465,28 @@ void NNDescent(const NDArray& points, const IdArray& offsets, ...@@ -445,27 +465,28 @@ void NNDescent(const NDArray& points, const IdArray& offsets,
}); });
// mark all elements in new_candidates as false // mark all elements in new_candidates as false
runtime::parallel_for(point_idx_start, point_idx_end, [&](size_t b, size_t e) { runtime::parallel_for(
for (auto i = b; i < e; ++i) { point_idx_start, point_idx_end, [&](size_t b, size_t e) {
IdType local_idx = i - point_idx_start; for (auto i = b; i < e; ++i) {
for (IdType n = 0; n < k; ++n) { IdType local_idx = i - point_idx_start;
IdType n_idx = neighbors[i * k + n]; for (IdType n = 0; n < k; ++n) {
IdType n_idx = neighbors[i * k + n];
for (IdType c = 0; c < num_candidates; ++c) {
if (new_candidates[local_idx * num_candidates + c] == n_idx) { for (IdType c = 0; c < num_candidates; ++c) {
flags[local_idx * k + n] = false; if (new_candidates[local_idx * num_candidates + c] == n_idx) {
break; flags[local_idx * k + n] = false;
break;
}
}
} }
} }
} });
}
});
// update neighbors block by block // update neighbors block by block
for (IdType block_start = point_idx_start; for (IdType block_start = point_idx_start; block_start < point_idx_end;
block_start < point_idx_end;
block_start += impl::NN_DESCENT_BLOCK_SIZE) { block_start += impl::NN_DESCENT_BLOCK_SIZE) {
IdType block_end = std::min(point_idx_end, block_start + impl::NN_DESCENT_BLOCK_SIZE); IdType block_end =
std::min(point_idx_end, block_start + impl::NN_DESCENT_BLOCK_SIZE);
IdType block_size = block_end - block_start; IdType block_size = block_end - block_start;
nnd_updates_t updates(block_size); nnd_updates_t updates(block_size);
...@@ -487,14 +508,15 @@ void NNDescent(const NDArray& points, const IdArray& offsets, ...@@ -487,14 +508,15 @@ void NNDescent(const NDArray& points, const IdArray& offsets,
FloatType worst_c1_dist = neighbors_dists[c1_local * k]; FloatType worst_c1_dist = neighbors_dists[c1_local * k];
FloatType worst_c2_dist = neighbors_dists[c2_local * k]; FloatType worst_c2_dist = neighbors_dists[c2_local * k];
FloatType new_dist = impl::EuclideanDistWithCheck<FloatType, IdType>( FloatType new_dist =
points_data + new_c1 * feature_size, impl::EuclideanDistWithCheck<FloatType, IdType>(
points_data + new_c2 * feature_size, points_data + new_c1 * feature_size,
feature_size, points_data + new_c2 * feature_size, feature_size,
std::max(worst_c1_dist, worst_c2_dist)); std::max(worst_c1_dist, worst_c2_dist));
if (new_dist < worst_c1_dist || new_dist < worst_c2_dist) { if (new_dist < worst_c1_dist || new_dist < worst_c2_dist) {
updates[i - block_start].push_back(std::make_tuple(new_c1, new_c2, new_dist)); updates[i - block_start].push_back(
std::make_tuple(new_c1, new_c2, new_dist));
} }
} }
...@@ -506,14 +528,15 @@ void NNDescent(const NDArray& points, const IdArray& offsets, ...@@ -506,14 +528,15 @@ void NNDescent(const NDArray& points, const IdArray& offsets,
FloatType worst_c1_dist = neighbors_dists[c1_local * k]; FloatType worst_c1_dist = neighbors_dists[c1_local * k];
FloatType worst_c2_dist = neighbors_dists[c2_local * k]; FloatType worst_c2_dist = neighbors_dists[c2_local * k];
FloatType new_dist = impl::EuclideanDistWithCheck<FloatType, IdType>( FloatType new_dist =
points_data + new_c1 * feature_size, impl::EuclideanDistWithCheck<FloatType, IdType>(
points_data + old_c2 * feature_size, points_data + new_c1 * feature_size,
feature_size, points_data + old_c2 * feature_size, feature_size,
std::max(worst_c1_dist, worst_c2_dist)); std::max(worst_c1_dist, worst_c2_dist));
if (new_dist < worst_c1_dist || new_dist < worst_c2_dist) { if (new_dist < worst_c1_dist || new_dist < worst_c2_dist) {
updates[i - block_start].push_back(std::make_tuple(new_c1, old_c2, new_dist)); updates[i - block_start].push_back(
std::make_tuple(new_c1, old_c2, new_dist));
} }
} }
} }
...@@ -521,12 +544,12 @@ void NNDescent(const NDArray& points, const IdArray& offsets, ...@@ -521,12 +544,12 @@ void NNDescent(const NDArray& points, const IdArray& offsets,
}); });
int tid; int tid;
#pragma omp parallel private(tid, num_threads) reduction(+:num_updates) #pragma omp parallel private(tid, num_threads) reduction(+ : num_updates)
{ {
tid = omp_get_thread_num(); tid = omp_get_thread_num();
num_threads = omp_get_num_threads(); num_threads = omp_get_num_threads();
for (IdType i = 0; i < block_size; ++i) { for (IdType i = 0; i < block_size; ++i) {
for (const auto & u : updates[i]) { for (const auto& u : updates[i]) {
IdType p1, p2; IdType p1, p2;
FloatType d; FloatType d;
std::tie(p1, p2, d) = u; std::tie(p1, p2, d) = u;
...@@ -535,17 +558,13 @@ void NNDescent(const NDArray& points, const IdArray& offsets, ...@@ -535,17 +558,13 @@ void NNDescent(const NDArray& points, const IdArray& offsets,
if (p1 % num_threads == tid) { if (p1 % num_threads == tid) {
num_updates += impl::FlaggedHeapInsert<FloatType, IdType>( num_updates += impl::FlaggedHeapInsert<FloatType, IdType>(
neighbors + p1 * k, neighbors + p1 * k, neighbors_dists + p1_local * k,
neighbors_dists + p1_local * k, flags + p1_local * k, p2, d, true, k, true);
flags + p1_local * k,
p2, d, true, k, true);
} }
if (p2 % num_threads == tid) { if (p2 % num_threads == tid) {
num_updates += impl::FlaggedHeapInsert<FloatType, IdType>( num_updates += impl::FlaggedHeapInsert<FloatType, IdType>(
neighbors + p2 * k, neighbors + p2 * k, neighbors_dists + p2_local * k,
neighbors_dists + p2_local * k, flags + p2_local * k, p1, d, true, k, true);
flags + p2_local * k,
p1, d, true, k, true);
} }
} }
} }
...@@ -568,37 +587,33 @@ void NNDescent(const NDArray& points, const IdArray& offsets, ...@@ -568,37 +587,33 @@ void NNDescent(const NDArray& points, const IdArray& offsets,
} }
template void KNN<kDGLCPU, float, int32_t>( template void KNN<kDGLCPU, float, int32_t>(
const NDArray& data_points, const IdArray& data_offsets, const NDArray& data_points, const IdArray& data_offsets,
const NDArray& query_points, const IdArray& query_offsets, const NDArray& query_points, const IdArray& query_offsets, const int k,
const int k, IdArray result, const std::string& algorithm); IdArray result, const std::string& algorithm);
template void KNN<kDGLCPU, float, int64_t>( template void KNN<kDGLCPU, float, int64_t>(
const NDArray& data_points, const IdArray& data_offsets, const NDArray& data_points, const IdArray& data_offsets,
const NDArray& query_points, const IdArray& query_offsets, const NDArray& query_points, const IdArray& query_offsets, const int k,
const int k, IdArray result, const std::string& algorithm); IdArray result, const std::string& algorithm);
template void KNN<kDGLCPU, double, int32_t>( template void KNN<kDGLCPU, double, int32_t>(
const NDArray& data_points, const IdArray& data_offsets, const NDArray& data_points, const IdArray& data_offsets,
const NDArray& query_points, const IdArray& query_offsets, const NDArray& query_points, const IdArray& query_offsets, const int k,
const int k, IdArray result, const std::string& algorithm); IdArray result, const std::string& algorithm);
template void KNN<kDGLCPU, double, int64_t>( template void KNN<kDGLCPU, double, int64_t>(
const NDArray& data_points, const IdArray& data_offsets, const NDArray& data_points, const IdArray& data_offsets,
const NDArray& query_points, const IdArray& query_offsets, const NDArray& query_points, const IdArray& query_offsets, const int k,
const int k, IdArray result, const std::string& algorithm); IdArray result, const std::string& algorithm);
template void NNDescent<kDGLCPU, float, int32_t>( template void NNDescent<kDGLCPU, float, int32_t>(
const NDArray& points, const IdArray& offsets, const NDArray& points, const IdArray& offsets, IdArray result, const int k,
IdArray result, const int k, const int num_iters, const int num_iters, const int num_candidates, const double delta);
const int num_candidates, const double delta);
template void NNDescent<kDGLCPU, float, int64_t>( template void NNDescent<kDGLCPU, float, int64_t>(
const NDArray& points, const IdArray& offsets, const NDArray& points, const IdArray& offsets, IdArray result, const int k,
IdArray result, const int k, const int num_iters, const int num_iters, const int num_candidates, const double delta);
const int num_candidates, const double delta);
template void NNDescent<kDGLCPU, double, int32_t>( template void NNDescent<kDGLCPU, double, int32_t>(
const NDArray& points, const IdArray& offsets, const NDArray& points, const IdArray& offsets, IdArray result, const int k,
IdArray result, const int k, const int num_iters, const int num_iters, const int num_candidates, const double delta);
const int num_candidates, const double delta);
template void NNDescent<kDGLCPU, double, int64_t>( template void NNDescent<kDGLCPU, double, int64_t>(
const NDArray& points, const IdArray& offsets, const NDArray& points, const IdArray& offsets, IdArray result, const int k,
IdArray result, const int k, const int num_iters, const int num_iters, const int num_candidates, const double delta);
const int num_candidates, const double delta);
} // namespace transform } // namespace transform
} // namespace dgl } // namespace dgl
...@@ -18,13 +18,13 @@ ...@@ -18,13 +18,13 @@
* all given graphs with the same set of nodes. * all given graphs with the same set of nodes.
*/ */
#include <dgl/runtime/device_api.h>
#include <dgl/immutable_graph.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <utility> #include <dgl/immutable_graph.h>
#include <dgl/runtime/device_api.h>
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include <utility>
#include "../../../runtime/cuda/cuda_common.h" #include "../../../runtime/cuda/cuda_common.h"
#include "../../heterograph.h" #include "../../heterograph.h"
...@@ -41,54 +41,45 @@ namespace transform { ...@@ -41,54 +41,45 @@ namespace transform {
namespace { namespace {
/** /**
* \brief This function builds node maps for each node type, preserving the * \brief This function builds node maps for each node type, preserving the
* order of the input nodes. Here it is assumed the nodes are not unique, * order of the input nodes. Here it is assumed the nodes are not unique,
* and thus a unique list is generated. * and thus a unique list is generated.
* *
* \param input_nodes The set of input nodes. * \param input_nodes The set of input nodes.
* \param node_maps The node maps to be constructed. * \param node_maps The node maps to be constructed.
* \param count_unique_device The number of unique nodes (on the GPU). * \param count_unique_device The number of unique nodes (on the GPU).
* \param unique_nodes_device The unique nodes (on the GPU). * \param unique_nodes_device The unique nodes (on the GPU).
* \param stream The stream to operate on. * \param stream The stream to operate on.
*/ */
template<typename IdType> template <typename IdType>
void BuildNodeMaps( void BuildNodeMaps(
const std::vector<IdArray>& input_nodes, const std::vector<IdArray> &input_nodes,
DeviceNodeMap<IdType> * const node_maps, DeviceNodeMap<IdType> *const node_maps, int64_t *const count_unique_device,
int64_t * const count_unique_device, std::vector<IdArray> *const unique_nodes_device, cudaStream_t stream) {
std::vector<IdArray>* const unique_nodes_device,
cudaStream_t stream) {
const int64_t num_ntypes = static_cast<int64_t>(input_nodes.size()); const int64_t num_ntypes = static_cast<int64_t>(input_nodes.size());
CUDA_CALL(cudaMemsetAsync( CUDA_CALL(cudaMemsetAsync(
count_unique_device, count_unique_device, 0, num_ntypes * sizeof(*count_unique_device),
0, stream));
num_ntypes*sizeof(*count_unique_device),
stream));
// possibly duplicated nodes // possibly duplicated nodes
for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) { for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) {
const IdArray& nodes = input_nodes[ntype]; const IdArray &nodes = input_nodes[ntype];
if (nodes->shape[0] > 0) { if (nodes->shape[0] > 0) {
CHECK_EQ(nodes->ctx.device_type, kDGLCUDA); CHECK_EQ(nodes->ctx.device_type, kDGLCUDA);
node_maps->LhsHashTable(ntype).FillWithDuplicates( node_maps->LhsHashTable(ntype).FillWithDuplicates(
nodes.Ptr<IdType>(), nodes.Ptr<IdType>(), nodes->shape[0],
nodes->shape[0],
(*unique_nodes_device)[ntype].Ptr<IdType>(), (*unique_nodes_device)[ntype].Ptr<IdType>(),
count_unique_device+ntype, count_unique_device + ntype, stream);
stream);
} }
} }
} }
template <typename IdType>
template<typename IdType> std::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>> CompactGraphsGPU(
std::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>>
CompactGraphsGPU(
const std::vector<HeteroGraphPtr> &graphs, const std::vector<HeteroGraphPtr> &graphs,
const std::vector<IdArray> &always_preserve) { const std::vector<IdArray> &always_preserve) {
const auto &ctx = graphs[0]->Context();
const auto& ctx = graphs[0]->Context();
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
...@@ -96,7 +87,8 @@ CompactGraphsGPU( ...@@ -96,7 +87,8 @@ CompactGraphsGPU(
// Step 1: Collect the nodes that has connections for each type. // Step 1: Collect the nodes that has connections for each type.
const uint64_t num_ntypes = graphs[0]->NumVertexTypes(); const uint64_t num_ntypes = graphs[0]->NumVertexTypes();
std::vector<std::vector<EdgeArray>> all_edges(graphs.size()); // all_edges[i][etype] std::vector<std::vector<EdgeArray>> all_edges(
graphs.size()); // all_edges[i][etype]
// count the number of nodes per type // count the number of nodes per type
std::vector<int64_t> max_vertex_cnt(num_ntypes, 0); std::vector<int64_t> max_vertex_cnt(num_ntypes, 0);
...@@ -123,19 +115,18 @@ CompactGraphsGPU( ...@@ -123,19 +115,18 @@ CompactGraphsGPU(
std::vector<int64_t> node_offsets(num_ntypes, 0); std::vector<int64_t> node_offsets(num_ntypes, 0);
for (uint64_t ntype = 0; ntype < num_ntypes; ++ntype) { for (uint64_t ntype = 0; ntype < num_ntypes; ++ntype) {
all_nodes[ntype] = NewIdArray(max_vertex_cnt[ntype], ctx, all_nodes[ntype] =
sizeof(IdType)*8); NewIdArray(max_vertex_cnt[ntype], ctx, sizeof(IdType) * 8);
// copy the nodes in always_preserve // copy the nodes in always_preserve
if (ntype < always_preserve.size() && always_preserve[ntype]->shape[0] > 0) { if (ntype < always_preserve.size() &&
always_preserve[ntype]->shape[0] > 0) {
device->CopyDataFromTo( device->CopyDataFromTo(
always_preserve[ntype].Ptr<IdType>(), 0, always_preserve[ntype].Ptr<IdType>(), 0,
all_nodes[ntype].Ptr<IdType>(), all_nodes[ntype].Ptr<IdType>(), node_offsets[ntype],
node_offsets[ntype], sizeof(IdType) * always_preserve[ntype]->shape[0],
sizeof(IdType)*always_preserve[ntype]->shape[0], always_preserve[ntype]->ctx, all_nodes[ntype]->ctx,
always_preserve[ntype]->ctx,
all_nodes[ntype]->ctx,
always_preserve[ntype]->dtype); always_preserve[ntype]->dtype);
node_offsets[ntype] += sizeof(IdType)*always_preserve[ntype]->shape[0]; node_offsets[ntype] += sizeof(IdType) * always_preserve[ntype]->shape[0];
} }
} }
...@@ -152,25 +143,17 @@ CompactGraphsGPU( ...@@ -152,25 +143,17 @@ CompactGraphsGPU(
if (edges.src.defined()) { if (edges.src.defined()) {
device->CopyDataFromTo( device->CopyDataFromTo(
edges.src.Ptr<IdType>(), 0, edges.src.Ptr<IdType>(), 0, all_nodes[srctype].Ptr<IdType>(),
all_nodes[srctype].Ptr<IdType>(), node_offsets[srctype], sizeof(IdType) * edges.src->shape[0],
node_offsets[srctype], edges.src->ctx, all_nodes[srctype]->ctx, edges.src->dtype);
sizeof(IdType)*edges.src->shape[0], node_offsets[srctype] += sizeof(IdType) * edges.src->shape[0];
edges.src->ctx,
all_nodes[srctype]->ctx,
edges.src->dtype);
node_offsets[srctype] += sizeof(IdType)*edges.src->shape[0];
} }
if (edges.dst.defined()) { if (edges.dst.defined()) {
device->CopyDataFromTo( device->CopyDataFromTo(
edges.dst.Ptr<IdType>(), 0, edges.dst.Ptr<IdType>(), 0, all_nodes[dsttype].Ptr<IdType>(),
all_nodes[dsttype].Ptr<IdType>(), node_offsets[dsttype], sizeof(IdType) * edges.dst->shape[0],
node_offsets[dsttype], edges.dst->ctx, all_nodes[dsttype]->ctx, edges.dst->dtype);
sizeof(IdType)*edges.dst->shape[0], node_offsets[dsttype] += sizeof(IdType) * edges.dst->shape[0];
edges.dst->ctx,
all_nodes[dsttype]->ctx,
edges.dst->dtype);
node_offsets[dsttype] += sizeof(IdType)*edges.dst->shape[0];
} }
all_edges[i].push_back(edges); all_edges[i].push_back(edges);
} }
...@@ -185,29 +168,22 @@ CompactGraphsGPU( ...@@ -185,29 +168,22 @@ CompactGraphsGPU(
// number of unique nodes per type on CPU // number of unique nodes per type on CPU
std::vector<int64_t> num_induced_nodes(num_ntypes); std::vector<int64_t> num_induced_nodes(num_ntypes);
// number of unique nodes per type on GPU // number of unique nodes per type on GPU
int64_t * count_unique_device = static_cast<int64_t*>( int64_t *count_unique_device = static_cast<int64_t *>(
device->AllocWorkspace(ctx, sizeof(int64_t)*num_ntypes)); device->AllocWorkspace(ctx, sizeof(int64_t) * num_ntypes));
// the set of unique nodes per type // the set of unique nodes per type
std::vector<IdArray> induced_nodes(num_ntypes); std::vector<IdArray> induced_nodes(num_ntypes);
for (uint64_t ntype = 0; ntype < num_ntypes; ++ntype) { for (uint64_t ntype = 0; ntype < num_ntypes; ++ntype) {
induced_nodes[ntype] = NewIdArray(max_vertex_cnt[ntype], ctx, induced_nodes[ntype] =
sizeof(IdType)*8); NewIdArray(max_vertex_cnt[ntype], ctx, sizeof(IdType) * 8);
} }
BuildNodeMaps( BuildNodeMaps(
all_nodes, all_nodes, &node_maps, count_unique_device, &induced_nodes, stream);
&node_maps,
count_unique_device,
&induced_nodes,
stream);
device->CopyDataFromTo( device->CopyDataFromTo(
count_unique_device, 0, count_unique_device, 0, num_induced_nodes.data(), 0,
num_induced_nodes.data(), 0, sizeof(*num_induced_nodes.data()) * num_ntypes, ctx,
sizeof(*num_induced_nodes.data())*num_ntypes, DGLContext{kDGLCPU, 0}, DGLDataType{kDGLInt, 64, 1});
ctx,
DGLContext{kDGLCPU, 0},
DGLDataType{kDGLInt, 64, 1});
device->StreamSync(ctx, stream); device->StreamSync(ctx, stream);
// wait for the node counts to finish transferring // wait for the node counts to finish transferring
...@@ -230,22 +206,20 @@ CompactGraphsGPU( ...@@ -230,22 +206,20 @@ CompactGraphsGPU(
std::vector<IdArray> new_src; std::vector<IdArray> new_src;
std::vector<IdArray> new_dst; std::vector<IdArray> new_dst;
std::tie(new_src, new_dst) = MapEdges( std::tie(new_src, new_dst) =
curr_graph, all_edges[i], node_maps, stream); MapEdges(curr_graph, all_edges[i], node_maps, stream);
for (IdType etype = 0; etype < num_etypes; ++etype) { for (IdType etype = 0; etype < num_etypes; ++etype) {
IdType srctype, dsttype; IdType srctype, dsttype;
std::tie(srctype, dsttype) = curr_graph->GetEndpointTypes(etype); std::tie(srctype, dsttype) = curr_graph->GetEndpointTypes(etype);
rel_graphs.push_back(UnitGraph::CreateFromCOO( rel_graphs.push_back(UnitGraph::CreateFromCOO(
srctype == dsttype ? 1 : 2, srctype == dsttype ? 1 : 2, induced_nodes[srctype]->shape[0],
induced_nodes[srctype]->shape[0], induced_nodes[dsttype]->shape[0], new_src[etype], new_dst[etype]));
induced_nodes[dsttype]->shape[0],
new_src[etype],
new_dst[etype]));
} }
new_graphs.push_back(CreateHeteroGraph(meta_graph, rel_graphs, num_induced_nodes)); new_graphs.push_back(
CreateHeteroGraph(meta_graph, rel_graphs, num_induced_nodes));
} }
return std::make_pair(new_graphs, induced_nodes); return std::make_pair(new_graphs, induced_nodes);
...@@ -253,7 +227,7 @@ CompactGraphsGPU( ...@@ -253,7 +227,7 @@ CompactGraphsGPU(
} // namespace } // namespace
template<> template <>
std::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>> std::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>>
CompactGraphs<kDGLCUDA, int32_t>( CompactGraphs<kDGLCUDA, int32_t>(
const std::vector<HeteroGraphPtr> &graphs, const std::vector<HeteroGraphPtr> &graphs,
...@@ -261,7 +235,7 @@ CompactGraphs<kDGLCUDA, int32_t>( ...@@ -261,7 +235,7 @@ CompactGraphs<kDGLCUDA, int32_t>(
return CompactGraphsGPU<int32_t>(graphs, always_preserve); return CompactGraphsGPU<int32_t>(graphs, always_preserve);
} }
template<> template <>
std::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>> std::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>>
CompactGraphs<kDGLCUDA, int64_t>( CompactGraphs<kDGLCUDA, int64_t>(
const std::vector<HeteroGraphPtr> &graphs, const std::vector<HeteroGraphPtr> &graphs,
......
...@@ -20,13 +20,14 @@ ...@@ -20,13 +20,14 @@
#ifndef DGL_GRAPH_TRANSFORM_CUDA_CUDA_MAP_EDGES_CUH_ #ifndef DGL_GRAPH_TRANSFORM_CUDA_CUDA_MAP_EDGES_CUH_
#define DGL_GRAPH_TRANSFORM_CUDA_CUDA_MAP_EDGES_CUH_ #define DGL_GRAPH_TRANSFORM_CUDA_CUDA_MAP_EDGES_CUH_
#include <dgl/runtime/c_runtime_api.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <dgl/runtime/c_runtime_api.h>
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include <tuple> #include <tuple>
#include <vector>
#include <utility> #include <utility>
#include <vector>
#include "../../../runtime/cuda/cuda_common.h" #include "../../../runtime/cuda/cuda_common.h"
#include "../../../runtime/cuda/cuda_hashtable.cuh" #include "../../../runtime/cuda/cuda_hashtable.cuh"
...@@ -39,48 +40,46 @@ namespace transform { ...@@ -39,48 +40,46 @@ namespace transform {
namespace cuda { namespace cuda {
template<typename IdType, int BLOCK_SIZE, IdType TILE_SIZE> template <typename IdType, int BLOCK_SIZE, IdType TILE_SIZE>
__device__ void map_vertex_ids( __device__ void map_vertex_ids(
const IdType * const global, const IdType* const global, IdType* const new_global,
IdType * const new_global, const IdType num_vertices, const DeviceOrderedHashTable<IdType>& table) {
const IdType num_vertices,
const DeviceOrderedHashTable<IdType>& table) {
assert(BLOCK_SIZE == blockDim.x); assert(BLOCK_SIZE == blockDim.x);
using Mapping = typename OrderedHashTable<IdType>::Mapping; using Mapping = typename OrderedHashTable<IdType>::Mapping;
const IdType tile_start = TILE_SIZE*blockIdx.x; const IdType tile_start = TILE_SIZE * blockIdx.x;
const IdType tile_end = min(TILE_SIZE*(blockIdx.x+1), num_vertices); const IdType tile_end = min(TILE_SIZE * (blockIdx.x + 1), num_vertices);
for (IdType idx = threadIdx.x+tile_start; idx < tile_end; idx+=BLOCK_SIZE) { for (IdType idx = threadIdx.x + tile_start; idx < tile_end;
idx += BLOCK_SIZE) {
const Mapping& mapping = *table.Search(global[idx]); const Mapping& mapping = *table.Search(global[idx]);
new_global[idx] = mapping.local; new_global[idx] = mapping.local;
} }
} }
/** /**
* \brief Generate mapped edge endpoint ids. * \brief Generate mapped edge endpoint ids.
* *
* \tparam IdType The type of id. * \tparam IdType The type of id.
* \tparam BLOCK_SIZE The size of each thread block. * \tparam BLOCK_SIZE The size of each thread block.
* \tparam TILE_SIZE The number of edges to process per thread block. * \tparam TILE_SIZE The number of edges to process per thread block.
* \param global_srcs_device The source ids to map. * \param global_srcs_device The source ids to map.
* \param new_global_srcs_device The mapped source ids (output). * \param new_global_srcs_device The mapped source ids (output).
* \param global_dsts_device The destination ids to map. * \param global_dsts_device The destination ids to map.
* \param new_global_dsts_device The mapped destination ids (output). * \param new_global_dsts_device The mapped destination ids (output).
* \param num_edges The number of edges to map. * \param num_edges The number of edges to map.
* \param src_mapping The mapping of sources ids. * \param src_mapping The mapping of sources ids.
* \param src_hash_size The the size of source id hash table/mapping. * \param src_hash_size The the size of source id hash table/mapping.
* \param dst_mapping The mapping of destination ids. * \param dst_mapping The mapping of destination ids.
* \param dst_hash_size The the size of destination id hash table/mapping. * \param dst_hash_size The the size of destination id hash table/mapping.
*/ */
template<typename IdType, int BLOCK_SIZE, IdType TILE_SIZE> template <typename IdType, int BLOCK_SIZE, IdType TILE_SIZE>
__global__ void map_edge_ids( __global__ void map_edge_ids(
const IdType * const global_srcs_device, const IdType* const global_srcs_device,
IdType * const new_global_srcs_device, IdType* const new_global_srcs_device,
const IdType * const global_dsts_device, const IdType* const global_dsts_device,
IdType * const new_global_dsts_device, IdType* const new_global_dsts_device, const IdType num_edges,
const IdType num_edges,
DeviceOrderedHashTable<IdType> src_mapping, DeviceOrderedHashTable<IdType> src_mapping,
DeviceOrderedHashTable<IdType> dst_mapping) { DeviceOrderedHashTable<IdType> dst_mapping) {
assert(BLOCK_SIZE == blockDim.x); assert(BLOCK_SIZE == blockDim.x);
...@@ -88,87 +87,67 @@ __global__ void map_edge_ids( ...@@ -88,87 +87,67 @@ __global__ void map_edge_ids(
if (blockIdx.y == 0) { if (blockIdx.y == 0) {
map_vertex_ids<IdType, BLOCK_SIZE, TILE_SIZE>( map_vertex_ids<IdType, BLOCK_SIZE, TILE_SIZE>(
global_srcs_device, global_srcs_device, new_global_srcs_device, num_edges, src_mapping);
new_global_srcs_device,
num_edges,
src_mapping);
} else { } else {
map_vertex_ids<IdType, BLOCK_SIZE, TILE_SIZE>( map_vertex_ids<IdType, BLOCK_SIZE, TILE_SIZE>(
global_dsts_device, global_dsts_device, new_global_dsts_device, num_edges, dst_mapping);
new_global_dsts_device,
num_edges,
dst_mapping);
} }
} }
/** /**
* \brief Device level node maps for each node type. * \brief Device level node maps for each node type.
* *
* \param num_nodes Number of nodes per type. * \param num_nodes Number of nodes per type.
* \param offset When offset is set to 0, LhsHashTable is identical to RhsHashTable. * \param offset When offset is set to 0, LhsHashTable is identical to
* Or set to num_nodes.size()/2 to use seperated LhsHashTable and RhsHashTable. * RhsHashTable. Or set to num_nodes.size()/2 to use seperated
* \param ctx The DGL context. * LhsHashTable and RhsHashTable.
* \param stream The stream to operate on. * \param ctx The DGL context.
*/ * \param stream The stream to operate on.
template<typename IdType> */
template <typename IdType>
class DeviceNodeMap { class DeviceNodeMap {
public: public:
using Mapping = typename OrderedHashTable<IdType>::Mapping; using Mapping = typename OrderedHashTable<IdType>::Mapping;
DeviceNodeMap( DeviceNodeMap(
const std::vector<int64_t>& num_nodes, const std::vector<int64_t>& num_nodes, const int64_t offset,
const int64_t offset, DGLContext ctx, cudaStream_t stream)
DGLContext ctx, : num_types_(num_nodes.size()),
cudaStream_t stream) : rhs_offset_(offset),
num_types_(num_nodes.size()), hash_tables_(),
rhs_offset_(offset), ctx_(ctx) {
hash_tables_(),
ctx_(ctx) {
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
hash_tables_.reserve(num_types_); hash_tables_.reserve(num_types_);
for (int64_t i = 0; i < num_types_; ++i) { for (int64_t i = 0; i < num_types_; ++i) {
hash_tables_.emplace_back( hash_tables_.emplace_back(
new OrderedHashTable<IdType>( new OrderedHashTable<IdType>(num_nodes[i], ctx_, stream));
num_nodes[i],
ctx_,
stream));
} }
} }
OrderedHashTable<IdType>& LhsHashTable( OrderedHashTable<IdType>& LhsHashTable(const size_t index) {
const size_t index) {
return HashData(index); return HashData(index);
} }
OrderedHashTable<IdType>& RhsHashTable( OrderedHashTable<IdType>& RhsHashTable(const size_t index) {
const size_t index) { return HashData(index + rhs_offset_);
return HashData(index+rhs_offset_);
} }
const OrderedHashTable<IdType>& LhsHashTable( const OrderedHashTable<IdType>& LhsHashTable(const size_t index) const {
const size_t index) const {
return HashData(index); return HashData(index);
} }
const OrderedHashTable<IdType>& RhsHashTable( const OrderedHashTable<IdType>& RhsHashTable(const size_t index) const {
const size_t index) const { return HashData(index + rhs_offset_);
return HashData(index+rhs_offset_);
} }
IdType LhsHashSize( IdType LhsHashSize(const size_t index) const { return HashSize(index); }
const size_t index) const {
return HashSize(index);
}
IdType RhsHashSize( IdType RhsHashSize(const size_t index) const {
const size_t index) const { return HashSize(rhs_offset_ + index);
return HashSize(rhs_offset_+index);
} }
size_t Size() const { size_t Size() const { return hash_tables_.size(); }
return hash_tables_.size();
}
private: private:
int64_t num_types_; int64_t num_types_;
...@@ -176,45 +155,35 @@ class DeviceNodeMap { ...@@ -176,45 +155,35 @@ class DeviceNodeMap {
std::vector<std::unique_ptr<OrderedHashTable<IdType>>> hash_tables_; std::vector<std::unique_ptr<OrderedHashTable<IdType>>> hash_tables_;
DGLContext ctx_; DGLContext ctx_;
inline OrderedHashTable<IdType>& HashData( inline OrderedHashTable<IdType>& HashData(const size_t index) {
const size_t index) {
CHECK_LT(index, hash_tables_.size()); CHECK_LT(index, hash_tables_.size());
return *hash_tables_[index]; return *hash_tables_[index];
} }
inline const OrderedHashTable<IdType>& HashData( inline const OrderedHashTable<IdType>& HashData(const size_t index) const {
const size_t index) const {
CHECK_LT(index, hash_tables_.size()); CHECK_LT(index, hash_tables_.size());
return *hash_tables_[index]; return *hash_tables_[index];
} }
inline IdType HashSize( inline IdType HashSize(const size_t index) const {
const size_t index) const {
return HashData(index).size(); return HashData(index).size();
} }
}; };
template<typename IdType> template <typename IdType>
inline size_t RoundUpDiv( inline size_t RoundUpDiv(const IdType num, const size_t divisor) {
const IdType num, return static_cast<IdType>(num / divisor) + (num % divisor == 0 ? 0 : 1);
const size_t divisor) {
return static_cast<IdType>(num/divisor) + (num % divisor == 0 ? 0 : 1);
} }
template<typename IdType> template <typename IdType>
inline IdType RoundUp( inline IdType RoundUp(const IdType num, const size_t unit) {
const IdType num, return RoundUpDiv(num, unit) * unit;
const size_t unit) {
return RoundUpDiv(num, unit)*unit;
} }
template<typename IdType> template <typename IdType>
std::tuple<std::vector<IdArray>, std::vector<IdArray>> std::tuple<std::vector<IdArray>, std::vector<IdArray>> MapEdges(
MapEdges( HeteroGraphPtr graph, const std::vector<EdgeArray>& edge_sets,
HeteroGraphPtr graph, const DeviceNodeMap<IdType>& node_map, cudaStream_t stream) {
const std::vector<EdgeArray>& edge_sets,
const DeviceNodeMap<IdType>& node_map,
cudaStream_t stream) {
constexpr const int BLOCK_SIZE = 128; constexpr const int BLOCK_SIZE = 128;
constexpr const size_t TILE_SIZE = 1024; constexpr const size_t TILE_SIZE = 1024;
...@@ -233,8 +202,8 @@ MapEdges( ...@@ -233,8 +202,8 @@ MapEdges(
if (edges.id.defined() && edges.src->shape[0] > 0) { if (edges.id.defined() && edges.src->shape[0] > 0) {
const int64_t num_edges = edges.src->shape[0]; const int64_t num_edges = edges.src->shape[0];
new_lhs.emplace_back(NewIdArray(num_edges, ctx, sizeof(IdType)*8)); new_lhs.emplace_back(NewIdArray(num_edges, ctx, sizeof(IdType) * 8));
new_rhs.emplace_back(NewIdArray(num_edges, ctx, sizeof(IdType)*8)); new_rhs.emplace_back(NewIdArray(num_edges, ctx, sizeof(IdType) * 8));
const auto src_dst_types = graph->GetEndpointTypes(etype); const auto src_dst_types = graph->GetEndpointTypes(etype);
const int src_type = src_dst_types.first; const int src_type = src_dst_types.first;
...@@ -244,20 +213,17 @@ MapEdges( ...@@ -244,20 +213,17 @@ MapEdges(
const dim3 block(BLOCK_SIZE); const dim3 block(BLOCK_SIZE);
// map the srcs // map the srcs
CUDA_KERNEL_CALL((map_edge_ids<IdType, BLOCK_SIZE, TILE_SIZE>), CUDA_KERNEL_CALL(
grid, block, 0, stream, (map_edge_ids<IdType, BLOCK_SIZE, TILE_SIZE>), grid, block, 0, stream,
edges.src.Ptr<IdType>(), edges.src.Ptr<IdType>(), new_lhs.back().Ptr<IdType>(),
new_lhs.back().Ptr<IdType>(), edges.dst.Ptr<IdType>(), new_rhs.back().Ptr<IdType>(), num_edges,
edges.dst.Ptr<IdType>(), node_map.LhsHashTable(src_type).DeviceHandle(),
new_rhs.back().Ptr<IdType>(), node_map.RhsHashTable(dst_type).DeviceHandle());
num_edges,
node_map.LhsHashTable(src_type).DeviceHandle(),
node_map.RhsHashTable(dst_type).DeviceHandle());
} else { } else {
new_lhs.emplace_back( new_lhs.emplace_back(
aten::NullArray(DGLDataType{kDGLInt, sizeof(IdType)*8, 1}, ctx)); aten::NullArray(DGLDataType{kDGLInt, sizeof(IdType) * 8, 1}, ctx));
new_rhs.emplace_back( new_rhs.emplace_back(
aten::NullArray(DGLDataType{kDGLInt, sizeof(IdType)*8, 1}, ctx)); aten::NullArray(DGLDataType{kDGLInt, sizeof(IdType) * 8, 1}, ctx));
} }
} }
...@@ -265,7 +231,6 @@ MapEdges( ...@@ -265,7 +231,6 @@ MapEdges(
std::move(new_lhs), std::move(new_rhs)); std::move(new_lhs), std::move(new_rhs));
} }
} // namespace cuda } // namespace cuda
} // namespace transform } // namespace transform
} // namespace dgl } // namespace dgl
......
...@@ -18,13 +18,13 @@ ...@@ -18,13 +18,13 @@
* ids. * ids.
*/ */
#include <dgl/runtime/device_api.h>
#include <dgl/immutable_graph.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <utility> #include <dgl/immutable_graph.h>
#include <dgl/runtime/device_api.h>
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include <utility>
#include "../../../runtime/cuda/cuda_common.h" #include "../../../runtime/cuda/cuda_common.h"
#include "../../heterograph.h" #include "../../heterograph.h"
...@@ -40,42 +40,36 @@ namespace transform { ...@@ -40,42 +40,36 @@ namespace transform {
namespace { namespace {
template<typename IdType> template <typename IdType>
class DeviceNodeMapMaker { class DeviceNodeMapMaker {
public: public:
explicit DeviceNodeMapMaker( explicit DeviceNodeMapMaker(const std::vector<int64_t>& maxNodesPerType)
const std::vector<int64_t>& maxNodesPerType) : : max_num_nodes_(0) {
max_num_nodes_(0) { max_num_nodes_ =
max_num_nodes_ = *std::max_element(maxNodesPerType.begin(), *std::max_element(maxNodesPerType.begin(), maxNodesPerType.end());
maxNodesPerType.end());
} }
/** /**
* \brief This function builds node maps for each node type, preserving the * \brief This function builds node maps for each node type, preserving the
* order of the input nodes. Here it is assumed the lhs_nodes are not unique, * order of the input nodes. Here it is assumed the lhs_nodes are not unique,
* and thus a unique list is generated. * and thus a unique list is generated.
* *
* \param lhs_nodes The set of source input nodes. * \param lhs_nodes The set of source input nodes.
* \param rhs_nodes The set of destination input nodes. * \param rhs_nodes The set of destination input nodes.
* \param node_maps The node maps to be constructed. * \param node_maps The node maps to be constructed.
* \param count_lhs_device The number of unique source nodes (on the GPU). * \param count_lhs_device The number of unique source nodes (on the GPU).
* \param lhs_device The unique source nodes (on the GPU). * \param lhs_device The unique source nodes (on the GPU).
* \param stream The stream to operate on. * \param stream The stream to operate on.
*/ */
void Make( void Make(
const std::vector<IdArray>& lhs_nodes, const std::vector<IdArray>& lhs_nodes,
const std::vector<IdArray>& rhs_nodes, const std::vector<IdArray>& rhs_nodes,
DeviceNodeMap<IdType> * const node_maps, DeviceNodeMap<IdType>* const node_maps, int64_t* const count_lhs_device,
int64_t * const count_lhs_device, std::vector<IdArray>* const lhs_device, cudaStream_t stream) {
std::vector<IdArray>* const lhs_device,
cudaStream_t stream) {
const int64_t num_ntypes = lhs_nodes.size() + rhs_nodes.size(); const int64_t num_ntypes = lhs_nodes.size() + rhs_nodes.size();
CUDA_CALL(cudaMemsetAsync( CUDA_CALL(cudaMemsetAsync(
count_lhs_device, count_lhs_device, 0, num_ntypes * sizeof(*count_lhs_device), stream));
0,
num_ntypes*sizeof(*count_lhs_device),
stream));
// possibly dublicate lhs nodes // possibly dublicate lhs nodes
const int64_t lhs_num_ntypes = static_cast<int64_t>(lhs_nodes.size()); const int64_t lhs_num_ntypes = static_cast<int64_t>(lhs_nodes.size());
...@@ -84,10 +78,8 @@ class DeviceNodeMapMaker { ...@@ -84,10 +78,8 @@ class DeviceNodeMapMaker {
if (nodes->shape[0] > 0) { if (nodes->shape[0] > 0) {
CHECK_EQ(nodes->ctx.device_type, kDGLCUDA); CHECK_EQ(nodes->ctx.device_type, kDGLCUDA);
node_maps->LhsHashTable(ntype).FillWithDuplicates( node_maps->LhsHashTable(ntype).FillWithDuplicates(
nodes.Ptr<IdType>(), nodes.Ptr<IdType>(), nodes->shape[0],
nodes->shape[0], (*lhs_device)[ntype].Ptr<IdType>(), count_lhs_device + ntype,
(*lhs_device)[ntype].Ptr<IdType>(),
count_lhs_device+ntype,
stream); stream);
} }
} }
...@@ -98,28 +90,25 @@ class DeviceNodeMapMaker { ...@@ -98,28 +90,25 @@ class DeviceNodeMapMaker {
const IdArray& nodes = rhs_nodes[ntype]; const IdArray& nodes = rhs_nodes[ntype];
if (nodes->shape[0] > 0) { if (nodes->shape[0] > 0) {
node_maps->RhsHashTable(ntype).FillWithUnique( node_maps->RhsHashTable(ntype).FillWithUnique(
nodes.Ptr<IdType>(), nodes.Ptr<IdType>(), nodes->shape[0], stream);
nodes->shape[0],
stream);
} }
} }
} }
/** /**
* \brief This function builds node maps for each node type, preserving the * \brief This function builds node maps for each node type, preserving the
* order of the input nodes. Here it is assumed both lhs_nodes and rhs_nodes * order of the input nodes. Here it is assumed both lhs_nodes and rhs_nodes
* are unique. * are unique.
* *
* \param lhs_nodes The set of source input nodes. * \param lhs_nodes The set of source input nodes.
* \param rhs_nodes The set of destination input nodes. * \param rhs_nodes The set of destination input nodes.
* \param node_maps The node maps to be constructed. * \param node_maps The node maps to be constructed.
* \param stream The stream to operate on. * \param stream The stream to operate on.
*/ */
void Make( void Make(
const std::vector<IdArray>& lhs_nodes, const std::vector<IdArray>& lhs_nodes,
const std::vector<IdArray>& rhs_nodes, const std::vector<IdArray>& rhs_nodes,
DeviceNodeMap<IdType> * const node_maps, DeviceNodeMap<IdType>* const node_maps, cudaStream_t stream) {
cudaStream_t stream) {
const int64_t num_ntypes = lhs_nodes.size() + rhs_nodes.size(); const int64_t num_ntypes = lhs_nodes.size() + rhs_nodes.size();
// unique lhs nodes // unique lhs nodes
...@@ -129,9 +118,7 @@ class DeviceNodeMapMaker { ...@@ -129,9 +118,7 @@ class DeviceNodeMapMaker {
if (nodes->shape[0] > 0) { if (nodes->shape[0] > 0) {
CHECK_EQ(nodes->ctx.device_type, kDGLCUDA); CHECK_EQ(nodes->ctx.device_type, kDGLCUDA);
node_maps->LhsHashTable(ntype).FillWithUnique( node_maps->LhsHashTable(ntype).FillWithUnique(
nodes.Ptr<IdType>(), nodes.Ptr<IdType>(), nodes->shape[0], stream);
nodes->shape[0],
stream);
} }
} }
...@@ -141,9 +128,7 @@ class DeviceNodeMapMaker { ...@@ -141,9 +128,7 @@ class DeviceNodeMapMaker {
const IdArray& nodes = rhs_nodes[ntype]; const IdArray& nodes = rhs_nodes[ntype];
if (nodes->shape[0] > 0) { if (nodes->shape[0] > 0) {
node_maps->RhsHashTable(ntype).FillWithUnique( node_maps->RhsHashTable(ntype).FillWithUnique(
nodes.Ptr<IdType>(), nodes.Ptr<IdType>(), nodes->shape[0], stream);
nodes->shape[0],
stream);
} }
} }
} }
...@@ -152,20 +137,15 @@ class DeviceNodeMapMaker { ...@@ -152,20 +137,15 @@ class DeviceNodeMapMaker {
IdType max_num_nodes_; IdType max_num_nodes_;
}; };
// Since partial specialization is not allowed for functions, use this as an // Since partial specialization is not allowed for functions, use this as an
// intermediate for ToBlock where XPU = kDGLCUDA. // intermediate for ToBlock where XPU = kDGLCUDA.
template<typename IdType> template <typename IdType>
std::tuple<HeteroGraphPtr, std::vector<IdArray>> std::tuple<HeteroGraphPtr, std::vector<IdArray>> ToBlockGPU(
ToBlockGPU( HeteroGraphPtr graph, const std::vector<IdArray>& rhs_nodes,
HeteroGraphPtr graph, const bool include_rhs_in_lhs, std::vector<IdArray>* const lhs_nodes_ptr) {
const std::vector<IdArray> &rhs_nodes,
const bool include_rhs_in_lhs,
std::vector<IdArray>* const lhs_nodes_ptr) {
std::vector<IdArray>& lhs_nodes = *lhs_nodes_ptr; std::vector<IdArray>& lhs_nodes = *lhs_nodes_ptr;
const bool generate_lhs_nodes = lhs_nodes.empty(); const bool generate_lhs_nodes = lhs_nodes.empty();
const auto& ctx = graph->Context(); const auto& ctx = graph->Context();
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
...@@ -176,16 +156,17 @@ ToBlockGPU( ...@@ -176,16 +156,17 @@ ToBlockGPU(
} }
// Since DST nodes are included in SRC nodes, a common requirement is to fetch // Since DST nodes are included in SRC nodes, a common requirement is to fetch
// the DST node features from the SRC nodes features. To avoid expensive sparse lookup, // the DST node features from the SRC nodes features. To avoid expensive
// the function assures that the DST nodes in both SRC and DST sets have the same ids. // sparse lookup, the function assures that the DST nodes in both SRC and DST
// As a result, given the node feature tensor ``X`` of type ``utype``, // sets have the same ids. As a result, given the node feature tensor ``X`` of
// the following code finds the corresponding DST node features of type ``vtype``: // type ``utype``, the following code finds the corresponding DST node
// features of type ``vtype``:
const int64_t num_etypes = graph->NumEdgeTypes(); const int64_t num_etypes = graph->NumEdgeTypes();
const int64_t num_ntypes = graph->NumVertexTypes(); const int64_t num_ntypes = graph->NumVertexTypes();
CHECK(rhs_nodes.size() == static_cast<size_t>(num_ntypes)) CHECK(rhs_nodes.size() == static_cast<size_t>(num_ntypes))
<< "rhs_nodes not given for every node type"; << "rhs_nodes not given for every node type";
std::vector<EdgeArray> edge_arrays(num_etypes); std::vector<EdgeArray> edge_arrays(num_etypes);
for (int64_t etype = 0; etype < num_etypes; ++etype) { for (int64_t etype = 0; etype < num_etypes; ++etype) {
...@@ -197,9 +178,9 @@ ToBlockGPU( ...@@ -197,9 +178,9 @@ ToBlockGPU(
} }
// count lhs and rhs nodes // count lhs and rhs nodes
std::vector<int64_t> maxNodesPerType(num_ntypes*2, 0); std::vector<int64_t> maxNodesPerType(num_ntypes * 2, 0);
for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) { for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) {
maxNodesPerType[ntype+num_ntypes] += rhs_nodes[ntype]->shape[0]; maxNodesPerType[ntype + num_ntypes] += rhs_nodes[ntype]->shape[0];
if (generate_lhs_nodes) { if (generate_lhs_nodes) {
if (include_rhs_in_lhs) { if (include_rhs_in_lhs) {
...@@ -226,16 +207,16 @@ ToBlockGPU( ...@@ -226,16 +207,16 @@ ToBlockGPU(
if (generate_lhs_nodes) { if (generate_lhs_nodes) {
std::vector<int64_t> src_node_offsets(num_ntypes, 0); std::vector<int64_t> src_node_offsets(num_ntypes, 0);
for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) { for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) {
src_nodes[ntype] = NewIdArray(maxNodesPerType[ntype], ctx, src_nodes[ntype] =
sizeof(IdType)*8); NewIdArray(maxNodesPerType[ntype], ctx, sizeof(IdType) * 8);
if (include_rhs_in_lhs) { if (include_rhs_in_lhs) {
// place rhs nodes first // place rhs nodes first
device->CopyDataFromTo(rhs_nodes[ntype].Ptr<IdType>(), 0, device->CopyDataFromTo(
src_nodes[ntype].Ptr<IdType>(), src_node_offsets[ntype], rhs_nodes[ntype].Ptr<IdType>(), 0, src_nodes[ntype].Ptr<IdType>(),
sizeof(IdType)*rhs_nodes[ntype]->shape[0], src_node_offsets[ntype],
rhs_nodes[ntype]->ctx, src_nodes[ntype]->ctx, sizeof(IdType) * rhs_nodes[ntype]->shape[0], rhs_nodes[ntype]->ctx,
rhs_nodes[ntype]->dtype); src_nodes[ntype]->ctx, rhs_nodes[ntype]->dtype);
src_node_offsets[ntype] += sizeof(IdType)*rhs_nodes[ntype]->shape[0]; src_node_offsets[ntype] += sizeof(IdType) * rhs_nodes[ntype]->shape[0];
} }
} }
for (int64_t etype = 0; etype < num_etypes; ++etype) { for (int64_t etype = 0; etype < num_etypes; ++etype) {
...@@ -244,14 +225,13 @@ ToBlockGPU( ...@@ -244,14 +225,13 @@ ToBlockGPU(
if (edge_arrays[etype].src.defined()) { if (edge_arrays[etype].src.defined()) {
device->CopyDataFromTo( device->CopyDataFromTo(
edge_arrays[etype].src.Ptr<IdType>(), 0, edge_arrays[etype].src.Ptr<IdType>(), 0,
src_nodes[srctype].Ptr<IdType>(), src_nodes[srctype].Ptr<IdType>(), src_node_offsets[srctype],
src_node_offsets[srctype], sizeof(IdType) * edge_arrays[etype].src->shape[0],
sizeof(IdType)*edge_arrays[etype].src->shape[0], rhs_nodes[srctype]->ctx, src_nodes[srctype]->ctx,
rhs_nodes[srctype]->ctx,
src_nodes[srctype]->ctx,
rhs_nodes[srctype]->dtype); rhs_nodes[srctype]->dtype);
src_node_offsets[srctype] += sizeof(IdType)*edge_arrays[etype].src->shape[0]; src_node_offsets[srctype] +=
sizeof(IdType) * edge_arrays[etype].src->shape[0];
} }
} }
} else { } else {
...@@ -267,47 +247,35 @@ ToBlockGPU( ...@@ -267,47 +247,35 @@ ToBlockGPU(
if (generate_lhs_nodes) { if (generate_lhs_nodes) {
lhs_nodes.reserve(num_ntypes); lhs_nodes.reserve(num_ntypes);
for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) { for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) {
lhs_nodes.emplace_back(NewIdArray( lhs_nodes.emplace_back(
maxNodesPerType[ntype], ctx, sizeof(IdType)*8)); NewIdArray(maxNodesPerType[ntype], ctx, sizeof(IdType) * 8));
} }
} }
std::vector<int64_t> num_nodes_per_type(num_ntypes*2); std::vector<int64_t> num_nodes_per_type(num_ntypes * 2);
// populate RHS nodes from what we already know // populate RHS nodes from what we already know
for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) { for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) {
num_nodes_per_type[num_ntypes+ntype] = rhs_nodes[ntype]->shape[0]; num_nodes_per_type[num_ntypes + ntype] = rhs_nodes[ntype]->shape[0];
} }
// populate the mappings // populate the mappings
if (generate_lhs_nodes) { if (generate_lhs_nodes) {
int64_t * count_lhs_device = static_cast<int64_t*>( int64_t* count_lhs_device = static_cast<int64_t*>(
device->AllocWorkspace(ctx, sizeof(int64_t)*num_ntypes*2)); device->AllocWorkspace(ctx, sizeof(int64_t) * num_ntypes * 2));
maker.Make( maker.Make(
src_nodes, src_nodes, rhs_nodes, &node_maps, count_lhs_device, &lhs_nodes, stream);
rhs_nodes,
&node_maps,
count_lhs_device,
&lhs_nodes,
stream);
device->CopyDataFromTo( device->CopyDataFromTo(
count_lhs_device, 0, count_lhs_device, 0, num_nodes_per_type.data(), 0,
num_nodes_per_type.data(), 0, sizeof(*num_nodes_per_type.data()) * num_ntypes, ctx,
sizeof(*num_nodes_per_type.data())*num_ntypes, DGLContext{kDGLCPU, 0}, DGLDataType{kDGLInt, 64, 1});
ctx,
DGLContext{kDGLCPU, 0},
DGLDataType{kDGLInt, 64, 1});
device->StreamSync(ctx, stream); device->StreamSync(ctx, stream);
// wait for the node counts to finish transferring // wait for the node counts to finish transferring
device->FreeWorkspace(ctx, count_lhs_device); device->FreeWorkspace(ctx, count_lhs_device);
} else { } else {
maker.Make( maker.Make(lhs_nodes, rhs_nodes, &node_maps, stream);
lhs_nodes,
rhs_nodes,
&node_maps,
stream);
for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) { for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) {
num_nodes_per_type[ntype] = lhs_nodes[ntype]->shape[0]; num_nodes_per_type[ntype] = lhs_nodes[ntype]->shape[0];
...@@ -321,7 +289,7 @@ ToBlockGPU( ...@@ -321,7 +289,7 @@ ToBlockGPU(
induced_edges.push_back(edge_arrays[etype].id); induced_edges.push_back(edge_arrays[etype].id);
} else { } else {
induced_edges.push_back( induced_edges.push_back(
aten::NullArray(DGLDataType{kDGLInt, sizeof(IdType)*8, 1}, ctx)); aten::NullArray(DGLDataType{kDGLInt, sizeof(IdType) * 8, 1}, ctx));
} }
} }
...@@ -329,8 +297,8 @@ ToBlockGPU( ...@@ -329,8 +297,8 @@ ToBlockGPU(
const auto meta_graph = graph->meta_graph(); const auto meta_graph = graph->meta_graph();
const EdgeArray etypes = meta_graph->Edges("eid"); const EdgeArray etypes = meta_graph->Edges("eid");
const IdArray new_dst = Add(etypes.dst, num_ntypes); const IdArray new_dst = Add(etypes.dst, num_ntypes);
const auto new_meta_graph = ImmutableGraph::CreateFromCOO( const auto new_meta_graph =
num_ntypes * 2, etypes.src, new_dst); ImmutableGraph::CreateFromCOO(num_ntypes * 2, etypes.src, new_dst);
// allocate vector for graph relations while GPU is busy // allocate vector for graph relations while GPU is busy
std::vector<HeteroGraphPtr> rel_graphs; std::vector<HeteroGraphPtr> rel_graphs;
...@@ -358,20 +326,17 @@ ToBlockGPU( ...@@ -358,20 +326,17 @@ ToBlockGPU(
// No rhs nodes are given for this edge type. Create an empty graph. // No rhs nodes are given for this edge type. Create an empty graph.
rel_graphs.push_back(CreateFromCOO( rel_graphs.push_back(CreateFromCOO(
2, lhs_nodes[srctype]->shape[0], rhs_nodes[dsttype]->shape[0], 2, lhs_nodes[srctype]->shape[0], rhs_nodes[dsttype]->shape[0],
aten::NullArray(DGLDataType{kDGLInt, sizeof(IdType)*8, 1}, ctx), aten::NullArray(DGLDataType{kDGLInt, sizeof(IdType) * 8, 1}, ctx),
aten::NullArray(DGLDataType{kDGLInt, sizeof(IdType)*8, 1}, ctx))); aten::NullArray(DGLDataType{kDGLInt, sizeof(IdType) * 8, 1}, ctx)));
} else { } else {
rel_graphs.push_back(CreateFromCOO( rel_graphs.push_back(CreateFromCOO(
2, 2, lhs_nodes[srctype]->shape[0], rhs_nodes[dsttype]->shape[0],
lhs_nodes[srctype]->shape[0], new_lhs[etype], new_rhs[etype]));
rhs_nodes[dsttype]->shape[0],
new_lhs[etype],
new_rhs[etype]));
} }
} }
HeteroGraphPtr new_graph = CreateHeteroGraph( HeteroGraphPtr new_graph =
new_meta_graph, rel_graphs, num_nodes_per_type); CreateHeteroGraph(new_meta_graph, rel_graphs, num_nodes_per_type);
// return the new graph, the new src nodes, and new edges // return the new graph, the new src nodes, and new edges
return std::make_tuple(new_graph, induced_edges); return std::make_tuple(new_graph, induced_edges);
...@@ -379,26 +344,22 @@ ToBlockGPU( ...@@ -379,26 +344,22 @@ ToBlockGPU(
} // namespace } // namespace
// Use explicit names to get around MSVC's broken mangling that thinks the following two // Use explicit names to get around MSVC's broken mangling that thinks the
// functions are the same. // following two functions are the same. Using template<> fails to export the
// Using template<> fails to export the symbols. // symbols.
std::tuple<HeteroGraphPtr, std::vector<IdArray>> std::tuple<HeteroGraphPtr, std::vector<IdArray>>
// ToBlock<kDGLCUDA, int32_t> // ToBlock<kDGLCUDA, int32_t>
ToBlockGPU32( ToBlockGPU32(
HeteroGraphPtr graph, HeteroGraphPtr graph, const std::vector<IdArray>& rhs_nodes,
const std::vector<IdArray> &rhs_nodes, bool include_rhs_in_lhs, std::vector<IdArray>* const lhs_nodes) {
bool include_rhs_in_lhs,
std::vector<IdArray>* const lhs_nodes) {
return ToBlockGPU<int32_t>(graph, rhs_nodes, include_rhs_in_lhs, lhs_nodes); return ToBlockGPU<int32_t>(graph, rhs_nodes, include_rhs_in_lhs, lhs_nodes);
} }
std::tuple<HeteroGraphPtr, std::vector<IdArray>> std::tuple<HeteroGraphPtr, std::vector<IdArray>>
// ToBlock<kDGLCUDA, int64_t> // ToBlock<kDGLCUDA, int64_t>
ToBlockGPU64( ToBlockGPU64(
HeteroGraphPtr graph, HeteroGraphPtr graph, const std::vector<IdArray>& rhs_nodes,
const std::vector<IdArray> &rhs_nodes, bool include_rhs_in_lhs, std::vector<IdArray>* const lhs_nodes) {
bool include_rhs_in_lhs,
std::vector<IdArray>* const lhs_nodes) {
return ToBlockGPU<int64_t>(graph, rhs_nodes, include_rhs_in_lhs, lhs_nodes); return ToBlockGPU<int64_t>(graph, rhs_nodes, include_rhs_in_lhs, lhs_nodes);
} }
......
...@@ -4,17 +4,19 @@ ...@@ -4,17 +4,19 @@
* \brief k-nearest-neighbor (KNN) implementation (cuda) * \brief k-nearest-neighbor (KNN) implementation (cuda)
*/ */
#include <curand_kernel.h>
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/random.h> #include <dgl/random.h>
#include <dgl/runtime/device_api.h> #include <dgl/runtime/device_api.h>
#include <curand_kernel.h>
#include <algorithm> #include <algorithm>
#include <limits>
#include <string> #include <string>
#include <vector> #include <vector>
#include <limits>
#include "../../../array/cuda/dgl_cub.cuh" #include "../../../array/cuda/dgl_cub.cuh"
#include "../../../runtime/cuda/cuda_common.h"
#include "../../../array/cuda/utils.h" #include "../../../array/cuda/utils.h"
#include "../../../runtime/cuda/cuda_common.h"
#include "../knn.h" #include "../knn.h"
namespace dgl { namespace dgl {
...@@ -26,12 +28,12 @@ namespace impl { ...@@ -26,12 +28,12 @@ namespace impl {
*/ */
template <typename Type> template <typename Type>
struct SharedMemory { struct SharedMemory {
__device__ inline operator Type* () { __device__ inline operator Type*() {
extern __shared__ int __smem[]; extern __shared__ int __smem[];
return reinterpret_cast<Type*>(__smem); return reinterpret_cast<Type*>(__smem);
} }
__device__ inline operator const Type* () const { __device__ inline operator const Type*() const {
extern __shared__ int __smem[]; extern __shared__ int __smem[];
return reinterpret_cast<Type*>(__smem); return reinterpret_cast<Type*>(__smem);
} }
...@@ -41,12 +43,12 @@ struct SharedMemory { ...@@ -41,12 +43,12 @@ struct SharedMemory {
// access compile errors // access compile errors
template <> template <>
struct SharedMemory<double> { struct SharedMemory<double> {
__device__ inline operator double* () { __device__ inline operator double*() {
extern __shared__ double __smem_d[]; extern __shared__ double __smem_d[];
return reinterpret_cast<double*>(__smem_d); return reinterpret_cast<double*>(__smem_d);
} }
__device__ inline operator const double* () const { __device__ inline operator const double*() const {
extern __shared__ double __smem_d[]; extern __shared__ double __smem_d[];
return reinterpret_cast<double*>(__smem_d); return reinterpret_cast<double*>(__smem_d);
} }
...@@ -54,9 +56,8 @@ struct SharedMemory<double> { ...@@ -54,9 +56,8 @@ struct SharedMemory<double> {
/*! \brief Compute Euclidean distance between two vectors in a cuda kernel */ /*! \brief Compute Euclidean distance between two vectors in a cuda kernel */
template <typename FloatType, typename IdType> template <typename FloatType, typename IdType>
__device__ FloatType EuclideanDist(const FloatType* vec1, __device__ FloatType
const FloatType* vec2, EuclideanDist(const FloatType* vec1, const FloatType* vec2, const int64_t dim) {
const int64_t dim) {
FloatType dist = 0; FloatType dist = 0;
IdType idx = 0; IdType idx = 0;
for (; idx < dim - 3; idx += 4) { for (; idx < dim - 3; idx += 4) {
...@@ -82,10 +83,9 @@ __device__ FloatType EuclideanDist(const FloatType* vec1, ...@@ -82,10 +83,9 @@ __device__ FloatType EuclideanDist(const FloatType* vec1,
* than the worst distance. * than the worst distance.
*/ */
template <typename FloatType, typename IdType> template <typename FloatType, typename IdType>
__device__ FloatType EuclideanDistWithCheck(const FloatType* vec1, __device__ FloatType EuclideanDistWithCheck(
const FloatType* vec2, const FloatType* vec1, const FloatType* vec2, const int64_t dim,
const int64_t dim, const FloatType worst_dist) {
const FloatType worst_dist) {
FloatType dist = 0; FloatType dist = 0;
IdType idx = 0; IdType idx = 0;
bool early_stop = false; bool early_stop = false;
...@@ -151,9 +151,9 @@ __device__ void BuildHeap(IdType* indices, FloatType* dists, int size) { ...@@ -151,9 +151,9 @@ __device__ void BuildHeap(IdType* indices, FloatType* dists, int size) {
} }
template <typename FloatType, typename IdType> template <typename FloatType, typename IdType>
__device__ void HeapInsert(IdType* indices, FloatType* dist, __device__ void HeapInsert(
IdType new_idx, FloatType new_dist, IdType* indices, FloatType* dist, IdType new_idx, FloatType new_dist,
int size, bool check_repeat = false) { int size, bool check_repeat = false) {
if (new_dist > dist[0]) return; if (new_dist > dist[0]) return;
// check if we have it // check if we have it
...@@ -192,9 +192,9 @@ __device__ void HeapInsert(IdType* indices, FloatType* dist, ...@@ -192,9 +192,9 @@ __device__ void HeapInsert(IdType* indices, FloatType* dist,
} }
template <typename FloatType, typename IdType> template <typename FloatType, typename IdType>
__device__ bool FlaggedHeapInsert(IdType* indices, FloatType* dist, bool* flags, __device__ bool FlaggedHeapInsert(
IdType new_idx, FloatType new_dist, bool new_flag, IdType* indices, FloatType* dist, bool* flags, IdType new_idx,
int size, bool check_repeat = false) { FloatType new_dist, bool new_flag, int size, bool check_repeat = false) {
if (new_dist > dist[0]) return false; if (new_dist > dist[0]) return false;
// check if we have it // check if we have it
...@@ -239,22 +239,26 @@ __device__ bool FlaggedHeapInsert(IdType* indices, FloatType* dist, bool* flags, ...@@ -239,22 +239,26 @@ __device__ bool FlaggedHeapInsert(IdType* indices, FloatType* dist, bool* flags,
} }
/*! /*!
* \brief Brute force kNN kernel. Compute distance for each pair of input points and get * \brief Brute force kNN kernel. Compute distance for each pair of input points
* the result directly (without a distance matrix). * and get the result directly (without a distance matrix).
*/ */
template <typename FloatType, typename IdType> template <typename FloatType, typename IdType>
__global__ void BruteforceKnnKernel(const FloatType* data_points, const IdType* data_offsets, __global__ void BruteforceKnnKernel(
const FloatType* query_points, const IdType* query_offsets, const FloatType* data_points, const IdType* data_offsets,
const int k, FloatType* dists, IdType* query_out, const FloatType* query_points, const IdType* query_offsets, const int k,
IdType* data_out, const int64_t num_batches, FloatType* dists, IdType* query_out, IdType* data_out,
const int64_t feature_size) { const int64_t num_batches, const int64_t feature_size) {
const IdType q_idx = blockIdx.x * blockDim.x + threadIdx.x; const IdType q_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (q_idx >= query_offsets[num_batches]) return; if (q_idx >= query_offsets[num_batches]) return;
IdType batch_idx = 0; IdType batch_idx = 0;
for (IdType b = 0; b < num_batches + 1; ++b) { for (IdType b = 0; b < num_batches + 1; ++b) {
if (query_offsets[b] > q_idx) { batch_idx = b - 1; break; } if (query_offsets[b] > q_idx) {
batch_idx = b - 1;
break;
}
} }
const IdType data_start = data_offsets[batch_idx], data_end = data_offsets[batch_idx + 1]; const IdType data_start = data_offsets[batch_idx],
data_end = data_offsets[batch_idx + 1];
for (IdType k_idx = 0; k_idx < k; ++k_idx) { for (IdType k_idx = 0; k_idx < k; ++k_idx) {
query_out[q_idx * k + k_idx] = q_idx; query_out[q_idx * k + k_idx] = q_idx;
...@@ -264,12 +268,12 @@ __global__ void BruteforceKnnKernel(const FloatType* data_points, const IdType* ...@@ -264,12 +268,12 @@ __global__ void BruteforceKnnKernel(const FloatType* data_points, const IdType*
for (IdType d_idx = data_start; d_idx < data_end; ++d_idx) { for (IdType d_idx = data_start; d_idx < data_end; ++d_idx) {
FloatType tmp_dist = EuclideanDistWithCheck<FloatType, IdType>( FloatType tmp_dist = EuclideanDistWithCheck<FloatType, IdType>(
query_points + q_idx * feature_size, query_points + q_idx * feature_size, data_points + d_idx * feature_size,
data_points + d_idx * feature_size, feature_size, worst_dist);
feature_size, worst_dist);
IdType out_offset = q_idx * k; IdType out_offset = q_idx * k;
HeapInsert<FloatType, IdType>(data_out + out_offset, dists + out_offset, d_idx, tmp_dist, k); HeapInsert<FloatType, IdType>(
data_out + out_offset, dists + out_offset, d_idx, tmp_dist, k);
worst_dist = dists[q_idx * k]; worst_dist = dists[q_idx * k];
} }
} }
...@@ -281,22 +285,19 @@ __global__ void BruteforceKnnKernel(const FloatType* data_points, const IdType* ...@@ -281,22 +285,19 @@ __global__ void BruteforceKnnKernel(const FloatType* data_points, const IdType*
* This kernel is faster when the dimension of input points is not large. * This kernel is faster when the dimension of input points is not large.
*/ */
template <typename FloatType, typename IdType> template <typename FloatType, typename IdType>
__global__ void BruteforceKnnShareKernel(const FloatType* data_points, __global__ void BruteforceKnnShareKernel(
const IdType* data_offsets, const FloatType* data_points, const IdType* data_offsets,
const FloatType* query_points, const FloatType* query_points, const IdType* query_offsets,
const IdType* query_offsets, const IdType* block_batch_id, const IdType* local_block_id, const int k,
const IdType* block_batch_id, FloatType* dists, IdType* query_out, IdType* data_out,
const IdType* local_block_id, const int64_t num_batches, const int64_t feature_size) {
const int k, FloatType* dists,
IdType* query_out, IdType* data_out,
const int64_t num_batches,
const int64_t feature_size) {
const IdType block_idx = static_cast<IdType>(blockIdx.x); const IdType block_idx = static_cast<IdType>(blockIdx.x);
const IdType block_size = static_cast<IdType>(blockDim.x); const IdType block_size = static_cast<IdType>(blockDim.x);
const IdType batch_idx = block_batch_id[block_idx]; const IdType batch_idx = block_batch_id[block_idx];
const IdType local_bid = local_block_id[block_idx]; const IdType local_bid = local_block_id[block_idx];
const IdType query_start = query_offsets[batch_idx] + block_size * local_bid; const IdType query_start = query_offsets[batch_idx] + block_size * local_bid;
const IdType query_end = min(query_start + block_size, query_offsets[batch_idx + 1]); const IdType query_end =
min(query_start + block_size, query_offsets[batch_idx + 1]);
if (query_start >= query_end) return; if (query_start >= query_end) return;
const IdType query_idx = query_start + threadIdx.x; const IdType query_idx = query_start + threadIdx.x;
const IdType data_start = data_offsets[batch_idx]; const IdType data_start = data_offsets[batch_idx];
...@@ -318,17 +319,20 @@ __global__ void BruteforceKnnShareKernel(const FloatType* data_points, ...@@ -318,17 +319,20 @@ __global__ void BruteforceKnnShareKernel(const FloatType* data_points,
if (query_idx < query_end) { if (query_idx < query_end) {
for (auto i = 0; i < feature_size; ++i) { for (auto i = 0; i < feature_size; ++i) {
// to avoid bank conflict, we use transpose here // to avoid bank conflict, we use transpose here
query_buff[threadIdx.x + i * block_size] = query_points[query_idx * feature_size + i]; query_buff[threadIdx.x + i * block_size] =
query_points[query_idx * feature_size + i];
} }
} }
// perform computation on each tile // perform computation on each tile
for (auto tile_start = data_start; tile_start < data_end; tile_start += block_size) { for (auto tile_start = data_start; tile_start < data_end;
tile_start += block_size) {
// each thread load one data point into the shared memory // each thread load one data point into the shared memory
IdType load_idx = tile_start + threadIdx.x; IdType load_idx = tile_start + threadIdx.x;
if (load_idx < data_end) { if (load_idx < data_end) {
for (auto i = 0; i < feature_size; ++i) { for (auto i = 0; i < feature_size; ++i) {
data_buff[threadIdx.x * feature_size + i] = data_points[load_idx * feature_size + i]; data_buff[threadIdx.x * feature_size + i] =
data_points[load_idx * feature_size + i];
} }
} }
__syncthreads(); __syncthreads();
...@@ -342,16 +346,20 @@ __global__ void BruteforceKnnShareKernel(const FloatType* data_points, ...@@ -342,16 +346,20 @@ __global__ void BruteforceKnnShareKernel(const FloatType* data_points,
IdType dim_idx = 0; IdType dim_idx = 0;
for (; dim_idx < feature_size - 3; dim_idx += 4) { for (; dim_idx < feature_size - 3; dim_idx += 4) {
FloatType diff0 = query_buff[threadIdx.x + block_size * (dim_idx)] FloatType diff0 = query_buff[threadIdx.x + block_size * (dim_idx)] -
- data_buff[d_idx * feature_size + dim_idx]; data_buff[d_idx * feature_size + dim_idx];
FloatType diff1 = query_buff[threadIdx.x + block_size * (dim_idx + 1)] FloatType diff1 =
- data_buff[d_idx * feature_size + dim_idx + 1]; query_buff[threadIdx.x + block_size * (dim_idx + 1)] -
FloatType diff2 = query_buff[threadIdx.x + block_size * (dim_idx + 2)] data_buff[d_idx * feature_size + dim_idx + 1];
- data_buff[d_idx * feature_size + dim_idx + 2]; FloatType diff2 =
FloatType diff3 = query_buff[threadIdx.x + block_size * (dim_idx + 3)] query_buff[threadIdx.x + block_size * (dim_idx + 2)] -
- data_buff[d_idx * feature_size + dim_idx + 3]; data_buff[d_idx * feature_size + dim_idx + 2];
FloatType diff3 =
tmp_dist += diff0 * diff0 + diff1 * diff1 + diff2 * diff2 + diff3 * diff3; query_buff[threadIdx.x + block_size * (dim_idx + 3)] -
data_buff[d_idx * feature_size + dim_idx + 3];
tmp_dist +=
diff0 * diff0 + diff1 * diff1 + diff2 * diff2 + diff3 * diff3;
if (tmp_dist > worst_dist) { if (tmp_dist > worst_dist) {
early_stop = true; early_stop = true;
...@@ -361,8 +369,9 @@ __global__ void BruteforceKnnShareKernel(const FloatType* data_points, ...@@ -361,8 +369,9 @@ __global__ void BruteforceKnnShareKernel(const FloatType* data_points,
} }
for (; dim_idx < feature_size; ++dim_idx) { for (; dim_idx < feature_size; ++dim_idx) {
const FloatType diff = query_buff[threadIdx.x + dim_idx * block_size] const FloatType diff =
- data_buff[d_idx * feature_size + dim_idx]; query_buff[threadIdx.x + dim_idx * block_size] -
data_buff[d_idx * feature_size + dim_idx];
tmp_dist += diff * diff; tmp_dist += diff * diff;
if (tmp_dist > worst_dist) { if (tmp_dist > worst_dist) {
...@@ -374,8 +383,8 @@ __global__ void BruteforceKnnShareKernel(const FloatType* data_points, ...@@ -374,8 +383,8 @@ __global__ void BruteforceKnnShareKernel(const FloatType* data_points,
if (early_stop) continue; if (early_stop) continue;
HeapInsert<FloatType, IdType>( HeapInsert<FloatType, IdType>(
res_buff + threadIdx.x * k, dist_buff + threadIdx.x * k, res_buff + threadIdx.x * k, dist_buff + threadIdx.x * k,
d_idx + tile_start, tmp_dist, k); d_idx + tile_start, tmp_dist, k);
worst_dist = dist_buff[threadIdx.x * k]; worst_dist = dist_buff[threadIdx.x * k];
} }
} }
...@@ -393,9 +402,9 @@ __global__ void BruteforceKnnShareKernel(const FloatType* data_points, ...@@ -393,9 +402,9 @@ __global__ void BruteforceKnnShareKernel(const FloatType* data_points,
/*! \brief determine the number of blocks for each segment */ /*! \brief determine the number of blocks for each segment */
template <typename IdType> template <typename IdType>
__global__ void GetNumBlockPerSegment(const IdType* offsets, IdType* out, __global__ void GetNumBlockPerSegment(
const int64_t batch_size, const IdType* offsets, IdType* out, const int64_t batch_size,
const int64_t block_size) { const int64_t block_size) {
const IdType idx = blockIdx.x * blockDim.x + threadIdx.x; const IdType idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < batch_size) { if (idx < batch_size) {
out[idx] = (offsets[idx + 1] - offsets[idx] - 1) / block_size + 1; out[idx] = (offsets[idx + 1] - offsets[idx] - 1) / block_size + 1;
...@@ -404,9 +413,9 @@ __global__ void GetNumBlockPerSegment(const IdType* offsets, IdType* out, ...@@ -404,9 +413,9 @@ __global__ void GetNumBlockPerSegment(const IdType* offsets, IdType* out,
/*! \brief Get the batch index and local index in segment for each block */ /*! \brief Get the batch index and local index in segment for each block */
template <typename IdType> template <typename IdType>
__global__ void GetBlockInfo(const IdType* num_block_prefixsum, __global__ void GetBlockInfo(
IdType* block_batch_id, IdType* local_block_id, const IdType* num_block_prefixsum, IdType* block_batch_id,
size_t batch_size, size_t num_blocks) { IdType* local_block_id, size_t batch_size, size_t num_blocks) {
const IdType idx = blockIdx.x * blockDim.x + threadIdx.x; const IdType idx = blockIdx.x * blockDim.x + threadIdx.x;
IdType i = 0; IdType i = 0;
...@@ -421,8 +430,8 @@ __global__ void GetBlockInfo(const IdType* num_block_prefixsum, ...@@ -421,8 +430,8 @@ __global__ void GetBlockInfo(const IdType* num_block_prefixsum,
} }
/*! /*!
* \brief Brute force kNN. Compute distance for each pair of input points and get * \brief Brute force kNN. Compute distance for each pair of input points and
* the result directly (without a distance matrix). * get the result directly (without a distance matrix).
* *
* \tparam FloatType The type of input points. * \tparam FloatType The type of input points.
* \tparam IdType The type of id. * \tparam IdType The type of id.
...@@ -434,9 +443,10 @@ __global__ void GetBlockInfo(const IdType* num_block_prefixsum, ...@@ -434,9 +443,10 @@ __global__ void GetBlockInfo(const IdType* num_block_prefixsum,
* \param result output array * \param result output array
*/ */
template <typename FloatType, typename IdType> template <typename FloatType, typename IdType>
void BruteForceKNNCuda(const NDArray& data_points, const IdArray& data_offsets, void BruteForceKNNCuda(
const NDArray& query_points, const IdArray& query_offsets, const NDArray& data_points, const IdArray& data_offsets,
const int k, IdArray result) { const NDArray& query_points, const IdArray& query_offsets, const int k,
IdArray result) {
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
const auto& ctx = data_points->ctx; const auto& ctx = data_points->ctx;
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
...@@ -450,13 +460,14 @@ void BruteForceKNNCuda(const NDArray& data_points, const IdArray& data_offsets, ...@@ -450,13 +460,14 @@ void BruteForceKNNCuda(const NDArray& data_points, const IdArray& data_offsets,
IdType* data_out = query_out + k * query_points->shape[0]; IdType* data_out = query_out + k * query_points->shape[0];
FloatType* dists = static_cast<FloatType*>(device->AllocWorkspace( FloatType* dists = static_cast<FloatType*>(device->AllocWorkspace(
ctx, k * query_points->shape[0] * sizeof(FloatType))); ctx, k * query_points->shape[0] * sizeof(FloatType)));
const int64_t block_size = cuda::FindNumThreads(query_points->shape[0]); const int64_t block_size = cuda::FindNumThreads(query_points->shape[0]);
const int64_t num_blocks = (query_points->shape[0] - 1) / block_size + 1; const int64_t num_blocks = (query_points->shape[0] - 1) / block_size + 1;
CUDA_KERNEL_CALL(BruteforceKnnKernel, num_blocks, block_size, 0, stream, CUDA_KERNEL_CALL(
data_points_data, data_offsets_data, query_points_data, query_offsets_data, BruteforceKnnKernel, num_blocks, block_size, 0, stream, data_points_data,
k, dists, query_out, data_out, batch_size, feature_size); data_offsets_data, query_points_data, query_offsets_data, k, dists,
query_out, data_out, batch_size, feature_size);
device->FreeWorkspace(ctx, dists); device->FreeWorkspace(ctx, dists);
} }
...@@ -477,9 +488,10 @@ void BruteForceKNNCuda(const NDArray& data_points, const IdArray& data_offsets, ...@@ -477,9 +488,10 @@ void BruteForceKNNCuda(const NDArray& data_points, const IdArray& data_offsets,
* \param result output array * \param result output array
*/ */
template <typename FloatType, typename IdType> template <typename FloatType, typename IdType>
void BruteForceKNNSharedCuda(const NDArray& data_points, const IdArray& data_offsets, void BruteForceKNNSharedCuda(
const NDArray& query_points, const IdArray& query_offsets, const NDArray& data_points, const IdArray& data_offsets,
const int k, IdArray result) { const NDArray& query_points, const IdArray& query_offsets, const int k,
IdArray result) {
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
const auto& ctx = data_points->ctx; const auto& ctx = data_points->ctx;
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
...@@ -496,44 +508,44 @@ void BruteForceKNNSharedCuda(const NDArray& data_points, const IdArray& data_off ...@@ -496,44 +508,44 @@ void BruteForceKNNSharedCuda(const NDArray& data_points, const IdArray& data_off
// determine block size according to this value // determine block size according to this value
int max_sharedmem_per_block = 0; int max_sharedmem_per_block = 0;
CUDA_CALL(cudaDeviceGetAttribute( CUDA_CALL(cudaDeviceGetAttribute(
&max_sharedmem_per_block, cudaDevAttrMaxSharedMemoryPerBlock, ctx.device_id)); &max_sharedmem_per_block, cudaDevAttrMaxSharedMemoryPerBlock,
const int64_t single_shared_mem = (k + 2 * feature_size) * sizeof(FloatType) + ctx.device_id));
k * sizeof(IdType); const int64_t single_shared_mem =
const int64_t block_size = cuda::FindNumThreads(max_sharedmem_per_block / single_shared_mem); (k + 2 * feature_size) * sizeof(FloatType) + k * sizeof(IdType);
const int64_t block_size =
cuda::FindNumThreads(max_sharedmem_per_block / single_shared_mem);
// Determine the number of blocks. We first get the number of blocks for each // Determine the number of blocks. We first get the number of blocks for each
// segment. Then we get the block id offset via prefix sum. // segment. Then we get the block id offset via prefix sum.
IdType* num_block_per_segment = static_cast<IdType*>( IdType* num_block_per_segment = static_cast<IdType*>(
device->AllocWorkspace(ctx, batch_size * sizeof(IdType))); device->AllocWorkspace(ctx, batch_size * sizeof(IdType)));
IdType* num_block_prefixsum = static_cast<IdType*>( IdType* num_block_prefixsum = static_cast<IdType*>(
device->AllocWorkspace(ctx, batch_size * sizeof(IdType))); device->AllocWorkspace(ctx, batch_size * sizeof(IdType)));
// block size for GetNumBlockPerSegment computation // block size for GetNumBlockPerSegment computation
int64_t temp_block_size = cuda::FindNumThreads(batch_size); int64_t temp_block_size = cuda::FindNumThreads(batch_size);
int64_t temp_num_blocks = (batch_size - 1) / temp_block_size + 1; int64_t temp_num_blocks = (batch_size - 1) / temp_block_size + 1;
CUDA_KERNEL_CALL(GetNumBlockPerSegment, temp_num_blocks, CUDA_KERNEL_CALL(
temp_block_size, 0, stream, GetNumBlockPerSegment, temp_num_blocks, temp_block_size, 0, stream,
query_offsets_data, num_block_per_segment, query_offsets_data, num_block_per_segment, batch_size, block_size);
batch_size, block_size);
size_t prefix_temp_size = 0; size_t prefix_temp_size = 0;
CUDA_CALL(cub::DeviceScan::ExclusiveSum( CUDA_CALL(cub::DeviceScan::ExclusiveSum(
nullptr, prefix_temp_size, num_block_per_segment, nullptr, prefix_temp_size, num_block_per_segment, num_block_prefixsum,
num_block_prefixsum, batch_size, stream)); batch_size, stream));
void* prefix_temp = device->AllocWorkspace(ctx, prefix_temp_size); void* prefix_temp = device->AllocWorkspace(ctx, prefix_temp_size);
CUDA_CALL(cub::DeviceScan::ExclusiveSum( CUDA_CALL(cub::DeviceScan::ExclusiveSum(
prefix_temp, prefix_temp_size, num_block_per_segment, prefix_temp, prefix_temp_size, num_block_per_segment, num_block_prefixsum,
num_block_prefixsum, batch_size, stream)); batch_size, stream));
device->FreeWorkspace(ctx, prefix_temp); device->FreeWorkspace(ctx, prefix_temp);
int64_t num_blocks = 0, final_elem = 0, copyoffset = (batch_size - 1) * sizeof(IdType); int64_t num_blocks = 0, final_elem = 0,
copyoffset = (batch_size - 1) * sizeof(IdType);
device->CopyDataFromTo( device->CopyDataFromTo(
num_block_prefixsum, copyoffset, &num_blocks, 0, num_block_prefixsum, copyoffset, &num_blocks, 0, sizeof(IdType), ctx,
sizeof(IdType), ctx, DGLContext{kDGLCPU, 0}, DGLContext{kDGLCPU, 0}, query_offsets->dtype);
query_offsets->dtype);
device->CopyDataFromTo( device->CopyDataFromTo(
num_block_per_segment, copyoffset, &final_elem, 0, num_block_per_segment, copyoffset, &final_elem, 0, sizeof(IdType), ctx,
sizeof(IdType), ctx, DGLContext{kDGLCPU, 0}, DGLContext{kDGLCPU, 0}, query_offsets->dtype);
query_offsets->dtype);
num_blocks += final_elem; num_blocks += final_elem;
device->FreeWorkspace(ctx, num_block_per_segment); device->FreeWorkspace(ctx, num_block_per_segment);
device->FreeWorkspace(ctx, num_block_prefixsum); device->FreeWorkspace(ctx, num_block_prefixsum);
...@@ -541,22 +553,22 @@ void BruteForceKNNSharedCuda(const NDArray& data_points, const IdArray& data_off ...@@ -541,22 +553,22 @@ void BruteForceKNNSharedCuda(const NDArray& data_points, const IdArray& data_off
// get batch id and local id in segment // get batch id and local id in segment
temp_block_size = cuda::FindNumThreads(num_blocks); temp_block_size = cuda::FindNumThreads(num_blocks);
temp_num_blocks = (num_blocks - 1) / temp_block_size + 1; temp_num_blocks = (num_blocks - 1) / temp_block_size + 1;
IdType* block_batch_id = static_cast<IdType*>(device->AllocWorkspace( IdType* block_batch_id = static_cast<IdType*>(
ctx, num_blocks * sizeof(IdType))); device->AllocWorkspace(ctx, num_blocks * sizeof(IdType)));
IdType* local_block_id = static_cast<IdType*>(device->AllocWorkspace( IdType* local_block_id = static_cast<IdType*>(
ctx, num_blocks * sizeof(IdType))); device->AllocWorkspace(ctx, num_blocks * sizeof(IdType)));
CUDA_KERNEL_CALL( CUDA_KERNEL_CALL(
GetBlockInfo, temp_num_blocks, temp_block_size, 0, GetBlockInfo, temp_num_blocks, temp_block_size, 0, stream,
stream, num_block_prefixsum, block_batch_id, num_block_prefixsum, block_batch_id, local_block_id, batch_size,
local_block_id, batch_size, num_blocks); num_blocks);
FloatType* dists = static_cast<FloatType*>(device->AllocWorkspace( FloatType* dists = static_cast<FloatType*>(device->AllocWorkspace(
ctx, k * query_points->shape[0] * sizeof(FloatType))); ctx, k * query_points->shape[0] * sizeof(FloatType)));
CUDA_KERNEL_CALL(BruteforceKnnShareKernel, num_blocks, block_size, CUDA_KERNEL_CALL(
single_shared_mem * block_size, stream, data_points_data, BruteforceKnnShareKernel, num_blocks, block_size,
data_offsets_data, query_points_data, query_offsets_data, single_shared_mem * block_size, stream, data_points_data,
block_batch_id, local_block_id, k, dists, query_out, data_offsets_data, query_points_data, query_offsets_data, block_batch_id,
data_out, batch_size, feature_size); local_block_id, k, dists, query_out, data_out, batch_size, feature_size);
device->FreeWorkspace(ctx, dists); device->FreeWorkspace(ctx, dists);
device->FreeWorkspace(ctx, local_block_id); device->FreeWorkspace(ctx, local_block_id);
...@@ -564,9 +576,8 @@ void BruteForceKNNSharedCuda(const NDArray& data_points, const IdArray& data_off ...@@ -564,9 +576,8 @@ void BruteForceKNNSharedCuda(const NDArray& data_points, const IdArray& data_off
} }
/*! \brief Setup rng state for nn-descent */ /*! \brief Setup rng state for nn-descent */
__global__ void SetupRngKernel(curandState* states, __global__ void SetupRngKernel(
const uint64_t seed, curandState* states, const uint64_t seed, const size_t n) {
const size_t n) {
size_t id = blockIdx.x * blockDim.x + threadIdx.x; size_t id = blockIdx.x * blockDim.x + threadIdx.x;
if (id < n) { if (id < n) {
curand_init(seed, id, 0, states + id); curand_init(seed, id, 0, states + id);
...@@ -578,16 +589,10 @@ __global__ void SetupRngKernel(curandState* states, ...@@ -578,16 +589,10 @@ __global__ void SetupRngKernel(curandState* states,
* for each nodes * for each nodes
*/ */
template <typename FloatType, typename IdType> template <typename FloatType, typename IdType>
__global__ void RandomInitNeighborsKernel(const FloatType* points, __global__ void RandomInitNeighborsKernel(
const IdType* offsets, const FloatType* points, const IdType* offsets, IdType* central_nodes,
IdType* central_nodes, IdType* neighbors, FloatType* dists, bool* flags, const int k,
IdType* neighbors, const int64_t feature_size, const int64_t batch_size, const uint64_t seed) {
FloatType* dists,
bool* flags,
const int k,
const int64_t feature_size,
const int64_t batch_size,
const uint64_t seed) {
const IdType point_idx = blockIdx.x * blockDim.x + threadIdx.x; const IdType point_idx = blockIdx.x * blockDim.x + threadIdx.x;
IdType batch_idx = 0; IdType batch_idx = 0;
if (point_idx >= offsets[batch_size]) return; if (point_idx >= offsets[batch_size]) return;
...@@ -623,21 +628,23 @@ __global__ void RandomInitNeighborsKernel(const FloatType* points, ...@@ -623,21 +628,23 @@ __global__ void RandomInitNeighborsKernel(const FloatType* points,
for (IdType i = 0; i < k; ++i) { for (IdType i = 0; i < k; ++i) {
current_flags[i] = true; current_flags[i] = true;
current_dists[i] = EuclideanDist<FloatType, IdType>( current_dists[i] = EuclideanDist<FloatType, IdType>(
points + point_idx * feature_size, points + point_idx * feature_size,
points + current_neighbors[i] * feature_size, points + current_neighbors[i] * feature_size, feature_size);
feature_size);
} }
// build heap // build heap
BuildHeap<FloatType, IdType>(neighbors + point_idx * k, current_dists, k); BuildHeap<FloatType, IdType>(neighbors + point_idx * k, current_dists, k);
} }
/*! \brief Randomly select candidates from current knn and reverse-knn graph for nn-descent */ /*!
* \brief Randomly select candidates from current knn and reverse-knn graph for
* nn-descent.
*/
template <typename IdType> template <typename IdType>
__global__ void FindCandidatesKernel(const IdType* offsets, IdType* new_candidates, __global__ void FindCandidatesKernel(
IdType* old_candidates, IdType* neighbors, bool* flags, const IdType* offsets, IdType* new_candidates, IdType* old_candidates,
const uint64_t seed, const int64_t batch_size, IdType* neighbors, bool* flags, const uint64_t seed,
const int num_candidates, const int k) { const int64_t batch_size, const int num_candidates, const int k) {
const IdType point_idx = blockIdx.x * blockDim.x + threadIdx.x; const IdType point_idx = blockIdx.x * blockDim.x + threadIdx.x;
IdType batch_idx = 0; IdType batch_idx = 0;
if (point_idx >= offsets[batch_size]) return; if (point_idx >= offsets[batch_size]) return;
...@@ -652,13 +659,16 @@ __global__ void FindCandidatesKernel(const IdType* offsets, IdType* new_candidat ...@@ -652,13 +659,16 @@ __global__ void FindCandidatesKernel(const IdType* offsets, IdType* new_candidat
} }
} }
IdType segment_start = offsets[batch_idx], segment_end = offsets[batch_idx + 1]; IdType segment_start = offsets[batch_idx],
segment_end = offsets[batch_idx + 1];
IdType* current_neighbors = neighbors + point_idx * k; IdType* current_neighbors = neighbors + point_idx * k;
bool* current_flags = flags + point_idx * k; bool* current_flags = flags + point_idx * k;
// reset candidates // reset candidates
IdType* new_candidates_ptr = new_candidates + point_idx * (num_candidates + 1); IdType* new_candidates_ptr =
IdType* old_candidates_ptr = old_candidates + point_idx * (num_candidates + 1); new_candidates + point_idx * (num_candidates + 1);
IdType* old_candidates_ptr =
old_candidates + point_idx * (num_candidates + 1);
new_candidates_ptr[0] = 0; new_candidates_ptr[0] = 0;
old_candidates_ptr[0] = 0; old_candidates_ptr[0] = 0;
...@@ -666,7 +676,8 @@ __global__ void FindCandidatesKernel(const IdType* offsets, IdType* new_candidat ...@@ -666,7 +676,8 @@ __global__ void FindCandidatesKernel(const IdType* offsets, IdType* new_candidat
// here we use candidate[0] for reservoir sampling temporarily // here we use candidate[0] for reservoir sampling temporarily
for (IdType i = 0; i < k; ++i) { for (IdType i = 0; i < k; ++i) {
IdType candidate = current_neighbors[i]; IdType candidate = current_neighbors[i];
IdType* candidate_array = current_flags[i] ? new_candidates_ptr : old_candidates_ptr; IdType* candidate_array =
current_flags[i] ? new_candidates_ptr : old_candidates_ptr;
IdType curr_num = candidate_array[0]; IdType curr_num = candidate_array[0];
IdType* candidate_data = candidate_array + 1; IdType* candidate_data = candidate_array + 1;
...@@ -686,7 +697,8 @@ __global__ void FindCandidatesKernel(const IdType* offsets, IdType* new_candidat ...@@ -686,7 +697,8 @@ __global__ void FindCandidatesKernel(const IdType* offsets, IdType* new_candidat
for (IdType i = index_start; i < index_end; ++i) { for (IdType i = index_start; i < index_end; ++i) {
if (neighbors[i] == point_idx) { if (neighbors[i] == point_idx) {
IdType reverse_candidate = (i - index_start) / k + segment_start; IdType reverse_candidate = (i - index_start) / k + segment_start;
IdType* candidate_array = flags[i] ? new_candidates_ptr : old_candidates_ptr; IdType* candidate_array =
flags[i] ? new_candidates_ptr : old_candidates_ptr;
IdType curr_num = candidate_array[0]; IdType curr_num = candidate_array[0];
IdType* candidate_data = candidate_array + 1; IdType* candidate_data = candidate_array + 1;
...@@ -702,8 +714,10 @@ __global__ void FindCandidatesKernel(const IdType* offsets, IdType* new_candidat ...@@ -702,8 +714,10 @@ __global__ void FindCandidatesKernel(const IdType* offsets, IdType* new_candidat
} }
// set candidate[0] back to length // set candidate[0] back to length
if (new_candidates_ptr[0] > num_candidates) new_candidates_ptr[0] = num_candidates; if (new_candidates_ptr[0] > num_candidates)
if (old_candidates_ptr[0] > num_candidates) old_candidates_ptr[0] = num_candidates; new_candidates_ptr[0] = num_candidates;
if (old_candidates_ptr[0] > num_candidates)
old_candidates_ptr[0] = num_candidates;
// mark new_candidates as old // mark new_candidates as old
IdType num_new_candidates = new_candidates_ptr[0]; IdType num_new_candidates = new_candidates_ptr[0];
...@@ -723,19 +737,20 @@ __global__ void FindCandidatesKernel(const IdType* offsets, IdType* new_candidat ...@@ -723,19 +737,20 @@ __global__ void FindCandidatesKernel(const IdType* offsets, IdType* new_candidat
/*! \brief Update knn graph according to selected candidates for nn-descent */ /*! \brief Update knn graph according to selected candidates for nn-descent */
template <typename FloatType, typename IdType> template <typename FloatType, typename IdType>
__global__ void UpdateNeighborsKernel(const FloatType* points, const IdType* offsets, __global__ void UpdateNeighborsKernel(
IdType* neighbors, IdType* new_candidates, const FloatType* points, const IdType* offsets, IdType* neighbors,
IdType* old_candidates, FloatType* distances, IdType* new_candidates, IdType* old_candidates, FloatType* distances,
bool* flags, IdType* num_updates, bool* flags, IdType* num_updates, const int64_t batch_size,
const int64_t batch_size, const int num_candidates, const int num_candidates, const int k, const int64_t feature_size) {
const int k, const int64_t feature_size) {
const IdType point_idx = blockIdx.x * blockDim.x + threadIdx.x; const IdType point_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (point_idx >= offsets[batch_size]) return; if (point_idx >= offsets[batch_size]) return;
IdType* current_neighbors = neighbors + point_idx * k; IdType* current_neighbors = neighbors + point_idx * k;
bool* current_flags = flags + point_idx * k; bool* current_flags = flags + point_idx * k;
FloatType* current_dists = distances + point_idx * k; FloatType* current_dists = distances + point_idx * k;
IdType* new_candidates_ptr = new_candidates + point_idx * (num_candidates + 1); IdType* new_candidates_ptr =
IdType* old_candidates_ptr = old_candidates + point_idx * (num_candidates + 1); new_candidates + point_idx * (num_candidates + 1);
IdType* old_candidates_ptr =
old_candidates + point_idx * (num_candidates + 1);
IdType num_new_candidates = new_candidates_ptr[0]; IdType num_new_candidates = new_candidates_ptr[0];
IdType num_old_candidates = old_candidates_ptr[0]; IdType num_old_candidates = old_candidates_ptr[0];
IdType current_num_updates = 0; IdType current_num_updates = 0;
...@@ -755,15 +770,14 @@ __global__ void UpdateNeighborsKernel(const FloatType* points, const IdType* off ...@@ -755,15 +770,14 @@ __global__ void UpdateNeighborsKernel(const FloatType* points, const IdType* off
for (IdType j = 1; j <= num_twohop_new; ++j) { for (IdType j = 1; j <= num_twohop_new; ++j) {
IdType twohop_new_c = twohop_new_ptr[j]; IdType twohop_new_c = twohop_new_ptr[j];
FloatType new_dist = EuclideanDistWithCheck<FloatType, IdType>( FloatType new_dist = EuclideanDistWithCheck<FloatType, IdType>(
points + point_idx * feature_size, points + point_idx * feature_size,
points + twohop_new_c * feature_size, points + twohop_new_c * feature_size, feature_size, worst_dist);
feature_size, worst_dist);
if (FlaggedHeapInsert<FloatType, IdType>( if (FlaggedHeapInsert<FloatType, IdType>(
current_neighbors, current_dists, current_flags, current_neighbors, current_dists, current_flags, twohop_new_c,
twohop_new_c, new_dist, true, k, true)) { new_dist, true, k, true)) {
++current_num_updates; ++current_num_updates;
worst_dist = current_dists[0]; worst_dist = current_dists[0];
} }
} }
...@@ -771,15 +785,14 @@ __global__ void UpdateNeighborsKernel(const FloatType* points, const IdType* off ...@@ -771,15 +785,14 @@ __global__ void UpdateNeighborsKernel(const FloatType* points, const IdType* off
for (IdType j = 1; j <= num_twohop_old; ++j) { for (IdType j = 1; j <= num_twohop_old; ++j) {
IdType twohop_old_c = twohop_old_ptr[j]; IdType twohop_old_c = twohop_old_ptr[j];
FloatType new_dist = EuclideanDistWithCheck<FloatType, IdType>( FloatType new_dist = EuclideanDistWithCheck<FloatType, IdType>(
points + point_idx * feature_size, points + point_idx * feature_size,
points + twohop_old_c * feature_size, points + twohop_old_c * feature_size, feature_size, worst_dist);
feature_size, worst_dist);
if (FlaggedHeapInsert<FloatType, IdType>( if (FlaggedHeapInsert<FloatType, IdType>(
current_neighbors, current_dists, current_flags, current_neighbors, current_dists, current_flags, twohop_old_c,
twohop_old_c, new_dist, true, k, true)) { new_dist, true, k, true)) {
++current_num_updates; ++current_num_updates;
worst_dist = current_dists[0]; worst_dist = current_dists[0];
} }
} }
} }
...@@ -797,15 +810,14 @@ __global__ void UpdateNeighborsKernel(const FloatType* points, const IdType* off ...@@ -797,15 +810,14 @@ __global__ void UpdateNeighborsKernel(const FloatType* points, const IdType* off
for (IdType j = 1; j <= num_twohop_new; ++j) { for (IdType j = 1; j <= num_twohop_new; ++j) {
IdType twohop_new_c = twohop_new_ptr[j]; IdType twohop_new_c = twohop_new_ptr[j];
FloatType new_dist = EuclideanDistWithCheck<FloatType, IdType>( FloatType new_dist = EuclideanDistWithCheck<FloatType, IdType>(
points + point_idx * feature_size, points + point_idx * feature_size,
points + twohop_new_c * feature_size, points + twohop_new_c * feature_size, feature_size, worst_dist);
feature_size, worst_dist);
if (FlaggedHeapInsert<FloatType, IdType>( if (FlaggedHeapInsert<FloatType, IdType>(
current_neighbors, current_dists, current_flags, current_neighbors, current_dists, current_flags, twohop_new_c,
twohop_new_c, new_dist, true, k, true)) { new_dist, true, k, true)) {
++current_num_updates; ++current_num_updates;
worst_dist = current_dists[0]; worst_dist = current_dists[0];
} }
} }
} }
...@@ -816,24 +828,25 @@ __global__ void UpdateNeighborsKernel(const FloatType* points, const IdType* off ...@@ -816,24 +828,25 @@ __global__ void UpdateNeighborsKernel(const FloatType* points, const IdType* off
} // namespace impl } // namespace impl
template <DGLDeviceType XPU, typename FloatType, typename IdType> template <DGLDeviceType XPU, typename FloatType, typename IdType>
void KNN(const NDArray& data_points, const IdArray& data_offsets, void KNN(
const NDArray& query_points, const IdArray& query_offsets, const NDArray& data_points, const IdArray& data_offsets,
const int k, IdArray result, const std::string& algorithm) { const NDArray& query_points, const IdArray& query_offsets, const int k,
IdArray result, const std::string& algorithm) {
if (algorithm == std::string("bruteforce")) { if (algorithm == std::string("bruteforce")) {
impl::BruteForceKNNCuda<FloatType, IdType>( impl::BruteForceKNNCuda<FloatType, IdType>(
data_points, data_offsets, query_points, query_offsets, k, result); data_points, data_offsets, query_points, query_offsets, k, result);
} else if (algorithm == std::string("bruteforce-sharemem")) { } else if (algorithm == std::string("bruteforce-sharemem")) {
impl::BruteForceKNNSharedCuda<FloatType, IdType>( impl::BruteForceKNNSharedCuda<FloatType, IdType>(
data_points, data_offsets, query_points, query_offsets, k, result); data_points, data_offsets, query_points, query_offsets, k, result);
} else { } else {
LOG(FATAL) << "Algorithm " << algorithm << " is not supported on CUDA."; LOG(FATAL) << "Algorithm " << algorithm << " is not supported on CUDA.";
} }
} }
template <DGLDeviceType XPU, typename FloatType, typename IdType> template <DGLDeviceType XPU, typename FloatType, typename IdType>
void NNDescent(const NDArray& points, const IdArray& offsets, void NNDescent(
IdArray result, const int k, const int num_iters, const NDArray& points, const IdArray& offsets, IdArray result, const int k,
const int num_candidates, const double delta) { const int num_iters, const int num_candidates, const double delta) {
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
const auto& ctx = points->ctx; const auto& ctx = points->ctx;
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
...@@ -847,66 +860,68 @@ void NNDescent(const NDArray& points, const IdArray& offsets, ...@@ -847,66 +860,68 @@ void NNDescent(const NDArray& points, const IdArray& offsets,
IdType* neighbors = central_nodes + k * num_nodes; IdType* neighbors = central_nodes + k * num_nodes;
uint64_t seed; uint64_t seed;
int warp_size = 0; int warp_size = 0;
CUDA_CALL(cudaDeviceGetAttribute( CUDA_CALL(
&warp_size, cudaDevAttrWarpSize, ctx.device_id)); cudaDeviceGetAttribute(&warp_size, cudaDevAttrWarpSize, ctx.device_id));
// We don't need large block sizes, since there's not much inter-thread communication // We don't need large block sizes, since there's not much inter-thread
// communication
int64_t block_size = warp_size; int64_t block_size = warp_size;
int64_t num_blocks = (num_nodes - 1) / block_size + 1; int64_t num_blocks = (num_nodes - 1) / block_size + 1;
// allocate space for candidates, distances and flags // allocate space for candidates, distances and flags
// we use the first element in candidate array to represent length // we use the first element in candidate array to represent length
IdType* new_candidates = static_cast<IdType*>( IdType* new_candidates = static_cast<IdType*>(device->AllocWorkspace(
device->AllocWorkspace(ctx, num_nodes * (num_candidates + 1) * sizeof(IdType))); ctx, num_nodes * (num_candidates + 1) * sizeof(IdType)));
IdType* old_candidates = static_cast<IdType*>( IdType* old_candidates = static_cast<IdType*>(device->AllocWorkspace(
device->AllocWorkspace(ctx, num_nodes * (num_candidates + 1) * sizeof(IdType))); ctx, num_nodes * (num_candidates + 1) * sizeof(IdType)));
IdType* num_updates = static_cast<IdType*>( IdType* num_updates = static_cast<IdType*>(
device->AllocWorkspace(ctx, num_nodes * sizeof(IdType))); device->AllocWorkspace(ctx, num_nodes * sizeof(IdType)));
FloatType* distances = static_cast<FloatType*>( FloatType* distances = static_cast<FloatType*>(
device->AllocWorkspace(ctx, num_nodes * k * sizeof(IdType))); device->AllocWorkspace(ctx, num_nodes * k * sizeof(IdType)));
bool* flags = static_cast<bool*>( bool* flags = static_cast<bool*>(
device->AllocWorkspace(ctx, num_nodes * k * sizeof(IdType))); device->AllocWorkspace(ctx, num_nodes * k * sizeof(IdType)));
size_t sum_temp_size = 0; size_t sum_temp_size = 0;
IdType total_num_updates = 0; IdType total_num_updates = 0;
IdType* total_num_updates_d = static_cast<IdType*>( IdType* total_num_updates_d =
device->AllocWorkspace(ctx, sizeof(IdType))); static_cast<IdType*>(device->AllocWorkspace(ctx, sizeof(IdType)));
CUDA_CALL(cub::DeviceReduce::Sum( CUDA_CALL(cub::DeviceReduce::Sum(
nullptr, sum_temp_size, num_updates, total_num_updates_d, num_nodes, stream)); nullptr, sum_temp_size, num_updates, total_num_updates_d, num_nodes,
IdType* sum_temp_storage = static_cast<IdType*>( stream));
device->AllocWorkspace(ctx, sum_temp_size)); IdType* sum_temp_storage =
static_cast<IdType*>(device->AllocWorkspace(ctx, sum_temp_size));
// random initialize neighbors // random initialize neighbors
seed = RandomEngine::ThreadLocal()->RandInt<uint64_t>( seed = RandomEngine::ThreadLocal()->RandInt<uint64_t>(
std::numeric_limits<uint64_t>::max()); std::numeric_limits<uint64_t>::max());
CUDA_KERNEL_CALL( CUDA_KERNEL_CALL(
impl::RandomInitNeighborsKernel, num_blocks, block_size, 0, stream, impl::RandomInitNeighborsKernel, num_blocks, block_size, 0, stream,
points_data, offsets_data, central_nodes, neighbors, distances, flags, k, points_data, offsets_data, central_nodes, neighbors, distances, flags, k,
feature_size, batch_size, seed); feature_size, batch_size, seed);
for (int i = 0; i < num_iters; ++i) { for (int i = 0; i < num_iters; ++i) {
// select candidates // select candidates
seed = RandomEngine::ThreadLocal()->RandInt<uint64_t>( seed = RandomEngine::ThreadLocal()->RandInt<uint64_t>(
std::numeric_limits<uint64_t>::max()); std::numeric_limits<uint64_t>::max());
CUDA_KERNEL_CALL( CUDA_KERNEL_CALL(
impl::FindCandidatesKernel, num_blocks, block_size, 0, impl::FindCandidatesKernel, num_blocks, block_size, 0, stream,
stream, offsets_data, new_candidates, old_candidates, neighbors, offsets_data, new_candidates, old_candidates, neighbors, flags, seed,
flags, seed, batch_size, num_candidates, k); batch_size, num_candidates, k);
// update // update
CUDA_KERNEL_CALL( CUDA_KERNEL_CALL(
impl::UpdateNeighborsKernel, num_blocks, block_size, 0, stream, impl::UpdateNeighborsKernel, num_blocks, block_size, 0, stream,
points_data, offsets_data, neighbors, new_candidates, old_candidates, distances, points_data, offsets_data, neighbors, new_candidates, old_candidates,
flags, num_updates, batch_size, num_candidates, k, feature_size); distances, flags, num_updates, batch_size, num_candidates, k,
feature_size);
total_num_updates = 0; total_num_updates = 0;
CUDA_CALL(cub::DeviceReduce::Sum( CUDA_CALL(cub::DeviceReduce::Sum(
sum_temp_storage, sum_temp_size, num_updates, total_num_updates_d, num_nodes, sum_temp_storage, sum_temp_size, num_updates, total_num_updates_d,
stream)); num_nodes, stream));
device->CopyDataFromTo( device->CopyDataFromTo(
total_num_updates_d, 0, &total_num_updates, 0, total_num_updates_d, 0, &total_num_updates, 0, sizeof(IdType), ctx,
sizeof(IdType), ctx, DGLContext{kDGLCPU, 0}, DGLContext{kDGLCPU, 0}, offsets->dtype);
offsets->dtype);
if (total_num_updates <= static_cast<IdType>(delta * k * num_nodes)) { if (total_num_updates <= static_cast<IdType>(delta * k * num_nodes)) {
break; break;
...@@ -923,38 +938,34 @@ void NNDescent(const NDArray& points, const IdArray& offsets, ...@@ -923,38 +938,34 @@ void NNDescent(const NDArray& points, const IdArray& offsets,
} }
template void KNN<kDGLCUDA, float, int32_t>( template void KNN<kDGLCUDA, float, int32_t>(
const NDArray& data_points, const IdArray& data_offsets, const NDArray& data_points, const IdArray& data_offsets,
const NDArray& query_points, const IdArray& query_offsets, const NDArray& query_points, const IdArray& query_offsets, const int k,
const int k, IdArray result, const std::string& algorithm); IdArray result, const std::string& algorithm);
template void KNN<kDGLCUDA, float, int64_t>( template void KNN<kDGLCUDA, float, int64_t>(
const NDArray& data_points, const IdArray& data_offsets, const NDArray& data_points, const IdArray& data_offsets,
const NDArray& query_points, const IdArray& query_offsets, const NDArray& query_points, const IdArray& query_offsets, const int k,
const int k, IdArray result, const std::string& algorithm); IdArray result, const std::string& algorithm);
template void KNN<kDGLCUDA, double, int32_t>( template void KNN<kDGLCUDA, double, int32_t>(
const NDArray& data_points, const IdArray& data_offsets, const NDArray& data_points, const IdArray& data_offsets,
const NDArray& query_points, const IdArray& query_offsets, const NDArray& query_points, const IdArray& query_offsets, const int k,
const int k, IdArray result, const std::string& algorithm); IdArray result, const std::string& algorithm);
template void KNN<kDGLCUDA, double, int64_t>( template void KNN<kDGLCUDA, double, int64_t>(
const NDArray& data_points, const IdArray& data_offsets, const NDArray& data_points, const IdArray& data_offsets,
const NDArray& query_points, const IdArray& query_offsets, const NDArray& query_points, const IdArray& query_offsets, const int k,
const int k, IdArray result, const std::string& algorithm); IdArray result, const std::string& algorithm);
template void NNDescent<kDGLCUDA, float, int32_t>( template void NNDescent<kDGLCUDA, float, int32_t>(
const NDArray& points, const IdArray& offsets, const NDArray& points, const IdArray& offsets, IdArray result, const int k,
IdArray result, const int k, const int num_iters, const int num_iters, const int num_candidates, const double delta);
const int num_candidates, const double delta);
template void NNDescent<kDGLCUDA, float, int64_t>( template void NNDescent<kDGLCUDA, float, int64_t>(
const NDArray& points, const IdArray& offsets, const NDArray& points, const IdArray& offsets, IdArray result, const int k,
IdArray result, const int k, const int num_iters, const int num_iters, const int num_candidates, const double delta);
const int num_candidates, const double delta);
template void NNDescent<kDGLCUDA, double, int32_t>( template void NNDescent<kDGLCUDA, double, int32_t>(
const NDArray& points, const IdArray& offsets, const NDArray& points, const IdArray& offsets, IdArray result, const int k,
IdArray result, const int k, const int num_iters, const int num_iters, const int num_candidates, const double delta);
const int num_candidates, const double delta);
template void NNDescent<kDGLCUDA, double, int64_t>( template void NNDescent<kDGLCUDA, double, int64_t>(
const NDArray& points, const IdArray& offsets, const NDArray& points, const IdArray& offsets, IdArray result, const int k,
IdArray result, const int k, const int num_iters, const int num_iters, const int num_candidates, const double delta);
const int num_candidates, const double delta);
} // namespace transform } // namespace transform
} // namespace dgl } // namespace dgl
...@@ -4,9 +4,11 @@ ...@@ -4,9 +4,11 @@
* \brief k-nearest-neighbor (KNN) interface * \brief k-nearest-neighbor (KNN) interface
*/ */
#include <dgl/runtime/registry.h>
#include <dgl/runtime/packed_func.h>
#include "knn.h" #include "knn.h"
#include <dgl/runtime/packed_func.h>
#include <dgl/runtime/registry.h>
#include "../../array/check.h" #include "../../array/check.h"
using namespace dgl::runtime; using namespace dgl::runtime;
...@@ -14,57 +16,59 @@ namespace dgl { ...@@ -14,57 +16,59 @@ namespace dgl {
namespace transform { namespace transform {
DGL_REGISTER_GLOBAL("transform._CAPI_DGLKNN") DGL_REGISTER_GLOBAL("transform._CAPI_DGLKNN")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
const NDArray data_points = args[0]; const NDArray data_points = args[0];
const IdArray data_offsets = args[1]; const IdArray data_offsets = args[1];
const NDArray query_points = args[2]; const NDArray query_points = args[2];
const IdArray query_offsets = args[3]; const IdArray query_offsets = args[3];
const int k = args[4]; const int k = args[4];
IdArray result = args[5]; IdArray result = args[5];
const std::string algorithm = args[6]; const std::string algorithm = args[6];
aten::CheckContiguous( aten::CheckContiguous(
{data_points, data_offsets, query_points, query_offsets, result}, {data_points, data_offsets, query_points, query_offsets, result},
{"data_points", "data_offsets", "query_points", "query_offsets", "result"}); {"data_points", "data_offsets", "query_points", "query_offsets",
aten::CheckCtx( "result"});
data_points->ctx, {data_offsets, query_points, query_offsets, result}, aten::CheckCtx(
{"data_offsets", "query_points", "query_offsets", "result"}); data_points->ctx, {data_offsets, query_points, query_offsets, result},
{"data_offsets", "query_points", "query_offsets", "result"});
ATEN_XPU_SWITCH_CUDA(data_points->ctx.device_type, XPU, "KNN", { ATEN_XPU_SWITCH_CUDA(data_points->ctx.device_type, XPU, "KNN", {
ATEN_FLOAT_TYPE_SWITCH(data_points->dtype, FloatType, "data_points", { ATEN_FLOAT_TYPE_SWITCH(data_points->dtype, FloatType, "data_points", {
ATEN_ID_TYPE_SWITCH(result->dtype, IdType, { ATEN_ID_TYPE_SWITCH(result->dtype, IdType, {
KNN<XPU, FloatType, IdType>( KNN<XPU, FloatType, IdType>(
data_points, data_offsets, query_points, data_points, data_offsets, query_points, query_offsets, k,
query_offsets, k, result, algorithm); result, algorithm);
});
}); });
}); });
}); });
});
DGL_REGISTER_GLOBAL("transform._CAPI_DGLNNDescent") DGL_REGISTER_GLOBAL("transform._CAPI_DGLNNDescent")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
const NDArray points = args[0]; const NDArray points = args[0];
const IdArray offsets = args[1]; const IdArray offsets = args[1];
const IdArray result = args[2]; const IdArray result = args[2];
const int k = args[3]; const int k = args[3];
const int num_iters = args[4]; const int num_iters = args[4];
const int num_candidates = args[5]; const int num_candidates = args[5];
const double delta = args[6]; const double delta = args[6];
aten::CheckContiguous( aten::CheckContiguous(
{points, offsets, result}, {"points", "offsets", "result"}); {points, offsets, result}, {"points", "offsets", "result"});
aten::CheckCtx( aten::CheckCtx(
points->ctx, {points, offsets, result}, {"points", "offsets", "result"}); points->ctx, {points, offsets, result},
{"points", "offsets", "result"});
ATEN_XPU_SWITCH_CUDA(points->ctx.device_type, XPU, "NNDescent", { ATEN_XPU_SWITCH_CUDA(points->ctx.device_type, XPU, "NNDescent", {
ATEN_FLOAT_TYPE_SWITCH(points->dtype, FloatType, "points", { ATEN_FLOAT_TYPE_SWITCH(points->dtype, FloatType, "points", {
ATEN_ID_TYPE_SWITCH(result->dtype, IdType, { ATEN_ID_TYPE_SWITCH(result->dtype, IdType, {
NNDescent<XPU, FloatType, IdType>( NNDescent<XPU, FloatType, IdType>(
points, offsets, result, k, num_iters, num_candidates, delta); points, offsets, result, k, num_iters, num_candidates, delta);
});
}); });
}); });
}); });
});
} // namespace transform } // namespace transform
} // namespace dgl } // namespace dgl
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#define DGL_GRAPH_TRANSFORM_KNN_H_ #define DGL_GRAPH_TRANSFORM_KNN_H_
#include <dgl/array.h> #include <dgl/array.h>
#include <string> #include <string>
namespace dgl { namespace dgl {
...@@ -15,42 +16,45 @@ namespace transform { ...@@ -15,42 +16,45 @@ namespace transform {
/*! /*!
* \brief For each point in each segment in \a query_points, find \a k nearest * \brief For each point in each segment in \a query_points, find \a k nearest
* points in the same segment in \a data_points. \a data_offsets and \a query_offsets * points in the same segment in \a data_points. \a data_offsets and \a
* determine the start index of each segment in \a data_points and \a query_points. * query_offsets determine the start index of each segment in \a
* data_points and \a query_points.
* *
* \param data_points dataset points. * \param data_points dataset points.
* \param data_offsets offsets of point index in \a data_points. * \param data_offsets offsets of point index in \a data_points.
* \param query_points query points. * \param query_points query points.
* \param query_offsets offsets of point index in \a query_points. * \param query_offsets offsets of point index in \a query_points.
* \param k the number of nearest points. * \param k the number of nearest points.
* \param result output array. A 2D tensor indicating the index * \param result output array. A 2D tensor indicating the index relation
* relation between \a query_points and \a data_points. * between \a query_points and \a data_points.
* \param algorithm algorithm used to compute the k-nearest neighbors. * \param algorithm algorithm used to compute the k-nearest neighbors.
*/ */
template <DGLDeviceType XPU, typename FloatType, typename IdType> template <DGLDeviceType XPU, typename FloatType, typename IdType>
void KNN(const NDArray& data_points, const IdArray& data_offsets, void KNN(
const NDArray& query_points, const IdArray& query_offsets, const NDArray& data_points, const IdArray& data_offsets,
const int k, IdArray result, const std::string& algorithm); const NDArray& query_points, const IdArray& query_offsets, const int k,
IdArray result, const std::string& algorithm);
/*! /*!
* \brief For each input point, find \a k approximate nearest points in the same * \brief For each input point, find \a k approximate nearest points in the same
* segment using NN-descent algorithm. * segment using NN-descent algorithm.
* *
* \param points input points. * \param points input points.
* \param offsets offsets of point index. * \param offsets offsets of point index.
* \param result output array. A 2D tensor indicating the index relation between points. * \param result output array. A 2D tensor indicating the index relation between
* points.
* \param k the number of nearest points. * \param k the number of nearest points.
* \param num_iters The maximum number of NN-descent iterations to perform. * \param num_iters The maximum number of NN-descent iterations to perform.
* \param num_candidates The maximum number of candidates to be considered during one iteration. * \param num_candidates The maximum number of candidates to be considered
* during one iteration.
* \param delta A value controls the early abort. * \param delta A value controls the early abort.
*/ */
template <DGLDeviceType XPU, typename FloatType, typename IdType> template <DGLDeviceType XPU, typename FloatType, typename IdType>
void NNDescent(const NDArray& points, const IdArray& offsets, void NNDescent(
IdArray result, const int k, const int num_iters, const NDArray& points, const IdArray& offsets, IdArray result, const int k,
const int num_candidates, const double delta); const int num_iters, const int num_candidates, const double delta);
} // namespace transform } // namespace transform
} // namespace dgl } // namespace dgl
#endif // DGL_GRAPH_TRANSFORM_KNN_H_ #endif // DGL_GRAPH_TRANSFORM_KNN_H_
...@@ -4,12 +4,14 @@ ...@@ -4,12 +4,14 @@
* \brief Line graph implementation * \brief Line graph implementation
*/ */
#include <dgl/base_heterograph.h>
#include <dgl/transform.h>
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/base_heterograph.h>
#include <dgl/packed_func_ext.h> #include <dgl/packed_func_ext.h>
#include <vector> #include <dgl/transform.h>
#include <utility> #include <utility>
#include <vector>
#include "../../c_api_common.h" #include "../../c_api_common.h"
#include "../heterograph.h" #include "../heterograph.h"
...@@ -21,26 +23,25 @@ using namespace dgl::aten; ...@@ -21,26 +23,25 @@ using namespace dgl::aten;
namespace transform { namespace transform {
/*! /*!
* \brief Create Line Graph * \brief Create Line Graph.
* \param hg Graph * \param hg Graph.
* \param backtracking whether the pair of (v, u) (u, v) edges are treated as linked * \param backtracking whether the pair of (v, u) (u, v) edges are treated as
* \return The Line Graph * linked.
* \return The Line Graph.
*/ */
HeteroGraphPtr CreateLineGraph( HeteroGraphPtr CreateLineGraph(HeteroGraphPtr hg, bool backtracking) {
HeteroGraphPtr hg,
bool backtracking) {
const auto hgp = std::dynamic_pointer_cast<HeteroGraph>(hg); const auto hgp = std::dynamic_pointer_cast<HeteroGraph>(hg);
return hgp->LineGraph(backtracking); return hgp->LineGraph(backtracking);
} }
DGL_REGISTER_GLOBAL("transform._CAPI_DGLHeteroLineGraph") DGL_REGISTER_GLOBAL("transform._CAPI_DGLHeteroLineGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
bool backtracking = args[1]; bool backtracking = args[1];
auto hgptr = CreateLineGraph(hg.sptr(), backtracking); auto hgptr = CreateLineGraph(hg.sptr(), backtracking);
*rv = HeteroGraphRef(hgptr); *rv = HeteroGraphRef(hgptr);
}); });
}; // namespace transform }; // namespace transform
}; // namespace dgl }; // namespace dgl
...@@ -19,14 +19,15 @@ namespace transform { ...@@ -19,14 +19,15 @@ namespace transform {
#if !defined(_WIN32) #if !defined(_WIN32)
IdArray MetisPartition(UnitGraphPtr g, int k, NDArray vwgt_arr, IdArray MetisPartition(
const std::string &mode, bool obj_cut) { UnitGraphPtr g, int k, NDArray vwgt_arr, const std::string &mode,
bool obj_cut) {
// Mode can only be "k-way" or "recursive" // Mode can only be "k-way" or "recursive"
CHECK(mode == "k-way" || mode == "recursive") CHECK(mode == "k-way" || mode == "recursive")
<< "mode can only be \"k-way\" or \"recursive\""; << "mode can only be \"k-way\" or \"recursive\"";
// The index type of Metis needs to be compatible with DGL index type. // The index type of Metis needs to be compatible with DGL index type.
CHECK_EQ(sizeof(idx_t), sizeof(int64_t)) CHECK_EQ(sizeof(idx_t), sizeof(int64_t))
<< "Metis only supports int64 graph for now"; << "Metis only supports int64 graph for now";
// This is a symmetric graph, so in-csr and out-csr are the same. // This is a symmetric graph, so in-csr and out-csr are the same.
const auto mat = g->GetCSCMatrix(0); const auto mat = g->GetCSCMatrix(0);
// const auto mat = g->GetInCSR()->ToCSRMatrix(); // const auto mat = g->GetInCSR()->ToCSRMatrix();
...@@ -42,16 +43,17 @@ IdArray MetisPartition(UnitGraphPtr g, int k, NDArray vwgt_arr, ...@@ -42,16 +43,17 @@ IdArray MetisPartition(UnitGraphPtr g, int k, NDArray vwgt_arr,
int64_t vwgt_len = vwgt_arr->shape[0]; int64_t vwgt_len = vwgt_arr->shape[0];
CHECK_EQ(sizeof(idx_t), vwgt_arr->dtype.bits / 8) CHECK_EQ(sizeof(idx_t), vwgt_arr->dtype.bits / 8)
<< "The vertex weight array doesn't have right type"; << "The vertex weight array doesn't have right type";
CHECK(vwgt_len % g->NumVertices(0) == 0) CHECK(vwgt_len % g->NumVertices(0) == 0)
<< "The vertex weight array doesn't have right number of elements"; << "The vertex weight array doesn't have right number of elements";
idx_t *vwgt = NULL; idx_t *vwgt = NULL;
if (vwgt_len > 0) { if (vwgt_len > 0) {
ncon = vwgt_len / g->NumVertices(0); ncon = vwgt_len / g->NumVertices(0);
vwgt = static_cast<idx_t *>(vwgt_arr->data); vwgt = static_cast<idx_t *>(vwgt_arr->data);
} }
auto partition_func = (mode == "k-way") ? METIS_PartGraphKway : METIS_PartGraphRecursive; auto partition_func =
(mode == "k-way") ? METIS_PartGraphKway : METIS_PartGraphRecursive;
idx_t options[METIS_NOPTIONS]; idx_t options[METIS_NOPTIONS];
METIS_SetDefaultOptions(options); METIS_SetDefaultOptions(options);
...@@ -67,21 +69,21 @@ IdArray MetisPartition(UnitGraphPtr g, int k, NDArray vwgt_arr, ...@@ -67,21 +69,21 @@ IdArray MetisPartition(UnitGraphPtr g, int k, NDArray vwgt_arr,
} }
int ret = partition_func( int ret = partition_func(
&nvtxs, // The number of vertices &nvtxs, // The number of vertices
&ncon, // The number of balancing constraints. &ncon, // The number of balancing constraints.
xadj, // indptr xadj, // indptr
adjncy, // indices adjncy, // indices
vwgt, // the weights of the vertices vwgt, // the weights of the vertices
NULL, // The size of the vertices for computing NULL, // The size of the vertices for computing
// the total communication volume // the total communication volume
NULL, // The weights of the edges NULL, // The weights of the edges
&nparts, // The number of partitions. &nparts, // The number of partitions.
NULL, // the desired weight for each partition and constraint NULL, // the desired weight for each partition and constraint
NULL, // the allowed load imbalance tolerance NULL, // the allowed load imbalance tolerance
options, // the array of options options, // the array of options
&objval, // the edge-cut or the total communication volume of &objval, // the edge-cut or the total communication volume of
// the partitioning solution // the partitioning solution
part); part);
if (obj_cut) { if (obj_cut) {
LOG(INFO) << "Partition a graph with " << g->NumVertices(0) << " nodes and " LOG(INFO) << "Partition a graph with " << g->NumVertices(0) << " nodes and "
...@@ -110,22 +112,22 @@ IdArray MetisPartition(UnitGraphPtr g, int k, NDArray vwgt_arr, ...@@ -110,22 +112,22 @@ IdArray MetisPartition(UnitGraphPtr g, int k, NDArray vwgt_arr,
#endif // !defined(_WIN32) #endif // !defined(_WIN32)
DGL_REGISTER_GLOBAL("partition._CAPI_DGLMetisPartition_Hetero") DGL_REGISTER_GLOBAL("partition._CAPI_DGLMetisPartition_Hetero")
.set_body([](DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
HeteroGraphRef g = args[0]; HeteroGraphRef g = args[0];
auto hgptr = std::dynamic_pointer_cast<HeteroGraph>(g.sptr()); auto hgptr = std::dynamic_pointer_cast<HeteroGraph>(g.sptr());
CHECK(hgptr) << "Invalid HeteroGraph object"; CHECK(hgptr) << "Invalid HeteroGraph object";
CHECK_EQ(hgptr->relation_graphs().size(), 1) CHECK_EQ(hgptr->relation_graphs().size(), 1)
<< "Metis partition only supports HomoGraph"; << "Metis partition only supports HomoGraph";
auto ugptr = hgptr->relation_graphs()[0]; auto ugptr = hgptr->relation_graphs()[0];
int k = args[1]; int k = args[1];
NDArray vwgt = args[2]; NDArray vwgt = args[2];
std::string mode = args[3]; std::string mode = args[3];
bool obj_cut = args[4]; bool obj_cut = args[4];
#if !defined(_WIN32) #if !defined(_WIN32)
*rv = MetisPartition(ugptr, k, vwgt, mode, obj_cut); *rv = MetisPartition(ugptr, k, vwgt, mode, obj_cut);
#else #else
LOG(FATAL) << "Metis partition does not support Windows."; LOG(FATAL) << "Metis partition does not support Windows.";
#endif // !defined(_WIN32) #endif // !defined(_WIN32)
}); });
} // namespace transform } // namespace transform
} // namespace dgl } // namespace dgl
...@@ -37,25 +37,28 @@ HeteroGraphPtr ReorderUnitGraph(UnitGraphPtr ug, IdArray new_order) { ...@@ -37,25 +37,28 @@ HeteroGraphPtr ReorderUnitGraph(UnitGraphPtr ug, IdArray new_order) {
if (format & CSC_CODE) { if (format & CSC_CODE) {
auto cscmat = ug->GetCSCMatrix(0); auto cscmat = ug->GetCSCMatrix(0);
auto new_cscmat = aten::CSRReorder(cscmat, new_order, new_order); auto new_cscmat = aten::CSRReorder(cscmat, new_order, new_order);
return UnitGraph::CreateFromCSC(ug->NumVertexTypes(), new_cscmat, ug->GetAllowedFormats()); return UnitGraph::CreateFromCSC(
ug->NumVertexTypes(), new_cscmat, ug->GetAllowedFormats());
} else if (format & CSR_CODE) { } else if (format & CSR_CODE) {
auto csrmat = ug->GetCSRMatrix(0); auto csrmat = ug->GetCSRMatrix(0);
auto new_csrmat = aten::CSRReorder(csrmat, new_order, new_order); auto new_csrmat = aten::CSRReorder(csrmat, new_order, new_order);
return UnitGraph::CreateFromCSR(ug->NumVertexTypes(), new_csrmat, ug->GetAllowedFormats()); return UnitGraph::CreateFromCSR(
ug->NumVertexTypes(), new_csrmat, ug->GetAllowedFormats());
} else { } else {
auto coomat = ug->GetCOOMatrix(0); auto coomat = ug->GetCOOMatrix(0);
auto new_coomat = aten::COOReorder(coomat, new_order, new_order); auto new_coomat = aten::COOReorder(coomat, new_order, new_order);
return UnitGraph::CreateFromCOO(ug->NumVertexTypes(), new_coomat, ug->GetAllowedFormats()); return UnitGraph::CreateFromCOO(
ug->NumVertexTypes(), new_coomat, ug->GetAllowedFormats());
} }
} }
HaloHeteroSubgraph GetSubgraphWithHalo(std::shared_ptr<HeteroGraph> hg, HaloHeteroSubgraph GetSubgraphWithHalo(
IdArray nodes, int num_hops) { std::shared_ptr<HeteroGraph> hg, IdArray nodes, int num_hops) {
CHECK_EQ(hg->NumBits(), 64) << "halo subgraph only supports 64bits graph"; CHECK_EQ(hg->NumBits(), 64) << "halo subgraph only supports 64bits graph";
CHECK_EQ(hg->relation_graphs().size(), 1) CHECK_EQ(hg->relation_graphs().size(), 1)
<< "halo subgraph only supports homogeneous graph"; << "halo subgraph only supports homogeneous graph";
CHECK_EQ(nodes->dtype.bits, 64) CHECK_EQ(nodes->dtype.bits, 64)
<< "halo subgraph only supports 64bits nodes tensor"; << "halo subgraph only supports 64bits nodes tensor";
const dgl_id_t *nid = static_cast<dgl_id_t *>(nodes->data); const dgl_id_t *nid = static_cast<dgl_id_t *>(nodes->data);
const auto id_len = nodes->shape[0]; const auto id_len = nodes->shape[0];
// A map contains all nodes in the subgraph. // A map contains all nodes in the subgraph.
...@@ -113,8 +116,8 @@ HaloHeteroSubgraph GetSubgraphWithHalo(std::shared_ptr<HeteroGraph> hg, ...@@ -113,8 +116,8 @@ HaloHeteroSubgraph GetSubgraphWithHalo(std::shared_ptr<HeteroGraph> hg,
const dgl_id_t *eid_data = static_cast<dgl_id_t *>(eid->data); const dgl_id_t *eid_data = static_cast<dgl_id_t *>(eid->data);
for (int64_t i = 0; i < num_edges; i++) { for (int64_t i = 0; i < num_edges; i++) {
auto it1 = orig_nodes.find(src_data[i]); auto it1 = orig_nodes.find(src_data[i]);
// If the source node is in the partition, we have got this edge when we iterate over // If the source node is in the partition, we have got this edge when we
// the out-edges above. // iterate over the out-edges above.
if (it1 == orig_nodes.end()) { if (it1 == orig_nodes.end()) {
edge_src.push_back(src_data[i]); edge_src.push_back(src_data[i]);
edge_dst.push_back(dst_data[i]); edge_dst.push_back(dst_data[i]);
...@@ -164,10 +167,10 @@ HaloHeteroSubgraph GetSubgraphWithHalo(std::shared_ptr<HeteroGraph> hg, ...@@ -164,10 +167,10 @@ HaloHeteroSubgraph GetSubgraphWithHalo(std::shared_ptr<HeteroGraph> hg,
} }
num_edges = edge_src.size(); num_edges = edge_src.size();
IdArray new_src = IdArray::Empty({num_edges}, DGLDataType{kDGLInt, 64, 1}, IdArray new_src = IdArray::Empty(
DGLContext{kDGLCPU, 0}); {num_edges}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
IdArray new_dst = IdArray::Empty({num_edges}, DGLDataType{kDGLInt, 64, 1}, IdArray new_dst = IdArray::Empty(
DGLContext{kDGLCPU, 0}); {num_edges}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
dgl_id_t *new_src_data = static_cast<dgl_id_t *>(new_src->data); dgl_id_t *new_src_data = static_cast<dgl_id_t *>(new_src->data);
dgl_id_t *new_dst_data = static_cast<dgl_id_t *>(new_dst->data); dgl_id_t *new_dst_data = static_cast<dgl_id_t *>(new_dst->data);
for (size_t i = 0; i < edge_src.size(); i++) { for (size_t i = 0; i < edge_src.size(); i++) {
...@@ -180,8 +183,8 @@ HaloHeteroSubgraph GetSubgraphWithHalo(std::shared_ptr<HeteroGraph> hg, ...@@ -180,8 +183,8 @@ HaloHeteroSubgraph GetSubgraphWithHalo(std::shared_ptr<HeteroGraph> hg,
dgl_id_t old_nid = old_node_ids[i]; dgl_id_t old_nid = old_node_ids[i];
inner_nodes[i] = all_nodes[old_nid]; inner_nodes[i] = all_nodes[old_nid];
} }
aten::COOMatrix coo(old_node_ids.size(), old_node_ids.size(), new_src, aten::COOMatrix coo(
new_dst); old_node_ids.size(), old_node_ids.size(), new_src, new_dst);
HeteroGraphPtr ugptr = UnitGraph::CreateFromCOO(1, coo); HeteroGraphPtr ugptr = UnitGraph::CreateFromCOO(1, coo);
HeteroGraphPtr subg = CreateHeteroGraph(hg->meta_graph(), {ugptr}); HeteroGraphPtr subg = CreateHeteroGraph(hg->meta_graph(), {ugptr});
HaloHeteroSubgraph halo_subg; HaloHeteroSubgraph halo_subg;
...@@ -194,83 +197,83 @@ HaloHeteroSubgraph GetSubgraphWithHalo(std::shared_ptr<HeteroGraph> hg, ...@@ -194,83 +197,83 @@ HaloHeteroSubgraph GetSubgraphWithHalo(std::shared_ptr<HeteroGraph> hg,
} }
DGL_REGISTER_GLOBAL("partition._CAPI_DGLReorderGraph_Hetero") DGL_REGISTER_GLOBAL("partition._CAPI_DGLReorderGraph_Hetero")
.set_body([](DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
HeteroGraphRef g = args[0]; HeteroGraphRef g = args[0];
auto hgptr = std::dynamic_pointer_cast<HeteroGraph>(g.sptr()); auto hgptr = std::dynamic_pointer_cast<HeteroGraph>(g.sptr());
CHECK(hgptr) << "Invalid HeteroGraph object"; CHECK(hgptr) << "Invalid HeteroGraph object";
CHECK_EQ(hgptr->relation_graphs().size(), 1) CHECK_EQ(hgptr->relation_graphs().size(), 1)
<< "Reorder only supports HomoGraph"; << "Reorder only supports HomoGraph";
auto ugptr = hgptr->relation_graphs()[0]; auto ugptr = hgptr->relation_graphs()[0];
const IdArray new_order = args[1]; const IdArray new_order = args[1];
auto reorder_ugptr = ReorderUnitGraph(ugptr, new_order); auto reorder_ugptr = ReorderUnitGraph(ugptr, new_order);
std::vector<HeteroGraphPtr> rel_graphs = {reorder_ugptr}; std::vector<HeteroGraphPtr> rel_graphs = {reorder_ugptr};
*rv = HeteroGraphRef(std::make_shared<HeteroGraph>( *rv = HeteroGraphRef(std::make_shared<HeteroGraph>(
hgptr->meta_graph(), rel_graphs, hgptr->NumVerticesPerType())); hgptr->meta_graph(), rel_graphs, hgptr->NumVerticesPerType()));
}); });
DGL_REGISTER_GLOBAL("partition._CAPI_DGLPartitionWithHalo_Hetero") DGL_REGISTER_GLOBAL("partition._CAPI_DGLPartitionWithHalo_Hetero")
.set_body([](DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
HeteroGraphRef g = args[0]; HeteroGraphRef g = args[0];
auto hgptr = std::dynamic_pointer_cast<HeteroGraph>(g.sptr()); auto hgptr = std::dynamic_pointer_cast<HeteroGraph>(g.sptr());
CHECK(hgptr) << "Invalid HeteroGraph object"; CHECK(hgptr) << "Invalid HeteroGraph object";
CHECK_EQ(hgptr->relation_graphs().size(), 1) CHECK_EQ(hgptr->relation_graphs().size(), 1)
<< "Metis partition only supports HomoGraph"; << "Metis partition only supports HomoGraph";
auto ugptr = hgptr->relation_graphs()[0]; auto ugptr = hgptr->relation_graphs()[0];
IdArray node_parts = args[1]; IdArray node_parts = args[1];
int num_hops = args[2]; int num_hops = args[2];
CHECK_EQ(node_parts->dtype.bits, 64) CHECK_EQ(node_parts->dtype.bits, 64)
<< "Only supports 64bits tensor for now"; << "Only supports 64bits tensor for now";
const int64_t *part_data = static_cast<int64_t *>(node_parts->data); const int64_t *part_data = static_cast<int64_t *>(node_parts->data);
int64_t num_nodes = node_parts->shape[0]; int64_t num_nodes = node_parts->shape[0];
std::unordered_map<int, std::vector<int64_t>> part_map; std::unordered_map<int, std::vector<int64_t>> part_map;
for (int64_t i = 0; i < num_nodes; i++) { for (int64_t i = 0; i < num_nodes; i++) {
dgl_id_t part_id = part_data[i]; dgl_id_t part_id = part_data[i];
auto it = part_map.find(part_id); auto it = part_map.find(part_id);
if (it == part_map.end()) { if (it == part_map.end()) {
std::vector<int64_t> vec; std::vector<int64_t> vec;
vec.push_back(i); vec.push_back(i);
part_map[part_id] = vec; part_map[part_id] = vec;
} else { } else {
it->second.push_back(i); it->second.push_back(i);
}
} }
} std::vector<int> part_ids;
std::vector<int> part_ids; std::vector<std::vector<int64_t>> part_nodes;
std::vector<std::vector<int64_t>> part_nodes; int max_part_id = 0;
int max_part_id = 0; for (auto it = part_map.begin(); it != part_map.end(); it++) {
for (auto it = part_map.begin(); it != part_map.end(); it++) { max_part_id = std::max(it->first, max_part_id);
max_part_id = std::max(it->first, max_part_id); part_ids.push_back(it->first);
part_ids.push_back(it->first); part_nodes.push_back(it->second);
part_nodes.push_back(it->second); }
} // When we construct subgraphs, we need to access both in-edges and
// When we construct subgraphs, we need to access both in-edges and out-edges. // out-edges. We need to make sure the in-CSR and out-CSR exist.
// We need to make sure the in-CSR and out-CSR exist. Otherwise, we'll // Otherwise, we'll try to construct in-CSR and out-CSR in openmp for
// try to construct in-CSR and out-CSR in openmp for loop, which will lead // loop, which will lead to some unexpected results.
// to some unexpected results. ugptr->GetInCSR();
ugptr->GetInCSR(); ugptr->GetOutCSR();
ugptr->GetOutCSR(); std::vector<std::shared_ptr<HaloHeteroSubgraph>> subgs(max_part_id + 1);
std::vector<std::shared_ptr<HaloHeteroSubgraph>> subgs(max_part_id + 1); int num_partitions = part_nodes.size();
int num_partitions = part_nodes.size(); runtime::parallel_for(0, num_partitions, [&](int b, int e) {
runtime::parallel_for(0, num_partitions, [&](int b, int e) { for (auto i = b; i < e; i++) {
for (auto i = b; i < e; i++) { auto nodes = aten::VecToIdArray(part_nodes[i]);
auto nodes = aten::VecToIdArray(part_nodes[i]); HaloHeteroSubgraph subg = GetSubgraphWithHalo(hgptr, nodes, num_hops);
HaloHeteroSubgraph subg = GetSubgraphWithHalo(hgptr, nodes, num_hops); std::shared_ptr<HaloHeteroSubgraph> subg_ptr(
std::shared_ptr<HaloHeteroSubgraph> subg_ptr( new HaloHeteroSubgraph(subg));
new HaloHeteroSubgraph(subg)); int part_id = part_ids[i];
int part_id = part_ids[i]; subgs[part_id] = subg_ptr;
subgs[part_id] = subg_ptr; }
});
List<HeteroSubgraphRef> ret_list;
for (size_t i = 0; i < subgs.size(); i++) {
ret_list.push_back(HeteroSubgraphRef(subgs[i]));
} }
*rv = ret_list;
}); });
List<HeteroSubgraphRef> ret_list;
for (size_t i = 0; i < subgs.size(); i++) {
ret_list.push_back(HeteroSubgraphRef(subgs[i]));
}
*rv = ret_list;
});
template<class IdType> template <class IdType>
struct EdgeProperty { struct EdgeProperty {
IdType eid; IdType eid;
int64_t idx; int64_t idx;
...@@ -280,98 +283,101 @@ struct EdgeProperty { ...@@ -280,98 +283,101 @@ struct EdgeProperty {
// Reassign edge IDs so that all edges in a partition have contiguous edge IDs. // Reassign edge IDs so that all edges in a partition have contiguous edge IDs.
// The original edge IDs are returned. // The original edge IDs are returned.
DGL_REGISTER_GLOBAL("partition._CAPI_DGLReassignEdges_Hetero") DGL_REGISTER_GLOBAL("partition._CAPI_DGLReassignEdges_Hetero")
.set_body([](DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
HeteroGraphRef g = args[0]; HeteroGraphRef g = args[0];
auto hgptr = std::dynamic_pointer_cast<HeteroGraph>(g.sptr()); auto hgptr = std::dynamic_pointer_cast<HeteroGraph>(g.sptr());
CHECK(hgptr) << "Invalid HeteroGraph object"; CHECK(hgptr) << "Invalid HeteroGraph object";
CHECK_EQ(hgptr->relation_graphs().size(), 1) CHECK_EQ(hgptr->relation_graphs().size(), 1)
<< "Reorder only supports HomoGraph"; << "Reorder only supports HomoGraph";
auto ugptr = hgptr->relation_graphs()[0]; auto ugptr = hgptr->relation_graphs()[0];
IdArray etype = args[1]; IdArray etype = args[1];
IdArray part_id = args[2]; IdArray part_id = args[2];
bool is_incsr = args[3]; bool is_incsr = args[3];
auto csrmat = is_incsr ? ugptr->GetCSCMatrix(0) : ugptr->GetCSRMatrix(0); auto csrmat = is_incsr ? ugptr->GetCSCMatrix(0) : ugptr->GetCSRMatrix(0);
int64_t num_edges = csrmat.data->shape[0]; int64_t num_edges = csrmat.data->shape[0];
int64_t num_rows = csrmat.indptr->shape[0] - 1; int64_t num_rows = csrmat.indptr->shape[0] - 1;
IdArray new_data = IdArray new_data =
IdArray::Empty({num_edges}, csrmat.data->dtype, csrmat.data->ctx); IdArray::Empty({num_edges}, csrmat.data->dtype, csrmat.data->ctx);
// Return the original edge Ids. // Return the original edge Ids.
*rv = new_data; *rv = new_data;
// Generate new edge Ids. // Generate new edge Ids.
ATEN_ID_TYPE_SWITCH(new_data->dtype, IdType, { ATEN_ID_TYPE_SWITCH(new_data->dtype, IdType, {
CHECK(etype->dtype.bits == sizeof(IdType) * 8); CHECK(etype->dtype.bits == sizeof(IdType) * 8);
CHECK(part_id->dtype.bits == sizeof(IdType) * 8); CHECK(part_id->dtype.bits == sizeof(IdType) * 8);
const IdType *part_id_data = static_cast<IdType *>(part_id->data); const IdType *part_id_data = static_cast<IdType *>(part_id->data);
const IdType *etype_data = static_cast<IdType *>(etype->data); const IdType *etype_data = static_cast<IdType *>(etype->data);
const IdType *indptr_data = static_cast<IdType *>(csrmat.indptr->data); const IdType *indptr_data = static_cast<IdType *>(csrmat.indptr->data);
IdType *typed_data = static_cast<IdType *>(csrmat.data->data); IdType *typed_data = static_cast<IdType *>(csrmat.data->data);
IdType *typed_new_data = static_cast<IdType *>(new_data->data); IdType *typed_new_data = static_cast<IdType *>(new_data->data);
std::vector<EdgeProperty<IdType>> indexed_eids(num_edges); std::vector<EdgeProperty<IdType>> indexed_eids(num_edges);
for (int64_t i = 0; i < num_rows; i++) { for (int64_t i = 0; i < num_rows; i++) {
for (int64_t j = indptr_data[i]; j < indptr_data[i + 1]; j++) { for (int64_t j = indptr_data[i]; j < indptr_data[i + 1]; j++) {
indexed_eids[j].eid = typed_data[j]; indexed_eids[j].eid = typed_data[j];
indexed_eids[j].idx = j; indexed_eids[j].idx = j;
indexed_eids[j].part_id = part_id_data[i]; indexed_eids[j].part_id = part_id_data[i];
}
} }
} auto comp = [etype_data](
auto comp = [etype_data](const EdgeProperty<IdType> &a, const EdgeProperty<IdType> &b) { const EdgeProperty<IdType> &a,
if (a.part_id == b.part_id) { const EdgeProperty<IdType> &b) {
return etype_data[a.eid] < etype_data[b.eid]; if (a.part_id == b.part_id) {
} else { return etype_data[a.eid] < etype_data[b.eid];
return a.part_id < b.part_id; } else {
return a.part_id < b.part_id;
}
};
// We only need to sort the edges if the input graph has multiple
// relations. If it's a homogeneous grap, we'll just assign edge Ids
// based on its previous order.
if (etype->shape[0] > 0) {
std::sort(indexed_eids.begin(), indexed_eids.end(), comp);
} }
}; for (int64_t new_eid = 0; new_eid < num_edges; new_eid++) {
// We only need to sort the edges if the input graph has multiple relations. int64_t orig_idx = indexed_eids[new_eid].idx;
// If it's a homogeneous grap, we'll just assign edge Ids based on its previous order. typed_new_data[new_eid] = typed_data[orig_idx];
if (etype->shape[0] > 0) { typed_data[orig_idx] = new_eid;
std::sort(indexed_eids.begin(), indexed_eids.end(), comp); }
} });
for (int64_t new_eid = 0; new_eid < num_edges; new_eid++) { ugptr->InvalidateCSR();
int64_t orig_idx = indexed_eids[new_eid].idx; ugptr->InvalidateCOO();
typed_new_data[new_eid] = typed_data[orig_idx];
typed_data[orig_idx] = new_eid;
}
}); });
ugptr->InvalidateCSR();
ugptr->InvalidateCOO();
});
DGL_REGISTER_GLOBAL("partition._CAPI_GetHaloSubgraphInnerNodes_Hetero") DGL_REGISTER_GLOBAL("partition._CAPI_GetHaloSubgraphInnerNodes_Hetero")
.set_body([](DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
HeteroSubgraphRef g = args[0]; HeteroSubgraphRef g = args[0];
auto gptr = std::dynamic_pointer_cast<HaloHeteroSubgraph>(g.sptr()); auto gptr = std::dynamic_pointer_cast<HaloHeteroSubgraph>(g.sptr());
CHECK(gptr) << "The input graph has to be HaloHeteroSubgraph"; CHECK(gptr) << "The input graph has to be HaloHeteroSubgraph";
*rv = gptr->inner_nodes[0]; *rv = gptr->inner_nodes[0];
}); });
DGL_REGISTER_GLOBAL("partition._CAPI_DGLMakeSymmetric_Hetero") DGL_REGISTER_GLOBAL("partition._CAPI_DGLMakeSymmetric_Hetero")
.set_body([](DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
HeteroGraphRef g = args[0]; HeteroGraphRef g = args[0];
auto hgptr = std::dynamic_pointer_cast<HeteroGraph>(g.sptr()); auto hgptr = std::dynamic_pointer_cast<HeteroGraph>(g.sptr());
CHECK(hgptr) << "Invalid HeteroGraph object"; CHECK(hgptr) << "Invalid HeteroGraph object";
CHECK_EQ(hgptr->relation_graphs().size(), 1) CHECK_EQ(hgptr->relation_graphs().size(), 1)
<< "Metis partition only supports homogeneous graph"; << "Metis partition only supports homogeneous graph";
auto ugptr = hgptr->relation_graphs()[0]; auto ugptr = hgptr->relation_graphs()[0];
#if !defined(_WIN32) #if !defined(_WIN32)
// TODO(zhengda) should we get whatever CSR exists in the graph. // TODO(zhengda) should we get whatever CSR exists in the graph.
gk_csr_t *gk_csr = Convert2GKCsr(ugptr->GetCSCMatrix(0), true); gk_csr_t *gk_csr = Convert2GKCsr(ugptr->GetCSCMatrix(0), true);
gk_csr_t *sym_gk_csr = gk_csr_MakeSymmetric(gk_csr, GK_CSR_SYM_SUM); gk_csr_t *sym_gk_csr = gk_csr_MakeSymmetric(gk_csr, GK_CSR_SYM_SUM);
auto mat = Convert2DGLCsr(sym_gk_csr, true); auto mat = Convert2DGLCsr(sym_gk_csr, true);
gk_csr_Free(&gk_csr); gk_csr_Free(&gk_csr);
gk_csr_Free(&sym_gk_csr); gk_csr_Free(&sym_gk_csr);
auto new_ugptr = UnitGraph::CreateFromCSC(ugptr->NumVertexTypes(), mat, auto new_ugptr = UnitGraph::CreateFromCSC(
ugptr->GetAllowedFormats()); ugptr->NumVertexTypes(), mat, ugptr->GetAllowedFormats());
std::vector<HeteroGraphPtr> rel_graphs = {new_ugptr}; std::vector<HeteroGraphPtr> rel_graphs = {new_ugptr};
*rv = HeteroGraphRef(std::make_shared<HeteroGraph>( *rv = HeteroGraphRef(std::make_shared<HeteroGraph>(
hgptr->meta_graph(), rel_graphs, hgptr->NumVerticesPerType())); hgptr->meta_graph(), rel_graphs, hgptr->NumVerticesPerType()));
#else #else
LOG(FATAL) << "The fast version of making symmetric graph is not supported in Windows."; LOG(FATAL) << "The fast version of making symmetric graph is not "
"supported in Windows.";
#endif // !defined(_WIN32) #endif // !defined(_WIN32)
}); });
} // namespace transform } // namespace transform
} // namespace dgl } // namespace dgl
...@@ -4,15 +4,16 @@ ...@@ -4,15 +4,16 @@
* \brief Remove edges. * \brief Remove edges.
*/ */
#include <dgl/base_heterograph.h>
#include <dgl/transform.h>
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/base_heterograph.h>
#include <dgl/packed_func_ext.h> #include <dgl/packed_func_ext.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/container.h> #include <dgl/runtime/container.h>
#include <vector> #include <dgl/runtime/registry.h>
#include <utility> #include <dgl/transform.h>
#include <tuple> #include <tuple>
#include <utility>
#include <vector>
namespace dgl { namespace dgl {
...@@ -21,8 +22,8 @@ using namespace dgl::aten; ...@@ -21,8 +22,8 @@ using namespace dgl::aten;
namespace transform { namespace transform {
std::pair<HeteroGraphPtr, std::vector<IdArray>> std::pair<HeteroGraphPtr, std::vector<IdArray>> RemoveEdges(
RemoveEdges(const HeteroGraphPtr graph, const std::vector<IdArray> &eids) { const HeteroGraphPtr graph, const std::vector<IdArray> &eids) {
std::vector<IdArray> induced_eids; std::vector<IdArray> induced_eids;
std::vector<HeteroGraphPtr> rel_graphs; std::vector<HeteroGraphPtr> rel_graphs;
const int64_t num_etypes = graph->NumEdgeTypes(); const int64_t num_etypes = graph->NumEdgeTypes();
...@@ -40,23 +41,30 @@ RemoveEdges(const HeteroGraphPtr graph, const std::vector<IdArray> &eids) { ...@@ -40,23 +41,30 @@ RemoveEdges(const HeteroGraphPtr graph, const std::vector<IdArray> &eids) {
const COOMatrix &coo = graph->GetCOOMatrix(etype); const COOMatrix &coo = graph->GetCOOMatrix(etype);
const COOMatrix &result = COORemove(coo, eids[etype]); const COOMatrix &result = COORemove(coo, eids[etype]);
new_rel_graph = CreateFromCOO( new_rel_graph = CreateFromCOO(
num_ntypes_rel, result.num_rows, result.num_cols, result.row, result.col); num_ntypes_rel, result.num_rows, result.num_cols, result.row,
result.col);
induced_eids_rel = result.data; induced_eids_rel = result.data;
} else if (fmt == SparseFormat::kCSR) { } else if (fmt == SparseFormat::kCSR) {
const CSRMatrix &csr = graph->GetCSRMatrix(etype); const CSRMatrix &csr = graph->GetCSRMatrix(etype);
const CSRMatrix &result = CSRRemove(csr, eids[etype]); const CSRMatrix &result = CSRRemove(csr, eids[etype]);
new_rel_graph = CreateFromCSR( new_rel_graph = CreateFromCSR(
num_ntypes_rel, result.num_rows, result.num_cols, result.indptr, result.indices, num_ntypes_rel, result.num_rows, result.num_cols, result.indptr,
result.indices,
// TODO(BarclayII): make CSR support null eid array // TODO(BarclayII): make CSR support null eid array
Range(0, result.indices->shape[0], result.indices->dtype.bits, result.indices->ctx)); Range(
0, result.indices->shape[0], result.indices->dtype.bits,
result.indices->ctx));
induced_eids_rel = result.data; induced_eids_rel = result.data;
} else if (fmt == SparseFormat::kCSC) { } else if (fmt == SparseFormat::kCSC) {
const CSRMatrix &csc = graph->GetCSCMatrix(etype); const CSRMatrix &csc = graph->GetCSCMatrix(etype);
const CSRMatrix &result = CSRRemove(csc, eids[etype]); const CSRMatrix &result = CSRRemove(csc, eids[etype]);
new_rel_graph = CreateFromCSC( new_rel_graph = CreateFromCSC(
num_ntypes_rel, result.num_rows, result.num_cols, result.indptr, result.indices, num_ntypes_rel, result.num_rows, result.num_cols, result.indptr,
result.indices,
// TODO(BarclayII): make CSR support null eid array // TODO(BarclayII): make CSR support null eid array
Range(0, result.indices->shape[0], result.indices->dtype.bits, result.indices->ctx)); Range(
0, result.indices->shape[0], result.indices->dtype.bits,
result.indices->ctx));
induced_eids_rel = result.data; induced_eids_rel = result.data;
} }
...@@ -70,24 +78,24 @@ RemoveEdges(const HeteroGraphPtr graph, const std::vector<IdArray> &eids) { ...@@ -70,24 +78,24 @@ RemoveEdges(const HeteroGraphPtr graph, const std::vector<IdArray> &eids) {
} }
DGL_REGISTER_GLOBAL("transform._CAPI_DGLRemoveEdges") DGL_REGISTER_GLOBAL("transform._CAPI_DGLRemoveEdges")
.set_body([] (DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
const HeteroGraphRef graph_ref = args[0]; const HeteroGraphRef graph_ref = args[0];
const std::vector<IdArray> &eids = ListValueToVector<IdArray>(args[1]); const std::vector<IdArray> &eids = ListValueToVector<IdArray>(args[1]);
HeteroGraphPtr new_graph; HeteroGraphPtr new_graph;
std::vector<IdArray> induced_eids; std::vector<IdArray> induced_eids;
std::tie(new_graph, induced_eids) = RemoveEdges(graph_ref.sptr(), eids); std::tie(new_graph, induced_eids) = RemoveEdges(graph_ref.sptr(), eids);
List<Value> induced_eids_ref; List<Value> induced_eids_ref;
for (IdArray &array : induced_eids) for (IdArray &array : induced_eids)
induced_eids_ref.push_back(Value(MakeValue(array))); induced_eids_ref.push_back(Value(MakeValue(array)));
List<ObjectRef> ret; List<ObjectRef> ret;
ret.push_back(HeteroGraphRef(new_graph)); ret.push_back(HeteroGraphRef(new_graph));
ret.push_back(induced_eids_ref); ret.push_back(induced_eids_ref);
*rv = ret; *rv = ret;
}); });
}; // namespace transform }; // namespace transform
......
...@@ -19,16 +19,18 @@ ...@@ -19,16 +19,18 @@
#include "to_bipartite.h" #include "to_bipartite.h"
#include <dgl/base_heterograph.h>
#include <dgl/transform.h>
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/packed_func_ext.h> #include <dgl/base_heterograph.h>
#include <dgl/immutable_graph.h> #include <dgl/immutable_graph.h>
#include <dgl/runtime/registry.h> #include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h> #include <dgl/runtime/container.h>
#include <vector> #include <dgl/runtime/registry.h>
#include <dgl/transform.h>
#include <tuple> #include <tuple>
#include <utility> #include <utility>
#include <vector>
#include "../../array/cpu/array_utils.h" #include "../../array/cpu/array_utils.h"
namespace dgl { namespace dgl {
...@@ -42,11 +44,11 @@ namespace { ...@@ -42,11 +44,11 @@ namespace {
// Since partial specialization is not allowed for functions, use this as an // Since partial specialization is not allowed for functions, use this as an
// intermediate for ToBlock where XPU = kDGLCPU. // intermediate for ToBlock where XPU = kDGLCPU.
template<typename IdType> template <typename IdType>
std::tuple<HeteroGraphPtr, std::vector<IdArray>> std::tuple<HeteroGraphPtr, std::vector<IdArray>> ToBlockCPU(
ToBlockCPU(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes, HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes,
bool include_rhs_in_lhs, std::vector<IdArray>* const lhs_nodes_ptr) { bool include_rhs_in_lhs, std::vector<IdArray> *const lhs_nodes_ptr) {
std::vector<IdArray>& lhs_nodes = *lhs_nodes_ptr; std::vector<IdArray> &lhs_nodes = *lhs_nodes_ptr;
const bool generate_lhs_nodes = lhs_nodes.empty(); const bool generate_lhs_nodes = lhs_nodes.empty();
const int64_t num_etypes = graph->NumEdgeTypes(); const int64_t num_etypes = graph->NumEdgeTypes();
...@@ -54,28 +56,29 @@ ToBlockCPU(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes, ...@@ -54,28 +56,29 @@ ToBlockCPU(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes,
std::vector<EdgeArray> edge_arrays(num_etypes); std::vector<EdgeArray> edge_arrays(num_etypes);
CHECK(rhs_nodes.size() == static_cast<size_t>(num_ntypes)) CHECK(rhs_nodes.size() == static_cast<size_t>(num_ntypes))
<< "rhs_nodes not given for every node type"; << "rhs_nodes not given for every node type";
const std::vector<IdHashMap<IdType>> rhs_node_mappings(rhs_nodes.begin(), rhs_nodes.end()); const std::vector<IdHashMap<IdType>> rhs_node_mappings(
rhs_nodes.begin(), rhs_nodes.end());
std::vector<IdHashMap<IdType>> lhs_node_mappings; std::vector<IdHashMap<IdType>> lhs_node_mappings;
if (generate_lhs_nodes) { if (generate_lhs_nodes) {
// build lhs_node_mappings -- if we don't have them already // build lhs_node_mappings -- if we don't have them already
if (include_rhs_in_lhs) if (include_rhs_in_lhs)
lhs_node_mappings = rhs_node_mappings; // copy lhs_node_mappings = rhs_node_mappings; // copy
else else
lhs_node_mappings.resize(num_ntypes); lhs_node_mappings.resize(num_ntypes);
} else { } else {
lhs_node_mappings = std::vector<IdHashMap<IdType>>(lhs_nodes.begin(), lhs_nodes.end()); lhs_node_mappings =
std::vector<IdHashMap<IdType>>(lhs_nodes.begin(), lhs_nodes.end());
} }
for (int64_t etype = 0; etype < num_etypes; ++etype) { for (int64_t etype = 0; etype < num_etypes; ++etype) {
const auto src_dst_types = graph->GetEndpointTypes(etype); const auto src_dst_types = graph->GetEndpointTypes(etype);
const dgl_type_t srctype = src_dst_types.first; const dgl_type_t srctype = src_dst_types.first;
const dgl_type_t dsttype = src_dst_types.second; const dgl_type_t dsttype = src_dst_types.second;
if (!aten::IsNullArray(rhs_nodes[dsttype])) { if (!aten::IsNullArray(rhs_nodes[dsttype])) {
const EdgeArray& edges = graph->Edges(etype); const EdgeArray &edges = graph->Edges(etype);
if (generate_lhs_nodes) { if (generate_lhs_nodes) {
lhs_node_mappings[srctype].Update(edges.src); lhs_node_mappings[srctype].Update(edges.src);
} }
...@@ -89,8 +92,8 @@ ToBlockCPU(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes, ...@@ -89,8 +92,8 @@ ToBlockCPU(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes,
const auto meta_graph = graph->meta_graph(); const auto meta_graph = graph->meta_graph();
const EdgeArray etypes = meta_graph->Edges("eid"); const EdgeArray etypes = meta_graph->Edges("eid");
const IdArray new_dst = Add(etypes.dst, num_ntypes); const IdArray new_dst = Add(etypes.dst, num_ntypes);
const auto new_meta_graph = ImmutableGraph::CreateFromCOO( const auto new_meta_graph =
num_ntypes * 2, etypes.src, new_dst); ImmutableGraph::CreateFromCOO(num_ntypes * 2, etypes.src, new_dst);
for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) for (int64_t ntype = 0; ntype < num_ntypes; ++ntype)
num_nodes_per_type.push_back(lhs_node_mappings[ntype].Size()); num_nodes_per_type.push_back(lhs_node_mappings[ntype].Size());
...@@ -108,8 +111,8 @@ ToBlockCPU(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes, ...@@ -108,8 +111,8 @@ ToBlockCPU(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes,
if (rhs_map.Size() == 0) { if (rhs_map.Size() == 0) {
// No rhs nodes are given for this edge type. Create an empty graph. // No rhs nodes are given for this edge type. Create an empty graph.
rel_graphs.push_back(CreateFromCOO( rel_graphs.push_back(CreateFromCOO(
2, lhs_map.Size(), rhs_map.Size(), 2, lhs_map.Size(), rhs_map.Size(), aten::NullArray(),
aten::NullArray(), aten::NullArray())); aten::NullArray()));
induced_edges.push_back(aten::NullArray()); induced_edges.push_back(aten::NullArray());
} else { } else {
IdArray new_src = lhs_map.Map(edge_arrays[etype].src, -1); IdArray new_src = lhs_map.Map(edge_arrays[etype].src, -1);
...@@ -117,22 +120,22 @@ ToBlockCPU(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes, ...@@ -117,22 +120,22 @@ ToBlockCPU(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes,
// Check whether there are unmapped IDs and raise error. // Check whether there are unmapped IDs and raise error.
for (int64_t i = 0; i < new_dst->shape[0]; ++i) for (int64_t i = 0; i < new_dst->shape[0]; ++i)
CHECK_NE(new_dst.Ptr<IdType>()[i], -1) CHECK_NE(new_dst.Ptr<IdType>()[i], -1)
<< "Node " << edge_arrays[etype].dst.Ptr<IdType>()[i] << " does not exist" << "Node " << edge_arrays[etype].dst.Ptr<IdType>()[i]
<< " in `rhs_nodes`. Argument `rhs_nodes` must contain all the edge" << " does not exist"
<< " destination nodes."; << " in `rhs_nodes`. Argument `rhs_nodes` must contain all the edge"
rel_graphs.push_back(CreateFromCOO( << " destination nodes.";
2, lhs_map.Size(), rhs_map.Size(), rel_graphs.push_back(
new_src, new_dst)); CreateFromCOO(2, lhs_map.Size(), rhs_map.Size(), new_src, new_dst));
induced_edges.push_back(edge_arrays[etype].id); induced_edges.push_back(edge_arrays[etype].id);
} }
} }
const HeteroGraphPtr new_graph = CreateHeteroGraph( const HeteroGraphPtr new_graph =
new_meta_graph, rel_graphs, num_nodes_per_type); CreateHeteroGraph(new_meta_graph, rel_graphs, num_nodes_per_type);
if (generate_lhs_nodes) { if (generate_lhs_nodes) {
CHECK_EQ(lhs_nodes.size(), 0) << "InteralError: lhs_nodes should be empty " CHECK_EQ(lhs_nodes.size(), 0) << "InteralError: lhs_nodes should be empty "
"when generating it."; "when generating it.";
for (const IdHashMap<IdType> &lhs_map : lhs_node_mappings) for (const IdHashMap<IdType> &lhs_map : lhs_node_mappings)
lhs_nodes.push_back(lhs_map.Values()); lhs_nodes.push_back(lhs_map.Values());
} }
...@@ -141,87 +144,83 @@ ToBlockCPU(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes, ...@@ -141,87 +144,83 @@ ToBlockCPU(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes,
} // namespace } // namespace
template<> template <>
std::tuple<HeteroGraphPtr, std::vector<IdArray>> std::tuple<HeteroGraphPtr, std::vector<IdArray>> ToBlock<kDGLCPU, int32_t>(
ToBlock<kDGLCPU, int32_t>(HeteroGraphPtr graph, HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes,
const std::vector<IdArray> &rhs_nodes, bool include_rhs_in_lhs, std::vector<IdArray> *const lhs_nodes) {
bool include_rhs_in_lhs,
std::vector<IdArray>* const lhs_nodes) {
return ToBlockCPU<int32_t>(graph, rhs_nodes, include_rhs_in_lhs, lhs_nodes); return ToBlockCPU<int32_t>(graph, rhs_nodes, include_rhs_in_lhs, lhs_nodes);
} }
template<> template <>
std::tuple<HeteroGraphPtr, std::vector<IdArray>> std::tuple<HeteroGraphPtr, std::vector<IdArray>> ToBlock<kDGLCPU, int64_t>(
ToBlock<kDGLCPU, int64_t>(HeteroGraphPtr graph, HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes,
const std::vector<IdArray> &rhs_nodes, bool include_rhs_in_lhs, std::vector<IdArray> *const lhs_nodes) {
bool include_rhs_in_lhs,
std::vector<IdArray>* const lhs_nodes) {
return ToBlockCPU<int64_t>(graph, rhs_nodes, include_rhs_in_lhs, lhs_nodes); return ToBlockCPU<int64_t>(graph, rhs_nodes, include_rhs_in_lhs, lhs_nodes);
} }
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
// Forward declaration of GPU ToBlock implementations - actual implementation is in // Forward declaration of GPU ToBlock implementations - actual implementation is
// in
// ./cuda/cuda_to_block.cu // ./cuda/cuda_to_block.cu
// This is to get around the broken name mangling in VS2019 CL 16.5.5 + CUDA 11.3 // This is to get around the broken name mangling in VS2019 CL 16.5.5 +
// which complains that the two template specializations have the same signature. // CUDA 11.3 which complains that the two template specializations have the same
std::tuple<HeteroGraphPtr, std::vector<IdArray>> // signature.
ToBlockGPU32(HeteroGraphPtr, const std::vector<IdArray>&, bool, std::vector<IdArray>* const); std::tuple<HeteroGraphPtr, std::vector<IdArray>> ToBlockGPU32(
std::tuple<HeteroGraphPtr, std::vector<IdArray>> HeteroGraphPtr, const std::vector<IdArray> &, bool,
ToBlockGPU64(HeteroGraphPtr, const std::vector<IdArray>&, bool, std::vector<IdArray>* const); std::vector<IdArray> *const);
std::tuple<HeteroGraphPtr, std::vector<IdArray>> ToBlockGPU64(
template<> HeteroGraphPtr, const std::vector<IdArray> &, bool,
std::tuple<HeteroGraphPtr, std::vector<IdArray>> std::vector<IdArray> *const);
ToBlock<kDGLCUDA, int32_t>(HeteroGraphPtr graph,
const std::vector<IdArray> &rhs_nodes, template <>
bool include_rhs_in_lhs, std::tuple<HeteroGraphPtr, std::vector<IdArray>> ToBlock<kDGLCUDA, int32_t>(
std::vector<IdArray>* const lhs_nodes) { HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes,
bool include_rhs_in_lhs, std::vector<IdArray> *const lhs_nodes) {
return ToBlockGPU32(graph, rhs_nodes, include_rhs_in_lhs, lhs_nodes); return ToBlockGPU32(graph, rhs_nodes, include_rhs_in_lhs, lhs_nodes);
} }
template<> template <>
std::tuple<HeteroGraphPtr, std::vector<IdArray>> std::tuple<HeteroGraphPtr, std::vector<IdArray>> ToBlock<kDGLCUDA, int64_t>(
ToBlock<kDGLCUDA, int64_t>(HeteroGraphPtr graph, HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes,
const std::vector<IdArray> &rhs_nodes, bool include_rhs_in_lhs, std::vector<IdArray> *const lhs_nodes) {
bool include_rhs_in_lhs,
std::vector<IdArray>* const lhs_nodes) {
return ToBlockGPU64(graph, rhs_nodes, include_rhs_in_lhs, lhs_nodes); return ToBlockGPU64(graph, rhs_nodes, include_rhs_in_lhs, lhs_nodes);
} }
#endif // DGL_USE_CUDA #endif // DGL_USE_CUDA
DGL_REGISTER_GLOBAL("transform._CAPI_DGLToBlock") DGL_REGISTER_GLOBAL("transform._CAPI_DGLToBlock")
.set_body([] (DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
const HeteroGraphRef graph_ref = args[0]; const HeteroGraphRef graph_ref = args[0];
const std::vector<IdArray> &rhs_nodes = ListValueToVector<IdArray>(args[1]); const std::vector<IdArray> &rhs_nodes =
const bool include_rhs_in_lhs = args[2]; ListValueToVector<IdArray>(args[1]);
std::vector<IdArray> lhs_nodes = ListValueToVector<IdArray>(args[3]); const bool include_rhs_in_lhs = args[2];
std::vector<IdArray> lhs_nodes = ListValueToVector<IdArray>(args[3]);
HeteroGraphPtr new_graph;
std::vector<IdArray> induced_edges; HeteroGraphPtr new_graph;
std::vector<IdArray> induced_edges;
ATEN_XPU_SWITCH_CUDA(graph_ref->Context().device_type, XPU, "ToBlock", {
ATEN_ID_TYPE_SWITCH(graph_ref->DataType(), IdType, { ATEN_XPU_SWITCH_CUDA(graph_ref->Context().device_type, XPU, "ToBlock", {
std::tie(new_graph, induced_edges) = ToBlock<XPU, IdType>( ATEN_ID_TYPE_SWITCH(graph_ref->DataType(), IdType, {
graph_ref.sptr(), rhs_nodes, include_rhs_in_lhs, std::tie(new_graph, induced_edges) = ToBlock<XPU, IdType>(
&lhs_nodes); graph_ref.sptr(), rhs_nodes, include_rhs_in_lhs, &lhs_nodes);
});
}); });
});
List<Value> lhs_nodes_ref; List<Value> lhs_nodes_ref;
for (IdArray &array : lhs_nodes) for (IdArray &array : lhs_nodes)
lhs_nodes_ref.push_back(Value(MakeValue(array))); lhs_nodes_ref.push_back(Value(MakeValue(array)));
List<Value> induced_edges_ref; List<Value> induced_edges_ref;
for (IdArray &array : induced_edges) for (IdArray &array : induced_edges)
induced_edges_ref.push_back(Value(MakeValue(array))); induced_edges_ref.push_back(Value(MakeValue(array)));
List<ObjectRef> ret; List<ObjectRef> ret;
ret.push_back(HeteroGraphRef(new_graph)); ret.push_back(HeteroGraphRef(new_graph));
ret.push_back(lhs_nodes_ref); ret.push_back(lhs_nodes_ref);
ret.push_back(induced_edges_ref); ret.push_back(induced_edges_ref);
*rv = ret; *rv = ret;
}); });
}; // namespace transform }; // namespace transform
......
...@@ -44,10 +44,10 @@ namespace transform { ...@@ -44,10 +44,10 @@ namespace transform {
* *
* @return The block and the induced edges. * @return The block and the induced edges.
*/ */
template<DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
std::tuple<HeteroGraphPtr, std::vector<IdArray>> std::tuple<HeteroGraphPtr, std::vector<IdArray>> ToBlock(
ToBlock(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes, HeteroGraphPtr graph, const std::vector<IdArray>& rhs_nodes,
bool include_rhs_in_lhs, std::vector<IdArray>* lhs_nodes); bool include_rhs_in_lhs, std::vector<IdArray>* lhs_nodes);
} // namespace transform } // namespace transform
} // namespace dgl } // namespace dgl
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment