Commit c3a5b02a authored by zhanghj2's avatar zhanghj2
Browse files

加入打印使用的环境变量

parent fd2b2d8f
...@@ -121,7 +121,10 @@ dense_attn_decode_interface( ...@@ -121,7 +121,10 @@ dense_attn_decode_interface(
KU_CHECK_SHAPE(tile_scheduler_metadata, num_sm_parts, sizeof(DecodingSchedMeta)/sizeof(int)); KU_CHECK_SHAPE(tile_scheduler_metadata, num_sm_parts, sizeof(DecodingSchedMeta)/sizeof(int));
KU_CHECK_SHAPE(num_splits, batch_size+1); KU_CHECK_SHAPE(num_splits, batch_size+1);
} }
bool print_param = (std::getenv("FLASH_MLA_PRINT_PARAM") != nullptr); bool print_param = false;
if (const char* val = std::getenv("FLASH_MLA_PRINT_PARAM")) {
print_param = (std::string(val) == "1");
}
if (print_param) { if (print_param) {
fprintf(stderr, "[FlashMLA] [dense_attn_decode_interface] [%s] batch_size = %d seqlen_q_ori = %d " fprintf(stderr, "[FlashMLA] [dense_attn_decode_interface] [%s] batch_size = %d seqlen_q_ori = %d "
"num_heads_q = %d head_size_k = %d max_num_blocks_per_seq = %d num_blocks %d page_block_size = %d num_heads_k = %d \n", "num_heads_q = %d head_size_k = %d max_num_blocks_per_seq = %d num_blocks %d page_block_size = %d num_heads_k = %d \n",
......
...@@ -228,11 +228,15 @@ sparse_attn_decode_interface( ...@@ -228,11 +228,15 @@ sparse_attn_decode_interface(
} else { } else {
TORCH_CHECK(false, "Unsupported d_qk: ", d_qk); TORCH_CHECK(false, "Unsupported d_qk: ", d_qk);
} }
bool print_param = (std::getenv("FLASH_MLA_PRINT_PARAM") != nullptr); bool print_param = false;
if (const char* val = std::getenv("FLASH_MLA_PRINT_PARAM")) {
print_param = (std::string(val) == "1");
}
if (print_param) { if (print_param) {
fprintf(stderr, "[FlashMLA] [sparse_attn_decode_interface] b = %d s_q = %d h_q = %d d_qk = %d num_blocks = %d page_block_size = %d h_kv = %d topk = %d " fprintf(stderr, "[FlashMLA] [sparse_attn_decode_interface] [%s] b = %d s_q = %d h_q = %d d_qk = %d num_blocks = %d page_block_size = %d h_kv = %d topk = %d "
"bytes_per_token = %ld have_topk_length = %d " "bytes_per_token = %ld have_topk_length = %d "
"have_extra_kcache = %d have_extra_topk_length = %d have_attn_sink = %d \n", "have_extra_kcache = %d have_extra_topk_length = %d have_attn_sink = %d \n",
arch.archName.c_str(),
b, s_q, h_q, d_qk, num_blocks, page_block_size, h_kv, topk, kv.stride(1), b, s_q, h_q, d_qk, num_blocks, page_block_size, h_kv, topk, kv.stride(1),
have_topk_length, have_topk_length,
have_extra_kcache, have_extra_kcache,
......
...@@ -113,7 +113,19 @@ static std::vector<at::Tensor> sparse_attn_prefill_interface( ...@@ -113,7 +113,19 @@ static std::vector<at::Tensor> sparse_attn_prefill_interface(
KU_CHECK_CONTIGUOUS(out); KU_CHECK_CONTIGUOUS(out);
KU_CHECK_CONTIGUOUS(lse); KU_CHECK_CONTIGUOUS(lse);
KU_CHECK_CONTIGUOUS(max_logits); KU_CHECK_CONTIGUOUS(max_logits);
bool print_param = false;
if (const char* val = std::getenv("FLASH_MLA_PRINT_PARAM")) {
print_param = (std::string(val) == "1");
}
if (print_param) {
fprintf(stderr, "[FlashMLA] [sparse_attn_prefill_interface] [%s] "
"s_q = %d s_kv = %d h_q = %d h_kv = %d d_qk = %d "
"topk = %d have_topk_length = %d \n",
arch.archName.c_str(),
s_q, s_kv, h_q, h_kv, d_qk,
topk, have_topk_length
);
}
SparseAttnFwdParams params = { SparseAttnFwdParams params = {
s_q, s_kv, h_q, h_kv, d_qk, d_v, topk, s_q, s_kv, h_q, h_kv, d_qk, d_v, topk,
sm_scale, sm_scale * LOG_2_E, sm_scale, sm_scale * LOG_2_E,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment