Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
dcf16992
Unverified
Commit
dcf16992
authored
Jul 01, 2022
by
Rhett Ying
Committed by
GitHub
Jul 01, 2022
Browse files
[BugFix] check whether etype sorted when sampling (#4198)
parent
a9768cb3
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
30 additions
and
1 deletion
+30
-1
src/array/cpu/rowwise_pick.h
src/array/cpu/rowwise_pick.h
+4
-1
tests/compute/test_sampling.py
tests/compute/test_sampling.py
+26
-0
No files found.
src/array/cpu/rowwise_pick.h
View file @
dcf16992
...
@@ -277,7 +277,10 @@ COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows, IdArray etypes,
...
@@ -277,7 +277,10 @@ COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows, IdArray etypes,
int64_t
et_offset
=
0
;
int64_t
et_offset
=
0
;
int64_t
et_len
=
1
;
int64_t
et_len
=
1
;
for
(
int64_t
j
=
0
;
j
<
len
;
++
j
)
{
for
(
int64_t
j
=
0
;
j
<
len
;
++
j
)
{
if
((
j
+
1
==
len
)
||
cur_et
!=
et
[
et_idx
[
j
+
1
]])
{
CHECK
((
j
+
1
==
len
)
||
(
et
[
et_idx
[
j
]]
<=
et
[
et_idx
[
j
+
1
]]))
<<
"Edge type is not sorted. Please sort in advance or specify "
"'etype_sorted' as false."
;
if
((
j
+
1
==
len
)
||
cur_et
!=
et
[
et_idx
[
j
+
1
]])
{
// 1 end of the current etype
// 1 end of the current etype
// 2 end of the row
// 2 end of the row
// random pick for current etype
// random pick for current etype
...
...
tests/compute/test_sampling.py
View file @
dcf16992
...
@@ -858,6 +858,32 @@ def test_sample_neighbors_etype_homogeneous(format_, direction, replace):
...
@@ -858,6 +858,32 @@ def test_sample_neighbors_etype_homogeneous(format_, direction, replace):
h_g
,
seeds
,
dgl
.
ETYPE
,
fanouts
,
replace
=
replace
,
edge_dir
=
direction
)
h_g
,
seeds
,
dgl
.
ETYPE
,
fanouts
,
replace
=
replace
,
edge_dir
=
direction
)
check_num
(
h_g
,
all_src
,
all_dst
,
subg
,
replace
,
fanouts
,
direction
)
check_num
(
h_g
,
all_src
,
all_dst
,
subg
,
replace
,
fanouts
,
direction
)
@
unittest
.
skipIf
(
F
.
_default_context_str
==
'gpu'
,
reason
=
"GPU sample neighbors not implemented"
)
@
pytest
.
mark
.
parametrize
(
'format_'
,
[
'csr'
,
'csc'
])
@
pytest
.
mark
.
parametrize
(
'direction'
,
[
'in'
,
'out'
])
def
test_sample_neighbors_etype_sorted_homogeneous
(
format_
,
direction
):
rare_cnt
=
4
g
=
create_etype_test_graph
(
100
,
30
,
rare_cnt
)
h_g
=
dgl
.
to_homogeneous
(
g
)
seed_ntype
=
g
.
get_ntype_id
(
"u"
)
seeds
=
F
.
nonzero_1d
(
h_g
.
ndata
[
dgl
.
NTYPE
]
==
seed_ntype
)
fanouts
=
F
.
tensor
([
6
,
5
,
4
,
3
,
2
],
dtype
=
F
.
int64
)
h_g
=
h_g
.
formats
(
format_
)
if
(
direction
,
format_
)
in
[(
'in'
,
'csr'
),
(
'out'
,
'csc'
)]:
h_g
=
h_g
.
formats
([
'csc'
,
'csr'
,
'coo'
])
orig_etype
=
F
.
asnumpy
(
h_g
.
edata
[
dgl
.
ETYPE
])
h_g
.
edata
[
dgl
.
ETYPE
]
=
F
.
tensor
(
np
.
sort
(
orig_etype
)[::
-
1
].
tolist
(),
dtype
=
F
.
int64
)
try
:
dgl
.
sampling
.
sample_etype_neighbors
(
h_g
,
seeds
,
dgl
.
ETYPE
,
fanouts
,
edge_dir
=
direction
,
etype_sorted
=
True
)
fail
=
False
except
dgl
.
DGLError
:
fail
=
True
assert
fail
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
'int32'
,
'int64'
])
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
'int32'
,
'int64'
])
@
unittest
.
skipIf
(
F
.
_default_context_str
==
'gpu'
,
reason
=
"GPU sample neighbors not implemented"
)
@
unittest
.
skipIf
(
F
.
_default_context_str
==
'gpu'
,
reason
=
"GPU sample neighbors not implemented"
)
def
test_sample_neighbors_exclude_edges_heteroG
(
dtype
):
def
test_sample_neighbors_exclude_edges_heteroG
(
dtype
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment