Unverified Commit 73d4a5f8 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Organize spec-related data structures (#10735)

parent 7fb551a7
...@@ -821,7 +821,7 @@ class CudaGraphRunner: ...@@ -821,7 +821,7 @@ class CudaGraphRunner:
self.model_runner.spec_algorithm.is_eagle() self.model_runner.spec_algorithm.is_eagle()
or self.model_runner.spec_algorithm.is_standalone() or self.model_runner.spec_algorithm.is_standalone()
): ):
from sglang.srt.speculative.eagle_utils import EagleVerifyInput from sglang.srt.speculative.eagle_info import EagleVerifyInput
if self.model_runner.is_draft_worker: if self.model_runner.is_draft_worker:
raise RuntimeError("This should not happen.") raise RuntimeError("This should not happen.")
......
...@@ -45,13 +45,7 @@ from sglang.srt.layers.dp_attention import ( ...@@ -45,13 +45,7 @@ from sglang.srt.layers.dp_attention import (
get_attention_tp_size, get_attention_tp_size,
set_dp_buffer_len, set_dp_buffer_len,
) )
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding from sglang.srt.utils import get_compiler_backend, is_npu, support_triton
from sglang.srt.utils import (
flatten_nested_list,
get_compiler_backend,
is_npu,
support_triton,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
...@@ -60,8 +54,7 @@ if TYPE_CHECKING: ...@@ -60,8 +54,7 @@ if TYPE_CHECKING:
from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.spec_info import SpecInput, SpeculativeAlgorithm
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
_is_npu = is_npu() _is_npu = is_npu()
...@@ -293,7 +286,7 @@ class ForwardBatch: ...@@ -293,7 +286,7 @@ class ForwardBatch:
global_forward_mode: Optional[ForwardMode] = None global_forward_mode: Optional[ForwardMode] = None
# Speculative decoding # Speculative decoding
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None spec_info: Optional[SpecInput] = None
spec_algorithm: SpeculativeAlgorithm = None spec_algorithm: SpeculativeAlgorithm = None
capture_hidden_mode: CaptureHiddenMode = None capture_hidden_mode: CaptureHiddenMode = None
...@@ -364,33 +357,14 @@ class ForwardBatch: ...@@ -364,33 +357,14 @@ class ForwardBatch:
# For MLP sync # For MLP sync
if batch.global_num_tokens is not None: if batch.global_num_tokens is not None:
from sglang.srt.speculative.eagle_utils import (
EagleDraftInput,
EagleVerifyInput,
)
assert batch.global_num_tokens_for_logprob is not None assert batch.global_num_tokens_for_logprob is not None
# process global_num_tokens and global_num_tokens_for_logprob # process global_num_tokens and global_num_tokens_for_logprob
if batch.spec_info is not None: if batch.spec_info is not None:
if isinstance(batch.spec_info, EagleDraftInput): spec_info: SpecInput = batch.spec_info
global_num_tokens = [ global_num_tokens, global_num_tokens_for_logprob = (
x * batch.spec_info.num_tokens_per_batch spec_info.get_spec_adjusted_global_num_tokens(batch)
for x in batch.global_num_tokens )
]
global_num_tokens_for_logprob = [
x * batch.spec_info.num_tokens_for_logprob_per_batch
for x in batch.global_num_tokens_for_logprob
]
else:
assert isinstance(batch.spec_info, EagleVerifyInput)
global_num_tokens = [
x * batch.spec_info.draft_token_num
for x in batch.global_num_tokens
]
global_num_tokens_for_logprob = [
x * batch.spec_info.draft_token_num
for x in batch.global_num_tokens_for_logprob
]
else: else:
global_num_tokens = batch.global_num_tokens global_num_tokens = batch.global_num_tokens
global_num_tokens_for_logprob = batch.global_num_tokens_for_logprob global_num_tokens_for_logprob = batch.global_num_tokens_for_logprob
...@@ -669,9 +643,6 @@ class ForwardBatch: ...@@ -669,9 +643,6 @@ class ForwardBatch:
) )
def prepare_mlp_sync_batch(self, model_runner: ModelRunner): def prepare_mlp_sync_batch(self, model_runner: ModelRunner):
from sglang.srt.speculative.eagle_utils import EagleDraftInput
assert self.global_num_tokens_cpu is not None assert self.global_num_tokens_cpu is not None
assert self.global_num_tokens_for_logprob_cpu is not None assert self.global_num_tokens_for_logprob_cpu is not None
...@@ -768,7 +739,8 @@ class ForwardBatch: ...@@ -768,7 +739,8 @@ class ForwardBatch:
if self.extend_seq_lens is not None: if self.extend_seq_lens is not None:
self.extend_seq_lens = self._pad_tensor_to_size(self.extend_seq_lens, bs) self.extend_seq_lens = self._pad_tensor_to_size(self.extend_seq_lens, bs)
if self.spec_info is not None and isinstance(self.spec_info, EagleDraftInput): if self.spec_info is not None and self.spec_info.is_draft_input():
# FIXME(lsyin): remove this isinstance logic
spec_info = self.spec_info spec_info = self.spec_info
self.output_cache_loc_backup = self.out_cache_loc self.output_cache_loc_backup = self.out_cache_loc
self.hidden_states_backup = spec_info.hidden_states self.hidden_states_backup = spec_info.hidden_states
......
...@@ -20,7 +20,7 @@ from sglang.srt.model_executor.forward_batch_info import ( ...@@ -20,7 +20,7 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch, ForwardBatch,
ForwardMode, ForwardMode,
) )
from sglang.srt.speculative.eagle_utils import EagleDraftInput from sglang.srt.speculative.eagle_info import EagleDraftInput
from sglang.srt.utils import ( from sglang.srt.utils import (
require_attn_tp_gather, require_attn_tp_gather,
require_gathered_buffer, require_gathered_buffer,
......
...@@ -21,7 +21,8 @@ from sglang.srt.model_executor.forward_batch_info import ( ...@@ -21,7 +21,8 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch, ForwardBatch,
ForwardMode, ForwardMode,
) )
from sglang.srt.speculative.eagle_utils import EagleDraftInput, fast_topk from sglang.srt.speculative.eagle_info import EagleDraftInput
from sglang.srt.speculative.spec_utils import fast_topk
from sglang.srt.utils import ( from sglang.srt.utils import (
require_attn_tp_gather, require_attn_tp_gather,
require_gathered_buffer, require_gathered_buffer,
......
...@@ -34,16 +34,18 @@ from sglang.srt.speculative.eagle_draft_cuda_graph_runner import ( ...@@ -34,16 +34,18 @@ from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
from sglang.srt.speculative.eagle_draft_extend_cuda_graph_runner import ( from sglang.srt.speculative.eagle_draft_extend_cuda_graph_runner import (
EAGLEDraftExtendCudaGraphRunner, EAGLEDraftExtendCudaGraphRunner,
) )
from sglang.srt.speculative.eagle_utils import ( from sglang.srt.speculative.eagle_info import (
EagleDraftInput, EagleDraftInput,
EagleVerifyInput, EagleVerifyInput,
EagleVerifyOutput, EagleVerifyOutput,
)
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.speculative.spec_utils import (
assign_draft_cache_locs, assign_draft_cache_locs,
fast_topk, fast_topk,
generate_token_bitmask, generate_token_bitmask,
select_top_k_tokens, select_top_k_tokens,
) )
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import ( from sglang.srt.utils import (
empty_context, empty_context,
get_available_gpu_memory, get_available_gpu_memory,
......
...@@ -2,7 +2,7 @@ from __future__ import annotations ...@@ -2,7 +2,7 @@ from __future__ import annotations
import copy import copy
import logging import logging
from typing import Optional from typing import Optional, Tuple
import torch import torch
import triton import triton
...@@ -13,6 +13,7 @@ from dataclasses import dataclass ...@@ -13,6 +13,7 @@ from dataclasses import dataclass
import torch.nn.functional as F import torch.nn.functional as F
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import apply_custom_logit_processor from sglang.srt.layers.sampler import apply_custom_logit_processor
from sglang.srt.managers.schedule_batch import ( from sglang.srt.managers.schedule_batch import (
...@@ -21,10 +22,10 @@ from sglang.srt.managers.schedule_batch import ( ...@@ -21,10 +22,10 @@ from sglang.srt.managers.schedule_batch import (
global_server_args_dict, global_server_args_dict,
) )
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.speculative.eagle_utils import ( from sglang.srt.speculative.spec_info import SpecInput, SpecInputType
from sglang.srt.speculative.spec_utils import (
TREE_SPEC_KERNEL_AVAILABLE, TREE_SPEC_KERNEL_AVAILABLE,
assign_req_to_token_pool, assign_req_to_token_pool,
create_flashinfer_kv_indices_triton,
get_src_tgt_cache_loc, get_src_tgt_cache_loc,
get_target_cache_loc, get_target_cache_loc,
) )
...@@ -42,7 +43,7 @@ elif is_hip(): ...@@ -42,7 +43,7 @@ elif is_hip():
@dataclass @dataclass
class NgramVerifyInput: class NgramVerifyInput(SpecInput):
def __init__( def __init__(
self, self,
draft_token: torch.Tensor, draft_token: torch.Tensor,
...@@ -53,6 +54,7 @@ class NgramVerifyInput: ...@@ -53,6 +54,7 @@ class NgramVerifyInput:
retrive_next_sibling: torch.Tensor, retrive_next_sibling: torch.Tensor,
draft_token_num: int, draft_token_num: int,
): ):
super().__init__(SpecInputType.NGRAM_VERIFY)
self.draft_token = draft_token self.draft_token = draft_token
self.custom_mask = tree_mask self.custom_mask = tree_mask
self.positions = positions self.positions = positions
...@@ -62,6 +64,9 @@ class NgramVerifyInput: ...@@ -62,6 +64,9 @@ class NgramVerifyInput:
self.draft_token_num = draft_token_num self.draft_token_num = draft_token_num
self.device = self.custom_mask.device self.device = self.custom_mask.device
def get_spec_adjust_token_coefficient(self) -> Tuple[int, int]:
return self.draft_token_num, self.draft_token_num
def prepare_for_verify(self, batch: ScheduleBatch, page_size: int): def prepare_for_verify(self, batch: ScheduleBatch, page_size: int):
if batch.forward_mode.is_idle(): if batch.forward_mode.is_idle():
return return
......
import logging import logging
import os from typing import List, Optional
import threading
import time
from typing import TYPE_CHECKING, List, Optional, Union
import numpy as np import numpy as np
import torch import torch
...@@ -15,7 +12,6 @@ from sglang.srt.server_args import ServerArgs ...@@ -15,7 +12,6 @@ from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.cpp_ngram.ngram_cache import NgramCache from sglang.srt.speculative.cpp_ngram.ngram_cache import NgramCache
from sglang.srt.speculative.ngram_utils import NgramVerifyInput from sglang.srt.speculative.ngram_utils import NgramVerifyInput
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import broadcast_pyobj
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
from abc import ABC, abstractmethod
from enum import IntEnum, auto from enum import IntEnum, auto
from typing import List, Tuple
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
class SpeculativeAlgorithm(IntEnum): class SpeculativeAlgorithm(IntEnum):
...@@ -35,3 +39,41 @@ class SpeculativeAlgorithm(IntEnum): ...@@ -35,3 +39,41 @@ class SpeculativeAlgorithm(IntEnum):
if name is not None: if name is not None:
name = name.upper() name = name.upper()
return name_map[name] return name_map[name]
class SpecInputType(IntEnum):
# NOTE: introduce this to distinguish the SpecInput types of multiple algorithms when asserting in attention backends.
# If all algorithms can share the same datastrucutre of draft_input and verify_input, consider simplify it
EAGLE_DRAFT = auto()
EAGLE_VERIFY = auto()
NGRAM_VERIFY = auto()
class SpecInput(ABC):
def __init__(self, spec_input_type: SpecInputType):
self.spec_input_type = spec_input_type
def is_draft_input(self) -> bool:
# FIXME: remove this function which is only used for assertion
# or use another variable name like `draft_input` to substitute `spec_info`
return self.spec_input_type == SpecInputType.EAGLE_DRAFT
def is_verify_input(self) -> bool:
return self.spec_input_type in {
SpecInputType.EAGLE_VERIFY,
SpecInputType.NGRAM_VERIFY,
}
@abstractmethod
def get_spec_adjust_token_coefficient(self) -> Tuple[int, int]:
pass
def get_spec_adjusted_global_num_tokens(
self, forward_batch: ModelWorkerBatch
) -> Tuple[List[int], List[int]]:
c1, c2 = self.get_spec_adjust_token_coefficient()
global_num_tokens = [x * c1 for x in forward_batch.global_num_tokens]
global_num_tokens_for_logprob = [
x * c2 for x in forward_batch.global_num_tokens_for_logprob
]
return global_num_tokens, global_num_tokens_for_logprob
This diff is collapsed.
...@@ -30,7 +30,8 @@ from sglang.srt.model_executor.forward_batch_info import ( ...@@ -30,7 +30,8 @@ from sglang.srt.model_executor.forward_batch_info import (
) )
from sglang.srt.operations import execute_operations, execute_overlapped_operations from sglang.srt.operations import execute_operations, execute_overlapped_operations
from sglang.srt.operations_strategy import OperationsStrategy from sglang.srt.operations_strategy import OperationsStrategy
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.spec_info import SpecInput
from sglang.srt.utils import BumpAllocator, empty_context, get_bool_env_var, is_hip from sglang.srt.utils import BumpAllocator, empty_context, get_bool_env_var, is_hip
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -48,7 +49,7 @@ logger = logging.getLogger(__name__) ...@@ -48,7 +49,7 @@ logger = logging.getLogger(__name__)
def get_token_num_per_seq( def get_token_num_per_seq(
forward_mode: ForwardMode, forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None, spec_info: Optional[SpecInput] = None,
): ):
if forward_mode.is_target_verify(): if forward_mode.is_target_verify():
return spec_info.draft_token_num return spec_info.draft_token_num
...@@ -273,7 +274,7 @@ def compute_split_token_index( ...@@ -273,7 +274,7 @@ def compute_split_token_index(
def compute_split_indices_for_cuda_graph_replay( def compute_split_indices_for_cuda_graph_replay(
forward_mode: ForwardMode, forward_mode: ForwardMode,
cuda_graph_num_tokens: int, cuda_graph_num_tokens: int,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], spec_info: Optional[SpecInput],
): ):
forward_mode_for_tbo_split = ( forward_mode_for_tbo_split = (
forward_mode if forward_mode != ForwardMode.IDLE else ForwardMode.DECODE forward_mode if forward_mode != ForwardMode.IDLE else ForwardMode.DECODE
...@@ -333,7 +334,7 @@ class TboCudaGraphRunnerPlugin: ...@@ -333,7 +334,7 @@ class TboCudaGraphRunnerPlugin:
forward_mode: ForwardMode, forward_mode: ForwardMode,
bs: int, bs: int,
num_token_non_padded: int, num_token_non_padded: int,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], spec_info: Optional[SpecInput],
): ):
token_num_per_seq = get_token_num_per_seq( token_num_per_seq = get_token_num_per_seq(
forward_mode=forward_mode, spec_info=spec_info forward_mode=forward_mode, spec_info=spec_info
......
...@@ -7,7 +7,6 @@ or ...@@ -7,7 +7,6 @@ or
python3 test_forward_split_prefill.py python3 test_forward_split_prefill.py
""" """
import time
import unittest import unittest
import numpy as np import numpy as np
...@@ -16,7 +15,7 @@ import torch ...@@ -16,7 +15,7 @@ import torch
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment