Commit 6e87043a authored by rusty1s's avatar rusty1s
Browse files

diag torch script support

parent b4030755
#include <torch/extension.h> #include <torch/script.h>
#include "compat.h" #include "compat.h"
#define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be CPU tensor") #define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be CPU tensor")
at::Tensor non_diag_mask(at::Tensor row, at::Tensor col, int64_t M, int64_t N, torch::Tensor non_diag_mask(torch::Tensor row, torch::Tensor col, int64_t M,
int64_t k) { int64_t N, int64_t k) {
CHECK_CPU(row); CHECK_CPU(row);
CHECK_CPU(col); CHECK_CPU(col);
...@@ -15,7 +15,7 @@ at::Tensor non_diag_mask(at::Tensor row, at::Tensor col, int64_t M, int64_t N, ...@@ -15,7 +15,7 @@ at::Tensor non_diag_mask(at::Tensor row, at::Tensor col, int64_t M, int64_t N,
auto row_data = row.DATA_PTR<int64_t>(); auto row_data = row.DATA_PTR<int64_t>();
auto col_data = col.DATA_PTR<int64_t>(); auto col_data = col.DATA_PTR<int64_t>();
auto mask = at::zeros(E + num_diag, row.options().dtype(at::kBool)); auto mask = torch::zeros(E + num_diag, row.options().dtype(at::kBool));
auto mask_data = mask.DATA_PTR<bool>(); auto mask_data = mask.DATA_PTR<bool>();
int64_t r, c; int64_t r, c;
...@@ -48,6 +48,5 @@ at::Tensor non_diag_mask(at::Tensor row, at::Tensor col, int64_t M, int64_t N, ...@@ -48,6 +48,5 @@ at::Tensor non_diag_mask(at::Tensor row, at::Tensor col, int64_t M, int64_t N,
return mask; return mask;
} }
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { static auto registry =
m.def("non_diag_mask", &non_diag_mask, "Non-Diagonal Mask (CPU)"); torch::RegisterOperators("torch_sparse_cpu::non_diag_mask", &non_diag_mask);
}
#include <torch/extension.h> #include <torch/script.h>
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor") #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
at::Tensor non_diag_mask_cuda(at::Tensor row, at::Tensor col, int64_t M, torch::Tensor non_diag_mask_cuda(torch::Tensor row, torch::Tensor col,
int64_t N, int64_t k); int64_t M, int64_t N, int64_t k);
at::Tensor non_diag_mask(at::Tensor row, at::Tensor col, int64_t M, int64_t N, torch::Tensor non_diag_mask(torch::Tensor row, torch::Tensor col, int64_t M,
int64_t k) { int64_t N, int64_t k) {
CHECK_CUDA(row); CHECK_CUDA(row);
CHECK_CUDA(col); CHECK_CUDA(col);
return non_diag_mask_cuda(row, col, M, N, k); return non_diag_mask_cuda(row, col, M, N, k);
} }
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { static auto registry = torch::RegisterOperators(
m.def("non_diag_mask", &non_diag_mask, "Non-Diagonal Mask (CUDA)"); "torch_sparse_cuda::non_diag_mask", &non_diag_mask);
}
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "compat.cuh" #include "compat.cuh"
...@@ -38,15 +38,15 @@ __global__ void non_diag_mask_kernel(const int64_t *row_data, ...@@ -38,15 +38,15 @@ __global__ void non_diag_mask_kernel(const int64_t *row_data,
} }
} }
at::Tensor non_diag_mask_cuda(at::Tensor row, at::Tensor col, int64_t M, torch::Tensor non_diag_mask_cuda(torch::Tensor row, torch::Tensor col,
int64_t N, int64_t k) { int64_t M, int64_t N, int64_t k) {
int64_t E = row.size(0); int64_t E = row.size(0);
int64_t num_diag = k < 0 ? std::min(M + k, N) : std::min(M, N - k); int64_t num_diag = k < 0 ? std::min(M + k, N) : std::min(M, N - k);
auto row_data = row.DATA_PTR<int64_t>(); auto row_data = row.DATA_PTR<int64_t>();
auto col_data = col.DATA_PTR<int64_t>(); auto col_data = col.DATA_PTR<int64_t>();
auto mask = at::zeros(E + num_diag, row.options().dtype(at::kBool)); auto mask = torch::zeros(E + num_diag, row.options().dtype(at::kBool));
auto mask_data = mask.DATA_PTR<bool>(); auto mask_data = mask.DATA_PTR<bool>();
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
......
import torch import torch
from torch_sparse import diag_cpu from .utils import ext
try:
from torch_sparse import diag_cuda
except ImportError:
diag_cuda = None
def remove_diag(src, k=0): def remove_diag(src, k=0):
...@@ -44,8 +39,8 @@ def set_diag(src, values=None, k=0): ...@@ -44,8 +39,8 @@ def set_diag(src, values=None, k=0):
row, col, value = src.coo() row, col, value = src.coo()
func = diag_cuda if row.is_cuda else diag_cpu mask = ext(row.is_cuda).non_diag_mask(row, col, src.size(0), src.size(1),
mask = func.non_diag_mask(row, col, src.size(0), src.size(1), k) k)
inv_mask = ~mask inv_mask = ~mask
start, num_diag = -k if k < 0 else 0, mask.numel() - row.numel() start, num_diag = -k if k < 0 else 0, mask.numel() - row.numel()
......
import torch import torch
torch.ops.load_library('torch_sparse/convert_cpu.so') torch.ops.load_library('torch_sparse/convert_cpu.so')
torch.ops.load_library('torch_sparse/diag_cpu.so')
try: try:
torch.ops.load_library('torch_sparse/convert_cuda.so') torch.ops.load_library('torch_sparse/convert_cuda.so')
except OSError: torch.ops.load_library('torch_sparse/diag_cuda.so')
pass except OSError as e:
if torch.cuda.is_available():
raise e
def ext(is_cuda): def ext(is_cuda):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment