Unverified Commit f44d1439 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Support target model verification in the attention backend (#2678)


Co-authored-by: default avataryukavio <kavioyu@gmail.com>
parent b6b57fc2
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Optional
from typing import TYPE_CHECKING, Optional
import torch
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.spec_info import SpecInfo
class AttentionBackend(ABC):
......@@ -22,9 +26,12 @@ class AttentionBackend(ABC):
def init_forward_metadata_capture_cuda_graph(
self,
bs: int,
num_token: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor] = None,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
):
"""Init the metadata for a forward pass for capturing a cuda graph."""
raise NotImplementedError()
......@@ -35,7 +42,9 @@ class AttentionBackend(ABC):
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor] = None,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
):
"""Init the metadata for a forward pass for replying a cuda graph."""
raise NotImplementedError()
......
......@@ -3,7 +3,6 @@ from __future__ import annotations
from typing import TYPE_CHECKING
import torch
import torch.nn as nn
from sglang.srt.layers.attention import AttentionBackend
from sglang.srt.managers.schedule_batch import global_server_args_dict
......@@ -52,8 +51,6 @@ class DoubleSparseAttnBackend(AttentionBackend):
self.forward_metadata = None
self.cuda_graph_max_seq_len = model_runner.model_config.context_len
def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Init auxiliary variables for triton attention backend."""
......@@ -115,55 +112,6 @@ class DoubleSparseAttnBackend(AttentionBackend):
ds_req_to_token,
)
def init_cuda_graph_state(self, max_bs: int):
# TODO(Andy): Support CUDA graph for double sparse attention
raise ValueError(
"Double sparse attention does not support CUDA graph for now. Please --disable-cuda-graph"
)
self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
self.cuda_graph_start_loc = torch.zeros(
(max_bs,), dtype=torch.int32, device="cuda"
)
self.cuda_graph_attn_logits = torch.empty(
(
self.num_head,
self.cuda_graph_max_total_num_tokens,
),
dtype=self.reduce_dtype,
device="cuda",
)
def init_forward_metadata_capture_cuda_graph(
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens=None,
):
# NOTE: encoder_lens expected to be zeros or None
self.forward_metadata = (
self.cuda_graph_start_loc,
self.cuda_graph_attn_logits,
self.cuda_graph_max_seq_len,
None,
)
def init_forward_metadata_replay_cuda_graph(
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens=None,
):
# NOTE: encoder_lens expected to be zeros or None
self.cuda_graph_start_loc.zero_()
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
def get_cuda_graph_seq_len_fill_value(self):
return 1
def forward_extend(
self,
q,
......
......@@ -10,7 +10,7 @@ Each backend supports two operators: extend (i.e. prefill with cached prefix) an
import os
from dataclasses import dataclass
from enum import Enum, auto
from typing import TYPE_CHECKING, List, Union
from typing import TYPE_CHECKING, List, Optional, Union
import torch
import triton
......@@ -18,12 +18,13 @@ import triton.language as tl
from sglang.global_config import global_config
from sglang.srt.layers.attention import AttentionBackend
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.utils import is_flashinfer_available
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.spec_info import SpecInfo
if is_flashinfer_available():
from flashinfer import (
......@@ -113,11 +114,15 @@ class FlashInferAttnBackend(AttentionBackend):
# Two wrappers: one for sliding window attention and one for full attention.
# Using two wrappers is unnecessary in the current PR, but are prepared for future PRs
self.prefill_wrappers_paged = []
self.prefill_wrappers_verify = []
self.decode_wrappers = []
for _ in range(self.num_wrappers):
self.prefill_wrappers_paged.append(
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
)
self.prefill_wrappers_verify.append(
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
)
self.decode_wrappers.append(
BatchDecodeWithPagedKVCacheWrapper(
self.workspace_buffer,
......@@ -135,6 +140,7 @@ class FlashInferAttnBackend(AttentionBackend):
# Other metadata
self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
self.decode_cuda_graph_metadata = {}
self.prefill_cuda_graph_metadata = {}
def init_forward_metadata(self, forward_batch: ForwardBatch):
if forward_batch.forward_mode.is_decode():
......@@ -144,8 +150,37 @@ class FlashInferAttnBackend(AttentionBackend):
forward_batch.seq_lens_sum,
decode_wrappers=self.decode_wrappers,
encoder_lens=forward_batch.encoder_lens,
spec_info=forward_batch.spec_info,
)
self.forward_metadata = DecodeMetadata(self.decode_wrappers)
elif forward_batch.forward_mode.is_draft_extend():
self.indices_updater_prefill.update(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_sum,
prefix_lens=None,
prefill_wrappers=self.prefill_wrappers_paged,
use_ragged=False,
encoder_lens=forward_batch.encoder_lens,
spec_info=forward_batch.spec_info,
)
self.forward_metadata = PrefillMetadata(
self.prefill_wrappers_paged, False, False
)
elif forward_batch.forward_mode.is_target_verify():
self.indices_updater_prefill.update(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_sum,
prefix_lens=None,
prefill_wrappers=self.prefill_wrappers_verify,
use_ragged=False,
encoder_lens=forward_batch.encoder_lens,
spec_info=forward_batch.spec_info,
)
self.forward_metadata = PrefillMetadata(
self.prefill_wrappers_verify, False, False
)
else:
prefix_lens = forward_batch.extend_prefix_lens
......@@ -165,6 +200,7 @@ class FlashInferAttnBackend(AttentionBackend):
prefill_wrappers=self.prefill_wrappers_paged,
use_ragged=use_ragged,
encoder_lens=forward_batch.encoder_lens,
spec_info=None,
)
self.forward_metadata = PrefillMetadata(
self.prefill_wrappers_paged, use_ragged, extend_no_prefix
......@@ -180,37 +216,80 @@ class FlashInferAttnBackend(AttentionBackend):
cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1)
]
self.cuda_graph_custom_mask = torch.zeros(
(max_bs * self.max_context_len),
dtype=torch.uint8,
device="cuda",
)
self.cuda_graph_qk_indptr = [x.clone() for x in self.kv_indptr]
self.cuda_graph_qo_indptr = [x.clone() for x in self.kv_indptr]
def init_forward_metadata_capture_cuda_graph(
self,
bs: int,
num_token: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: torch.Tensor = None,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
):
decode_wrappers = []
for i in range(self.num_wrappers):
decode_wrappers.append(
BatchDecodeWithPagedKVCacheWrapper(
self.workspace_buffer,
"NHD",
use_cuda_graph=True,
use_tensor_cores=self.decode_use_tensor_cores,
paged_kv_indptr_buffer=self.kv_indptr[i][: bs + 1],
paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
paged_kv_last_page_len_buffer=self.kv_last_page_len[:bs],
if forward_mode.is_decode():
decode_wrappers = []
for i in range(self.num_wrappers):
decode_wrappers.append(
BatchDecodeWithPagedKVCacheWrapper(
self.workspace_buffer,
"NHD",
use_cuda_graph=True,
use_tensor_cores=self.decode_use_tensor_cores,
paged_kv_indptr_buffer=self.kv_indptr[i][: num_token + 1],
paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
paged_kv_last_page_len_buffer=self.kv_last_page_len[:num_token],
)
)
seq_lens_sum = seq_lens.sum().item()
self.indices_updater_decode.update(
req_pool_indices,
seq_lens,
seq_lens_sum,
decode_wrappers=decode_wrappers,
encoder_lens=encoder_lens,
spec_info=spec_info,
)
seq_lens_sum = seq_lens.sum().item()
self.indices_updater_decode.update(
req_pool_indices,
seq_lens,
seq_lens_sum,
decode_wrappers=decode_wrappers,
encoder_lens=encoder_lens,
)
self.decode_cuda_graph_metadata[bs] = decode_wrappers
self.forward_metadata = DecodeMetadata(decode_wrappers)
self.decode_cuda_graph_metadata[bs] = decode_wrappers
self.forward_metadata = DecodeMetadata(decode_wrappers)
elif forward_mode.is_target_verify():
prefill_wrappers = []
for i in range(self.num_wrappers):
prefill_wrappers.append(
BatchPrefillWithPagedKVCacheWrapper(
self.workspace_buffer,
"NHD",
use_cuda_graph=True,
qo_indptr_buf=self.cuda_graph_qo_indptr[i][: bs + 1],
paged_kv_indptr_buf=self.kv_indptr[i][: bs + 1],
paged_kv_indices_buf=self.cuda_graph_kv_indices[i],
paged_kv_last_page_len_buf=self.kv_last_page_len[:bs],
custom_mask_buf=self.cuda_graph_custom_mask,
qk_indptr_buf=self.cuda_graph_qk_indptr[i][: bs + 1],
)
)
seq_lens_sum = seq_lens.sum().item()
self.indices_updater_prefill.update(
req_pool_indices,
seq_lens,
seq_lens_sum,
prefix_lens=None,
prefill_wrappers=prefill_wrappers,
use_ragged=False,
encoder_lens=encoder_lens,
spec_info=spec_info,
)
self.prefill_cuda_graph_metadata[bs] = prefill_wrappers
self.forward_metadata = PrefillMetadata(prefill_wrappers, False, False)
else:
raise ValueError(f"Invalid mode: {forward_mode=}")
def init_forward_metadata_replay_cuda_graph(
self,
......@@ -218,24 +297,41 @@ class FlashInferAttnBackend(AttentionBackend):
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens: torch.Tensor = None,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
):
self.indices_updater_decode.update(
req_pool_indices[:bs],
seq_lens[:bs],
seq_lens_sum,
decode_wrappers=self.decode_cuda_graph_metadata[bs],
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
)
if forward_mode.is_decode():
self.indices_updater_decode.update(
req_pool_indices[:bs],
seq_lens[:bs],
seq_lens_sum,
decode_wrappers=self.decode_cuda_graph_metadata[bs],
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
spec_info=spec_info,
)
elif forward_mode.is_target_verify():
self.indices_updater_prefill.update(
req_pool_indices[:bs],
seq_lens[:bs],
seq_lens_sum,
prefix_lens=None,
prefill_wrappers=self.prefill_cuda_graph_metadata[bs],
use_ragged=False,
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
spec_info=spec_info,
)
else:
raise ValueError("Invalid forward mode")
def get_cuda_graph_seq_len_fill_value(self):
return 0
def forward_extend(
self,
q,
k,
v,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
......@@ -293,9 +389,9 @@ class FlashInferAttnBackend(AttentionBackend):
def forward_decode(
self,
q,
k,
v,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
......@@ -348,7 +444,6 @@ class FlashInferIndicesUpdaterDecode:
self.data_type = model_runner.kv_cache_dtype
self.q_data_type = model_runner.dtype
self.sliding_window_size = model_runner.sliding_window_size
self.attn_backend = attn_backend
# Buffers and wrappers
......@@ -371,7 +466,8 @@ class FlashInferIndicesUpdaterDecode:
seq_lens: torch.Tensor,
seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo],
):
# Keep the signature for type checking. It will be assigned during runtime.
raise NotImplementedError()
......@@ -382,7 +478,8 @@ class FlashInferIndicesUpdaterDecode:
seq_lens: torch.Tensor,
seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo],
):
decode_wrappers = decode_wrappers or self.decode_wrappers
self.call_begin_forward(
......@@ -392,6 +489,7 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum,
self.kv_indptr[0],
None,
spec_info,
)
def update_sliding_window(
......@@ -400,7 +498,8 @@ class FlashInferIndicesUpdaterDecode:
seq_lens: torch.Tensor,
seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo],
):
for wrapper_id in range(2):
if wrapper_id == 0:
......@@ -424,6 +523,7 @@ class FlashInferIndicesUpdaterDecode:
paged_kernel_lens_sum_tmp,
self.kv_indptr[wrapper_id],
kv_start_idx_tmp,
spec_info,
)
def update_cross_attention(
......@@ -432,7 +532,8 @@ class FlashInferIndicesUpdaterDecode:
seq_lens: torch.Tensor,
seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo],
):
for wrapper_id in range(2):
if wrapper_id == 0:
......@@ -452,6 +553,7 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum,
self.kv_indptr[wrapper_id],
kv_start_idx,
spec_info,
)
def call_begin_forward(
......@@ -462,23 +564,30 @@ class FlashInferIndicesUpdaterDecode:
paged_kernel_lens_sum: int,
kv_indptr: torch.Tensor,
kv_start_idx: torch.Tensor,
spec_info: Optional[SpecInfo],
):
bs = len(req_pool_indices)
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.empty(
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
)
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
paged_kernel_lens,
kv_indptr,
kv_start_idx,
kv_indices,
self.req_to_token.shape[1],
)
if spec_info is None:
bs = len(req_pool_indices)
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.empty(
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
)
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
paged_kernel_lens,
kv_indptr,
kv_start_idx,
kv_indices,
self.req_to_token.shape[1],
)
else:
bs, kv_indices, kv_indptr = spec_info.generate_attn_arg_decode(
req_pool_indices,
paged_kernel_lens,
self.req_to_token,
)
wrapper.end_forward()
wrapper.begin_forward(
......@@ -507,7 +616,6 @@ class FlashInferIndicesUpdaterPrefill:
self.data_type = model_runner.kv_cache_dtype
self.q_data_type = model_runner.dtype
self.sliding_window_size = model_runner.sliding_window_size
self.attn_backend = attn_backend
# Buffers and wrappers
......@@ -534,7 +642,8 @@ class FlashInferIndicesUpdaterPrefill:
prefix_lens: torch.Tensor,
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool,
encoder_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo],
):
# Keep the signature for type checking. It will be assigned during runtime.
raise NotImplementedError()
......@@ -547,7 +656,8 @@ class FlashInferIndicesUpdaterPrefill:
prefix_lens: torch.Tensor,
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool,
encoder_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo],
):
if use_ragged:
paged_kernel_lens = prefix_lens
......@@ -568,6 +678,7 @@ class FlashInferIndicesUpdaterPrefill:
self.kv_indptr[0],
self.qo_indptr[0],
use_ragged,
spec_info,
)
def update_sliding_window(
......@@ -578,7 +689,8 @@ class FlashInferIndicesUpdaterPrefill:
prefix_lens: torch.Tensor,
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool,
encoder_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo],
):
for wrapper_id in range(2):
if wrapper_id == 0:
......@@ -607,6 +719,7 @@ class FlashInferIndicesUpdaterPrefill:
self.kv_indptr[wrapper_id],
self.qo_indptr[wrapper_id],
use_ragged,
spec_info,
)
def update_cross_attention(
......@@ -617,7 +730,8 @@ class FlashInferIndicesUpdaterPrefill:
prefix_lens: torch.Tensor,
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool,
encoder_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo],
):
for wrapper_id in range(2):
if wrapper_id == 0:
......@@ -643,6 +757,7 @@ class FlashInferIndicesUpdaterPrefill:
self.kv_indptr[wrapper_id],
self.qo_indptr[wrapper_id],
use_ragged,
spec_info,
)
def call_begin_forward(
......@@ -658,25 +773,37 @@ class FlashInferIndicesUpdaterPrefill:
kv_indptr: torch.Tensor,
qo_indptr: torch.Tensor,
use_ragged: bool,
spec_info: Optional[SpecInfo],
):
bs = len(req_pool_indices)
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.empty(
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
)
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
paged_kernel_lens,
kv_indptr,
kv_start_idx,
kv_indices,
self.req_to_token.shape[1],
)
if spec_info is None:
# Normal extend
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.empty(
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
)
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
paged_kernel_lens,
kv_indptr,
kv_start_idx,
kv_indices,
self.req_to_token.shape[1],
)
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
qo_indptr = qo_indptr[: bs + 1]
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
qo_indptr = qo_indptr[: bs + 1]
custom_mask = None
else:
kv_indices, kv_indptr, qo_indptr, custom_mask = (
spec_info.generate_attn_arg_prefill(
req_pool_indices,
paged_kernel_lens,
self.req_to_token,
)
)
# extend part
if use_ragged:
......@@ -702,6 +829,7 @@ class FlashInferIndicesUpdaterPrefill:
self.head_dim,
1,
q_data_type=self.q_data_type,
custom_mask=custom_mask,
)
......
from __future__ import annotations
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING
import torch
from torch.nn.functional import scaled_dot_product_attention
......@@ -23,43 +23,6 @@ class TorchNativeAttnBackend(AttentionBackend):
"""Init the metadata for a forward pass."""
pass
def init_cuda_graph_state(self, max_bs: int):
# TODO: Support CUDA graph
raise ValueError(
"Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
)
def init_forward_metadata_capture_cuda_graph(
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor] = None,
):
# TODO: Support CUDA graph
raise ValueError(
"Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
)
def init_forward_metadata_replay_cuda_graph(
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor] = None,
):
# TODO: Support CUDA graph
raise ValueError(
"Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
)
def get_cuda_graph_seq_len_fill_value(self):
# TODO: Support CUDA graph
raise ValueError(
"Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
)
def _run_sdpa_forward_extend(
self,
query: torch.Tensor,
......
from __future__ import annotations
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional
import torch
from sglang.srt.layers.attention import AttentionBackend
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.spec_info import SpecInfo
class TritonAttnBackend(AttentionBackend):
......@@ -80,11 +81,17 @@ class TritonAttnBackend(AttentionBackend):
def init_forward_metadata_capture_cuda_graph(
self,
bs: int,
num_token: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens=None,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
):
# NOTE: encoder_lens expected to be zeros or None
assert encoder_lens is None, "Not supported"
assert forward_mode.is_decode(), "Not supported"
assert spec_info is None, "Not supported"
self.forward_metadata = (
self.cuda_graph_attn_logits,
None,
......@@ -96,7 +103,9 @@ class TritonAttnBackend(AttentionBackend):
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens=None,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
):
# NOTE: encoder_lens expected to be zeros or None
self.cuda_graph_start_loc.zero_()
......@@ -107,9 +116,9 @@ class TritonAttnBackend(AttentionBackend):
def forward_extend(
self,
q,
k,
v,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
......@@ -146,9 +155,9 @@ class TritonAttnBackend(AttentionBackend):
def forward_decode(
self,
q,
k,
v,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
......
......@@ -25,14 +25,14 @@ from vllm.distributed import get_tensor_model_parallel_rank
from vllm.distributed.parallel_state import graph_capture
from vllm.model_executor.custom_op import CustomOp
from sglang.srt.layers.logits_processor import (
LogitsMetadata,
LogitsProcessor,
LogitsProcessorOutput,
)
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native
from sglang.srt.layers.torchao_utils import save_gemlite_cache
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode,
ForwardBatch,
ForwardMode,
)
from sglang.srt.utils import maybe_torch_compile, monkey_patch_vllm_all_gather
if TYPE_CHECKING:
......@@ -153,6 +153,10 @@ class CudaGraphRunner:
if bs <= model_runner.req_to_token_pool.size
and bs <= model_runner.server_args.cuda_graph_max_bs
]
self.capture_forward_mode = ForwardMode.DECODE
self.num_tokens_per_bs = 1
self.compile_bs = (
[
bs
......@@ -165,8 +169,8 @@ class CudaGraphRunner:
# Attention backend
self.max_bs = max(self.capture_bs)
self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs)
self.max_num_token = self.max_bs * self.num_tokens_per_bs
self.model_runner.attn_backend.init_cuda_graph_state(self.max_num_token)
self.seq_len_fill_value = (
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
)
......@@ -179,12 +183,13 @@ class CudaGraphRunner:
# Common inputs
with torch.device("cuda"):
self.input_ids = torch.zeros((self.max_bs,), dtype=torch.int32)
self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int32)
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
self.seq_lens = torch.full(
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
)
self.out_cache_loc = torch.zeros((self.max_bs,), dtype=torch.int32)
self.out_cache_loc = torch.zeros((self.max_num_token,), dtype=torch.int32)
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int32)
if self.is_encoder_decoder:
......@@ -229,6 +234,9 @@ class CudaGraphRunner:
self.model_runner.model.capture_mode = False
def can_run(self, forward_batch: ForwardBatch):
if not forward_batch.forward_mode.is_cuda_graph():
return False
if self.enable_dp_attention:
min_num_tokens, max_num_tokens = min(forward_batch.global_num_tokens), max(
forward_batch.global_num_tokens
......@@ -258,12 +266,12 @@ class CudaGraphRunner:
def capture(self):
with graph_capture() as graph_capture_context:
self.stream = graph_capture_context.stream
capture_bs = (
capture_range = (
tqdm.tqdm(self.capture_bs)
if get_tensor_model_parallel_rank() == 0
else self.capture_bs
)
for bs in capture_bs:
for bs in capture_range:
with patch_model(
self.model_runner.model,
bs in self.compile_bs,
......@@ -283,12 +291,15 @@ class CudaGraphRunner:
def capture_one_batch_size(self, bs: int, forward: Callable):
graph = torch.cuda.CUDAGraph()
stream = self.stream
num_token = bs * self.num_tokens_per_bs
# Common inputs
input_ids = self.input_ids[:bs]
input_ids = self.input_ids[:num_token]
req_pool_indices = self.req_pool_indices[:bs]
seq_lens = self.seq_lens[:bs]
out_cache_loc = self.out_cache_loc[:bs]
out_cache_loc = self.out_cache_loc[:num_token]
positions = self.positions[:num_token]
if self.is_encoder_decoder:
encoder_lens = self.encoder_lens[:bs]
else:
......@@ -304,37 +315,41 @@ class CudaGraphRunner:
global_num_tokens = None
gathered_buffer = None
forward_batch = ForwardBatch(
forward_mode=self.capture_forward_mode,
batch_size=bs,
input_ids=input_ids,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
req_to_token_pool=self.model_runner.req_to_token_pool,
token_to_kv_pool=self.model_runner.token_to_kv_pool,
attn_backend=self.model_runner.attn_backend,
out_cache_loc=out_cache_loc,
seq_lens_sum=seq_lens_sum,
encoder_lens=encoder_lens,
return_logprob=False,
top_logprobs_nums=[0] * num_token,
positions=positions,
global_num_tokens=global_num_tokens,
mrope_positions=mrope_positions,
gathered_buffer=gathered_buffer,
)
# Attention backend
self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
bs,
num_token,
req_pool_indices,
seq_lens,
encoder_lens,
forward_batch.forward_mode,
forward_batch.spec_info,
)
# Run and capture
def run_once():
forward_batch = ForwardBatch(
forward_mode=ForwardMode.DECODE,
batch_size=bs,
input_ids=input_ids,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
req_to_token_pool=self.model_runner.req_to_token_pool,
token_to_kv_pool=self.model_runner.token_to_kv_pool,
attn_backend=self.model_runner.attn_backend,
out_cache_loc=out_cache_loc,
seq_lens_sum=seq_lens_sum,
encoder_lens=encoder_lens,
return_logprob=False,
top_logprobs_nums=[0] * bs,
positions=clamp_position(seq_lens),
mrope_positions=mrope_positions,
global_num_tokens=global_num_tokens,
gathered_buffer=gathered_buffer,
)
logits_output = forward(input_ids, forward_batch.positions, forward_batch)
return logits_output.next_token_logits
return logits_output.next_token_logits, logits_output.hidden_states
for _ in range(2):
torch.cuda.synchronize()
......@@ -360,6 +375,9 @@ class CudaGraphRunner:
def replay(self, forward_batch: ForwardBatch):
assert forward_batch.out_cache_loc is not None
raw_bs = forward_batch.batch_size
# In normal decoding case, raw_bs == raw_num_token
# But in speculative decoding, raw_num_token is raw_bs * self.num_tokens_per_bs
raw_num_token = forward_batch.input_ids.numel()
# Pad
if self.enable_dp_attention:
......@@ -374,10 +392,13 @@ class CudaGraphRunner:
self.out_cache_loc.zero_()
# Common inputs
self.input_ids[:raw_bs].copy_(forward_batch.input_ids)
self.input_ids[:raw_num_token].copy_(forward_batch.input_ids)
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
self.out_cache_loc[:raw_bs].copy_(forward_batch.out_cache_loc)
self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
positions = clamp_position(forward_batch.seq_lens)
self.positions[:raw_num_token].copy_(positions)
if self.is_encoder_decoder:
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
if forward_batch.mrope_positions is not None:
......@@ -390,13 +411,18 @@ class CudaGraphRunner:
self.seq_lens,
forward_batch.seq_lens_sum + (bs - raw_bs),
self.encoder_lens,
forward_batch.forward_mode,
forward_batch.spec_info,
)
# Replay
self.graphs[bs].replay()
next_token_logits = self.output_buffers[bs][:raw_bs]
next_token_logits, hidden_states = self.output_buffers[bs]
logits_output = LogitsProcessorOutput(
next_token_logits=next_token_logits,
next_token_logits=next_token_logits[:raw_num_token],
hidden_states=(
hidden_states[:raw_num_token] if hidden_states is not None else None
),
)
return logits_output
......@@ -96,7 +96,7 @@ class ForwardMode(IntEnum):
return self == ForwardMode.DRAFT_EXTEND
def is_cuda_graph(self):
return self in (ForwardMode.DECODE, ForwardMode.TARGET_VERIFY)
return self == ForwardMode.DECODE or self == ForwardMode.TARGET_VERIFY
def is_dummy_first(self):
return self == ForwardMode.DUMMY_FIRST
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment