Unverified Commit e62d60fe authored by Baizhou Zhang's avatar Baizhou Zhang Committed by GitHub
Browse files

[Fix] avoid stream sync and torch compile in prefill for fa3 backend (#4932)

parent 032f8faa
......@@ -79,7 +79,7 @@ class FlashAttentionBackend(AttentionBackend):
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
)
# Precompute maximum sequence length
metadata.max_seq_len_k = seqlens_in_batch.max().item()
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
# Precompute page table
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
......
......@@ -797,7 +797,7 @@ class FlashInferMLAMultiStepDraftBackend:
encoder_lens=None,
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
seq_lens_cpu=forward_batch.decode_seq_lens_cpu,
seq_lens_cpu=forward_batch.seq_lens_cpu,
)
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
......
......@@ -92,7 +92,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
if forward_batch.forward_mode.is_decode_or_idle():
if spec_info is None:
max_seqlen_pad = triton.cdiv(
forward_batch.decode_seq_lens_cpu.max().item(), PAGE_SIZE
forward_batch.seq_lens_cpu.max().item(), PAGE_SIZE
)
block_kv_indices = torch.full(
(bs, max_seqlen_pad),
......
......@@ -1398,21 +1398,22 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
def get_model_worker_batch(self) -> ModelWorkerBatch:
if self.forward_mode.is_decode_or_idle():
if (
global_server_args_dict["enable_flashinfer_mla"]
or global_server_args_dict["enable_flashmla"]
or global_server_args_dict["attention_backend"] == "fa3"
):
decode_seq_lens = self.seq_lens.cpu()
else:
decode_seq_lens = None
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
else:
decode_seq_lens = None
extend_seq_lens = self.extend_lens
extend_prefix_lens = self.prefix_lens
extend_logprob_start_lens = self.extend_logprob_start_lens
# Create seq_lens_cpu when needed
if (
global_server_args_dict["enable_flashinfer_mla"]
or global_server_args_dict["enable_flashmla"]
or global_server_args_dict["attention_backend"] == "fa3"
):
seq_lens_cpu = self.seq_lens.cpu()
else:
seq_lens_cpu = None
if self.sampling_info:
if self.has_grammar:
self.sampling_info.grammars = [req.grammar for req in self.reqs]
......@@ -1435,7 +1436,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
global_num_tokens=self.global_num_tokens,
global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
decode_seq_lens=decode_seq_lens,
seq_lens_cpu=seq_lens_cpu,
extend_num_tokens=self.extend_num_tokens,
extend_seq_lens=extend_seq_lens,
extend_prefix_lens=extend_prefix_lens,
......@@ -1496,6 +1497,7 @@ class ModelWorkerBatch:
req_pool_indices: torch.Tensor
# The sequence length
seq_lens: torch.Tensor
seq_lens_cpu: Optional[torch.Tensor]
# The indices of output tokens in the token_to_kv_pool_allocator
out_cache_loc: torch.Tensor
......@@ -1512,9 +1514,6 @@ class ModelWorkerBatch:
global_num_tokens_for_logprob: Optional[List[int]]
can_run_dp_cuda_graph: bool
# For decode
decode_seq_lens: Optional[torch.Tensor]
# For extend
extend_num_tokens: Optional[int]
extend_seq_lens: Optional[List[int]]
......
......@@ -491,10 +491,10 @@ class CudaGraphRunner:
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
self.positions[:raw_num_token].copy_(forward_batch.positions)
if forward_batch.decode_seq_lens_cpu is not None:
if forward_batch.seq_lens_cpu is not None:
if bs != raw_bs:
self.seq_lens_cpu.fill_(1)
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.decode_seq_lens_cpu)
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
if self.is_encoder_decoder:
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
......
......@@ -39,7 +39,6 @@ import triton
import triton.language as tl
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
from sglang.srt.utils import get_compiler_backend
if TYPE_CHECKING:
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
......@@ -148,6 +147,9 @@ class ForwardBatch:
# The sum of all sequence lengths
seq_lens_sum: int
# Optional seq_lens on cpu
seq_lens_cpu: Optional[torch.Tensor] = None
# For logprob
return_logprob: bool = False
top_logprobs_nums: Optional[List[int]] = None
......@@ -162,9 +164,6 @@ class ForwardBatch:
# Position information
positions: torch.Tensor = None
# For decode
decode_seq_lens_cpu: Optional[torch.Tensor] = None
# For extend
extend_num_tokens: Optional[int] = None
extend_seq_lens: Optional[torch.Tensor] = None
......@@ -293,12 +292,14 @@ class ForwardBatch:
):
ret.positions = ret.spec_info.positions
# Get seq_lens_cpu if needed
if ret.seq_lens_cpu is None:
ret.seq_lens_cpu = batch.seq_lens_cpu
# Init position information
if ret.forward_mode.is_decode():
if ret.positions is None:
ret.positions = clamp_position(batch.seq_lens)
if ret.decode_seq_lens_cpu is None:
ret.decode_seq_lens_cpu = batch.decode_seq_lens
ret.positions = torch.clamp((batch.seq_lens - 1), min=0).to(torch.int64)
else:
ret.extend_seq_lens = torch.tensor(
batch.extend_seq_lens, dtype=torch.int32
......@@ -518,8 +519,3 @@ def compute_position_torch(
extend_start_loc = torch.zeros_like(extend_seq_lens)
extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0)
return positions.to(torch.int64), extend_start_loc
@torch.compile(dynamic=True, backend=get_compiler_backend())
def clamp_position(seq_lens):
return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
......@@ -214,10 +214,10 @@ class EAGLEDraftCudaGraphRunner:
forward_batch.positions = self.positions[:num_tokens]
# Special handle for seq_len_cpu used when flashinfer mla is used
if (forward_batch.decode_seq_lens_cpu is not None) and (bs != raw_bs):
if (forward_batch.seq_lens_cpu is not None) and (bs != raw_bs):
self.seq_lens_cpu.fill_(1)
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.decode_seq_lens_cpu)
forward_batch.decode_seq_lens_cpu = self.seq_lens_cpu[:bs]
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
forward_batch.seq_lens_cpu = self.seq_lens_cpu[:bs]
self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph(
forward_batch, bs
......@@ -233,7 +233,7 @@ class EAGLEDraftCudaGraphRunner:
forward_batch.positions = self.positions[:raw_num_token]
forward_batch.seq_lens = self.seq_lens[:raw_bs]
forward_batch.req_pool_indices = self.req_pool_indices[:raw_bs]
if forward_batch.decode_seq_lens_cpu is not None:
forward_batch.decode_seq_lens_cpu = self.seq_lens_cpu[:raw_bs]
if forward_batch.seq_lens_cpu is not None:
forward_batch.seq_lens_cpu = self.seq_lens_cpu[:raw_bs]
return out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment