"docs/source/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "308bd6f5b245929211f365396ca2007ac151b8e7"
Unverified Commit 24dc71fc authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[BUG] Fix #1409 (#1411)

* [BUG] Fix #1409

* fix test
parent d3560b71
...@@ -999,9 +999,27 @@ HeteroSubgraph UnitGraph::VertexSubgraph(const std::vector<IdArray>& vids) const ...@@ -999,9 +999,27 @@ HeteroSubgraph UnitGraph::VertexSubgraph(const std::vector<IdArray>& vids) const
// We prefer to generate a subgraph from out-csr. // We prefer to generate a subgraph from out-csr.
SparseFormat fmt = SelectFormat(SparseFormat::kCSR); SparseFormat fmt = SelectFormat(SparseFormat::kCSR);
HeteroSubgraph sg = GetFormat(fmt)->VertexSubgraph(vids); HeteroSubgraph sg = GetFormat(fmt)->VertexSubgraph(vids);
CSRPtr subcsr = std::dynamic_pointer_cast<CSR>(sg.graph);
HeteroSubgraph ret; HeteroSubgraph ret;
ret.graph = HeteroGraphPtr(new UnitGraph(meta_graph(), nullptr, subcsr, nullptr));
CSRPtr subcsr = nullptr;
CSRPtr subcsc = nullptr;
COOPtr subcoo = nullptr;
switch (fmt) {
case SparseFormat::kCSR:
subcsr = std::dynamic_pointer_cast<CSR>(sg.graph);
break;
case SparseFormat::kCSC:
subcsc = std::dynamic_pointer_cast<CSR>(sg.graph);
break;
case SparseFormat::kCOO:
subcoo = std::dynamic_pointer_cast<COO>(sg.graph);
break;
default:
LOG(FATAL) << "[BUG] unsupported format " << static_cast<int>(fmt);
return ret;
}
ret.graph = HeteroGraphPtr(new UnitGraph(meta_graph(), subcsc, subcsr, subcoo));
ret.induced_vertices = std::move(sg.induced_vertices); ret.induced_vertices = std::move(sg.induced_vertices);
ret.induced_edges = std::move(sg.induced_edges); ret.induced_edges = std::move(sg.induced_edges);
return ret; return ret;
...@@ -1011,9 +1029,27 @@ HeteroSubgraph UnitGraph::EdgeSubgraph( ...@@ -1011,9 +1029,27 @@ HeteroSubgraph UnitGraph::EdgeSubgraph(
const std::vector<IdArray>& eids, bool preserve_nodes) const { const std::vector<IdArray>& eids, bool preserve_nodes) const {
SparseFormat fmt = SelectFormat(SparseFormat::kCOO); SparseFormat fmt = SelectFormat(SparseFormat::kCOO);
auto sg = GetFormat(fmt)->EdgeSubgraph(eids, preserve_nodes); auto sg = GetFormat(fmt)->EdgeSubgraph(eids, preserve_nodes);
COOPtr subcoo = std::dynamic_pointer_cast<COO>(sg.graph);
HeteroSubgraph ret; HeteroSubgraph ret;
ret.graph = HeteroGraphPtr(new UnitGraph(meta_graph(), nullptr, nullptr, subcoo));
CSRPtr subcsr = nullptr;
CSRPtr subcsc = nullptr;
COOPtr subcoo = nullptr;
switch (fmt) {
case SparseFormat::kCSR:
subcsr = std::dynamic_pointer_cast<CSR>(sg.graph);
break;
case SparseFormat::kCSC:
subcsc = std::dynamic_pointer_cast<CSR>(sg.graph);
break;
case SparseFormat::kCOO:
subcoo = std::dynamic_pointer_cast<COO>(sg.graph);
break;
default:
LOG(FATAL) << "[BUG] unsupported format " << static_cast<int>(fmt);
return ret;
}
ret.graph = HeteroGraphPtr(new UnitGraph(meta_graph(), subcsc, subcsr, subcoo));
ret.induced_vertices = std::move(sg.induced_vertices); ret.induced_vertices = std::move(sg.induced_vertices);
ret.induced_edges = std::move(sg.induced_edges); ret.induced_edges = std::move(sg.induced_edges);
return ret; return ret;
......
...@@ -940,6 +940,18 @@ def test_subgraph(): ...@@ -940,6 +940,18 @@ def test_subgraph():
sg5 = g.edge_type_subgraph(['follows', 'plays', 'wishes']) sg5 = g.edge_type_subgraph(['follows', 'plays', 'wishes'])
_check_typed_subgraph1(g, sg5) _check_typed_subgraph1(g, sg5)
# Test for restricted format
for fmt in ['csr', 'csc', 'coo']:
g = dgl.graph([(0, 1), (1, 2)], restrict_format=fmt)
sg = g.subgraph({g.ntypes[0]: [1, 0]})
nids = F.asnumpy(sg.ndata[dgl.NID])
assert np.array_equal(nids, np.array([1, 0]))
src, dst = sg.all_edges(order='eid')
src = F.asnumpy(src)
dst = F.asnumpy(dst)
assert np.array_equal(src, np.array([1]))
assert np.array_equal(dst, np.array([0]))
def test_apply(): def test_apply():
def node_udf(nodes): def node_udf(nodes):
return {'h': nodes.data['h'] * 2} return {'h': nodes.data['h'] * 2}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment