Commit 157a8fe1 authored by rusty1s's avatar rusty1s
Browse files

test ffi

parent ff207a2f
import torch
from torch_cluster.functions.utils.ffi import ffi_serial
from torch_cluster.functions.utils.ffi import ffi_serial, ffi_grid
def test_serial_cpu():
......@@ -14,3 +14,12 @@ def test_serial_cpu():
cluster = ffi_serial(row, col, degree, weight)
expected_cluster = [0, 1, 0, 1]
assert cluster.tolist() == expected_cluster
def test_grid_cpu():
position = torch.Tensor([[0, 0], [11, 9], [2, 8], [2, 2], [8, 3]])
size = torch.Tensor([5, 5])
count = torch.LongTensor([3, 2])
cluster = ffi_grid(position, size, count)
expected_cluster = [0, 5, 1, 0, 2]
assert cluster.tolist() == expected_cluster
......@@ -24,7 +24,10 @@ def ffi_serial(row, col, degree, weight=None):
return output
def ffi_grid(C, output, position, size, count):
def ffi_grid(position, size, count):
C = count.prod()
output = count.new(position.size(0), 1)
func = _get_typed_func('grid', position)
func(C, output, position, size, count)
output = output.squeeze(-1)
return output
......@@ -4,16 +4,17 @@
void cluster_(serial)(THLongTensor *output, THLongTensor *row, THLongTensor *col, THLongTensor *degree, THTensor *weight) {
real *weight_data = weight->storage->data + weight->storageOffset;
real max_weight, w;
real weight_value, w;
int64_t d, c;
SERIAL(output, row, col, degree,
max_weight = 0;
for (d = 0; d < degree_data[row_value]; d++) {
weight_value = 0;
for (d = 0; d < degree_data[row_value]; d++) { // Iterate over neighbors.
c = col_data[e + d];
w = weight_data[e + d];
if (output_data[c] < 0 && w >= max_weight) {
if (output_data[c] < 0 && w >= weight_value) {
// Neighbor is unmatched and edge has a higher weight.
col_value = c;
max_weight = w;
weight_value = w;
}
}
)
......
......@@ -8,7 +8,7 @@
int64_t *col_data = col->storage->data + col->storageOffset; \
int64_t *degree_data = degree->storage->data + degree->storageOffset; \
\
int64_t e = 0, row_value, col_value, value; \
int64_t e = 0, row_value, col_value, v; \
while(e < THLongTensor_nElement(row)) { \
row_value = row_data[e]; \
if (output_data[row_value] < 0) { \
......@@ -18,9 +18,9 @@
output_data[row_value] = row_value; \
} \
else { \
value = row_value < col_value ? row_value : col_value; \
output_data[row_value] = value; \
output_data[col_value] = value; \
v = row_value < col_value ? row_value : col_value; \
output_data[row_value] = v; \
output_data[col_value] = v; \
} \
} \
e += degree_data[row_value]; \
......@@ -30,9 +30,9 @@
void cluster_serial(THLongTensor *output, THLongTensor *row, THLongTensor *col, THLongTensor *degree) {
int64_t d, c;
SERIAL(output, row, col, degree,
for (d = 0; d < degree_data[row_value]; d++) {
for (d = 0; d < degree_data[row_value]; d++) { // Iterate over neighbors.
c = col_data[e + d];
if (output_data[c] < 0) {
if (output_data[c] < 0) { // Neighbor is unmatched.
col_value = c;
break;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment