Commit 1b316a63 authored by rusty1s's avatar rusty1s
Browse files

basic segment_add functionality

parent d9565693
......@@ -2,15 +2,15 @@
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
void segment_add_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
int64_t dim);
std::tuple<at::Tensor, at::Tensor>
segment_add_cuda(at::Tensor src, at::Tensor index, at::Tensor out);
void segment_add(at::Tensor src, at::Tensor index, at::Tensor out,
int64_t dim) {
std::tuple<at::Tensor, at::Tensor> segment_add(at::Tensor src, at::Tensor index,
at::Tensor out) {
CHECK_CUDA(src);
CHECK_CUDA(index);
CHECK_CUDA(out);
segment_add_cuda(src, index, out, dim);
return segment_add_cuda(src, index, out);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
......
......@@ -8,16 +8,25 @@
#include "compat.cuh"
void segment_add_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
int64_t dim) {
std::tuple<at::Tensor, at::Tensor>
segment_add_cuda(at::Tensor src, at::Tensor index, at::Tensor out) {
cudaSetDevice(src.get_device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
auto allocator = THCThrustAllocator(at::globalContext().lazyInitCUDA());
auto policy = thrust::cuda::par(allocator).on(stream);
auto key = at::full_like(out, -1, out.options().dtype(at::kLong));
auto index_data = thrust::device_ptr<int64_t>(index.DATA_PTR<int64_t>());
auto key_data = thrust::device_ptr<int64_t>(key.DATA_PTR<int64_t>());
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_add_kernel", [&] {
auto src_data = thrust::device_ptr<scalar_t>(src.DATA_PTR<scalar_t>());
auto out_data = thrust::device_ptr<scalar_t>(out.DATA_PTR<scalar_t>());
thrust::reduce_by_key(policy, index_data, index_data + index.size(0),
src_data, key_data, out_data);
});
return std::make_tuple(out, key);
}
......@@ -13,7 +13,8 @@ devices = [torch.device('cuda')]
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_forward(dtype, device):
src = tensor([1, 2, 3, 4, 5, 6], dtype, device)
index = tensor([0, 0, 1, 1, 1, 2], torch.long, device)
index = tensor([0, 0, 1, 1, 1, 3], torch.long, device)
out = segment_add(src, index, dim=0)
out, key = segment_add(src, index, dim=0)
print(out)
print(key)
......@@ -11,5 +11,5 @@ def segment_add(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
if src.size(dim) == 0: # pragma: no cover
return out
assert src.is_cuda
torch_scatter.segment_cuda.segment_add(src, index, out, dim)
return out
out, key = torch_scatter.segment_cuda.segment_add(src, index, out)
return out, key
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment