Commit 4b3bcb50 authored by zhanghj2's avatar zhanghj2
Browse files

使用64位计算stride,避免越界

parent 84f83034
...@@ -366,7 +366,7 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp ...@@ -366,7 +366,7 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
} }
int block_index = token_index == -1 ? 0 : (int)((uint32_t)token_index/(uint32_t)page_block_size); // Use uint32_t division and mod to improve performance const int token_indexrel_idx_in_block = (token_index + page_block_size) % page_block_size; int block_index = token_index == -1 ? 0 : (int)((uint32_t)token_index/(uint32_t)page_block_size); // Use uint32_t division and mod to improve performance const int token_indexrel_idx_in_block = (token_index + page_block_size) % page_block_size;
int rel_idx_in_block = (uint32_t)token_index % (uint32_t)page_block_size; // NOTE When token_index is -1 (UINT_MAX), UINT_MAX%page_block_size < page_block_size, so there will be no illegal-memory-access error int rel_idx_in_block = (uint32_t)token_index % (uint32_t)page_block_size; // NOTE When token_index is -1 (UINT_MAX), UINT_MAX%page_block_size < page_block_size, so there will be no illegal-memory-access error
const index_t offset_k = block_index * k_block_stride; const index_t offset_k = (index_t)(block_index) * k_block_stride;
uint8_t* gK_base; uint8_t* gK_base;
float scales[NUM_SCALES]; float scales[NUM_SCALES];
if constexpr (MODEL_TYPE == ModelType::V32) { if constexpr (MODEL_TYPE == ModelType::V32) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment