Commit 2ff5a773 authored by zhanghj2's avatar zhanghj2
Browse files

加入print_param

parent 79096f6b
......@@ -121,7 +121,21 @@ dense_attn_decode_interface(
KU_CHECK_SHAPE(tile_scheduler_metadata, num_sm_parts, sizeof(DecodingSchedMeta)/sizeof(int));
KU_CHECK_SHAPE(num_splits, batch_size+1);
}
bool print_param = (std::getenv("FLASH_MLA_PRINT_PARAM") != nullptr);
if (print_param) {
fprintf(stderr, "[FlashMLA] [dense_attn_decode_interface] [%s] batch_size = %d seqlen_q_ori = %d "
"num_heads_q = %d head_size_k = %d max_num_blocks_per_seq = %d num_blocks %d page_block_size = %d num_heads_k = %d \n",
arch.archName.c_str(),
batch_size,
seqlen_q_ori,
num_heads_q,
head_size_k,
max_num_blocks_per_seq,
num_blocks,
page_block_size,
num_heads_k
);
}
// Set the sizes
DenseAttnDecodeParams params;
params.b = batch_size;
......
......@@ -228,7 +228,18 @@ sparse_attn_decode_interface(
} else {
TORCH_CHECK(false, "Unsupported d_qk: ", d_qk);
}
bool print_param = (std::getenv("FLASH_MLA_PRINT_PARAM") != nullptr);
if (print_param) {
fprintf(stderr, "[FlashMLA] [sparse_attn_decode_interface] b = %d s_q = %d h_q = %d d_qk = %d num_blocks = %d page_block_size = %d h_kv = %d topk = %d "
"bytes_per_token = %ld have_topk_length = %d "
"have_extra_kcache = %d have_extra_topk_length = %d have_attn_sink = %d \n",
b, s_q, h_q, d_qk, num_blocks, page_block_size, h_kv, topk, kv.stride(1),
have_topk_length,
have_extra_kcache,
have_extra_topk_length,
have_attn_sink
);
}
std::vector<DecodeFeatures> features;
if (h_q <= 16) {
features.push_back(DecodeFeatures::HEAD_16);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment