Commit 4056bf63 authored by rusty1s's avatar rusty1s
Browse files

added random cluster

parent 65846a61
......@@ -6,7 +6,8 @@ def test_random():
edge_index = torch.LongTensor([[0, 0, 0, 1, 2, 3, 3, 3, 4, 5, 5, 5, 6, 6],
[2, 3, 6, 5, 0, 0, 4, 5, 3, 1, 3, 6, 0, 3]])
# edge_attr = torch.Tensor([2, 2, 2, 1, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2])
node_rid = torch.arange(edge_index.max() + 1, out=edge_index.new())
edge_rid = torch.arange(edge_index.size(0), out=edge_index.new())
rid = torch.arange(edge_index.max() + 1, out=edge_index.new())
output = random_cluster(edge_index, rid, perm_edges=False)
random_cluster(edge_index, node_rid, edge_rid)
expected_output = [0, 1, 2, 0, 4, 1, 6]
assert output.tolist() == expected_output
import torch
def permute(edge_index, num_nodes, node_rid=None, edge_rid=None):
def permute(edge_index, num_nodes, rid=None, perm_edges=True):
row, col = edge_index
edge_rid = torch.randperm(row.size(0)) if edge_rid is None else edge_rid
row, col = row[edge_rid], col[edge_rid]
if perm_edges:
edge_rid = torch.randperm(row.size(0))
row, col = row[edge_rid], col[edge_rid]
node_rid = torch.randperm(num_nodes) if node_rid is None else node_rid
_, perm = node_rid[row].sort()
rid = torch.randperm(num_nodes) if rid is None else rid
_, perm = rid[row].sort()
row, col = row[perm], col[perm]
return row, col
return torch.stack([row, col], dim=0)
......@@ -3,9 +3,9 @@ from .degree import node_degree
from .permute import permute
def random_cluster(edge_index, node_rid=None, edge_rid=None, num_nodes=None):
def random_cluster(edge_index, rid=None, perm_edges=True, num_nodes=None):
num_nodes = edge_index.max() + 1 if num_nodes is None else num_nodes
row, col = permute(edge_index, num_nodes, node_rid, edge_rid)
row, col = permute(edge_index, num_nodes, rid, perm_edges)
degree = node_degree(row, num_nodes, out=row.new())
cluster = edge_index.new(num_nodes).fill_(-1)
......
......@@ -3,18 +3,36 @@
#define cluster_(NAME) TH_CONCAT_4(cluster_, NAME, _, Real)
void cluster_random(THLongTensor *output, THLongTensor *row, THLongTensor *col, THLongTensor *degree) {
/* int64_t *output_data = output->storage->data + output->storageOffset; */
/* int64_t *row_data = row->storage->data + row->storageOffset; */
/* int64_t *col_data = col->storage->data + col->storageOffset; */
/* int64_t e, E = THLongTensor_nElement(row), r, c, value; */
/* for (e = 0; e < E; e++) { */
/* r = row_data[e]; c = col_data[e]; */
/* if (output_data[r] == -1 && output_data[c] == -1) { */
/* value = r < c ? r : c; */
/* output_data[r] = value; */
/* output_data[c] = value; */
/* } */
/* } */
int64_t *output_data = output->storage->data + output->storageOffset;
int64_t *row_data = row->storage->data + row->storageOffset;
int64_t *col_data = col->storage->data + col->storageOffset;
int64_t *degree_data = degree->storage->data + degree->storageOffset;
int64_t e = 0, row_value, col_value, i;
while(e < THLongTensor_nElement(row)) {
row_value = row_data[e];
if (output_data[row_value] < 0) { // Node is unmatched.
// Find next unmatched neighbor.
col_value = -1;
for (i = 0; i < degree_data[row_value]; i++) {
col_value = col_data[e + i];
if (output_data[col_value] < 0) break; // Neighbor found.
else col_value = -1;
}
// Set output.
if (col_value < 0) {
output_data[row_value] = row_value;
}
else {
i = row_value < col_value ? row_value : col_value;
output_data[row_value] = i;
output_data[col_value] = i;
}
}
e += degree_data[row_value];
}
}
#include "generic/cpu.c"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment