"torchvision/vscode:/vscode.git/clone" did not exist on "11e49de410ec84ec669293a91dfaa13a53c9bc47"
Unverified Commit 714f3e63 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

feat: support flashinfer mla with prefix cache (#3643)

parent c38f3aed
......@@ -54,7 +54,9 @@ class DecodeMetadata:
@dataclass
class PrefillMetadata:
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper]
prefill_wrappers: List[
Union[BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
]
use_ragged: bool
extend_no_prefix: bool
......@@ -160,16 +162,36 @@ class FlashInferAttnBackend(AttentionBackend):
self.decode_wrappers = []
for _ in range(self.num_wrappers):
if not skip_prefill:
self.prefill_wrappers_paged.append(
BatchPrefillWithPagedKVCacheWrapper(
self.workspace_buffer,
"NHD",
backend="fa2",
if (
self.enable_flashinfer_mla
and not global_server_args_dict["disable_radix_cache"]
):
# use mla paged prefill
self.prefill_wrappers_paged.append(
BatchMLAPagedAttentionWrapper(
self.workspace_buffer,
backend="fa2",
)
)
self.prefill_wrappers_verify.append(
BatchMLAPagedAttentionWrapper(
self.workspace_buffer,
backend="fa2",
)
)
else:
self.prefill_wrappers_paged.append(
BatchPrefillWithPagedKVCacheWrapper(
self.workspace_buffer,
"NHD",
backend="fa2",
)
)
self.prefill_wrappers_verify.append(
BatchPrefillWithPagedKVCacheWrapper(
self.workspace_buffer, "NHD"
)
)
)
self.prefill_wrappers_verify.append(
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
)
if self.enable_flashinfer_mla:
self.decode_wrappers.append(
BatchMLAPagedAttentionWrapper(self.workspace_buffer, backend="fa2")
......@@ -237,7 +259,10 @@ class FlashInferAttnBackend(AttentionBackend):
else:
prefix_lens = forward_batch.extend_prefix_lens
if self.is_multimodal:
if self.is_multimodal or (
self.enable_flashinfer_mla
and not global_server_args_dict["disable_radix_cache"]
):
use_ragged = False
extend_no_prefix = False
else:
......@@ -419,23 +444,43 @@ class FlashInferAttnBackend(AttentionBackend):
logits_soft_cap = layer.logit_cap
o1, _ = self.prefill_wrapper_ragged.forward_return_lse(
q.view(-1, layer.tp_q_head_num, layer.head_dim),
k.view(-1, layer.tp_k_head_num, layer.head_dim),
v.view(-1, layer.tp_v_head_num, layer.v_head_dim),
causal=True,
sm_scale=layer.scaling,
logits_soft_cap=logits_soft_cap,
)
o = o1
if global_server_args_dict["disable_radix_cache"]:
# use mla ragged prefill
o, _ = self.prefill_wrapper_ragged.forward_return_lse(
q.view(-1, layer.tp_q_head_num, layer.head_dim),
k.view(-1, layer.tp_k_head_num, layer.head_dim),
v.view(-1, layer.tp_v_head_num, layer.v_head_dim),
causal=True,
sm_scale=layer.scaling,
logits_soft_cap=logits_soft_cap,
)
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer,
cache_loc,
k,
v,
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer,
cache_loc,
k,
v,
)
else:
# use mla paged prefill
prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
self._get_wrapper_idx(layer)
]
if k is not None:
assert v is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v
)
qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
o = prefill_wrapper_paged.run(
qall[:, :, : layer.v_head_dim],
qall[:, :, layer.v_head_dim :],
k_buf[:, :, : layer.v_head_dim],
k_buf[:, :, layer.v_head_dim :],
)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
......@@ -800,7 +845,9 @@ class FlashInferIndicesUpdaterPrefill:
seq_lens: torch.Tensor,
seq_lens_sum: int,
prefix_lens: torch.Tensor,
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
prefill_wrappers: List[
Union[BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
],
use_ragged: bool,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo],
......@@ -814,7 +861,9 @@ class FlashInferIndicesUpdaterPrefill:
seq_lens: torch.Tensor,
seq_lens_sum: int,
prefix_lens: torch.Tensor,
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
prefill_wrappers: List[
Union[BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
],
use_ragged: bool,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo],
......@@ -923,7 +972,9 @@ class FlashInferIndicesUpdaterPrefill:
def call_begin_forward(
self,
wrapper_ragged: BatchPrefillWithRaggedKVCacheWrapper,
wrapper_paged: BatchPrefillWithPagedKVCacheWrapper,
wrapper_paged: Union[
BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper
],
req_pool_indices: torch.Tensor,
paged_kernel_lens: torch.Tensor,
paged_kernel_lens_sum: int,
......@@ -1004,6 +1055,26 @@ class FlashInferIndicesUpdaterPrefill:
custom_mask=custom_mask,
non_blocking=True,
)
elif (
global_config.enable_flashinfer_mla
and not global_server_args_dict["disable_radix_cache"]
):
# mla paged prefill
kv_len_arr = kv_indptr[1:] - kv_indptr[:-1]
wrapper_paged.plan(
qo_indptr,
kv_indptr,
kv_indices,
kv_len_arr,
self.num_qo_heads,
512,
64,
1,
True,
1 / math.sqrt(192),
self.data_type,
self.data_type,
)
class FlashInferMultiStepDraftBackend:
......
......@@ -66,6 +66,7 @@ global_server_args_dict = {
"enable_ep_moe": ServerArgs.enable_ep_moe,
"device": ServerArgs.device,
"enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla,
"disable_radix_cache": ServerArgs.disable_radix_cache,
}
logger = logging.getLogger(__name__)
......
......@@ -177,6 +177,7 @@ class ModelRunner:
"enable_ep_moe": server_args.enable_ep_moe,
"device": server_args.device,
"enable_flashinfer_mla": server_args.enable_flashinfer_mla,
"disable_radix_cache": server_args.disable_radix_cache,
}
)
......
......@@ -511,8 +511,11 @@ class DeepseekV2AttentionMLA(nn.Module):
forward_batch: ForwardBatch,
) -> torch.Tensor:
if global_server_args_dict["enable_flashinfer_mla"]:
if forward_batch.forward_mode.is_extend():
return self.forward_normal(positions, hidden_states, forward_batch)
if global_server_args_dict["disable_radix_cache"]:
if forward_batch.forward_mode.is_extend():
return self.forward_normal(positions, hidden_states, forward_batch)
else:
return self.forward_absorb(positions, hidden_states, forward_batch)
else:
return self.forward_absorb(positions, hidden_states, forward_batch)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment