Commit 70afd38e authored by rusty1s's avatar rusty1s
Browse files

fixed cluster start arg bugfix

parent 5b994648
...@@ -2,7 +2,7 @@ from os import path as osp ...@@ -2,7 +2,7 @@ from os import path as osp
from setuptools import setup, find_packages from setuptools import setup, find_packages
__version__ = '0.2.5' __version__ = '0.2.6'
url = 'https://github.com/rusty1s/pytorch_cluster' url = 'https://github.com/rusty1s/pytorch_cluster'
install_requires = ['cffi', 'torch-unique'] install_requires = ['cffi', 'torch-unique']
......
from .functions.grid import sparse_grid_cluster, dense_grid_cluster from .functions.grid import sparse_grid_cluster, dense_grid_cluster
__version__ = '0.2.5' __version__ = '0.2.6'
__all__ = ['sparse_grid_cluster', 'dense_grid_cluster', '__version__'] __all__ = ['sparse_grid_cluster', 'dense_grid_cluster', '__version__']
...@@ -21,6 +21,7 @@ def _preprocess(position, size, batch=None, start=None): ...@@ -21,6 +21,7 @@ def _preprocess(position, size, batch=None, start=None):
min = [] min = []
for i in range(position.size(-1)): for i in range(position.size(-1)):
min.append(position[:, i].min()) min.append(position[:, i].min())
start = position.new(min)
position = position - position.new(min) position = position - position.new(min)
else: else:
assert start.numel() == size.numel(), ( assert start.numel() == size.numel(), (
...@@ -36,7 +37,7 @@ def _preprocess(position, size, batch=None, start=None): ...@@ -36,7 +37,7 @@ def _preprocess(position, size, batch=None, start=None):
position = torch.cat([batch, position], dim=-1) position = torch.cat([batch, position], dim=-1)
size = torch.cat([size.new(1).fill_(1), size], dim=-1) size = torch.cat([size.new(1).fill_(1), size], dim=-1)
return position, size return position, size, start
def _minimal_cluster_size(position, size): def _minimal_cluster_size(position, size):
...@@ -47,11 +48,11 @@ def _minimal_cluster_size(position, size): ...@@ -47,11 +48,11 @@ def _minimal_cluster_size(position, size):
return cluster_size return cluster_size
def _fixed_cluster_size(position, size, batch=None, end=None): def _fixed_cluster_size(position, size, start, batch=None, end=None):
if end is None: if end is None:
return _minimal_cluster_size(position, size) return _minimal_cluster_size(position, size)
end = end.type_as(size) end = end.type_as(size) - start.type_as(size)
eps = 0.000001 # Simulate [start, end) interval. eps = 0.000001 # Simulate [start, end) interval.
if batch is None: if batch is None:
cluster_size = ((end / size).float() - eps).long() + 1 cluster_size = ((end / size).float() - eps).long() + 1
...@@ -76,7 +77,7 @@ def _grid_cluster(position, size, cluster_size): ...@@ -76,7 +77,7 @@ def _grid_cluster(position, size, cluster_size):
def sparse_grid_cluster(position, size, batch=None, start=None): def sparse_grid_cluster(position, size, batch=None, start=None):
position, size = _preprocess(position, size, batch, start) position, size, start = _preprocess(position, size, batch, start)
cluster_size = _minimal_cluster_size(position, size) cluster_size = _minimal_cluster_size(position, size)
cluster, C = _grid_cluster(position, size, cluster_size) cluster, C = _grid_cluster(position, size, cluster_size)
cluster, u = consecutive(cluster) cluster, u = consecutive(cluster)
...@@ -89,7 +90,7 @@ def sparse_grid_cluster(position, size, batch=None, start=None): ...@@ -89,7 +90,7 @@ def sparse_grid_cluster(position, size, batch=None, start=None):
def dense_grid_cluster(position, size, batch=None, start=None, end=None): def dense_grid_cluster(position, size, batch=None, start=None, end=None):
position, size = _preprocess(position, size, batch, start) position, size, start = _preprocess(position, size, batch, start)
cluster_size = _fixed_cluster_size(position, size, batch, end) cluster_size = _fixed_cluster_size(position, size, start, batch, end)
cluster, C = _grid_cluster(position, size, cluster_size) cluster, C = _grid_cluster(position, size, cluster_size)
return cluster, C return cluster, C
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment