"...git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "9ebb75dad2bd0f1d1633b7af50b9cd03db379987"
Commit 0b3069fe authored by rusty1s's avatar rusty1s
Browse files

shared memory version

parent 4d7b32c5
...@@ -89,6 +89,48 @@ __global__ void segment_add_csr_broadcast_kernel( ...@@ -89,6 +89,48 @@ __global__ void segment_add_csr_broadcast_kernel(
} }
} }
template <typename scalar_t, int TB>
__global__ void segment_add_csr_broadcast_kernel2(
const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> indptr_info,
scalar_t *out_data, size_t N, size_t K, size_t E) {
int thread_idx = blockIdx.y * blockDim.y + threadIdx.y;
int row_idx = thread_idx / TB;
int lane_idx = thread_idx & (TB - 1);
int col_idx = blockIdx.x * blockDim.x + threadIdx.x;
__shared__ scalar_t vals[32][32];
if (row_idx < N) {
auto offset = IndexPtrToOffset<int64_t, int>::get(row_idx, indptr_info);
int row_start = __ldg(indptr_info.data + offset);
int row_end = __ldg(indptr_info.data + offset +
indptr_info.strides[indptr_info.dims - 1]);
scalar_t val = (scalar_t)0;
offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E * K;
if (col_idx < K) {
for (int i = row_start + lane_idx; i < row_end; i += TB) {
val += src_data[offset + K * i + col_idx];
}
}
vals[threadIdx.x][threadIdx.y] = val;
__syncthreads();
#pragma unroll
for (int i = 1; i < TB; i *= 2) {
vals[threadIdx.x][threadIdx.y] += vals[threadIdx.x][threadIdx.y + i];
__syncthreads();
}
if (col_idx < K && lane_idx == 0) {
out_data[row_idx * K + col_idx] = vals[threadIdx.x][threadIdx.y];
}
}
}
at::Tensor segment_add_csr_cuda(at::Tensor src, at::Tensor indptr) { at::Tensor segment_add_csr_cuda(at::Tensor src, at::Tensor indptr) {
AT_ASSERTM(src.dim() >= indptr.dim()); AT_ASSERTM(src.dim() >= indptr.dim());
for (int i = 0; i < indptr.dim() - 1; i++) for (int i = 0; i < indptr.dim() - 1; i++)
......
...@@ -28,7 +28,7 @@ def test_forward2(dtype, device): ...@@ -28,7 +28,7 @@ def test_forward2(dtype, device):
# indptr = indptr.view(1, -1).expand(2, -1).t().contiguous().t() # indptr = indptr.view(1, -1).expand(2, -1).t().contiguous().t()
out = segment_add_csr(src, indptr) out = segment_add_csr(src, indptr)
# print('CSR', out) print('CSR', out)
# index = tensor([0, 0, 1, 1, 1, 3], torch.long, device) # index = tensor([0, 0, 1, 1, 1, 3], torch.long, device)
# out = segment_add_coo(src, index) # out = segment_add_coo(src, index)
...@@ -95,7 +95,7 @@ def test_benchmark(dtype, device): ...@@ -95,7 +95,7 @@ def test_benchmark(dtype, device):
assert torch.allclose(out1, out4, atol=1e-2) assert torch.allclose(out1, out4, atol=1e-2)
x = torch.randn((data.num_edges, 1024), device=device) x = torch.randn((data.num_edges, 32), device=device)
torch.cuda.synchronize() torch.cuda.synchronize()
t = time.perf_counter() t = time.perf_counter()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment