"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "a8c81018c538c215e94e185de957485ea63ab16d"
Unverified Commit ea8b93f9 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[BugFix] fix in_degree/out_degree computation logic (#3477)

* [BugFix] fix in/out degree computation

* add unit tests
parent f7360c3c
...@@ -1000,30 +1000,38 @@ EdgeArray UnitGraph::Edges(dgl_type_t etype, const std::string &order) const { ...@@ -1000,30 +1000,38 @@ EdgeArray UnitGraph::Edges(dgl_type_t etype, const std::string &order) const {
uint64_t UnitGraph::InDegree(dgl_type_t etype, dgl_id_t vid) const { uint64_t UnitGraph::InDegree(dgl_type_t etype, dgl_id_t vid) const {
SparseFormat fmt = SelectFormat(CSC_CODE); SparseFormat fmt = SelectFormat(CSC_CODE);
const auto ptr = GetFormat(fmt); const auto ptr = GetFormat(fmt);
if (fmt == SparseFormat::kCSC) CHECK(fmt == SparseFormat::kCSC || fmt == SparseFormat::kCOO)
return ptr->OutDegree(etype, vid); << "In degree cannot be computed as neither CSC nor COO format is "
else "allowed for this graph. Please enable one of them at least.";
return ptr->InDegree(etype, vid); return fmt == SparseFormat::kCSC ? ptr->OutDegree(etype, vid)
: ptr->InDegree(etype, vid);
} }
DegreeArray UnitGraph::InDegrees(dgl_type_t etype, IdArray vids) const { DegreeArray UnitGraph::InDegrees(dgl_type_t etype, IdArray vids) const {
SparseFormat fmt = SelectFormat(CSC_CODE); SparseFormat fmt = SelectFormat(CSC_CODE);
const auto ptr = GetFormat(fmt); const auto ptr = GetFormat(fmt);
if (fmt == SparseFormat::kCSC) CHECK(fmt == SparseFormat::kCSC || fmt == SparseFormat::kCOO)
return ptr->OutDegrees(etype, vids); << "In degree cannot be computed as neither CSC nor COO format is "
else "allowed for this graph. Please enable one of them at least.";
return ptr->InDegrees(etype, vids); return fmt == SparseFormat::kCSC ? ptr->OutDegrees(etype, vids)
: ptr->InDegrees(etype, vids);
} }
uint64_t UnitGraph::OutDegree(dgl_type_t etype, dgl_id_t vid) const { uint64_t UnitGraph::OutDegree(dgl_type_t etype, dgl_id_t vid) const {
SparseFormat fmt = SelectFormat(CSR_CODE); SparseFormat fmt = SelectFormat(CSR_CODE);
const auto ptr = GetFormat(fmt); const auto ptr = GetFormat(fmt);
CHECK(fmt == SparseFormat::kCSR || fmt == SparseFormat::kCOO)
<< "Out degree cannot be computed as neither CSR nor COO format is "
"allowed for this graph. Please enable one of them at least.";
return ptr->OutDegree(etype, vid); return ptr->OutDegree(etype, vid);
} }
DegreeArray UnitGraph::OutDegrees(dgl_type_t etype, IdArray vids) const { DegreeArray UnitGraph::OutDegrees(dgl_type_t etype, IdArray vids) const {
SparseFormat fmt = SelectFormat(CSR_CODE); SparseFormat fmt = SelectFormat(CSR_CODE);
const auto ptr = GetFormat(fmt); const auto ptr = GetFormat(fmt);
CHECK(fmt == SparseFormat::kCSR || fmt == SparseFormat::kCOO)
<< "Out degree cannot be computed as neither CSR nor COO format is "
"allowed for this graph. Please enable one of them at least.";
return ptr->OutDegrees(etype, vids); return ptr->OutDegrees(etype, vids);
} }
......
...@@ -376,6 +376,40 @@ def test_default_types(): ...@@ -376,6 +376,40 @@ def test_default_types():
assert dg.ntypes == g.ntypes assert dg.ntypes == g.ntypes
assert dg.etypes == g.etypes assert dg.etypes == g.etypes
def test_formats():
g = dgl.rand_graph(10, 20)
# in_degrees works if coo or csc available
# out_degrees works if coo or csr available
try:
g.in_degrees()
g.out_degrees()
g.formats('coo').in_degrees()
g.formats('coo').out_degrees()
g.formats('csc').in_degrees()
g.formats('csr').out_degrees()
fail = False
except DGLError:
fail = True
finally:
assert not fail
# in_degrees NOT works if csc available only
try:
g.formats('csc').out_degrees()
fail = True
except DGLError:
fail = False
finally:
assert not fail
# out_degrees NOT works if csr available only
try:
g.formats('csr').in_degrees()
fail = True
except DGLError:
fail = False
finally:
assert not fail
if __name__ == '__main__': if __name__ == '__main__':
test_query() test_query()
test_mutation() test_mutation()
...@@ -385,3 +419,4 @@ if __name__ == '__main__': ...@@ -385,3 +419,4 @@ if __name__ == '__main__':
test_hypersparse_query() test_hypersparse_query()
test_is_sorted() test_is_sorted()
test_default_types() test_default_types()
test_formats()
...@@ -67,6 +67,52 @@ aten::COOMatrix COO1(DLContext ctx) { ...@@ -67,6 +67,52 @@ aten::COOMatrix COO1(DLContext ctx) {
template aten::COOMatrix COO1<int32_t>(DLContext ctx); template aten::COOMatrix COO1<int32_t>(DLContext ctx);
template aten::COOMatrix COO1<int64_t>(DLContext ctx); template aten::COOMatrix COO1<int64_t>(DLContext ctx);
template <typename IdType> void _TestUnitGraph_InOutDegrees(DLContext ctx) {
/*
InDegree(s) is available only if COO or CSC formats permitted.
OutDegree(s) is available only if COO or CSR formats permitted.
*/
// COO
{
const aten::COOMatrix &coo = COO1<IdType>(ctx);
auto &&g = CreateFromCOO(2, coo, COO_CODE);
ASSERT_EQ(g->InDegree(0, 0), 1);
auto &&nids = aten::Range(0, g->NumVertices(0), g->NumBits(), g->Context());
ASSERT_TRUE(ArrayEQ<IdType>(
g->InDegrees(0, nids),
aten::VecToIdArray<IdType>({1, 2}, g->NumBits(), g->Context())));
ASSERT_EQ(g->OutDegree(0, 0), 2);
ASSERT_TRUE(ArrayEQ<IdType>(
g->OutDegrees(0, nids),
aten::VecToIdArray<IdType>({2, 1}, g->NumBits(), g->Context())));
}
// CSC
{
const aten::CSRMatrix &csr = CSR1<IdType>(ctx);
auto &&g = CreateFromCSC(2, csr, CSC_CODE);
ASSERT_EQ(g->InDegree(0, 0), 1);
auto &&nids = aten::Range(0, g->NumVertices(0), g->NumBits(), g->Context());
ASSERT_TRUE(ArrayEQ<IdType>(
g->InDegrees(0, nids),
aten::VecToIdArray<IdType>({1, 2, 1}, g->NumBits(), g->Context())));
EXPECT_ANY_THROW(g->OutDegree(0, 0));
EXPECT_ANY_THROW(g->OutDegrees(0, nids));
}
// CSR
{
const aten::CSRMatrix &csr = CSR1<IdType>(ctx);
auto &&g = CreateFromCSR(2, csr, CSR_CODE);
ASSERT_EQ(g->OutDegree(0, 0), 1);
auto &&nids = aten::Range(0, g->NumVertices(0), g->NumBits(), g->Context());
ASSERT_TRUE(ArrayEQ<IdType>(
g->OutDegrees(0, nids),
aten::VecToIdArray<IdType>({1, 2, 1, 2}, g->NumBits(), g->Context())));
EXPECT_ANY_THROW(g->InDegree(0, 0));
EXPECT_ANY_THROW(g->InDegrees(0, nids));
}
}
template <typename IdType> template <typename IdType>
void _TestUnitGraph(DLContext ctx) { void _TestUnitGraph(DLContext ctx) {
const aten::CSRMatrix &csr = CSR1<IdType>(ctx); const aten::CSRMatrix &csr = CSR1<IdType>(ctx);
...@@ -340,6 +386,15 @@ TEST(UniGraphTest, TestUnitGraph_CopyTo) { ...@@ -340,6 +386,15 @@ TEST(UniGraphTest, TestUnitGraph_CopyTo) {
#endif #endif
} }
TEST(UniGraphTest, TestUnitGraph_InOutDegrees) {
_TestUnitGraph_InOutDegrees<int32_t>(CPU);
_TestUnitGraph_InOutDegrees<int64_t>(CPU);
#ifdef DGL_USE_CUDA
_TestUnitGraph_InOutDegrees<int32_t>(GPU);
_TestUnitGraph_InOutDegrees<int64_t>(GPU);
#endif
}
TEST(UniGraphTest, TestUnitGraph_Create) { TEST(UniGraphTest, TestUnitGraph_Create) {
_TestUnitGraph<int32_t>(CPU); _TestUnitGraph<int32_t>(CPU);
_TestUnitGraph<int64_t>(CPU); _TestUnitGraph<int64_t>(CPU);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment