Commit 072d4638 authored by zhuwenwen's avatar zhuwenwen
Browse files

use VLLM_USE_TRITON_CAT during the prefill phase

parent ff090f36
...@@ -982,6 +982,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -982,6 +982,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
k_nope, v = kv_nope\ k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
if envs.VLLM_USE_TRITON_CAT:
k = concat_helper((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
else:
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
output = self._flash_attn_varlen_diff_headdims( output = self._flash_attn_varlen_diff_headdims(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment