Commit 3aad518a authored by rusty1s's avatar rusty1s
Browse files

added normalized cut implementation

parent af96c80a
import pytest
import torch
from torch_cluster import graclus_cluster
......@@ -7,5 +6,6 @@ def test_graclus():
edge_index = torch.LongTensor([[0, 0, 0, 1, 2, 3, 3, 3, 4, 5, 5, 5, 6, 6],
[2, 3, 6, 5, 0, 0, 4, 5, 3, 1, 3, 6, 0, 3]])
edge_attr = torch.Tensor([2, 2, 2, 1, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2])
rid = torch.LongTensor([0, 1, 2, 3, 4, 5, 6])
graclus_cluster(edge_index, edge_attr)
graclus_cluster(edge_index, edge_attr=edge_attr, rid=rid)
import torch
def node_degree(edge_index, num_nodes, out=None):
zero = torch.zeros(num_nodes, out=out)
one = torch.ones(edge_index.size(1), out=zero.new())
return zero.scatter_add_(0, edge_index[0], one)
def graclus_cluster(edge_index, edge_attr=None, batch=None):
pass
from __future__ import division
import torch
from .degree import node_degree
def graclus_cluster(edge_index,
num_nodes=None,
edge_attr=None,
batch=None,
rid=None):
num_nodes = edge_index.max() + 1 if num_nodes is None else num_nodes
rid = torch.randperm(num_nodes) if rid is None else rid
cut = normalized_cut(edge_index, num_nodes, edge_attr)
print(cut)
def normalized_cut(edge_index, num_nodes, edge_attr=None):
row, col = edge_index
out = edge_attr.new() if edge_attr is not None else torch.Tensor()
cut = node_degree(edge_index, num_nodes, out=out)
cut = 1 / cut
cut = cut[row] + cut[col]
if edge_attr is None:
return cut
else:
if edge_attr.dim() > 1 and edge_attr.size(1) > 1:
edge_attr = torch.norm(edge_attr, 2, 1)
return edge_attr.squeeze() * cut
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment