Commit 7985cdd8 authored by rusty1s's avatar rusty1s
Browse files

removed self loops in graclus

parent d2ee1523
......@@ -2,7 +2,7 @@ from os import path as osp
from setuptools import setup, find_packages
__version__ = '1.0.1'
__version__ = '1.0.2'
url = 'https://github.com/rusty1s/pytorch_cluster'
install_requires = ['cffi']
......
from .graclus import graclus_cluster
from .grid import grid_cluster
__version__ = '1.0.1'
__version__ = '1.0.2'
__all__ = ['graclus_cluster', 'grid_cluster', '__version__']
from .utils.loop import remove_self_loops
from .utils.perm import randperm, sort_row, randperm_sort_row
from .utils.ffi import graclus
......@@ -19,7 +20,6 @@ def graclus_cluster(row, col, weight=None, num_nodes=None):
>>> weight = torch.Tensor([1, 1, 1, 1])
>>> cluster = graclus_cluster(row, col, weight)
"""
num_nodes = row.max() + 1 if num_nodes is None else num_nodes
if row.is_cuda: # pragma: no cover
......@@ -28,6 +28,7 @@ def graclus_cluster(row, col, weight=None, num_nodes=None):
row, col = randperm(row, col)
row, col = randperm_sort_row(row, col, num_nodes)
row, col = remove_self_loops(row, col)
cluster = row.new(num_nodes)
graclus(cluster, row, col, weight)
......
def remove_self_loops(row, col):
mask = row != col
return row[mask], col[mask]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment