Commit 0ad76a83 authored by rusty1s's avatar rusty1s
Browse files

warp parallel segment implementation

parent 1b316a63
...@@ -2,17 +2,23 @@ ...@@ -2,17 +2,23 @@
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor") #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
std::tuple<at::Tensor, at::Tensor> at::Tensor segment_add_cuda(at::Tensor src, at::Tensor indptr, int64_t dim);
segment_add_cuda(at::Tensor src, at::Tensor index, at::Tensor out); void segment_add_thrust_cuda(at::Tensor src, at::Tensor index, at::Tensor out);
std::tuple<at::Tensor, at::Tensor> segment_add(at::Tensor src, at::Tensor index, at::Tensor segment_add(at::Tensor src, at::Tensor indptr, int64_t dim) {
at::Tensor out) { CHECK_CUDA(src);
CHECK_CUDA(indptr);
return segment_add_cuda(src, indptr, dim);
}
void segment_add_thrust(at::Tensor src, at::Tensor index, at::Tensor out) {
CHECK_CUDA(src); CHECK_CUDA(src);
CHECK_CUDA(index); CHECK_CUDA(index);
CHECK_CUDA(out); CHECK_CUDA(out);
return segment_add_cuda(src, index, out); return segment_add_thrust_cuda(src, index, out);
} }
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("segment_add", &segment_add, "Segment Add (CUDA)"); m.def("segment_add", &segment_add, "Segment Add (CUDA)");
m.def("segment_add_thrust", &segment_add_thrust, "Segment Add Thrust (CUDA)");
} }
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <THC/THCGeneral.h> #include <THC/THCGeneral.h>
#include <THC/THCThrustAllocator.cuh> #include <THC/THCThrustAllocator.cuh>
...@@ -8,10 +10,57 @@ ...@@ -8,10 +10,57 @@
#include "compat.cuh" #include "compat.cuh"
std::tuple<at::Tensor, at::Tensor> #define THREADS 256
segment_add_cuda(at::Tensor src, at::Tensor index, at::Tensor out) { #define FULL_MASK 0xffffffff
cudaSetDevice(src.get_device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); template <typename scalar_t, int TB>
__global__ void segment_add_kernel(const scalar_t *src_data,
const int64_t *indptr_data,
scalar_t *out_data, size_t numel) {
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int warp_idx = thread_idx / TB;
int lane_idx = thread_idx & (TB - 1);
if (warp_idx < numel) {
int row_start = __ldg(indptr_data + warp_idx);
int row_end = __ldg(indptr_data + warp_idx + 1);
scalar_t val = (scalar_t)0;
for (int src_idx = row_start + lane_idx; src_idx < row_end; src_idx += TB) {
val += __ldg(src_data + src_idx);
}
#pragma unroll
for (int offset = TB / 2; offset > 0; offset /= 2)
val += __shfl_down_sync(FULL_MASK, val, offset); // Parallel reduction.
if (lane_idx == 0) {
out_data[warp_idx] = val;
}
}
}
at::Tensor segment_add_cuda(at::Tensor src, at::Tensor indptr, int64_t dim) {
auto numel = indptr.numel() - 1;
auto out = at::empty({numel}, src.options());
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_add_kernel", [&] {
auto indptr_data = indptr.DATA_PTR<int64_t>();
auto src_data = src.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>();
segment_add_kernel<scalar_t, 32>
<<<(32 * numel + THREADS - 1) / THREADS, THREADS, 0, stream>>>(
src_data, indptr_data, out_data, numel);
});
return out;
}
void segment_add_thrust_cuda(at::Tensor src, at::Tensor index, at::Tensor out) {
auto stream = at::cuda::getCurrentCUDAStream();
auto allocator = THCThrustAllocator(at::globalContext().lazyInitCUDA()); auto allocator = THCThrustAllocator(at::globalContext().lazyInitCUDA());
auto policy = thrust::cuda::par(allocator).on(stream); auto policy = thrust::cuda::par(allocator).on(stream);
...@@ -20,13 +69,11 @@ segment_add_cuda(at::Tensor src, at::Tensor index, at::Tensor out) { ...@@ -20,13 +69,11 @@ segment_add_cuda(at::Tensor src, at::Tensor index, at::Tensor out) {
auto index_data = thrust::device_ptr<int64_t>(index.DATA_PTR<int64_t>()); auto index_data = thrust::device_ptr<int64_t>(index.DATA_PTR<int64_t>());
auto key_data = thrust::device_ptr<int64_t>(key.DATA_PTR<int64_t>()); auto key_data = thrust::device_ptr<int64_t>(key.DATA_PTR<int64_t>());
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_add_kernel", [&] { AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_add_thrust_kernel", [&] {
auto src_data = thrust::device_ptr<scalar_t>(src.DATA_PTR<scalar_t>()); auto src_data = thrust::device_ptr<scalar_t>(src.DATA_PTR<scalar_t>());
auto out_data = thrust::device_ptr<scalar_t>(out.DATA_PTR<scalar_t>()); auto out_data = thrust::device_ptr<scalar_t>(out.DATA_PTR<scalar_t>());
thrust::reduce_by_key(policy, index_data, index_data + index.size(0), thrust::reduce_by_key(policy, index_data, index_data + index.numel(),
src_data, key_data, out_data); src_data, key_data, out_data);
}); });
return std::make_tuple(out, key);
} }
import time
from itertools import product from itertools import product
import pytest import pytest
import torch import torch
from torch_scatter import segment_add from torch_scatter import segment_add, scatter_add
from torch_scatter.segment import segment_add2
from .utils import tensor from .utils import tensor
...@@ -14,7 +16,61 @@ devices = [torch.device('cuda')] ...@@ -14,7 +16,61 @@ devices = [torch.device('cuda')]
def test_forward(dtype, device): def test_forward(dtype, device):
src = tensor([1, 2, 3, 4, 5, 6], dtype, device) src = tensor([1, 2, 3, 4, 5, 6], dtype, device)
index = tensor([0, 0, 1, 1, 1, 3], torch.long, device) index = tensor([0, 0, 1, 1, 1, 3], torch.long, device)
out = segment_add(src, index, dim=0)
print('Thrust', out)
out, key = segment_add(src, index, dim=0)
print(out) @pytest.mark.parametrize('dtype,device', product(dtypes, devices))
print(key) def test_forward2(dtype, device):
src = tensor([1, 2, 3, 4, 5, 6], dtype, device)
# indptr = tensor([0, 2, 5, 5, 6], torch.long, device)
indptr = tensor([[0, 2, 5, 5, 6]], torch.long, device)
out = segment_add2(src, indptr, dim=0)
print('My', out)
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_benchmark(dtype, device):
from torch_geometric.datasets import Planetoid, Reddit # noqa
data = Planetoid('/tmp/Cora', 'Cora')[0].to(device)
# data = Planetoid('/tmp/PubMed', 'PubMed')[0].to(device)
data = Reddit('/tmp/Reddit')[0].to(device)
row, col = data.edge_index
x = torch.randn(data.num_edges, device=device)
print(row.size(0) / data.num_nodes)
# Warmup
for _ in range(10):
torch.randn(100, 100, device=device).sum()
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(100):
out1 = scatter_add(x, row, dim=0, dim_size=data.num_nodes)
torch.cuda.synchronize()
print(time.perf_counter() - t)
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(100):
out2 = segment_add(x, row, dim=0, dim_size=data.num_nodes)
torch.cuda.synchronize()
print(time.perf_counter() - t)
assert torch.allclose(out1, out2, atol=1e-2)
rowcount = segment_add(torch.ones_like(row), row)
rowptr = torch.cat([rowcount.new_zeros(1), rowcount.cumsum(0)], dim=0)
torch.cuda.synchronize()
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(100):
out3 = segment_add2(x, rowptr, dim=0)
torch.cuda.synchronize()
print(time.perf_counter() - t)
assert torch.allclose(out1, out3, atol=1e-2)
import torch import torch
from torch_scatter.utils.gen import gen from torch_scatter.utils.gen import gen
from torch_scatter.add import scatter_add
if torch.cuda.is_available(): if torch.cuda.is_available():
import torch_scatter.segment_cuda import torch_scatter.segment_cuda
...@@ -10,6 +11,23 @@ def segment_add(src, index, dim=-1, out=None, dim_size=None, fill_value=0): ...@@ -10,6 +11,23 @@ def segment_add(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value) src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value)
if src.size(dim) == 0: # pragma: no cover if src.size(dim) == 0: # pragma: no cover
return out return out
assert src.is_cuda
out, key = torch_scatter.segment_cuda.segment_add(src, index, out) if not src.is_cuda:
return out, key return scatter_add(src, index, dim, out, dim_size, fill_value)
# index = index.transpose(dim, -1).contiguous()
# src = src.transpose(dim, -1).contiguous()
# out = out.transpose(dim, -1).contiguous()
# print(index)
# print(src)
torch_scatter.segment_cuda.segment_add_thrust(src, index, out)
# out = out.transpose(dim, -1).contiguous()
# key = key.transpose(dim, -1).contiguous()
return out
def segment_add2(src, indptr, dim=-1):
return torch_scatter.segment_cuda.segment_add(src, indptr, dim)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment