cluster.py 610 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch
import cluster_cuda

dtype = torch.float
device = torch.device('cuda')


def grid_cluster(pos, size, start=None, end=None):
    start = pos.t().min(dim=1)[0] if start is None else start
    end = pos.t().max(dim=1)[0] if end is None else end
    return cluster_cuda.grid(pos, size, start, end)


pos = torch.tensor(
    [[1, 1], [3, 3], [5, 5], [7, 7]], dtype=dtype, device=device)
size = torch.tensor([2, 2, 1, 1, 4, 2, 1], dtype=dtype, device=device)
# print('pos', pos.tolist())
# print('size', size.tolist())
cluster = grid_cluster(pos, size)
print('result', cluster.tolist(), cluster.type())