Commit 7d959770 authored by zhuwenwen's avatar zhuwenwen
Browse files

update VLLM_USE_TRITON_CAT during the prefill phase

parent 072d4638
...@@ -923,7 +923,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -923,7 +923,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
if envs.VLLM_USE_TRITON_CAT: if envs.VLLM_USE_TRITON_CAT:
k = concat_helper((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), k = concat_helper(k_nope, k_pe.expand((*k_nope.shape[:-1], -1)),
dim=-1) dim=-1)
else: else:
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
...@@ -983,7 +983,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -983,7 +983,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
if envs.VLLM_USE_TRITON_CAT: if envs.VLLM_USE_TRITON_CAT:
k = concat_helper((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) k = concat_helper(k_nope, k_pe.expand((*k_nope.shape[:-1], -1)), dim=-1)
else: else:
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment