Commit d4325fd1 authored by rusty1s's avatar rusty1s
Browse files

segment coo kernels

parent 0b3069fe
......@@ -20,15 +20,7 @@ void segment_add_coo(at::Tensor src, at::Tensor index, at::Tensor out) {
segment_add_coo_cuda(src, index, out);
}
void segment_add_thrust(at::Tensor src, at::Tensor index, at::Tensor out) {
CHECK_CUDA(src);
CHECK_CUDA(index);
CHECK_CUDA(out);
return segment_add_thrust_cuda(src, index, out);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("segment_add_csr", &segment_add_csr, "Segment Add CSR (CUDA)");
m.def("segment_add_coo", &segment_add_coo, "Segment Add COO (CUDA)");
m.def("segment_add_thrust", &segment_add_thrust, "Segment Add Thrust (CUDA)");
}
......@@ -3,19 +3,15 @@
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <THC/THCGeneral.h>
#include <THC/THCThrustAllocator.cuh>
#include <thrust/execution_policy.h>
#include "atomics.cuh"
#include "compat.cuh"
#include "index.cuh"
#define THREADS 256
#define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS
#define FULL_MASK 0xffffffff
// We need our own `IndexToOffset` implementation since we do not want to access
// the last element of the `indexptr`.
template <typename T, typename I> struct IndexPtrToOffset {
static __host__ __device__ I
get(I idx, const at::cuda::detail::TensorInfo<T, I> &info) {
......@@ -36,12 +32,15 @@ __global__ void segment_add_csr_kernel(
const at::cuda::detail::TensorInfo<int64_t, int> indptr_info,
scalar_t *out_data, size_t N, size_t E) {
// Each warp processes exactly `32/TB` rows. We usually set `TB=32` and only
// make use of it in case the average row length is less than 32.
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int row_idx = thread_idx / TB;
int lane_idx = thread_idx & (TB - 1);
if (row_idx < N) {
auto offset = IndexPtrToOffset<int64_t, int>::get(row_idx, indptr_info);
int offset = IndexPtrToOffset<int64_t, int>::get(row_idx, indptr_info);
int row_start = __ldg(indptr_info.data + offset);
int row_end = __ldg(indptr_info.data + offset +
indptr_info.strides[indptr_info.dims - 1]);
......@@ -49,15 +48,17 @@ __global__ void segment_add_csr_kernel(
offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E;
for (int src_idx = row_start + lane_idx; src_idx < row_end; src_idx += TB) {
val += src_data[offset + src_idx];
val += src_data[offset + src_idx]; // "Mostly" coalesced read.
}
#pragma unroll
for (int i = TB / 2; i > 0; i /= 2)
val += __shfl_down_sync(FULL_MASK, val, i); // Parallel reduction
for (int i = TB / 2; i > 0; i /= 2) {
// Parallel reduction inside a single warp.
val += __shfl_down_sync(FULL_MASK, val, i);
}
if (lane_idx == 0) {
out_data[row_idx] = val;
out_data[row_idx] = val; // "Mostly" coalesced write.
}
}
}
......@@ -68,12 +69,15 @@ __global__ void segment_add_csr_broadcast_kernel(
const at::cuda::detail::TensorInfo<int64_t, int> indptr_info,
scalar_t *out_data, size_t N, size_t K, size_t E) {
// Each thread processes exactly one row. It turned out that is more efficient
// than using shared memory due to avoiding synchronization barriers.
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int row_idx = thread_idx / K;
int lane_idx = thread_idx % K;
if (thread_idx < N * K) {
auto offset = IndexPtrToOffset<int64_t, int>::get(row_idx, indptr_info);
int offset = IndexPtrToOffset<int64_t, int>::get(row_idx, indptr_info);
int row_start = __ldg(indptr_info.data + offset);
int row_end = __ldg(indptr_info.data + offset +
indptr_info.strides[indptr_info.dims - 1]);
......@@ -81,53 +85,10 @@ __global__ void segment_add_csr_broadcast_kernel(
offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E * K;
for (int src_idx = row_start; src_idx < row_end; src_idx++) {
// Coalesced read into `src_data`.
val += src_data[offset + K * src_idx + lane_idx];
}
out_data[thread_idx] = val; // Coalesced write into `out_data`
}
}
template <typename scalar_t, int TB>
__global__ void segment_add_csr_broadcast_kernel2(
const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> indptr_info,
scalar_t *out_data, size_t N, size_t K, size_t E) {
int thread_idx = blockIdx.y * blockDim.y + threadIdx.y;
int row_idx = thread_idx / TB;
int lane_idx = thread_idx & (TB - 1);
int col_idx = blockIdx.x * blockDim.x + threadIdx.x;
__shared__ scalar_t vals[32][32];
if (row_idx < N) {
auto offset = IndexPtrToOffset<int64_t, int>::get(row_idx, indptr_info);
int row_start = __ldg(indptr_info.data + offset);
int row_end = __ldg(indptr_info.data + offset +
indptr_info.strides[indptr_info.dims - 1]);
scalar_t val = (scalar_t)0;
offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E * K;
if (col_idx < K) {
for (int i = row_start + lane_idx; i < row_end; i += TB) {
val += src_data[offset + K * i + col_idx];
}
}
vals[threadIdx.x][threadIdx.y] = val;
__syncthreads();
#pragma unroll
for (int i = 1; i < TB; i *= 2) {
vals[threadIdx.x][threadIdx.y] += vals[threadIdx.x][threadIdx.y + i];
__syncthreads();
val += src_data[offset + K * src_idx + lane_idx]; // Coalesced read.
}
if (col_idx < K && lane_idx == 0) {
out_data[row_idx * K + col_idx] = vals[threadIdx.x][threadIdx.y];
}
out_data[thread_idx] = val; // Coalesced write.
}
}
......@@ -150,10 +111,12 @@ at::Tensor segment_add_csr_cuda(at::Tensor src, at::Tensor indptr) {
auto indptr_info = at::cuda::detail::getTensorInfo<int64_t, int>(indptr);
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_add_csr_kernel", [&] {
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_csr_kernel", [&] {
auto src_data = src.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>();
// Select the right kernel based on average row length and whether we need
// broadcasting capabilties (K > 1):
if (K == 1 && avg_length <= 4) {
segment_add_csr_kernel<scalar_t, 4><<<BLOCKS(4, N), THREADS, 0, stream>>>(
src_data, indptr_info, out_data, N, E);
......@@ -178,65 +141,120 @@ at::Tensor segment_add_csr_cuda(at::Tensor src, at::Tensor indptr) {
return out;
}
template <typename scalar_t, int TB>
__global__ void segment_add_coo_kernel(const scalar_t *src_data,
const int64_t *index_data,
scalar_t *out_data, size_t numel) {
template <typename scalar_t>
__global__ void segment_add_coo_kernel(
const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> index_info,
scalar_t *out_data, size_t E) {
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int lane_idx = thread_idx & (TB - 1);
// Each thread processes exactly one entry. Within a warp, we perform a
// parallel reduction across equal indices, and write the intermediate
// result via atomics.
int row_idx = blockIdx.x * blockDim.x + threadIdx.x;
int lane_idx = row_idx & (32 - 1);
if (thread_idx < numel) {
auto idx = __ldg(index_data + thread_idx);
scalar_t val = src_data[thread_idx], tmp;
if (row_idx < E) {
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
row_idx, index_info);
int idx = index_info.data[offset], next_idx;
scalar_t val = src_data[row_idx], tmp;
#pragma unroll
for (int offset = 1; offset < TB; offset *= 2) {
tmp = __shfl_up_sync(FULL_MASK, val, offset);
int idx_next = __ldg(index_data + thread_idx - offset);
// AT_ASSERTM(lane_idx < offset || idx <= idx_next);
if (lane_idx >= offset && idx == idx_next) {
for (int i = 1; i < 32; i *= 2) {
tmp = __shfl_up_sync(FULL_MASK, val, i);
next_idx = __shfl_up_sync(FULL_MASK, idx, i);
if (lane_idx >= i && idx == next_idx)
val += tmp;
}
}
if (lane_idx == TB - 1 || idx != __ldg(index_data + thread_idx + 1)) {
next_idx = __shfl_down_sync(FULL_MASK, idx, 1);
if (lane_idx == 32 - 1 || idx != next_idx) {
atomAdd(out_data + idx, val);
}
}
}
void segment_add_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out) {
auto numel = src.numel();
auto avg_length = (float)numel / (float)out.numel();
template <typename scalar_t, int TB>
__global__ void segment_add_coo_broadcast_kernel(
const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> index_info,
scalar_t *out_data, size_t E, size_t K) {
auto index_data = index.DATA_PTR<int64_t>();
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_add_coo_kernel", [&] {
auto src_data = src.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>();
// Each thread processes a single column and `TB` rows. Coalesced read and
// write is performed in column-major order. The intermediate results are
// written via atomics.
segment_add_coo_kernel<scalar_t, 32>
<<<BLOCKS(1, numel), THREADS, 0, stream>>>(src_data, index_data,
out_data, numel);
});
int row_start = (blockIdx.x * blockDim.y + threadIdx.y) * TB;
int col_idx = blockIdx.y * blockDim.x + threadIdx.x;
if (row_start < E && col_idx < K) {
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
row_start, index_info);
int idx1 = __ldg(index_info.data + offset);
scalar_t val = src_data[K * row_start + col_idx];
#pragma unroll
for (int i = 1; i < TB; i++) {
if (row_start + i >= E)
break;
int idx2 = __ldg(index_info.data + offset +
i * index_info.strides[index_info.dims - 1]);
if (idx1 == idx2) {
val += src_data[K * (row_start + i) + col_idx];
} else {
atomAdd(out_data + K * idx1 + col_idx, val);
val = src_data[K * (row_start + i) + col_idx];
}
idx1 = idx2;
}
atomAdd(out_data + K * idx1 + col_idx, val);
}
}
void segment_add_thrust_cuda(at::Tensor src, at::Tensor index, at::Tensor out) {
auto stream = at::cuda::getCurrentCUDAStream();
auto allocator = THCThrustAllocator(at::globalContext().lazyInitCUDA());
auto policy = thrust::cuda::par(allocator).on(stream);
void segment_add_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out) {
AT_ASSERTM(src.dim() >= index.dim());
for (int i = 0; i < index.dim(); i++)
AT_ASSERTM(src.size(i) == index.size(i));
src = src.contiguous();
auto reduce_dim = index.dim() - 1;
auto key = at::full_like(out, -1, out.options().dtype(at::kLong));
for (int i = 0; i < out.dim(); i++)
if (i != reduce_dim)
AT_ASSERTM(src.size(i) == out.size(i));
auto index_data = thrust::device_ptr<int64_t>(index.DATA_PTR<int64_t>());
auto key_data = thrust::device_ptr<int64_t>(key.DATA_PTR<int64_t>());
auto E = index.numel();
auto K = src.numel() / index.numel();
auto avg_length = (float)src.size(reduce_dim) / (float)out.size(reduce_dim);
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_add_thrust_kernel", [&] {
auto src_data = thrust::device_ptr<scalar_t>(src.DATA_PTR<scalar_t>());
auto out_data = thrust::device_ptr<scalar_t>(out.DATA_PTR<scalar_t>());
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index);
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_add_coo_kernel", [&] {
auto src_data = src.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>();
thrust::reduce_by_key(policy, index_data, index_data + index.numel(),
src_data, key_data, out_data);
if (K == 1)
segment_add_coo_kernel<scalar_t><<<BLOCKS(1, E), THREADS, 0, stream>>>(
src_data, index_info, out_data, E);
else if (avg_length <= 8)
segment_add_coo_broadcast_kernel<scalar_t, 4>
<<<dim3(((E + (8 * 4) - 1) / (8 * 4)), (K + 31) / 32), dim3(32, 8), 0,
stream>>>(src_data, index_info, out_data, E, K);
else if (avg_length <= 16)
segment_add_coo_broadcast_kernel<scalar_t, 8>
<<<dim3(((E + (8 * 8) - 1) / (8 * 8)), (K + 31) / 32), dim3(32, 8), 0,
stream>>>(src_data, index_info, out_data, E, K);
else if (avg_length <= 32)
segment_add_coo_broadcast_kernel<scalar_t, 16>
<<<dim3(((E + (8 * 16) - 1) / (8 * 16)), (K + 31) / 32), dim3(32, 8),
0, stream>>>(src_data, index_info, out_data, E, K);
else
segment_add_coo_broadcast_kernel<scalar_t, 32>
<<<dim3(((E + (8 * 32) - 1) / (8 * 32)), (K + 31) / 32), dim3(32, 8),
0, stream>>>(src_data, index_info, out_data, E, K);
});
}
......@@ -14,25 +14,16 @@ devices = [torch.device('cuda')]
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_forward(dtype, device):
src = tensor([1, 2, 3, 4, 5, 6], dtype, device)
index = tensor([0, 0, 1, 1, 1, 3], torch.long, device)
out = segment_add(src, index, dim=0)
# print('Thrust', out)
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_forward2(dtype, device):
src = tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]], dtype,
device)
indptr = tensor([0, 2, 5, 5, 6], torch.long, device)
# indptr = indptr.view(1, -1).expand(2, -1).t().contiguous().t()
out = segment_add_csr(src, indptr)
print('CSR', out)
# index = tensor([0, 0, 1, 1, 1, 3], torch.long, device)
# out = segment_add_coo(src, index)
# print('COO', out)
index = tensor([0, 0, 1, 1, 1, 3], torch.long, device)
out = segment_add_coo(src, index)
print('COO', out)
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
......@@ -40,15 +31,20 @@ def test_benchmark(dtype, device):
from torch_geometric.datasets import Planetoid, Reddit # noqa
# data = Planetoid('/tmp/Cora', 'Cora')[0].to(device)
data = Planetoid('/tmp/PubMed', 'PubMed')[0].to(device)
# data = Reddit('/tmp/Reddit')[0].to(device)
row, col = data.edge_index
x = torch.randn(data.num_edges, device=device)
print(data.num_edges)
print(row.size(0) / data.num_nodes)
num_repeats = 1
row = row.view(-1, 1).repeat(1, num_repeats).view(-1).contiguous()
col = col.view(-1, 1).repeat(1, num_repeats).view(-1).contiguous()
# Warmup
for _ in range(10):
torch.randn(100, 100, device=device).sum()
x = torch.randn(row.size(0), device=device)
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(100):
......@@ -63,16 +59,6 @@ def test_benchmark(dtype, device):
torch.cuda.synchronize()
print('Scatter Col', time.perf_counter() - t)
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(100):
out2 = segment_add(x, row, dim=0, dim_size=data.num_nodes)
torch.cuda.synchronize()
print('Thrust', time.perf_counter() - t)
assert torch.allclose(out1, out2, atol=1e-2)
rowcount = segment_add(torch.ones_like(row), row)
rowptr = torch.cat([rowcount.new_zeros(1), rowcount.cumsum(0)], dim=0)
torch.cuda.synchronize()
......@@ -84,8 +70,6 @@ def test_benchmark(dtype, device):
torch.cuda.synchronize()
print('CSR', time.perf_counter() - t)
assert torch.allclose(out1, out3, atol=1e-2)
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(100):
......@@ -93,9 +77,10 @@ def test_benchmark(dtype, device):
torch.cuda.synchronize()
print('COO', time.perf_counter() - t)
assert torch.allclose(out1, out3, atol=1e-2)
assert torch.allclose(out1, out4, atol=1e-2)
x = torch.randn((data.num_edges, 32), device=device)
x = torch.randn((row.size(0), 64), device=device)
torch.cuda.synchronize()
t = time.perf_counter()
......@@ -118,4 +103,12 @@ def test_benchmark(dtype, device):
torch.cuda.synchronize()
print('CSR + Dim', time.perf_counter() - t)
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(100):
out7 = segment_add_coo(x, row, dim_size=data.num_nodes)
torch.cuda.synchronize()
print('COO + Dim', time.perf_counter() - t)
assert torch.allclose(out5, out6, atol=1e-2)
assert torch.allclose(out5, out7, atol=1e-2)
......@@ -8,16 +8,7 @@ if torch.cuda.is_available():
def segment_add(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value)
if src.size(dim) == 0: # pragma: no cover
return out
if not src.is_cuda:
return scatter_add(src, index, dim, out, dim_size, fill_value)
torch_scatter.segment_cuda.segment_add_thrust(src, index, out)
return out
return scatter_add(src, index, dim, out, dim_size, fill_value)
def segment_add_csr(src, indptr):
......@@ -26,6 +17,8 @@ def segment_add_csr(src, indptr):
def segment_add_coo(src, index, dim_size=None):
dim_size = index.max().item() + 1 if dim_size is None else dim_size
out = src.new_zeros(dim_size)
size = list(src.size())
size[index.dim() - 1] = dim_size
out = src.new_zeros(size)
torch_scatter.segment_cuda.segment_add_coo(src, index, out)
return out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment