Commit cca0044c authored by rusty1s's avatar rusty1s
Browse files

added tbs

parent 58d0025d
......@@ -74,20 +74,20 @@ at::Tensor segment_add_csr_cuda(at::Tensor src, at::Tensor indptr) {
return out;
}
template <typename scalar_t>
template <typename scalar_t, int TB>
__global__ void segment_add_coo_kernel(const scalar_t *src_data,
const int64_t *index_data,
scalar_t *out_data, size_t numel) {
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int lane_idx = thread_idx & (32 - 1);
int lane_idx = thread_idx & (TB - 1);
if (thread_idx < numel) {
auto idx = __ldg(index_data + thread_idx);
scalar_t val = src_data[thread_idx], tmp;
#pragma unroll
for (int offset = 1; offset < 32; offset *= 2) {
for (int offset = 1; offset < TB; offset *= 2) {
tmp = __shfl_up_sync(FULL_MASK, val, offset);
if (lane_idx >= offset &&
idx == __ldg(index_data + thread_idx - offset)) {
......@@ -95,7 +95,7 @@ __global__ void segment_add_coo_kernel(const scalar_t *src_data,
}
}
if (lane_idx == 31 || idx != __ldg(index_data + thread_idx + 1)) {
if (lane_idx == TB - 1 || idx != __ldg(index_data + thread_idx + 1)) {
atomAdd(out_data + idx, val);
}
}
......@@ -103,6 +103,7 @@ __global__ void segment_add_coo_kernel(const scalar_t *src_data,
void segment_add_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out) {
auto numel = src.numel();
auto avg_length = (float)numel / (float)out.numel();
auto index_data = index.DATA_PTR<int64_t>();
auto stream = at::cuda::getCurrentCUDAStream();
......@@ -110,8 +111,9 @@ void segment_add_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out) {
auto src_data = src.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>();
segment_add_coo_kernel<scalar_t><<<BLOCKS(1, numel), THREADS, 0, stream>>>(
src_data, index_data, out_data, numel);
segment_add_coo_kernel<scalar_t, 32>
<<<BLOCKS(1, numel), THREADS, 0, stream>>>(src_data, index_data,
out_data, numel);
});
}
......
......@@ -37,7 +37,7 @@ def test_forward2(dtype, device):
def test_benchmark(dtype, device):
from torch_geometric.datasets import Planetoid, Reddit # noqa
data = Planetoid('/tmp/Cora', 'Cora')[0].to(device)
# data = Planetoid('/tmp/PubMed', 'PubMed')[0].to(device)
data = Planetoid('/tmp/PubMed', 'PubMed')[0].to(device)
# data = Reddit('/tmp/Reddit')[0].to(device)
row, col = data.edge_index
x = torch.randn(data.num_edges, device=device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment