Unverified Commit 714f3e63 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

feat: support flashinfer mla with prefix cache (#3643)

parent c38f3aed
...@@ -54,7 +54,9 @@ class DecodeMetadata: ...@@ -54,7 +54,9 @@ class DecodeMetadata:
@dataclass @dataclass
class PrefillMetadata: class PrefillMetadata:
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper] prefill_wrappers: List[
Union[BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
]
use_ragged: bool use_ragged: bool
extend_no_prefix: bool extend_no_prefix: bool
...@@ -160,16 +162,36 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -160,16 +162,36 @@ class FlashInferAttnBackend(AttentionBackend):
self.decode_wrappers = [] self.decode_wrappers = []
for _ in range(self.num_wrappers): for _ in range(self.num_wrappers):
if not skip_prefill: if not skip_prefill:
self.prefill_wrappers_paged.append( if (
BatchPrefillWithPagedKVCacheWrapper( self.enable_flashinfer_mla
self.workspace_buffer, and not global_server_args_dict["disable_radix_cache"]
"NHD", ):
backend="fa2", # use mla paged prefill
self.prefill_wrappers_paged.append(
BatchMLAPagedAttentionWrapper(
self.workspace_buffer,
backend="fa2",
)
)
self.prefill_wrappers_verify.append(
BatchMLAPagedAttentionWrapper(
self.workspace_buffer,
backend="fa2",
)
)
else:
self.prefill_wrappers_paged.append(
BatchPrefillWithPagedKVCacheWrapper(
self.workspace_buffer,
"NHD",
backend="fa2",
)
)
self.prefill_wrappers_verify.append(
BatchPrefillWithPagedKVCacheWrapper(
self.workspace_buffer, "NHD"
)
) )
)
self.prefill_wrappers_verify.append(
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
)
if self.enable_flashinfer_mla: if self.enable_flashinfer_mla:
self.decode_wrappers.append( self.decode_wrappers.append(
BatchMLAPagedAttentionWrapper(self.workspace_buffer, backend="fa2") BatchMLAPagedAttentionWrapper(self.workspace_buffer, backend="fa2")
...@@ -237,7 +259,10 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -237,7 +259,10 @@ class FlashInferAttnBackend(AttentionBackend):
else: else:
prefix_lens = forward_batch.extend_prefix_lens prefix_lens = forward_batch.extend_prefix_lens
if self.is_multimodal: if self.is_multimodal or (
self.enable_flashinfer_mla
and not global_server_args_dict["disable_radix_cache"]
):
use_ragged = False use_ragged = False
extend_no_prefix = False extend_no_prefix = False
else: else:
...@@ -419,23 +444,43 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -419,23 +444,43 @@ class FlashInferAttnBackend(AttentionBackend):
logits_soft_cap = layer.logit_cap logits_soft_cap = layer.logit_cap
o1, _ = self.prefill_wrapper_ragged.forward_return_lse( if global_server_args_dict["disable_radix_cache"]:
q.view(-1, layer.tp_q_head_num, layer.head_dim), # use mla ragged prefill
k.view(-1, layer.tp_k_head_num, layer.head_dim), o, _ = self.prefill_wrapper_ragged.forward_return_lse(
v.view(-1, layer.tp_v_head_num, layer.v_head_dim), q.view(-1, layer.tp_q_head_num, layer.head_dim),
causal=True, k.view(-1, layer.tp_k_head_num, layer.head_dim),
sm_scale=layer.scaling, v.view(-1, layer.tp_v_head_num, layer.v_head_dim),
logits_soft_cap=logits_soft_cap, causal=True,
) sm_scale=layer.scaling,
logits_soft_cap=logits_soft_cap,
o = o1 )
if save_kv_cache: if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer( forward_batch.token_to_kv_pool.set_kv_buffer(
layer, layer,
cache_loc, cache_loc,
k, k,
v, v,
)
else:
# use mla paged prefill
prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
self._get_wrapper_idx(layer)
]
if k is not None:
assert v is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v
)
qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
o = prefill_wrapper_paged.run(
qall[:, :, : layer.v_head_dim],
qall[:, :, layer.v_head_dim :],
k_buf[:, :, : layer.v_head_dim],
k_buf[:, :, layer.v_head_dim :],
) )
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
...@@ -800,7 +845,9 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -800,7 +845,9 @@ class FlashInferIndicesUpdaterPrefill:
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_sum: int, seq_lens_sum: int,
prefix_lens: torch.Tensor, prefix_lens: torch.Tensor,
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], prefill_wrappers: List[
Union[BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
],
use_ragged: bool, use_ragged: bool,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo], spec_info: Optional[SpecInfo],
...@@ -814,7 +861,9 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -814,7 +861,9 @@ class FlashInferIndicesUpdaterPrefill:
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_sum: int, seq_lens_sum: int,
prefix_lens: torch.Tensor, prefix_lens: torch.Tensor,
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], prefill_wrappers: List[
Union[BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
],
use_ragged: bool, use_ragged: bool,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo], spec_info: Optional[SpecInfo],
...@@ -923,7 +972,9 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -923,7 +972,9 @@ class FlashInferIndicesUpdaterPrefill:
def call_begin_forward( def call_begin_forward(
self, self,
wrapper_ragged: BatchPrefillWithRaggedKVCacheWrapper, wrapper_ragged: BatchPrefillWithRaggedKVCacheWrapper,
wrapper_paged: BatchPrefillWithPagedKVCacheWrapper, wrapper_paged: Union[
BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper
],
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
paged_kernel_lens: torch.Tensor, paged_kernel_lens: torch.Tensor,
paged_kernel_lens_sum: int, paged_kernel_lens_sum: int,
...@@ -1004,6 +1055,26 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -1004,6 +1055,26 @@ class FlashInferIndicesUpdaterPrefill:
custom_mask=custom_mask, custom_mask=custom_mask,
non_blocking=True, non_blocking=True,
) )
elif (
global_config.enable_flashinfer_mla
and not global_server_args_dict["disable_radix_cache"]
):
# mla paged prefill
kv_len_arr = kv_indptr[1:] - kv_indptr[:-1]
wrapper_paged.plan(
qo_indptr,
kv_indptr,
kv_indices,
kv_len_arr,
self.num_qo_heads,
512,
64,
1,
True,
1 / math.sqrt(192),
self.data_type,
self.data_type,
)
class FlashInferMultiStepDraftBackend: class FlashInferMultiStepDraftBackend:
......
...@@ -66,6 +66,7 @@ global_server_args_dict = { ...@@ -66,6 +66,7 @@ global_server_args_dict = {
"enable_ep_moe": ServerArgs.enable_ep_moe, "enable_ep_moe": ServerArgs.enable_ep_moe,
"device": ServerArgs.device, "device": ServerArgs.device,
"enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla, "enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla,
"disable_radix_cache": ServerArgs.disable_radix_cache,
} }
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -177,6 +177,7 @@ class ModelRunner: ...@@ -177,6 +177,7 @@ class ModelRunner:
"enable_ep_moe": server_args.enable_ep_moe, "enable_ep_moe": server_args.enable_ep_moe,
"device": server_args.device, "device": server_args.device,
"enable_flashinfer_mla": server_args.enable_flashinfer_mla, "enable_flashinfer_mla": server_args.enable_flashinfer_mla,
"disable_radix_cache": server_args.disable_radix_cache,
} }
) )
......
...@@ -511,8 +511,11 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -511,8 +511,11 @@ class DeepseekV2AttentionMLA(nn.Module):
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
if global_server_args_dict["enable_flashinfer_mla"]: if global_server_args_dict["enable_flashinfer_mla"]:
if forward_batch.forward_mode.is_extend(): if global_server_args_dict["disable_radix_cache"]:
return self.forward_normal(positions, hidden_states, forward_batch) if forward_batch.forward_mode.is_extend():
return self.forward_normal(positions, hidden_states, forward_batch)
else:
return self.forward_absorb(positions, hidden_states, forward_batch)
else: else:
return self.forward_absorb(positions, hidden_states, forward_batch) return self.forward_absorb(positions, hidden_states, forward_batch)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment