Unverified Commit c8452551 authored by ykcombat's avatar ykcombat Committed by GitHub
Browse files

[Fix] Fix split prefill with fa3. (#11428)

parent bf3e7149
...@@ -137,7 +137,10 @@ class LogitsMetadata: ...@@ -137,7 +137,10 @@ class LogitsMetadata:
@classmethod @classmethod
def from_forward_batch(cls, forward_batch: ForwardBatch): def from_forward_batch(cls, forward_batch: ForwardBatch):
if ( if (
(
forward_batch.forward_mode.is_extend() forward_batch.forward_mode.is_extend()
or forward_batch.forward_mode.is_split_prefill()
)
and forward_batch.return_logprob and forward_batch.return_logprob
and not forward_batch.forward_mode.is_target_verify() and not forward_batch.forward_mode.is_target_verify()
): ):
...@@ -389,8 +392,8 @@ class LogitsProcessor(nn.Module): ...@@ -389,8 +392,8 @@ class LogitsProcessor(nn.Module):
input_logprob_indices = None input_logprob_indices = None
elif ( elif (
logits_metadata.forward_mode.is_extend() logits_metadata.forward_mode.is_extend()
and not logits_metadata.extend_return_logprob or logits_metadata.forward_mode.is_split_prefill()
): ) and not logits_metadata.extend_return_logprob:
# Prefill without input logprobs. # Prefill without input logprobs.
if logits_metadata.padded_static_len < 0: if logits_metadata.padded_static_len < 0:
last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1 last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
......
...@@ -112,6 +112,7 @@ class ForwardMode(IntEnum): ...@@ -112,6 +112,7 @@ class ForwardMode(IntEnum):
self == ForwardMode.EXTEND self == ForwardMode.EXTEND
or self == ForwardMode.DRAFT_EXTEND or self == ForwardMode.DRAFT_EXTEND
or self == ForwardMode.MIXED or self == ForwardMode.MIXED
or self == ForwardMode.SPLIT_PREFILL
) )
def is_cuda_graph(self): def is_cuda_graph(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment