Commit db777e5c authored by rusty1s's avatar rusty1s
Browse files

typo

parent 124c8115
...@@ -102,21 +102,20 @@ at::Tensor segment_add_csr_cuda(at::Tensor src, at::Tensor indptr) { ...@@ -102,21 +102,20 @@ at::Tensor segment_add_csr_cuda(at::Tensor src, at::Tensor indptr) {
auto src_data = src.DATA_PTR<scalar_t>(); auto src_data = src.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>(); auto out_data = out.DATA_PTR<scalar_t>();
if (avg_length <= 4) { if (avg_length <= 4)
segment_add_csr_kernel<scalar_t, 4><<<BLOCKS(4, N), THREADS, 0, stream>>>( segment_add_csr_kernel<scalar_t, 4><<<BLOCKS(4, N), THREADS, 0, stream>>>(
src_data, indptr_info, out_data, N, E); src_data, indptr_info, out_data, N, E);
} else if (avg_length <= 8) { else if (avg_length <= 8)
segment_add_csr_kernel<scalar_t, 8><<<BLOCKS(8, N), THREADS, 0, stream>>>( segment_add_csr_kernel<scalar_t, 8><<<BLOCKS(8, N), THREADS, 0, stream>>>(
src_data, indptr_info, out_data, N, E); src_data, indptr_info, out_data, N, E);
} else if (avg_length <= 16) { else if (avg_length <= 16)
segment_add_csr_kernel<scalar_t, 16> segment_add_csr_kernel<scalar_t, 16>
<<<BLOCKS(16, N), THREADS, 0, stream>>>(src_data, indptr_info, <<<BLOCKS(16, N), THREADS, 0, stream>>>(src_data, indptr_info,
out_data, N, E); out_data, N, E);
} else { else
segment_add_csr_kernel<scalar_t, 32> segment_add_csr_kernel<scalar_t, 32>
<<<BLOCKS(32, N), THREADS, 0, stream>>>(src_data, indptr_info, <<<BLOCKS(32, N), THREADS, 0, stream>>>(src_data, indptr_info,
out_data, N, E); out_data, N, E);
}
}); });
return out; return out;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment