Unverified Commit 5a33c3aa authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Optimize Triton Draft Backend (#11556)

parent 9767a1e4
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, Union from typing import TYPE_CHECKING, List, Optional
import torch import torch
import triton import triton
...@@ -12,6 +12,7 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_trito ...@@ -12,6 +12,7 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_trito
from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.layers.radix_attention import AttentionType from sglang.srt.layers.radix_attention import AttentionType
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices
from sglang.srt.utils import ( from sglang.srt.utils import (
get_bool_env_var, get_bool_env_var,
get_device_core_count, get_device_core_count,
...@@ -423,6 +424,7 @@ class TritonAttnBackend(AttentionBackend): ...@@ -423,6 +424,7 @@ class TritonAttnBackend(AttentionBackend):
max_bs: int, max_bs: int,
max_num_tokens: int, max_num_tokens: int,
kv_indices_buf: Optional[torch.Tensor] = None, kv_indices_buf: Optional[torch.Tensor] = None,
cuda_graph_num_kv_splits_buf: Optional[torch.Tensor] = None,
): ):
self.cuda_graph_attn_logits = torch.zeros( self.cuda_graph_attn_logits = torch.zeros(
(max_num_tokens, self.num_head, self.max_kv_splits, self.v_head_dim), (max_num_tokens, self.num_head, self.max_kv_splits, self.v_head_dim),
...@@ -434,9 +436,17 @@ class TritonAttnBackend(AttentionBackend): ...@@ -434,9 +436,17 @@ class TritonAttnBackend(AttentionBackend):
dtype=torch.float32, dtype=torch.float32,
device=self.device, device=self.device,
) )
if cuda_graph_num_kv_splits_buf is None:
self.cuda_graph_num_kv_splits = torch.full( self.cuda_graph_num_kv_splits = torch.full(
(max_num_tokens,), self.max_kv_splits, dtype=torch.int32, device=self.device (max_num_tokens,),
self.max_kv_splits,
dtype=torch.int32,
device=self.device,
) )
else:
self.cuda_graph_num_kv_splits = cuda_graph_num_kv_splits_buf
if kv_indices_buf is None: if kv_indices_buf is None:
self.cuda_graph_kv_indices = torch.zeros( self.cuda_graph_kv_indices = torch.zeros(
(max_num_tokens * self.max_context_len), (max_num_tokens * self.max_context_len),
...@@ -683,9 +693,7 @@ class TritonAttnBackend(AttentionBackend): ...@@ -683,9 +693,7 @@ class TritonAttnBackend(AttentionBackend):
) )
else: else:
kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr assert False, "Multi-step cuda graph init is not done here."
kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices
num_token = spec_info.kv_indptr.shape[0] - 1
self.get_num_kv_splits(num_kv_splits[:num_token], seq_lens[:bs]) self.get_num_kv_splits(num_kv_splits[:num_token], seq_lens[:bs])
elif forward_mode.is_target_verify(): elif forward_mode.is_target_verify():
...@@ -898,11 +906,8 @@ class TritonMultiStepDraftBackend: ...@@ -898,11 +906,8 @@ class TritonMultiStepDraftBackend:
topk: int, topk: int,
speculative_num_steps: int, speculative_num_steps: int,
): ):
from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices
self.topk = topk self.topk = topk
self.speculative_num_steps = speculative_num_steps self.speculative_num_steps = speculative_num_steps
self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
max_bs = model_runner.req_to_token_pool.size * self.topk max_bs = model_runner.req_to_token_pool.size * self.topk
self.kv_indptr = torch.zeros( self.kv_indptr = torch.zeros(
( (
...@@ -912,7 +917,7 @@ class TritonMultiStepDraftBackend: ...@@ -912,7 +917,7 @@ class TritonMultiStepDraftBackend:
dtype=torch.int32, dtype=torch.int32,
device=model_runner.device, device=model_runner.device,
) )
self.attn_backends = [] self.attn_backends: List[TritonAttnBackend] = []
for i in range(self.speculative_num_steps): for i in range(self.speculative_num_steps):
self.attn_backends.append( self.attn_backends.append(
TritonAttnBackend( TritonAttnBackend(
...@@ -931,13 +936,19 @@ class TritonMultiStepDraftBackend: ...@@ -931,13 +936,19 @@ class TritonMultiStepDraftBackend:
self.page_size = model_runner.server_args.page_size self.page_size = model_runner.server_args.page_size
def common_template( def common_template(
self, forward_batch: ForwardBatch, kv_indices_buffer: torch.Tensor, call_fn: int self,
forward_batch: ForwardBatch,
kv_indices_buffer: Optional[torch.Tensor],
call_fn: int,
): ):
if kv_indices_buffer is None:
kv_indices_buffer = self.cuda_graph_kv_indices
num_seqs = forward_batch.batch_size num_seqs = forward_batch.batch_size
bs = self.topk * num_seqs bs = self.topk * num_seqs
seq_lens_sum = forward_batch.seq_lens_sum seq_lens_sum = forward_batch.seq_lens_sum
self.generate_draft_decode_kv_indices[ generate_draft_decode_kv_indices[
(self.speculative_num_steps, num_seqs, self.topk) (self.speculative_num_steps, num_seqs, self.topk)
]( ](
forward_batch.req_pool_indices, forward_batch.req_pool_indices,
...@@ -955,6 +966,9 @@ class TritonMultiStepDraftBackend: ...@@ -955,6 +966,9 @@ class TritonMultiStepDraftBackend:
self.page_size, self.page_size,
) )
if call_fn is None:
return
for i in range(self.speculative_num_steps): for i in range(self.speculative_num_steps):
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1] forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
forward_batch.spec_info.kv_indices = kv_indices_buffer[i][ forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
...@@ -989,9 +1003,18 @@ class TritonMultiStepDraftBackend: ...@@ -989,9 +1003,18 @@ class TritonMultiStepDraftBackend:
dtype=torch.int64, dtype=torch.int64,
device=self.device, device=self.device,
) )
self.cuda_graph_num_kv_splits = torch.full(
(max_num_tokens,),
self.attn_backends[0].max_kv_splits,
dtype=torch.int32,
device=self.device,
)
for i in range(self.speculative_num_steps): for i in range(self.speculative_num_steps):
self.attn_backends[i].init_cuda_graph_state( self.attn_backends[i].init_cuda_graph_state(
max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i] max_bs,
max_num_tokens,
kv_indices_buf=self.cuda_graph_kv_indices[i],
cuda_graph_num_kv_splits_buf=self.cuda_graph_num_kv_splits,
) )
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch): def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
...@@ -1006,24 +1029,24 @@ class TritonMultiStepDraftBackend: ...@@ -1006,24 +1029,24 @@ class TritonMultiStepDraftBackend:
spec_info=forward_batch.spec_info, spec_info=forward_batch.spec_info,
) )
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn) self.common_template(forward_batch, None, call_fn)
def init_forward_metadata_replay_cuda_graph( def init_forward_metadata_replay_cuda_graph(
self, forward_batch: ForwardBatch, bs: int self, forward_batch: ForwardBatch, bs: int
): ):
def call_fn(i, forward_batch): self.common_template(forward_batch, None, None)
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
bs, # NOTE: Multi-step's attention backends use the slice of
forward_batch.req_pool_indices, # - kv_indptr buffer (cuda graph and non-cuda graph)
forward_batch.seq_lens, # - kv_indices buffer (cuda graph only)
seq_lens_sum=-1, # So we don't need to assign the KV indices inside the attention backend.
encoder_lens=None,
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
seq_lens_cpu=None,
)
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn) # Compute num_kv_splits only once
num_token = forward_batch.batch_size * self.topk
self.attn_backends[-1].get_num_kv_splits(
self.attn_backends[-1].cuda_graph_num_kv_splits[:num_token],
forward_batch.seq_lens[:bs],
)
@triton.jit @triton.jit
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment