cluster.py 436 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torch

import cluster_cpu


def grid_cluster(pos, size, start, end):
    return cluster_cpu.grid(pos, size, start, end)


pos = torch.tensor([[1, 1], [3, 3], [5, 5], [7, 7]], dtype=torch.uint8)
size = torch.tensor([2, 2])
start = torch.tensor([0, 0])
end = torch.tensor([7, 7])
print('pos', pos.tolist())
print('size', size.tolist())
cluster = grid_cluster(pos, size, start, end)
print('result', cluster.tolist(), cluster.dtype)