Unverified Commit 7b3a7b14 authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by GitHub
Browse files

[Optimization] Optimize CompactGraph (#1328)



* compact

* prealloc memory for hashtable

* Fix

* upd

* Reduce memory

* upd
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-51-214.ec2.internal>
parent 20e1bb45
...@@ -40,6 +40,10 @@ class IdHashMap { ...@@ -40,6 +40,10 @@ class IdHashMap {
// copy ctor // copy ctor
IdHashMap(const IdHashMap &other) = default; IdHashMap(const IdHashMap &other) = default;
void Reserve(const int64_t size) {
oldv2newv_.reserve(size);
}
// Update the hashmap with given id array. // Update the hashmap with given id array.
// The id array could contain duplicates. // The id array could contain duplicates.
void Update(IdArray ids) { void Update(IdArray ids) {
......
...@@ -36,13 +36,38 @@ CompactGraphs( ...@@ -36,13 +36,38 @@ CompactGraphs(
std::vector<aten::IdHashMap<IdType>> hashmaps(num_ntypes); std::vector<aten::IdHashMap<IdType>> hashmaps(num_ntypes);
std::vector<std::vector<EdgeArray>> all_edges(graphs.size()); // all_edges[i][etype] std::vector<std::vector<EdgeArray>> all_edges(graphs.size()); // all_edges[i][etype]
for (size_t i = 0; i < always_preserve.size(); ++i) std::vector<int64_t> max_vertex_cnt(num_ntypes, 0);
for (size_t i = 0; i < graphs.size(); ++i) {
const HeteroGraphPtr curr_graph = graphs[i];
const int64_t num_etypes = curr_graph->NumEdgeTypes();
for (IdType etype = 0; etype < num_etypes; ++etype) {
IdType srctype, dsttype;
std::tie(srctype, dsttype) = curr_graph->GetEndpointTypes(etype);
const int64_t n_edges = curr_graph->NumEdges(etype);
max_vertex_cnt[srctype] += n_edges;
max_vertex_cnt[dsttype] += n_edges;
}
}
// Reserve the space for hash maps before ahead to aoivd rehashing
for (size_t i = 0; i < num_ntypes; ++i) {
if (i < always_preserve.size())
hashmaps[i].Reserve(always_preserve[i]->shape[0] + max_vertex_cnt[i]);
else
hashmaps[i].Reserve(max_vertex_cnt[i]);
}
for (size_t i = 0; i < always_preserve.size(); ++i) {
hashmaps[i].Update(always_preserve[i]); hashmaps[i].Update(always_preserve[i]);
}
for (size_t i = 0; i < graphs.size(); ++i) { for (size_t i = 0; i < graphs.size(); ++i) {
const HeteroGraphPtr curr_graph = graphs[i]; const HeteroGraphPtr curr_graph = graphs[i];
const int64_t num_etypes = curr_graph->NumEdgeTypes(); const int64_t num_etypes = curr_graph->NumEdgeTypes();
all_edges[i].reserve(num_etypes);
for (IdType etype = 0; etype < num_etypes; ++etype) { for (IdType etype = 0; etype < num_etypes; ++etype) {
IdType srctype, dsttype; IdType srctype, dsttype;
std::tie(srctype, dsttype) = curr_graph->GetEndpointTypes(etype); std::tie(srctype, dsttype) = curr_graph->GetEndpointTypes(etype);
......
...@@ -552,7 +552,6 @@ if __name__ == '__main__': ...@@ -552,7 +552,6 @@ if __name__ == '__main__':
test_laplacian_lambda_max() test_laplacian_lambda_max()
test_remove_self_loop() test_remove_self_loop()
test_add_self_loop() test_add_self_loop()
test_partition()
test_compact() test_compact()
test_to_simple() test_to_simple()
test_in_subgraph() test_in_subgraph()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment