Unverified Commit 7e412900 authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Add draft extend CUDA graph for Triton backend (#6705)

parent c673727e
......@@ -128,6 +128,7 @@ class TritonAttnBackend(AttentionBackend):
)
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
self.num_head = (
model_runner.model_config.num_attention_heads // get_attention_tp_size()
......@@ -424,6 +425,34 @@ class TritonAttnBackend(AttentionBackend):
num_kv_splits = None
attn_logits = None
attn_lse = None
elif forward_mode.is_draft_extend():
num_tokens_per_bs = self.speculative_num_steps + 1
qo_indptr = self.qo_indptr[: bs + 1]
qo_indptr[: bs + 1] = torch.arange(
0,
bs * num_tokens_per_bs + 1,
step=num_tokens_per_bs,
dtype=torch.int32,
device=self.device,
)
kv_indptr = self.kv_indptr[: bs + 1]
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
kv_indices = self.cuda_graph_kv_indices
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
seq_lens,
kv_indptr,
None,
kv_indices,
self.req_to_token.stride(0),
)
custom_mask = None
mask_indptr = None
max_extend_len = num_tokens_per_bs
num_kv_splits = None
attn_logits = None
attn_lse = None
else:
raise ValueError(
f"Invalid forward mode: {forward_mode=} for CUDA Graph capture."
......@@ -504,6 +533,23 @@ class TritonAttnBackend(AttentionBackend):
seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
mask_indptr = self.mask_indptr[: bs + 1]
mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)
elif forward_mode.is_draft_extend():
seq_lens = seq_lens[:bs]
accept_lens = spec_info.accept_length[:bs]
qo_indptr = self.qo_indptr[: bs + 1]
qo_indptr[1 : bs + 1] = torch.cumsum(accept_lens, dim=0)
kv_indptr = self.kv_indptr[: bs + 1]
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
kv_indices = self.cuda_graph_kv_indices
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
seq_lens,
kv_indptr,
None,
kv_indices,
self.req_to_token.stride(0),
)
else:
raise ValueError(
f"Invalid forward mode: {forward_mode=} for CUDA Graph replay."
......
......@@ -179,6 +179,7 @@ class EAGLEWorker(TpModelWorker):
self.has_prefill_wrapper_verify = True
elif self.server_args.attention_backend == "triton":
from sglang.srt.layers.attention.triton_backend import (
TritonAttnBackend,
TritonMultiStepDraftBackend,
)
......@@ -187,7 +188,10 @@ class EAGLEWorker(TpModelWorker):
self.topk,
self.speculative_num_steps,
)
self.draft_extend_attn_backend = None
self.draft_extend_attn_backend = TritonAttnBackend(
self.draft_model_runner,
skip_prefill=False,
)
self.padded_static_len = self.speculative_num_steps + 1
self.has_prefill_wrapper_verify = False
elif self.server_args.attention_backend == "fa3":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment