Commit 5a485e98 authored by rusty1s's avatar rusty1s
Browse files

cuda complete

parent 06d9038f
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
torch::Tensor graclus_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor graclus_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_weight) { torch::optional<torch::Tensor> optional_weight) {
CHECK_CPU(rowptr); CHECK_CPU(rowptr);
CHECK_CPU(col); CHECK_CPU(col);
CHECK_INPUT(rowptr.dim() == 1 && col.dim() == 1); CHECK_INPUT(rowptr.dim() == 1 && col.dim() == 1);
...@@ -33,11 +32,9 @@ torch::Tensor graclus_cpu(torch::Tensor rowptr, torch::Tensor col, ...@@ -33,11 +32,9 @@ torch::Tensor graclus_cpu(torch::Tensor rowptr, torch::Tensor col,
out_data[u] = u; out_data[u] = u;
int64_t row_start = rowptr_data[u], row_end = rowptr_data[u + 1]; int64_t row_start = rowptr_data[u], row_end = rowptr_data[u + 1];
auto edge_perm = torch::randperm(row_end - row_start, rowptr.options());
auto edge_perm_data = edge_perm.data_ptr<int64_t>();
for (auto e = 0; e < row_end - row_start; e++) { for (auto e = 0; e < row_end - row_start; e++) {
auto v = col_data[row_start + edge_perm_data[e]]; auto v = col_data[row_start + e];
if (out_data[v] >= 0) if (out_data[v] >= 0)
continue; continue;
......
#pragma once
#include <ATen/ATen.h>
#include "compat.cuh"
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
#define BLUE_PROB 0.53406
__device__ int64_t done;
__global__ void init_done_kernel() { done = 1; }
__global__ void colorize_kernel(int64_t *cluster, float *__restrict__ bernoulli,
size_t numel) {
const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
for (int64_t u = index; u < numel; u += stride) {
if (cluster[u] < 0) {
cluster[u] = (int64_t)bernoulli[u] - 2;
done = 0;
}
}
}
int64_t colorize(at::Tensor cluster) {
init_done_kernel<<<1, 1>>>();
auto numel = cluster.size(0);
auto props = at::full(numel, BLUE_PROB, cluster.options().dtype(at::kFloat));
auto bernoulli = props.bernoulli();
colorize_kernel<<<BLOCKS(numel), THREADS>>>(
cluster.DATA_PTR<int64_t>(), bernoulli.DATA_PTR<float>(), numel);
int64_t out;
cudaMemcpyFromSymbol(&out, done, sizeof(out), 0, cudaMemcpyDeviceToHost);
return out;
}
...@@ -6,15 +6,11 @@ ...@@ -6,15 +6,11 @@
#define THREADS 1024 #define THREADS 1024
inline torch::Tensor get_dist(torch::Tensor x, int64_t idx) {
return (x - x[idx]).norm(2, 1);
}
template <typename scalar_t> struct Dist<scalar_t> { template <typename scalar_t> struct Dist<scalar_t> {
static inline __device__ void compute(int64_t idx, int64_t start_idx, static inline __device__ void compute(int64_t idx, int64_t start_idx,
int64_t end_idx, int64_t old, int64_t end_idx, int64_t old,
scalar_t *best, int64_t *best_idx, scalar_t *best, int64_t *best_idx,
const scalar_t *x, scalar_t *dist, const scalar_t *src, scalar_t *dist,
scalar_t *tmp_dist, int64_t dim) { scalar_t *tmp_dist, int64_t dim) {
for (int64_t n = start_idx + idx; n < end_idx; n += THREADS) { for (int64_t n = start_idx + idx; n < end_idx; n += THREADS) {
...@@ -23,7 +19,7 @@ template <typename scalar_t> struct Dist<scalar_t> { ...@@ -23,7 +19,7 @@ template <typename scalar_t> struct Dist<scalar_t> {
__syncthreads(); __syncthreads();
for (int64_t i = start_idx * dim + idx; i < end_idx * dim; i += THREADS) { for (int64_t i = start_idx * dim + idx; i < end_idx * dim; i += THREADS) {
scalar_t d = x[(old * dim) + (i % dim)] - x[i]; scalar_t d = src[(old * dim) + (i % dim)] - src[i];
atomicAdd(&tmp_dist[i / dim], d * d); atomicAdd(&tmp_dist[i / dim], d * d);
} }
...@@ -39,7 +35,7 @@ template <typename scalar_t> struct Dist<scalar_t> { ...@@ -39,7 +35,7 @@ template <typename scalar_t> struct Dist<scalar_t> {
}; };
template <typename scalar_t> template <typename scalar_t>
__global__ void fps_kernel(const scalar_t *x, const int64_t *ptr, __global__ void fps_kernel(const scalar_t *src, const int64_t *ptr,
const int64_t *out_ptr, const int64_t *start, const int64_t *out_ptr, const int64_t *start,
scalar_t *dist, scalar_t *tmp_dist, int64_t *out, scalar_t *dist, scalar_t *tmp_dist, int64_t *out,
int64_t dim) { int64_t dim) {
...@@ -63,7 +59,7 @@ __global__ void fps_kernel(const scalar_t *x, const int64_t *ptr, ...@@ -63,7 +59,7 @@ __global__ void fps_kernel(const scalar_t *x, const int64_t *ptr,
__syncthreads(); __syncthreads();
Dist<scalar_t, Dim>::compute(thread_idx, start_idx, end_idx, out[m - 1], Dist<scalar_t, Dim>::compute(thread_idx, start_idx, end_idx, out[m - 1],
&best, &best_idx, x, dist, tmp_dist, dim); &best, &best_idx, src, dist, tmp_dist, dim);
best_dist[idx] = best; best_dist[idx] = best;
best_dist_idx[idx] = best_idx; best_dist_idx[idx] = best_idx;
...@@ -94,6 +90,7 @@ torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr, double ratio, ...@@ -94,6 +90,7 @@ torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr, double ratio,
CHECK_CUDA(ptr); CHECK_CUDA(ptr);
CHECK_INPUT(ptr.dim() == 1); CHECK_INPUT(ptr.dim() == 1);
AT_ASSERTM(ratio > 0 and ratio < 1, "Invalid input"); AT_ASSERTM(ratio > 0 and ratio < 1, "Invalid input");
cudaSetDevice(src.get_device());
src = src.view({src.size(0), -1}).contiguous(); src = src.view({src.size(0), -1}).contiguous();
ptr = ptr.contiguous(); ptr = ptr.contiguous();
...@@ -106,7 +103,7 @@ torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr, double ratio, ...@@ -106,7 +103,7 @@ torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr, double ratio,
torch::Tensor start; torch::Tensor start;
if (random_start) { if (random_start) {
start = at::rand(batch_size, src.options()); start = torch::rand(batch_size, src.options());
start = (start * deg.toType(torch::kFloat)).toType(torch::kLong); start = (start * deg.toType(torch::kFloat)).toType(torch::kLong);
} else { } else {
start = torch::zeros(batch_size, ptr.options()); start = torch::zeros(batch_size, ptr.options());
...@@ -118,7 +115,7 @@ torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr, double ratio, ...@@ -118,7 +115,7 @@ torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr, double ratio,
auto out_size = (int64_t *)malloc(sizeof(int64_t)); auto out_size = (int64_t *)malloc(sizeof(int64_t));
cudaMemcpy(out_size, out_ptr[-1].data_ptr<int64_t>(), sizeof(int64_t), cudaMemcpy(out_size, out_ptr[-1].data_ptr<int64_t>(), sizeof(int64_t),
cudaMemcpyDeviceToHost); cudaMemcpyDeviceToHost);
auto out = at::empty(out_size[0], out_ptr.options()); auto out = torch::empty(out_size[0], out_ptr.options());
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "fps_kernel", [&] { AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "fps_kernel", [&] {
......
#include <torch/extension.h>
#define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
at::Tensor graclus_cuda(at::Tensor row, at::Tensor col, int64_t num_nodes);
at::Tensor weighted_graclus_cuda(at::Tensor row, at::Tensor col,
at::Tensor weight, int64_t num_nodes);
at::Tensor graclus(at::Tensor row, at::Tensor col, int64_t num_nodes) {
CHECK_CUDA(row);
CHECK_CUDA(col);
return graclus_cuda(row, col, num_nodes);
}
at::Tensor weighted_graclus(at::Tensor row, at::Tensor col, at::Tensor weight,
int64_t num_nodes) {
CHECK_CUDA(row);
CHECK_CUDA(col);
CHECK_CUDA(weight);
return weighted_graclus_cuda(row, col, weight, num_nodes);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("graclus", &graclus, "Graclus (CUDA)");
m.def("weighted_graclus", &weighted_graclus, "Weighted Graclus (CUDA)");
}
#include "graclus_cuda.h"
#include <ATen/cuda/CUDAContext.h>
#include "utils.h"
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
#define BLUE_P 0.53406
torch::Tensor graclus_cuda(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_weight) {
CHECK_CUDA(rowptr);
CHECK_CUDA(col);
CHECK_INPUT(rowptr.dim() == 1 && col.dim() == 1);
if (optional_weight.has_value()) {
CHECK_CUDA(optional_weight.value());
CHECK_INPUT(optional_weight.value().dim() == 1);
CHECK_INPUT(optional_weight.value().numel() == col.numel());
}
cudaSetDevice(rowptr.get_device());
int64_t num_nodes = rowptr.numel() - 1;
auto out = torch::full(num_nodes, -1, rowptr.options());
auto proposal = torch::full(num_nodes, -1, rowptr.options());
while (!colorize(out)) {
propose(out, proposal, rowptr, col, optional_weight);
respond(out, proposal, rowptr, col, optional_weight);
}
return out;
}
__device__ int64_t done_d;
__global__ void init_done_kernel() { done_d = 1; }
__global__ void colorize_kernel(int64_t *out, const float *bernoulli,
int64_t numel) {
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_idx < numel) {
if (out[u] < 0) {
out[u] = (int64_t)bernoulli[u] - 2;
done_d = 0;
}
}
}
int64_t colorize(torch::Tensor out) {
auto stream = at::cuda::getCurrentCUDAStream();
init_done_kernel<<<1, 1, 0, stream>>>();
auto numel = cluster.size(0);
auto props = torch::full(numel, BLUE_P, out.options().dtype(torch::kFloat));
auto bernoulli = props.bernoulli();
colorize_kernel<<<BLOCKS(numel), THREADS, 0, stream>>>(
out.data_ptr<int64_t>(), bernoulli.data_ptr<float>(), numel);
int64_t done_h;
cudaMemcpyFromSymbol(&done_h, done_d, sizeof(done_h), 0,
cudaMemcpyDeviceToHost);
return done_h;
}
__global__ void propose_kernel(int64_t *out, int64_t *proposal,
const int64_t *rowptr, const int64_t *col,
int64_t numel) {
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_idx < numel) {
if (out[u] != -1)
continue; // Only vist blue nodes.
bool has_unmatched_neighbor = false;
for (int64_t i = rowptr[u]; i < rowptr[u + 1]; i++) {
auto v = col[i];
if (out[v] < 0)
has_unmatched_neighbor = true; // Unmatched neighbor found.
if (out[v] == -2) {
proposal[u] = v; // Propose to first red neighbor.
break;
}
}
if (!has_unmatched_neighbor)
out[u] = u;
}
}
template <typename scalar_t>
__global__ void weighted_propose_kernel(int64_t *out, int64_t *proposal,
const int64_t *rowptr,
const int64_t *col,
const scalar_t *weight, int64_t numel) {
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_idx < numel) {
if (out[u] != -1)
continue; // Only vist blue nodes.
bool has_unmatched_neighbor = false;
int64_t v_max = -1;
scalar_t w_max = 0;
for (int64_t i = rowptr[u]; i < rowptr[u + 1]; i++) {
auto v = col[i];
if (out[v] < 0)
has_unmatched_neighbor = true; // Unmatched neighbor found.
// Find maximum weighted red neighbor.
if (out[v] == -2 && weight[i] >= w_max) {
v_max = v;
w_max = weight[i];
}
}
proposal[u] = v_max; // Propose.
if (!has_unmatched_neighbor)
out[u] = u;
}
}
void propose(torch::Tensor out, torch::Tensor proposal, torch::Tensor rowptr,
torch::Tensor col, torch::optional<torch::Tensor> optional_weight) {
auto stream = at::cuda::getCurrentCUDAStream();
if (!optional_weight.has_value()) {
propose_kernel<<<BLOCKS(out.numel()), THREADS, 0, stream>>>(
out.data_ptr<int64_t>(), proposal.data_ptr<int64_t>(),
row.data_ptr<int64_t>(), col.data_ptr<int64_t>(), out.numel());
} else {
auto = optional_weight.value();
AT_DISPATCH_ALL_TYPES(weight.scalar_type(), "propose_kernel", [&] {
weighted_propose_kernel<scalar_t>
<<<BLOCKS(out.numel()), THREADS, 0, stream>>>(
out.data_ptr<int64_t>(), proposal.data_ptr<int64_t>(),
rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(),
weight.data_ptr<scalar_t>(), out.numel());
});
}
}
__global__ void respond_kernel(int64_t *out, const int64_t *proposal,
const int64_t *rowptr, const int64_t *col,
int64_t numel) {
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_idx < numel) {
if (out[u] != -2)
continue; // Only vist red nodes.
bool has_unmatched_neighbor = false;
for (int64_t i = rowptr[u]; i < rowptr[u + 1]; i++) {
auto v = col[i];
if (out[v] < 0)
has_unmatched_neighbor = true; // Unmatched neighbor found.
if (out[v] == -1 && proposal[v] == u) {
// Match first blue neighbhor v which proposed to u.
out[u] = min(u, v);
out[v] = min(u, v);
break;
}
}
if (!has_unmatched_neighbor)
cluster[u] = u;
}
}
template <typename scalar_t>
__global__ void weighted_respond_kernel(int64_t *out, const int64_t *proposal,
const int64_t *rowptr,
const int64_t *col,
const scalar_t *weight, int64_t numel) {
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_idx < numel) {
if (out[u] != -2)
continue; // Only vist red nodes.
bool has_unmatched_neighbor = false;
int64_t v_max = -1;
scalar_t w_max = 0;
for (int64_t i = rowptr[u]; i < rowptr[u + 1]; i++) {
auto v = col[i];
if (out[v] < 0)
has_unmatched_neighbor = true; // Unmatched neighbor found.
if (out[v] == -1 && proposal[v] == u && weight[i] >= w_max) {
// Find maximum weighted blue neighbhor v which proposed to u.
v_max = v;
w_max = weight[i];
}
}
if (v_max >= 0) {
out[u] = min(u, v_max); // Match neighbors.
out[v_max] = min(u, v_max);
}
if (!has_unmatched_neighbor)
out[u] = u;
}
}
void respond(torch::Tensor out, torch::Tensor proposal, torch::Tensor rowptr,
torch::Tensor col, torch::optional<torch::Tensor> optional_weight) {
auto stream = at::cuda::getCurrentCUDAStream();
if (!optional_weight.has_value()) {
respond_kernel<<<BLOCKS(out.numel()), THREADS, 0, stream>>>(
out.data_ptr<int64_t>(), proposal.data_ptr<int64_t>(),
rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(), out.numel());
} else {
auto = optional_weight.value();
AT_DISPATCH_ALL_TYPES(weight.scalar_type(), "respond_kernel", [&] {
respond_kernel<scalar_t><<<BLOCKS(out.numel()), THREADS, 0, stream>>>(
out.data_ptr<int64_t>(), proposal.data_ptr<int64_t>(),
rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(),
weight.data_ptr<scalar_t>(), out.numel());
});
}
}
#pragma once
#include <torch/extension.h>
torch::Tensor graclus_cuda(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_weight);
#include <ATen/ATen.h>
#include "coloring.cuh"
#include "proposal.cuh"
#include "response.cuh"
#include "utils.cuh"
at::Tensor graclus_cuda(at::Tensor row, at::Tensor col, int64_t num_nodes) {
cudaSetDevice(row.get_device());
std::tie(row, col) = remove_self_loops(row, col);
std::tie(row, col) = rand(row, col);
std::tie(row, col) = to_csr(row, col, num_nodes);
auto cluster = at::full(num_nodes, -1, row.options());
auto proposal = at::full(num_nodes, -1, row.options());
while (!colorize(cluster)) {
propose(cluster, proposal, row, col);
respond(cluster, proposal, row, col);
}
return cluster;
}
at::Tensor weighted_graclus_cuda(at::Tensor row, at::Tensor col,
at::Tensor weight, int64_t num_nodes) {
cudaSetDevice(row.get_device());
std::tie(row, col, weight) = remove_self_loops(row, col, weight);
std::tie(row, col, weight) = to_csr(row, col, weight, num_nodes);
auto cluster = at::full(num_nodes, -1, row.options());
auto proposal = at::full(num_nodes, -1, row.options());
while (!colorize(cluster)) {
propose(cluster, proposal, row, col, weight);
respond(cluster, proposal, row, col, weight);
}
return cluster;
}
...@@ -100,5 +100,5 @@ torch::Tensor knn_cuda(torch::Tensor x, torch::Tensor y, torch::Tensor ptr_x, ...@@ -100,5 +100,5 @@ torch::Tensor knn_cuda(torch::Tensor x, torch::Tensor y, torch::Tensor ptr_x,
}); });
auto mask = col != -1; auto mask = col != -1;
return at::stack({row.masked_select(mask), col.masked_select(mask)}, 0); return torch::stack({row.masked_select(mask), col.masked_select(mask)}, 0);
} }
#pragma once
#include <ATen/ATen.h>
#include "compat.cuh"
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
__global__ void propose_kernel(int64_t *__restrict__ cluster, int64_t *proposal,
int64_t *__restrict row,
int64_t *__restrict__ col, size_t numel) {
const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
for (int64_t u = index; u < numel; u += stride) {
if (cluster[u] != -1)
continue; // Only vist blue nodes.
bool has_unmatched_neighbor = false;
for (int64_t i = row[u]; i < row[u + 1]; i++) {
auto v = col[i];
if (cluster[v] < 0)
has_unmatched_neighbor = true; // Unmatched neighbor found.
if (cluster[v] == -2) {
proposal[u] = v; // Propose to first red neighbor.
break;
}
}
if (!has_unmatched_neighbor)
cluster[u] = u;
}
}
void propose(at::Tensor cluster, at::Tensor proposal, at::Tensor row,
at::Tensor col) {
propose_kernel<<<BLOCKS(cluster.numel()), THREADS>>>(
cluster.DATA_PTR<int64_t>(), proposal.DATA_PTR<int64_t>(),
row.DATA_PTR<int64_t>(), col.DATA_PTR<int64_t>(), cluster.numel());
}
template <typename scalar_t>
__global__ void propose_kernel(int64_t *__restrict__ cluster, int64_t *proposal,
int64_t *__restrict row,
int64_t *__restrict__ col,
scalar_t *__restrict__ weight, size_t numel) {
const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
for (int64_t u = index; u < numel; u += stride) {
if (cluster[u] != -1)
continue; // Only vist blue nodes.
bool has_unmatched_neighbor = false;
int64_t v_max = -1;
scalar_t w_max = 0;
for (int64_t i = row[u]; i < row[u + 1]; i++) {
auto v = col[i];
if (cluster[v] < 0)
has_unmatched_neighbor = true; // Unmatched neighbor found.
// Find maximum weighted red neighbor.
if (cluster[v] == -2 && weight[i] >= w_max) {
v_max = v;
w_max = weight[i];
}
}
proposal[u] = v_max; // Propose.
if (!has_unmatched_neighbor)
cluster[u] = u;
}
}
void propose(at::Tensor cluster, at::Tensor proposal, at::Tensor row,
at::Tensor col, at::Tensor weight) {
AT_DISPATCH_ALL_TYPES(weight.scalar_type(), "propose_kernel", [&] {
propose_kernel<scalar_t><<<BLOCKS(cluster.numel()), THREADS>>>(
cluster.DATA_PTR<int64_t>(), proposal.DATA_PTR<int64_t>(),
row.DATA_PTR<int64_t>(), col.DATA_PTR<int64_t>(),
weight.DATA_PTR<scalar_t>(), cluster.numel());
});
}
#pragma once
#include <ATen/ATen.h>
#include "compat.cuh"
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
__global__ void respond_kernel(int64_t *__restrict__ cluster, int64_t *proposal,
int64_t *__restrict row,
int64_t *__restrict__ col, size_t numel) {
const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
for (int64_t u = index; u < numel; u += stride) {
if (cluster[u] != -2)
continue; // Only vist red nodes.
bool has_unmatched_neighbor = false;
for (int64_t i = row[u]; i < row[u + 1]; i++) {
auto v = col[i];
if (cluster[v] < 0)
has_unmatched_neighbor = true; // Unmatched neighbor found.
if (cluster[v] == -1 && proposal[v] == u) {
// Match first blue neighbhor v which proposed to u.
cluster[u] = min(u, v);
cluster[v] = min(u, v);
break;
}
}
if (!has_unmatched_neighbor)
cluster[u] = u;
}
}
void respond(at::Tensor cluster, at::Tensor proposal, at::Tensor row,
at::Tensor col) {
respond_kernel<<<BLOCKS(cluster.numel()), THREADS>>>(
cluster.DATA_PTR<int64_t>(), proposal.DATA_PTR<int64_t>(),
row.DATA_PTR<int64_t>(), col.DATA_PTR<int64_t>(), cluster.numel());
}
template <typename scalar_t>
__global__ void respond_kernel(int64_t *__restrict__ cluster, int64_t *proposal,
int64_t *__restrict row,
int64_t *__restrict__ col,
scalar_t *__restrict__ weight, size_t numel) {
const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
for (int64_t u = index; u < numel; u += stride) {
if (cluster[u] != -2)
continue; // Only vist red nodes.
bool has_unmatched_neighbor = false;
int64_t v_max = -1;
scalar_t w_max = 0;
for (int64_t i = row[u]; i < row[u + 1]; i++) {
auto v = col[i];
if (cluster[v] < 0)
has_unmatched_neighbor = true; // Unmatched neighbor found.
if (cluster[v] == -1 && proposal[v] == u && weight[i] >= w_max) {
// Find maximum weighted blue neighbhor v which proposed to u.
v_max = v;
w_max = weight[i];
}
}
if (v_max >= 0) {
cluster[u] = min(u, v_max); // Match neighbors.
cluster[v_max] = min(u, v_max);
}
if (!has_unmatched_neighbor)
cluster[u] = u;
}
}
void respond(at::Tensor cluster, at::Tensor proposal, at::Tensor row,
at::Tensor col, at::Tensor weight) {
AT_DISPATCH_ALL_TYPES(weight.scalar_type(), "respond_kernel", [&] {
respond_kernel<scalar_t><<<BLOCKS(cluster.numel()), THREADS>>>(
cluster.DATA_PTR<int64_t>(), proposal.DATA_PTR<int64_t>(),
row.DATA_PTR<int64_t>(), col.DATA_PTR<int64_t>(),
weight.DATA_PTR<scalar_t>(), cluster.numel());
});
}
...@@ -5,62 +5,3 @@ ...@@ -5,62 +5,3 @@
#define CHECK_CUDA(x) \ #define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor") AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch") #define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
////////////////////////////////////////////////////////////////////////
#include <ATen/ATen.h>
std::tuple<at::Tensor, at::Tensor> remove_self_loops(at::Tensor row,
at::Tensor col) {
auto mask = row != col;
return std::make_tuple(row.masked_select(mask), col.masked_select(mask));
}
std::tuple<at::Tensor, at::Tensor, at::Tensor>
remove_self_loops(at::Tensor row, at::Tensor col, at::Tensor weight) {
auto mask = row != col;
return std::make_tuple(row.masked_select(mask), col.masked_select(mask),
weight.masked_select(mask));
}
std::tuple<at::Tensor, at::Tensor> rand(at::Tensor row, at::Tensor col) {
auto perm = at::empty(row.size(0), row.options());
at::randperm_out(perm, row.size(0));
return std::make_tuple(row.index_select(0, perm), col.index_select(0, perm));
}
std::tuple<at::Tensor, at::Tensor> sort_by_row(at::Tensor row, at::Tensor col) {
at::Tensor perm;
std::tie(row, perm) = row.sort();
return std::make_tuple(row, col.index_select(0, perm));
}
std::tuple<at::Tensor, at::Tensor, at::Tensor>
sort_by_row(at::Tensor row, at::Tensor col, at::Tensor weight) {
at::Tensor perm;
std::tie(row, perm) = row.sort();
return std::make_tuple(row, col.index_select(0, perm),
weight.index_select(0, perm));
}
at::Tensor degree(at::Tensor row, int64_t num_nodes) {
auto zero = at::zeros(num_nodes, row.options());
auto one = at::ones(row.size(0), row.options());
return zero.scatter_add_(0, row, one);
}
std::tuple<at::Tensor, at::Tensor> to_csr(at::Tensor row, at::Tensor col,
int64_t num_nodes) {
std::tie(row, col) = sort_by_row(row, col);
row = degree(row, num_nodes).cumsum(0);
row = at::cat({at::zeros(1, row.options()), row}, 0); // Prepend zero.
return std::make_tuple(row, col);
}
std::tuple<at::Tensor, at::Tensor, at::Tensor>
to_csr(at::Tensor row, at::Tensor col, at::Tensor weight, int64_t num_nodes) {
std::tie(row, col, weight) = sort_by_row(row, col, weight);
row = degree(row, num_nodes).cumsum(0);
row = at::cat({at::zeros(1, row.options()), row}, 0); // Prepend zero.
return std::make_tuple(row, col, weight);
}
...@@ -32,7 +32,17 @@ def graclus_cluster(row: torch.Tensor, col: torch.Tensor, ...@@ -32,7 +32,17 @@ def graclus_cluster(row: torch.Tensor, col: torch.Tensor,
if num_nodes is None: if num_nodes is None:
num_nodes = max(int(row.max()), int(col.max())) + 1 num_nodes = max(int(row.max()), int(col.max())) + 1
perm = torch.argsort(row * num_nodes + col) # Remove self-loops.
mask = row == col
row, col = row[mask], col[mask]
# Randomly shuffle nodes.
if weight is not None:
perm = torch.randperm(row.size(0), device=row.device)
row, col = row[perm], col[perm]
# To CSR.
perm = torch.argsort(row)
row, col = row[perm], col[perm] row, col = row[perm], col[perm]
deg = row.new_zeros(num_nodes) deg = row.new_zeros(num_nodes)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment