Unverified Commit 5a33c3aa authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Optimize Triton Draft Backend (#11556)

parent 9767a1e4
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, Union
from typing import TYPE_CHECKING, List, Optional
import torch
import triton
......@@ -12,6 +12,7 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_trito
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.layers.radix_attention import AttentionType
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices
from sglang.srt.utils import (
get_bool_env_var,
get_device_core_count,
......@@ -423,6 +424,7 @@ class TritonAttnBackend(AttentionBackend):
max_bs: int,
max_num_tokens: int,
kv_indices_buf: Optional[torch.Tensor] = None,
cuda_graph_num_kv_splits_buf: Optional[torch.Tensor] = None,
):
self.cuda_graph_attn_logits = torch.zeros(
(max_num_tokens, self.num_head, self.max_kv_splits, self.v_head_dim),
......@@ -434,9 +436,17 @@ class TritonAttnBackend(AttentionBackend):
dtype=torch.float32,
device=self.device,
)
self.cuda_graph_num_kv_splits = torch.full(
(max_num_tokens,), self.max_kv_splits, dtype=torch.int32, device=self.device
)
if cuda_graph_num_kv_splits_buf is None:
self.cuda_graph_num_kv_splits = torch.full(
(max_num_tokens,),
self.max_kv_splits,
dtype=torch.int32,
device=self.device,
)
else:
self.cuda_graph_num_kv_splits = cuda_graph_num_kv_splits_buf
if kv_indices_buf is None:
self.cuda_graph_kv_indices = torch.zeros(
(max_num_tokens * self.max_context_len),
......@@ -683,9 +693,7 @@ class TritonAttnBackend(AttentionBackend):
)
else:
kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr
kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices
num_token = spec_info.kv_indptr.shape[0] - 1
assert False, "Multi-step cuda graph init is not done here."
self.get_num_kv_splits(num_kv_splits[:num_token], seq_lens[:bs])
elif forward_mode.is_target_verify():
......@@ -898,11 +906,8 @@ class TritonMultiStepDraftBackend:
topk: int,
speculative_num_steps: int,
):
from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices
self.topk = topk
self.speculative_num_steps = speculative_num_steps
self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
max_bs = model_runner.req_to_token_pool.size * self.topk
self.kv_indptr = torch.zeros(
(
......@@ -912,7 +917,7 @@ class TritonMultiStepDraftBackend:
dtype=torch.int32,
device=model_runner.device,
)
self.attn_backends = []
self.attn_backends: List[TritonAttnBackend] = []
for i in range(self.speculative_num_steps):
self.attn_backends.append(
TritonAttnBackend(
......@@ -931,13 +936,19 @@ class TritonMultiStepDraftBackend:
self.page_size = model_runner.server_args.page_size
def common_template(
self, forward_batch: ForwardBatch, kv_indices_buffer: torch.Tensor, call_fn: int
self,
forward_batch: ForwardBatch,
kv_indices_buffer: Optional[torch.Tensor],
call_fn: int,
):
if kv_indices_buffer is None:
kv_indices_buffer = self.cuda_graph_kv_indices
num_seqs = forward_batch.batch_size
bs = self.topk * num_seqs
seq_lens_sum = forward_batch.seq_lens_sum
self.generate_draft_decode_kv_indices[
generate_draft_decode_kv_indices[
(self.speculative_num_steps, num_seqs, self.topk)
](
forward_batch.req_pool_indices,
......@@ -955,6 +966,9 @@ class TritonMultiStepDraftBackend:
self.page_size,
)
if call_fn is None:
return
for i in range(self.speculative_num_steps):
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
......@@ -989,9 +1003,18 @@ class TritonMultiStepDraftBackend:
dtype=torch.int64,
device=self.device,
)
self.cuda_graph_num_kv_splits = torch.full(
(max_num_tokens,),
self.attn_backends[0].max_kv_splits,
dtype=torch.int32,
device=self.device,
)
for i in range(self.speculative_num_steps):
self.attn_backends[i].init_cuda_graph_state(
max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
max_bs,
max_num_tokens,
kv_indices_buf=self.cuda_graph_kv_indices[i],
cuda_graph_num_kv_splits_buf=self.cuda_graph_num_kv_splits,
)
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
......@@ -1006,24 +1029,24 @@ class TritonMultiStepDraftBackend:
spec_info=forward_batch.spec_info,
)
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
self.common_template(forward_batch, None, call_fn)
def init_forward_metadata_replay_cuda_graph(
self, forward_batch: ForwardBatch, bs: int
):
def call_fn(i, forward_batch):
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
bs,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
seq_lens_sum=-1,
encoder_lens=None,
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
seq_lens_cpu=None,
)
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
self.common_template(forward_batch, None, None)
# NOTE: Multi-step's attention backends use the slice of
# - kv_indptr buffer (cuda graph and non-cuda graph)
# - kv_indices buffer (cuda graph only)
# So we don't need to assign the KV indices inside the attention backend.
# Compute num_kv_splits only once
num_token = forward_batch.batch_size * self.topk
self.attn_backends[-1].get_num_kv_splits(
self.attn_backends[-1].cuda_graph_num_kv_splits[:num_token],
forward_batch.seq_lens[:bs],
)
@triton.jit
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment