ffi.py 508 Bytes
Newer Older
rusty1s's avatar
rename  
rusty1s committed
1
import time
rusty1s's avatar
gputest  
rusty1s committed
2
3
4
import torch
from torch_cluster._ext import ffi

rusty1s's avatar
rename  
rusty1s committed
5
6
7
8
cluster = torch.cuda.LongTensor(4)
row = torch.cuda.LongTensor([0, 0, 1, 1, 1, 2, 2, 2, 3, 3])
col = torch.cuda.LongTensor([1, 2, 0, 2, 3, 0, 1, 3, 1, 2])
# deg = torch.cuda.LongTensor([2, 3, 3, 2])
rusty1s's avatar
gputest  
rusty1s committed
9

rusty1s's avatar
rename  
rusty1s committed
10
func = ffi.THCCGreedy
rusty1s's avatar
gputest  
rusty1s committed
11
12
print(func)

rusty1s's avatar
rename  
rusty1s committed
13
14
15
16
17
18
19
20
a = 0
torch.cuda.synchronize()
t = time.perf_counter()
# for i in range(100):
func(cluster, row, col)
# a += cluster.sum() / cluster.size(0)
torch.cuda.synchronize()
print(time.perf_counter() - t)
rusty1s's avatar
rusty1s committed
21
print(cluster)