"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "ff573ae245166d8b958950f68eb65ea766f07f69"
Unverified Commit 6ae93e5c authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Bug] Fix #1088 (#1089)

* [Bug] Fix #1088

* fix

* add comment
parent 48c7ec44
...@@ -468,12 +468,13 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetRelationGraph") ...@@ -468,12 +468,13 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetRelationGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1]; dgl_type_t etype = args[1];
if (hg->NumEdgeTypes() == 1) { CHECK_LE(etype, hg->NumEdgeTypes()) << "invalid edge type " << etype;
CHECK_EQ(etype, 0); // Test if the heterograph is a unit graph. If so, return itself.
*rv = hg; auto bg = std::dynamic_pointer_cast<UnitGraph>(hg.sptr());
} else { if (bg != nullptr)
*rv = bg;
else
*rv = HeteroGraphRef(hg->GetRelationGraph(etype)); *rv = HeteroGraphRef(hg->GetRelationGraph(etype));
}
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetFlattenedGraph") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetFlattenedGraph")
......
...@@ -762,6 +762,14 @@ def test_local_scope(): ...@@ -762,6 +762,14 @@ def test_local_scope():
assert F.allclose(g.edata['w'], F.tensor([[1.], [0.]])) assert F.allclose(g.edata['w'], F.tensor([[1.], [0.]]))
foo(g) foo(g)
def test_issue_1088():
# This test ensures that message passing on a heterograph with one edge type
# would not crash (GitHub issue #1088).
import dgl.function as fn
g = dgl.heterograph({('U', 'E', 'V'): ([0, 1, 2], [1, 2, 3])})
g.nodes['U'].data['x'] = F.randn((3, 3))
g.update_all(fn.copy_u('x', 'm'), fn.sum('m', 'y'))
if __name__ == '__main__': if __name__ == '__main__':
test_nx_conversion() test_nx_conversion()
test_batch_setter_getter() test_batch_setter_getter()
...@@ -781,3 +789,4 @@ if __name__ == '__main__': ...@@ -781,3 +789,4 @@ if __name__ == '__main__':
test_group_apply_edges() test_group_apply_edges()
test_local_var() test_local_var()
test_local_scope() test_local_scope()
test_issue_1088()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment