Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
ae17185a
"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "4645e28355363d5dceb5f644a0c1ea5bdc2471f9"
Unverified
Commit
ae17185a
authored
Jun 13, 2023
by
keli-wen
Committed by
GitHub
Jun 13, 2023
Browse files
[Sparse] Update code and add unittest for `formats` (#5859)
parent
2446f2fd
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
54 additions
and
9 deletions
+54
-9
src/graph/unit_graph.cc
src/graph/unit_graph.cc
+19
-9
tests/python/common/test_heterograph-misc.py
tests/python/common/test_heterograph-misc.py
+35
-0
No files found.
src/graph/unit_graph.cc
View file @
ae17185a
...
...
@@ -1529,18 +1529,28 @@ HeteroGraphPtr UnitGraph::GetFormat(SparseFormat format) const {
}
HeteroGraphPtr
UnitGraph
::
GetGraphInFormat
(
dgl_format_code_t
formats
)
const
{
if
(
formats
==
ALL_CODE
)
// Get the created formats.
auto
created_formats
=
GetCreatedFormats
();
// Get the intersection of formats and created_formats.
auto
intersection
=
formats
&
created_formats
;
// If the intersection of formats and created_formats is not empty.
// The format(s) in the intersection will be retained.
if
(
intersection
!=
0
)
{
COOPtr
coo_ptr
=
COO_CODE
&
intersection
?
GetCOO
(
false
)
:
nullptr
;
CSRPtr
in_csr_ptr
=
CSC_CODE
&
intersection
?
GetInCSR
(
false
)
:
nullptr
;
CSRPtr
out_csr_ptr
=
CSR_CODE
&
intersection
?
GetOutCSR
(
false
)
:
nullptr
;
return
HeteroGraphPtr
(
// TODO(xiangsx) Make it as graph storage.Clone()
new
UnitGraph
(
meta_graph_
,
(
in_csr_
->
defined
())
?
CSRPtr
(
new
CSR
(
*
in_csr_
))
:
nullptr
,
(
out_csr_
->
defined
())
?
CSRPtr
(
new
CSR
(
*
out_csr_
))
:
nullptr
,
(
coo_
->
defined
())
?
COOPtr
(
new
COO
(
*
coo_
))
:
nullptr
,
formats
));
new
UnitGraph
(
meta_graph_
,
in_csr_ptr
,
out_csr_ptr
,
coo_ptr
,
formats
));
}
// If the intersection of formats and created_formats is empty.
// Create a format in the order of COO -> CSR -> CSC.
int64_t
num_vtypes
=
NumVertexTypes
();
if
(
formats
&
COO_CODE
)
if
(
COO_CODE
&
formats
)
return
CreateFromCOO
(
num_vtypes
,
GetCOO
(
false
)
->
adj
(),
formats
);
if
(
formats
&
CSR_CODE
)
if
(
CSR_CODE
&
formats
)
return
CreateFromCSR
(
num_vtypes
,
GetOutCSR
(
false
)
->
adj
(),
formats
);
return
CreateFromCSC
(
num_vtypes
,
GetInCSR
(
false
)
->
adj
(),
formats
);
}
...
...
tests/python/common/test_heterograph-misc.py
View file @
ae17185a
...
...
@@ -499,6 +499,41 @@ def test_formats():
finally
:
assert
not
fail
# If the intersection of created formats and allowed formats is
# not empty, then retain the intersection.
# Case1: intersection is not empty and intersected is equal to
# created formats.
g
=
g
.
formats
([
"coo"
,
"csr"
])
g
.
create_formats_
()
g
=
g
.
formats
([
"coo"
,
"csr"
,
"csc"
])
assert
sorted
(
g
.
formats
()[
"created"
])
==
sorted
([
"coo"
,
"csr"
])
assert
sorted
(
g
.
formats
()[
"not created"
])
==
sorted
([
"csc"
])
# Case2: intersection is not empty and intersected is not equal
# to created formats.
g
=
g
.
formats
([
"coo"
,
"csr"
])
g
.
create_formats_
()
g
=
g
.
formats
([
"coo"
,
"csc"
])
assert
sorted
(
g
.
formats
()[
"created"
])
==
sorted
([
"coo"
])
assert
sorted
(
g
.
formats
()[
"not created"
])
==
sorted
([
"csc"
])
# If the intersection of created formats and allowed formats is
# empty, then create a format in the order of `coo` -> `csr` ->
# `csc`.
# Case1: intersection is empty and just one format is allowed.
g
=
g
.
formats
([
"coo"
,
"csr"
])
g
.
create_formats_
()
g
=
g
.
formats
([
"csc"
])
assert
sorted
(
g
.
formats
()[
"created"
])
==
sorted
([
"csc"
])
assert
sorted
(
g
.
formats
()[
"not created"
])
==
sorted
([])
# Case2: intersection is empty and more than one format is allowed.
g
=
g
.
formats
(
"csc"
)
g
.
create_formats_
()
g
=
g
.
formats
([
"csr"
,
"coo"
])
assert
sorted
(
g
.
formats
()[
"created"
])
==
sorted
([
"coo"
])
assert
sorted
(
g
.
formats
()[
"not created"
])
==
sorted
([
"csr"
])
if
__name__
==
"__main__"
:
test_query
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment