Commit af7786d8 authored by rusty1s's avatar rusty1s
Browse files

added python impl

parent a2f2986a
from .functions.grid import grid_cluster
__version__ = '0.1.0'
__all__ = ['__version__']
__all__ = ['grid_cluster', '__version__']
import torch
from .utils import get_func
def grid_cluster(position, size, batch=None):
# TODO: Check types and sizes
print(batch.type())
print(position.type())
if batch is not None:
batch = batch.type_as(position)
position = torch.cat([position, batch], dim=position.dim() - 1)
size = torch.cat([size, size.new(1).fill_(1)], dim=0)
print(position)
# TODO: BATCH
# print(position[0])
# print(position[1])
dim = position.dim()
# Allow one-dimensional positions.
if dim == 1:
position = position.unsqueeze(1)
dim += 1
# Translate to minimal positive positions.
min = position.min(dim=dim - 2, keepdim=True)[0]
position = position - min
# Compute cluster count for each dimension.
max = position.max(dim=0)[0]
while max.dim() > 1:
max = max.max(dim=0)[0]
c_max = torch.ceil(max / size.type_as(max)).long()
C = c_max.prod()
# Generate cluster tensor.
s = list(position.size())
s[-1] = 1
cluster = c_max.new(torch.Size(s))
# Fill cluster tensor and reshape.
func = get_func('grid', position)
func(C, cluster, position, size, c_max)
cluster = cluster.squeeze(dim=dim - 1)
return cluster
from .._ext import ffi
def get_func(name, tensor):
typename = type(tensor).__name__.replace('Tensor', '')
cuda = 'cuda_' if tensor.is_cuda else ''
func = getattr(ffi, 'cluster_{}_{}{}'.format(name, cuda, typename))
return func
......@@ -5,9 +5,10 @@
void cluster_(grid)(int C, THLongTensor *output, THTensor *position, THTensor *size, THLongTensor *count) {
real *size_data = size->storage->data + size->storageOffset;
int64_t *count_data = count->storage->data + count->storageOffset;
int64_t d, i, c, tmp;
d = THTensor_(size)(position, 1);
TH_TENSOR_DIM_APPLY2(int64_t, output, real, position, 1,
int64_t D, d, i, c, tmp;
D = THTensor_(nDimension)(position);
d = THTensor_(size)(position, D - 1);
TH_TENSOR_DIM_APPLY2(int64_t, output, real, position, D - 1,
tmp = C; c = 0;
for (i = 0; i < d; i++) {
tmp = tmp / *(count_data + i);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment