Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
91f7ff8e
Unverified
Commit
91f7ff8e
authored
Apr 06, 2020
by
Quan (Andy) Gan
Committed by
GitHub
Apr 06, 2020
Browse files
fix #1421 (#1422)
parent
7c47d8c9
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
12 deletions
+14
-12
src/array/cpu/rowwise_pick.h
src/array/cpu/rowwise_pick.h
+6
-8
tests/compute/test_sampling.py
tests/compute/test_sampling.py
+8
-4
No files found.
src/array/cpu/rowwise_pick.h
View file @
91f7ff8e
...
...
@@ -73,15 +73,13 @@ COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows,
IdxType
*
picked_idata
=
static_cast
<
IdxType
*>
(
picked_idx
->
data
);
bool
all_has_fanout
=
true
;
if
(
replace
)
{
all_has_fanout
=
true
;
}
else
{
#pragma omp parallel for reduction(&&:all_has_fanout)
for
(
int64_t
i
=
0
;
i
<
num_rows
;
++
i
)
{
const
IdxType
rid
=
rows_data
[
i
];
const
IdxType
len
=
indptr
[
rid
+
1
]
-
indptr
[
rid
];
all_has_fanout
=
all_has_fanout
&&
(
len
>=
num_picks
);
}
for
(
int64_t
i
=
0
;
i
<
num_rows
;
++
i
)
{
const
IdxType
rid
=
rows_data
[
i
];
const
IdxType
len
=
indptr
[
rid
+
1
]
-
indptr
[
rid
];
// If a node has no neighbor then all_has_fanout must be false even if replace is
// true.
all_has_fanout
=
all_has_fanout
&&
(
len
>=
(
replace
?
1
:
num_picks
));
}
#pragma omp parallel for
...
...
tests/compute/test_sampling.py
View file @
91f7ff8e
...
...
@@ -460,10 +460,14 @@ def test_sample_neighbors_topk_outedge():
@
unittest
.
skipIf
(
F
.
_default_context_str
==
'gpu'
,
reason
=
"GPU sample neighbors not implemented"
)
def
test_sample_neighbors_with_0deg
():
g
=
dgl
.
graph
([],
num_nodes
=
5
)
dgl
.
sampling
.
sample_neighbors
(
g
,
F
.
tensor
([
1
,
2
],
dtype
=
F
.
int64
),
2
,
edge_dir
=
'in'
,
replace
=
False
)
dgl
.
sampling
.
sample_neighbors
(
g
,
F
.
tensor
([
1
,
2
],
dtype
=
F
.
int64
),
2
,
edge_dir
=
'in'
,
replace
=
True
)
dgl
.
sampling
.
sample_neighbors
(
g
,
F
.
tensor
([
1
,
2
],
dtype
=
F
.
int64
),
2
,
edge_dir
=
'out'
,
replace
=
False
)
dgl
.
sampling
.
sample_neighbors
(
g
,
F
.
tensor
([
1
,
2
],
dtype
=
F
.
int64
),
2
,
edge_dir
=
'out'
,
replace
=
True
)
sg
=
dgl
.
sampling
.
sample_neighbors
(
g
,
F
.
tensor
([
1
,
2
],
dtype
=
F
.
int64
),
2
,
edge_dir
=
'in'
,
replace
=
False
)
assert
sg
.
number_of_edges
()
==
0
sg
=
dgl
.
sampling
.
sample_neighbors
(
g
,
F
.
tensor
([
1
,
2
],
dtype
=
F
.
int64
),
2
,
edge_dir
=
'in'
,
replace
=
True
)
assert
sg
.
number_of_edges
()
==
0
sg
=
dgl
.
sampling
.
sample_neighbors
(
g
,
F
.
tensor
([
1
,
2
],
dtype
=
F
.
int64
),
2
,
edge_dir
=
'out'
,
replace
=
False
)
assert
sg
.
number_of_edges
()
==
0
sg
=
dgl
.
sampling
.
sample_neighbors
(
g
,
F
.
tensor
([
1
,
2
],
dtype
=
F
.
int64
),
2
,
edge_dir
=
'out'
,
replace
=
True
)
assert
sg
.
number_of_edges
()
==
0
if
__name__
==
'__main__'
:
test_random_walk
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment