Unverified Commit 73d4a5f8 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Organize spec-related data structures (#10735)

parent 7fb551a7
......@@ -37,7 +37,7 @@ except ImportError:
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
# Env var was set in sglang.srt.server_args.ServerArgs.__post__init__
# Env var was set in sglang.srt.server_args.ServerArgs.__post_init__
DISABLE_DISK_CACHE = get_bool_env_var("SGLANG_DISABLE_OUTLINES_DISK_CACHE", "true")
logger = logging.getLogger(__name__)
......
......@@ -157,7 +157,7 @@ class ScheduleBatchDisaggregationDecodeMixin:
hidden_states = torch.stack(hidden_states_list, dim=0).to(self.device)
# local import to avoid circular import
from sglang.srt.speculative.eagle_utils import EagleDraftInput
from sglang.srt.speculative.eagle_info import EagleDraftInput
spec_info = EagleDraftInput(
topk_p=topk_p,
......
......@@ -4,18 +4,13 @@ from __future__ import annotations
end to end attention solution with aiter kernels
"""
import math
import os
from dataclasses import dataclass
from enum import Enum, auto
from functools import partial
from typing import TYPE_CHECKING, List, Optional, Union
from typing import TYPE_CHECKING, Optional
import torch
import triton
import triton.language as tl
from sglang.global_config import global_config
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
from sglang.srt.layers.dp_attention import (
......@@ -27,7 +22,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.spec_info import SpecInfo
from sglang.srt.speculative.spec_info import SpecInput
try:
from aiter import (
......@@ -374,7 +369,7 @@ class AiterAttnBackend(AttentionBackend):
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
spec_info: Optional[SpecInput],
):
if forward_mode.is_decode_or_idle():
qo_indptr = None
......@@ -509,7 +504,7 @@ class AiterAttnBackend(AttentionBackend):
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor],
):
if forward_mode.is_decode_or_idle():
......@@ -888,7 +883,7 @@ class AiterIndicesUpdaterPrefill:
seq_lens_sum: int,
prefix_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo],
spec_info: Optional[SpecInput],
):
# Keep the signature for type checking. It will be assigned during runtime.
raise NotImplementedError()
......@@ -900,7 +895,7 @@ class AiterIndicesUpdaterPrefill:
seq_lens_sum: int,
prefix_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo],
spec_info: Optional[SpecInput],
):
kv_start_idx = None
......@@ -984,7 +979,7 @@ class AiterMlaIndicesUpdaterPrefill:
extend_lens: torch.Tensor,
max_q_len: int,
max_kv_len: int,
spec_info: Optional[SpecInfo],
spec_info: Optional[SpecInput],
):
# Keep the signature for type checking. It will be assigned during runtime.
raise NotImplementedError()
......@@ -997,7 +992,7 @@ class AiterMlaIndicesUpdaterPrefill:
extend_lens: torch.Tensor,
max_q_len: int,
max_kv_len: int,
spec_info: Optional[SpecInfo],
spec_info: Optional[SpecInput],
):
bs = len(req_pool_indices)
......@@ -1054,7 +1049,7 @@ class AiterMultiStepDraftBackend:
topk: int,
speculative_num_steps: int,
):
from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices
self.topk = topk
self.speculative_num_steps = speculative_num_steps
......
......@@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, List, Optional
import torch
import torch_npu
from torch.nn.functional import scaled_dot_product_attention
from sglang.srt.configs.model_config import AttentionArch
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
......@@ -13,7 +12,8 @@ from sglang.srt.layers.attention.npu_ops.mla_preprocess import is_mla_preprocess
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.layers.radix_attention import AttentionType
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.spec_info import SpecInput
from sglang.srt.utils import get_bool_env_var
if TYPE_CHECKING:
......@@ -127,7 +127,7 @@ class AscendAttnBackend(AttentionBackend):
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
spec_info: Optional[SpecInput],
):
metadata = ForwardMetadata()
......@@ -147,7 +147,7 @@ class AscendAttnBackend(AttentionBackend):
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor],
):
metadata = self.graph_metadata[bs]
......
......@@ -8,7 +8,7 @@ import torch
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.spec_info import SpecInput
class AttentionBackend(ABC):
......@@ -31,7 +31,7 @@ class AttentionBackend(ABC):
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
spec_info: Optional[SpecInput],
):
"""Init the metadata for a forward pass for capturing a cuda graph."""
raise NotImplementedError()
......@@ -44,7 +44,7 @@ class AttentionBackend(ABC):
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor],
):
"""Init the metadata for a forward pass for replaying a cuda graph."""
......
......@@ -20,7 +20,7 @@ from sglang.srt.utils import is_cuda
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.spec_info import SpecInfo
from sglang.srt.speculative.spec_info import SpecInput
_is_cuda = is_cuda()
if _is_cuda:
......@@ -151,7 +151,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
spec_info: Optional[SpecInput],
):
if forward_mode.is_decode_or_idle():
if spec_info is None:
......@@ -190,7 +190,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor],
):
......
......@@ -11,9 +11,8 @@ import triton.language as tl
from sglang.srt.configs.model_config import AttentionArch
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.mem_cache.memory_pool import SWAKVPool
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.spec_info import SpecInput
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
......@@ -1487,7 +1486,7 @@ class FlashAttentionBackend(AttentionBackend):
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
spec_info: Optional[SpecInput],
):
"""Initialize forward metadata for capturing CUDA graph."""
metadata = FlashAttentionMetadata()
......@@ -1722,7 +1721,7 @@ class FlashAttentionBackend(AttentionBackend):
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor],
out_cache_loc: Optional[torch.Tensor] = None,
):
......@@ -2340,7 +2339,7 @@ class FlashAttentionMultiStepBackend:
forward_batch: ForwardBatch,
):
assert forward_batch.spec_info is not None
assert isinstance(forward_batch.spec_info, EagleDraftInput)
assert forward_batch.spec_info.is_draft_input()
for i in range(self.speculative_num_steps - 1):
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
......@@ -2357,7 +2356,7 @@ class FlashAttentionMultiStepBackend:
self, forward_batch: ForwardBatch, bs: int
):
assert forward_batch.spec_info is not None
assert isinstance(forward_batch.spec_info, EagleDraftInput)
assert forward_batch.spec_info.is_draft_input()
for i in range(self.speculative_num_steps - 1):
# TODO: incrementally update the metadata for the later steps,
......
......@@ -28,8 +28,8 @@ from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.layers.radix_attention import AttentionType
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.ngram_utils import NgramVerifyInput
from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.spec_info import SpecInput
from sglang.srt.utils import (
get_int_env_var,
is_flashinfer_available,
......@@ -344,7 +344,7 @@ class FlashInferAttnBackend(AttentionBackend):
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
spec_info: Optional[SpecInput],
):
if forward_mode.is_decode_or_idle():
decode_wrappers = []
......@@ -451,7 +451,7 @@ class FlashInferAttnBackend(AttentionBackend):
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor],
):
if forward_mode.is_decode_or_idle():
......@@ -669,7 +669,7 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
spec_info: Optional[SpecInput],
fixed_split_size: Optional[int] = None,
disable_split_kv: Optional[bool] = None,
):
......@@ -684,7 +684,7 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
spec_info: Optional[SpecInput],
fixed_split_size: Optional[int] = None,
disable_split_kv: Optional[bool] = None,
):
......@@ -710,7 +710,7 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
spec_info: Optional[SpecInput],
fixed_split_size: Optional[int] = None,
disable_split_kv: Optional[bool] = None,
):
......@@ -760,7 +760,7 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
spec_info: Optional[SpecInput],
fixed_split_size: Optional[int] = None,
disable_split_kv: Optional[bool] = None,
):
......@@ -794,7 +794,7 @@ class FlashInferIndicesUpdaterDecode:
paged_kernel_lens_sum: int,
kv_indptr: torch.Tensor,
kv_start_idx: torch.Tensor,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor],
use_sliding_window_kv_pool: bool = False,
fixed_split_size: Optional[int] = None,
......@@ -905,7 +905,7 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
spec_info: Optional[SpecInput],
fixed_split_size: Optional[int] = None,
):
# Keep the signature for type checking. It will be assigned during runtime.
......@@ -921,7 +921,7 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
spec_info: Optional[SpecInput],
fixed_split_size: Optional[int] = None,
):
if use_ragged:
......@@ -959,7 +959,7 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
spec_info: Optional[SpecInput],
fixed_split_size: Optional[int] = None,
):
for wrapper_id in range(2):
......@@ -1006,7 +1006,7 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
spec_info: Optional[SpecInput],
fixed_split_size: Optional[int] = None,
):
for wrapper_id in range(2):
......@@ -1049,7 +1049,7 @@ class FlashInferIndicesUpdaterPrefill:
kv_indptr: torch.Tensor,
qo_indptr: torch.Tensor,
use_ragged: bool,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
spec_info: Optional[SpecInput],
use_sliding_window_kv_pool: bool = False,
fixed_split_size: Optional[int] = None,
):
......@@ -1077,9 +1077,7 @@ class FlashInferIndicesUpdaterPrefill:
qo_indptr = qo_indptr[: bs + 1]
custom_mask = None
else:
assert isinstance(
spec_info, (EagleDraftInput, EagleVerifyInput, NgramVerifyInput)
)
assert isinstance(spec_info, SpecInput)
kv_indices, kv_indptr, qo_indptr, custom_mask = (
spec_info.generate_attn_arg_prefill(
req_pool_indices,
......@@ -1138,7 +1136,7 @@ class FlashInferMultiStepDraftBackend:
topk: int,
speculative_num_steps: int,
):
from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices
self.topk = topk
self.speculative_num_steps = speculative_num_steps
......@@ -1202,7 +1200,7 @@ class FlashInferMultiStepDraftBackend:
)
assert forward_batch.spec_info is not None
assert isinstance(forward_batch.spec_info, EagleDraftInput)
assert forward_batch.spec_info.is_draft_input()
# Copy the kv_indptr once to avoid multiple device-to-host copies in flashinfer's plan.
indptr_cpu_whole = self.kv_indptr[:, : bs + 1].cpu()
......
......@@ -30,7 +30,7 @@ from sglang.srt.layers.attention.flashinfer_backend import (
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.spec_info import SpecInput
from sglang.srt.utils import (
is_flashinfer_available,
is_sm100_supported,
......@@ -40,7 +40,7 @@ from sglang.srt.utils import (
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.spec_info import SpecInfo
from sglang.srt.speculative.spec_info import SpecInput
if is_flashinfer_available():
from flashinfer import (
......@@ -361,7 +361,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
spec_info: Optional[SpecInput],
):
if forward_mode.is_decode_or_idle():
decode_wrapper = BatchMLAPagedAttentionWrapper(
......@@ -441,7 +441,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor],
):
if forward_mode.is_decode_or_idle():
......@@ -663,7 +663,7 @@ class FlashInferMLAIndicesUpdaterDecode:
seq_lens_sum: int,
decode_wrapper: BatchMLAPagedAttentionWrapper,
init_metadata_replay: bool = False,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
spec_info: Optional[SpecInput] = None,
**fast_decode_kwargs,
):
decode_wrapper = decode_wrapper or self.decode_wrapper
......@@ -688,7 +688,7 @@ class FlashInferMLAIndicesUpdaterDecode:
q_indptr: torch.Tensor,
kv_indptr: torch.Tensor,
init_metadata_replay: bool = False,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
spec_info: Optional[SpecInput] = None,
**fast_decode_kwargs,
):
bs = len(req_pool_indices)
......@@ -776,7 +776,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
prefix_lens: torch.Tensor,
prefill_wrapper_paged: BatchMLAPagedAttentionWrapper,
use_ragged: bool,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
spec_info: Optional[SpecInput] = None,
):
if use_ragged:
paged_kernel_lens = prefix_lens
......@@ -811,7 +811,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
kv_indptr: torch.Tensor,
qo_indptr: torch.Tensor,
use_ragged: bool,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
spec_info: Optional[SpecInput] = None,
):
bs = len(seq_lens)
sm_scale = self.scaling
......@@ -838,9 +838,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
qo_indptr = qo_indptr[: bs + 1]
custom_mask = None
else:
assert isinstance(spec_info, EagleDraftInput) or isinstance(
spec_info, EagleVerifyInput
)
assert isinstance(spec_info, SpecInput)
# TODO: Support topk > 1 with custom mask
kv_indices, kv_indptr, qo_indptr, custom_mask = (
spec_info.generate_attn_arg_prefill(
......@@ -894,7 +892,7 @@ class FlashInferMLAMultiStepDraftBackend:
topk: int,
speculative_num_steps: int,
):
from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices
if topk > 1:
raise ValueError(
......@@ -963,7 +961,7 @@ class FlashInferMLAMultiStepDraftBackend:
)
assert forward_batch.spec_info is not None
assert isinstance(forward_batch.spec_info, EagleDraftInput)
assert forward_batch.spec_info.is_draft_input()
for i in range(self.speculative_num_steps - 1):
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
......@@ -983,8 +981,6 @@ class FlashInferMLAMultiStepDraftBackend:
)
def call_fn(i, forward_batch):
assert forward_batch.spec_info is not None
assert isinstance(forward_batch.spec_info, EagleDraftInput)
forward_batch.spec_info.kv_indptr = (
forward_batch.spec_info.kv_indptr.clone()
)
......
......@@ -19,7 +19,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.spec_info import SpecInfo
from sglang.srt.speculative.spec_info import SpecInput
# FlashMLA only supports pagesize=64
......@@ -187,7 +187,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
spec_info: Optional[SpecInput],
):
if forward_mode.is_decode_or_idle():
max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE)
......@@ -257,7 +257,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor],
):
......
......@@ -6,7 +6,7 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.spec_info import SpecInput
class HybridAttnBackend(AttentionBackend):
......@@ -71,7 +71,7 @@ class HybridAttnBackend(AttentionBackend):
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
spec_info: Optional[SpecInput],
):
backend = self._select_backend(forward_mode)
backend.init_forward_metadata_capture_cuda_graph(
......@@ -92,7 +92,7 @@ class HybridAttnBackend(AttentionBackend):
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor],
):
backend = self._select_backend(forward_mode)
......
......@@ -21,8 +21,8 @@ from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.models.qwen3_next import Qwen3HybridLinearDecoderLayer, fused_gdn_gating
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.models.qwen3_next import fused_gdn_gating
from sglang.srt.speculative.spec_info import SpecInput
from sglang.srt.utils import is_cuda, is_npu
if is_cuda():
......@@ -134,7 +134,7 @@ class MambaAttnBackend(AttentionBackend):
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
spec_info: Optional[SpecInput],
):
if forward_mode.is_decode_or_idle():
self.query_start_loc_list[bs - 1].copy_(
......@@ -161,7 +161,7 @@ class MambaAttnBackend(AttentionBackend):
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor],
):
num_padding = torch.count_nonzero(
......@@ -451,7 +451,7 @@ class HybridLinearAttnBackend(AttentionBackend):
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
spec_info: Optional[SpecInput],
):
for attn_backend in self.attn_backend_list:
attn_backend.init_forward_metadata_capture_cuda_graph(
......@@ -472,7 +472,7 @@ class HybridLinearAttnBackend(AttentionBackend):
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor],
):
for attn_backend in self.attn_backend_list:
......
from typing import TYPE_CHECKING, Callable, List, Optional, Union
from typing import TYPE_CHECKING, Callable, List, Optional
import torch
from sglang.srt import two_batch_overlap
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.spec_info import SpecInput
if TYPE_CHECKING:
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
......@@ -46,7 +46,7 @@ class TboAttnBackend(AttentionBackend):
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: "ForwardMode",
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
spec_info: Optional[SpecInput],
):
self.primary.init_forward_metadata_capture_cuda_graph(
bs=bs,
......@@ -77,7 +77,7 @@ class TboAttnBackend(AttentionBackend):
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: "ForwardMode",
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor],
):
self.primary.init_forward_metadata_replay_cuda_graph(
......@@ -112,7 +112,7 @@ class TboAttnBackend(AttentionBackend):
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: "ForwardMode",
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
spec_info: Optional[SpecInput],
# capture args
capture_num_tokens: int = None,
# replay args
......@@ -196,7 +196,7 @@ def _init_forward_metadata_cuda_graph_split(
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: "ForwardMode",
spec_info: Optional[EagleVerifyInput],
spec_info: Optional[SpecInput],
# capture args
capture_num_tokens: int = None,
# replay args
......
......@@ -22,7 +22,7 @@ from sglang.srt.utils import (
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.spec_info import SpecInput
def logit_capping_mod(logit_capping_method, logit_cap):
......@@ -482,7 +482,7 @@ class TritonAttnBackend(AttentionBackend):
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
spec_info: Optional[SpecInput],
):
assert encoder_lens is None, "Not supported"
window_kv_indptr = self.window_kv_indptr
......@@ -638,7 +638,7 @@ class TritonAttnBackend(AttentionBackend):
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor],
):
# NOTE: encoder_lens expected to be zeros or None
......@@ -883,7 +883,7 @@ class TritonMultiStepDraftBackend:
topk: int,
speculative_num_steps: int,
):
from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices
self.topk = topk
self.speculative_num_steps = speculative_num_steps
......
......@@ -20,12 +20,10 @@ from sglang.srt.utils import is_flashinfer_available
if is_flashinfer_available():
import flashinfer
from sglang.srt.speculative.eagle_utils import EagleDraftInput
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.spec_info import SpecInfo
from sglang.srt.speculative.spec_info import SpecInput
# Constants
DEFAULT_WORKSPACE_SIZE_MB = (
......@@ -201,7 +199,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
spec_info: Optional[SpecInput],
):
"""Initialize metadata for CUDA graph capture."""
metadata = TRTLLMMHAMetadata()
......@@ -314,7 +312,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor],
):
"""Replay CUDA graph with new inputs."""
......@@ -661,7 +659,7 @@ class TRTLLMHAAttnMultiStepDraftBackend(FlashInferMultiStepDraftBackend):
forward_batch: ForwardBatch,
):
assert forward_batch.spec_info is not None
assert isinstance(forward_batch.spec_info, EagleDraftInput)
assert forward_batch.spec_info.is_draft_input()
for i in range(self.speculative_num_steps - 1):
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
......@@ -678,7 +676,7 @@ class TRTLLMHAAttnMultiStepDraftBackend(FlashInferMultiStepDraftBackend):
self, forward_batch: ForwardBatch, bs: int
):
assert forward_batch.spec_info is not None
assert isinstance(forward_batch.spec_info, EagleDraftInput)
assert forward_batch.spec_info.is_draft_input()
for i in range(self.speculative_num_steps - 1):
......
......@@ -30,7 +30,7 @@ if is_flashinfer_available():
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.spec_info import SpecInfo
from sglang.srt.speculative.spec_info import SpecInput
_is_cuda = is_cuda()
......@@ -214,7 +214,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
spec_info: Optional[SpecInput],
):
"""Initialize metadata for CUDA graph capture."""
......@@ -270,7 +270,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor],
):
"""Replay CUDA graph with new inputs."""
......
......@@ -2,7 +2,7 @@ from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, Union
from typing import TYPE_CHECKING, Optional
import torch
import triton
......@@ -17,7 +17,7 @@ from sglang.srt.utils import get_bool_env_var, get_device_core_count
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.spec_info import SpecInput
logger = logging.getLogger(__name__)
......@@ -393,7 +393,7 @@ class WaveAttnBackend(AttentionBackend):
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
spec_info: Optional[SpecInput],
):
assert encoder_lens is None, "Not supported"
......@@ -477,7 +477,7 @@ class WaveAttnBackend(AttentionBackend):
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor],
):
# NOTE: encoder_lens expected to be zeros or None
......
......@@ -11,12 +11,8 @@ from sglang.srt.distributed import (
get_moe_expert_parallel_world_size,
get_moe_tensor_parallel_rank,
get_moe_tensor_parallel_world_size,
get_tp_group,
tensor_model_parallel_all_reduce,
)
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
use_symmetric_memory,
)
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
from sglang.srt.layers.moe import (
MoeRunnerConfig,
......@@ -24,7 +20,6 @@ from sglang.srt.layers.moe import (
should_use_flashinfer_trtllm_moe,
)
from sglang.srt.layers.moe.token_dispatcher.standard import (
CombineInput,
StandardDispatcher,
StandardDispatchOutput,
)
......
......@@ -73,9 +73,7 @@ from sglang.srt.utils import flatten_nested_list, support_triton
if TYPE_CHECKING:
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.ngram_utils import NgramVerifyInput
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.speculative.spec_info import SpecInput, SpeculativeAlgorithm
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
......@@ -957,9 +955,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Speculative decoding
spec_algorithm: SpeculativeAlgorithm = None
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]] = (
None
)
# spec_info: Optional[SpecInput] = None
spec_info: Optional[SpecInput] = None
# Whether to return hidden states
return_hidden_states: bool = False
......@@ -1995,9 +1992,9 @@ class ModelWorkerBatch:
# Speculative decoding
spec_algorithm: SpeculativeAlgorithm = None
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput, NgramVerifyInput]] = (
None
)
spec_info: Optional[SpecInput] = None
# If set, the output of the batch contains the hidden states of the run.
capture_hidden_mode: CaptureHiddenMode = None
hicache_consumer_index: int = -1
......
......@@ -607,7 +607,7 @@ class CPUGraphRunner:
def get_spec_info(self, num_tokens: int):
spec_info = None
if self.model_runner.spec_algorithm.is_eagle():
from sglang.srt.speculative.eagle_utils import EagleVerifyInput
from sglang.srt.speculative.eagle_info import EagleVerifyInput
if self.model_runner.is_draft_worker:
raise RuntimeError("This should not happen.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment