bernoulli.py 477 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import time

import torch
from torch_cluster.functions.utils.ffi import _get_func

output = torch.cuda.FloatTensor(500000000).fill_(0.5)
torch.cuda.synchronize()
t = time.perf_counter()
torch.bernoulli(output)
torch.cuda.synchronize()
print(time.perf_counter() - t)

output = output.long().fill_(-1)
func = _get_func('serial', output)
torch.cuda.synchronize()
t = time.perf_counter()
func(output, output, output, output)
torch.cuda.synchronize()
print(time.perf_counter() - t)