Unverified Commit 5854ef5e authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[Enhancement]Speed up ToBlockCPU with concurrent id hash map (#5297)

parent cce31e9a
......@@ -34,7 +34,7 @@
#include <utility>
#include <vector>
#include "../../array/cpu/array_utils.h"
#include "../../array/cpu/concurrent_id_hash_map.h"
namespace dgl {
......@@ -45,104 +45,89 @@ namespace transform {
namespace {
// Since partial specialization is not allowed for functions, use this as an
// intermediate for ToBlock where XPU = kDGLCPU.
template <typename IdType>
std::tuple<HeteroGraphPtr, std::vector<IdArray>> ToBlockCPU(
HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes,
bool include_rhs_in_lhs, std::vector<IdArray> *const lhs_nodes_ptr) {
struct CPUIdsMapper {
std::tuple<std::vector<IdArray>, std::vector<IdArray>> operator()(
const HeteroGraphPtr &graph, bool include_rhs_in_lhs, int64_t num_ntypes,
const DGLContext &ctx, const std::vector<int64_t> &max_nodes_per_type,
const std::vector<EdgeArray> &edge_arrays,
const std::vector<IdArray> &src_nodes,
const std::vector<IdArray> &rhs_nodes,
std::vector<IdArray> *const lhs_nodes_ptr,
std::vector<int64_t> *const num_nodes_per_type_ptr) {
std::vector<IdArray> &lhs_nodes = *lhs_nodes_ptr;
const bool generate_lhs_nodes = lhs_nodes.empty();
const int64_t num_etypes = graph->NumEdgeTypes();
const int64_t num_ntypes = graph->NumVertexTypes();
std::vector<EdgeArray> edge_arrays(num_etypes);
CHECK(rhs_nodes.size() == static_cast<size_t>(num_ntypes))
<< "rhs_nodes not given for every node type";
const std::vector<IdHashMap<IdType>> rhs_node_mappings(
rhs_nodes.begin(), rhs_nodes.end());
std::vector<IdHashMap<IdType>> lhs_node_mappings;
std::vector<int64_t> &num_nodes_per_type = *num_nodes_per_type_ptr;
const bool generate_lhs_nodes = lhs_nodes.empty();
if (generate_lhs_nodes) {
// build lhs_node_mappings -- if we don't have them already
if (include_rhs_in_lhs)
lhs_node_mappings = rhs_node_mappings; // copy
else
lhs_node_mappings.resize(num_ntypes);
} else {
lhs_node_mappings =
std::vector<IdHashMap<IdType>>(lhs_nodes.begin(), lhs_nodes.end());
lhs_nodes.reserve(num_ntypes);
}
for (int64_t etype = 0; etype < num_etypes; ++etype) {
const auto src_dst_types = graph->GetEndpointTypes(etype);
const dgl_type_t srctype = src_dst_types.first;
const dgl_type_t dsttype = src_dst_types.second;
if (!aten::IsNullArray(rhs_nodes[dsttype])) {
const EdgeArray &edges = graph->Edges(etype);
if (generate_lhs_nodes) {
lhs_node_mappings[srctype].Update(edges.src);
std::vector<ConcurrentIdHashMap<IdType>> lhs_nodes_map(num_ntypes);
std::vector<ConcurrentIdHashMap<IdType>> rhs_nodes_map(num_ntypes);
for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) {
IdArray unique_ids =
aten::NullArray(DGLDataTypeTraits<IdType>::dtype, ctx);
if (!aten::IsNullArray(src_nodes[ntype])) {
auto num_seeds = include_rhs_in_lhs ? rhs_nodes[ntype]->shape[0] : 0;
unique_ids = lhs_nodes_map[ntype].Init(src_nodes[ntype], num_seeds);
}
edge_arrays[etype] = edges;
if (generate_lhs_nodes) {
num_nodes_per_type[ntype] = unique_ids->shape[0];
lhs_nodes.emplace_back(unique_ids);
}
}
std::vector<int64_t> num_nodes_per_type;
num_nodes_per_type.reserve(2 * num_ntypes);
const auto meta_graph = graph->meta_graph();
const EdgeArray etypes = meta_graph->Edges("eid");
const IdArray new_dst = Add(etypes.dst, num_ntypes);
const auto new_meta_graph =
ImmutableGraph::CreateFromCOO(num_ntypes * 2, etypes.src, new_dst);
for (int64_t ntype = 0; ntype < num_ntypes; ++ntype)
num_nodes_per_type.push_back(lhs_node_mappings[ntype].Size());
for (int64_t ntype = 0; ntype < num_ntypes; ++ntype)
num_nodes_per_type.push_back(rhs_node_mappings[ntype].Size());
// Skip rhs mapping construction to save efforts when rhs is already
// contained in lhs.
if (!include_rhs_in_lhs) {
for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) {
if (!aten::IsNullArray(rhs_nodes[ntype])) {
rhs_nodes_map[ntype].Init(
rhs_nodes[ntype], rhs_nodes[ntype]->shape[0]);
}
}
}
std::vector<HeteroGraphPtr> rel_graphs;
std::vector<IdArray> induced_edges;
// Map node numberings from global to local, and build pointer for CSR.
std::vector<IdArray> new_lhs;
std::vector<IdArray> new_rhs;
new_lhs.reserve(edge_arrays.size());
new_rhs.reserve(edge_arrays.size());
const int64_t num_etypes = static_cast<int64_t>(edge_arrays.size());
for (int64_t etype = 0; etype < num_etypes; ++etype) {
const EdgeArray &edges = edge_arrays[etype];
if (edges.id.defined() && !aten::IsNullArray(edges.src)) {
const auto src_dst_types = graph->GetEndpointTypes(etype);
const dgl_type_t srctype = src_dst_types.first;
const dgl_type_t dsttype = src_dst_types.second;
const IdHashMap<IdType> &lhs_map = lhs_node_mappings[srctype];
const IdHashMap<IdType> &rhs_map = rhs_node_mappings[dsttype];
if (rhs_map.Size() == 0) {
// No rhs nodes are given for this edge type. Create an empty graph.
rel_graphs.push_back(CreateFromCOO(
2, lhs_map.Size(), rhs_map.Size(), aten::NullArray(),
aten::NullArray()));
induced_edges.push_back(aten::NullArray());
const int src_type = src_dst_types.first;
const int dst_type = src_dst_types.second;
new_lhs.emplace_back(lhs_nodes_map[src_type].MapIds(edges.src));
if (include_rhs_in_lhs) {
new_rhs.emplace_back(lhs_nodes_map[dst_type].MapIds(edges.dst));
} else {
IdArray new_src = lhs_map.Map(edge_arrays[etype].src, -1);
IdArray new_dst = rhs_map.Map(edge_arrays[etype].dst, -1);
// Check whether there are unmapped IDs and raise error.
for (int64_t i = 0; i < new_dst->shape[0]; ++i)
CHECK_NE(new_dst.Ptr<IdType>()[i], -1)
<< "Node " << edge_arrays[etype].dst.Ptr<IdType>()[i]
<< " does not exist"
<< " in `rhs_nodes`. Argument `rhs_nodes` must contain all the edge"
<< " destination nodes.";
rel_graphs.push_back(
CreateFromCOO(2, lhs_map.Size(), rhs_map.Size(), new_src, new_dst));
induced_edges.push_back(edge_arrays[etype].id);
new_rhs.emplace_back(rhs_nodes_map[dst_type].MapIds(edges.dst));
}
} else {
new_lhs.emplace_back(
aten::NullArray(DGLDataTypeTraits<IdType>::dtype, ctx));
new_rhs.emplace_back(
aten::NullArray(DGLDataTypeTraits<IdType>::dtype, ctx));
}
const HeteroGraphPtr new_graph =
CreateHeteroGraph(new_meta_graph, rel_graphs, num_nodes_per_type);
if (generate_lhs_nodes) {
CHECK_EQ(lhs_nodes.size(), 0) << "InteralError: lhs_nodes should be empty "
"when generating it.";
for (const IdHashMap<IdType> &lhs_map : lhs_node_mappings)
lhs_nodes.push_back(lhs_map.Values());
}
return std::make_tuple(new_graph, induced_edges);
return std::tuple<std::vector<IdArray>, std::vector<IdArray>>(
std::move(new_lhs), std::move(new_rhs));
}
};
// Since partial specialization is not allowed for functions, use this as an
// intermediate for ToBlock where XPU = kDGLCPU.
template <typename IdType>
std::tuple<HeteroGraphPtr, std::vector<IdArray>> ToBlockCPU(
HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes,
bool include_rhs_in_lhs, std::vector<IdArray> *const lhs_nodes_ptr) {
return dgl::transform::ProcessToBlock<IdType>(
graph, rhs_nodes, include_rhs_in_lhs, lhs_nodes_ptr,
CPUIdsMapper<IdType>());
}
} // namespace
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment