Commit 124c8115 authored by rusty1s's avatar rusty1s
Browse files

no need to enforce stride=1

parent d762df6a
......@@ -41,6 +41,7 @@ template <typename T, typename I> struct IndexPtrToOffset {
static __host__ __device__ I
get(I idx, const at::cuda::detail::TensorInfo<T, I> &info) {
I offset = idx % (info.sizes[info.dims - 1] - 1);
offset *= info.strides[info.dims - 1];
idx /= info.sizes[info.dims - 1] - 1;
for (int i = info.dims - 2; i >= 0; --i) {
offset += (idx % info.sizes[i]) * info.strides[i];
......@@ -63,7 +64,8 @@ __global__ void segment_add_csr_kernel(
if (warp_idx < N) {
auto offset = IndexPtrToOffset<int64_t, int>::get(warp_idx, indptr_info);
int row_start = __ldg(indptr_info.data + offset);
int row_end = __ldg(indptr_info.data + offset + 1);
int row_end = __ldg(indptr_info.data + offset +
indptr_info.strides[indptr_info.dims - 1]);
scalar_t val = (scalar_t)0;
offset = (warp_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E;
......@@ -82,9 +84,8 @@ __global__ void segment_add_csr_kernel(
}
at::Tensor segment_add_csr_cuda(at::Tensor src, at::Tensor indptr) {
src = src.contiguous();
AT_ASSERTM(indptr.stride(-1) == 1);
AT_ASSERTM(src.dim() >= indptr.dim());
src = src.contiguous();
auto reduce_dim = indptr.dim() - 1;
auto sizes = src.sizes().vec();
......
......@@ -25,8 +25,7 @@ def test_forward2(dtype, device):
src = tensor([[1, 2, 3, 4, 5, 6], [1, 3, 5, 7, 9, 11]], dtype, device)
indptr = tensor([[0, 2, 5, 5, 6]], torch.long, device)
indptr = indptr.view(1, -1).expand(2, -1)
assert indptr.stride(-1) == 1
indptr = indptr.view(1, -1).expand(2, -1).t().contiguous().t()
out = segment_add_csr(src, indptr)
print('CSR', out)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment