"docs/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "3bd5a9b6d11a74df6035ecdbdf5f71088eb2e901"
Unverified Commit 03024f95 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Peformance] Remove unnecessary induced vertices in EdgeSubgraph (#3978)

* remove unnecessary induced vertices in EdgeSubgraph

* add unit test
parent 4c147814
...@@ -46,7 +46,6 @@ void SpMMCsr(const std::string& op, const std::string& reduce, ...@@ -46,7 +46,6 @@ void SpMMCsr(const std::string& op, const std::string& reduce,
NDArray efeat, NDArray efeat,
NDArray out, NDArray out,
std::vector<NDArray> out_aux) { std::vector<NDArray> out_aux) {
int64_t feat_len = bcast.out_len;
bool is_scalar_efeat = efeat.NumElements() == csr.indices->shape[0]; bool is_scalar_efeat = efeat.NumElements() == csr.indices->shape[0];
bool use_efeat = op != "copy_lhs"; bool use_efeat = op != "copy_lhs";
......
...@@ -117,13 +117,11 @@ class DeviceNodeMap { ...@@ -117,13 +117,11 @@ class DeviceNodeMap {
cudaStream_t stream) : cudaStream_t stream) :
num_types_(num_nodes.size()), num_types_(num_nodes.size()),
rhs_offset_(offset), rhs_offset_(offset),
workspaces_(),
hash_tables_(), hash_tables_(),
ctx_(ctx) { ctx_(ctx) {
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
hash_tables_.reserve(num_types_); hash_tables_.reserve(num_types_);
workspaces_.reserve(num_types_);
for (int64_t i = 0; i < num_types_; ++i) { for (int64_t i = 0; i < num_types_; ++i) {
hash_tables_.emplace_back( hash_tables_.emplace_back(
new OrderedHashTable<IdType>( new OrderedHashTable<IdType>(
...@@ -170,7 +168,6 @@ class DeviceNodeMap { ...@@ -170,7 +168,6 @@ class DeviceNodeMap {
private: private:
int64_t num_types_; int64_t num_types_;
size_t rhs_offset_; size_t rhs_offset_;
std::vector<void*> workspaces_;
std::vector<std::unique_ptr<OrderedHashTable<IdType>>> hash_tables_; std::vector<std::unique_ptr<OrderedHashTable<IdType>>> hash_tables_;
DGLContext ctx_; DGLContext ctx_;
......
...@@ -408,9 +408,9 @@ class UnitGraph::COO : public BaseHeteroGraph { ...@@ -408,9 +408,9 @@ class UnitGraph::COO : public BaseHeteroGraph {
IdArray new_src = aten::IndexSelect(adj_.row, eids[0]); IdArray new_src = aten::IndexSelect(adj_.row, eids[0]);
IdArray new_dst = aten::IndexSelect(adj_.col, eids[0]); IdArray new_dst = aten::IndexSelect(adj_.col, eids[0]);
subg.induced_vertices.emplace_back( subg.induced_vertices.emplace_back(
aten::Range(0, NumVertices(SrcType()), NumBits(), Context())); aten::NullArray(DLDataType{kDLInt, NumBits(), 1}, Context()));
subg.induced_vertices.emplace_back( subg.induced_vertices.emplace_back(
aten::Range(0, NumVertices(DstType()), NumBits(), Context())); aten::NullArray(DLDataType{kDLInt, NumBits(), 1}, Context()));
subg.graph = std::make_shared<COO>( subg.graph = std::make_shared<COO>(
meta_graph(), NumVertices(SrcType()), NumVertices(DstType()), new_src, new_dst); meta_graph(), NumVertices(SrcType()), NumVertices(DstType()), new_src, new_dst);
subg.induced_edges = eids; subg.induced_edges = eids;
......
...@@ -33,7 +33,18 @@ def test_edge_subgraph(): ...@@ -33,7 +33,18 @@ def test_edge_subgraph():
# Test when the graph has no node data and edge data. # Test when the graph has no node data and edge data.
g = generate_graph(add_data=False) g = generate_graph(add_data=False)
eid = [0, 2, 3, 6, 7, 9] eid = [0, 2, 3, 6, 7, 9]
# relabel=True
sg = g.edge_subgraph(eid) sg = g.edge_subgraph(eid)
assert F.array_equal(sg.ndata[dgl.NID], F.tensor([0, 2, 4, 5, 1, 9], g.idtype))
assert F.array_equal(sg.edata[dgl.EID], F.tensor(eid, g.idtype))
sg.ndata['h'] = F.arange(0, sg.number_of_nodes())
sg.edata['h'] = F.arange(0, sg.number_of_edges())
# relabel=False
sg = g.edge_subgraph(eid, relabel_nodes=False)
assert g.number_of_nodes() == sg.number_of_nodes()
assert F.array_equal(sg.edata[dgl.EID], F.tensor(eid, g.idtype))
sg.ndata['h'] = F.arange(0, sg.number_of_nodes()) sg.ndata['h'] = F.arange(0, sg.number_of_nodes())
sg.edata['h'] = F.arange(0, sg.number_of_edges()) sg.edata['h'] = F.arange(0, sg.number_of_edges())
...@@ -655,4 +666,5 @@ def test_uva_subgraph(idtype, device): ...@@ -655,4 +666,5 @@ def test_uva_subgraph(idtype, device):
g.unpin_memory_() g.unpin_memory_()
if __name__ == '__main__': if __name__ == '__main__':
test_uva_subgraph(F.int64, F.cpu()) test_edge_subgraph()
# test_uva_subgraph(F.int64, F.cpu())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment