Unverified Commit 7e412900 authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Add draft extend CUDA graph for Triton backend (#6705)

parent c673727e
...@@ -128,6 +128,7 @@ class TritonAttnBackend(AttentionBackend): ...@@ -128,6 +128,7 @@ class TritonAttnBackend(AttentionBackend):
) )
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
self.num_head = ( self.num_head = (
model_runner.model_config.num_attention_heads // get_attention_tp_size() model_runner.model_config.num_attention_heads // get_attention_tp_size()
...@@ -424,6 +425,34 @@ class TritonAttnBackend(AttentionBackend): ...@@ -424,6 +425,34 @@ class TritonAttnBackend(AttentionBackend):
num_kv_splits = None num_kv_splits = None
attn_logits = None attn_logits = None
attn_lse = None attn_lse = None
elif forward_mode.is_draft_extend():
num_tokens_per_bs = self.speculative_num_steps + 1
qo_indptr = self.qo_indptr[: bs + 1]
qo_indptr[: bs + 1] = torch.arange(
0,
bs * num_tokens_per_bs + 1,
step=num_tokens_per_bs,
dtype=torch.int32,
device=self.device,
)
kv_indptr = self.kv_indptr[: bs + 1]
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
kv_indices = self.cuda_graph_kv_indices
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
seq_lens,
kv_indptr,
None,
kv_indices,
self.req_to_token.stride(0),
)
custom_mask = None
mask_indptr = None
max_extend_len = num_tokens_per_bs
num_kv_splits = None
attn_logits = None
attn_lse = None
else: else:
raise ValueError( raise ValueError(
f"Invalid forward mode: {forward_mode=} for CUDA Graph capture." f"Invalid forward mode: {forward_mode=} for CUDA Graph capture."
...@@ -504,6 +533,23 @@ class TritonAttnBackend(AttentionBackend): ...@@ -504,6 +533,23 @@ class TritonAttnBackend(AttentionBackend):
seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens) seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
mask_indptr = self.mask_indptr[: bs + 1] mask_indptr = self.mask_indptr[: bs + 1]
mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0) mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)
elif forward_mode.is_draft_extend():
seq_lens = seq_lens[:bs]
accept_lens = spec_info.accept_length[:bs]
qo_indptr = self.qo_indptr[: bs + 1]
qo_indptr[1 : bs + 1] = torch.cumsum(accept_lens, dim=0)
kv_indptr = self.kv_indptr[: bs + 1]
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
kv_indices = self.cuda_graph_kv_indices
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
seq_lens,
kv_indptr,
None,
kv_indices,
self.req_to_token.stride(0),
)
else: else:
raise ValueError( raise ValueError(
f"Invalid forward mode: {forward_mode=} for CUDA Graph replay." f"Invalid forward mode: {forward_mode=} for CUDA Graph replay."
......
...@@ -179,6 +179,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -179,6 +179,7 @@ class EAGLEWorker(TpModelWorker):
self.has_prefill_wrapper_verify = True self.has_prefill_wrapper_verify = True
elif self.server_args.attention_backend == "triton": elif self.server_args.attention_backend == "triton":
from sglang.srt.layers.attention.triton_backend import ( from sglang.srt.layers.attention.triton_backend import (
TritonAttnBackend,
TritonMultiStepDraftBackend, TritonMultiStepDraftBackend,
) )
...@@ -187,7 +188,10 @@ class EAGLEWorker(TpModelWorker): ...@@ -187,7 +188,10 @@ class EAGLEWorker(TpModelWorker):
self.topk, self.topk,
self.speculative_num_steps, self.speculative_num_steps,
) )
self.draft_extend_attn_backend = None self.draft_extend_attn_backend = TritonAttnBackend(
self.draft_model_runner,
skip_prefill=False,
)
self.padded_static_len = self.speculative_num_steps + 1 self.padded_static_len = self.speculative_num_steps + 1
self.has_prefill_wrapper_verify = False self.has_prefill_wrapper_verify = False
elif self.server_args.attention_backend == "fa3": elif self.server_args.attention_backend == "fa3":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment