Unverified Commit ae17185a authored by keli-wen's avatar keli-wen Committed by GitHub
Browse files

[Sparse] Update code and add unittest for `formats` (#5859)

parent 2446f2fd
......@@ -1529,18 +1529,28 @@ HeteroGraphPtr UnitGraph::GetFormat(SparseFormat format) const {
}
HeteroGraphPtr UnitGraph::GetGraphInFormat(dgl_format_code_t formats) const {
if (formats == ALL_CODE)
// Get the created formats.
auto created_formats = GetCreatedFormats();
// Get the intersection of formats and created_formats.
auto intersection = formats & created_formats;
// If the intersection of formats and created_formats is not empty.
// The format(s) in the intersection will be retained.
if (intersection != 0) {
COOPtr coo_ptr = COO_CODE & intersection ? GetCOO(false) : nullptr;
CSRPtr in_csr_ptr = CSC_CODE & intersection ? GetInCSR(false) : nullptr;
CSRPtr out_csr_ptr = CSR_CODE & intersection ? GetOutCSR(false) : nullptr;
return HeteroGraphPtr(
// TODO(xiangsx) Make it as graph storage.Clone()
new UnitGraph(
meta_graph_,
(in_csr_->defined()) ? CSRPtr(new CSR(*in_csr_)) : nullptr,
(out_csr_->defined()) ? CSRPtr(new CSR(*out_csr_)) : nullptr,
(coo_->defined()) ? COOPtr(new COO(*coo_)) : nullptr, formats));
new UnitGraph(meta_graph_, in_csr_ptr, out_csr_ptr, coo_ptr, formats));
}
// If the intersection of formats and created_formats is empty.
// Create a format in the order of COO -> CSR -> CSC.
int64_t num_vtypes = NumVertexTypes();
if (formats & COO_CODE)
if (COO_CODE & formats)
return CreateFromCOO(num_vtypes, GetCOO(false)->adj(), formats);
if (formats & CSR_CODE)
if (CSR_CODE & formats)
return CreateFromCSR(num_vtypes, GetOutCSR(false)->adj(), formats);
return CreateFromCSC(num_vtypes, GetInCSR(false)->adj(), formats);
}
......
......@@ -499,6 +499,41 @@ def test_formats():
finally:
assert not fail
# If the intersection of created formats and allowed formats is
# not empty, then retain the intersection.
# Case1: intersection is not empty and intersected is equal to
# created formats.
g = g.formats(["coo", "csr"])
g.create_formats_()
g = g.formats(["coo", "csr", "csc"])
assert sorted(g.formats()["created"]) == sorted(["coo", "csr"])
assert sorted(g.formats()["not created"]) == sorted(["csc"])
# Case2: intersection is not empty and intersected is not equal
# to created formats.
g = g.formats(["coo", "csr"])
g.create_formats_()
g = g.formats(["coo", "csc"])
assert sorted(g.formats()["created"]) == sorted(["coo"])
assert sorted(g.formats()["not created"]) == sorted(["csc"])
# If the intersection of created formats and allowed formats is
# empty, then create a format in the order of `coo` -> `csr` ->
# `csc`.
# Case1: intersection is empty and just one format is allowed.
g = g.formats(["coo", "csr"])
g.create_formats_()
g = g.formats(["csc"])
assert sorted(g.formats()["created"]) == sorted(["csc"])
assert sorted(g.formats()["not created"]) == sorted([])
# Case2: intersection is empty and more than one format is allowed.
g = g.formats("csc")
g.create_formats_()
g = g.formats(["csr", "coo"])
assert sorted(g.formats()["created"]) == sorted(["coo"])
assert sorted(g.formats()["not created"]) == sorted(["csr"])
if __name__ == "__main__":
test_query()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment