Unverified Commit 15ed27d7 authored by bingps's avatar bingps Committed by GitHub
Browse files

[Fix] `concat_mla_absorb_q_kernel` fails for long inputs (#12453)

parent 66fb9b13
......@@ -18,11 +18,11 @@ __global__ void concat_mla_k_kernel(
const nv_bfloat16* __restrict__ k_nope,
const nv_bfloat16* __restrict__ k_rope,
const int num_tokens,
const int k_stride_0,
const int64_t k_stride_0,
const int k_stride_1,
const int k_nope_stride_0,
const int64_t k_nope_stride_0,
const int k_nope_stride_1,
const int k_rope_stride_0) {
const int64_t k_rope_stride_0) {
const int flat_warp_id = (blockIdx.x * blockDim.x + threadIdx.x) / 32;
const int token_id = flat_warp_id / NUM_HEAD_CHUNKS;
const int head_chunk_id = flat_warp_id % NUM_HEAD_CHUNKS;
......@@ -126,11 +126,11 @@ __global__ void concat_mla_absorb_q_kernel(
nv_bfloat16* out,
const int num_items,
const int dim_1,
const int a_stride_0,
const int64_t a_stride_0,
const int a_stride_1,
const int b_stride_0,
const int64_t b_stride_0,
const int b_stride_1,
const int out_stride_0,
const int64_t out_stride_0,
const int out_stride_1) {
const int flat_warp_id = (blockIdx.x * blockDim.x + threadIdx.x) / 32;
const int lane_id = get_lane_id();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment