Commit fe67ccbd authored by rusty1s's avatar rusty1s
Browse files

update with variable TB

parent 0ad76a83
......@@ -2,13 +2,21 @@
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
at::Tensor segment_add_cuda(at::Tensor src, at::Tensor indptr, int64_t dim);
at::Tensor segment_add_csr_cuda(at::Tensor src, at::Tensor indptr);
at::Tensor segment_add_coo_cuda(at::Tensor src, at::Tensor index);
void segment_add_thrust_cuda(at::Tensor src, at::Tensor index, at::Tensor out);
at::Tensor segment_add(at::Tensor src, at::Tensor indptr, int64_t dim) {
at::Tensor segment_add_csr(at::Tensor src, at::Tensor indptr) {
CHECK_CUDA(src);
CHECK_CUDA(indptr);
return segment_add_cuda(src, indptr, dim);
return segment_add_csr_cuda(src, indptr);
}
at::Tensor segment_add_coo(at::Tensor src, at::Tensor index) {
CHECK_CUDA(src);
CHECK_CUDA(index);
return segment_add_coo_cuda(src, index);
}
void segment_add_thrust(at::Tensor src, at::Tensor index, at::Tensor out) {
......@@ -19,6 +27,7 @@ void segment_add_thrust(at::Tensor src, at::Tensor index, at::Tensor out) {
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("segment_add", &segment_add, "Segment Add (CUDA)");
m.def("segment_add_csr", &segment_add_csr, "Segment Add CSR (CUDA)");
m.def("segment_add_coo", &segment_add_coo, "Segment Add COO (CUDA)");
m.def("segment_add_thrust", &segment_add_thrust, "Segment Add Thrust (CUDA)");
}
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <THC/THCGeneral.h>
#include <THC/THCThrustAllocator.cuh>
......@@ -11,12 +9,13 @@
#include "compat.cuh"
#define THREADS 256
#define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS
#define FULL_MASK 0xffffffff
template <typename scalar_t, int TB>
__global__ void segment_add_kernel(const scalar_t *src_data,
const int64_t *indptr_data,
scalar_t *out_data, size_t numel) {
__global__ void segment_add_csr_kernel(const scalar_t *src_data,
const int64_t *indptr_data,
scalar_t *out_data, size_t numel) {
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int warp_idx = thread_idx / TB;
......@@ -41,24 +40,43 @@ __global__ void segment_add_kernel(const scalar_t *src_data,
}
}
at::Tensor segment_add_cuda(at::Tensor src, at::Tensor indptr, int64_t dim) {
at::Tensor segment_add_csr_cuda(at::Tensor src, at::Tensor indptr) {
auto numel = indptr.numel() - 1;
auto avg_length = (float)src.numel() / (float)numel;
auto out = at::empty({numel}, src.options());
auto indptr_data = indptr.DATA_PTR<int64_t>();
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_add_kernel", [&] {
auto indptr_data = indptr.DATA_PTR<int64_t>();
auto src_data = src.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>();
segment_add_kernel<scalar_t, 32>
<<<(32 * numel + THREADS - 1) / THREADS, THREADS, 0, stream>>>(
src_data, indptr_data, out_data, numel);
if (avg_length <= 4)
segment_add_csr_kernel<scalar_t, 4>
<<<BLOCKS(4, numel), THREADS, 0, stream>>>(src_data, indptr_data,
out_data, numel);
else if (avg_length <= 8)
segment_add_csr_kernel<scalar_t, 8>
<<<BLOCKS(8, numel), THREADS, 0, stream>>>(src_data, indptr_data,
out_data, numel);
else if (avg_length <= 16)
segment_add_csr_kernel<scalar_t, 16>
<<<BLOCKS(16, numel), THREADS, 0, stream>>>(src_data, indptr_data,
out_data, numel);
else
segment_add_csr_kernel<scalar_t, 32>
<<<BLOCKS(32, numel), THREADS, 0, stream>>>(src_data, indptr_data,
out_data, numel);
});
return out;
}
at::Tensor segment_add_coo_cuda(at::Tensor src, at::Tensor index) {
return src;
}
void segment_add_thrust_cuda(at::Tensor src, at::Tensor index, at::Tensor out) {
auto stream = at::cuda::getCurrentCUDAStream();
auto allocator = THCThrustAllocator(at::globalContext().lazyInitCUDA());
......
......@@ -27,14 +27,14 @@ def test_forward2(dtype, device):
indptr = tensor([[0, 2, 5, 5, 6]], torch.long, device)
out = segment_add2(src, indptr, dim=0)
out = segment_add2(src, indptr)
print('My', out)
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_benchmark(dtype, device):
from torch_geometric.datasets import Planetoid, Reddit # noqa
data = Planetoid('/tmp/Cora', 'Cora')[0].to(device)
# data = Planetoid('/tmp/Cora', 'Cora')[0].to(device)
# data = Planetoid('/tmp/PubMed', 'PubMed')[0].to(device)
data = Reddit('/tmp/Reddit')[0].to(device)
row, col = data.edge_index
......@@ -69,7 +69,7 @@ def test_benchmark(dtype, device):
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(100):
out3 = segment_add2(x, rowptr, dim=0)
out3 = segment_add2(x, rowptr)
torch.cuda.synchronize()
print(time.perf_counter() - t)
......
......@@ -29,5 +29,5 @@ def segment_add(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
return out
def segment_add2(src, indptr, dim=-1):
return torch_scatter.segment_cuda.segment_add(src, indptr, dim)
def segment_add2(src, indptr):
return torch_scatter.segment_cuda.segment_add_csr(src, indptr)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment