from typing import Optional import torch def graclus_cluster( row: torch.Tensor, col: torch.Tensor, weight: Optional[torch.Tensor] = None, num_nodes: Optional[int] = None, ) -> torch.Tensor: """A greedy clustering algorithm of picking an unmarked vertex and matching it with one its unmarked neighbors (that maximizes its edge weight). Args: row (LongTensor): Source nodes. col (LongTensor): Target nodes. weight (Tensor, optional): Edge weights. (default: :obj:`None`) num_nodes (int, optional): The number of nodes. (default: :obj:`None`) :rtype: :class:`LongTensor` .. code-block:: python import torch from torch_cluster import graclus_cluster row = torch.tensor([0, 1, 1, 2]) col = torch.tensor([1, 0, 2, 1]) weight = torch.Tensor([1, 1, 1, 1]) cluster = graclus_cluster(row, col, weight) """ if num_nodes is None: num_nodes = max(int(row.max()), int(col.max())) + 1 # Remove self-loops. mask = row != col row, col = row[mask], col[mask] if weight is not None: weight = weight[mask] # Randomly shuffle nodes. if weight is None: perm = torch.randperm(row.size(0), dtype=torch.long, device=row.device) row, col = row[perm], col[perm] # To CSR. perm = torch.argsort(row) row, col = row[perm], col[perm] if weight is not None: weight = weight[perm] deg = row.new_zeros(num_nodes) deg.scatter_add_(0, row, torch.ones_like(row)) rowptr = row.new_zeros(num_nodes + 1) torch.cumsum(deg, 0, out=rowptr[1:]) return torch.ops.torch_cluster.graclus(rowptr, col, weight)