Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
ff207a2f
Commit
ff207a2f
authored
Mar 28, 2018
by
rusty1s
Browse files
bugfix
parent
da65da5b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
19 additions
and
2 deletions
+19
-2
test/utils/test_ffi.py
test/utils/test_ffi.py
+16
-0
torch_cluster/functions/utils/ffi.py
torch_cluster/functions/utils/ffi.py
+2
-1
torch_cluster/src/serial_cpu.c
torch_cluster/src/serial_cpu.c
+1
-1
No files found.
test/utils/test_ffi.py
View file @
ff207a2f
import
torch
from
torch_cluster.functions.utils.ffi
import
ffi_serial
def
test_serial_cpu
():
row
=
torch
.
LongTensor
([
0
,
0
,
1
,
1
,
1
,
2
,
2
,
2
,
3
,
3
])
col
=
torch
.
LongTensor
([
1
,
2
,
0
,
2
,
3
,
0
,
1
,
3
,
1
,
2
])
degree
=
torch
.
LongTensor
([
2
,
3
,
3
,
2
])
cluster
=
ffi_serial
(
row
,
col
,
degree
)
expected_cluster
=
[
0
,
0
,
2
,
2
]
assert
cluster
.
tolist
()
==
expected_cluster
weight
=
torch
.
Tensor
([
1
,
2
,
1
,
3
,
2
,
2
,
3
,
3
,
2
,
3
])
cluster
=
ffi_serial
(
row
,
col
,
degree
,
weight
)
expected_cluster
=
[
0
,
1
,
0
,
1
]
assert
cluster
.
tolist
()
==
expected_cluster
torch_cluster/functions/utils/ffi.py
View file @
ff207a2f
...
...
@@ -12,7 +12,8 @@ def _get_typed_func(name, tensor):
return
getattr
(
ffi
,
'cluster_{}_{}{}'
.
format
(
name
,
cuda
,
typename
))
def
ffi_serial
(
output
,
row
,
col
,
degree
,
weight
=
None
):
def
ffi_serial
(
row
,
col
,
degree
,
weight
=
None
):
output
=
row
.
new
(
degree
.
size
(
0
)).
fill_
(
-
1
)
if
weight
is
None
:
func
=
_get_func
(
'serial'
,
row
)
func
(
output
,
row
,
col
,
degree
)
...
...
torch_cluster/src/serial_cpu.c
View file @
ff207a2f
...
...
@@ -11,7 +11,7 @@
int64_t e = 0, row_value, col_value, value; \
while(e < THLongTensor_nElement(row)) { \
row_value = row_data[e]; \
if (output_data[row_value]
>=
0) { \
if (output_data[row_value]
<
0) { \
col_value = -1; \
SELECT \
if (col_value < 0) { \
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment