Unverified Commit 5d15fb8c authored by Tao He's avatar Tao He Committed by GitHub
Browse files

[bugifx] QWen-1M context support[2/3] using current cuda stream in the DCA's...


[bugifx] QWen-1M context support[2/3] using current cuda stream in the DCA's kernel for bugfix. (#8611)
Signed-off-by: default avatarTao He <linzhu.ht@alibaba-inc.com>
Co-authored-by: default avatarsa-buc <linzhu.ht@w32d09270.cloud.sqa.na131>
parent 016fd251
......@@ -3,6 +3,7 @@
// This file is for blocksparse attention utils cuda kernel.
#include <assert.h>
#include <c10/cuda/CUDAStream.h>
#include <cuda.h>
#include <torch/all.h>
......@@ -176,7 +177,8 @@ void convert_vertical_slash_indexes_64x64(
const dim3 dimBlock((int32_t)N_THREADS);
const dim3 dimGrid(
(int32_t)N_HEADS, (int32_t)BATCH_SIZE, ((int32_t)N_ROWS + (int32_t)N_THREADS - 1) / (int32_t)N_THREADS);
convert_vertical_slash_indexes_kernel<<<dimGrid, dimBlock>>>(
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
convert_vertical_slash_indexes_kernel<<<dimGrid, dimBlock, 0, stream>>>(
q_seqlens,
kv_seqlens,
vertical_indexes,
......@@ -393,7 +395,8 @@ void convert_vertical_slash_indexes_64x64_mergehead(
const int N_THREADS = 64;
const dim3 dimBlock(N_THREADS);
const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS);
convert_vertical_slash_indexes_kernel_mergehead<<<dimGrid, dimBlock>>>(
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
convert_vertical_slash_indexes_kernel_mergehead<<<dimGrid, dimBlock, 0, stream>>>(
q_seqlens,
kv_seqlens,
vertical_indexes,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment