Commit ff207a2f authored by rusty1s's avatar rusty1s
Browse files

bugfix

parent da65da5b
import torch
from torch_cluster.functions.utils.ffi import ffi_serial
def test_serial_cpu():
row = torch.LongTensor([0, 0, 1, 1, 1, 2, 2, 2, 3, 3])
col = torch.LongTensor([1, 2, 0, 2, 3, 0, 1, 3, 1, 2])
degree = torch.LongTensor([2, 3, 3, 2])
cluster = ffi_serial(row, col, degree)
expected_cluster = [0, 0, 2, 2]
assert cluster.tolist() == expected_cluster
weight = torch.Tensor([1, 2, 1, 3, 2, 2, 3, 3, 2, 3])
cluster = ffi_serial(row, col, degree, weight)
expected_cluster = [0, 1, 0, 1]
assert cluster.tolist() == expected_cluster
......@@ -12,7 +12,8 @@ def _get_typed_func(name, tensor):
return getattr(ffi, 'cluster_{}_{}{}'.format(name, cuda, typename))
def ffi_serial(output, row, col, degree, weight=None):
def ffi_serial(row, col, degree, weight=None):
output = row.new(degree.size(0)).fill_(-1)
if weight is None:
func = _get_func('serial', row)
func(output, row, col, degree)
......
......@@ -11,7 +11,7 @@
int64_t e = 0, row_value, col_value, value; \
while(e < THLongTensor_nElement(row)) { \
row_value = row_data[e]; \
if (output_data[row_value] >= 0) { \
if (output_data[row_value] < 0) { \
col_value = -1; \
SELECT \
if (col_value < 0) { \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment