Commit 56ae40e7 authored by rusty1s's avatar rusty1s
Browse files

grid impl

parent ea62b22d
......@@ -3,8 +3,8 @@ import subprocess
import torch
from torch.utils.ffi import create_extension
headers = []
sources = []
headers = ['torch_cluster/src/cpu.h']
sources = ['torch_cluster/src/cpu.c']
include_dirs = ['torch_cluster/src']
define_macros = []
extra_objects = []
......
void cluster_grid_Float (int C, THLongTensor *output, THFloatTensor *position, THFloatTensor *size);
void cluster_grid_Double(int C, THLongTensor *output, THDoubleTensor *position, THDoubleTensor *size);
void cluster_grid_Byte (int C, THLongTensor *output, THByteTensor *position, THByteTensor *size);
void cluster_grid_Char (int C, THLongTensor *output, THCharTensor *position, THCharTensor *size);
void cluster_grid_Short (int C, THLongTensor *output, THShortTensor *position, THShortTensor *size);
void cluster_grid_Int (int C, THLongTensor *output, THIntTensor *position, THIntTensor *size);
void cluster_grid_Long (int C, THLongTensor *output, THLongTensor *position, THLongTensor *size);
void cluster_grid_Float (int C, THLongTensor *output, THFloatTensor *position, THFloatTensor *size, THLongTensor *count);
void cluster_grid_Double(int C, THLongTensor *output, THDoubleTensor *position, THDoubleTensor *size, THLongTensor *count);
void cluster_grid_Byte (int C, THLongTensor *output, THByteTensor *position, THByteTensor *size, THLongTensor *count);
void cluster_grid_Char (int C, THLongTensor *output, THCharTensor *position, THCharTensor *size, THLongTensor *count);
void cluster_grid_Short (int C, THLongTensor *output, THShortTensor *position, THShortTensor *size, THLongTensor *count);
void cluster_grid_Int (int C, THLongTensor *output, THIntTensor *position, THIntTensor *size, THLongTensor *count);
void cluster_grid_Long (int C, THLongTensor *output, THLongTensor *position, THLongTensor *size, THLongTensor *count);
......@@ -2,7 +2,19 @@
#define TH_GENERIC_FILE "generic/cpu.c"
#else
void cluster_(grid)(int C, THLongTensor *output, THTensor *position, THTensor *size) {
void cluster_(grid)(int C, THLongTensor *output, THTensor *position, THTensor *size, THLongTensor *count) {
real *size_data = size->storage->data + size->storageOffset;
int64_t *count_data = count->storage->data + count->storageOffset;
int64_t d, i, c, tmp;
d = THTensor_(size)(position, 1);
TH_TENSOR_DIM_APPLY2(int64_t, output, real, position, 1,
tmp = C; c = 0;
for (i = 0; i < d; i++) {
tmp = tmp / *(count_data + i);
c += tmp * (int64_t)floor((float)(*(position_data + i * position_stride) / *(size_data + i)));
}
output_data[0] = c;
)
}
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment