Commit 920cc934 authored by rusty1s's avatar rusty1s
Browse files

graclus cuda, cleanup old code

parent d2cc3162
#include <ATen/ATen.h>
#include "color.cuh"
#include "common.cuh"
at::Tensor graclus(at::Tensor row, at::Tensor col, int num_nodes) {
// Remove self-loops.
auto mask = row != col;
row = row.masked_select(mask);
col.masked_select(mask);
// Sort by row index.
at::Tensor perm;
std::tie(row, perm) = row.sort();
col = col.index_select(0, perm);
// Generate helper vectors.
auto cluster = at::full(row.type(), {num_nodes}, -1);
auto prop = at::full(row.type(), {num_nodes}, -1);
auto deg = degree(row, num_nodes);
auto cum_deg = deg.cumsum(0);
color(cluster);
/* while (!color(cluster)) { */
/* propose(cluster, prop, row, col, weight, deg, cum_deg); */
/* response(cluster, prop, row, col, weight, deg, cum_deg); */
/* } */
return cluster;
}
at::Tensor weighted_graclus(at::Tensor row, at::Tensor col, at::Tensor weight,
int num_nodes) {
// Remove self-loops.
auto mask = row != col;
row = row.masked_select(mask);
col = col.masked_select(mask);
weight = weight.masked_select(mask);
// Sort by row index.
at::Tensor perm;
std::tie(row, perm) = row.sort();
col = col.index_select(0, perm);
weight = weight.index_select(0, perm);
// Generate helper vectors.
auto cluster = at::full(row.type(), {num_nodes}, -1);
auto prop = at::full(row.type(), {num_nodes}, -1);
auto deg = degree(row, num_nodes);
auto cum_deg = deg.cumsum(0);
color(cluster);
/* while (!color(cluster)) { */
/* weighted_propose(cluster, prop, row, col, weight, deg, cum_deg); */
/* weighted_response(cluster, prop, row, col, weight, deg, cum_deg); */
/* } */
return cluster;
}
#include <ATen/ATen.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include "common.cuh"
template <typename scalar_t>
__global__ void
grid_kernel(int64_t *cluster, at::cuda::detail::TensorInfo<scalar_t, int> pos,
scalar_t *__restrict__ size, scalar_t *__restrict__ start,
scalar_t *__restrict__ end, size_t num_nodes) {
const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
for (ptrdiff_t i = index; i < num_nodes; i += stride) {
int64_t c = 0, k = 1;
scalar_t tmp;
for (ptrdiff_t d = 0; d < pos.sizes[1]; d++) {
tmp = pos.data[i * pos.strides[0] + d * pos.strides[1]] - start[d];
c += (int64_t)(tmp / size[d]) * k;
k += (int64_t)((end[d] - start[d]) / size[d]);
}
cluster[i] = c;
}
}
at::Tensor grid(at::Tensor pos, at::Tensor size, at::Tensor start,
at::Tensor end) {
auto cluster = at::empty(pos.type().toScalarType(at::kLong), {pos.size(0)});
AT_DISPATCH_ALL_TYPES(pos.type(), "grid_kernel", [&] {
grid_kernel<scalar_t><<<BLOCKS(pos.size(0)), THREADS>>>(
cluster.data<int64_t>(),
at::cuda::detail::getTensorInfo<scalar_t, int>(pos),
size.toType(pos.type()).data<scalar_t>(),
start.toType(pos.type()).data<scalar_t>(),
end.toType(pos.type()).data<scalar_t>(), pos.size(0));
});
return cluster;
}
import glob
from setuptools import setup
import torch.cuda
from torch.utils.cpp_extension import CppExtension, CUDAExtension
ext_modules = [CppExtension('cluster_cpu', ['cpu/cluster.cpp'])]
if torch.cuda.is_available():
ext_modules += [
CUDAExtension('cluster_cuda',
['cuda/cluster.cpp'] + glob.glob('cuda/*.cu'))
]
setup(
name='cluster',
ext_modules=ext_modules,
cmdclass={'build_ext': torch.utils.cpp_extension.BuildExtension},
)
......@@ -2,18 +2,10 @@
#include "utils.h"
#define ITERATE_NEIGHBORS(NODE, NAME, ROW, COL, ...) \
{ \
for (int64_t e = ROW[NODE]; e < ROW[NODE + 1]; e++) { \
auto NAME = COL[e]; \
__VA_ARGS__; \
} \
}
at::Tensor graclus(at::Tensor row, at::Tensor col, int64_t num_nodes) {
std::tie(row, col) = remove_self_loops(row, col);
std::tie(row, col) = rand(row, col);
std::tie(row, col) = to_csr(row, col);
std::tie(row, col) = to_csr(row, col, num_nodes);
auto row_data = row.data<int64_t>(), col_data = col.data<int64_t>();
auto perm = randperm(num_nodes);
......@@ -30,14 +22,16 @@ at::Tensor graclus(at::Tensor row, at::Tensor col, int64_t num_nodes) {
cluster_data[u] = u;
ITERATE_NEIGHBORS(u, v, row_data, col_data, {
for (int64_t j = row_data[u]; j < row_data[u + 1]; j++) {
auto v = col_data[j];
if (cluster_data[v] >= 0)
continue;
cluster_data[u] = std::min(u, v);
cluster_data[v] = std::min(u, v);
break;
});
}
}
return cluster;
......@@ -45,8 +39,8 @@ at::Tensor graclus(at::Tensor row, at::Tensor col, int64_t num_nodes) {
at::Tensor weighted_graclus(at::Tensor row, at::Tensor col, at::Tensor weight,
int64_t num_nodes) {
std::tie(row, col) = remove_self_loops(row, col, weight);
std::tie(row, col, weight) = to_csr(row, col, weight);
std::tie(row, col, weight) = remove_self_loops(row, col, weight);
std::tie(row, col, weight) = to_csr(row, col, weight, num_nodes);
auto row_data = row.data<int64_t>(), col_data = col.data<int64_t>();
auto perm = randperm(num_nodes);
......@@ -57,7 +51,6 @@ at::Tensor weighted_graclus(at::Tensor row, at::Tensor col, at::Tensor weight,
AT_DISPATCH_ALL_TYPES(weight.type(), "weighted_graclus", [&] {
auto weight_data = weight.data<scalar_t>();
auto weight_data = weight.data<scalar_t>();
for (int64_t i = 0; i < num_nodes; i++) {
auto u = perm_data[i];
......@@ -65,21 +58,20 @@ at::Tensor weighted_graclus(at::Tensor row, at::Tensor col, at::Tensor weight,
if (cluster_data[u] >= 0)
continue;
cluster_data[u] = u;
int64_t v_max;
int64_t v_max = u;
scalar_t w_max = 0;
ITERATE_NEIGHBORS(u, v, row_data, col_data, {
for (int64_t j = row_data[u]; j < row_data[u + 1]; j++) {
auto v = col_data[j];
if (cluster_data[v] >= 0)
continue;
auto w = weight_data[e];
if (w >= w_max) {
if (weight_data[j] >= w_max) {
v_max = v;
w_max = w;
w_max = weight_data[j];
}
});
}
cluster_data[u] = std::min(u, v_max);
cluster_data[v_max] = std::min(u, v_max);
......
......@@ -5,14 +5,14 @@
std::tuple<at::Tensor, at::Tensor> remove_self_loops(at::Tensor row,
at::Tensor col) {
auto mask = row != col;
return make_tuple(row.masked_select(mask), col.masked_select(mask));
return std::make_tuple(row.masked_select(mask), col.masked_select(mask));
}
std::tuple<at::Tensor, at::Tensor, at::Tensor>
remove_self_loops(at::Tensor row, at::Tensor col, at::Tensor weight) {
auto mask = row != col;
return make_tuple(row.masked_select(mask), col.masked_select(mask),
weight.masked_select(mask));
return std::make_tuple(row.masked_select(mask), col.masked_select(mask),
weight.masked_select(mask));
}
at::Tensor randperm(int64_t n) {
......@@ -23,28 +23,41 @@ at::Tensor randperm(int64_t n) {
std::tuple<at::Tensor, at::Tensor> rand(at::Tensor row, at::Tensor col) {
auto perm = randperm(row.size(0));
return make_tuple(row.index_select(perm), col.index_select(perm));
return std::make_tuple(row.index_select(0, perm), col.index_select(0, perm));
}
std::tuple<at::Tensor, at::Tensor> sort_by_row(at::Tensor row, at::Tensor col) {
Tensor perm;
tie(row, perm) = row.sort();
col = col.index_select(0, perm);
return stack({row, col}, 0);
at::Tensor perm;
std::tie(row, perm) = row.sort();
return std::make_tuple(row, col.index_select(0, perm));
}
inline Tensor degree(Tensor row, int64_t num_nodes) {
auto zero = zeros(num_nodes, row.type());
auto one = ones(row.size(0), row.type());
std::tuple<at::Tensor, at::Tensor, at::Tensor>
sort_by_row(at::Tensor row, at::Tensor col, at::Tensor weight) {
at::Tensor perm;
std::tie(row, perm) = row.sort();
return std::make_tuple(row, col.index_select(0, perm),
weight.index_select(0, perm));
}
at::Tensor degree(at::Tensor row, int64_t num_nodes) {
auto zero = zeros(num_nodes, row.options());
auto one = ones(row.size(0), row.options());
return zero.scatter_add_(0, row, one);
}
inline tuple<Tensor, Tensor> to_csr(Tensor index, int64_t num_nodes) {
index = sort_by_row(index);
auto row = degree(index[0], num_nodes).cumsum(0);
row = cat({zeros(1, row.type()), row}, 0); // Prepend zero.
return make_tuple(row, index[1]);
std::tuple<at::Tensor, at::Tensor> to_csr(at::Tensor row, at::Tensor col,
int64_t num_nodes) {
std::tie(row, col) = sort_by_row(row, col);
row = degree(row, num_nodes).cumsum(0);
row = at::cat({zeros(1, row.options()), row}, 0); // Prepend zero.
return std::make_tuple(row, col);
}
// std::tie(row, col) = randperm(row, col);
// std::tie(row, col) = to_csr(row, col);
std::tuple<at::Tensor, at::Tensor, at::Tensor>
to_csr(at::Tensor row, at::Tensor col, at::Tensor weight, int64_t num_nodes) {
std::tie(row, col, weight) = sort_by_row(row, col, weight);
row = degree(row, num_nodes).cumsum(0);
row = at::cat({zeros(1, row.options()), row}, 0); // Prepend zero.
return std::make_tuple(row, col, weight);
}
#pragma once
#include <ATen/ATen.h>
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
#define BLUE_PROB 0.53406
__device__ int64_t done;
__global__ void init_done_kernel() { done = 1; }
__global__ void colorize_kernel(int64_t *cluster, float *__restrict__ bernoulli,
size_t numel) {
const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
for (int64_t u = index; u < numel; u += stride) {
if (cluster[u] < 0) {
cluster[u] = (int64_t)bernoulli[u] - 2;
done = 0;
}
}
}
int64_t colorize(at::Tensor cluster) {
init_done_kernel<<<1, 1>>>();
auto numel = cluster.size(0);
auto props = at::full(numel, BLUE_PROB, cluster.options().dtype(at::kFloat));
auto bernoulli = props.bernoulli();
colorize_kernel<<<BLOCKS(numel), THREADS>>>(cluster.data<int64_t>(),
bernoulli.data<float>(), numel);
int64_t out;
cudaMemcpyFromSymbol(&out, done, sizeof(out), 0, cudaMemcpyDeviceToHost);
return out;
}
#include <ATen/ATen.h>
#include "coloring.cuh"
#include "proposal.cuh"
#include "response.cuh"
#include "utils.cuh"
at::Tensor graclus_cuda(at::Tensor row, at::Tensor col, int64_t num_nodes) {
std::tie(row, col) = remove_self_loops(row, col);
std::tie(row, col) = rand(row, col);
std::tie(row, col) = to_csr(row, col, num_nodes);
auto cluster = at::full(num_nodes, -1, row.options());
auto proposal = at::full(num_nodes, -1, row.options());
while (!colorize(cluster)) {
propose(cluster, proposal, row, col);
respond(cluster, proposal, row, col);
}
return cluster;
}
at::Tensor weighted_graclus_cuda(at::Tensor row, at::Tensor col,
at::Tensor weight, int64_t num_nodes) {
auto cluster = at::full(num_nodes, -1, row.options());
return cluster;
}
// #include "color.cuh"
// #include "common.cuh"
// at::Tensor graclus(at::Tensor row, at::Tensor col, int num_nodes) {
// // Remove self-loops.
// auto mask = row != col;
// row = row.masked_select(mask);
// col.masked_select(mask);
// // Sort by row index.
// at::Tensor perm;
// std::tie(row, perm) = row.sort();
// col = col.index_select(0, perm);
std::tie(row, col, weight) = remove_self_loops(row, col, weight);
std::tie(row, col, weight) = to_csr(row, col, weight, num_nodes);
// // Generate helper vectors.
// auto cluster = at::full(row.type(), {num_nodes}, -1);
// auto prop = at::full(row.type(), {num_nodes}, -1);
// auto deg = degree(row, num_nodes);
// auto cum_deg = deg.cumsum(0);
// color(cluster);
// /* while (!color(cluster)) { */
// /* propose(cluster, prop, row, col, weight, deg, cum_deg); */
// /* response(cluster, prop, row, col, weight, deg, cum_deg); */
// /* } */
// return cluster;
// }
// at::Tensor weighted_graclus(at::Tensor row, at::Tensor col, at::Tensor
// weight,
// int num_nodes) {
// // Remove self-loops.
// auto mask = row != col;
// row = row.masked_select(mask);
// col = col.masked_select(mask);
// weight = weight.masked_select(mask);
// // Sort by row index.
// at::Tensor perm;
// std::tie(row, perm) = row.sort();
// col = col.index_select(0, perm);
// weight = weight.index_select(0, perm);
// // Generate helper vectors.
// auto cluster = at::full(row.type(), {num_nodes}, -1);
// auto prop = at::full(row.type(), {num_nodes}, -1);
// auto deg = degree(row, num_nodes);
// auto cum_deg = deg.cumsum(0);
// color(cluster);
auto cluster = at::full(num_nodes, -1, row.options());
auto proposal = at::full(num_nodes, -1, row.options());
// /* while (!color(cluster)) { */
// /* weighted_propose(cluster, prop, row, col, weight, deg, cum_deg); */
// /* weighted_response(cluster, prop, row, col, weight, deg, cum_deg); */
// /* } */
while (!colorize(cluster)) {
propose(cluster, proposal, row, col, weight);
respond(cluster, proposal, row, col, weight);
}
// return cluster;
// }
return cluster;
}
#pragma once
#include <ATen/ATen.h>
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
__global__ void propose_kernel(int64_t *__restrict__ cluster, int64_t *proposal,
int64_t *__restrict row,
int64_t *__restrict__ col, size_t numel) {
const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
for (int64_t u = index; u < numel; u += stride) {
if (cluster[u] != -1)
continue; // Only vist blue nodes.
bool has_unmatched_neighbor = false;
for (int64_t i = row[u]; i < row[u + 1]; i++) {
auto v = col[i];
if (cluster[v] < 0)
has_unmatched_neighbor = true; // Unmatched neighbor found.
if (cluster[v] == -2) {
proposal[u] = v; // Propose to first red neighbor.
break;
}
}
if (!has_unmatched_neighbor)
cluster[u] = u;
}
}
void propose(at::Tensor cluster, at::Tensor proposal, at::Tensor row,
at::Tensor col) {
propose_kernel<<<BLOCKS(cluster.numel()), THREADS>>>(
cluster.data<int64_t>(), proposal.data<int64_t>(), row.data<int64_t>(),
col.data<int64_t>(), cluster.numel());
}
template <typename scalar_t>
__global__ void propose_kernel(int64_t *__restrict__ cluster, int64_t *proposal,
int64_t *__restrict row,
int64_t *__restrict__ col,
scalar_t *__restrict__ weight, size_t numel) {
const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
for (int64_t u = index; u < numel; u += stride) {
if (cluster[u] != -1)
continue; // Only vist blue nodes.
bool has_unmatched_neighbor = false;
int64_t v_max = -1;
scalar_t w_max = 0;
for (int64_t i = row[u]; i < row[u + 1]; i++) {
auto v = col[i];
if (cluster[v] < 0)
has_unmatched_neighbor = true; // Unmatched neighbor found.
// Find maximum weighted red neighbor.
if (cluster[v] == -2 && weight[i] >= w_max) {
v_max = v;
w_max = weight[i];
}
}
proposal[u] = v_max; // Propose.
if (!has_unmatched_neighbor)
cluster[u] = u;
}
}
void propose(at::Tensor cluster, at::Tensor proposal, at::Tensor row,
at::Tensor col, at::Tensor weight) {
AT_DISPATCH_ALL_TYPES(weight.type(), "propose_kernel", [&] {
propose_kernel<scalar_t><<<BLOCKS(cluster.numel()), THREADS>>>(
cluster.data<int64_t>(), proposal.data<int64_t>(), row.data<int64_t>(),
col.data<int64_t>(), weight.data<scalar_t>(), cluster.numel());
});
}
#pragma once
#include <ATen/ATen.h>
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
__global__ void respond_kernel(int64_t *__restrict__ cluster, int64_t *proposal,
int64_t *__restrict row,
int64_t *__restrict__ col, size_t numel) {
const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
for (int64_t u = index; u < numel; u += stride) {
if (cluster[u] != -2)
continue; // Only vist red nodes.
bool has_unmatched_neighbor = false;
for (int64_t i = row[u]; i < row[u + 1]; i++) {
auto v = col[i];
if (cluster[v] < 0)
has_unmatched_neighbor = true; // Unmatched neighbor found.
if (cluster[v] == -1 && proposal[v] == u) {
// Match first blue neighbhor v which proposed to u.
cluster[u] = min(u, v);
cluster[v] = min(u, v);
break;
}
}
if (!has_unmatched_neighbor)
cluster[u] = u;
}
}
void respond(at::Tensor cluster, at::Tensor proposal, at::Tensor row,
at::Tensor col) {
respond_kernel<<<BLOCKS(cluster.numel()), THREADS>>>(
cluster.data<int64_t>(), proposal.data<int64_t>(), row.data<int64_t>(),
col.data<int64_t>(), cluster.numel());
}
template <typename scalar_t>
__global__ void respond_kernel(int64_t *__restrict__ cluster, int64_t *proposal,
int64_t *__restrict row,
int64_t *__restrict__ col,
scalar_t *__restrict__ weight, size_t numel) {
const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
for (int64_t u = index; u < numel; u += stride) {
if (cluster[u] != -2)
continue; // Only vist red nodes.
bool has_unmatched_neighbor = false;
int64_t v_max = -1;
scalar_t w_max = 0;
for (int64_t i = row[u]; i < row[u + 1]; i++) {
auto v = col[i];
if (cluster[v] < 0)
has_unmatched_neighbor = true; // Unmatched neighbor found.
if (cluster[v] == -1 && proposal[v] == u && weight[i] >= w_max) {
// Find maximum weighted blue neighbhor v which proposed to u.
v_max = v;
w_max = weight[i];
}
}
if (v_max >= 0) {
cluster[u] = min(u, v_max); // Match neighbors.
cluster[v_max] = min(u, v_max);
}
if (!has_unmatched_neighbor)
cluster[u] = u;
}
}
void respond(at::Tensor cluster, at::Tensor proposal, at::Tensor row,
at::Tensor col, at::Tensor weight) {
AT_DISPATCH_ALL_TYPES(weight.type(), "respond_kernel", [&] {
respond_kernel<scalar_t><<<BLOCKS(cluster.numel()), THREADS>>>(
cluster.data<int64_t>(), proposal.data<int64_t>(), row.data<int64_t>(),
col.data<int64_t>(), weight.data<scalar_t>(), cluster.numel());
});
}
#pragma once
#include <ATen/ATen.h>
std::tuple<at::Tensor, at::Tensor> remove_self_loops(at::Tensor row,
at::Tensor col) {
auto mask = row != col;
return std::make_tuple(row.masked_select(mask), col.masked_select(mask));
}
std::tuple<at::Tensor, at::Tensor, at::Tensor>
remove_self_loops(at::Tensor row, at::Tensor col, at::Tensor weight) {
auto mask = row != col;
return std::make_tuple(row.masked_select(mask), col.masked_select(mask),
weight.masked_select(mask));
}
std::tuple<at::Tensor, at::Tensor> rand(at::Tensor row, at::Tensor col) {
auto perm = at::empty(row.size(0), row.options());
at::randperm_out(perm, row.size(0));
return std::make_tuple(row.index_select(0, perm), col.index_select(0, perm));
}
std::tuple<at::Tensor, at::Tensor> sort_by_row(at::Tensor row, at::Tensor col) {
at::Tensor perm;
std::tie(row, perm) = row.sort();
return std::make_tuple(row, col.index_select(0, perm));
}
std::tuple<at::Tensor, at::Tensor, at::Tensor>
sort_by_row(at::Tensor row, at::Tensor col, at::Tensor weight) {
at::Tensor perm;
std::tie(row, perm) = row.sort();
return std::make_tuple(row, col.index_select(0, perm),
weight.index_select(0, perm));
}
at::Tensor degree(at::Tensor row, int64_t num_nodes) {
auto zero = zeros(num_nodes, row.options());
auto one = ones(row.size(0), row.options());
return zero.scatter_add_(0, row, one);
}
std::tuple<at::Tensor, at::Tensor> to_csr(at::Tensor row, at::Tensor col,
int64_t num_nodes) {
std::tie(row, col) = sort_by_row(row, col);
row = degree(row, num_nodes).cumsum(0);
row = at::cat({zeros(1, row.options()), row}, 0); // Prepend zero.
return std::make_tuple(row, col);
}
std::tuple<at::Tensor, at::Tensor, at::Tensor>
to_csr(at::Tensor row, at::Tensor col, at::Tensor weight, int64_t num_nodes) {
std::tie(row, col, weight) = sort_by_row(row, col, weight);
row = degree(row, num_nodes).cumsum(0);
row = at::cat({zeros(1, row.options()), row}, 0); // Prepend zero.
return std::make_tuple(row, col, weight);
}
......@@ -15,10 +15,6 @@ tests = [{
'weight': [1, 2, 1, 3, 2, 2, 3, 1, 2, 1],
}]
devices = [torch.device('cpu')]
dtypes = [torch.float]
tests = [tests[0]]
def assert_correct(row, col, cluster):
row, col, cluster = row.to('cpu'), col.to('cpu'), cluster.to('cpu')
......@@ -51,5 +47,4 @@ def test_graclus_cluster(test, dtype, device):
weight = tensor(test.get('weight'), dtype, device)
cluster = graclus_cluster(row, col, weight)
print(cluster)
# assert_correct(row, col, cluster)
assert_correct(row, col, cluster)
# from .utils.loop import remove_self_loops
# from .utils.perm import randperm, sort_row, randperm_sort_row
# from .utils.ffi import graclus
import torch
import graclus_cpu
......@@ -38,9 +34,3 @@ def graclus_cluster(row, col, weight=None, num_nodes=None):
cluster = op.weighted_graclus(row, col, weight, num_nodes)
return cluster
# if row.is_cuda:
# row, col = sort_row(row, col)
# else:
# row, col = randperm(row, col)
# row, col = randperm_sort_row(row, col, num_nodes)
from .._ext import ffi
def get_func(name, is_cuda, tensor=None):
prefix = 'THCC' if is_cuda else 'TH'
prefix += 'Tensor' if tensor is None else tensor.type().split('.')[-1]
return getattr(ffi, '{}_{}'.format(prefix, name))
def graclus(self, row, col, weight=None):
func = get_func('graclus', self.is_cuda, weight)
func(self, row, col) if weight is None else func(self, row, col, weight)
def grid(self, pos, size, count):
func = get_func('grid', self.is_cuda, pos)
func(self, pos, size, count)
def remove_self_loops(row, col):
mask = row != col
return row[mask], col[mask]
import torch
def randperm(row, col):
# Randomly reorder row and column indices.
edge_rid = torch.randperm(row.size(0))
return row[edge_rid], col[edge_rid]
def sort_row(row, col):
# Sort row and column indices row-wise.
row, perm = row.sort()
col = col[perm]
return row, col
def randperm_sort_row(row, col, num_nodes):
# Randomly change row indices to new values.
node_rid = torch.randperm(num_nodes)
row = node_rid[row]
# Sort row and column indices row-wise.
row, col = sort_row(row, col)
# Revert previous row value changes to old indices.
row = node_rid.sort()[1][row]
return row, col
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment