Unverified Commit c8452551 authored by ykcombat's avatar ykcombat Committed by GitHub
Browse files

[Fix] Fix split prefill with fa3. (#11428)

parent bf3e7149
......@@ -137,7 +137,10 @@ class LogitsMetadata:
@classmethod
def from_forward_batch(cls, forward_batch: ForwardBatch):
if (
forward_batch.forward_mode.is_extend()
(
forward_batch.forward_mode.is_extend()
or forward_batch.forward_mode.is_split_prefill()
)
and forward_batch.return_logprob
and not forward_batch.forward_mode.is_target_verify()
):
......@@ -389,8 +392,8 @@ class LogitsProcessor(nn.Module):
input_logprob_indices = None
elif (
logits_metadata.forward_mode.is_extend()
and not logits_metadata.extend_return_logprob
):
or logits_metadata.forward_mode.is_split_prefill()
) and not logits_metadata.extend_return_logprob:
# Prefill without input logprobs.
if logits_metadata.padded_static_len < 0:
last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
......
......@@ -112,6 +112,7 @@ class ForwardMode(IntEnum):
self == ForwardMode.EXTEND
or self == ForwardMode.DRAFT_EXTEND
or self == ForwardMode.MIXED
or self == ForwardMode.SPLIT_PREFILL
)
def is_cuda_graph(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment