Unverified Commit 0bfe1d14 authored by Xinyuan Tong's avatar Xinyuan Tong Committed by GitHub
Browse files

fa3 & trtllm_mha spec overlap (#11874)

parent 5f98b7fe
......@@ -584,7 +584,9 @@ class FlashAttentionBackend(AttentionBackend):
metadata, metadata_expand
)
elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed():
elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed(
include_draft_extend_v2=True
):
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
metadata.cu_seqlens_k = torch.nn.functional.pad(
......@@ -594,10 +596,9 @@ class FlashAttentionBackend(AttentionBackend):
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
if (
any(forward_batch.extend_prefix_lens_cpu)
or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND
):
if any(
forward_batch.extend_prefix_lens_cpu
) or forward_batch.forward_mode.is_draft_extend(include_v2=True):
extend_seq_lens = forward_batch.extend_seq_lens
metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
metadata.cu_seqlens_q = torch.nn.functional.pad(
......@@ -826,7 +827,7 @@ class FlashAttentionBackend(AttentionBackend):
if (
forward_batch.attn_attend_prefix_cache is not None
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
and not forward_batch.forward_mode.is_draft_extend(include_v2=True)
):
# Do multi-head attention with chunked prefix cache
if forward_batch.attn_attend_prefix_cache:
......
......@@ -488,10 +488,9 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
if (
any(forward_batch.extend_prefix_lens_cpu)
or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND
):
if any(
forward_batch.extend_prefix_lens_cpu
) or forward_batch.forward_mode.is_draft_extend(include_v2=True):
extend_seq_lens = forward_batch.extend_seq_lens
metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
metadata.cu_seqlens_q = torch.nn.functional.pad(
......
......@@ -90,11 +90,7 @@ class ForwardMode(IntEnum):
self == ForwardMode.EXTEND
or self == ForwardMode.MIXED
or self == ForwardMode.DRAFT_EXTEND
or (
self == ForwardMode.DRAFT_EXTEND_V2
if include_draft_extend_v2
else False
)
or (include_draft_extend_v2 and self == ForwardMode.DRAFT_EXTEND_V2)
or self == ForwardMode.TARGET_VERIFY
or self == ForwardMode.SPLIT_PREFILL
)
......@@ -115,22 +111,21 @@ class ForwardMode(IntEnum):
return self == ForwardMode.TARGET_VERIFY
def is_draft_extend(self, include_v2: bool = False):
if include_v2:
return (
self == ForwardMode.DRAFT_EXTEND_V2 or self == ForwardMode.DRAFT_EXTEND
)
return self == ForwardMode.DRAFT_EXTEND
return self == ForwardMode.DRAFT_EXTEND or (
include_v2 and self == ForwardMode.DRAFT_EXTEND_V2
)
def is_draft_extend_v2(self):
# For fixed shape logits output in v2 eagle worker
return self == ForwardMode.DRAFT_EXTEND_V2
def is_extend_or_draft_extend_or_mixed(self):
def is_extend_or_draft_extend_or_mixed(self, include_draft_extend_v2: bool = False):
return (
self == ForwardMode.EXTEND
or self == ForwardMode.DRAFT_EXTEND
or self == ForwardMode.MIXED
or self == ForwardMode.SPLIT_PREFILL
or (include_draft_extend_v2 and self == ForwardMode.DRAFT_EXTEND_V2)
)
def is_cuda_graph(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment