Commit ce9a53d1 authored by rusty1s's avatar rusty1s
Browse files

return C with batch

parent 6a2e1a08
......@@ -2,7 +2,7 @@ from os import path as osp
from setuptools import setup, find_packages
__version__ = '0.2.0'
__version__ = '0.2.1'
url = 'https://github.com/rusty1s/pytorch_cluster'
install_requires = ['cffi', 'torch-unique']
......
......@@ -50,7 +50,7 @@ def test_grid_cluster_cpu(tensor):
output, C = grid_cluster(position, size, batch, fake_nodes=True)
expected = torch.LongTensor([0, 5, 1, 0, 2, 6, 11, 7, 6, 8])
assert output.tolist() == expected.tolist()
assert C == 6
assert C == 12
@pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA')
......@@ -101,4 +101,4 @@ def test_grid_cluster_gpu(tensor): # pragma: no cover
output, C = grid_cluster(position, size, batch, fake_nodes=True)
expected = torch.LongTensor([0, 5, 1, 0, 2, 6, 11, 7, 6, 8])
assert output.cpu().tolist() == expected.tolist()
assert C == 6
assert C == 12
from .functions.grid import grid_cluster
__version__ = '0.2.0'
__version__ = '0.2.1'
__all__ = ['grid_cluster', '__version__']
......@@ -50,7 +50,7 @@ def grid_cluster(position, size, batch=None, origin=None, fake_nodes=False):
cluster = cluster.squeeze(dim=-1)
if fake_nodes:
return cluster, C // c_max[0]
return cluster, C
cluster, u = consecutive(cluster)
return cluster, None if batch is None else (u / (C // c_max[0])).long()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment