"cacheflow/git@developer.sourcefind.cn:norm/vllm.git" did not exist on "709a69176ea86f60786acb87ede52e62d9efb036"
Commit b4030755 authored by rusty1s's avatar rusty1s
Browse files

torch script support

parent d14110e1
#include <torch/extension.h> #include <torch/script.h>
#include "compat.h" #include "compat.h"
#define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be CPU tensor") #define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be CPU tensor")
at::Tensor ind2ptr(at::Tensor ind, int64_t M) { torch::Tensor ind2ptr(torch::Tensor ind, int64_t M) {
CHECK_CPU(ind); CHECK_CPU(ind);
auto out = at::empty(M + 1, ind.options()); auto out = torch::empty(M + 1, ind.options());
auto ind_data = ind.DATA_PTR<int64_t>(); auto ind_data = ind.DATA_PTR<int64_t>();
auto out_data = out.DATA_PTR<int64_t>(); auto out_data = out.DATA_PTR<int64_t>();
...@@ -28,9 +28,9 @@ at::Tensor ind2ptr(at::Tensor ind, int64_t M) { ...@@ -28,9 +28,9 @@ at::Tensor ind2ptr(at::Tensor ind, int64_t M) {
return out; return out;
} }
at::Tensor ptr2ind(at::Tensor ptr, int64_t E) { torch::Tensor ptr2ind(torch::Tensor ptr, int64_t E) {
CHECK_CPU(ptr); CHECK_CPU(ptr);
auto out = at::empty(E, ptr.options()); auto out = torch::empty(E, ptr.options());
auto ptr_data = ptr.DATA_PTR<int64_t>(); auto ptr_data = ptr.DATA_PTR<int64_t>();
auto out_data = out.DATA_PTR<int64_t>(); auto out_data = out.DATA_PTR<int64_t>();
...@@ -45,7 +45,6 @@ at::Tensor ptr2ind(at::Tensor ptr, int64_t E) { ...@@ -45,7 +45,6 @@ at::Tensor ptr2ind(at::Tensor ptr, int64_t E) {
return out; return out;
} }
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { static auto registry =
m.def("ind2ptr", &ind2ptr, "Ind2Ptr (CPU)"); torch::RegisterOperators("torch_sparse_cpu::ind2ptr", &ind2ptr)
m.def("ptr2ind", &ptr2ind, "Ptr2Ind (CPU)"); .op("torch_sparse_cpu::ptr2ind", &ptr2ind);
}
#include <torch/extension.h> #include <torch/script.h>
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor") #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
at::Tensor ind2ptr_cuda(at::Tensor ind, int64_t M); torch::Tensor ind2ptr_cuda(torch::Tensor ind, int64_t M);
at::Tensor ptr2ind_cuda(at::Tensor ptr, int64_t E); torch::Tensor ptr2ind_cuda(torch::Tensor ptr, int64_t E);
at::Tensor ind2ptr(at::Tensor ind, int64_t M) { torch::Tensor ind2ptr(torch::Tensor ind, int64_t M) {
CHECK_CUDA(ind); CHECK_CUDA(ind);
return ind2ptr_cuda(ind, M); return ind2ptr_cuda(ind, M);
} }
at::Tensor ptr2ind(at::Tensor ptr, int64_t E) { torch::Tensor ptr2ind(torch::Tensor ptr, int64_t E) {
CHECK_CUDA(ptr); CHECK_CUDA(ptr);
return ptr2ind_cuda(ptr, E); return ptr2ind_cuda(ptr, E);
} }
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { static auto registry =
m.def("ind2ptr", &ind2ptr, "Ind2Ptr (CUDA)"); torch::RegisterOperators("torch_sparse_cuda::ind2ptr", &ind2ptr)
m.def("ptr2ind", &ptr2ind, "Ptr2Ind (CUDA)"); .op("torch_sparse_cuda::ptr2ind", &ptr2ind);
}
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <cusparse.h> #include <torch/extension.h>
#include "compat.cuh" #include "compat.cuh"
...@@ -23,8 +22,8 @@ __global__ void ind2ptr_kernel(const int64_t *ind_data, int64_t *out_data, ...@@ -23,8 +22,8 @@ __global__ void ind2ptr_kernel(const int64_t *ind_data, int64_t *out_data,
} }
} }
at::Tensor ind2ptr_cuda(at::Tensor ind, int64_t M) { torch::Tensor ind2ptr_cuda(torch::Tensor ind, int64_t M) {
auto out = at::empty(M + 1, ind.options()); auto out = torch::empty(M + 1, ind.options());
auto ind_data = ind.DATA_PTR<int64_t>(); auto ind_data = ind.DATA_PTR<int64_t>();
auto out_data = out.DATA_PTR<int64_t>(); auto out_data = out.DATA_PTR<int64_t>();
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
...@@ -46,8 +45,8 @@ __global__ void ptr2ind_kernel(const int64_t *ptr_data, int64_t *out_data, ...@@ -46,8 +45,8 @@ __global__ void ptr2ind_kernel(const int64_t *ptr_data, int64_t *out_data,
} }
} }
at::Tensor ptr2ind_cuda(at::Tensor ptr, int64_t E) { torch::Tensor ptr2ind_cuda(torch::Tensor ptr, int64_t E) {
auto out = at::empty(E, ptr.options()); auto out = torch::empty(E, ptr.options());
auto ptr_data = ptr.DATA_PTR<int64_t>(); auto ptr_data = ptr.DATA_PTR<int64_t>();
auto out_data = out.DATA_PTR<int64_t>(); auto out_data = out.DATA_PTR<int64_t>();
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
......
...@@ -5,6 +5,7 @@ from setuptools import setup, find_packages ...@@ -5,6 +5,7 @@ from setuptools import setup, find_packages
from sys import argv from sys import argv
import torch import torch
from torch.utils.cpp_extension import BuildExtension
from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME
cxx_extra_compile_args = [] cxx_extra_compile_args = []
...@@ -16,7 +17,9 @@ extra_compile_args = [] ...@@ -16,7 +17,9 @@ extra_compile_args = []
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2): if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2):
cxx_extra_compile_args += ['-DVERSION_GE_1_3'] cxx_extra_compile_args += ['-DVERSION_GE_1_3']
nvcc_extra_compile_args += ['-DVERSION_GE_1_3'] nvcc_extra_compile_args += ['-DVERSION_GE_1_3']
cmdclass = {'build_ext': torch.utils.cpp_extension.BuildExtension} cmdclass = {
'build_ext': BuildExtension.with_options(no_python_abi_suffix=True)
}
ext_modules = [] ext_modules = []
exts = [e.split(osp.sep)[-1][:-4] for e in glob(osp.join('cpu', '*.cpp'))] exts = [e.split(osp.sep)[-1][:-4] for e in glob(osp.join('cpu', '*.cpp'))]
......
import torch import torch
import torch_scatter import torch_scatter
from .utils.unique import unique from .unique import unique
def coalesce(index, value, m, n, op='add', fill_value=0): def coalesce(index, value, m, n, op='add', fill_value=0):
......
...@@ -2,13 +2,7 @@ import warnings ...@@ -2,13 +2,7 @@ import warnings
import torch import torch
from torch_scatter import segment_csr, scatter_add from torch_scatter import segment_csr, scatter_add
from .utils import ext
from torch_sparse import convert_cpu
try:
from torch_sparse import convert_cuda
except ImportError:
convert_cuda = None
__cache__ = {'enabled': True} __cache__ = {'enabled': True}
...@@ -167,8 +161,8 @@ class SparseStorage(object): ...@@ -167,8 +161,8 @@ class SparseStorage(object):
@property @property
def row(self): def row(self):
if self._row is None: if self._row is None:
func = convert_cuda if self.rowptr.is_cuda else convert_cpu self._row = ext(self.col.is_cuda).ptr2ind(self.rowptr,
self._row = func.ptr2ind(self.rowptr, self.col.numel()) self.col.numel())
return self._row return self._row
def has_rowptr(self): def has_rowptr(self):
...@@ -177,8 +171,8 @@ class SparseStorage(object): ...@@ -177,8 +171,8 @@ class SparseStorage(object):
@property @property
def rowptr(self): def rowptr(self):
if self._rowptr is None: if self._rowptr is None:
func = convert_cuda if self.row.is_cuda else convert_cpu self._rowptr = ext(self.col.is_cuda).ind2ptr(
self._rowptr = func.ind2ptr(self.row, self.sparse_size[0]) self.row, self.sparse_size[0])
return self._rowptr return self._rowptr
@property @property
...@@ -279,8 +273,8 @@ class SparseStorage(object): ...@@ -279,8 +273,8 @@ class SparseStorage(object):
@cached_property @cached_property
def colptr(self): def colptr(self):
if self.has_csr2csc(): if self.has_csr2csc():
func = convert_cuda if self.col.is_cuda else convert_cpu return ext(self.col.is_cuda).ind2ptr(self.col[self.csr2csc],
return func.ind2ptr(self.col[self.csr2csc], self.sparse_size[1]) self.sparse_size[1])
else: else:
colptr = self.col.new_zeros(self.sparse_size[1] + 1) colptr = self.col.new_zeros(self.sparse_size[1] + 1)
torch.cumsum(self.colcount, dim=0, out=colptr[1:]) torch.cumsum(self.colcount, dim=0, out=colptr[1:])
......
import torch
torch.ops.load_library('torch_sparse/convert_cpu.so')
try:
torch.ops.load_library('torch_sparse/convert_cuda.so')
except OSError:
pass
def ext(is_cuda):
name = 'torch_sparse_cuda' if is_cuda else 'torch_sparse_cpu'
return getattr(torch.ops, name)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment