Commit e046b382 authored by zhuwenwen's avatar zhuwenwen
Browse files

update triton_mla.py

parent f54ad7b9
...@@ -120,8 +120,9 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): ...@@ -120,8 +120,9 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
match_seq_len = int((decode_meta.seq_lens_tensor.sum()/ max(1, B)).item()) match_seq_len = int((decode_meta.seq_lens_tensor.sum()/ max(1, B)).item())
else: else:
match_seq_len = max_seq_len match_seq_len = max_seq_len
best_config = self.attn_configs[min(self.attn_configs.keys(), key=lambda x: abs(int(x) - match_seq_len))] if envs.VLLM_USE_TRITON_OPT_MLA:
best_config = self.attn_configs[min(self.attn_configs.keys(), key=lambda x: abs(int(x) - match_seq_len))]
# Run MQA # Run MQA
decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o, decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment