"docs/vscode:/vscode.git/clone" did not exist on "14e3a28c120eea88093442eb0a2a3df35d21a22d"
Unverified Commit 714f3e63 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

feat: support flashinfer mla with prefix cache (#3643)

parent c38f3aed
...@@ -54,7 +54,9 @@ class DecodeMetadata: ...@@ -54,7 +54,9 @@ class DecodeMetadata:
@dataclass @dataclass
class PrefillMetadata: class PrefillMetadata:
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper] prefill_wrappers: List[
Union[BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
]
use_ragged: bool use_ragged: bool
extend_no_prefix: bool extend_no_prefix: bool
...@@ -160,6 +162,24 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -160,6 +162,24 @@ class FlashInferAttnBackend(AttentionBackend):
self.decode_wrappers = [] self.decode_wrappers = []
for _ in range(self.num_wrappers): for _ in range(self.num_wrappers):
if not skip_prefill: if not skip_prefill:
if (
self.enable_flashinfer_mla
and not global_server_args_dict["disable_radix_cache"]
):
# use mla paged prefill
self.prefill_wrappers_paged.append(
BatchMLAPagedAttentionWrapper(
self.workspace_buffer,
backend="fa2",
)
)
self.prefill_wrappers_verify.append(
BatchMLAPagedAttentionWrapper(
self.workspace_buffer,
backend="fa2",
)
)
else:
self.prefill_wrappers_paged.append( self.prefill_wrappers_paged.append(
BatchPrefillWithPagedKVCacheWrapper( BatchPrefillWithPagedKVCacheWrapper(
self.workspace_buffer, self.workspace_buffer,
...@@ -168,7 +188,9 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -168,7 +188,9 @@ class FlashInferAttnBackend(AttentionBackend):
) )
) )
self.prefill_wrappers_verify.append( self.prefill_wrappers_verify.append(
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD") BatchPrefillWithPagedKVCacheWrapper(
self.workspace_buffer, "NHD"
)
) )
if self.enable_flashinfer_mla: if self.enable_flashinfer_mla:
self.decode_wrappers.append( self.decode_wrappers.append(
...@@ -237,7 +259,10 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -237,7 +259,10 @@ class FlashInferAttnBackend(AttentionBackend):
else: else:
prefix_lens = forward_batch.extend_prefix_lens prefix_lens = forward_batch.extend_prefix_lens
if self.is_multimodal: if self.is_multimodal or (
self.enable_flashinfer_mla
and not global_server_args_dict["disable_radix_cache"]
):
use_ragged = False use_ragged = False
extend_no_prefix = False extend_no_prefix = False
else: else:
...@@ -419,7 +444,9 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -419,7 +444,9 @@ class FlashInferAttnBackend(AttentionBackend):
logits_soft_cap = layer.logit_cap logits_soft_cap = layer.logit_cap
o1, _ = self.prefill_wrapper_ragged.forward_return_lse( if global_server_args_dict["disable_radix_cache"]:
# use mla ragged prefill
o, _ = self.prefill_wrapper_ragged.forward_return_lse(
q.view(-1, layer.tp_q_head_num, layer.head_dim), q.view(-1, layer.tp_q_head_num, layer.head_dim),
k.view(-1, layer.tp_k_head_num, layer.head_dim), k.view(-1, layer.tp_k_head_num, layer.head_dim),
v.view(-1, layer.tp_v_head_num, layer.v_head_dim), v.view(-1, layer.tp_v_head_num, layer.v_head_dim),
...@@ -428,8 +455,6 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -428,8 +455,6 @@ class FlashInferAttnBackend(AttentionBackend):
logits_soft_cap=logits_soft_cap, logits_soft_cap=logits_soft_cap,
) )
o = o1
if save_kv_cache: if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer( forward_batch.token_to_kv_pool.set_kv_buffer(
layer, layer,
...@@ -437,6 +462,26 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -437,6 +462,26 @@ class FlashInferAttnBackend(AttentionBackend):
k, k,
v, v,
) )
else:
# use mla paged prefill
prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
self._get_wrapper_idx(layer)
]
if k is not None:
assert v is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v
)
qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
o = prefill_wrapper_paged.run(
qall[:, :, : layer.v_head_dim],
qall[:, :, layer.v_head_dim :],
k_buf[:, :, : layer.v_head_dim],
k_buf[:, :, layer.v_head_dim :],
)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
else: else:
...@@ -800,7 +845,9 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -800,7 +845,9 @@ class FlashInferIndicesUpdaterPrefill:
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_sum: int, seq_lens_sum: int,
prefix_lens: torch.Tensor, prefix_lens: torch.Tensor,
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], prefill_wrappers: List[
Union[BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
],
use_ragged: bool, use_ragged: bool,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo], spec_info: Optional[SpecInfo],
...@@ -814,7 +861,9 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -814,7 +861,9 @@ class FlashInferIndicesUpdaterPrefill:
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_sum: int, seq_lens_sum: int,
prefix_lens: torch.Tensor, prefix_lens: torch.Tensor,
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], prefill_wrappers: List[
Union[BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
],
use_ragged: bool, use_ragged: bool,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo], spec_info: Optional[SpecInfo],
...@@ -923,7 +972,9 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -923,7 +972,9 @@ class FlashInferIndicesUpdaterPrefill:
def call_begin_forward( def call_begin_forward(
self, self,
wrapper_ragged: BatchPrefillWithRaggedKVCacheWrapper, wrapper_ragged: BatchPrefillWithRaggedKVCacheWrapper,
wrapper_paged: BatchPrefillWithPagedKVCacheWrapper, wrapper_paged: Union[
BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper
],
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
paged_kernel_lens: torch.Tensor, paged_kernel_lens: torch.Tensor,
paged_kernel_lens_sum: int, paged_kernel_lens_sum: int,
...@@ -1004,6 +1055,26 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -1004,6 +1055,26 @@ class FlashInferIndicesUpdaterPrefill:
custom_mask=custom_mask, custom_mask=custom_mask,
non_blocking=True, non_blocking=True,
) )
elif (
global_config.enable_flashinfer_mla
and not global_server_args_dict["disable_radix_cache"]
):
# mla paged prefill
kv_len_arr = kv_indptr[1:] - kv_indptr[:-1]
wrapper_paged.plan(
qo_indptr,
kv_indptr,
kv_indices,
kv_len_arr,
self.num_qo_heads,
512,
64,
1,
True,
1 / math.sqrt(192),
self.data_type,
self.data_type,
)
class FlashInferMultiStepDraftBackend: class FlashInferMultiStepDraftBackend:
......
...@@ -66,6 +66,7 @@ global_server_args_dict = { ...@@ -66,6 +66,7 @@ global_server_args_dict = {
"enable_ep_moe": ServerArgs.enable_ep_moe, "enable_ep_moe": ServerArgs.enable_ep_moe,
"device": ServerArgs.device, "device": ServerArgs.device,
"enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla, "enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla,
"disable_radix_cache": ServerArgs.disable_radix_cache,
} }
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -177,6 +177,7 @@ class ModelRunner: ...@@ -177,6 +177,7 @@ class ModelRunner:
"enable_ep_moe": server_args.enable_ep_moe, "enable_ep_moe": server_args.enable_ep_moe,
"device": server_args.device, "device": server_args.device,
"enable_flashinfer_mla": server_args.enable_flashinfer_mla, "enable_flashinfer_mla": server_args.enable_flashinfer_mla,
"disable_radix_cache": server_args.disable_radix_cache,
} }
) )
......
...@@ -511,10 +511,13 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -511,10 +511,13 @@ class DeepseekV2AttentionMLA(nn.Module):
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
if global_server_args_dict["enable_flashinfer_mla"]: if global_server_args_dict["enable_flashinfer_mla"]:
if global_server_args_dict["disable_radix_cache"]:
if forward_batch.forward_mode.is_extend(): if forward_batch.forward_mode.is_extend():
return self.forward_normal(positions, hidden_states, forward_batch) return self.forward_normal(positions, hidden_states, forward_batch)
else: else:
return self.forward_absorb(positions, hidden_states, forward_batch) return self.forward_absorb(positions, hidden_states, forward_batch)
else:
return self.forward_absorb(positions, hidden_states, forward_batch)
else: else:
# Triton: Use normal computation for prefill and use weight absorption for extend/decode # Triton: Use normal computation for prefill and use weight absorption for extend/decode
if ( if (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment