Commit ee355d86 authored by laibao's avatar laibao
Browse files

mla: 恢复 opt-cat 在 prefill 和 decode 的拼接路由

parent db85ab07
......@@ -1771,6 +1771,21 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
k[..., k_nope.shape[-1] :] = k_pe
return k
def _concat_k_nope_k_pe_prefill(
self, k_nope: torch.Tensor, k_pe: torch.Tensor
) -> torch.Tensor:
if envs.VLLM_USE_OPT_CAT and k_nope.shape[0] > 1024:
from vllm.v1.attention.backends.mla.test_concat import (
lightop_concat_prefill_helper,
)
return lightop_concat_prefill_helper(
k_nope,
k_pe.expand((*k_nope.shape[:-1], -1)),
dim=2,
)
return self._concat_k_nope_k_pe(k_nope, k_pe)
def _compute_prefill_context(
self,
q: torch.Tensor,
......@@ -1816,7 +1831,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
# k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
# else:
# k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
k = self._concat_k_nope_k_pe(k_nope, k_pe)
k = self._concat_k_nope_k_pe_prefill(k_nope, k_pe)
attn_output, attn_softmax_lse = self._run_prefill_context_chunk(
prefill=prefill_metadata,
......@@ -1920,7 +1935,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim
)
k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k = self._concat_k_nope_k_pe(k_nope, k_pe)
k = self._concat_k_nope_k_pe_prefill(k_nope, k_pe)
attn_output, attn_softmax_lse = self._run_prefill_context_chunk(
prefill=prefill_metadata,
......@@ -1982,7 +1997,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
# dim=-1)
# else:
# k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
k = self._concat_k_nope_k_pe(k_nope, k_pe)
k = self._concat_k_nope_k_pe_prefill(k_nope, k_pe)
output_prefill = self._run_prefill_new_tokens(
prefill=attn_metadata.prefill,
......
......@@ -247,18 +247,16 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
# TODO
# if envs.VLLM_USE_OPT_CAT:
# if q_nope.shape[0] < 1024:
# from vllm.v1.attention.backends.mla.test_concat import concat_helper_decode
# q = concat_helper_decode(q_nope, q_pe, dim=2)\
# .unsqueeze(1)
# else:
# q = torch.cat([q_nope, q_pe], dim=-1)\
# .unsqueeze(1) # Add seqlen dim of 1 (decode)
if type(q) is tuple:
q = torch.cat(q, dim=-1)
if isinstance(q, tuple):
q_nope, q_pe = q
if envs.VLLM_USE_OPT_CAT and q_nope.shape[0] < 1024:
from vllm.v1.attention.backends.mla.test_concat import (
concat_helper_decode,
)
q = concat_helper_decode(q_nope, q_pe, dim=2)
else:
q = torch.cat((q_nope, q_pe), dim=-1)
# mypy assertion: q is now always a tensor
assert isinstance(q, torch.Tensor)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment