Commit 70afd38e authored by rusty1s's avatar rusty1s
Browse files

fixed cluster start arg bugfix

parent 5b994648
......@@ -2,7 +2,7 @@ from os import path as osp
from setuptools import setup, find_packages
__version__ = '0.2.5'
__version__ = '0.2.6'
url = 'https://github.com/rusty1s/pytorch_cluster'
install_requires = ['cffi', 'torch-unique']
......
from .functions.grid import sparse_grid_cluster, dense_grid_cluster
__version__ = '0.2.5'
__version__ = '0.2.6'
__all__ = ['sparse_grid_cluster', 'dense_grid_cluster', '__version__']
......@@ -21,6 +21,7 @@ def _preprocess(position, size, batch=None, start=None):
min = []
for i in range(position.size(-1)):
min.append(position[:, i].min())
start = position.new(min)
position = position - position.new(min)
else:
assert start.numel() == size.numel(), (
......@@ -36,7 +37,7 @@ def _preprocess(position, size, batch=None, start=None):
position = torch.cat([batch, position], dim=-1)
size = torch.cat([size.new(1).fill_(1), size], dim=-1)
return position, size
return position, size, start
def _minimal_cluster_size(position, size):
......@@ -47,11 +48,11 @@ def _minimal_cluster_size(position, size):
return cluster_size
def _fixed_cluster_size(position, size, batch=None, end=None):
def _fixed_cluster_size(position, size, start, batch=None, end=None):
if end is None:
return _minimal_cluster_size(position, size)
end = end.type_as(size)
end = end.type_as(size) - start.type_as(size)
eps = 0.000001 # Simulate [start, end) interval.
if batch is None:
cluster_size = ((end / size).float() - eps).long() + 1
......@@ -76,7 +77,7 @@ def _grid_cluster(position, size, cluster_size):
def sparse_grid_cluster(position, size, batch=None, start=None):
position, size = _preprocess(position, size, batch, start)
position, size, start = _preprocess(position, size, batch, start)
cluster_size = _minimal_cluster_size(position, size)
cluster, C = _grid_cluster(position, size, cluster_size)
cluster, u = consecutive(cluster)
......@@ -89,7 +90,7 @@ def sparse_grid_cluster(position, size, batch=None, start=None):
def dense_grid_cluster(position, size, batch=None, start=None, end=None):
position, size = _preprocess(position, size, batch, start)
cluster_size = _fixed_cluster_size(position, size, batch, end)
position, size, start = _preprocess(position, size, batch, start)
cluster_size = _fixed_cluster_size(position, size, start, batch, end)
cluster, C = _grid_cluster(position, size, cluster_size)
return cluster, C
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment