Unverified Commit 73d4a5f8 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Organize spec-related data structures (#10735)

parent 7fb551a7
...@@ -37,7 +37,7 @@ except ImportError: ...@@ -37,7 +37,7 @@ except ImportError:
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)" IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
# Env var was set in sglang.srt.server_args.ServerArgs.__post__init__ # Env var was set in sglang.srt.server_args.ServerArgs.__post_init__
DISABLE_DISK_CACHE = get_bool_env_var("SGLANG_DISABLE_OUTLINES_DISK_CACHE", "true") DISABLE_DISK_CACHE = get_bool_env_var("SGLANG_DISABLE_OUTLINES_DISK_CACHE", "true")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -157,7 +157,7 @@ class ScheduleBatchDisaggregationDecodeMixin: ...@@ -157,7 +157,7 @@ class ScheduleBatchDisaggregationDecodeMixin:
hidden_states = torch.stack(hidden_states_list, dim=0).to(self.device) hidden_states = torch.stack(hidden_states_list, dim=0).to(self.device)
# local import to avoid circular import # local import to avoid circular import
from sglang.srt.speculative.eagle_utils import EagleDraftInput from sglang.srt.speculative.eagle_info import EagleDraftInput
spec_info = EagleDraftInput( spec_info = EagleDraftInput(
topk_p=topk_p, topk_p=topk_p,
......
...@@ -4,18 +4,13 @@ from __future__ import annotations ...@@ -4,18 +4,13 @@ from __future__ import annotations
end to end attention solution with aiter kernels end to end attention solution with aiter kernels
""" """
import math
import os
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, auto from enum import Enum, auto
from functools import partial from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, List, Optional, Union
import torch import torch
import triton import triton
import triton.language as tl
from sglang.global_config import global_config
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
from sglang.srt.layers.dp_attention import ( from sglang.srt.layers.dp_attention import (
...@@ -27,7 +22,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo ...@@ -27,7 +22,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.spec_info import SpecInfo from sglang.srt.speculative.spec_info import SpecInput
try: try:
from aiter import ( from aiter import (
...@@ -374,7 +369,7 @@ class AiterAttnBackend(AttentionBackend): ...@@ -374,7 +369,7 @@ class AiterAttnBackend(AttentionBackend):
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode, forward_mode: ForwardMode,
spec_info: Optional[SpecInfo], spec_info: Optional[SpecInput],
): ):
if forward_mode.is_decode_or_idle(): if forward_mode.is_decode_or_idle():
qo_indptr = None qo_indptr = None
...@@ -509,7 +504,7 @@ class AiterAttnBackend(AttentionBackend): ...@@ -509,7 +504,7 @@ class AiterAttnBackend(AttentionBackend):
seq_lens_sum: int, seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode, forward_mode: ForwardMode,
spec_info: Optional[SpecInfo], spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor], seq_lens_cpu: Optional[torch.Tensor],
): ):
if forward_mode.is_decode_or_idle(): if forward_mode.is_decode_or_idle():
...@@ -888,7 +883,7 @@ class AiterIndicesUpdaterPrefill: ...@@ -888,7 +883,7 @@ class AiterIndicesUpdaterPrefill:
seq_lens_sum: int, seq_lens_sum: int,
prefix_lens: torch.Tensor, prefix_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo], spec_info: Optional[SpecInput],
): ):
# Keep the signature for type checking. It will be assigned during runtime. # Keep the signature for type checking. It will be assigned during runtime.
raise NotImplementedError() raise NotImplementedError()
...@@ -900,7 +895,7 @@ class AiterIndicesUpdaterPrefill: ...@@ -900,7 +895,7 @@ class AiterIndicesUpdaterPrefill:
seq_lens_sum: int, seq_lens_sum: int,
prefix_lens: torch.Tensor, prefix_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo], spec_info: Optional[SpecInput],
): ):
kv_start_idx = None kv_start_idx = None
...@@ -984,7 +979,7 @@ class AiterMlaIndicesUpdaterPrefill: ...@@ -984,7 +979,7 @@ class AiterMlaIndicesUpdaterPrefill:
extend_lens: torch.Tensor, extend_lens: torch.Tensor,
max_q_len: int, max_q_len: int,
max_kv_len: int, max_kv_len: int,
spec_info: Optional[SpecInfo], spec_info: Optional[SpecInput],
): ):
# Keep the signature for type checking. It will be assigned during runtime. # Keep the signature for type checking. It will be assigned during runtime.
raise NotImplementedError() raise NotImplementedError()
...@@ -997,7 +992,7 @@ class AiterMlaIndicesUpdaterPrefill: ...@@ -997,7 +992,7 @@ class AiterMlaIndicesUpdaterPrefill:
extend_lens: torch.Tensor, extend_lens: torch.Tensor,
max_q_len: int, max_q_len: int,
max_kv_len: int, max_kv_len: int,
spec_info: Optional[SpecInfo], spec_info: Optional[SpecInput],
): ):
bs = len(req_pool_indices) bs = len(req_pool_indices)
...@@ -1054,7 +1049,7 @@ class AiterMultiStepDraftBackend: ...@@ -1054,7 +1049,7 @@ class AiterMultiStepDraftBackend:
topk: int, topk: int,
speculative_num_steps: int, speculative_num_steps: int,
): ):
from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices
self.topk = topk self.topk = topk
self.speculative_num_steps = speculative_num_steps self.speculative_num_steps = speculative_num_steps
......
...@@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, List, Optional ...@@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, List, Optional
import torch import torch
import torch_npu import torch_npu
from torch.nn.functional import scaled_dot_product_attention
from sglang.srt.configs.model_config import AttentionArch from sglang.srt.configs.model_config import AttentionArch
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
...@@ -13,7 +12,8 @@ from sglang.srt.layers.attention.npu_ops.mla_preprocess import is_mla_preprocess ...@@ -13,7 +12,8 @@ from sglang.srt.layers.attention.npu_ops.mla_preprocess import is_mla_preprocess
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.layers.radix_attention import AttentionType from sglang.srt.layers.radix_attention import AttentionType
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.spec_info import SpecInput
from sglang.srt.utils import get_bool_env_var from sglang.srt.utils import get_bool_env_var
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -127,7 +127,7 @@ class AscendAttnBackend(AttentionBackend): ...@@ -127,7 +127,7 @@ class AscendAttnBackend(AttentionBackend):
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode, forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], spec_info: Optional[SpecInput],
): ):
metadata = ForwardMetadata() metadata = ForwardMetadata()
...@@ -147,7 +147,7 @@ class AscendAttnBackend(AttentionBackend): ...@@ -147,7 +147,7 @@ class AscendAttnBackend(AttentionBackend):
seq_lens_sum: int, seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode, forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor], seq_lens_cpu: Optional[torch.Tensor],
): ):
metadata = self.graph_metadata[bs] metadata = self.graph_metadata[bs]
......
...@@ -8,7 +8,7 @@ import torch ...@@ -8,7 +8,7 @@ import torch
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.spec_info import SpecInput
class AttentionBackend(ABC): class AttentionBackend(ABC):
...@@ -31,7 +31,7 @@ class AttentionBackend(ABC): ...@@ -31,7 +31,7 @@ class AttentionBackend(ABC):
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode, forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], spec_info: Optional[SpecInput],
): ):
"""Init the metadata for a forward pass for capturing a cuda graph.""" """Init the metadata for a forward pass for capturing a cuda graph."""
raise NotImplementedError() raise NotImplementedError()
...@@ -44,7 +44,7 @@ class AttentionBackend(ABC): ...@@ -44,7 +44,7 @@ class AttentionBackend(ABC):
seq_lens_sum: int, seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode, forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor], seq_lens_cpu: Optional[torch.Tensor],
): ):
"""Init the metadata for a forward pass for replaying a cuda graph.""" """Init the metadata for a forward pass for replaying a cuda graph."""
......
...@@ -20,7 +20,7 @@ from sglang.srt.utils import is_cuda ...@@ -20,7 +20,7 @@ from sglang.srt.utils import is_cuda
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.spec_info import SpecInfo from sglang.srt.speculative.spec_info import SpecInput
_is_cuda = is_cuda() _is_cuda = is_cuda()
if _is_cuda: if _is_cuda:
...@@ -151,7 +151,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend): ...@@ -151,7 +151,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode, forward_mode: ForwardMode,
spec_info: Optional[SpecInfo], spec_info: Optional[SpecInput],
): ):
if forward_mode.is_decode_or_idle(): if forward_mode.is_decode_or_idle():
if spec_info is None: if spec_info is None:
...@@ -190,7 +190,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend): ...@@ -190,7 +190,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
seq_lens_sum: int, seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode, forward_mode: ForwardMode,
spec_info: Optional[SpecInfo], spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor], seq_lens_cpu: Optional[torch.Tensor],
): ):
......
...@@ -11,9 +11,8 @@ import triton.language as tl ...@@ -11,9 +11,8 @@ import triton.language as tl
from sglang.srt.configs.model_config import AttentionArch from sglang.srt.configs.model_config import AttentionArch
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.mem_cache.memory_pool import SWAKVPool
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.spec_info import SpecInput
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
...@@ -1487,7 +1486,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1487,7 +1486,7 @@ class FlashAttentionBackend(AttentionBackend):
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode, forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], spec_info: Optional[SpecInput],
): ):
"""Initialize forward metadata for capturing CUDA graph.""" """Initialize forward metadata for capturing CUDA graph."""
metadata = FlashAttentionMetadata() metadata = FlashAttentionMetadata()
...@@ -1722,7 +1721,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1722,7 +1721,7 @@ class FlashAttentionBackend(AttentionBackend):
seq_lens_sum: int, seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode, forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor], seq_lens_cpu: Optional[torch.Tensor],
out_cache_loc: Optional[torch.Tensor] = None, out_cache_loc: Optional[torch.Tensor] = None,
): ):
...@@ -2340,7 +2339,7 @@ class FlashAttentionMultiStepBackend: ...@@ -2340,7 +2339,7 @@ class FlashAttentionMultiStepBackend:
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
): ):
assert forward_batch.spec_info is not None assert forward_batch.spec_info is not None
assert isinstance(forward_batch.spec_info, EagleDraftInput) assert forward_batch.spec_info.is_draft_input()
for i in range(self.speculative_num_steps - 1): for i in range(self.speculative_num_steps - 1):
self.attn_backends[i].init_forward_metadata_capture_cuda_graph( self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
...@@ -2357,7 +2356,7 @@ class FlashAttentionMultiStepBackend: ...@@ -2357,7 +2356,7 @@ class FlashAttentionMultiStepBackend:
self, forward_batch: ForwardBatch, bs: int self, forward_batch: ForwardBatch, bs: int
): ):
assert forward_batch.spec_info is not None assert forward_batch.spec_info is not None
assert isinstance(forward_batch.spec_info, EagleDraftInput) assert forward_batch.spec_info.is_draft_input()
for i in range(self.speculative_num_steps - 1): for i in range(self.speculative_num_steps - 1):
# TODO: incrementally update the metadata for the later steps, # TODO: incrementally update the metadata for the later steps,
......
...@@ -28,8 +28,8 @@ from sglang.srt.layers.dp_attention import get_attention_tp_size ...@@ -28,8 +28,8 @@ from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.layers.radix_attention import AttentionType from sglang.srt.layers.radix_attention import AttentionType
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.ngram_utils import NgramVerifyInput from sglang.srt.speculative.spec_info import SpecInput
from sglang.srt.utils import ( from sglang.srt.utils import (
get_int_env_var, get_int_env_var,
is_flashinfer_available, is_flashinfer_available,
...@@ -344,7 +344,7 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -344,7 +344,7 @@ class FlashInferAttnBackend(AttentionBackend):
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode, forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]], spec_info: Optional[SpecInput],
): ):
if forward_mode.is_decode_or_idle(): if forward_mode.is_decode_or_idle():
decode_wrappers = [] decode_wrappers = []
...@@ -451,7 +451,7 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -451,7 +451,7 @@ class FlashInferAttnBackend(AttentionBackend):
seq_lens_sum: int, seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode, forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]], spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor], seq_lens_cpu: Optional[torch.Tensor],
): ):
if forward_mode.is_decode_or_idle(): if forward_mode.is_decode_or_idle():
...@@ -669,7 +669,7 @@ class FlashInferIndicesUpdaterDecode: ...@@ -669,7 +669,7 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum: int, seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]], spec_info: Optional[SpecInput],
fixed_split_size: Optional[int] = None, fixed_split_size: Optional[int] = None,
disable_split_kv: Optional[bool] = None, disable_split_kv: Optional[bool] = None,
): ):
...@@ -684,7 +684,7 @@ class FlashInferIndicesUpdaterDecode: ...@@ -684,7 +684,7 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum: int, seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]], spec_info: Optional[SpecInput],
fixed_split_size: Optional[int] = None, fixed_split_size: Optional[int] = None,
disable_split_kv: Optional[bool] = None, disable_split_kv: Optional[bool] = None,
): ):
...@@ -710,7 +710,7 @@ class FlashInferIndicesUpdaterDecode: ...@@ -710,7 +710,7 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum: int, seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]], spec_info: Optional[SpecInput],
fixed_split_size: Optional[int] = None, fixed_split_size: Optional[int] = None,
disable_split_kv: Optional[bool] = None, disable_split_kv: Optional[bool] = None,
): ):
...@@ -760,7 +760,7 @@ class FlashInferIndicesUpdaterDecode: ...@@ -760,7 +760,7 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum: int, seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]], spec_info: Optional[SpecInput],
fixed_split_size: Optional[int] = None, fixed_split_size: Optional[int] = None,
disable_split_kv: Optional[bool] = None, disable_split_kv: Optional[bool] = None,
): ):
...@@ -794,7 +794,7 @@ class FlashInferIndicesUpdaterDecode: ...@@ -794,7 +794,7 @@ class FlashInferIndicesUpdaterDecode:
paged_kernel_lens_sum: int, paged_kernel_lens_sum: int,
kv_indptr: torch.Tensor, kv_indptr: torch.Tensor,
kv_start_idx: torch.Tensor, kv_start_idx: torch.Tensor,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]], spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor], seq_lens_cpu: Optional[torch.Tensor],
use_sliding_window_kv_pool: bool = False, use_sliding_window_kv_pool: bool = False,
fixed_split_size: Optional[int] = None, fixed_split_size: Optional[int] = None,
...@@ -905,7 +905,7 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -905,7 +905,7 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool, use_ragged: bool,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]], spec_info: Optional[SpecInput],
fixed_split_size: Optional[int] = None, fixed_split_size: Optional[int] = None,
): ):
# Keep the signature for type checking. It will be assigned during runtime. # Keep the signature for type checking. It will be assigned during runtime.
...@@ -921,7 +921,7 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -921,7 +921,7 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool, use_ragged: bool,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]], spec_info: Optional[SpecInput],
fixed_split_size: Optional[int] = None, fixed_split_size: Optional[int] = None,
): ):
if use_ragged: if use_ragged:
...@@ -959,7 +959,7 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -959,7 +959,7 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool, use_ragged: bool,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]], spec_info: Optional[SpecInput],
fixed_split_size: Optional[int] = None, fixed_split_size: Optional[int] = None,
): ):
for wrapper_id in range(2): for wrapper_id in range(2):
...@@ -1006,7 +1006,7 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -1006,7 +1006,7 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool, use_ragged: bool,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]], spec_info: Optional[SpecInput],
fixed_split_size: Optional[int] = None, fixed_split_size: Optional[int] = None,
): ):
for wrapper_id in range(2): for wrapper_id in range(2):
...@@ -1049,7 +1049,7 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -1049,7 +1049,7 @@ class FlashInferIndicesUpdaterPrefill:
kv_indptr: torch.Tensor, kv_indptr: torch.Tensor,
qo_indptr: torch.Tensor, qo_indptr: torch.Tensor,
use_ragged: bool, use_ragged: bool,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]], spec_info: Optional[SpecInput],
use_sliding_window_kv_pool: bool = False, use_sliding_window_kv_pool: bool = False,
fixed_split_size: Optional[int] = None, fixed_split_size: Optional[int] = None,
): ):
...@@ -1077,9 +1077,7 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -1077,9 +1077,7 @@ class FlashInferIndicesUpdaterPrefill:
qo_indptr = qo_indptr[: bs + 1] qo_indptr = qo_indptr[: bs + 1]
custom_mask = None custom_mask = None
else: else:
assert isinstance( assert isinstance(spec_info, SpecInput)
spec_info, (EagleDraftInput, EagleVerifyInput, NgramVerifyInput)
)
kv_indices, kv_indptr, qo_indptr, custom_mask = ( kv_indices, kv_indptr, qo_indptr, custom_mask = (
spec_info.generate_attn_arg_prefill( spec_info.generate_attn_arg_prefill(
req_pool_indices, req_pool_indices,
...@@ -1138,7 +1136,7 @@ class FlashInferMultiStepDraftBackend: ...@@ -1138,7 +1136,7 @@ class FlashInferMultiStepDraftBackend:
topk: int, topk: int,
speculative_num_steps: int, speculative_num_steps: int,
): ):
from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices
self.topk = topk self.topk = topk
self.speculative_num_steps = speculative_num_steps self.speculative_num_steps = speculative_num_steps
...@@ -1202,7 +1200,7 @@ class FlashInferMultiStepDraftBackend: ...@@ -1202,7 +1200,7 @@ class FlashInferMultiStepDraftBackend:
) )
assert forward_batch.spec_info is not None assert forward_batch.spec_info is not None
assert isinstance(forward_batch.spec_info, EagleDraftInput) assert forward_batch.spec_info.is_draft_input()
# Copy the kv_indptr once to avoid multiple device-to-host copies in flashinfer's plan. # Copy the kv_indptr once to avoid multiple device-to-host copies in flashinfer's plan.
indptr_cpu_whole = self.kv_indptr[:, : bs + 1].cpu() indptr_cpu_whole = self.kv_indptr[:, : bs + 1].cpu()
......
...@@ -30,7 +30,7 @@ from sglang.srt.layers.attention.flashinfer_backend import ( ...@@ -30,7 +30,7 @@ from sglang.srt.layers.attention.flashinfer_backend import (
from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.spec_info import SpecInput
from sglang.srt.utils import ( from sglang.srt.utils import (
is_flashinfer_available, is_flashinfer_available,
is_sm100_supported, is_sm100_supported,
...@@ -40,7 +40,7 @@ from sglang.srt.utils import ( ...@@ -40,7 +40,7 @@ from sglang.srt.utils import (
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.spec_info import SpecInfo from sglang.srt.speculative.spec_info import SpecInput
if is_flashinfer_available(): if is_flashinfer_available():
from flashinfer import ( from flashinfer import (
...@@ -361,7 +361,7 @@ class FlashInferMLAAttnBackend(AttentionBackend): ...@@ -361,7 +361,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode, forward_mode: ForwardMode,
spec_info: Optional[SpecInfo], spec_info: Optional[SpecInput],
): ):
if forward_mode.is_decode_or_idle(): if forward_mode.is_decode_or_idle():
decode_wrapper = BatchMLAPagedAttentionWrapper( decode_wrapper = BatchMLAPagedAttentionWrapper(
...@@ -441,7 +441,7 @@ class FlashInferMLAAttnBackend(AttentionBackend): ...@@ -441,7 +441,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
seq_lens_sum: int, seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode, forward_mode: ForwardMode,
spec_info: Optional[SpecInfo], spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor], seq_lens_cpu: Optional[torch.Tensor],
): ):
if forward_mode.is_decode_or_idle(): if forward_mode.is_decode_or_idle():
...@@ -663,7 +663,7 @@ class FlashInferMLAIndicesUpdaterDecode: ...@@ -663,7 +663,7 @@ class FlashInferMLAIndicesUpdaterDecode:
seq_lens_sum: int, seq_lens_sum: int,
decode_wrapper: BatchMLAPagedAttentionWrapper, decode_wrapper: BatchMLAPagedAttentionWrapper,
init_metadata_replay: bool = False, init_metadata_replay: bool = False,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None, spec_info: Optional[SpecInput] = None,
**fast_decode_kwargs, **fast_decode_kwargs,
): ):
decode_wrapper = decode_wrapper or self.decode_wrapper decode_wrapper = decode_wrapper or self.decode_wrapper
...@@ -688,7 +688,7 @@ class FlashInferMLAIndicesUpdaterDecode: ...@@ -688,7 +688,7 @@ class FlashInferMLAIndicesUpdaterDecode:
q_indptr: torch.Tensor, q_indptr: torch.Tensor,
kv_indptr: torch.Tensor, kv_indptr: torch.Tensor,
init_metadata_replay: bool = False, init_metadata_replay: bool = False,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None, spec_info: Optional[SpecInput] = None,
**fast_decode_kwargs, **fast_decode_kwargs,
): ):
bs = len(req_pool_indices) bs = len(req_pool_indices)
...@@ -776,7 +776,7 @@ class FlashInferMLAIndicesUpdaterPrefill: ...@@ -776,7 +776,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
prefix_lens: torch.Tensor, prefix_lens: torch.Tensor,
prefill_wrapper_paged: BatchMLAPagedAttentionWrapper, prefill_wrapper_paged: BatchMLAPagedAttentionWrapper,
use_ragged: bool, use_ragged: bool,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None, spec_info: Optional[SpecInput] = None,
): ):
if use_ragged: if use_ragged:
paged_kernel_lens = prefix_lens paged_kernel_lens = prefix_lens
...@@ -811,7 +811,7 @@ class FlashInferMLAIndicesUpdaterPrefill: ...@@ -811,7 +811,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
kv_indptr: torch.Tensor, kv_indptr: torch.Tensor,
qo_indptr: torch.Tensor, qo_indptr: torch.Tensor,
use_ragged: bool, use_ragged: bool,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None, spec_info: Optional[SpecInput] = None,
): ):
bs = len(seq_lens) bs = len(seq_lens)
sm_scale = self.scaling sm_scale = self.scaling
...@@ -838,9 +838,7 @@ class FlashInferMLAIndicesUpdaterPrefill: ...@@ -838,9 +838,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
qo_indptr = qo_indptr[: bs + 1] qo_indptr = qo_indptr[: bs + 1]
custom_mask = None custom_mask = None
else: else:
assert isinstance(spec_info, EagleDraftInput) or isinstance( assert isinstance(spec_info, SpecInput)
spec_info, EagleVerifyInput
)
# TODO: Support topk > 1 with custom mask # TODO: Support topk > 1 with custom mask
kv_indices, kv_indptr, qo_indptr, custom_mask = ( kv_indices, kv_indptr, qo_indptr, custom_mask = (
spec_info.generate_attn_arg_prefill( spec_info.generate_attn_arg_prefill(
...@@ -894,7 +892,7 @@ class FlashInferMLAMultiStepDraftBackend: ...@@ -894,7 +892,7 @@ class FlashInferMLAMultiStepDraftBackend:
topk: int, topk: int,
speculative_num_steps: int, speculative_num_steps: int,
): ):
from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices
if topk > 1: if topk > 1:
raise ValueError( raise ValueError(
...@@ -963,7 +961,7 @@ class FlashInferMLAMultiStepDraftBackend: ...@@ -963,7 +961,7 @@ class FlashInferMLAMultiStepDraftBackend:
) )
assert forward_batch.spec_info is not None assert forward_batch.spec_info is not None
assert isinstance(forward_batch.spec_info, EagleDraftInput) assert forward_batch.spec_info.is_draft_input()
for i in range(self.speculative_num_steps - 1): for i in range(self.speculative_num_steps - 1):
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1] forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
...@@ -983,8 +981,6 @@ class FlashInferMLAMultiStepDraftBackend: ...@@ -983,8 +981,6 @@ class FlashInferMLAMultiStepDraftBackend:
) )
def call_fn(i, forward_batch): def call_fn(i, forward_batch):
assert forward_batch.spec_info is not None
assert isinstance(forward_batch.spec_info, EagleDraftInput)
forward_batch.spec_info.kv_indptr = ( forward_batch.spec_info.kv_indptr = (
forward_batch.spec_info.kv_indptr.clone() forward_batch.spec_info.kv_indptr.clone()
) )
......
...@@ -19,7 +19,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo ...@@ -19,7 +19,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.spec_info import SpecInfo from sglang.srt.speculative.spec_info import SpecInput
# FlashMLA only supports pagesize=64 # FlashMLA only supports pagesize=64
...@@ -187,7 +187,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend): ...@@ -187,7 +187,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode, forward_mode: ForwardMode,
spec_info: Optional[SpecInfo], spec_info: Optional[SpecInput],
): ):
if forward_mode.is_decode_or_idle(): if forward_mode.is_decode_or_idle():
max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE) max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE)
...@@ -257,7 +257,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend): ...@@ -257,7 +257,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
seq_lens_sum: int, seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode, forward_mode: ForwardMode,
spec_info: Optional[SpecInfo], spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor], seq_lens_cpu: Optional[torch.Tensor],
): ):
......
...@@ -6,7 +6,7 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend ...@@ -6,7 +6,7 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.spec_info import SpecInput
class HybridAttnBackend(AttentionBackend): class HybridAttnBackend(AttentionBackend):
...@@ -71,7 +71,7 @@ class HybridAttnBackend(AttentionBackend): ...@@ -71,7 +71,7 @@ class HybridAttnBackend(AttentionBackend):
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode, forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], spec_info: Optional[SpecInput],
): ):
backend = self._select_backend(forward_mode) backend = self._select_backend(forward_mode)
backend.init_forward_metadata_capture_cuda_graph( backend.init_forward_metadata_capture_cuda_graph(
...@@ -92,7 +92,7 @@ class HybridAttnBackend(AttentionBackend): ...@@ -92,7 +92,7 @@ class HybridAttnBackend(AttentionBackend):
seq_lens_sum: int, seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode, forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor], seq_lens_cpu: Optional[torch.Tensor],
): ):
backend = self._select_backend(forward_mode) backend = self._select_backend(forward_mode)
......
...@@ -21,8 +21,8 @@ from sglang.srt.layers.radix_attention import RadixAttention ...@@ -21,8 +21,8 @@ from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.models.qwen3_next import Qwen3HybridLinearDecoderLayer, fused_gdn_gating from sglang.srt.models.qwen3_next import fused_gdn_gating
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.spec_info import SpecInput
from sglang.srt.utils import is_cuda, is_npu from sglang.srt.utils import is_cuda, is_npu
if is_cuda(): if is_cuda():
...@@ -134,7 +134,7 @@ class MambaAttnBackend(AttentionBackend): ...@@ -134,7 +134,7 @@ class MambaAttnBackend(AttentionBackend):
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode, forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], spec_info: Optional[SpecInput],
): ):
if forward_mode.is_decode_or_idle(): if forward_mode.is_decode_or_idle():
self.query_start_loc_list[bs - 1].copy_( self.query_start_loc_list[bs - 1].copy_(
...@@ -161,7 +161,7 @@ class MambaAttnBackend(AttentionBackend): ...@@ -161,7 +161,7 @@ class MambaAttnBackend(AttentionBackend):
seq_lens_sum: int, seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode, forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor], seq_lens_cpu: Optional[torch.Tensor],
): ):
num_padding = torch.count_nonzero( num_padding = torch.count_nonzero(
...@@ -451,7 +451,7 @@ class HybridLinearAttnBackend(AttentionBackend): ...@@ -451,7 +451,7 @@ class HybridLinearAttnBackend(AttentionBackend):
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode, forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], spec_info: Optional[SpecInput],
): ):
for attn_backend in self.attn_backend_list: for attn_backend in self.attn_backend_list:
attn_backend.init_forward_metadata_capture_cuda_graph( attn_backend.init_forward_metadata_capture_cuda_graph(
...@@ -472,7 +472,7 @@ class HybridLinearAttnBackend(AttentionBackend): ...@@ -472,7 +472,7 @@ class HybridLinearAttnBackend(AttentionBackend):
seq_lens_sum: int, seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode, forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor], seq_lens_cpu: Optional[torch.Tensor],
): ):
for attn_backend in self.attn_backend_list: for attn_backend in self.attn_backend_list:
......
from typing import TYPE_CHECKING, Callable, List, Optional, Union from typing import TYPE_CHECKING, Callable, List, Optional
import torch import torch
from sglang.srt import two_batch_overlap from sglang.srt import two_batch_overlap
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.spec_info import SpecInput
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
...@@ -46,7 +46,7 @@ class TboAttnBackend(AttentionBackend): ...@@ -46,7 +46,7 @@ class TboAttnBackend(AttentionBackend):
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
forward_mode: "ForwardMode", forward_mode: "ForwardMode",
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], spec_info: Optional[SpecInput],
): ):
self.primary.init_forward_metadata_capture_cuda_graph( self.primary.init_forward_metadata_capture_cuda_graph(
bs=bs, bs=bs,
...@@ -77,7 +77,7 @@ class TboAttnBackend(AttentionBackend): ...@@ -77,7 +77,7 @@ class TboAttnBackend(AttentionBackend):
seq_lens_sum: int, seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
forward_mode: "ForwardMode", forward_mode: "ForwardMode",
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor], seq_lens_cpu: Optional[torch.Tensor],
): ):
self.primary.init_forward_metadata_replay_cuda_graph( self.primary.init_forward_metadata_replay_cuda_graph(
...@@ -112,7 +112,7 @@ class TboAttnBackend(AttentionBackend): ...@@ -112,7 +112,7 @@ class TboAttnBackend(AttentionBackend):
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
forward_mode: "ForwardMode", forward_mode: "ForwardMode",
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], spec_info: Optional[SpecInput],
# capture args # capture args
capture_num_tokens: int = None, capture_num_tokens: int = None,
# replay args # replay args
...@@ -196,7 +196,7 @@ def _init_forward_metadata_cuda_graph_split( ...@@ -196,7 +196,7 @@ def _init_forward_metadata_cuda_graph_split(
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
forward_mode: "ForwardMode", forward_mode: "ForwardMode",
spec_info: Optional[EagleVerifyInput], spec_info: Optional[SpecInput],
# capture args # capture args
capture_num_tokens: int = None, capture_num_tokens: int = None,
# replay args # replay args
......
...@@ -22,7 +22,7 @@ from sglang.srt.utils import ( ...@@ -22,7 +22,7 @@ from sglang.srt.utils import (
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.spec_info import SpecInput
def logit_capping_mod(logit_capping_method, logit_cap): def logit_capping_mod(logit_capping_method, logit_cap):
...@@ -482,7 +482,7 @@ class TritonAttnBackend(AttentionBackend): ...@@ -482,7 +482,7 @@ class TritonAttnBackend(AttentionBackend):
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode, forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], spec_info: Optional[SpecInput],
): ):
assert encoder_lens is None, "Not supported" assert encoder_lens is None, "Not supported"
window_kv_indptr = self.window_kv_indptr window_kv_indptr = self.window_kv_indptr
...@@ -638,7 +638,7 @@ class TritonAttnBackend(AttentionBackend): ...@@ -638,7 +638,7 @@ class TritonAttnBackend(AttentionBackend):
seq_lens_sum: int, seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode, forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor], seq_lens_cpu: Optional[torch.Tensor],
): ):
# NOTE: encoder_lens expected to be zeros or None # NOTE: encoder_lens expected to be zeros or None
...@@ -883,7 +883,7 @@ class TritonMultiStepDraftBackend: ...@@ -883,7 +883,7 @@ class TritonMultiStepDraftBackend:
topk: int, topk: int,
speculative_num_steps: int, speculative_num_steps: int,
): ):
from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices
self.topk = topk self.topk = topk
self.speculative_num_steps = speculative_num_steps self.speculative_num_steps = speculative_num_steps
......
...@@ -20,12 +20,10 @@ from sglang.srt.utils import is_flashinfer_available ...@@ -20,12 +20,10 @@ from sglang.srt.utils import is_flashinfer_available
if is_flashinfer_available(): if is_flashinfer_available():
import flashinfer import flashinfer
from sglang.srt.speculative.eagle_utils import EagleDraftInput
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.spec_info import SpecInfo from sglang.srt.speculative.spec_info import SpecInput
# Constants # Constants
DEFAULT_WORKSPACE_SIZE_MB = ( DEFAULT_WORKSPACE_SIZE_MB = (
...@@ -201,7 +199,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): ...@@ -201,7 +199,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode, forward_mode: ForwardMode,
spec_info: Optional[SpecInfo], spec_info: Optional[SpecInput],
): ):
"""Initialize metadata for CUDA graph capture.""" """Initialize metadata for CUDA graph capture."""
metadata = TRTLLMMHAMetadata() metadata = TRTLLMMHAMetadata()
...@@ -314,7 +312,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): ...@@ -314,7 +312,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
seq_lens_sum: int, seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode, forward_mode: ForwardMode,
spec_info: Optional[SpecInfo], spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor], seq_lens_cpu: Optional[torch.Tensor],
): ):
"""Replay CUDA graph with new inputs.""" """Replay CUDA graph with new inputs."""
...@@ -661,7 +659,7 @@ class TRTLLMHAAttnMultiStepDraftBackend(FlashInferMultiStepDraftBackend): ...@@ -661,7 +659,7 @@ class TRTLLMHAAttnMultiStepDraftBackend(FlashInferMultiStepDraftBackend):
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
): ):
assert forward_batch.spec_info is not None assert forward_batch.spec_info is not None
assert isinstance(forward_batch.spec_info, EagleDraftInput) assert forward_batch.spec_info.is_draft_input()
for i in range(self.speculative_num_steps - 1): for i in range(self.speculative_num_steps - 1):
self.attn_backends[i].init_forward_metadata_capture_cuda_graph( self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
...@@ -678,7 +676,7 @@ class TRTLLMHAAttnMultiStepDraftBackend(FlashInferMultiStepDraftBackend): ...@@ -678,7 +676,7 @@ class TRTLLMHAAttnMultiStepDraftBackend(FlashInferMultiStepDraftBackend):
self, forward_batch: ForwardBatch, bs: int self, forward_batch: ForwardBatch, bs: int
): ):
assert forward_batch.spec_info is not None assert forward_batch.spec_info is not None
assert isinstance(forward_batch.spec_info, EagleDraftInput) assert forward_batch.spec_info.is_draft_input()
for i in range(self.speculative_num_steps - 1): for i in range(self.speculative_num_steps - 1):
......
...@@ -30,7 +30,7 @@ if is_flashinfer_available(): ...@@ -30,7 +30,7 @@ if is_flashinfer_available():
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.spec_info import SpecInfo from sglang.srt.speculative.spec_info import SpecInput
_is_cuda = is_cuda() _is_cuda = is_cuda()
...@@ -214,7 +214,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -214,7 +214,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode, forward_mode: ForwardMode,
spec_info: Optional[SpecInfo], spec_info: Optional[SpecInput],
): ):
"""Initialize metadata for CUDA graph capture.""" """Initialize metadata for CUDA graph capture."""
...@@ -270,7 +270,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -270,7 +270,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
seq_lens_sum: int, seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode, forward_mode: ForwardMode,
spec_info: Optional[SpecInfo], spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor], seq_lens_cpu: Optional[torch.Tensor],
): ):
"""Replay CUDA graph with new inputs.""" """Replay CUDA graph with new inputs."""
......
...@@ -2,7 +2,7 @@ from __future__ import annotations ...@@ -2,7 +2,7 @@ from __future__ import annotations
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, Union from typing import TYPE_CHECKING, Optional
import torch import torch
import triton import triton
...@@ -17,7 +17,7 @@ from sglang.srt.utils import get_bool_env_var, get_device_core_count ...@@ -17,7 +17,7 @@ from sglang.srt.utils import get_bool_env_var, get_device_core_count
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.spec_info import SpecInput
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -393,7 +393,7 @@ class WaveAttnBackend(AttentionBackend): ...@@ -393,7 +393,7 @@ class WaveAttnBackend(AttentionBackend):
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode, forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], spec_info: Optional[SpecInput],
): ):
assert encoder_lens is None, "Not supported" assert encoder_lens is None, "Not supported"
...@@ -477,7 +477,7 @@ class WaveAttnBackend(AttentionBackend): ...@@ -477,7 +477,7 @@ class WaveAttnBackend(AttentionBackend):
seq_lens_sum: int, seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode, forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor], seq_lens_cpu: Optional[torch.Tensor],
): ):
# NOTE: encoder_lens expected to be zeros or None # NOTE: encoder_lens expected to be zeros or None
......
...@@ -11,12 +11,8 @@ from sglang.srt.distributed import ( ...@@ -11,12 +11,8 @@ from sglang.srt.distributed import (
get_moe_expert_parallel_world_size, get_moe_expert_parallel_world_size,
get_moe_tensor_parallel_rank, get_moe_tensor_parallel_rank,
get_moe_tensor_parallel_world_size, get_moe_tensor_parallel_world_size,
get_tp_group,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
use_symmetric_memory,
)
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
from sglang.srt.layers.moe import ( from sglang.srt.layers.moe import (
MoeRunnerConfig, MoeRunnerConfig,
...@@ -24,7 +20,6 @@ from sglang.srt.layers.moe import ( ...@@ -24,7 +20,6 @@ from sglang.srt.layers.moe import (
should_use_flashinfer_trtllm_moe, should_use_flashinfer_trtllm_moe,
) )
from sglang.srt.layers.moe.token_dispatcher.standard import ( from sglang.srt.layers.moe.token_dispatcher.standard import (
CombineInput,
StandardDispatcher, StandardDispatcher,
StandardDispatchOutput, StandardDispatchOutput,
) )
......
...@@ -73,9 +73,7 @@ from sglang.srt.utils import flatten_nested_list, support_triton ...@@ -73,9 +73,7 @@ from sglang.srt.utils import flatten_nested_list, support_triton
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.spec_info import SpecInput, SpeculativeAlgorithm
from sglang.srt.speculative.ngram_utils import NgramVerifyInput
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
...@@ -957,9 +955,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -957,9 +955,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Speculative decoding # Speculative decoding
spec_algorithm: SpeculativeAlgorithm = None spec_algorithm: SpeculativeAlgorithm = None
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]] = ( # spec_info: Optional[SpecInput] = None
None spec_info: Optional[SpecInput] = None
)
# Whether to return hidden states # Whether to return hidden states
return_hidden_states: bool = False return_hidden_states: bool = False
...@@ -1995,9 +1992,9 @@ class ModelWorkerBatch: ...@@ -1995,9 +1992,9 @@ class ModelWorkerBatch:
# Speculative decoding # Speculative decoding
spec_algorithm: SpeculativeAlgorithm = None spec_algorithm: SpeculativeAlgorithm = None
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput, NgramVerifyInput]] = (
None spec_info: Optional[SpecInput] = None
)
# If set, the output of the batch contains the hidden states of the run. # If set, the output of the batch contains the hidden states of the run.
capture_hidden_mode: CaptureHiddenMode = None capture_hidden_mode: CaptureHiddenMode = None
hicache_consumer_index: int = -1 hicache_consumer_index: int = -1
......
...@@ -607,7 +607,7 @@ class CPUGraphRunner: ...@@ -607,7 +607,7 @@ class CPUGraphRunner:
def get_spec_info(self, num_tokens: int): def get_spec_info(self, num_tokens: int):
spec_info = None spec_info = None
if self.model_runner.spec_algorithm.is_eagle(): if self.model_runner.spec_algorithm.is_eagle():
from sglang.srt.speculative.eagle_utils import EagleVerifyInput from sglang.srt.speculative.eagle_info import EagleVerifyInput
if self.model_runner.is_draft_worker: if self.model_runner.is_draft_worker:
raise RuntimeError("This should not happen.") raise RuntimeError("This should not happen.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment