Commit 98c4d2c6 authored by rusty1s's avatar rusty1s
Browse files

backward implementation

parent 354ef5e5
#pragma once
static inline __device__ void atomAdd(float *address, float val) {
atomicAdd(address, val);
}
static inline __device__ void atomAdd(double *address, double val) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600 || CUDA_VERSION < 8000)
unsigned long long int *address_as_ull = (unsigned long long int *)address;
unsigned long long int old = *address_as_ull;
unsigned long long int assumed;
do {
assumed = old;
old = atomicCAS(address_as_ull, assumed,
__double_as_longlong(val + __longlong_as_double(assumed)));
} while (assumed != old);
#else
atomicAdd(address, val);
#endif
}
......@@ -9,8 +9,3 @@ std::tuple<torch::Tensor, torch::Tensor>
padded_index_select_cuda(torch::Tensor src, torch::Tensor rowptr,
torch::Tensor col, torch::Tensor index, int64_t length,
torch::Tensor fill_value);
// std::tuple<torch::Tensor, torch::Tensor> padded_index_select_cuda2(
// torch::Tensor src, torch::Tensor rowptr, torch::Tensor col,
// torch::Tensor bin, torch::Tensor index, std::vector<int64_t> node_counts,
// std::vector<int64_t> lengths, torch::Tensor fill_value);
......@@ -2,6 +2,7 @@
#include <ATen/cuda/CUDAContext.h>
#include "atomics.cuh"
#include "utils.cuh"
#define THREADS 1024
......@@ -225,7 +226,7 @@ torch::Tensor padded_index_select_cuda(torch::Tensor src, torch::Tensor index,
size_t E = index.numel();
size_t F = src.size(-1);
auto out = torch::empty(E * F, src.options());
auto out = torch::empty({(int)E, (int)F}, src.options());
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "padded_index_select_kernel", [&] {
scalar_t *fill;
......@@ -245,3 +246,47 @@ torch::Tensor padded_index_select_cuda(torch::Tensor src, torch::Tensor index,
return out;
}
template <typename scalar_t>
__global__ void padded_index_scatter_kernel(const scalar_t *__restrict__ src,
const int64_t *__restrict__ index,
scalar_t *__restrict__ out,
const size_t E, const size_t F) {
for (ptrdiff_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
thread_idx < E * F; thread_idx += gridDim.x * blockDim.x) {
int64_t index_idx = __ldg(index + thread_idx / F);
if (index_idx != -1) {
atomAdd(out + index_idx * F + thread_idx % F, src[thread_idx]);
}
}
}
torch::Tensor padded_index_scatter_cuda(torch::Tensor src, torch::Tensor index,
int64_t N) {
CHECK_CUDA(src);
CHECK_CUDA(index);
CHECK_INPUT(src.dim() == 2);
CHECK_INPUT(index.dim() == 1);
CHECK_INPUT(src.size(0) == index.size(0));
cudaSetDevice(src.get_device());
auto stream = at::cuda::getCurrentCUDAStream();
size_t mpc = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
size_t E = index.numel();
size_t F = src.size(-1);
auto out = torch::zeros({N, (int)F}, src.options());
AT_DISPATCH_FLOATING_TYPES(
src.scalar_type(), "padded_index_scatter_kernel", [&] {
padded_index_scatter_kernel<scalar_t>
<<<std::min(BLOCKS(E * F), mpc * 8), THREADS, 0, stream>>>(
src.data_ptr<scalar_t>(), index.data_ptr<int64_t>(),
out.data_ptr<scalar_t>(), E, F);
});
return out;
}
......@@ -9,3 +9,6 @@ padded_index_cuda(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor padded_index_select_cuda(torch::Tensor src, torch::Tensor index,
torch::Tensor fill_value);
torch::Tensor padded_index_scatter_cuda(torch::Tensor src, torch::Tensor index,
int64_t N);
......@@ -16,9 +16,33 @@ padded_index(torch::Tensor rowptr, torch::Tensor col, torch::Tensor rowcount,
return padded_index_cuda(rowptr, col, rowcount, binptr);
}
using torch::autograd::AutogradContext;
using torch::autograd::Variable;
using torch::autograd::variable_list;
class PaddedIndexSelect : public torch::autograd::Function<PaddedIndexSelect> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable index, Variable fill_value) {
ctx->saved_data["N"] = src.size(0);
auto out = padded_index_select_cuda(src, index, fill_value);
ctx->save_for_backward({index});
return {out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto index = saved[0];
auto N = ctx->saved_data["N"].toInt();
auto grad_in = padded_index_scatter_cuda(grad_out, index, N);
return {grad_in, Variable(), Variable()};
}
};
torch::Tensor padded_index_select(torch::Tensor src, torch::Tensor index,
torch::Tensor fill_value) {
return padded_index_select_cuda(src, index, fill_value);
return PaddedIndexSelect::apply(src, index, fill_value)[0];
}
static auto registry =
......
......@@ -64,6 +64,17 @@ def test_padded_index_select(device):
# print(mask[:10])
# print(idx[:10])
x = torch.randn(adj.size(0), 512).to(device)
data = torch.ops.torch_sparse.padded_index(rowptr, col, rowcount, binptr)
node_perm, row_perm, col_perm, mask, node_sizes, edge_sizes = data
out = torch.ops.torch_sparse.padded_index_select(x, col_perm,
torch.tensor(0.))
outs = out.split(edge_sizes)
for out, size in zip(outs, node_sizes):
print(out.view(size, -1, x.size(-1)).shape)
for i in range(110):
if i == 10:
start.record()
......@@ -71,15 +82,13 @@ def test_padded_index_select(device):
end.record()
torch.cuda.synchronize()
print('padded index', start.elapsed_time(end))
return
x = torch.randn(data.num_nodes, 512).to(device)
for i in range(110):
if i == 10:
start.record()
torch.ops.torch_sparse.padded_index_select(x, col, idx,
torch.tensor(0.))
out = torch.ops.torch_sparse.padded_index_select(
x, col_perm, torch.tensor(0.))
out.split(edge_sizes)
end.record()
torch.cuda.synchronize()
print('padded index select', start.elapsed_time(end))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment