Unverified Commit c64290dc authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Use seq_len_fill_value in the cuda graph runners (#7233)

parent 8e2363dc
......@@ -1807,7 +1807,7 @@ class FlashAttentionBackend(AttentionBackend):
def get_cuda_graph_seq_len_fill_value(self):
"""Get the fill value for sequence length in CUDA graph."""
return 0
return 1
def _init_local_attn_metadata(self, metadata: FlashAttentionMetadata, device):
"""Centralized utility to initialize local_attn_metadata if chunked attention is enabled."""
......
......@@ -440,7 +440,7 @@ class FlashInferAttnBackend(AttentionBackend):
raise ValueError("Invalid forward mode")
def get_cuda_graph_seq_len_fill_value(self):
return 0
return 1
def forward_extend(
self,
......
......@@ -364,7 +364,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
raise ValueError(f"Invalid forward mode: {forward_mode=}")
def get_cuda_graph_seq_len_fill_value(self):
return 0
return 1
def forward_extend(
self,
......
......@@ -612,7 +612,7 @@ class CudaGraphRunner:
index = bisect.bisect_left(self.capture_bs, raw_bs)
bs = self.capture_bs[index]
if bs != raw_bs:
self.seq_lens.fill_(1)
self.seq_lens.fill_(self.seq_len_fill_value)
self.out_cache_loc.zero_()
# Common inputs
......@@ -624,7 +624,7 @@ class CudaGraphRunner:
if forward_batch.seq_lens_cpu is not None:
if bs != raw_bs:
self.seq_lens_cpu.fill_(1)
self.seq_lens_cpu.fill_(self.seq_len_fill_value)
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
if pp_proxy_tensors:
......@@ -652,7 +652,7 @@ class CudaGraphRunner:
bs,
self.req_pool_indices,
self.seq_lens,
forward_batch.seq_lens_sum + (bs - raw_bs),
forward_batch.seq_lens_sum + (bs - raw_bs) * self.seq_len_fill_value,
self.encoder_lens,
forward_batch.forward_mode,
forward_batch.spec_info,
......
......@@ -187,9 +187,8 @@ class EAGLEDraftCudaGraphRunner:
index = bisect.bisect_left(self.capture_bs, raw_bs)
bs = self.capture_bs[index]
if bs != raw_bs:
self.seq_lens.fill_(1)
self.seq_lens.fill_(self.seq_len_fill_value)
self.out_cache_loc.zero_()
self.positions.zero_()
num_tokens = bs * self.num_tokens_per_bs
......@@ -211,15 +210,15 @@ class EAGLEDraftCudaGraphRunner:
forward_batch.req_pool_indices = self.req_pool_indices[:bs]
forward_batch.positions = self.positions[:num_tokens]
# Special handle for seq_len_cpu used when flashinfer mla is used
if forward_batch.seq_lens_cpu is not None and bs != raw_bs:
self.seq_lens_cpu.fill_(1)
self.seq_lens_cpu.fill_(self.seq_len_fill_value)
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
forward_batch.seq_lens_cpu = self.seq_lens_cpu[:bs]
self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph(
forward_batch, bs
)
# TODO: The forward_batch.seq_len_sum might need to be updated to reflect the padding in the cuda graph
# Replay
self.graphs[bs].replay()
......
......@@ -207,9 +207,9 @@ class EAGLEDraftExtendCudaGraphRunner:
index = bisect.bisect_left(self.capture_bs, raw_bs)
bs = self.capture_bs[index]
if bs * self.num_tokens_per_bs != num_tokens:
self.seq_lens.fill_(1)
self.accept_length.fill_(1)
self.seq_lens.fill_(self.seq_len_fill_value)
self.out_cache_loc.zero_()
self.accept_length.fill_(1)
# Common inputs
self.input_ids[:num_tokens].copy_(forward_batch.input_ids)
......@@ -223,18 +223,19 @@ class EAGLEDraftExtendCudaGraphRunner:
if forward_batch.seq_lens_cpu is not None:
if bs != raw_bs:
self.seq_lens_cpu.fill_(1)
self.seq_lens_cpu.fill_(self.seq_len_fill_value)
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
if bs != raw_bs:
forward_batch.spec_info.positions = self.positions[:num_tokens]
forward_batch.spec_info.accept_length = self.accept_length[:bs]
forward_batch.spec_info.positions = None
self.eagle_worker.draft_extend_attn_backend.init_forward_metadata_replay_cuda_graph(
bs=bs,
req_pool_indices=self.req_pool_indices,
seq_lens=self.seq_lens,
seq_lens_sum=forward_batch.seq_lens_sum + (bs - raw_bs),
seq_lens_sum=forward_batch.seq_lens_sum
+ (bs - raw_bs) * self.seq_len_fill_value,
encoder_lens=None,
forward_mode=ForwardMode.DRAFT_EXTEND,
spec_info=forward_batch.spec_info,
......
......@@ -166,6 +166,10 @@ class EAGLEWorker(TpModelWorker):
def init_attention_backend(self):
# Create multi-step attn backends and cuda graph runners
self.has_prefill_wrapper_verify = False
self.draft_extend_attn_backend = None
if self.server_args.attention_backend == "flashinfer":
if not global_server_args_dict["use_mla_backend"]:
from sglang.srt.layers.attention.flashinfer_backend import (
......@@ -213,7 +217,6 @@ class EAGLEWorker(TpModelWorker):
self.draft_model_runner,
skip_prefill=False,
)
self.has_prefill_wrapper_verify = False
elif self.server_args.attention_backend == "fa3":
from sglang.srt.layers.attention.flashattention_backend import (
FlashAttentionBackend,
......@@ -229,7 +232,6 @@ class EAGLEWorker(TpModelWorker):
self.draft_model_runner,
skip_prefill=False,
)
self.has_prefill_wrapper_verify = False
elif self.server_args.attention_backend == "flashmla":
from sglang.srt.layers.attention.flashmla_backend import (
FlashMLAMultiStepDraftBackend,
......@@ -240,8 +242,6 @@ class EAGLEWorker(TpModelWorker):
self.topk,
self.speculative_num_steps,
)
self.draft_extend_attn_backend = None
self.has_prefill_wrapper_verify = False
else:
raise ValueError(
f"EAGLE is not supported in attention backend {self.server_args.attention_backend}"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment