cluster.py 1.03 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
import torch

import cluster_cpu
import cluster_cuda


def grid(pos, size, start=None, end=None):
    lib = cluster_cuda if pos.is_cuda else cluster_cpu
    start = pos.t().min(dim=1)[0] if start is None else start
    end = pos.t().max(dim=1)[0] if end is None else end
    return lib.grid(pos, size, start, end)


def graclus(row, col, num_nodes):
rusty1s's avatar
new try  
rusty1s committed
15
16
    lib = cluster_cuda if pos.is_cuda else cluster_cpu
    return lib.graclus(row, col, num_nodes)
rusty1s's avatar
rusty1s committed
17
18
19
20
21
22
23
24
25
26
27


device = torch.device('cuda')
pos = torch.tensor([[1, 1], [3, 3], [5, 5], [7, 7]], device=device)
size = torch.tensor([2, 2], device=device)
print('pos', pos.tolist())
print('size', size.tolist())
cluster = grid(pos, size)
print('result', cluster.tolist(), cluster.dtype, cluster.device)

print('-----------------')
rusty1s's avatar
new try  
rusty1s committed
28
29
30
31
32

row = torch.tensor([0, 0, 1, 1, 1, 2, 2, 2, 3, 3], device=device)
col = torch.tensor([1, 2, 0, 2, 3, 0, 1, 3, 1, 2], device=device)
print('row', row.tolist())
print('col', col.tolist())
rusty1s's avatar
rusty1s committed
33
cluster = graclus(row, col, 4)
rusty1s's avatar
new try  
rusty1s committed
34
print('result', cluster.tolist(), cluster.dtype, cluster.device)