Commit ceb73a8c authored by rusty1s's avatar rusty1s
Browse files

padded_index_select

parent 8b77e547
#include "padding_cpu.h"
#include "utils.h"
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
std::vector<int64_t>, std::vector<int64_t>>
padded_index_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor rowcount, torch::Tensor binptr) {
std::vector<int64_t> bla = {1};
return std::make_tuple(col, col, col, col, bla, bla);
}
torch::Tensor padded_index_select_cpu(torch::Tensor src, torch::Tensor index,
torch::Tensor fill_value) {
CHECK_CPU(src);
CHECK_CPU(index);
CHECK_INPUT(src.dim() == 2);
CHECK_INPUT(index.dim() == 1);
auto mask = index == -1;
auto out = src.index_select(0, index.masked_fill(mask, 0));
out.masked_fill_(mask.view({-1, 1}).expand_as(out), fill_value);
return out;
}
torch::Tensor padded_index_scatter_cpu(torch::Tensor src, torch::Tensor index,
int64_t N) {
CHECK_CPU(src);
CHECK_CPU(index);
CHECK_INPUT(src.dim() == 2);
CHECK_INPUT(index.dim() == 1);
CHECK_INPUT(src.size(0) == index.size(0));
auto mask = index == -1;
index = index.masked_fill(mask, N);
auto out = torch::zeros({N + 1, src.size(-1)}, src.options());
out.scatter_add_(0, index.view({-1, 1}).expand_as(src), src);
out = out.narrow(0, 0, N);
return out;
}
#pragma once
#include <torch/extension.h>
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
std::vector<int64_t>, std::vector<int64_t>>
padded_index_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor rowcount, torch::Tensor binptr);
torch::Tensor padded_index_select_cpu(torch::Tensor src, torch::Tensor index,
torch::Tensor fill_value);
torch::Tensor padded_index_scatter_cpu(torch::Tensor src, torch::Tensor index,
int64_t N);
#include <Python.h> #include <Python.h>
#include <torch/script.h> #include <torch/script.h>
#include "cpu/padding_cpu.h"
#ifdef WITH_CUDA #ifdef WITH_CUDA
#include "cuda/padding_cuda.h" #include "cuda/padding_cuda.h"
#endif #endif
...@@ -13,7 +15,15 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, ...@@ -13,7 +15,15 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
std::vector<int64_t>, std::vector<int64_t>> std::vector<int64_t>, std::vector<int64_t>>
padded_index(torch::Tensor rowptr, torch::Tensor col, torch::Tensor rowcount, padded_index(torch::Tensor rowptr, torch::Tensor col, torch::Tensor rowcount,
torch::Tensor binptr) { torch::Tensor binptr) {
if (rowptr.device().is_cuda()) {
#ifdef WITH_CUDA
return padded_index_cuda(rowptr, col, rowcount, binptr); return padded_index_cuda(rowptr, col, rowcount, binptr);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return padded_index_cpu(rowptr, col, rowcount, binptr);
}
} }
using torch::autograd::AutogradContext; using torch::autograd::AutogradContext;
...@@ -25,7 +35,17 @@ public: ...@@ -25,7 +35,17 @@ public:
static variable_list forward(AutogradContext *ctx, Variable src, static variable_list forward(AutogradContext *ctx, Variable src,
Variable index, Variable fill_value) { Variable index, Variable fill_value) {
ctx->saved_data["N"] = src.size(0); ctx->saved_data["N"] = src.size(0);
auto out = padded_index_select_cuda(src, index, fill_value);
torch::Tensor out;
if (src.device().is_cuda()) {
#ifdef WITH_CUDA
out = padded_index_select_cuda(src, index, fill_value);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
out = padded_index_select_cpu(src, index, fill_value);
}
ctx->save_for_backward({index}); ctx->save_for_backward({index});
return {out}; return {out};
} }
...@@ -35,7 +55,16 @@ public: ...@@ -35,7 +55,16 @@ public:
auto saved = ctx->get_saved_variables(); auto saved = ctx->get_saved_variables();
auto index = saved[0]; auto index = saved[0];
auto N = ctx->saved_data["N"].toInt(); auto N = ctx->saved_data["N"].toInt();
auto grad_in = padded_index_scatter_cuda(grad_out, index, N); torch::Tensor grad_in;
if (grad_out.device().is_cuda()) {
#ifdef WITH_CUDA
grad_in = padded_index_scatter_cuda(grad_out, index, N);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
grad_in = padded_index_scatter_cpu(grad_out, index, N);
}
return {grad_in, Variable(), Variable()}; return {grad_in, Variable(), Variable()};
} }
}; };
......
...@@ -4,9 +4,7 @@ import pytest ...@@ -4,9 +4,7 @@ import pytest
import torch import torch
from torch_sparse import SparseTensor, padded_index_select from torch_sparse import SparseTensor, padded_index_select
from .utils import grad_dtypes, tensor from .utils import grad_dtypes, devices, tensor
devices = [torch.device('cuda')]
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment