Unverified Commit 73d4a5f8 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Organize spec-related data structures (#10735)

parent 7fb551a7
......@@ -821,7 +821,7 @@ class CudaGraphRunner:
self.model_runner.spec_algorithm.is_eagle()
or self.model_runner.spec_algorithm.is_standalone()
):
from sglang.srt.speculative.eagle_utils import EagleVerifyInput
from sglang.srt.speculative.eagle_info import EagleVerifyInput
if self.model_runner.is_draft_worker:
raise RuntimeError("This should not happen.")
......
......@@ -45,13 +45,7 @@ from sglang.srt.layers.dp_attention import (
get_attention_tp_size,
set_dp_buffer_len,
)
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
from sglang.srt.utils import (
flatten_nested_list,
get_compiler_backend,
is_npu,
support_triton,
)
from sglang.srt.utils import get_compiler_backend, is_npu, support_triton
if TYPE_CHECKING:
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
......@@ -60,8 +54,7 @@ if TYPE_CHECKING:
from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.speculative.spec_info import SpecInput, SpeculativeAlgorithm
_is_npu = is_npu()
......@@ -293,7 +286,7 @@ class ForwardBatch:
global_forward_mode: Optional[ForwardMode] = None
# Speculative decoding
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
spec_info: Optional[SpecInput] = None
spec_algorithm: SpeculativeAlgorithm = None
capture_hidden_mode: CaptureHiddenMode = None
......@@ -364,33 +357,14 @@ class ForwardBatch:
# For MLP sync
if batch.global_num_tokens is not None:
from sglang.srt.speculative.eagle_utils import (
EagleDraftInput,
EagleVerifyInput,
)
assert batch.global_num_tokens_for_logprob is not None
# process global_num_tokens and global_num_tokens_for_logprob
if batch.spec_info is not None:
if isinstance(batch.spec_info, EagleDraftInput):
global_num_tokens = [
x * batch.spec_info.num_tokens_per_batch
for x in batch.global_num_tokens
]
global_num_tokens_for_logprob = [
x * batch.spec_info.num_tokens_for_logprob_per_batch
for x in batch.global_num_tokens_for_logprob
]
else:
assert isinstance(batch.spec_info, EagleVerifyInput)
global_num_tokens = [
x * batch.spec_info.draft_token_num
for x in batch.global_num_tokens
]
global_num_tokens_for_logprob = [
x * batch.spec_info.draft_token_num
for x in batch.global_num_tokens_for_logprob
]
spec_info: SpecInput = batch.spec_info
global_num_tokens, global_num_tokens_for_logprob = (
spec_info.get_spec_adjusted_global_num_tokens(batch)
)
else:
global_num_tokens = batch.global_num_tokens
global_num_tokens_for_logprob = batch.global_num_tokens_for_logprob
......@@ -669,9 +643,6 @@ class ForwardBatch:
)
def prepare_mlp_sync_batch(self, model_runner: ModelRunner):
from sglang.srt.speculative.eagle_utils import EagleDraftInput
assert self.global_num_tokens_cpu is not None
assert self.global_num_tokens_for_logprob_cpu is not None
......@@ -768,7 +739,8 @@ class ForwardBatch:
if self.extend_seq_lens is not None:
self.extend_seq_lens = self._pad_tensor_to_size(self.extend_seq_lens, bs)
if self.spec_info is not None and isinstance(self.spec_info, EagleDraftInput):
if self.spec_info is not None and self.spec_info.is_draft_input():
# FIXME(lsyin): remove this isinstance logic
spec_info = self.spec_info
self.output_cache_loc_backup = self.out_cache_loc
self.hidden_states_backup = spec_info.hidden_states
......
......@@ -20,7 +20,7 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch,
ForwardMode,
)
from sglang.srt.speculative.eagle_utils import EagleDraftInput
from sglang.srt.speculative.eagle_info import EagleDraftInput
from sglang.srt.utils import (
require_attn_tp_gather,
require_gathered_buffer,
......
......@@ -21,7 +21,8 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch,
ForwardMode,
)
from sglang.srt.speculative.eagle_utils import EagleDraftInput, fast_topk
from sglang.srt.speculative.eagle_info import EagleDraftInput
from sglang.srt.speculative.spec_utils import fast_topk
from sglang.srt.utils import (
require_attn_tp_gather,
require_gathered_buffer,
......
......@@ -34,16 +34,18 @@ from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
from sglang.srt.speculative.eagle_draft_extend_cuda_graph_runner import (
EAGLEDraftExtendCudaGraphRunner,
)
from sglang.srt.speculative.eagle_utils import (
from sglang.srt.speculative.eagle_info import (
EagleDraftInput,
EagleVerifyInput,
EagleVerifyOutput,
)
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.speculative.spec_utils import (
assign_draft_cache_locs,
fast_topk,
generate_token_bitmask,
select_top_k_tokens,
)
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import (
empty_context,
get_available_gpu_memory,
......
......@@ -2,7 +2,7 @@ from __future__ import annotations
import copy
import logging
from typing import Optional
from typing import Optional, Tuple
import torch
import triton
......@@ -13,6 +13,7 @@ from dataclasses import dataclass
import torch.nn.functional as F
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import apply_custom_logit_processor
from sglang.srt.managers.schedule_batch import (
......@@ -21,10 +22,10 @@ from sglang.srt.managers.schedule_batch import (
global_server_args_dict,
)
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.speculative.eagle_utils import (
from sglang.srt.speculative.spec_info import SpecInput, SpecInputType
from sglang.srt.speculative.spec_utils import (
TREE_SPEC_KERNEL_AVAILABLE,
assign_req_to_token_pool,
create_flashinfer_kv_indices_triton,
get_src_tgt_cache_loc,
get_target_cache_loc,
)
......@@ -42,7 +43,7 @@ elif is_hip():
@dataclass
class NgramVerifyInput:
class NgramVerifyInput(SpecInput):
def __init__(
self,
draft_token: torch.Tensor,
......@@ -53,6 +54,7 @@ class NgramVerifyInput:
retrive_next_sibling: torch.Tensor,
draft_token_num: int,
):
super().__init__(SpecInputType.NGRAM_VERIFY)
self.draft_token = draft_token
self.custom_mask = tree_mask
self.positions = positions
......@@ -62,6 +64,9 @@ class NgramVerifyInput:
self.draft_token_num = draft_token_num
self.device = self.custom_mask.device
def get_spec_adjust_token_coefficient(self) -> Tuple[int, int]:
return self.draft_token_num, self.draft_token_num
def prepare_for_verify(self, batch: ScheduleBatch, page_size: int):
if batch.forward_mode.is_idle():
return
......
import logging
import os
import threading
import time
from typing import TYPE_CHECKING, List, Optional, Union
from typing import List, Optional
import numpy as np
import torch
......@@ -15,7 +12,6 @@ from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.cpp_ngram.ngram_cache import NgramCache
from sglang.srt.speculative.ngram_utils import NgramVerifyInput
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import broadcast_pyobj
logger = logging.getLogger(__name__)
......
from abc import ABC, abstractmethod
from enum import IntEnum, auto
from typing import List, Tuple
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
class SpeculativeAlgorithm(IntEnum):
......@@ -35,3 +39,41 @@ class SpeculativeAlgorithm(IntEnum):
if name is not None:
name = name.upper()
return name_map[name]
class SpecInputType(IntEnum):
# NOTE: introduce this to distinguish the SpecInput types of multiple algorithms when asserting in attention backends.
# If all algorithms can share the same datastrucutre of draft_input and verify_input, consider simplify it
EAGLE_DRAFT = auto()
EAGLE_VERIFY = auto()
NGRAM_VERIFY = auto()
class SpecInput(ABC):
def __init__(self, spec_input_type: SpecInputType):
self.spec_input_type = spec_input_type
def is_draft_input(self) -> bool:
# FIXME: remove this function which is only used for assertion
# or use another variable name like `draft_input` to substitute `spec_info`
return self.spec_input_type == SpecInputType.EAGLE_DRAFT
def is_verify_input(self) -> bool:
return self.spec_input_type in {
SpecInputType.EAGLE_VERIFY,
SpecInputType.NGRAM_VERIFY,
}
@abstractmethod
def get_spec_adjust_token_coefficient(self) -> Tuple[int, int]:
pass
def get_spec_adjusted_global_num_tokens(
self, forward_batch: ModelWorkerBatch
) -> Tuple[List[int], List[int]]:
c1, c2 = self.get_spec_adjust_token_coefficient()
global_num_tokens = [x * c1 for x in forward_batch.global_num_tokens]
global_num_tokens_for_logprob = [
x * c2 for x in forward_batch.global_num_tokens_for_logprob
]
return global_num_tokens, global_num_tokens_for_logprob
This diff is collapsed.
......@@ -30,7 +30,8 @@ from sglang.srt.model_executor.forward_batch_info import (
)
from sglang.srt.operations import execute_operations, execute_overlapped_operations
from sglang.srt.operations_strategy import OperationsStrategy
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.spec_info import SpecInput
from sglang.srt.utils import BumpAllocator, empty_context, get_bool_env_var, is_hip
if TYPE_CHECKING:
......@@ -48,7 +49,7 @@ logger = logging.getLogger(__name__)
def get_token_num_per_seq(
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
spec_info: Optional[SpecInput] = None,
):
if forward_mode.is_target_verify():
return spec_info.draft_token_num
......@@ -273,7 +274,7 @@ def compute_split_token_index(
def compute_split_indices_for_cuda_graph_replay(
forward_mode: ForwardMode,
cuda_graph_num_tokens: int,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
spec_info: Optional[SpecInput],
):
forward_mode_for_tbo_split = (
forward_mode if forward_mode != ForwardMode.IDLE else ForwardMode.DECODE
......@@ -333,7 +334,7 @@ class TboCudaGraphRunnerPlugin:
forward_mode: ForwardMode,
bs: int,
num_token_non_padded: int,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
spec_info: Optional[SpecInput],
):
token_num_per_seq = get_token_num_per_seq(
forward_mode=forward_mode, spec_info=spec_info
......
......@@ -7,7 +7,6 @@ or
python3 test_forward_split_prefill.py
"""
import time
import unittest
import numpy as np
......@@ -16,7 +15,7 @@ import torch
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment