Commit 7b83a608 authored by bowendeng's avatar bowendeng
Browse files
parents 5b9304b9 0e541cc9
#include "padding_cpu.h"
#include "utils.h"
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
std::vector<int64_t>, std::vector<int64_t>>
padded_index_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor rowcount, torch::Tensor binptr) {
CHECK_CPU(rowptr);
CHECK_CPU(col);
CHECK_CPU(rowcount);
CHECK_CPU(binptr);
CHECK_INPUT(rowptr.numel() == rowcount.numel() + 1);
ptrdiff_t B = binptr.numel() - 1;
ptrdiff_t N = rowcount.numel();
auto rowptr_data = rowptr.data_ptr<int64_t>();
auto col_data = col.data_ptr<int64_t>();
auto rowcount_data = rowcount.data_ptr<int64_t>();
auto binptr_data = binptr.data_ptr<int64_t>();
auto bin = torch::empty(N, col.options());
auto bin_data = bin.data_ptr<int64_t>();
auto idx = torch::empty(N, col.options());
auto idx_data = idx.data_ptr<int64_t>();
std::vector<int64_t> node_sizes(B), edge_sizes(B), max_degs(B),
node_offsets(B + 1), edge_offsets(B + 1);
int64_t deg, bin_idx = -1;
for (ptrdiff_t n = 0; n < N; n++) {
deg = rowcount_data[n];
for (ptrdiff_t b = 1; b <= B; b++) {
if (deg < binptr_data[b]) {
bin_idx = b - 1;
break;
}
}
if (bin_idx == -1) {
bin_idx = B - 1;
}
bin_data[n] = bin_idx;
idx_data[n] = node_sizes[bin_idx];
node_sizes[bin_idx] += 1;
max_degs[bin_idx] = std::max(max_degs[bin_idx], deg);
}
for (ptrdiff_t b = 0; b < B; b++) {
edge_sizes[b] = node_sizes[b] * max_degs[b];
node_offsets[b + 1] = node_offsets[b] + node_sizes[b];
edge_offsets[b + 1] = edge_offsets[b] + edge_sizes[b];
}
auto node_perm = torch::empty(N, col.options());
auto node_perm_data = node_perm.data_ptr<int64_t>();
auto E = edge_offsets[B];
auto row_perm = torch::empty(E, col.options());
auto row_perm_data = row_perm.data_ptr<int64_t>();
auto col_perm = torch::empty(E, col.options());
auto col_perm_data = col_perm.data_ptr<int64_t>();
auto edge_mask = torch::empty(E, col.options().dtype(torch::kBool));
auto edge_mask_data = edge_mask.data_ptr<bool>();
int64_t row_start = rowptr_data[0], row_end, edge_offset, offset;
for (ptrdiff_t n = 0; n < N; n++) {
bin_idx = bin_data[n];
offset = idx_data[n];
node_perm_data[node_offsets[bin_idx] + offset] = n;
row_end = rowptr_data[n + 1];
edge_offset = edge_offsets[bin_idx] + offset * max_degs[bin_idx];
for (ptrdiff_t e = 0; e < row_end - row_start; e++) {
row_perm_data[edge_offset + e] = n;
col_perm_data[edge_offset + e] = col_data[row_start + e];
edge_mask_data[edge_offset + e] = false;
}
for (ptrdiff_t e = row_end - row_start; e < max_degs[bin_data[n]]; e++) {
row_perm_data[edge_offset + e] = -1;
col_perm_data[edge_offset + e] = -1;
edge_mask_data[edge_offset + e] = true;
}
row_start = row_end;
}
return std::make_tuple(node_perm, row_perm, col_perm, edge_mask, node_sizes,
edge_sizes);
}
torch::Tensor padded_index_select_cpu(torch::Tensor src, torch::Tensor index,
torch::Tensor fill_value) {
CHECK_CPU(src);
CHECK_CPU(index);
CHECK_INPUT(src.dim() == 2);
CHECK_INPUT(index.dim() == 1);
auto mask = index == -1;
auto out = src.index_select(0, index.masked_fill(mask, 0));
out.masked_fill_(mask.view({-1, 1}).expand_as(out), fill_value);
return out;
}
torch::Tensor padded_index_scatter_cpu(torch::Tensor src, torch::Tensor index,
int64_t N) {
CHECK_CPU(src);
CHECK_CPU(index);
CHECK_INPUT(src.dim() == 2);
CHECK_INPUT(index.dim() == 1);
CHECK_INPUT(src.size(0) == index.size(0));
auto mask = index == -1;
index = index.masked_fill(mask, N);
auto out = torch::zeros({N + 1, src.size(-1)}, src.options());
out.scatter_add_(0, index.view({-1, 1}).expand_as(src), src);
out = out.narrow(0, 0, N);
return out;
}
#pragma once
#include <torch/extension.h>
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
std::vector<int64_t>, std::vector<int64_t>>
padded_index_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor rowcount, torch::Tensor binptr);
torch::Tensor padded_index_select_cpu(torch::Tensor src, torch::Tensor index,
torch::Tensor fill_value);
torch::Tensor padded_index_scatter_cpu(torch::Tensor src, torch::Tensor index,
int64_t N);
#include "rw_cpu.h"
#include "utils.h"
torch::Tensor random_walk_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor start, int64_t walk_length) {
CHECK_CPU(rowptr);
CHECK_CPU(col);
CHECK_CPU(start);
CHECK_INPUT(rowptr.dim() == 1);
CHECK_INPUT(col.dim() == 1);
CHECK_INPUT(start.dim() == 1);
auto rand = torch::rand({start.size(0), walk_length},
start.options().dtype(torch::kFloat));
auto L = walk_length + 1;
auto out = torch::full({start.size(0), L}, -1, start.options());
auto rowptr_data = rowptr.data_ptr<int64_t>();
auto col_data = col.data_ptr<int64_t>();
auto start_data = start.data_ptr<int64_t>();
auto rand_data = rand.data_ptr<float>();
auto out_data = out.data_ptr<int64_t>();
for (auto n = 0; n < start.size(0); n++) {
auto cur = start_data[n];
out_data[n * L] = cur;
int64_t row_start, row_end;
for (auto l = 0; l < walk_length; l++) {
row_start = rowptr_data[cur];
row_end = rowptr_data[cur + 1];
cur = col_data[row_start + int64_t(rand_data[n * walk_length + l] *
(row_end - row_start))];
out_data[n * L + l + 1] = cur;
}
}
return out;
}
#pragma once
#include <torch/extension.h>
torch::Tensor random_walk_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor start, int64_t walk_length);
#include "saint_cpu.h"
#include "utils.h"
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
subgraph_cpu(torch::Tensor idx, torch::Tensor rowptr, torch::Tensor row,
torch::Tensor col) {
CHECK_CPU(idx);
CHECK_CPU(rowptr);
CHECK_CPU(col);
CHECK_INPUT(idx.dim() == 1);
CHECK_INPUT(rowptr.dim() == 1);
CHECK_INPUT(col.dim() == 1);
auto assoc = torch::full({rowptr.size(0) - 1}, -1, idx.options());
assoc.index_copy_(0, idx, torch::arange(idx.size(0), idx.options()));
auto idx_data = idx.data_ptr<int64_t>();
auto rowptr_data = rowptr.data_ptr<int64_t>();
auto col_data = col.data_ptr<int64_t>();
auto assoc_data = assoc.data_ptr<int64_t>();
std::vector<int64_t> rows, cols, indices;
int64_t v, w, w_new, row_start, row_end;
for (int64_t v_new = 0; v_new < idx.size(0); v_new++) {
v = idx_data[v_new];
row_start = rowptr_data[v];
row_end = rowptr_data[v + 1];
for (int64_t j = row_start; j < row_end; j++) {
w = col_data[j];
w_new = assoc_data[w];
if (w_new > -1) {
rows.push_back(v_new);
cols.push_back(w_new);
indices.push_back(j);
}
}
}
int64_t length = rows.size();
row = torch::from_blob(rows.data(), {length}, row.options()).clone();
col = torch::from_blob(cols.data(), {length}, row.options()).clone();
idx = torch::from_blob(indices.data(), {length}, row.options()).clone();
return std::make_tuple(row, col, idx);
}
#pragma once
#include <torch/extension.h>
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
subgraph_cpu(torch::Tensor idx, torch::Tensor rowptr, torch::Tensor row,
torch::Tensor col);
#pragma once
static inline __device__ void atomAdd(float *address, float val) {
atomicAdd(address, val);
}
static inline __device__ void atomAdd(double *address, double val) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600 || CUDA_VERSION < 8000)
unsigned long long int *address_as_ull = (unsigned long long int *)address;
unsigned long long int old = *address_as_ull;
unsigned long long int assumed;
do {
assumed = old;
old = atomicCAS(address_as_ull, assumed,
__double_as_longlong(val + __longlong_as_double(assumed)));
} while (assumed != old);
#else
atomicAdd(address, val);
#endif
}
#include "padding_cuda.h"
#include <ATen/cuda/CUDAContext.h>
#include "atomics.cuh"
#include "utils.cuh"
#define THREADS 1024
#define FULL_MASK 0xffffffff
#define BLOCKS(N) (N + THREADS - 1) / THREADS
__global__ void bin_kernel(const int64_t *__restrict__ rowcount,
const int64_t *__restrict__ binptr,
int64_t *__restrict__ bin, int64_t *__restrict__ idx,
int *__restrict__ node_size,
int *__restrict__ max_deg, const size_t B,
const size_t N) {
for (ptrdiff_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
thread_idx < N; thread_idx += gridDim.x * blockDim.x) {
int bin_idx = -1, deg = rowcount[thread_idx];
for (ptrdiff_t b = 1; b <= B; b++) {
if (deg < __ldg(binptr + b)) {
bin_idx = b - 1;
break;
}
}
if (bin_idx == -1) {
bin_idx = B - 1;
}
int old = atomicAdd(node_size + bin_idx, 1);
atomicMax(max_deg + bin_idx, deg);
bin[thread_idx] = bin_idx;
idx[thread_idx] = old;
}
}
__global__ void info_kernel(const int *__restrict__ node_size,
const int *__restrict__ max_deg,
int *__restrict__ edge_size,
int *__restrict__ node_offset,
int *__restrict__ edge_offset, const size_t B) {
int thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
int bin_idx = thread_idx / 32;
int lane_idx = thread_idx % 32;
if (bin_idx <= B) { // Computes `node_offset` and `edge_offset`.
int node_tmp = 0;
int edge_tmp = 0;
for (int i = lane_idx; i < bin_idx; i += 32) {
node_tmp += node_size[i];
edge_tmp += node_size[i] * max_deg[i];
}
for (int i = 32 / 2; i > 0; i /= 2) {
node_tmp += __shfl_down_sync(FULL_MASK, node_tmp, i);
edge_tmp += __shfl_down_sync(FULL_MASK, edge_tmp, i);
}
if (lane_idx == 0) {
node_offset[bin_idx] = node_tmp;
edge_offset[bin_idx] = edge_tmp;
}
} else if (bin_idx == B + 1) { // Computes `edge_size`.
for (int i = lane_idx; i < B; i += 32) {
edge_size[i] = node_size[i] * max_deg[i];
}
}
}
__global__ void node_perm_kernel(const int64_t *__restrict__ bin,
const int64_t *__restrict__ idx,
const int *__restrict__ node_offset,
int64_t *__restrict__ out, const size_t N) {
for (ptrdiff_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
thread_idx < N; thread_idx += gridDim.x * blockDim.x) {
out[__ldg(node_offset + bin[thread_idx]) + idx[thread_idx]] = thread_idx;
}
}
template <int TB>
__global__ void padded_index_kernel(
const int64_t *__restrict__ rowptr, const int64_t *__restrict__ col,
const int64_t *__restrict__ rowcount, const int64_t *__restrict__ bin,
const int64_t *__restrict__ idx, const int *__restrict__ max_deg,
const int *__restrict__ edge_offset, int64_t *__restrict__ row_perm,
int64_t *__restrict__ col_perm, bool *__restrict__ edge_mask,
const size_t B, const size_t N) {
for (ptrdiff_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
thread_idx < TB * N; thread_idx += gridDim.x * blockDim.x) {
int row_idx = thread_idx / TB;
int lane_idx = thread_idx % TB;
int64_t bin_idx = bin[row_idx];
int len = __ldg(max_deg + bin_idx);
int off = __ldg(edge_offset + bin_idx) + len * idx[row_idx];
int64_t row_start = rowptr[row_idx], deg = rowcount[row_idx];
int64_t row_tmp, col_tmp;
for (int i = lane_idx; i < len; i += TB) {
row_tmp = -1, col_tmp = -1;
if (i < deg) {
row_tmp = row_idx;
col_tmp = col[row_start + i];
}
row_perm[off + i] = row_tmp;
col_perm[off + i] = col_tmp;
edge_mask[off + i] = row_tmp == -1;
}
}
}
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
std::vector<int64_t>, std::vector<int64_t>>
padded_index_cuda(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor rowcount, torch::Tensor binptr) {
CHECK_CUDA(rowptr);
CHECK_CUDA(col);
CHECK_CUDA(rowcount);
CHECK_CUDA(binptr);
CHECK_INPUT(rowptr.numel() == rowcount.numel() + 1);
cudaSetDevice(rowcount.get_device());
auto stream = at::cuda::getCurrentCUDAStream();
size_t mpc = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
size_t B = binptr.numel() - 1;
size_t N = rowcount.numel();
auto bin = torch::empty(N, col.options());
auto idx = torch::empty(N, col.options());
auto d_info = torch::zeros(5 * B + 2, col.options().dtype(torch::kInt));
auto d_node_size = d_info.narrow(0, 0, B);
auto d_edge_size = d_info.narrow(0, B, B);
auto d_max_deg = d_info.narrow(0, 2 * B, B);
auto d_node_offset = d_info.narrow(0, 3 * B, B + 1);
auto d_edge_offset = d_info.narrow(0, 4 * B + 1, B + 1);
bin_kernel<<<std::min(BLOCKS(N), mpc * 8), THREADS, 0, stream>>>(
rowcount.data_ptr<int64_t>(), binptr.data_ptr<int64_t>(),
bin.data_ptr<int64_t>(), idx.data_ptr<int64_t>(),
d_node_size.data_ptr<int>(), d_max_deg.data_ptr<int>(), B, N);
info_kernel<<<BLOCKS(32 * (B + 2)), THREADS, 0, stream>>>(
d_node_size.data_ptr<int>(), d_max_deg.data_ptr<int>(),
d_edge_size.data_ptr<int>(), d_node_offset.data_ptr<int>(),
d_edge_offset.data_ptr<int>(), B);
auto node_perm = torch::empty(N, col.options());
node_perm_kernel<<<std::min(BLOCKS(N), mpc * 8), THREADS, 0, stream>>>(
bin.data_ptr<int64_t>(), idx.data_ptr<int64_t>(),
d_node_offset.data_ptr<int>(), node_perm.data_ptr<int64_t>(), N);
auto h_info = torch::empty(
d_info.numel(), d_info.options().device(torch::kCPU).pinned_memory(true));
cudaMemcpy(h_info.data_ptr<int>(), d_info.data_ptr<int>(),
d_info.numel() * sizeof(int), cudaMemcpyDeviceToHost);
size_t E = h_info.data_ptr<int>()[5 * B + 1];
auto row_perm = torch::empty(E, col.options());
auto col_perm = torch::empty(E, col.options());
auto edge_mask = torch::empty(E, col.options().dtype(torch::kBool));
padded_index_kernel<8>
<<<std::min(BLOCKS(N * 8), mpc * 8), THREADS, 0, stream>>>(
rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(),
rowcount.data_ptr<int64_t>(), bin.data_ptr<int64_t>(),
idx.data_ptr<int64_t>(), d_max_deg.data_ptr<int>(),
d_edge_offset.data_ptr<int>(), row_perm.data_ptr<int64_t>(),
col_perm.data_ptr<int64_t>(), edge_mask.data_ptr<bool>(), B, N);
h_info = h_info.to(torch::kLong);
auto h_info_data = h_info.data_ptr<int64_t>();
std::vector<int64_t> node_sizes(h_info_data, h_info_data + B);
std::vector<int64_t> edge_sizes(h_info_data + B, h_info_data + 2 * B);
return std::make_tuple(node_perm, row_perm, col_perm, edge_mask, node_sizes,
edge_sizes);
}
template <typename scalar_t>
__global__ void padded_index_select_kernel(const scalar_t *__restrict__ src,
const int64_t *__restrict__ index,
scalar_t *__restrict__ out,
const scalar_t fill_value,
const size_t E, const size_t F) {
for (ptrdiff_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
thread_idx < E * F; thread_idx += gridDim.x * blockDim.x) {
int64_t row_idx = thread_idx / F;
int64_t lane_idx = thread_idx % F;
int64_t index_idx = __ldg(index + row_idx);
scalar_t tmp = fill_value;
if (index_idx != -1) {
tmp = src[index_idx * F + lane_idx];
}
out[thread_idx] = tmp;
}
}
torch::Tensor padded_index_select_cuda(torch::Tensor src, torch::Tensor index,
torch::Tensor fill_value) {
CHECK_CUDA(src);
CHECK_CUDA(index);
CHECK_INPUT(src.dim() == 2);
CHECK_INPUT(index.dim() == 1);
cudaSetDevice(src.get_device());
auto stream = at::cuda::getCurrentCUDAStream();
size_t mpc = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
size_t E = index.numel();
size_t F = src.size(-1);
auto out = torch::empty({(int)E, (int)F}, src.options());
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "padded_index_select_kernel", [&] {
scalar_t *fill;
if (fill_value.is_cuda()) {
fill = (scalar_t *)malloc(sizeof(scalar_t));
cudaMemcpy(fill, fill_value.data_ptr<scalar_t>(), sizeof(scalar_t),
cudaMemcpyDeviceToHost);
} else {
fill = fill_value.data_ptr<scalar_t>();
}
padded_index_select_kernel<scalar_t>
<<<std::min(BLOCKS(E * F), mpc * 8), THREADS, 0, stream>>>(
src.data_ptr<scalar_t>(), index.data_ptr<int64_t>(),
out.data_ptr<scalar_t>(), fill[0], E, F);
});
return out;
}
template <typename scalar_t>
__global__ void padded_index_scatter_kernel(const scalar_t *__restrict__ src,
const int64_t *__restrict__ index,
scalar_t *__restrict__ out,
const size_t E, const size_t F) {
for (ptrdiff_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
thread_idx < E * F; thread_idx += gridDim.x * blockDim.x) {
int64_t index_idx = __ldg(index + thread_idx / F);
if (index_idx != -1) {
atomAdd(out + index_idx * F + thread_idx % F, src[thread_idx]);
}
}
}
torch::Tensor padded_index_scatter_cuda(torch::Tensor src, torch::Tensor index,
int64_t N) {
CHECK_CUDA(src);
CHECK_CUDA(index);
CHECK_INPUT(src.dim() == 2);
CHECK_INPUT(index.dim() == 1);
CHECK_INPUT(src.size(0) == index.size(0));
cudaSetDevice(src.get_device());
auto stream = at::cuda::getCurrentCUDAStream();
size_t mpc = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
size_t E = index.numel();
size_t F = src.size(-1);
auto out = torch::zeros({N, (int)F}, src.options());
AT_DISPATCH_FLOATING_TYPES(
src.scalar_type(), "padded_index_scatter_kernel", [&] {
padded_index_scatter_kernel<scalar_t>
<<<std::min(BLOCKS(E * F), mpc * 8), THREADS, 0, stream>>>(
src.data_ptr<scalar_t>(), index.data_ptr<int64_t>(),
out.data_ptr<scalar_t>(), E, F);
});
return out;
}
#pragma once
#include <torch/extension.h>
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
std::vector<int64_t>, std::vector<int64_t>>
padded_index_cuda(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor rowcount, torch::Tensor binptr);
torch::Tensor padded_index_select_cuda(torch::Tensor src, torch::Tensor index,
torch::Tensor fill_value);
torch::Tensor padded_index_scatter_cuda(torch::Tensor src, torch::Tensor index,
int64_t N);
#include "rw_cuda.h"
#include <ATen/cuda/CUDAContext.h>
#include "utils.cuh"
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
__global__ void uniform_random_walk_kernel(const int64_t *rowptr,
const int64_t *col,
const int64_t *start,
const float *rand, int64_t *out,
int64_t walk_length, int64_t numel) {
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_idx < numel) {
int64_t cur = start[thread_idx];
out[thread_idx] = cur;
int64_t row_start, row_end;
for (int64_t l = 0; l < walk_length; l++) {
row_start = rowptr[cur], row_end = rowptr[cur + 1];
cur = col[row_start +
int64_t(rand[l * numel + thread_idx] * (row_end - row_start))];
out[(l + 1) * numel + thread_idx] = cur;
}
}
}
torch::Tensor random_walk_cuda(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor start, int64_t walk_length) {
CHECK_CUDA(rowptr);
CHECK_CUDA(col);
CHECK_CUDA(start);
cudaSetDevice(rowptr.get_device());
CHECK_INPUT(rowptr.dim() == 1);
CHECK_INPUT(col.dim() == 1);
CHECK_INPUT(start.dim() == 1);
auto rand = torch::rand({walk_length, start.size(0)},
start.options().dtype(torch::kFloat));
auto out = torch::full({walk_length + 1, start.size(0)}, -1, start.options());
auto stream = at::cuda::getCurrentCUDAStream();
uniform_random_walk_kernel<<<BLOCKS(start.numel()), THREADS, 0, stream>>>(
rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(),
start.data_ptr<int64_t>(), rand.data_ptr<float>(),
out.data_ptr<int64_t>(), walk_length, start.numel());
return out.t().contiguous();
}
#pragma once
#include <torch/extension.h>
torch::Tensor random_walk_cuda(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor start, int64_t walk_length);
#include <Python.h>
#include <torch/script.h>
#include "cpu/padding_cpu.h"
#ifdef WITH_CUDA
#include "cuda/padding_cuda.h"
#endif
#ifdef _WIN32
PyMODINIT_FUNC PyInit__padding(void) { return NULL; }
#endif
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
std::vector<int64_t>, std::vector<int64_t>>
padded_index(torch::Tensor rowptr, torch::Tensor col, torch::Tensor rowcount,
torch::Tensor binptr) {
if (rowptr.device().is_cuda()) {
#ifdef WITH_CUDA
return padded_index_cuda(rowptr, col, rowcount, binptr);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return padded_index_cpu(rowptr, col, rowcount, binptr);
}
}
using torch::autograd::AutogradContext;
using torch::autograd::Variable;
using torch::autograd::variable_list;
class PaddedIndexSelect : public torch::autograd::Function<PaddedIndexSelect> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable index, Variable fill_value) {
ctx->saved_data["N"] = src.size(0);
torch::Tensor out;
if (src.device().is_cuda()) {
#ifdef WITH_CUDA
out = padded_index_select_cuda(src, index, fill_value);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
out = padded_index_select_cpu(src, index, fill_value);
}
ctx->save_for_backward({index});
return {out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto index = saved[0];
auto N = ctx->saved_data["N"].toInt();
torch::Tensor grad_in;
if (grad_out.device().is_cuda()) {
#ifdef WITH_CUDA
grad_in = padded_index_scatter_cuda(grad_out, index, N);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
grad_in = padded_index_scatter_cpu(grad_out, index, N);
}
return {grad_in, Variable(), Variable()};
}
};
torch::Tensor padded_index_select(torch::Tensor src, torch::Tensor index,
torch::Tensor fill_value) {
return PaddedIndexSelect::apply(src, index, fill_value)[0];
}
static auto registry =
torch::RegisterOperators()
.op("torch_sparse::padded_index", &padded_index)
.op("torch_sparse::padded_index_select", &padded_index_select);
#include <Python.h>
#include <torch/script.h>
#include "cpu/rw_cpu.h"
#ifdef WITH_CUDA
#include "cuda/rw_cuda.h"
#endif
#ifdef _WIN32
PyMODINIT_FUNC PyInit__rw(void) { return NULL; }
#endif
torch::Tensor random_walk(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor start, int64_t walk_length) {
if (rowptr.device().is_cuda()) {
#ifdef WITH_CUDA
return random_walk_cuda(rowptr, col, start, walk_length);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return random_walk_cpu(rowptr, col, start, walk_length);
}
}
static auto registry =
torch::RegisterOperators().op("torch_sparse::random_walk", &random_walk);
#include <Python.h>
#include <torch/script.h>
#include "cpu/saint_cpu.h"
#ifdef _WIN32
PyMODINIT_FUNC PyInit__saint(void) { return NULL; }
#endif
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
subgraph(torch::Tensor idx, torch::Tensor rowptr, torch::Tensor row,
torch::Tensor col) {
if (idx.device().is_cuda()) {
#ifdef WITH_CUDA
AT_ERROR("No CUDA version supported");
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return subgraph_cpu(idx, rowptr, row, col);
}
}
static auto registry =
torch::RegisterOperators().op("torch_sparse::saint_subgraph", &subgraph);
from itertools import product
import pytest
import torch
from torch_sparse import SparseTensor, padded_index_select
from .utils import grad_dtypes, devices, tensor
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_padded_index_select(dtype, device):
row = torch.tensor([0, 0, 0, 0, 1, 1, 1, 2, 2, 3])
col = torch.tensor([0, 1, 2, 3, 0, 2, 3, 1, 3, 2])
adj = SparseTensor(row=row, col=col).to(device)
binptr = torch.tensor([0, 3, 5], device=device)
data = adj.padded_index(binptr)
node_perm, row_perm, col_perm, mask, node_size, edge_size = data
assert node_perm.tolist() == [2, 3, 0, 1]
assert row_perm.tolist() == [2, 2, 3, -1, 0, 0, 0, 0, 1, 1, 1, -1]
assert col_perm.tolist() == [1, 3, 2, -1, 0, 1, 2, 3, 0, 2, 3, -1]
assert mask.long().tolist() == [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1]
assert node_size == [2, 2]
assert edge_size == [4, 8]
x = tensor([0, 1, 2, 3], dtype, device).view(-1, 1).requires_grad_()
x_j = padded_index_select(x, col_perm)
assert x_j.flatten().tolist() == [1, 3, 2, 0, 0, 1, 2, 3, 0, 2, 3, 0]
grad_out = tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], dtype, device)
x_j.backward(grad_out.view(-1, 1))
assert x.grad.flatten().tolist() == [12, 5, 17, 18]
def test_padded_index_select_runtime():
return
from torch_geometric.datasets import Planetoid
device = torch.device('cuda')
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
dataset = Planetoid('/tmp/Planetoid', name='PubMed')
data = dataset[0]
row, col = data.edge_index.to(device)
adj = SparseTensor(row=row, col=col)
rowcount = adj.storage.rowcount().to(device)
rowptr = adj.storage.rowptr().to(device)
binptr = torch.tensor([0, 4, 11, 30, 50, 80, 120, 140, 2000]).to(device)
x = torch.randn(adj.size(0), 512).to(device)
data = torch.ops.torch_sparse.padded_index(rowptr, col, rowcount, binptr)
node_perm, row_perm, col_perm, mask, node_sizes, edge_sizes = data
out = torch.ops.torch_sparse.padded_index_select(x, col_perm,
torch.tensor(0.))
outs = out.split(edge_sizes)
for out, size in zip(outs, node_sizes):
print(out.view(size, -1, x.size(-1)).shape)
for i in range(110):
if i == 10:
start.record()
torch.ops.torch_sparse.padded_index(rowptr, col, rowcount, binptr)
end.record()
torch.cuda.synchronize()
print('padded index', start.elapsed_time(end))
for i in range(110):
if i == 10:
start.record()
out = torch.ops.torch_sparse.padded_index_select(
x, col_perm, torch.tensor(0.))
out.split(edge_sizes)
end.record()
torch.cuda.synchronize()
print('padded index select', start.elapsed_time(end))
for i in range(110):
if i == 10:
start.record()
x.index_select(0, col)
end.record()
torch.cuda.synchronize()
print('index_select', start.elapsed_time(end))
import torch
from torch_sparse.tensor import SparseTensor
def test_saint_subgraph():
row = torch.tensor([0, 0, 1, 1, 2, 2, 2, 3, 3, 4])
col = torch.tensor([1, 2, 0, 2, 0, 1, 3, 2, 4, 3])
adj = SparseTensor(row=row, col=col)
node_idx = torch.tensor([0, 1, 2])
adj, edge_index = adj.saint_subgraph(node_idx)
...@@ -8,7 +8,8 @@ expected_torch_version = (1, 4) ...@@ -8,7 +8,8 @@ expected_torch_version = (1, 4)
try: try:
for library in [ for library in [
'_version', '_convert', '_diag', '_spmm', '_spspmm', '_metis' '_version', '_convert', '_diag', '_spmm', '_spspmm', '_metis',
'_rw', '_saint', '_padding'
]: ]:
torch.ops.load_library(importlib.machinery.PathFinder().find_spec( torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
library, [osp.dirname(__file__)]).origin) library, [osp.dirname(__file__)]).origin)
...@@ -54,7 +55,10 @@ from .mul import mul, mul_, mul_nnz, mul_nnz_ # noqa ...@@ -54,7 +55,10 @@ from .mul import mul, mul_, mul_nnz, mul_nnz_ # noqa
from .reduce import sum, mean, min, max # noqa from .reduce import sum, mean, min, max # noqa
from .matmul import matmul # noqa from .matmul import matmul # noqa
from .cat import cat, cat_diag # noqa from .cat import cat, cat_diag # noqa
from .rw import random_walk # noqa
from .metis import partition # noqa from .metis import partition # noqa
from .saint import saint_subgraph # noqa
from .padding import padded_index, padded_index_select # noqa
from .convert import to_torch_sparse, from_torch_sparse # noqa from .convert import to_torch_sparse, from_torch_sparse # noqa
from .convert import to_scipy, from_scipy # noqa from .convert import to_scipy, from_scipy # noqa
...@@ -94,7 +98,11 @@ __all__ = [ ...@@ -94,7 +98,11 @@ __all__ = [
'matmul', 'matmul',
'cat', 'cat',
'cat_diag', 'cat_diag',
'random_walk',
'partition', 'partition',
'saint_subgraph',
'padded_index',
'padded_index_select',
'to_torch_sparse', 'to_torch_sparse',
'from_torch_sparse', 'from_torch_sparse',
'to_scipy', 'to_scipy',
......
...@@ -34,5 +34,4 @@ def partition(src: SparseTensor, num_parts: int, recursive: bool = False ...@@ -34,5 +34,4 @@ def partition(src: SparseTensor, num_parts: int, recursive: bool = False
return out, partptr, perm return out, partptr, perm
SparseTensor.partition = lambda self, num_parts, recursive=False: partition( SparseTensor.partition = partition
self, num_parts, recursive)
from typing import Tuple, List
import torch
from torch_sparse.tensor import SparseTensor
def padded_index(src: SparseTensor, binptr: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.
Tensor, List[int], List[int]]:
return torch.ops.torch_sparse.padded_index(src.storage.rowptr(),
src.storage.col(),
src.storage.rowcount(), binptr)
def padded_index_select(src: torch.Tensor, index: torch.Tensor,
fill_value: float = 0.) -> torch.Tensor:
fill_value = torch.tensor(fill_value, dtype=src.dtype)
return torch.ops.torch_sparse.padded_index_select(src, index, fill_value)
SparseTensor.padded_index = padded_index
import torch
from torch_sparse.tensor import SparseTensor
def random_walk(src: SparseTensor, start: torch.Tensor,
walk_length: int) -> torch.Tensor:
rowptr, col, _ = src.csr()
return torch.ops.torch_sparse.random_walk(rowptr, col, start, walk_length)
SparseTensor.random_walk = random_walk
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment