Unverified Commit 0bfe1d14 authored by Xinyuan Tong's avatar Xinyuan Tong Committed by GitHub
Browse files

fa3 & trtllm_mha spec overlap (#11874)

parent 5f98b7fe
...@@ -584,7 +584,9 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -584,7 +584,9 @@ class FlashAttentionBackend(AttentionBackend):
metadata, metadata_expand metadata, metadata_expand
) )
elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed(): elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed(
include_draft_extend_v2=True
):
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32) metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
metadata.cu_seqlens_k = torch.nn.functional.pad( metadata.cu_seqlens_k = torch.nn.functional.pad(
...@@ -594,10 +596,9 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -594,10 +596,9 @@ class FlashAttentionBackend(AttentionBackend):
forward_batch.req_pool_indices, : metadata.max_seq_len_k forward_batch.req_pool_indices, : metadata.max_seq_len_k
] ]
if ( if any(
any(forward_batch.extend_prefix_lens_cpu) forward_batch.extend_prefix_lens_cpu
or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND ) or forward_batch.forward_mode.is_draft_extend(include_v2=True):
):
extend_seq_lens = forward_batch.extend_seq_lens extend_seq_lens = forward_batch.extend_seq_lens
metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu) metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
metadata.cu_seqlens_q = torch.nn.functional.pad( metadata.cu_seqlens_q = torch.nn.functional.pad(
...@@ -826,7 +827,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -826,7 +827,7 @@ class FlashAttentionBackend(AttentionBackend):
if ( if (
forward_batch.attn_attend_prefix_cache is not None forward_batch.attn_attend_prefix_cache is not None
and not forward_batch.forward_mode.is_target_verify() and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend() and not forward_batch.forward_mode.is_draft_extend(include_v2=True)
): ):
# Do multi-head attention with chunked prefix cache # Do multi-head attention with chunked prefix cache
if forward_batch.attn_attend_prefix_cache: if forward_batch.attn_attend_prefix_cache:
......
...@@ -488,10 +488,9 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): ...@@ -488,10 +488,9 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
forward_batch.req_pool_indices, : metadata.max_seq_len_k forward_batch.req_pool_indices, : metadata.max_seq_len_k
] ]
if ( if any(
any(forward_batch.extend_prefix_lens_cpu) forward_batch.extend_prefix_lens_cpu
or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND ) or forward_batch.forward_mode.is_draft_extend(include_v2=True):
):
extend_seq_lens = forward_batch.extend_seq_lens extend_seq_lens = forward_batch.extend_seq_lens
metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu) metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
metadata.cu_seqlens_q = torch.nn.functional.pad( metadata.cu_seqlens_q = torch.nn.functional.pad(
......
...@@ -90,11 +90,7 @@ class ForwardMode(IntEnum): ...@@ -90,11 +90,7 @@ class ForwardMode(IntEnum):
self == ForwardMode.EXTEND self == ForwardMode.EXTEND
or self == ForwardMode.MIXED or self == ForwardMode.MIXED
or self == ForwardMode.DRAFT_EXTEND or self == ForwardMode.DRAFT_EXTEND
or ( or (include_draft_extend_v2 and self == ForwardMode.DRAFT_EXTEND_V2)
self == ForwardMode.DRAFT_EXTEND_V2
if include_draft_extend_v2
else False
)
or self == ForwardMode.TARGET_VERIFY or self == ForwardMode.TARGET_VERIFY
or self == ForwardMode.SPLIT_PREFILL or self == ForwardMode.SPLIT_PREFILL
) )
...@@ -115,22 +111,21 @@ class ForwardMode(IntEnum): ...@@ -115,22 +111,21 @@ class ForwardMode(IntEnum):
return self == ForwardMode.TARGET_VERIFY return self == ForwardMode.TARGET_VERIFY
def is_draft_extend(self, include_v2: bool = False): def is_draft_extend(self, include_v2: bool = False):
if include_v2: return self == ForwardMode.DRAFT_EXTEND or (
return ( include_v2 and self == ForwardMode.DRAFT_EXTEND_V2
self == ForwardMode.DRAFT_EXTEND_V2 or self == ForwardMode.DRAFT_EXTEND
) )
return self == ForwardMode.DRAFT_EXTEND
def is_draft_extend_v2(self): def is_draft_extend_v2(self):
# For fixed shape logits output in v2 eagle worker # For fixed shape logits output in v2 eagle worker
return self == ForwardMode.DRAFT_EXTEND_V2 return self == ForwardMode.DRAFT_EXTEND_V2
def is_extend_or_draft_extend_or_mixed(self): def is_extend_or_draft_extend_or_mixed(self, include_draft_extend_v2: bool = False):
return ( return (
self == ForwardMode.EXTEND self == ForwardMode.EXTEND
or self == ForwardMode.DRAFT_EXTEND or self == ForwardMode.DRAFT_EXTEND
or self == ForwardMode.MIXED or self == ForwardMode.MIXED
or self == ForwardMode.SPLIT_PREFILL or self == ForwardMode.SPLIT_PREFILL
or (include_draft_extend_v2 and self == ForwardMode.DRAFT_EXTEND_V2)
) )
def is_cuda_graph(self): def is_cuda_graph(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment