Commit fff675ec authored by rusty1s's avatar rusty1s
Browse files

so much faster

parent ef96f7a1
import time
import torch
from torch_cluster import sparse_grid_cluster
n = 90000000
s = 1 / 64
print('GPU ===================')
t = time.perf_counter()
pos = torch.cuda.FloatTensor(n, 3).uniform_(0, 1)
size = torch.cuda.FloatTensor([s, s, s])
torch.cuda.synchronize()
print('Init:', time.perf_counter() - t)
t_all = time.perf_counter()
sparse_grid_cluster(pos, size)
torch.cuda.synchronize()
t_all = time.perf_counter() - t_all
print('All:', t_all)
print('CPU ===================')
pos = pos.cpu()
size = size.cpu()
t_all = time.perf_counter()
sparse_grid_cluster(pos, size)
t_all = time.perf_counter() - t_all
print('All:', t_all)
...@@ -2,7 +2,7 @@ from os import path as osp ...@@ -2,7 +2,7 @@ from os import path as osp
from setuptools import setup, find_packages from setuptools import setup, find_packages
__version__ = '0.2.3' __version__ = '0.2.4'
url = 'https://github.com/rusty1s/pytorch_cluster' url = 'https://github.com/rusty1s/pytorch_cluster'
install_requires = ['cffi', 'torch-unique'] install_requires = ['cffi', 'torch-unique']
......
from .functions.grid import sparse_grid_cluster, dense_grid_cluster from .functions.grid import sparse_grid_cluster, dense_grid_cluster
__version__ = '0.2.3' __version__ = '0.2.4'
__all__ = ['sparse_grid_cluster', 'dense_grid_cluster', '__version__'] __all__ = ['sparse_grid_cluster', 'dense_grid_cluster', '__version__']
...@@ -18,11 +18,12 @@ def _preprocess(position, size, batch=None, start=None): ...@@ -18,11 +18,12 @@ def _preprocess(position, size, batch=None, start=None):
# Translate to minimal positive positions if no start was passed. # Translate to minimal positive positions if no start was passed.
if start is None: if start is None:
position = position - position.min(dim=-2, keepdim=True)[0] min = []
else: for i in range(position.size(-1)):
min.append(position[:, i].min())
position = position - position.new(min)
elif start != 0:
position = position - start position = position - start
assert position.min() >= 0, (
'Passed origin resulting in unallowed negative positions')
# If given, append batch to position tensor. # If given, append batch to position tensor.
if batch is not None: if batch is not None:
...@@ -37,10 +38,10 @@ def _preprocess(position, size, batch=None, start=None): ...@@ -37,10 +38,10 @@ def _preprocess(position, size, batch=None, start=None):
def _minimal_cluster_size(position, size): def _minimal_cluster_size(position, size):
max = position.max(dim=0)[0] max = []
while max.dim() > 1: for i in range(position.size(-1)):
max = max.max(dim=0)[0] max.append(position[:, i].max())
cluster_size = (max / size).long() + 1 cluster_size = (size.new(max) / size).long() + 1
return cluster_size return cluster_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment