Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
bacc9047
Unverified
Commit
bacc9047
authored
Sep 17, 2021
by
Rhett Ying
Committed by
GitHub
Sep 17, 2021
Browse files
[BugFix] initialize data if null when converting from row sorted coo to csr (#3360)
parent
2647afc9
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
58 additions
and
1 deletion
+58
-1
src/array/cpu/spmat_op_impl_coo.cc
src/array/cpu/spmat_op_impl_coo.cc
+1
-1
tests/compute/test_transform.py
tests/compute/test_transform.py
+10
-0
tests/cpp/test_spmat_coo.cc
tests/cpp/test_spmat_coo.cc
+47
-0
No files found.
src/array/cpu/spmat_op_impl_coo.cc
View file @
bacc9047
...
...
@@ -323,7 +323,7 @@ template <class IdType> CSRMatrix SortedCOOToCSR(const COOMatrix &coo) {
Bp
[
0
]
=
0
;
IdType
*
const
fill_data
=
data
?
nullptr
:
static_cast
<
IdType
*>
(
coo
.
data
->
data
);
data
?
nullptr
:
static_cast
<
IdType
*>
(
ret_
data
->
data
);
if
(
NNZ
>
0
)
{
auto
num_threads
=
omp_get_max_threads
();
...
...
tests/compute/test_transform.py
View file @
bacc9047
...
...
@@ -847,6 +847,16 @@ def test_to_simple(idtype):
assert
'h'
not
in
sg
.
nodes
[
'user'
].
data
assert
'hh'
not
in
sg
.
nodes
[
'user'
].
data
# verify DGLGraph.edge_ids() after dgl.to_simple()
# in case ids are not initialized in underlying coo2csr()
u
=
F
.
tensor
([
0
,
1
,
2
])
v
=
F
.
tensor
([
1
,
2
,
3
])
eids
=
F
.
tensor
([
0
,
1
,
2
])
g
=
dgl
.
graph
((
u
,
v
))
assert
F
.
array_equal
(
g
.
edge_ids
(
u
,
v
),
eids
)
sg
=
dgl
.
to_simple
(
g
)
assert
F
.
array_equal
(
sg
.
edge_ids
(
u
,
v
),
eids
)
@
parametrize_dtype
def
test_to_block
(
idtype
):
def
check
(
g
,
bg
,
ntype
,
etype
,
dst_nodes
,
include_dst_in_src
=
True
):
...
...
tests/cpp/test_spmat_coo.cc
View file @
bacc9047
...
...
@@ -148,6 +148,39 @@ bool isSparseCOO(const int64_t &num_threads, const int64_t &num_nodes,
// refer to COOToCSR<>() in ~dgl/src/array/cpu/spmat_op_impl_coo for details.
return
num_threads
*
num_nodes
>
4
*
num_edges
;
}
template
<
typename
IDX
>
aten
::
COOMatrix
RowSorted_NullData_COO
(
DLContext
ctx
=
CTX
)
{
// [[0, 1, 1, 0, 0],
// [1, 0, 0, 0, 0],
// [0, 0, 1, 1, 0],
// [0, 0, 0, 0, 0]]
// row : [0, 0, 1, 2, 2]
// col : [1, 2, 0, 2, 3]
return
aten
::
COOMatrix
(
4
,
5
,
aten
::
VecToIdArray
(
std
::
vector
<
IDX
>
({
0
,
0
,
1
,
2
,
2
}),
sizeof
(
IDX
)
*
8
,
ctx
),
aten
::
VecToIdArray
(
std
::
vector
<
IDX
>
({
1
,
2
,
0
,
2
,
3
}),
sizeof
(
IDX
)
*
8
,
ctx
),
aten
::
NullArray
(),
true
,
false
);
}
template
<
typename
IDX
>
aten
::
CSRMatrix
RowSorted_NullData_CSR
(
DLContext
ctx
=
CTX
)
{
// [[0, 1, 1, 0, 0],
// [1, 0, 0, 0, 0],
// [0, 0, 1, 1, 0],
// [0, 0, 0, 0, 0]]
// data: [0, 1, 2, 3, 4]
return
aten
::
CSRMatrix
(
4
,
5
,
aten
::
VecToIdArray
(
std
::
vector
<
IDX
>
({
0
,
2
,
3
,
5
,
5
}),
sizeof
(
IDX
)
*
8
,
ctx
),
aten
::
VecToIdArray
(
std
::
vector
<
IDX
>
({
1
,
2
,
0
,
2
,
3
}),
sizeof
(
IDX
)
*
8
,
ctx
),
aten
::
VecToIdArray
(
std
::
vector
<
IDX
>
({
0
,
1
,
2
,
3
,
4
}),
sizeof
(
IDX
)
*
8
,
ctx
),
false
);
}
}
// namespace
template
<
typename
IDX
>
...
...
@@ -192,6 +225,20 @@ void _TestCOOToCSR(DLContext ctx) {
ASSERT_TRUE
(
ArrayEQ
<
IDX
>
(
rs_tcsr
.
indices
,
rs_coo
.
col
));
ASSERT_TRUE
(
ArrayEQ
<
IDX
>
(
rs_tcsr
.
data
,
rs_coo
.
data
));
rs_coo
=
RowSorted_NullData_COO
<
IDX
>
(
ctx
);
ASSERT_TRUE
(
rs_coo
.
row_sorted
);
rs_csr
=
RowSorted_NullData_CSR
<
IDX
>
(
ctx
);
rs_tcsr
=
aten
::
COOToCSR
(
rs_coo
);
ASSERT_EQ
(
coo
.
num_rows
,
rs_tcsr
.
num_rows
);
ASSERT_EQ
(
rs_csr
.
num_rows
,
rs_tcsr
.
num_rows
);
ASSERT_EQ
(
coo
.
num_cols
,
rs_tcsr
.
num_cols
);
ASSERT_EQ
(
rs_csr
.
num_cols
,
rs_tcsr
.
num_cols
);
ASSERT_TRUE
(
ArrayEQ
<
IDX
>
(
rs_csr
.
indptr
,
rs_tcsr
.
indptr
));
ASSERT_TRUE
(
ArrayEQ
<
IDX
>
(
rs_csr
.
indices
,
rs_tcsr
.
indices
));
ASSERT_TRUE
(
ArrayEQ
<
IDX
>
(
rs_csr
.
data
,
rs_tcsr
.
data
));
ASSERT_TRUE
(
ArrayEQ
<
IDX
>
(
rs_coo
.
col
,
rs_tcsr
.
indices
));
ASSERT_FALSE
(
ArrayEQ
<
IDX
>
(
rs_coo
.
data
,
rs_tcsr
.
data
));
// Convert from col sorted coo
coo
=
COO1
<
IDX
>
(
ctx
);
auto
src_coo
=
aten
::
COOSort
(
coo
,
true
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment