Commit b4030755 authored by rusty1s's avatar rusty1s
Browse files

torch script support

parent d14110e1
#include <torch/extension.h>
#include <torch/script.h>
#include "compat.h"
#define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be CPU tensor")
at::Tensor ind2ptr(at::Tensor ind, int64_t M) {
torch::Tensor ind2ptr(torch::Tensor ind, int64_t M) {
CHECK_CPU(ind);
auto out = at::empty(M + 1, ind.options());
auto out = torch::empty(M + 1, ind.options());
auto ind_data = ind.DATA_PTR<int64_t>();
auto out_data = out.DATA_PTR<int64_t>();
......@@ -28,9 +28,9 @@ at::Tensor ind2ptr(at::Tensor ind, int64_t M) {
return out;
}
at::Tensor ptr2ind(at::Tensor ptr, int64_t E) {
torch::Tensor ptr2ind(torch::Tensor ptr, int64_t E) {
CHECK_CPU(ptr);
auto out = at::empty(E, ptr.options());
auto out = torch::empty(E, ptr.options());
auto ptr_data = ptr.DATA_PTR<int64_t>();
auto out_data = out.DATA_PTR<int64_t>();
......@@ -45,7 +45,6 @@ at::Tensor ptr2ind(at::Tensor ptr, int64_t E) {
return out;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("ind2ptr", &ind2ptr, "Ind2Ptr (CPU)");
m.def("ptr2ind", &ptr2ind, "Ptr2Ind (CPU)");
}
static auto registry =
torch::RegisterOperators("torch_sparse_cpu::ind2ptr", &ind2ptr)
.op("torch_sparse_cpu::ptr2ind", &ptr2ind);
#include <torch/extension.h>
#include <torch/script.h>
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
at::Tensor ind2ptr_cuda(at::Tensor ind, int64_t M);
at::Tensor ptr2ind_cuda(at::Tensor ptr, int64_t E);
torch::Tensor ind2ptr_cuda(torch::Tensor ind, int64_t M);
torch::Tensor ptr2ind_cuda(torch::Tensor ptr, int64_t E);
at::Tensor ind2ptr(at::Tensor ind, int64_t M) {
torch::Tensor ind2ptr(torch::Tensor ind, int64_t M) {
CHECK_CUDA(ind);
return ind2ptr_cuda(ind, M);
}
at::Tensor ptr2ind(at::Tensor ptr, int64_t E) {
torch::Tensor ptr2ind(torch::Tensor ptr, int64_t E) {
CHECK_CUDA(ptr);
return ptr2ind_cuda(ptr, E);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("ind2ptr", &ind2ptr, "Ind2Ptr (CUDA)");
m.def("ptr2ind", &ptr2ind, "Ptr2Ind (CUDA)");
}
static auto registry =
torch::RegisterOperators("torch_sparse_cuda::ind2ptr", &ind2ptr)
.op("torch_sparse_cuda::ptr2ind", &ptr2ind);
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cusparse.h>
#include <torch/extension.h>
#include "compat.cuh"
......@@ -23,8 +22,8 @@ __global__ void ind2ptr_kernel(const int64_t *ind_data, int64_t *out_data,
}
}
at::Tensor ind2ptr_cuda(at::Tensor ind, int64_t M) {
auto out = at::empty(M + 1, ind.options());
torch::Tensor ind2ptr_cuda(torch::Tensor ind, int64_t M) {
auto out = torch::empty(M + 1, ind.options());
auto ind_data = ind.DATA_PTR<int64_t>();
auto out_data = out.DATA_PTR<int64_t>();
auto stream = at::cuda::getCurrentCUDAStream();
......@@ -46,8 +45,8 @@ __global__ void ptr2ind_kernel(const int64_t *ptr_data, int64_t *out_data,
}
}
at::Tensor ptr2ind_cuda(at::Tensor ptr, int64_t E) {
auto out = at::empty(E, ptr.options());
torch::Tensor ptr2ind_cuda(torch::Tensor ptr, int64_t E) {
auto out = torch::empty(E, ptr.options());
auto ptr_data = ptr.DATA_PTR<int64_t>();
auto out_data = out.DATA_PTR<int64_t>();
auto stream = at::cuda::getCurrentCUDAStream();
......
......@@ -5,6 +5,7 @@ from setuptools import setup, find_packages
from sys import argv
import torch
from torch.utils.cpp_extension import BuildExtension
from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME
cxx_extra_compile_args = []
......@@ -16,7 +17,9 @@ extra_compile_args = []
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2):
cxx_extra_compile_args += ['-DVERSION_GE_1_3']
nvcc_extra_compile_args += ['-DVERSION_GE_1_3']
cmdclass = {'build_ext': torch.utils.cpp_extension.BuildExtension}
cmdclass = {
'build_ext': BuildExtension.with_options(no_python_abi_suffix=True)
}
ext_modules = []
exts = [e.split(osp.sep)[-1][:-4] for e in glob(osp.join('cpu', '*.cpp'))]
......
import torch
import torch_scatter
from .utils.unique import unique
from .unique import unique
def coalesce(index, value, m, n, op='add', fill_value=0):
......
......@@ -2,13 +2,7 @@ import warnings
import torch
from torch_scatter import segment_csr, scatter_add
from torch_sparse import convert_cpu
try:
from torch_sparse import convert_cuda
except ImportError:
convert_cuda = None
from .utils import ext
__cache__ = {'enabled': True}
......@@ -167,8 +161,8 @@ class SparseStorage(object):
@property
def row(self):
if self._row is None:
func = convert_cuda if self.rowptr.is_cuda else convert_cpu
self._row = func.ptr2ind(self.rowptr, self.col.numel())
self._row = ext(self.col.is_cuda).ptr2ind(self.rowptr,
self.col.numel())
return self._row
def has_rowptr(self):
......@@ -177,8 +171,8 @@ class SparseStorage(object):
@property
def rowptr(self):
if self._rowptr is None:
func = convert_cuda if self.row.is_cuda else convert_cpu
self._rowptr = func.ind2ptr(self.row, self.sparse_size[0])
self._rowptr = ext(self.col.is_cuda).ind2ptr(
self.row, self.sparse_size[0])
return self._rowptr
@property
......@@ -279,8 +273,8 @@ class SparseStorage(object):
@cached_property
def colptr(self):
if self.has_csr2csc():
func = convert_cuda if self.col.is_cuda else convert_cpu
return func.ind2ptr(self.col[self.csr2csc], self.sparse_size[1])
return ext(self.col.is_cuda).ind2ptr(self.col[self.csr2csc],
self.sparse_size[1])
else:
colptr = self.col.new_zeros(self.sparse_size[1] + 1)
torch.cumsum(self.colcount, dim=0, out=colptr[1:])
......
import torch
torch.ops.load_library('torch_sparse/convert_cpu.so')
try:
torch.ops.load_library('torch_sparse/convert_cuda.so')
except OSError:
pass
def ext(is_cuda):
name = 'torch_sparse_cuda' if is_cuda else 'torch_sparse_cpu'
return getattr(torch.ops, name)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment