Commit ce9a53d1 authored by rusty1s's avatar rusty1s
Browse files

return C with batch

parent 6a2e1a08
...@@ -2,7 +2,7 @@ from os import path as osp ...@@ -2,7 +2,7 @@ from os import path as osp
from setuptools import setup, find_packages from setuptools import setup, find_packages
__version__ = '0.2.0' __version__ = '0.2.1'
url = 'https://github.com/rusty1s/pytorch_cluster' url = 'https://github.com/rusty1s/pytorch_cluster'
install_requires = ['cffi', 'torch-unique'] install_requires = ['cffi', 'torch-unique']
......
...@@ -50,7 +50,7 @@ def test_grid_cluster_cpu(tensor): ...@@ -50,7 +50,7 @@ def test_grid_cluster_cpu(tensor):
output, C = grid_cluster(position, size, batch, fake_nodes=True) output, C = grid_cluster(position, size, batch, fake_nodes=True)
expected = torch.LongTensor([0, 5, 1, 0, 2, 6, 11, 7, 6, 8]) expected = torch.LongTensor([0, 5, 1, 0, 2, 6, 11, 7, 6, 8])
assert output.tolist() == expected.tolist() assert output.tolist() == expected.tolist()
assert C == 6 assert C == 12
@pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA') @pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA')
...@@ -101,4 +101,4 @@ def test_grid_cluster_gpu(tensor): # pragma: no cover ...@@ -101,4 +101,4 @@ def test_grid_cluster_gpu(tensor): # pragma: no cover
output, C = grid_cluster(position, size, batch, fake_nodes=True) output, C = grid_cluster(position, size, batch, fake_nodes=True)
expected = torch.LongTensor([0, 5, 1, 0, 2, 6, 11, 7, 6, 8]) expected = torch.LongTensor([0, 5, 1, 0, 2, 6, 11, 7, 6, 8])
assert output.cpu().tolist() == expected.tolist() assert output.cpu().tolist() == expected.tolist()
assert C == 6 assert C == 12
from .functions.grid import grid_cluster from .functions.grid import grid_cluster
__version__ = '0.2.0' __version__ = '0.2.1'
__all__ = ['grid_cluster', '__version__'] __all__ = ['grid_cluster', '__version__']
...@@ -50,7 +50,7 @@ def grid_cluster(position, size, batch=None, origin=None, fake_nodes=False): ...@@ -50,7 +50,7 @@ def grid_cluster(position, size, batch=None, origin=None, fake_nodes=False):
cluster = cluster.squeeze(dim=-1) cluster = cluster.squeeze(dim=-1)
if fake_nodes: if fake_nodes:
return cluster, C // c_max[0] return cluster, C
cluster, u = consecutive(cluster) cluster, u = consecutive(cluster)
return cluster, None if batch is None else (u / (C // c_max[0])).long() return cluster, None if batch is None else (u / (C // c_max[0])).long()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment