cluster.py 533 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
import torch

import cluster_cpu


rusty1s's avatar
rusty1s committed
6
7
8
def grid_cluster(pos, size, start=None, end=None):
    start = pos.t().min(dim=1)[0] if start is None else start
    end = pos.t().max(dim=1)[0] if end is None else end
rusty1s's avatar
rusty1s committed
9
10
11
    return cluster_cpu.grid(pos, size, start, end)


rusty1s's avatar
rusty1s committed
12
pos = torch.tensor([[1, 1], [3, 3], [5, 5], [7, 7]])
rusty1s's avatar
rusty1s committed
13
14
15
16
17
size = torch.tensor([2, 2])
start = torch.tensor([0, 0])
end = torch.tensor([7, 7])
print('pos', pos.tolist())
print('size', size.tolist())
rusty1s's avatar
rusty1s committed
18
cluster = grid_cluster(pos, size)
rusty1s's avatar
rusty1s committed
19
print('result', cluster.tolist(), cluster.dtype)