Commit 124c8115 authored by rusty1s's avatar rusty1s
Browse files

no need to enforce stride=1

parent d762df6a
...@@ -41,6 +41,7 @@ template <typename T, typename I> struct IndexPtrToOffset { ...@@ -41,6 +41,7 @@ template <typename T, typename I> struct IndexPtrToOffset {
static __host__ __device__ I static __host__ __device__ I
get(I idx, const at::cuda::detail::TensorInfo<T, I> &info) { get(I idx, const at::cuda::detail::TensorInfo<T, I> &info) {
I offset = idx % (info.sizes[info.dims - 1] - 1); I offset = idx % (info.sizes[info.dims - 1] - 1);
offset *= info.strides[info.dims - 1];
idx /= info.sizes[info.dims - 1] - 1; idx /= info.sizes[info.dims - 1] - 1;
for (int i = info.dims - 2; i >= 0; --i) { for (int i = info.dims - 2; i >= 0; --i) {
offset += (idx % info.sizes[i]) * info.strides[i]; offset += (idx % info.sizes[i]) * info.strides[i];
...@@ -63,7 +64,8 @@ __global__ void segment_add_csr_kernel( ...@@ -63,7 +64,8 @@ __global__ void segment_add_csr_kernel(
if (warp_idx < N) { if (warp_idx < N) {
auto offset = IndexPtrToOffset<int64_t, int>::get(warp_idx, indptr_info); auto offset = IndexPtrToOffset<int64_t, int>::get(warp_idx, indptr_info);
int row_start = __ldg(indptr_info.data + offset); int row_start = __ldg(indptr_info.data + offset);
int row_end = __ldg(indptr_info.data + offset + 1); int row_end = __ldg(indptr_info.data + offset +
indptr_info.strides[indptr_info.dims - 1]);
scalar_t val = (scalar_t)0; scalar_t val = (scalar_t)0;
offset = (warp_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E; offset = (warp_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E;
...@@ -82,9 +84,8 @@ __global__ void segment_add_csr_kernel( ...@@ -82,9 +84,8 @@ __global__ void segment_add_csr_kernel(
} }
at::Tensor segment_add_csr_cuda(at::Tensor src, at::Tensor indptr) { at::Tensor segment_add_csr_cuda(at::Tensor src, at::Tensor indptr) {
src = src.contiguous();
AT_ASSERTM(indptr.stride(-1) == 1);
AT_ASSERTM(src.dim() >= indptr.dim()); AT_ASSERTM(src.dim() >= indptr.dim());
src = src.contiguous();
auto reduce_dim = indptr.dim() - 1; auto reduce_dim = indptr.dim() - 1;
auto sizes = src.sizes().vec(); auto sizes = src.sizes().vec();
......
...@@ -25,8 +25,7 @@ def test_forward2(dtype, device): ...@@ -25,8 +25,7 @@ def test_forward2(dtype, device):
src = tensor([[1, 2, 3, 4, 5, 6], [1, 3, 5, 7, 9, 11]], dtype, device) src = tensor([[1, 2, 3, 4, 5, 6], [1, 3, 5, 7, 9, 11]], dtype, device)
indptr = tensor([[0, 2, 5, 5, 6]], torch.long, device) indptr = tensor([[0, 2, 5, 5, 6]], torch.long, device)
indptr = indptr.view(1, -1).expand(2, -1) indptr = indptr.view(1, -1).expand(2, -1).t().contiguous().t()
assert indptr.stride(-1) == 1
out = segment_add_csr(src, indptr) out = segment_add_csr(src, indptr)
print('CSR', out) print('CSR', out)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment