"vscode:/vscode.git/clone" did not exist on "ebf3ab1477dd480df1b8dd5d97a7b4aa3822716b"
Commit 65846a61 authored by rusty1s's avatar rusty1s
Browse files

added random serial cluster boilerplate

parent 546db4dc
import torch
from torch_cluster import graclus_cluster
from torch_cluster import random_cluster
def test_graclus():
def test_random():
edge_index = torch.LongTensor([[0, 0, 0, 1, 2, 3, 3, 3, 4, 5, 5, 5, 6, 6],
[2, 3, 6, 5, 0, 0, 4, 5, 3, 1, 3, 6, 0, 3]])
edge_attr = torch.Tensor([2, 2, 2, 1, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2])
rid = torch.LongTensor([0, 1, 2, 3, 4, 5, 6])
# edge_attr = torch.Tensor([2, 2, 2, 1, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2])
node_rid = torch.arange(edge_index.max() + 1, out=edge_index.new())
edge_rid = torch.arange(edge_index.size(0), out=edge_index.new())
graclus_cluster(edge_index, edge_attr=edge_attr, rid=rid)
random_cluster(edge_index, node_rid, edge_rid)
from .functions.grid import sparse_grid_cluster, dense_grid_cluster
from .functions.graclus import graclus_cluster
from .functions.random import random_cluster
__version__ = '0.2.6'
__all__ = [
'sparse_grid_cluster', 'dense_grid_cluster', 'graclus_cluster',
'sparse_grid_cluster', 'dense_grid_cluster', 'random_cluster',
'__version__'
]
import torch
def node_degree(edge_index, num_nodes, out=None):
def node_degree(row, num_nodes, out=None):
zero = torch.zeros(num_nodes, out=out)
one = torch.ones(edge_index.size(1), out=zero.new())
return zero.scatter_add_(0, edge_index[0], one)
one = torch.ones(row.size(0), out=zero.new())
return zero.scatter_add_(0, row, one)
from __future__ import division
import torch
# from .utils import get_func
from .degree import node_degree
def graclus_cluster(edge_index, num_nodes=None, edge_attr=None, rid=None):
num_nodes = edge_index.max() + 1 if num_nodes is None else num_nodes
rid = torch.randperm(num_nodes) if rid is None else rid
row, col = edge_index
# Compute edge-wise normalized cut.
cut = normalized_cut(edge_index, num_nodes, edge_attr)
# Sort row and col indices based on the (possibly random) `rid`.
_, perm = rid[row].sort()
row, col, cut = row[perm], col[perm], cut[perm]
print(row, col)
cluster = edge_index.new(num_nodes).fill_(-1)
# func = get_func('graclus', cluster)
# func(cluster, row, col, cut)
return cluster
def normalized_cut(edge_index, num_nodes, edge_attr=None):
row, col = edge_index
out = edge_attr.new() if edge_attr is not None else torch.Tensor()
cut = node_degree(edge_index, num_nodes, out=out)
cut = 1 / cut
cut = cut[row] + cut[col]
if edge_attr is None:
return cut
else:
if edge_attr.dim() > 1 and edge_attr.size(1) > 1:
edge_attr = torch.norm(edge_attr, 2, 1)
return edge_attr.squeeze() * cut
......@@ -2,7 +2,7 @@ from __future__ import division
import torch
from .utils import get_func, consecutive
from .utils import get_dynamic_func, consecutive
def _preprocess(position, size, batch=None, start=None):
......@@ -69,7 +69,7 @@ def _grid_cluster(position, size, cluster_size):
cluster = cluster_size.new(torch.Size(list(position.size())[:-1]))
cluster = cluster.unsqueeze(dim=-1)
func = get_func('grid', position)
func = get_dynamic_func('grid', position)
func(C, cluster, position, size, cluster_size)
cluster = cluster.squeeze(dim=-1)
......
from __future__ import division
import torch
def normalized_cut(edge_index, num_nodes, degree, edge_attr=None):
row, col = edge_index
cut = 1 / degree
cut = cut[row] + cut[col]
if edge_attr is None:
return cut
else:
if edge_attr.dim() > 1 and edge_attr.size(1) > 1:
edge_attr = torch.norm(edge_attr, 2, 1)
return edge_attr.squeeze() * cut
import torch
def permute(edge_index, num_nodes, node_rid=None, edge_rid=None):
row, col = edge_index
edge_rid = torch.randperm(row.size(0)) if edge_rid is None else edge_rid
row, col = row[edge_rid], col[edge_rid]
node_rid = torch.randperm(num_nodes) if node_rid is None else node_rid
_, perm = node_rid[row].sort()
row, col = row[perm], col[perm]
return row, col
from .utils import get_func
from .degree import node_degree
from .permute import permute
def random_cluster(edge_index, node_rid=None, edge_rid=None, num_nodes=None):
num_nodes = edge_index.max() + 1 if num_nodes is None else num_nodes
row, col = permute(edge_index, num_nodes, node_rid, edge_rid)
degree = node_degree(row, num_nodes, out=row.new())
cluster = edge_index.new(num_nodes).fill_(-1)
func = get_func('random', cluster)
func(cluster, row, col, degree)
return cluster
......@@ -5,10 +5,14 @@ from .._ext import ffi
def get_func(name, tensor):
cuda = '_cuda' if tensor.is_cuda else ''
return getattr(ffi, 'cluster_{}{}'.format(name, cuda))
def get_dynamic_func(name, tensor):
typename = type(tensor).__name__.replace('Tensor', '')
cuda = 'cuda_' if tensor.is_cuda else ''
func = getattr(ffi, 'cluster_{}_{}{}'.format(name, cuda, typename))
return func
return getattr(ffi, 'cluster_{}_{}{}'.format(name, cuda, typename))
def get_type(max, cuda):
......
......@@ -2,19 +2,19 @@
#define cluster_(NAME) TH_CONCAT_4(cluster_, NAME, _, Real)
void cluster_serial(THLongTensor *output, THLongTensor *row, THLongTensor *col) {
int64_t *output_data = output->storage->data + output->storageOffset;
int64_t *row_data = row->storage->data + row->storageOffset;
int64_t *col_data = col->storage->data + col->storageOffset;
int64_t n, N = THLongTensor_nElement(output), r, c, value;
for (n = 0; n < N; n++) {
r = row_data[n]; c = col_data[c];
if (output_data[r] == -1 && output_data[c] == -1) {
value = r < c ? r : c;
output_data[r] = value;
output_data[c] = value;
}
}
void cluster_random(THLongTensor *output, THLongTensor *row, THLongTensor *col, THLongTensor *degree) {
/* int64_t *output_data = output->storage->data + output->storageOffset; */
/* int64_t *row_data = row->storage->data + row->storageOffset; */
/* int64_t *col_data = col->storage->data + col->storageOffset; */
/* int64_t e, E = THLongTensor_nElement(row), r, c, value; */
/* for (e = 0; e < E; e++) { */
/* r = row_data[e]; c = col_data[e]; */
/* if (output_data[r] == -1 && output_data[c] == -1) { */
/* value = r < c ? r : c; */
/* output_data[r] = value; */
/* output_data[c] = value; */
/* } */
/* } */
}
#include "generic/cpu.c"
......
......@@ -6,4 +6,4 @@ void cluster_grid_Short (int C, THLongTensor *output, THShortTensor *position,
void cluster_grid_Int (int C, THLongTensor *output, THIntTensor *position, THIntTensor *size, THLongTensor *count);
void cluster_grid_Long (int C, THLongTensor *output, THLongTensor *position, THLongTensor *size, THLongTensor *count);
void cluster_serial(THLongTensor *output, THLongTensor *row, THLongTensor *col);
void cluster_random(THLongTensor *output, THLongTensor *row, THLongTensor *col, THLongTensor *degree);
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment