Commit 9b68396b authored by rusty1s's avatar rusty1s
Browse files

broadcast csr

parent db777e5c
...@@ -16,27 +16,6 @@ ...@@ -16,27 +16,6 @@
#define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS #define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS
#define FULL_MASK 0xffffffff #define FULL_MASK 0xffffffff
// template <typename scalar_t, int TB>
// __global__ void segment_add_csr_broadcast_kernel(const scalar_t *src_data,
// const int64_t *indptr_data,
// scalar_t *out_data,
// size_t numel) {}
// template <typename T, typename I> struct IndexPtrToOffset<T, I> {
// static inline __host__ __device__ I
// get(I idx, const at::cuda::detail::TensorInfo<T, I> &info) {
// return idx;
// I offset = idx % (info.sizes[info.dims - 1] - 1);
// idx /= info.sizes[info.dims - 1] - 1;
// for (int i = info.dims - 2; i >= 0; --i) {
// offset += (idx % info.sizes[i]) * info.strides[i];
// idx /= info.sizes[i];
// }
// return offset;
// }
// };
template <typename T, typename I> struct IndexPtrToOffset { template <typename T, typename I> struct IndexPtrToOffset {
static __host__ __device__ I static __host__ __device__ I
get(I idx, const at::cuda::detail::TensorInfo<T, I> &info) { get(I idx, const at::cuda::detail::TensorInfo<T, I> &info) {
...@@ -58,33 +37,63 @@ __global__ void segment_add_csr_kernel( ...@@ -58,33 +37,63 @@ __global__ void segment_add_csr_kernel(
scalar_t *out_data, size_t N, size_t E) { scalar_t *out_data, size_t N, size_t E) {
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int warp_idx = thread_idx / TB; int row_idx = thread_idx / TB;
int lane_idx = thread_idx & (TB - 1); int lane_idx = thread_idx & (TB - 1);
if (warp_idx < N) { if (row_idx < N) {
auto offset = IndexPtrToOffset<int64_t, int>::get(warp_idx, indptr_info); auto offset = IndexPtrToOffset<int64_t, int>::get(row_idx, indptr_info);
int row_start = __ldg(indptr_info.data + offset); int row_start = __ldg(indptr_info.data + offset);
int row_end = __ldg(indptr_info.data + offset + int row_end = __ldg(indptr_info.data + offset +
indptr_info.strides[indptr_info.dims - 1]); indptr_info.strides[indptr_info.dims - 1]);
scalar_t val = (scalar_t)0; scalar_t val = (scalar_t)0;
offset = (warp_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E; offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E;
for (int src_idx = row_start + lane_idx; src_idx < row_end; src_idx += TB) { for (int src_idx = row_start + lane_idx; src_idx < row_end; src_idx += TB) {
val += src_data[offset + src_idx]; val += src_data[offset + src_idx];
} }
#pragma unroll #pragma unroll
for (int i = TB / 2; i > 0; i /= 2) for (int i = TB / 2; i > 0; i /= 2)
val += __shfl_down_sync(FULL_MASK, val, i); // Parallel reduction. val += __shfl_down_sync(FULL_MASK, val, i); // Parallel reduction
if (lane_idx == 0) { if (lane_idx == 0) {
out_data[warp_idx] = val; out_data[row_idx] = val;
}
}
}
template <typename scalar_t>
__global__ void segment_add_csr_broadcast_kernel(
const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> indptr_info,
scalar_t *out_data, size_t N, size_t K, size_t E) {
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int row_idx = thread_idx / K;
int lane_idx = thread_idx % K;
if (thread_idx < N * K) {
auto offset = IndexPtrToOffset<int64_t, int>::get(row_idx, indptr_info);
int row_start = __ldg(indptr_info.data + offset);
int row_end = __ldg(indptr_info.data + offset +
indptr_info.strides[indptr_info.dims - 1]);
scalar_t val = (scalar_t)0;
offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E * K;
for (int src_idx = row_start; src_idx < row_end; src_idx++) {
// Coalesced read into `src_data`.
val += src_data[offset + K * src_idx + lane_idx];
} }
out_data[thread_idx] = val; // Coalesced write into `out_data`
} }
} }
at::Tensor segment_add_csr_cuda(at::Tensor src, at::Tensor indptr) { at::Tensor segment_add_csr_cuda(at::Tensor src, at::Tensor indptr) {
AT_ASSERTM(src.dim() >= indptr.dim()); AT_ASSERTM(src.dim() >= indptr.dim());
for (int i = 0; i < indptr.dim() - 1; i++)
AT_ASSERTM(src.size(i) == indptr.size(i));
src = src.contiguous(); src = src.contiguous();
auto reduce_dim = indptr.dim() - 1; auto reduce_dim = indptr.dim() - 1;
...@@ -92,7 +101,8 @@ at::Tensor segment_add_csr_cuda(at::Tensor src, at::Tensor indptr) { ...@@ -92,7 +101,8 @@ at::Tensor segment_add_csr_cuda(at::Tensor src, at::Tensor indptr) {
sizes[reduce_dim] = indptr.size(reduce_dim) - 1; sizes[reduce_dim] = indptr.size(reduce_dim) - 1;
auto out = at::empty(sizes, src.options()); auto out = at::empty(sizes, src.options());
auto N = (indptr.size(-1) - 1) * (indptr.numel() / indptr.size(-1)); auto N = out.size(reduce_dim) * (indptr.numel() / indptr.size(-1));
auto K = out.numel() / N;
auto E = src.size(reduce_dim); auto E = src.size(reduce_dim);
auto avg_length = (float)src.size(reduce_dim) / (float)out.size(reduce_dim); auto avg_length = (float)src.size(reduce_dim) / (float)out.size(reduce_dim);
...@@ -102,20 +112,25 @@ at::Tensor segment_add_csr_cuda(at::Tensor src, at::Tensor indptr) { ...@@ -102,20 +112,25 @@ at::Tensor segment_add_csr_cuda(at::Tensor src, at::Tensor indptr) {
auto src_data = src.DATA_PTR<scalar_t>(); auto src_data = src.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>(); auto out_data = out.DATA_PTR<scalar_t>();
if (avg_length <= 4) if (K == 1 && avg_length <= 4) {
segment_add_csr_kernel<scalar_t, 4><<<BLOCKS(4, N), THREADS, 0, stream>>>( segment_add_csr_kernel<scalar_t, 4><<<BLOCKS(4, N), THREADS, 0, stream>>>(
src_data, indptr_info, out_data, N, E); src_data, indptr_info, out_data, N, E);
else if (avg_length <= 8) } else if (K == 1 && avg_length <= 8) {
segment_add_csr_kernel<scalar_t, 8><<<BLOCKS(8, N), THREADS, 0, stream>>>( segment_add_csr_kernel<scalar_t, 8><<<BLOCKS(8, N), THREADS, 0, stream>>>(
src_data, indptr_info, out_data, N, E); src_data, indptr_info, out_data, N, E);
else if (avg_length <= 16) } else if (K == 1 && avg_length <= 16) {
segment_add_csr_kernel<scalar_t, 16> segment_add_csr_kernel<scalar_t, 16>
<<<BLOCKS(16, N), THREADS, 0, stream>>>(src_data, indptr_info, <<<BLOCKS(16, N), THREADS, 0, stream>>>(src_data, indptr_info,
out_data, N, E); out_data, N, E);
else } else if (K == 1) {
segment_add_csr_kernel<scalar_t, 32> segment_add_csr_kernel<scalar_t, 32>
<<<BLOCKS(32, N), THREADS, 0, stream>>>(src_data, indptr_info, <<<BLOCKS(32, N), THREADS, 0, stream>>>(src_data, indptr_info,
out_data, N, E); out_data, N, E);
} else {
segment_add_csr_broadcast_kernel<scalar_t>
<<<BLOCKS(1, N * K), THREADS, 0, stream>>>(src_data, indptr_info,
out_data, N, K, E);
}
}); });
return out; return out;
......
...@@ -17,29 +17,30 @@ def test_forward(dtype, device): ...@@ -17,29 +17,30 @@ def test_forward(dtype, device):
src = tensor([1, 2, 3, 4, 5, 6], dtype, device) src = tensor([1, 2, 3, 4, 5, 6], dtype, device)
index = tensor([0, 0, 1, 1, 1, 3], torch.long, device) index = tensor([0, 0, 1, 1, 1, 3], torch.long, device)
out = segment_add(src, index, dim=0) out = segment_add(src, index, dim=0)
print('Thrust', out) # print('Thrust', out)
@pytest.mark.parametrize('dtype,device', product(dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_forward2(dtype, device): def test_forward2(dtype, device):
src = tensor([[1, 2, 3, 4, 5, 6], [1, 3, 5, 7, 9, 11]], dtype, device) src = tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]], dtype,
device)
indptr = tensor([0, 2, 5, 5, 6], torch.long, device)
# indptr = indptr.view(1, -1).expand(2, -1).t().contiguous().t()
indptr = tensor([[0, 2, 5, 5, 6]], torch.long, device)
indptr = indptr.view(1, -1).expand(2, -1).t().contiguous().t()
out = segment_add_csr(src, indptr) out = segment_add_csr(src, indptr)
print('CSR', out) # print('CSR', out)
index = tensor([0, 0, 1, 1, 1, 3], torch.long, device) # index = tensor([0, 0, 1, 1, 1, 3], torch.long, device)
out = segment_add_coo(src, index) # out = segment_add_coo(src, index)
print('COO', out) # print('COO', out)
@pytest.mark.parametrize('dtype,device', product(dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_benchmark(dtype, device): def test_benchmark(dtype, device):
from torch_geometric.datasets import Planetoid, Reddit # noqa from torch_geometric.datasets import Planetoid, Reddit # noqa
# data = Planetoid('/tmp/Cora', 'Cora')[0].to(device) # data = Planetoid('/tmp/Cora', 'Cora')[0].to(device)
# data = Planetoid('/tmp/PubMed', 'PubMed')[0].to(device) data = Planetoid('/tmp/PubMed', 'PubMed')[0].to(device)
data = Reddit('/tmp/Reddit')[0].to(device) # data = Reddit('/tmp/Reddit')[0].to(device)
row, col = data.edge_index row, col = data.edge_index
x = torch.randn(data.num_edges, device=device) x = torch.randn(data.num_edges, device=device)
print(row.size(0) / data.num_nodes) print(row.size(0) / data.num_nodes)
...@@ -93,3 +94,28 @@ def test_benchmark(dtype, device): ...@@ -93,3 +94,28 @@ def test_benchmark(dtype, device):
print('COO', time.perf_counter() - t) print('COO', time.perf_counter() - t)
assert torch.allclose(out1, out4, atol=1e-2) assert torch.allclose(out1, out4, atol=1e-2)
x = torch.randn((data.num_edges, 1024), device=device)
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(100):
out5 = scatter_add(x, row, dim=0, dim_size=data.num_nodes)
torch.cuda.synchronize()
print('Scatter Row + Dim', time.perf_counter() - t)
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(100):
scatter_add(x, col, dim=0, dim_size=data.num_nodes)
torch.cuda.synchronize()
print('Scatter Col + Dim', time.perf_counter() - t)
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(100):
out6 = segment_add_csr(x, rowptr)
torch.cuda.synchronize()
print('CSR + Dim', time.perf_counter() - t)
assert torch.allclose(out5, out6, atol=1e-2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment