Commit 7a8722d7 authored by zhanghj2's avatar zhanghj2
Browse files

使用64位计算地址,避免大size类型溢出

parent d1c9d3fa
......@@ -24,7 +24,7 @@ __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttn
const int s_q_idx = blockIdx.x;
const int bidh = blockIdx.y;
const int lane_idx = tidx % 64;
const index_t row_offset_q = s_q_idx * params.stride_q_s_q + bidh * kBlockM * params.stride_q_h_q;
const index_t row_offset_q = s_q_idx * static_cast<index_t>(params.stride_q_s_q) + bidh * kBlockM * params.stride_q_h_q;
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q) + row_offset_q),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.stride_q_h_q, _1{}));
......@@ -403,7 +403,7 @@ __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttn
}
Tensor lse = softmax.template normalize_softmax_lse_prefill<false>(acc_o, sRow_sum_reduce_buffer, params.sm_scale);
const index_t row_offset_o = s_q_idx * params.h_q * params.d_v + bidh * kBlockM * params.d_v;
const index_t row_offset_o = s_q_idx * static_cast<index_t>(params.h_q * params.d_v) + bidh * kBlockM * params.d_v;
Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.out) + row_offset_o),
Shape<Int<kBlockM>, Int<kHeadDimV>>{},
make_stride(params.d_v, _1{}));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment