Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
a5d21c2b
Unverified
Commit
a5d21c2b
authored
Oct 13, 2022
by
Rhett Ying
Committed by
GitHub
Oct 13, 2022
Browse files
[Sampling] handle fanout=-1 differently from fanout>0 in sample_etype_neighbors() (#4716)
parent
e452179c
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
3 additions
and
2 deletions
+3
-2
src/array/cpu/rowwise_pick.h
src/array/cpu/rowwise_pick.h
+2
-1
tests/compute/test_sampling.py
tests/compute/test_sampling.py
+1
-1
No files found.
src/array/cpu/rowwise_pick.h
View file @
a5d21c2b
...
@@ -284,7 +284,8 @@ COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows, IdArray etypes,
...
@@ -284,7 +284,8 @@ COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows, IdArray etypes,
// 1 end of the current etype
// 1 end of the current etype
// 2 end of the row
// 2 end of the row
// random pick for current etype
// random pick for current etype
if
(
et_len
<=
num_picks
[
cur_et
]
&&
!
replace
)
{
if
((
num_picks
[
cur_et
]
==
-
1
)
||
(
et_len
<=
num_picks
[
cur_et
]
&&
!
replace
))
{
// fast path, select all
// fast path, select all
for
(
int64_t
k
=
0
;
k
<
et_len
;
++
k
)
{
for
(
int64_t
k
=
0
;
k
<
et_len
;
++
k
)
{
rows
.
push_back
(
rid
);
rows
.
push_back
(
rid
);
...
...
tests/compute/test_sampling.py
View file @
a5d21c2b
...
@@ -882,7 +882,7 @@ def test_sample_neighbors_etype_sorted_homogeneous(format_, direction):
...
@@ -882,7 +882,7 @@ def test_sample_neighbors_etype_sorted_homogeneous(format_, direction):
h_g
=
dgl
.
to_homogeneous
(
g
)
h_g
=
dgl
.
to_homogeneous
(
g
)
seed_ntype
=
g
.
get_ntype_id
(
"u"
)
seed_ntype
=
g
.
get_ntype_id
(
"u"
)
seeds
=
F
.
nonzero_1d
(
h_g
.
ndata
[
dgl
.
NTYPE
]
==
seed_ntype
)
seeds
=
F
.
nonzero_1d
(
h_g
.
ndata
[
dgl
.
NTYPE
]
==
seed_ntype
)
fanouts
=
F
.
tensor
([
6
,
5
,
4
,
3
,
2
],
dtype
=
F
.
int64
)
fanouts
=
F
.
tensor
([
6
,
5
,
-
1
,
3
,
2
],
dtype
=
F
.
int64
)
h_g
=
h_g
.
formats
(
format_
)
h_g
=
h_g
.
formats
(
format_
)
if
(
direction
,
format_
)
in
[(
'in'
,
'csr'
),
(
'out'
,
'csc'
)]:
if
(
direction
,
format_
)
in
[(
'in'
,
'csr'
),
(
'out'
,
'csc'
)]:
h_g
=
h_g
.
formats
([
'csc'
,
'csr'
,
'coo'
])
h_g
=
h_g
.
formats
([
'csc'
,
'csr'
,
'coo'
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment