"examples/advanced/cugraph/graphsage.py" did not exist on "b76d0ed1db605d05e2f167cebc57ce55f6acda96"
Unverified Commit c64290dc authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Use seq_len_fill_value in the cuda graph runners (#7233)

parent 8e2363dc
...@@ -1807,7 +1807,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1807,7 +1807,7 @@ class FlashAttentionBackend(AttentionBackend):
def get_cuda_graph_seq_len_fill_value(self): def get_cuda_graph_seq_len_fill_value(self):
"""Get the fill value for sequence length in CUDA graph.""" """Get the fill value for sequence length in CUDA graph."""
return 0 return 1
def _init_local_attn_metadata(self, metadata: FlashAttentionMetadata, device): def _init_local_attn_metadata(self, metadata: FlashAttentionMetadata, device):
"""Centralized utility to initialize local_attn_metadata if chunked attention is enabled.""" """Centralized utility to initialize local_attn_metadata if chunked attention is enabled."""
......
...@@ -440,7 +440,7 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -440,7 +440,7 @@ class FlashInferAttnBackend(AttentionBackend):
raise ValueError("Invalid forward mode") raise ValueError("Invalid forward mode")
def get_cuda_graph_seq_len_fill_value(self): def get_cuda_graph_seq_len_fill_value(self):
return 0 return 1
def forward_extend( def forward_extend(
self, self,
......
...@@ -364,7 +364,7 @@ class FlashInferMLAAttnBackend(AttentionBackend): ...@@ -364,7 +364,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
raise ValueError(f"Invalid forward mode: {forward_mode=}") raise ValueError(f"Invalid forward mode: {forward_mode=}")
def get_cuda_graph_seq_len_fill_value(self): def get_cuda_graph_seq_len_fill_value(self):
return 0 return 1
def forward_extend( def forward_extend(
self, self,
......
...@@ -612,7 +612,7 @@ class CudaGraphRunner: ...@@ -612,7 +612,7 @@ class CudaGraphRunner:
index = bisect.bisect_left(self.capture_bs, raw_bs) index = bisect.bisect_left(self.capture_bs, raw_bs)
bs = self.capture_bs[index] bs = self.capture_bs[index]
if bs != raw_bs: if bs != raw_bs:
self.seq_lens.fill_(1) self.seq_lens.fill_(self.seq_len_fill_value)
self.out_cache_loc.zero_() self.out_cache_loc.zero_()
# Common inputs # Common inputs
...@@ -624,7 +624,7 @@ class CudaGraphRunner: ...@@ -624,7 +624,7 @@ class CudaGraphRunner:
if forward_batch.seq_lens_cpu is not None: if forward_batch.seq_lens_cpu is not None:
if bs != raw_bs: if bs != raw_bs:
self.seq_lens_cpu.fill_(1) self.seq_lens_cpu.fill_(self.seq_len_fill_value)
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu) self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
if pp_proxy_tensors: if pp_proxy_tensors:
...@@ -652,7 +652,7 @@ class CudaGraphRunner: ...@@ -652,7 +652,7 @@ class CudaGraphRunner:
bs, bs,
self.req_pool_indices, self.req_pool_indices,
self.seq_lens, self.seq_lens,
forward_batch.seq_lens_sum + (bs - raw_bs), forward_batch.seq_lens_sum + (bs - raw_bs) * self.seq_len_fill_value,
self.encoder_lens, self.encoder_lens,
forward_batch.forward_mode, forward_batch.forward_mode,
forward_batch.spec_info, forward_batch.spec_info,
......
...@@ -187,9 +187,8 @@ class EAGLEDraftCudaGraphRunner: ...@@ -187,9 +187,8 @@ class EAGLEDraftCudaGraphRunner:
index = bisect.bisect_left(self.capture_bs, raw_bs) index = bisect.bisect_left(self.capture_bs, raw_bs)
bs = self.capture_bs[index] bs = self.capture_bs[index]
if bs != raw_bs: if bs != raw_bs:
self.seq_lens.fill_(1) self.seq_lens.fill_(self.seq_len_fill_value)
self.out_cache_loc.zero_() self.out_cache_loc.zero_()
self.positions.zero_()
num_tokens = bs * self.num_tokens_per_bs num_tokens = bs * self.num_tokens_per_bs
...@@ -211,15 +210,15 @@ class EAGLEDraftCudaGraphRunner: ...@@ -211,15 +210,15 @@ class EAGLEDraftCudaGraphRunner:
forward_batch.req_pool_indices = self.req_pool_indices[:bs] forward_batch.req_pool_indices = self.req_pool_indices[:bs]
forward_batch.positions = self.positions[:num_tokens] forward_batch.positions = self.positions[:num_tokens]
# Special handle for seq_len_cpu used when flashinfer mla is used
if forward_batch.seq_lens_cpu is not None and bs != raw_bs: if forward_batch.seq_lens_cpu is not None and bs != raw_bs:
self.seq_lens_cpu.fill_(1) self.seq_lens_cpu.fill_(self.seq_len_fill_value)
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu) self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
forward_batch.seq_lens_cpu = self.seq_lens_cpu[:bs] forward_batch.seq_lens_cpu = self.seq_lens_cpu[:bs]
self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph( self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph(
forward_batch, bs forward_batch, bs
) )
# TODO: The forward_batch.seq_len_sum might need to be updated to reflect the padding in the cuda graph
# Replay # Replay
self.graphs[bs].replay() self.graphs[bs].replay()
......
...@@ -207,9 +207,9 @@ class EAGLEDraftExtendCudaGraphRunner: ...@@ -207,9 +207,9 @@ class EAGLEDraftExtendCudaGraphRunner:
index = bisect.bisect_left(self.capture_bs, raw_bs) index = bisect.bisect_left(self.capture_bs, raw_bs)
bs = self.capture_bs[index] bs = self.capture_bs[index]
if bs * self.num_tokens_per_bs != num_tokens: if bs * self.num_tokens_per_bs != num_tokens:
self.seq_lens.fill_(1) self.seq_lens.fill_(self.seq_len_fill_value)
self.accept_length.fill_(1)
self.out_cache_loc.zero_() self.out_cache_loc.zero_()
self.accept_length.fill_(1)
# Common inputs # Common inputs
self.input_ids[:num_tokens].copy_(forward_batch.input_ids) self.input_ids[:num_tokens].copy_(forward_batch.input_ids)
...@@ -223,18 +223,19 @@ class EAGLEDraftExtendCudaGraphRunner: ...@@ -223,18 +223,19 @@ class EAGLEDraftExtendCudaGraphRunner:
if forward_batch.seq_lens_cpu is not None: if forward_batch.seq_lens_cpu is not None:
if bs != raw_bs: if bs != raw_bs:
self.seq_lens_cpu.fill_(1) self.seq_lens_cpu.fill_(self.seq_len_fill_value)
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu) self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
if bs != raw_bs: if bs != raw_bs:
forward_batch.spec_info.positions = self.positions[:num_tokens]
forward_batch.spec_info.accept_length = self.accept_length[:bs] forward_batch.spec_info.accept_length = self.accept_length[:bs]
forward_batch.spec_info.positions = None
self.eagle_worker.draft_extend_attn_backend.init_forward_metadata_replay_cuda_graph( self.eagle_worker.draft_extend_attn_backend.init_forward_metadata_replay_cuda_graph(
bs=bs, bs=bs,
req_pool_indices=self.req_pool_indices, req_pool_indices=self.req_pool_indices,
seq_lens=self.seq_lens, seq_lens=self.seq_lens,
seq_lens_sum=forward_batch.seq_lens_sum + (bs - raw_bs), seq_lens_sum=forward_batch.seq_lens_sum
+ (bs - raw_bs) * self.seq_len_fill_value,
encoder_lens=None, encoder_lens=None,
forward_mode=ForwardMode.DRAFT_EXTEND, forward_mode=ForwardMode.DRAFT_EXTEND,
spec_info=forward_batch.spec_info, spec_info=forward_batch.spec_info,
......
...@@ -166,6 +166,10 @@ class EAGLEWorker(TpModelWorker): ...@@ -166,6 +166,10 @@ class EAGLEWorker(TpModelWorker):
def init_attention_backend(self): def init_attention_backend(self):
# Create multi-step attn backends and cuda graph runners # Create multi-step attn backends and cuda graph runners
self.has_prefill_wrapper_verify = False
self.draft_extend_attn_backend = None
if self.server_args.attention_backend == "flashinfer": if self.server_args.attention_backend == "flashinfer":
if not global_server_args_dict["use_mla_backend"]: if not global_server_args_dict["use_mla_backend"]:
from sglang.srt.layers.attention.flashinfer_backend import ( from sglang.srt.layers.attention.flashinfer_backend import (
...@@ -213,7 +217,6 @@ class EAGLEWorker(TpModelWorker): ...@@ -213,7 +217,6 @@ class EAGLEWorker(TpModelWorker):
self.draft_model_runner, self.draft_model_runner,
skip_prefill=False, skip_prefill=False,
) )
self.has_prefill_wrapper_verify = False
elif self.server_args.attention_backend == "fa3": elif self.server_args.attention_backend == "fa3":
from sglang.srt.layers.attention.flashattention_backend import ( from sglang.srt.layers.attention.flashattention_backend import (
FlashAttentionBackend, FlashAttentionBackend,
...@@ -229,7 +232,6 @@ class EAGLEWorker(TpModelWorker): ...@@ -229,7 +232,6 @@ class EAGLEWorker(TpModelWorker):
self.draft_model_runner, self.draft_model_runner,
skip_prefill=False, skip_prefill=False,
) )
self.has_prefill_wrapper_verify = False
elif self.server_args.attention_backend == "flashmla": elif self.server_args.attention_backend == "flashmla":
from sglang.srt.layers.attention.flashmla_backend import ( from sglang.srt.layers.attention.flashmla_backend import (
FlashMLAMultiStepDraftBackend, FlashMLAMultiStepDraftBackend,
...@@ -240,8 +242,6 @@ class EAGLEWorker(TpModelWorker): ...@@ -240,8 +242,6 @@ class EAGLEWorker(TpModelWorker):
self.topk, self.topk,
self.speculative_num_steps, self.speculative_num_steps,
) )
self.draft_extend_attn_backend = None
self.has_prefill_wrapper_verify = False
else: else:
raise ValueError( raise ValueError(
f"EAGLE is not supported in attention backend {self.server_args.attention_backend}" f"EAGLE is not supported in attention backend {self.server_args.attention_backend}"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment