Commit 58d0025d authored by rusty1s's avatar rusty1s
Browse files

coo segment impl

parent fe67ccbd
......@@ -3,7 +3,7 @@
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
at::Tensor segment_add_csr_cuda(at::Tensor src, at::Tensor indptr);
at::Tensor segment_add_coo_cuda(at::Tensor src, at::Tensor index);
void segment_add_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out);
void segment_add_thrust_cuda(at::Tensor src, at::Tensor index, at::Tensor out);
......@@ -13,10 +13,11 @@ at::Tensor segment_add_csr(at::Tensor src, at::Tensor indptr) {
return segment_add_csr_cuda(src, indptr);
}
at::Tensor segment_add_coo(at::Tensor src, at::Tensor index) {
void segment_add_coo(at::Tensor src, at::Tensor index, at::Tensor out) {
CHECK_CUDA(src);
CHECK_CUDA(index);
return segment_add_coo_cuda(src, index);
CHECK_CUDA(out);
segment_add_coo_cuda(src, index, out);
}
void segment_add_thrust(at::Tensor src, at::Tensor index, at::Tensor out) {
......
......@@ -6,6 +6,7 @@
#include <thrust/execution_policy.h>
#include "atomics.cuh"
#include "compat.cuh"
#define THREADS 256
......@@ -41,14 +42,14 @@ __global__ void segment_add_csr_kernel(const scalar_t *src_data,
}
at::Tensor segment_add_csr_cuda(at::Tensor src, at::Tensor indptr) {
auto numel = indptr.numel() - 1;
auto numel = indptr.numel() - 1; // TODO
auto avg_length = (float)src.numel() / (float)numel;
auto out = at::empty({numel}, src.options());
auto indptr_data = indptr.DATA_PTR<int64_t>();
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_add_kernel", [&] {
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_add_csr_kernel", [&] {
auto src_data = src.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>();
......@@ -73,8 +74,45 @@ at::Tensor segment_add_csr_cuda(at::Tensor src, at::Tensor indptr) {
return out;
}
at::Tensor segment_add_coo_cuda(at::Tensor src, at::Tensor index) {
return src;
template <typename scalar_t>
__global__ void segment_add_coo_kernel(const scalar_t *src_data,
const int64_t *index_data,
scalar_t *out_data, size_t numel) {
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int lane_idx = thread_idx & (32 - 1);
if (thread_idx < numel) {
auto idx = __ldg(index_data + thread_idx);
scalar_t val = src_data[thread_idx], tmp;
#pragma unroll
for (int offset = 1; offset < 32; offset *= 2) {
tmp = __shfl_up_sync(FULL_MASK, val, offset);
if (lane_idx >= offset &&
idx == __ldg(index_data + thread_idx - offset)) {
val += tmp;
}
}
if (lane_idx == 31 || idx != __ldg(index_data + thread_idx + 1)) {
atomAdd(out_data + idx, val);
}
}
}
void segment_add_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out) {
auto numel = src.numel();
auto index_data = index.DATA_PTR<int64_t>();
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_add_coo_kernel", [&] {
auto src_data = src.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>();
segment_add_coo_kernel<scalar_t><<<BLOCKS(1, numel), THREADS, 0, stream>>>(
src_data, index_data, out_data, numel);
});
}
void segment_add_thrust_cuda(at::Tensor src, at::Tensor index, at::Tensor out) {
......
......@@ -4,7 +4,7 @@ from itertools import product
import pytest
import torch
from torch_scatter import segment_add, scatter_add
from torch_scatter.segment import segment_add2
from torch_scatter.segment import segment_add_csr, segment_add_coo
from .utils import tensor
......@@ -23,20 +23,22 @@ def test_forward(dtype, device):
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_forward2(dtype, device):
src = tensor([1, 2, 3, 4, 5, 6], dtype, device)
# indptr = tensor([0, 2, 5, 5, 6], torch.long, device)
indptr = tensor([[0, 2, 5, 5, 6]], torch.long, device)
out = segment_add_csr(src, indptr)
print('CSR', out)
out = segment_add2(src, indptr)
print('My', out)
index = tensor([0, 0, 1, 1, 1, 3], torch.long, device)
out = segment_add_coo(src, index)
print('COO', out)
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_benchmark(dtype, device):
from torch_geometric.datasets import Planetoid, Reddit # noqa
# data = Planetoid('/tmp/Cora', 'Cora')[0].to(device)
data = Planetoid('/tmp/Cora', 'Cora')[0].to(device)
# data = Planetoid('/tmp/PubMed', 'PubMed')[0].to(device)
data = Reddit('/tmp/Reddit')[0].to(device)
# data = Reddit('/tmp/Reddit')[0].to(device)
row, col = data.edge_index
x = torch.randn(data.num_edges, device=device)
print(row.size(0) / data.num_nodes)
......@@ -50,7 +52,14 @@ def test_benchmark(dtype, device):
for _ in range(100):
out1 = scatter_add(x, row, dim=0, dim_size=data.num_nodes)
torch.cuda.synchronize()
print(time.perf_counter() - t)
print('Scatter Row', time.perf_counter() - t)
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(100):
scatter_add(x, col, dim=0, dim_size=data.num_nodes)
torch.cuda.synchronize()
print('Scatter Col', time.perf_counter() - t)
torch.cuda.synchronize()
......@@ -58,7 +67,7 @@ def test_benchmark(dtype, device):
for _ in range(100):
out2 = segment_add(x, row, dim=0, dim_size=data.num_nodes)
torch.cuda.synchronize()
print(time.perf_counter() - t)
print('Thrust', time.perf_counter() - t)
assert torch.allclose(out1, out2, atol=1e-2)
......@@ -69,8 +78,17 @@ def test_benchmark(dtype, device):
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(100):
out3 = segment_add2(x, rowptr)
out3 = segment_add_csr(x, rowptr)
torch.cuda.synchronize()
print(time.perf_counter() - t)
print('CSR', time.perf_counter() - t)
assert torch.allclose(out1, out3, atol=1e-2)
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(100):
out4 = segment_add_coo(x, row, dim_size=data.num_nodes)
torch.cuda.synchronize()
print('COO', time.perf_counter() - t)
assert torch.allclose(out1, out4, atol=1e-2)
......@@ -15,19 +15,17 @@ def segment_add(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
if not src.is_cuda:
return scatter_add(src, index, dim, out, dim_size, fill_value)
# index = index.transpose(dim, -1).contiguous()
# src = src.transpose(dim, -1).contiguous()
# out = out.transpose(dim, -1).contiguous()
# print(index)
# print(src)
torch_scatter.segment_cuda.segment_add_thrust(src, index, out)
# out = out.transpose(dim, -1).contiguous()
# key = key.transpose(dim, -1).contiguous()
return out
def segment_add2(src, indptr):
def segment_add_csr(src, indptr):
return torch_scatter.segment_cuda.segment_add_csr(src, indptr)
def segment_add_coo(src, index, dim_size=None):
dim_size = index.max().item() + 1 if dim_size is None else dim_size
out = src.new_zeros(dim_size)
torch_scatter.segment_cuda.segment_add_coo(src, index, out)
return out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment