Commit cca0044c authored by rusty1s's avatar rusty1s
Browse files

added tbs

parent 58d0025d
...@@ -74,20 +74,20 @@ at::Tensor segment_add_csr_cuda(at::Tensor src, at::Tensor indptr) { ...@@ -74,20 +74,20 @@ at::Tensor segment_add_csr_cuda(at::Tensor src, at::Tensor indptr) {
return out; return out;
} }
template <typename scalar_t> template <typename scalar_t, int TB>
__global__ void segment_add_coo_kernel(const scalar_t *src_data, __global__ void segment_add_coo_kernel(const scalar_t *src_data,
const int64_t *index_data, const int64_t *index_data,
scalar_t *out_data, size_t numel) { scalar_t *out_data, size_t numel) {
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int lane_idx = thread_idx & (32 - 1); int lane_idx = thread_idx & (TB - 1);
if (thread_idx < numel) { if (thread_idx < numel) {
auto idx = __ldg(index_data + thread_idx); auto idx = __ldg(index_data + thread_idx);
scalar_t val = src_data[thread_idx], tmp; scalar_t val = src_data[thread_idx], tmp;
#pragma unroll #pragma unroll
for (int offset = 1; offset < 32; offset *= 2) { for (int offset = 1; offset < TB; offset *= 2) {
tmp = __shfl_up_sync(FULL_MASK, val, offset); tmp = __shfl_up_sync(FULL_MASK, val, offset);
if (lane_idx >= offset && if (lane_idx >= offset &&
idx == __ldg(index_data + thread_idx - offset)) { idx == __ldg(index_data + thread_idx - offset)) {
...@@ -95,7 +95,7 @@ __global__ void segment_add_coo_kernel(const scalar_t *src_data, ...@@ -95,7 +95,7 @@ __global__ void segment_add_coo_kernel(const scalar_t *src_data,
} }
} }
if (lane_idx == 31 || idx != __ldg(index_data + thread_idx + 1)) { if (lane_idx == TB - 1 || idx != __ldg(index_data + thread_idx + 1)) {
atomAdd(out_data + idx, val); atomAdd(out_data + idx, val);
} }
} }
...@@ -103,6 +103,7 @@ __global__ void segment_add_coo_kernel(const scalar_t *src_data, ...@@ -103,6 +103,7 @@ __global__ void segment_add_coo_kernel(const scalar_t *src_data,
void segment_add_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out) { void segment_add_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out) {
auto numel = src.numel(); auto numel = src.numel();
auto avg_length = (float)numel / (float)out.numel();
auto index_data = index.DATA_PTR<int64_t>(); auto index_data = index.DATA_PTR<int64_t>();
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
...@@ -110,8 +111,9 @@ void segment_add_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out) { ...@@ -110,8 +111,9 @@ void segment_add_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out) {
auto src_data = src.DATA_PTR<scalar_t>(); auto src_data = src.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>(); auto out_data = out.DATA_PTR<scalar_t>();
segment_add_coo_kernel<scalar_t><<<BLOCKS(1, numel), THREADS, 0, stream>>>( segment_add_coo_kernel<scalar_t, 32>
src_data, index_data, out_data, numel); <<<BLOCKS(1, numel), THREADS, 0, stream>>>(src_data, index_data,
out_data, numel);
}); });
} }
......
...@@ -37,7 +37,7 @@ def test_forward2(dtype, device): ...@@ -37,7 +37,7 @@ def test_forward2(dtype, device):
def test_benchmark(dtype, device): def test_benchmark(dtype, device):
from torch_geometric.datasets import Planetoid, Reddit # noqa from torch_geometric.datasets import Planetoid, Reddit # noqa
data = Planetoid('/tmp/Cora', 'Cora')[0].to(device) data = Planetoid('/tmp/Cora', 'Cora')[0].to(device)
# data = Planetoid('/tmp/PubMed', 'PubMed')[0].to(device) data = Planetoid('/tmp/PubMed', 'PubMed')[0].to(device)
# data = Reddit('/tmp/Reddit')[0].to(device) # data = Reddit('/tmp/Reddit')[0].to(device)
row, col = data.edge_index row, col = data.edge_index
x = torch.randn(data.num_edges, device=device) x = torch.randn(data.num_edges, device=device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment