Commit 9b68396b authored by rusty1s's avatar rusty1s
Browse files

broadcast csr

parent db777e5c
......@@ -16,27 +16,6 @@
#define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS
#define FULL_MASK 0xffffffff
// template <typename scalar_t, int TB>
// __global__ void segment_add_csr_broadcast_kernel(const scalar_t *src_data,
// const int64_t *indptr_data,
// scalar_t *out_data,
// size_t numel) {}
// template <typename T, typename I> struct IndexPtrToOffset<T, I> {
// static inline __host__ __device__ I
// get(I idx, const at::cuda::detail::TensorInfo<T, I> &info) {
// return idx;
// I offset = idx % (info.sizes[info.dims - 1] - 1);
// idx /= info.sizes[info.dims - 1] - 1;
// for (int i = info.dims - 2; i >= 0; --i) {
// offset += (idx % info.sizes[i]) * info.strides[i];
// idx /= info.sizes[i];
// }
// return offset;
// }
// };
template <typename T, typename I> struct IndexPtrToOffset {
static __host__ __device__ I
get(I idx, const at::cuda::detail::TensorInfo<T, I> &info) {
......@@ -58,33 +37,63 @@ __global__ void segment_add_csr_kernel(
scalar_t *out_data, size_t N, size_t E) {
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int warp_idx = thread_idx / TB;
int row_idx = thread_idx / TB;
int lane_idx = thread_idx & (TB - 1);
if (warp_idx < N) {
auto offset = IndexPtrToOffset<int64_t, int>::get(warp_idx, indptr_info);
if (row_idx < N) {
auto offset = IndexPtrToOffset<int64_t, int>::get(row_idx, indptr_info);
int row_start = __ldg(indptr_info.data + offset);
int row_end = __ldg(indptr_info.data + offset +
indptr_info.strides[indptr_info.dims - 1]);
scalar_t val = (scalar_t)0;
offset = (warp_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E;
offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E;
for (int src_idx = row_start + lane_idx; src_idx < row_end; src_idx += TB) {
val += src_data[offset + src_idx];
}
#pragma unroll
for (int i = TB / 2; i > 0; i /= 2)
val += __shfl_down_sync(FULL_MASK, val, i); // Parallel reduction.
val += __shfl_down_sync(FULL_MASK, val, i); // Parallel reduction
if (lane_idx == 0) {
out_data[warp_idx] = val;
out_data[row_idx] = val;
}
}
}
template <typename scalar_t>
__global__ void segment_add_csr_broadcast_kernel(
const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> indptr_info,
scalar_t *out_data, size_t N, size_t K, size_t E) {
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int row_idx = thread_idx / K;
int lane_idx = thread_idx % K;
if (thread_idx < N * K) {
auto offset = IndexPtrToOffset<int64_t, int>::get(row_idx, indptr_info);
int row_start = __ldg(indptr_info.data + offset);
int row_end = __ldg(indptr_info.data + offset +
indptr_info.strides[indptr_info.dims - 1]);
scalar_t val = (scalar_t)0;
offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E * K;
for (int src_idx = row_start; src_idx < row_end; src_idx++) {
// Coalesced read into `src_data`.
val += src_data[offset + K * src_idx + lane_idx];
}
out_data[thread_idx] = val; // Coalesced write into `out_data`
}
}
at::Tensor segment_add_csr_cuda(at::Tensor src, at::Tensor indptr) {
AT_ASSERTM(src.dim() >= indptr.dim());
for (int i = 0; i < indptr.dim() - 1; i++)
AT_ASSERTM(src.size(i) == indptr.size(i));
src = src.contiguous();
auto reduce_dim = indptr.dim() - 1;
......@@ -92,7 +101,8 @@ at::Tensor segment_add_csr_cuda(at::Tensor src, at::Tensor indptr) {
sizes[reduce_dim] = indptr.size(reduce_dim) - 1;
auto out = at::empty(sizes, src.options());
auto N = (indptr.size(-1) - 1) * (indptr.numel() / indptr.size(-1));
auto N = out.size(reduce_dim) * (indptr.numel() / indptr.size(-1));
auto K = out.numel() / N;
auto E = src.size(reduce_dim);
auto avg_length = (float)src.size(reduce_dim) / (float)out.size(reduce_dim);
......@@ -102,20 +112,25 @@ at::Tensor segment_add_csr_cuda(at::Tensor src, at::Tensor indptr) {
auto src_data = src.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>();
if (avg_length <= 4)
if (K == 1 && avg_length <= 4) {
segment_add_csr_kernel<scalar_t, 4><<<BLOCKS(4, N), THREADS, 0, stream>>>(
src_data, indptr_info, out_data, N, E);
else if (avg_length <= 8)
} else if (K == 1 && avg_length <= 8) {
segment_add_csr_kernel<scalar_t, 8><<<BLOCKS(8, N), THREADS, 0, stream>>>(
src_data, indptr_info, out_data, N, E);
else if (avg_length <= 16)
} else if (K == 1 && avg_length <= 16) {
segment_add_csr_kernel<scalar_t, 16>
<<<BLOCKS(16, N), THREADS, 0, stream>>>(src_data, indptr_info,
out_data, N, E);
else
} else if (K == 1) {
segment_add_csr_kernel<scalar_t, 32>
<<<BLOCKS(32, N), THREADS, 0, stream>>>(src_data, indptr_info,
out_data, N, E);
} else {
segment_add_csr_broadcast_kernel<scalar_t>
<<<BLOCKS(1, N * K), THREADS, 0, stream>>>(src_data, indptr_info,
out_data, N, K, E);
}
});
return out;
......
......@@ -17,29 +17,30 @@ def test_forward(dtype, device):
src = tensor([1, 2, 3, 4, 5, 6], dtype, device)
index = tensor([0, 0, 1, 1, 1, 3], torch.long, device)
out = segment_add(src, index, dim=0)
print('Thrust', out)
# print('Thrust', out)
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_forward2(dtype, device):
src = tensor([[1, 2, 3, 4, 5, 6], [1, 3, 5, 7, 9, 11]], dtype, device)
src = tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]], dtype,
device)
indptr = tensor([0, 2, 5, 5, 6], torch.long, device)
# indptr = indptr.view(1, -1).expand(2, -1).t().contiguous().t()
indptr = tensor([[0, 2, 5, 5, 6]], torch.long, device)
indptr = indptr.view(1, -1).expand(2, -1).t().contiguous().t()
out = segment_add_csr(src, indptr)
print('CSR', out)
# print('CSR', out)
index = tensor([0, 0, 1, 1, 1, 3], torch.long, device)
out = segment_add_coo(src, index)
print('COO', out)
# index = tensor([0, 0, 1, 1, 1, 3], torch.long, device)
# out = segment_add_coo(src, index)
# print('COO', out)
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_benchmark(dtype, device):
from torch_geometric.datasets import Planetoid, Reddit # noqa
# data = Planetoid('/tmp/Cora', 'Cora')[0].to(device)
# data = Planetoid('/tmp/PubMed', 'PubMed')[0].to(device)
data = Reddit('/tmp/Reddit')[0].to(device)
data = Planetoid('/tmp/PubMed', 'PubMed')[0].to(device)
# data = Reddit('/tmp/Reddit')[0].to(device)
row, col = data.edge_index
x = torch.randn(data.num_edges, device=device)
print(row.size(0) / data.num_nodes)
......@@ -93,3 +94,28 @@ def test_benchmark(dtype, device):
print('COO', time.perf_counter() - t)
assert torch.allclose(out1, out4, atol=1e-2)
x = torch.randn((data.num_edges, 1024), device=device)
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(100):
out5 = scatter_add(x, row, dim=0, dim_size=data.num_nodes)
torch.cuda.synchronize()
print('Scatter Row + Dim', time.perf_counter() - t)
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(100):
scatter_add(x, col, dim=0, dim_size=data.num_nodes)
torch.cuda.synchronize()
print('Scatter Col + Dim', time.perf_counter() - t)
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(100):
out6 = segment_add_csr(x, rowptr)
torch.cuda.synchronize()
print('CSR + Dim', time.perf_counter() - t)
assert torch.allclose(out5, out6, atol=1e-2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment