Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
23709f94
Commit
23709f94
authored
Oct 18, 2021
by
rusty1s
Browse files
add test
parent
709f6837
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
1 deletion
+12
-1
test/test_sample.py
test/test_sample.py
+10
-1
torch_sparse/sample.py
torch_sparse/sample.py
+2
-0
No files found.
test/test_sample.py
View file @
23709f94
import
torch
import
torch
from
torch_sparse
import
SparseTensor
,
sample_adj
from
torch_sparse
import
SparseTensor
,
sample
,
sample_adj
def
test_sample
():
row
=
torch
.
tensor
([
0
,
0
,
2
,
2
])
col
=
torch
.
tensor
([
1
,
2
,
0
,
1
])
adj
=
SparseTensor
(
row
=
row
,
col
=
col
,
sparse_sizes
=
(
3
,
3
))
out
=
sample
(
adj
,
num_neighbors
=
1
)
assert
out
.
min
()
>=
0
and
out
.
max
()
<=
2
def
test_sample_adj
():
def
test_sample_adj
():
...
...
torch_sparse/sample.py
View file @
23709f94
...
@@ -13,6 +13,8 @@ def sample(src: SparseTensor, num_neighbors: int,
...
@@ -13,6 +13,8 @@ def sample(src: SparseTensor, num_neighbors: int,
if
subset
is
not
None
:
if
subset
is
not
None
:
rowcount
=
rowcount
[
subset
]
rowcount
=
rowcount
[
subset
]
rowptr
=
rowptr
[
subset
]
rowptr
=
rowptr
[
subset
]
else
:
rowptr
=
rowptr
[:
-
1
]
rand
=
torch
.
rand
((
rowcount
.
size
(
0
),
num_neighbors
),
device
=
col
.
device
)
rand
=
torch
.
rand
((
rowcount
.
size
(
0
),
num_neighbors
),
device
=
col
.
device
)
rand
.
mul_
(
rowcount
.
to
(
rand
.
dtype
).
view
(
-
1
,
1
))
rand
.
mul_
(
rowcount
.
to
(
rand
.
dtype
).
view
(
-
1
,
1
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment