Commit d9565693 authored by rusty1s's avatar rusty1s
Browse files

basic thrust boilerplate

parent 4ceb2d1a
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
#include <torch/extension.h>
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
void segment_add_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
int64_t dim);
void segment_add(at::Tensor src, at::Tensor index, at::Tensor out,
int64_t dim) {
CHECK_CUDA(src);
CHECK_CUDA(index);
CHECK_CUDA(out);
segment_add_cuda(src, index, out, dim);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("segment_add", &segment_add, "Segment Add (CUDA)");
}
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THCGeneral.h>
#include <THC/THCThrustAllocator.cuh>
#include <thrust/execution_policy.h>
#include "compat.cuh"
void segment_add_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
int64_t dim) {
cudaSetDevice(src.get_device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
auto allocator = THCThrustAllocator(at::globalContext().lazyInitCUDA());
auto policy = thrust::cuda::par(allocator).on(stream);
auto index_data = thrust::device_ptr<int64_t>(index.DATA_PTR<int64_t>());
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_add_kernel", [&] {
auto src_data = thrust::device_ptr<scalar_t>(src.DATA_PTR<scalar_t>());
auto out_data = thrust::device_ptr<scalar_t>(out.DATA_PTR<scalar_t>());
});
}
...@@ -8,15 +8,14 @@ TORCH_MAJOR = int(torch.__version__.split('.')[0]) ...@@ -8,15 +8,14 @@ TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1]) TORCH_MINOR = int(torch.__version__.split('.')[1])
extra_compile_args = [] extra_compile_args = []
if platform.system() != 'Windows':
extra_compile_args += ['-Wno-unused-variable']
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2): if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2):
extra_compile_args += ['-DVERSION_GE_1_3'] extra_compile_args += ['-DVERSION_GE_1_3']
ext_modules = [ ext_modules = [
CppExtension('torch_scatter.scatter_cpu', ['cpu/scatter.cpp'], CppExtension(
extra_compile_args=extra_compile_args) 'torch_scatter.scatter_cpu', ['cpu/scatter.cpp'],
extra_compile_args=extra_compile_args +
['-Wno-unused-variable'] if platform.system() != 'Windows' else [])
] ]
cmdclass = {'build_ext': torch.utils.cpp_extension.BuildExtension} cmdclass = {'build_ext': torch.utils.cpp_extension.BuildExtension}
...@@ -29,7 +28,11 @@ for arg in argv: ...@@ -29,7 +28,11 @@ for arg in argv:
if CUDA_HOME is not None and GPU: if CUDA_HOME is not None and GPU:
ext_modules += [ ext_modules += [
CUDAExtension('torch_scatter.scatter_cuda', CUDAExtension('torch_scatter.scatter_cuda',
['cuda/scatter.cpp', 'cuda/scatter_kernel.cu']) ['cuda/scatter.cpp', 'cuda/scatter_kernel.cu'],
extra_compile_args=extra_compile_args),
CUDAExtension('torch_scatter.segment_cuda',
['cuda/segment.cpp', 'cuda/segment_kernel.cu'],
extra_compile_args=extra_compile_args),
] ]
__version__ = '1.4.0' __version__ = '1.4.0'
......
from itertools import product
import pytest
import torch
from torch_scatter import segment_add
from .utils import tensor
dtypes = [torch.float]
devices = [torch.device('cuda')]
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_forward(dtype, device):
src = tensor([1, 2, 3, 4, 5, 6], dtype, device)
index = tensor([0, 0, 1, 1, 1, 2], torch.long, device)
out = segment_add(src, index, dim=0)
print(out)
...@@ -7,6 +7,9 @@ from .std import scatter_std ...@@ -7,6 +7,9 @@ from .std import scatter_std
from .max import scatter_max from .max import scatter_max
from .min import scatter_min from .min import scatter_min
from .logsumexp import scatter_logsumexp from .logsumexp import scatter_logsumexp
from .segment import segment_add
import torch_scatter.composite import torch_scatter.composite
__version__ = '1.4.0' __version__ = '1.4.0'
...@@ -21,6 +24,7 @@ __all__ = [ ...@@ -21,6 +24,7 @@ __all__ = [
'scatter_max', 'scatter_max',
'scatter_min', 'scatter_min',
'scatter_logsumexp', 'scatter_logsumexp',
'segment_add',
'torch_scatter', 'torch_scatter',
'__version__', '__version__',
] ]
import torch
from torch_scatter.utils.gen import gen
if torch.cuda.is_available():
import torch_scatter.segment_cuda
def segment_add(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value)
if src.size(dim) == 0: # pragma: no cover
return out
assert src.is_cuda
torch_scatter.segment_cuda.segment_add(src, index, out, dim)
return out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment