Commit 06df4d9b authored by rusty1s's avatar rusty1s
Browse files

graclus fix

parent 8c8014b9
......@@ -2,61 +2,37 @@
#include <ATen/cuda/CUDAContext.h>
#include "utils.h"
#include "utils.cuh"
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
#define BLUE_P 0.53406
torch::Tensor graclus_cuda(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_weight) {
CHECK_CUDA(rowptr);
CHECK_CUDA(col);
CHECK_INPUT(rowptr.dim() == 1 && col.dim() == 1);
if (optional_weight.has_value()) {
CHECK_CUDA(optional_weight.value());
CHECK_INPUT(optional_weight.value().dim() == 1);
CHECK_INPUT(optional_weight.value().numel() == col.numel());
}
cudaSetDevice(rowptr.get_device());
int64_t num_nodes = rowptr.numel() - 1;
auto out = torch::full(num_nodes, -1, rowptr.options());
auto proposal = torch::full(num_nodes, -1, rowptr.options());
while (!colorize(out)) {
propose(out, proposal, rowptr, col, optional_weight);
respond(out, proposal, rowptr, col, optional_weight);
}
return out;
}
__device__ int64_t done_d;
__global__ void init_done_kernel() { done_d = 1; }
__device__ bool done_d;
__global__ void init_done_kernel() { done_d = true; }
__global__ void colorize_kernel(int64_t *out, const float *bernoulli,
int64_t numel) {
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_idx < numel) {
if (out[u] < 0) {
out[u] = (int64_t)bernoulli[u] - 2;
done_d = 0;
if (out[thread_idx] < 0) {
out[thread_idx] = (int64_t)bernoulli[thread_idx] - 2;
done_d = false;
}
}
}
int64_t colorize(torch::Tensor out) {
bool colorize(torch::Tensor out) {
auto stream = at::cuda::getCurrentCUDAStream();
init_done_kernel<<<1, 1, 0, stream>>>();
auto numel = cluster.size(0);
auto numel = out.size(0);
auto props = torch::full(numel, BLUE_P, out.options().dtype(torch::kFloat));
auto bernoulli = props.bernoulli();
colorize_kernel<<<BLOCKS(numel), THREADS, 0, stream>>>(
out.data_ptr<int64_t>(), bernoulli.data_ptr<float>(), numel);
int64_t done_h;
bool done_h;
cudaMemcpyFromSymbol(&done_h, done_d, sizeof(done_h), 0,
cudaMemcpyDeviceToHost);
return done_h;
......@@ -68,25 +44,25 @@ __global__ void propose_kernel(int64_t *out, int64_t *proposal,
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_idx < numel) {
if (out[u] != -1)
continue; // Only vist blue nodes.
if (out[thread_idx] != -1)
return; // Only vist blue nodes.
bool has_unmatched_neighbor = false;
for (int64_t i = rowptr[u]; i < rowptr[u + 1]; i++) {
for (int64_t i = rowptr[thread_idx]; i < rowptr[thread_idx + 1]; i++) {
auto v = col[i];
if (out[v] < 0)
has_unmatched_neighbor = true; // Unmatched neighbor found.
if (out[v] == -2) {
proposal[u] = v; // Propose to first red neighbor.
proposal[thread_idx] = v; // Propose to first red neighbor.
break;
}
}
if (!has_unmatched_neighbor)
out[u] = u;
out[thread_idx] = thread_idx;
}
}
......@@ -98,14 +74,14 @@ __global__ void weighted_propose_kernel(int64_t *out, int64_t *proposal,
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_idx < numel) {
if (out[u] != -1)
continue; // Only vist blue nodes.
if (out[thread_idx] != -1)
return; // Only vist blue nodes.
bool has_unmatched_neighbor = false;
int64_t v_max = -1;
scalar_t w_max = 0;
for (int64_t i = rowptr[u]; i < rowptr[u + 1]; i++) {
for (int64_t i = rowptr[thread_idx]; i < rowptr[thread_idx + 1]; i++) {
auto v = col[i];
if (out[v] < 0)
......@@ -118,24 +94,25 @@ __global__ void weighted_propose_kernel(int64_t *out, int64_t *proposal,
}
}
proposal[u] = v_max; // Propose.
proposal[thread_idx] = v_max; // Propose.
if (!has_unmatched_neighbor)
out[u] = u;
out[thread_idx] = thread_idx;
}
}
void propose(torch::Tensor out, torch::Tensor proposal, torch::Tensor rowptr,
torch::Tensor col, torch::optional<torch::Tensor> optional_weight) {
torch::Tensor col,
torch::optional<torch::Tensor> optional_weight) {
auto stream = at::cuda::getCurrentCUDAStream();
if (!optional_weight.has_value()) {
propose_kernel<<<BLOCKS(out.numel()), THREADS, 0, stream>>>(
out.data_ptr<int64_t>(), proposal.data_ptr<int64_t>(),
row.data_ptr<int64_t>(), col.data_ptr<int64_t>(), out.numel());
rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(), out.numel());
} else {
auto = optional_weight.value();
auto weight = optional_weight.value();
AT_DISPATCH_ALL_TYPES(weight.scalar_type(), "propose_kernel", [&] {
weighted_propose_kernel<scalar_t>
<<<BLOCKS(out.numel()), THREADS, 0, stream>>>(
......@@ -151,27 +128,27 @@ __global__ void respond_kernel(int64_t *out, const int64_t *proposal,
int64_t numel) {
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_idx < numel) {
if (out[u] != -2)
continue; // Only vist red nodes.
if (out[thread_idx] != -2)
return; // Only vist red nodes.
bool has_unmatched_neighbor = false;
for (int64_t i = rowptr[u]; i < rowptr[u + 1]; i++) {
for (int64_t i = rowptr[thread_idx]; i < rowptr[thread_idx + 1]; i++) {
auto v = col[i];
if (out[v] < 0)
has_unmatched_neighbor = true; // Unmatched neighbor found.
if (out[v] == -1 && proposal[v] == u) {
if (out[v] == -1 && proposal[v] == thread_idx) {
// Match first blue neighbhor v which proposed to u.
out[u] = min(u, v);
out[v] = min(u, v);
out[thread_idx] = min(thread_idx, v);
out[v] = min(thread_idx, v);
break;
}
}
if (!has_unmatched_neighbor)
cluster[u] = u;
out[thread_idx] = thread_idx;
}
}
......@@ -182,20 +159,20 @@ __global__ void weighted_respond_kernel(int64_t *out, const int64_t *proposal,
const scalar_t *weight, int64_t numel) {
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_idx < numel) {
if (out[u] != -2)
continue; // Only vist red nodes.
if (out[thread_idx] != -2)
return; // Only vist red nodes.
bool has_unmatched_neighbor = false;
int64_t v_max = -1;
scalar_t w_max = 0;
for (int64_t i = rowptr[u]; i < rowptr[u + 1]; i++) {
for (int64_t i = rowptr[thread_idx]; i < rowptr[thread_idx + 1]; i++) {
auto v = col[i];
if (out[v] < 0)
has_unmatched_neighbor = true; // Unmatched neighbor found.
if (out[v] == -1 && proposal[v] == u && weight[i] >= w_max) {
if (out[v] == -1 && proposal[v] == thread_idx && weight[i] >= w_max) {
// Find maximum weighted blue neighbhor v which proposed to u.
v_max = v;
w_max = weight[i];
......@@ -203,17 +180,18 @@ __global__ void weighted_respond_kernel(int64_t *out, const int64_t *proposal,
}
if (v_max >= 0) {
out[u] = min(u, v_max); // Match neighbors.
out[v_max] = min(u, v_max);
out[thread_idx] = min(thread_idx, v_max); // Match neighbors.
out[v_max] = min(thread_idx, v_max);
}
if (!has_unmatched_neighbor)
out[u] = u;
out[thread_idx] = thread_idx;
}
}
void respond(torch::Tensor out, torch::Tensor proposal, torch::Tensor rowptr,
torch::Tensor col, torch::optional<torch::Tensor> optional_weight) {
torch::Tensor col,
torch::optional<torch::Tensor> optional_weight) {
auto stream = at::cuda::getCurrentCUDAStream();
......@@ -222,12 +200,37 @@ void respond(torch::Tensor out, torch::Tensor proposal, torch::Tensor rowptr,
out.data_ptr<int64_t>(), proposal.data_ptr<int64_t>(),
rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(), out.numel());
} else {
auto = optional_weight.value();
auto weight = optional_weight.value();
AT_DISPATCH_ALL_TYPES(weight.scalar_type(), "respond_kernel", [&] {
respond_kernel<scalar_t><<<BLOCKS(out.numel()), THREADS, 0, stream>>>(
weighted_respond_kernel<scalar_t>
<<<BLOCKS(out.numel()), THREADS, 0, stream>>>(
out.data_ptr<int64_t>(), proposal.data_ptr<int64_t>(),
rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(),
weight.data_ptr<scalar_t>(), out.numel());
});
}
}
torch::Tensor graclus_cuda(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_weight) {
CHECK_CUDA(rowptr);
CHECK_CUDA(col);
CHECK_INPUT(rowptr.dim() == 1 && col.dim() == 1);
if (optional_weight.has_value()) {
CHECK_CUDA(optional_weight.value());
CHECK_INPUT(optional_weight.value().dim() == 1);
CHECK_INPUT(optional_weight.value().numel() == col.numel());
}
cudaSetDevice(rowptr.get_device());
int64_t num_nodes = rowptr.numel() - 1;
auto out = torch::full(num_nodes, -1, rowptr.options());
auto proposal = torch::full(num_nodes, -1, rowptr.options());
while (!colorize(out)) {
propose(out, proposal, rowptr, col, optional_weight);
respond(out, proposal, rowptr, col, optional_weight);
}
return out;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment