Unverified Commit 73d4a5f8 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Organize spec-related data structures (#10735)

parent 7fb551a7
...@@ -821,7 +821,7 @@ class CudaGraphRunner: ...@@ -821,7 +821,7 @@ class CudaGraphRunner:
self.model_runner.spec_algorithm.is_eagle() self.model_runner.spec_algorithm.is_eagle()
or self.model_runner.spec_algorithm.is_standalone() or self.model_runner.spec_algorithm.is_standalone()
): ):
from sglang.srt.speculative.eagle_utils import EagleVerifyInput from sglang.srt.speculative.eagle_info import EagleVerifyInput
if self.model_runner.is_draft_worker: if self.model_runner.is_draft_worker:
raise RuntimeError("This should not happen.") raise RuntimeError("This should not happen.")
......
...@@ -45,13 +45,7 @@ from sglang.srt.layers.dp_attention import ( ...@@ -45,13 +45,7 @@ from sglang.srt.layers.dp_attention import (
get_attention_tp_size, get_attention_tp_size,
set_dp_buffer_len, set_dp_buffer_len,
) )
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding from sglang.srt.utils import get_compiler_backend, is_npu, support_triton
from sglang.srt.utils import (
flatten_nested_list,
get_compiler_backend,
is_npu,
support_triton,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
...@@ -60,8 +54,7 @@ if TYPE_CHECKING: ...@@ -60,8 +54,7 @@ if TYPE_CHECKING:
from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.spec_info import SpecInput, SpeculativeAlgorithm
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
_is_npu = is_npu() _is_npu = is_npu()
...@@ -293,7 +286,7 @@ class ForwardBatch: ...@@ -293,7 +286,7 @@ class ForwardBatch:
global_forward_mode: Optional[ForwardMode] = None global_forward_mode: Optional[ForwardMode] = None
# Speculative decoding # Speculative decoding
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None spec_info: Optional[SpecInput] = None
spec_algorithm: SpeculativeAlgorithm = None spec_algorithm: SpeculativeAlgorithm = None
capture_hidden_mode: CaptureHiddenMode = None capture_hidden_mode: CaptureHiddenMode = None
...@@ -364,33 +357,14 @@ class ForwardBatch: ...@@ -364,33 +357,14 @@ class ForwardBatch:
# For MLP sync # For MLP sync
if batch.global_num_tokens is not None: if batch.global_num_tokens is not None:
from sglang.srt.speculative.eagle_utils import (
EagleDraftInput,
EagleVerifyInput,
)
assert batch.global_num_tokens_for_logprob is not None assert batch.global_num_tokens_for_logprob is not None
# process global_num_tokens and global_num_tokens_for_logprob # process global_num_tokens and global_num_tokens_for_logprob
if batch.spec_info is not None: if batch.spec_info is not None:
if isinstance(batch.spec_info, EagleDraftInput): spec_info: SpecInput = batch.spec_info
global_num_tokens = [ global_num_tokens, global_num_tokens_for_logprob = (
x * batch.spec_info.num_tokens_per_batch spec_info.get_spec_adjusted_global_num_tokens(batch)
for x in batch.global_num_tokens )
]
global_num_tokens_for_logprob = [
x * batch.spec_info.num_tokens_for_logprob_per_batch
for x in batch.global_num_tokens_for_logprob
]
else:
assert isinstance(batch.spec_info, EagleVerifyInput)
global_num_tokens = [
x * batch.spec_info.draft_token_num
for x in batch.global_num_tokens
]
global_num_tokens_for_logprob = [
x * batch.spec_info.draft_token_num
for x in batch.global_num_tokens_for_logprob
]
else: else:
global_num_tokens = batch.global_num_tokens global_num_tokens = batch.global_num_tokens
global_num_tokens_for_logprob = batch.global_num_tokens_for_logprob global_num_tokens_for_logprob = batch.global_num_tokens_for_logprob
...@@ -669,9 +643,6 @@ class ForwardBatch: ...@@ -669,9 +643,6 @@ class ForwardBatch:
) )
def prepare_mlp_sync_batch(self, model_runner: ModelRunner): def prepare_mlp_sync_batch(self, model_runner: ModelRunner):
from sglang.srt.speculative.eagle_utils import EagleDraftInput
assert self.global_num_tokens_cpu is not None assert self.global_num_tokens_cpu is not None
assert self.global_num_tokens_for_logprob_cpu is not None assert self.global_num_tokens_for_logprob_cpu is not None
...@@ -768,7 +739,8 @@ class ForwardBatch: ...@@ -768,7 +739,8 @@ class ForwardBatch:
if self.extend_seq_lens is not None: if self.extend_seq_lens is not None:
self.extend_seq_lens = self._pad_tensor_to_size(self.extend_seq_lens, bs) self.extend_seq_lens = self._pad_tensor_to_size(self.extend_seq_lens, bs)
if self.spec_info is not None and isinstance(self.spec_info, EagleDraftInput): if self.spec_info is not None and self.spec_info.is_draft_input():
# FIXME(lsyin): remove this isinstance logic
spec_info = self.spec_info spec_info = self.spec_info
self.output_cache_loc_backup = self.out_cache_loc self.output_cache_loc_backup = self.out_cache_loc
self.hidden_states_backup = spec_info.hidden_states self.hidden_states_backup = spec_info.hidden_states
......
...@@ -20,7 +20,7 @@ from sglang.srt.model_executor.forward_batch_info import ( ...@@ -20,7 +20,7 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch, ForwardBatch,
ForwardMode, ForwardMode,
) )
from sglang.srt.speculative.eagle_utils import EagleDraftInput from sglang.srt.speculative.eagle_info import EagleDraftInput
from sglang.srt.utils import ( from sglang.srt.utils import (
require_attn_tp_gather, require_attn_tp_gather,
require_gathered_buffer, require_gathered_buffer,
......
...@@ -21,7 +21,8 @@ from sglang.srt.model_executor.forward_batch_info import ( ...@@ -21,7 +21,8 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch, ForwardBatch,
ForwardMode, ForwardMode,
) )
from sglang.srt.speculative.eagle_utils import EagleDraftInput, fast_topk from sglang.srt.speculative.eagle_info import EagleDraftInput
from sglang.srt.speculative.spec_utils import fast_topk
from sglang.srt.utils import ( from sglang.srt.utils import (
require_attn_tp_gather, require_attn_tp_gather,
require_gathered_buffer, require_gathered_buffer,
......
from __future__ import annotations
import copy
import logging import logging
import os from copy import copy
import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional from typing import List, Optional, Tuple
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import triton
import triton.language as tl
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
from sglang.srt.environ import envs
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import apply_custom_logit_processor from sglang.srt.layers.sampler import apply_custom_logit_processor
from sglang.srt.managers.schedule_batch import ( from sglang.srt.managers.schedule_batch import (
Req,
ScheduleBatch, ScheduleBatch,
get_last_loc, get_last_loc,
global_server_args_dict, global_server_args_dict,
) )
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
from sglang.srt.speculative.spec_info import SpecInput, SpecInputType
from sglang.srt.speculative.spec_utils import (
SIMULATE_ACC_LEN,
TREE_SPEC_KERNEL_AVAILABLE,
_generate_simulated_accept_index,
align_evict_mask_to_page_size,
assign_req_to_token_pool,
create_accept_length_filter,
create_extend_after_decode_spec_info,
filter_finished_cache_loc_kernel,
get_src_tgt_cache_loc,
get_target_cache_loc,
)
from sglang.srt.utils import is_cuda, is_hip, next_power_of_2 from sglang.srt.utils import is_cuda, is_hip, next_power_of_2
if is_cuda(): if is_cuda():
from sgl_kernel import ( from sgl_kernel import (
fast_topk,
top_k_renorm_prob, top_k_renorm_prob,
top_p_renorm_prob, top_p_renorm_prob,
tree_speculative_sampling_target_only, tree_speculative_sampling_target_only,
verify_tree_greedy, verify_tree_greedy,
) )
elif is_hip(): elif is_hip():
from sgl_kernel import fast_topk, verify_tree_greedy from sgl_kernel import verify_tree_greedy
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Simulate acceptance length for benchmarking purposes
SIMULATE_ACC_LEN = envs.SGLANG_SIMULATE_ACC_LEN.get() # turn off if < 0
SIMULATE_ACC_METHOD = envs.SGLANG_SIMULATE_ACC_METHOD.get()
TREE_TRAVERSE_TIME_THRESHOLD = 1 # TODO: set this properly
TREE_SPEC_KERNEL_AVAILABLE = "tree_speculative_sampling_target_only" in globals()
@dataclass
class EagleDraftInput:
# The inputs for decode
# shape: (b, topk)
topk_p: torch.Tensor = None
topk_index: torch.Tensor = None
# shape: (b, hidden_size)
hidden_states: torch.Tensor = None
capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.FULL
# Inputs for extend
# shape: (b,)
verified_id: torch.Tensor = None
accept_length: torch.Tensor = None
accept_length_cpu: List[int] = None
# Inputs for the attention backends
# shape: (b + 1,)
kv_indptr: torch.Tensor = None
kv_indices: torch.Tensor = None
# Shape info for padding
num_tokens_per_batch: int = -1
num_tokens_for_logprob_per_batch: int = -1
# Inputs for draft extend
# shape: (b,)
seq_lens_for_draft_extend: torch.Tensor = None
req_pool_indices_for_draft_extend: torch.Tensor = None
def prepare_for_extend(self, batch: ScheduleBatch):
if batch.forward_mode.is_idle():
return
# Prefill only generate 1 token.
assert len(self.verified_id) == len(batch.seq_lens)
pt = 0
for i, extend_len in enumerate(batch.extend_lens):
input_ids = batch.input_ids[pt : pt + extend_len]
batch.input_ids[pt : pt + extend_len] = torch.cat(
(input_ids[1:], self.verified_id[i].reshape(1))
)
pt += extend_len
@classmethod
def create_idle_input(
cls,
device: torch.device,
hidden_size: int,
dtype: torch.dtype,
topk: int,
capture_hidden_mode: CaptureHiddenMode,
):
return cls(
verified_id=torch.empty((0,), device=device, dtype=torch.int32),
hidden_states=torch.empty((0, hidden_size), device=device, dtype=dtype),
topk_p=torch.empty((0, topk), device=device, dtype=torch.float32),
topk_index=torch.empty((0, topk), device=device, dtype=torch.int64),
capture_hidden_mode=capture_hidden_mode,
accept_length=torch.empty((0,), device=device, dtype=torch.int32),
accept_length_cpu=[],
)
def prepare_extend_after_decode(
self,
batch: ScheduleBatch,
speculative_num_steps: int,
):
if batch.forward_mode.is_idle():
return
batch.input_ids = self.verified_id
batch.extend_lens = [x + 1 for x in batch.spec_info.accept_length_cpu]
batch.extend_num_tokens = sum(batch.extend_lens)
batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend
batch.return_logprob = False
batch.return_hidden_states = False
self.capture_hidden_mode = CaptureHiddenMode.LAST
self.accept_length.add_(1)
self.positions = torch.empty_like(batch.input_ids, dtype=torch.long)
self.verified_id = torch.empty_like(self.accept_length, dtype=torch.int32)
create_extend_after_decode_spec_info[(len(batch.seq_lens),)](
batch.input_ids,
batch.seq_lens,
self.accept_length,
self.positions,
self.verified_id,
next_power_of_2(max(speculative_num_steps + 1, len(batch.seq_lens))),
)
def generate_attn_arg_prefill(
self,
req_pool_indices: torch.Tensor,
paged_kernel_lens: torch.Tensor,
paged_kernel_lens_sum: int,
req_to_token: torch.Tensor,
):
bs = self.accept_length.numel()
qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0)
cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
if paged_kernel_lens_sum is None:
paged_kernel_lens_sum = cum_kv_seq_len[-1]
kv_indices = torch.empty(
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
)
create_flashinfer_kv_indices_triton[(bs,)](
req_to_token,
req_pool_indices,
paged_kernel_lens,
cum_kv_seq_len,
None,
kv_indices,
req_to_token.size(1),
)
return kv_indices, cum_kv_seq_len, qo_indptr, None
def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True):
if has_been_filtered:
# in eagle_utils.py:verify, we have already filtered the batch by `unfinished_index`
# therefore, we don't need to filter the batch again in scheduler
if len(new_indices) != len(self.topk_p):
logger.warning(
f"length of new_indices: {len(new_indices)} != length of topk_p: {len(self.topk_p)}, this should not happen"
)
self.topk_p = self.topk_p[: len(new_indices)]
self.topk_index = self.topk_index[: len(new_indices)]
self.hidden_states = self.hidden_states[: len(new_indices)]
self.verified_id = self.verified_id[: len(new_indices)]
else:
# in some cases(e.g draft_extend), we have not filtered the batch by `unfinished_index`
self.topk_p = self.topk_p[new_indices]
self.topk_index = self.topk_index[new_indices]
self.hidden_states = self.hidden_states[new_indices]
self.verified_id = self.verified_id[new_indices]
def merge_batch(self, spec_info: EagleDraftInput):
if self.hidden_states is None:
self.hidden_states = spec_info.hidden_states
self.verified_id = spec_info.verified_id
self.topk_p = spec_info.topk_p
self.topk_index = spec_info.topk_index
return
if spec_info.hidden_states is None:
return
self.hidden_states = torch.cat(
[self.hidden_states, spec_info.hidden_states], axis=0
)
self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0)
self.topk_p = torch.cat([self.topk_p, spec_info.topk_p])
self.topk_index = torch.cat([self.topk_index, spec_info.topk_index])
@dataclass @dataclass
class EagleVerifyOutput: class EagleVerifyInput(SpecInput):
# Draft input batch
draft_input: EagleDraftInput
# Logit outputs from target worker
logits_output: LogitsProcessorOutput
# Accepted token ids including the bonus token
verified_id: torch.Tensor
# Accepted token length per sequence in a batch in CPU.
accept_length_per_req_cpu: List[int]
# Accepted indices from logits_output.next_token_logits
accepted_indices: torch.Tensor
@dataclass
class EagleVerifyInput:
draft_token: torch.Tensor draft_token: torch.Tensor
custom_mask: torch.Tensor custom_mask: torch.Tensor
positions: torch.Tensor positions: torch.Tensor
...@@ -245,6 +62,12 @@ class EagleVerifyInput: ...@@ -245,6 +62,12 @@ class EagleVerifyInput:
seq_lens_cpu: torch.Tensor seq_lens_cpu: torch.Tensor
grammar: BaseGrammarObject = None grammar: BaseGrammarObject = None
def __post_init__(self):
super().__init__(SpecInputType.EAGLE_VERIFY)
def get_spec_adjust_token_coefficient(self) -> Tuple[int, int]:
return self.draft_token_num, self.draft_token_num
@classmethod @classmethod
def create_idle_input(cls, topk: int, spec_steps: int, num_verify_tokens: int): def create_idle_input(cls, topk: int, spec_steps: int, num_verify_tokens: int):
return cls( return cls(
...@@ -724,574 +547,184 @@ class EagleVerifyInput: ...@@ -724,574 +547,184 @@ class EagleVerifyInput:
) )
@triton.jit @dataclass
def create_extend_after_decode_spec_info( class EagleDraftInput(SpecInput):
verified_id, # The inputs for decode
seq_lens, # shape: (b, topk)
accept_lens, topk_p: torch.Tensor = None
positions, topk_index: torch.Tensor = None
new_verified_id, # shape: (b, hidden_size)
bs_upper: tl.constexpr, hidden_states: torch.Tensor = None
): capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.FULL
pid = tl.program_id(axis=0)
offsets = tl.arange(0, bs_upper)
seq_length = tl.load(seq_lens + pid)
accept_length = tl.load(accept_lens + pid)
accept_len_cumsum = tl.sum(
tl.load(accept_lens + offsets, mask=offsets < pid, other=0)
)
positions_ptr = positions + accept_len_cumsum
mask = offsets < accept_length
tl.store(positions_ptr + offsets, seq_length - accept_length + offsets, mask)
accept_len_cumsum += accept_length - 1
verified_id_data = tl.load(verified_id + accept_len_cumsum)
tl.store(new_verified_id + pid, verified_id_data)
@triton.jit
def assign_req_to_token_pool(
req_pool_indices,
req_to_token,
start_offset,
end_offset,
out_cache_loc,
pool_len: tl.constexpr,
bs_upper: tl.constexpr,
):
BLOCK_SIZE: tl.constexpr = 32
pid = tl.program_id(axis=0)
kv_start = tl.load(start_offset + pid)
kv_end = tl.load(end_offset + pid)
token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
length_offset = tl.arange(0, bs_upper)
start = tl.load(start_offset + length_offset, mask=length_offset < pid, other=0)
end = tl.load(end_offset + length_offset, mask=length_offset < pid, other=0)
out_offset = tl.sum(end - start, axis=0)
out_cache_ptr = out_cache_loc + out_offset
save_offset = tl.arange(0, BLOCK_SIZE) + kv_start
load_offset = tl.arange(0, BLOCK_SIZE)
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
for _ in range(num_loop):
mask = save_offset < kv_end
data = tl.load(out_cache_ptr + load_offset, mask=mask)
tl.store(token_pool + save_offset, data, mask=mask)
save_offset += BLOCK_SIZE
load_offset += BLOCK_SIZE
@triton.jit
def assign_draft_cache_locs(
req_pool_indices,
req_to_token,
seq_lens,
extend_lens,
num_new_pages_per_topk,
out_cache_loc,
pool_len: tl.constexpr,
topk: tl.constexpr,
speculative_num_steps: tl.constexpr,
page_size: tl.constexpr,
bs_upper: tl.constexpr,
iter_upper: tl.constexpr,
):
BLOCK_SIZE: tl.constexpr = 128
pid = tl.program_id(axis=0)
if page_size == 1 or topk == 1:
copy_len = topk * speculative_num_steps
out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps
else:
bs_offset = tl.arange(0, bs_upper)
copy_len = tl.load(extend_lens + pid)
cum_copy_len = tl.sum(tl.load(extend_lens + bs_offset, mask=bs_offset < pid))
out_cache_ptr = out_cache_loc + cum_copy_len
# Part 1: Copy from out_cache_loc to req_to_token
kv_start = tl.load(seq_lens + pid)
token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
num_loop = tl.cdiv(copy_len, BLOCK_SIZE)
for i in range(num_loop):
copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
mask = copy_offset < copy_len
data = tl.load(out_cache_ptr + copy_offset, mask=mask)
tl.store(token_pool + kv_start + copy_offset, data, mask=mask)
if page_size == 1 or topk == 1:
return
# Part 2: Copy the indices for the last partial page
prefix_len = tl.load(seq_lens + pid)
last_page_len = prefix_len % page_size
offsets = tl.arange(0, page_size)
mask = offsets < last_page_len
num_new_pages_per_topk_ = tl.load(num_new_pages_per_topk + pid)
prefix_base = token_pool + prefix_len - last_page_len
for topk_id in range(topk):
value = tl.load(prefix_base + offsets, mask=mask)
tl.store(
prefix_base + topk_id * num_new_pages_per_topk_ * page_size + offsets,
value,
mask=mask,
)
# Part 3: Remove the padding in out_cache_loc
iter_offest = tl.arange(0, iter_upper)
for topk_id in range(topk):
indices = tl.load(
prefix_base
+ topk_id * num_new_pages_per_topk_ * page_size
+ last_page_len
+ iter_offest,
mask=iter_offest < speculative_num_steps,
)
tl.store(
out_cache_loc
+ pid * topk * speculative_num_steps
+ topk_id * speculative_num_steps
+ iter_offest,
indices,
mask=iter_offest < speculative_num_steps,
)
# Inputs for extend
# shape: (b,)
verified_id: torch.Tensor = None
accept_length: torch.Tensor = None
accept_length_cpu: List[int] = None
@triton.jit # Inputs for the attention backends
def generate_draft_decode_kv_indices( # shape: (b + 1,)
req_pool_indices, kv_indptr: torch.Tensor = None
req_to_token, kv_indices: torch.Tensor = None
paged_kernel_lens,
kv_indices,
kv_indptr,
positions,
pool_len: tl.constexpr,
kv_indices_stride: tl.constexpr,
kv_indptr_stride: tl.constexpr,
bs_upper: tl.constexpr,
iter_upper: tl.constexpr,
num_tokens_upper: tl.constexpr,
page_size: tl.constexpr,
):
BLOCK_SIZE: tl.constexpr = 128
iters = tl.program_id(axis=0)
bid = tl.program_id(axis=1)
topk_id = tl.program_id(axis=2)
num_steps = tl.num_programs(axis=0)
num_seqs = tl.num_programs(axis=1)
topk = tl.num_programs(axis=2)
kv_indices += kv_indices_stride * iters
kv_indptr += kv_indptr_stride * iters
iters += 1
load_offset = tl.arange(0, bs_upper)
seq_lens = tl.load(paged_kernel_lens + load_offset, mask=load_offset < bid, other=0)
seq_len = tl.load(paged_kernel_lens + bid)
cum_seq_len = tl.sum(seq_lens)
# Update kv_indices
kv_offset = cum_seq_len * topk + bid * iters * topk + topk_id * (seq_len + iters)
kv_ptr = kv_indices + kv_offset
token_pool_ptr = req_to_token + tl.load(req_pool_indices + bid) * pool_len
kv_offset = tl.arange(0, BLOCK_SIZE)
num_loop = tl.cdiv(seq_len, BLOCK_SIZE)
for _ in range(num_loop):
mask = kv_offset < seq_len
data = tl.load(token_pool_ptr + kv_offset, mask=mask)
tl.store(kv_ptr + kv_offset, data, mask=mask)
kv_offset += BLOCK_SIZE
extend_offset = tl.arange(0, iter_upper)
if page_size == 1 or topk == 1:
extend_data = tl.load(
token_pool_ptr + seq_len + topk_id * num_steps + tl.arange(0, iter_upper),
mask=extend_offset < iters,
)
else:
prefix_len = seq_len
last_page_len = prefix_len % page_size
num_new_pages_per_topk = (
last_page_len + num_steps + page_size - 1
) // page_size
prefix_base = seq_len // page_size * page_size
start = (
prefix_base + topk_id * num_new_pages_per_topk * page_size + last_page_len
)
extend_data = tl.load(
token_pool_ptr + start + extend_offset,
mask=extend_offset < iters,
)
tl.store(kv_ptr + seq_len + extend_offset, extend_data, mask=extend_offset < iters) # Shape info for padding
num_tokens_per_batch: int = -1
# Update kv_indptr num_tokens_for_logprob_per_batch: int = -1
bs_offset = tl.arange(0, num_tokens_upper)
zid = bid * topk + topk_id
if zid == 0:
zid = num_seqs * topk
positions = tl.load(positions + bs_offset, mask=bs_offset < zid, other=0)
base = tl.sum(positions)
tl.store(kv_indptr + zid, base + zid * iters)
@triton.jit
def align_evict_mask_to_page_size(
seq_lens,
evict_mask,
page_size: tl.constexpr,
num_draft_tokens: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
t_range = tl.arange(0, BLOCK_SIZE)
bid = tl.program_id(axis=0)
seq_len = tl.load(seq_lens + bid)
io_mask = t_range < num_draft_tokens
mask_row = tl.load(
evict_mask + bid * num_draft_tokens + t_range, mask=io_mask, other=0
)
num_trues = tl.sum(mask_row) # Inputs for draft extend
num_false = num_draft_tokens - num_trues # shape: (b,)
seq_lens_for_draft_extend: torch.Tensor = None
start = (seq_len + num_false - 1) // page_size * page_size - seq_len req_pool_indices_for_draft_extend: torch.Tensor = None
for i in range(max(start, 0), min(start + page_size, num_draft_tokens)):
tl.store(evict_mask + bid * num_draft_tokens + i, False)
@triton.jit
def get_target_cache_loc(
tgt_cache_loc,
to_free_slots,
accept_length,
to_free_num_slots,
out_cache_loc,
num_verify_tokens: tl.constexpr,
num_verify_tokens_upper: tl.constexpr,
bs_upper: tl.constexpr,
):
bid = tl.program_id(axis=0)
offset = tl.arange(0, num_verify_tokens_upper)
bs_offset = tl.arange(0, bs_upper)
# write the first part to tgt_cache_loc
accept_len_all = tl.load(accept_length + bs_offset, mask=bs_offset < bid)
tgt_cache_loc_start = tl.sum(accept_len_all) + bid
copy_len = tl.load(accept_length + bid) + 1
out_cache_loc_row = tl.load(
out_cache_loc + bid * num_verify_tokens + offset, mask=offset < copy_len
)
tl.store(
tgt_cache_loc + tgt_cache_loc_start + offset,
out_cache_loc_row,
mask=offset < copy_len,
)
# write the second part to to_free_num_pages def __post_init__(self):
to_free_num_slots_all = tl.load(to_free_num_slots + bs_offset, mask=bs_offset < bid) super().__init__(SpecInputType.EAGLE_DRAFT)
to_free_num_slots_cur = tl.load(to_free_num_slots + bid)
out_cache_loc_start = num_verify_tokens - to_free_num_slots_cur
to_free_slots_start = tl.sum(to_free_num_slots_all)
copy_len = to_free_num_slots_cur def get_spec_adjust_token_coefficient(self) -> Tuple[int, int]:
out_cache_loc_row = tl.load( return self.num_tokens_per_batch, self.num_tokens_for_logprob_per_batch
out_cache_loc + bid * num_verify_tokens + out_cache_loc_start + offset,
mask=offset < copy_len,
)
tl.store(
to_free_slots + to_free_slots_start + offset,
out_cache_loc_row,
mask=offset < copy_len,
)
def prepare_for_extend(self, batch: ScheduleBatch):
@torch.compile(dynamic=True) if batch.forward_mode.is_idle():
def get_src_tgt_cache_loc( return
seq_lens: torch.Tensor,
out_cache_loc: torch.Tensor,
accept_index: torch.Tensor,
accept_length: torch.Tensor,
draft_token_num: int,
page_size: int,
):
src_cache_loc = out_cache_loc[accept_index]
tgt_cache_loc = torch.empty_like(src_cache_loc)
extended_len = seq_lens + draft_token_num
keep_len = torch.minimum(
(seq_lens + accept_length + 1 + page_size - 1) // page_size * page_size,
extended_len,
)
to_free_num_slots = extended_len - keep_len
return src_cache_loc, tgt_cache_loc, to_free_num_slots
@triton.jit
def filter_finished_cache_loc_kernel(
out_cache_loc,
tgt_cache_loc,
accept_length,
accept_length_filter,
bs_upper: tl.constexpr,
num_verify_tokens_upper: tl.constexpr,
):
bid = tl.program_id(0)
bs_offset = tl.arange(0, bs_upper)
accept_length_all = tl.load(accept_length + bs_offset, mask=bs_offset < bid)
old_start = tl.sum(accept_length_all) + bid
accept_length_filter_all = tl.load(
accept_length_filter + bs_offset, mask=bs_offset < bid
)
new_start = tl.sum(accept_length_filter_all)
copy_len = tl.load(accept_length_filter + bid) # Prefill only generate 1 token.
copy_offset = tl.arange(0, num_verify_tokens_upper) assert len(self.verified_id) == len(batch.seq_lens)
value = tl.load(
tgt_cache_loc + old_start + copy_offset, mask=copy_offset < copy_len
)
tl.store(
out_cache_loc + new_start + copy_offset, value, mask=copy_offset < copy_len
)
pt = 0
for i, extend_len in enumerate(batch.extend_lens):
input_ids = batch.input_ids[pt : pt + extend_len]
batch.input_ids[pt : pt + extend_len] = torch.cat(
(input_ids[1:], self.verified_id[i].reshape(1))
)
pt += extend_len
@torch.compile(dynamic=True) @classmethod
def create_accept_length_filter( def create_idle_input(
accept_length: torch.Tensor, cls,
unfinished_index_device: torch.Tensor, device: torch.device,
seq_lens: torch.Tensor, hidden_size: int,
): dtype: torch.dtype,
accept_length_filter = torch.zeros_like(accept_length) topk: int,
accept_length_filter[unfinished_index_device] = ( capture_hidden_mode: CaptureHiddenMode,
accept_length[unfinished_index_device] + 1 ):
) return cls(
seq_lens.add_(accept_length + 1) verified_id=torch.empty((0,), device=device, dtype=torch.int32),
return accept_length_filter hidden_states=torch.empty((0, hidden_size), device=device, dtype=dtype),
topk_p=torch.empty((0, topk), device=device, dtype=torch.float32),
topk_index=torch.empty((0, topk), device=device, dtype=torch.int64),
@torch.compile(dynamic=True) capture_hidden_mode=capture_hidden_mode,
def select_top_k_tokens( accept_length=torch.empty((0,), device=device, dtype=torch.int32),
i: int, accept_length_cpu=[],
topk_p: torch.Tensor,
topk_index: torch.Tensor,
hidden_states: torch.Tensor,
scores: torch.Tensor,
topk: int,
):
if i == 0:
# The first step after extend
input_ids = topk_index.flatten()
hidden_states = hidden_states.repeat_interleave(topk, dim=0)
scores = topk_p # shape: (b, topk)
tree_info = (
topk_p.unsqueeze(1), # shape: (b, 1, topk)
topk_index, # shape: (b, topk)
torch.arange(-1, topk, dtype=torch.long, device="cuda")
.unsqueeze(0)
.repeat(topk_p.shape[0], 1), # shape: (b, topk + 1)
)
else:
# The later decode steps
expand_scores = torch.mul(
scores.unsqueeze(2), topk_p.reshape(-1, topk, topk)
) # (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk)
topk_cs_p, topk_cs_index = fast_topk(
expand_scores.flatten(start_dim=1), topk, dim=-1
) # (b, topk)
scores = topk_cs_p # shape: (b, topk)
topk_index = topk_index.reshape(-1, topk**2)
input_ids = torch.gather(topk_index, index=topk_cs_index, dim=1).flatten()
if hidden_states.shape[0] > 0:
selected_input_index = topk_cs_index.flatten() // topk + torch.arange(
0, hidden_states.shape[0], step=topk, device="cuda"
).repeat_interleave(topk)
hidden_states = hidden_states[selected_input_index, :]
tree_info = (
expand_scores, # shape: (b, topk, topk)
topk_index, # shape: (b, topk * topk)
topk_cs_index + (topk**2 * (i - 1) + topk), # shape: (b, topk)
) )
return input_ids, hidden_states, scores, tree_info def prepare_extend_after_decode(
self,
batch: ScheduleBatch,
def _generate_simulated_accept_index( speculative_num_steps: int,
accept_index, ):
predict,
accept_length, if batch.forward_mode.is_idle():
bs, return
spec_steps,
simulate_acc_len: float = SIMULATE_ACC_LEN, batch.input_ids = self.verified_id
simulate_acc_method: str = SIMULATE_ACC_METHOD, batch.extend_lens = [x + 1 for x in batch.spec_info.accept_length_cpu]
): batch.extend_num_tokens = sum(batch.extend_lens)
assert simulate_acc_len > 0.0 batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend
if simulate_acc_method == "multinomial": batch.return_logprob = False
simulated_values = torch.normal( batch.return_hidden_states = False
mean=simulate_acc_len,
std=1.0,
size=(1,),
device="cpu",
)
# clamp simulated values to be between 1 and self.spec_steps
simulated_values = torch.clamp(simulated_values, min=1.0, max=spec_steps + 1)
simulate_acc_len = int(simulated_values.round().item())
elif simulate_acc_method == "match-expected":
# multinomial sampling does not match the expected length
# we keep it for the sake of compatibility of existing tests
# but it's better to use "match-expected" for the cases that need to
# match the expected length, One caveat is that this will only sample
# either round down or round up of the expected length
simulate_acc_len = max(1.0, min(spec_steps + 1, simulate_acc_len))
lower = int(simulate_acc_len // 1)
upper = lower + 1 if lower < spec_steps + 1 else lower
if lower == upper:
simulate_acc_len = lower
else:
weight_upper = simulate_acc_len - lower
weight_lower = 1.0 - weight_upper
probs = torch.tensor([weight_lower, weight_upper], device="cpu")
sampled_index = torch.multinomial(probs, num_samples=1)
simulate_acc_len = lower if sampled_index == 0 else upper
else:
raise ValueError(f"Invalid simulate_acc_method: {SIMULATE_ACC_METHOD}")
accept_indx_first_col = accept_index[:, 0].view(-1, 1)
sim_accept_index = torch.full(
(bs, spec_steps + 1), -1, dtype=torch.int32, device="cuda"
)
sim_accept_index[:, :simulate_acc_len] = accept_indx_first_col + torch.arange(
simulate_acc_len, device=accept_index.device
)
accept_length.fill_(simulate_acc_len - 1)
predict.fill_(100) # some legit token id
return sim_accept_index
def traverse_tree(
retrieve_next_token: torch.Tensor,
retrieve_next_sibling: torch.Tensor,
draft_tokens: torch.Tensor,
grammar: BaseGrammarObject,
allocate_token_bitmask: torch.Tensor,
):
"""
Traverse the tree constructed by the draft model to generate the logits mask.
"""
assert (
retrieve_next_token.shape == retrieve_next_sibling.shape == draft_tokens.shape
)
allocate_token_bitmask.fill_(0) self.capture_hidden_mode = CaptureHiddenMode.LAST
self.accept_length.add_(1)
self.positions = torch.empty_like(batch.input_ids, dtype=torch.long)
self.verified_id = torch.empty_like(self.accept_length, dtype=torch.int32)
def dfs( create_extend_after_decode_spec_info[(len(batch.seq_lens),)](
curr: int, batch.input_ids,
retrieve_next_token: torch.Tensor, batch.seq_lens,
retrieve_next_sibling: torch.Tensor, self.accept_length,
parent_pos: int, self.positions,
self.verified_id,
next_power_of_2(max(speculative_num_steps + 1, len(batch.seq_lens))),
)
def generate_attn_arg_prefill(
self,
req_pool_indices: torch.Tensor,
paged_kernel_lens: torch.Tensor,
paged_kernel_lens_sum: int,
req_to_token: torch.Tensor,
): ):
if curr == 0: bs = self.accept_length.numel()
# the first token generated by the target model, and thus it is always qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
# accepted from the previous iteration qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0)
accepted = True cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
else: cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
parent_bitmask = allocate_token_bitmask[parent_pos]
curr_token_id = draft_tokens[curr]
# 32 boolean bitmask values are packed into 32-bit integers
accepted = (
parent_bitmask[curr_token_id // 32] & (1 << (curr_token_id % 32))
) != 0
if accepted:
if curr != 0:
# Accept the current token
grammar.accept_token(draft_tokens[curr])
if not grammar.is_terminated():
# Generate the bitmask for the current token
grammar.fill_vocab_mask(allocate_token_bitmask, curr)
if retrieve_next_token[curr] != -1:
# Visit the child node
dfs(
retrieve_next_token[curr],
retrieve_next_token,
retrieve_next_sibling,
curr,
)
if curr != 0: if paged_kernel_lens_sum is None:
# Rollback the current token paged_kernel_lens_sum = cum_kv_seq_len[-1]
grammar.rollback(1)
if retrieve_next_sibling[curr] != -1:
# Visit the sibling node
dfs(
retrieve_next_sibling[curr],
retrieve_next_token,
retrieve_next_sibling,
parent_pos,
)
dfs(0, retrieve_next_token, retrieve_next_sibling, -1) kv_indices = torch.empty(
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
)
def generate_token_bitmask(
reqs: List[Req], create_flashinfer_kv_indices_triton[(bs,)](
verify_input: EagleVerifyInput, req_to_token,
retrieve_next_token_cpu: torch.Tensor, req_pool_indices,
retrieve_next_sibling_cpu: torch.Tensor, paged_kernel_lens,
draft_tokens_cpu: torch.Tensor, cum_kv_seq_len,
vocab_size: int, None,
): kv_indices,
""" req_to_token.size(1),
Generate the logit mask for structured output. )
Draft model's token can be either valid or invalid with respect to the grammar. return kv_indices, cum_kv_seq_len, qo_indptr, None
We need to perform DFS to
1. figure out which tokens are accepted by the grammar. def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True):
2. if so, what is the corresponding logit mask. if has_been_filtered:
""" # in eagle_utils.py:verify, we have already filtered the batch by `unfinished_index`
# therefore, we don't need to filter the batch again in scheduler
num_draft_tokens = draft_tokens_cpu.shape[-1] if len(new_indices) != len(self.topk_p):
allocate_token_bitmask = None
assert len(reqs) == retrieve_next_token_cpu.shape[0]
grammar = None
for i, req in enumerate(reqs):
if req.grammar is not None:
if allocate_token_bitmask is None:
allocate_token_bitmask = req.grammar.allocate_vocab_mask(
vocab_size=vocab_size,
batch_size=draft_tokens_cpu.numel(),
device="cpu",
)
grammar = req.grammar
s = time.perf_counter()
traverse_tree(
retrieve_next_token_cpu[i],
retrieve_next_sibling_cpu[i],
draft_tokens_cpu[i],
req.grammar,
allocate_token_bitmask[
i * num_draft_tokens : (i + 1) * num_draft_tokens
],
)
tree_traverse_time = time.perf_counter() - s
if tree_traverse_time > TREE_TRAVERSE_TIME_THRESHOLD:
logger.warning( logger.warning(
f"Bit mask generation took {tree_traverse_time} seconds with " f"length of new_indices: {len(new_indices)} != length of topk_p: {len(self.topk_p)}, this should not happen"
f"grammar: {req.grammar}"
) )
self.topk_p = self.topk_p[: len(new_indices)]
self.topk_index = self.topk_index[: len(new_indices)]
self.hidden_states = self.hidden_states[: len(new_indices)]
self.verified_id = self.verified_id[: len(new_indices)]
else:
# in some cases(e.g draft_extend), we have not filtered the batch by `unfinished_index`
self.topk_p = self.topk_p[new_indices]
self.topk_index = self.topk_index[new_indices]
self.hidden_states = self.hidden_states[new_indices]
self.verified_id = self.verified_id[new_indices]
def merge_batch(self, spec_info: "EagleDraftInput"):
if self.hidden_states is None:
self.hidden_states = spec_info.hidden_states
self.verified_id = spec_info.verified_id
self.topk_p = spec_info.topk_p
self.topk_index = spec_info.topk_index
return
if spec_info.hidden_states is None:
return
self.hidden_states = torch.cat(
[self.hidden_states, spec_info.hidden_states], axis=0
)
self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0)
self.topk_p = torch.cat([self.topk_p, spec_info.topk_p])
self.topk_index = torch.cat([self.topk_index, spec_info.topk_index])
verify_input.grammar = grammar @dataclass
return allocate_token_bitmask class EagleVerifyOutput:
# Draft input batch
draft_input: EagleDraftInput
# Logit outputs from target worker
logits_output: LogitsProcessorOutput
# Accepted token ids including the bonus token
verified_id: torch.Tensor
# Accepted token length per sequence in a batch in CPU.
accept_length_per_req_cpu: List[int]
# Accepted indices from logits_output.next_token_logits
accepted_indices: torch.Tensor
...@@ -34,16 +34,18 @@ from sglang.srt.speculative.eagle_draft_cuda_graph_runner import ( ...@@ -34,16 +34,18 @@ from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
from sglang.srt.speculative.eagle_draft_extend_cuda_graph_runner import ( from sglang.srt.speculative.eagle_draft_extend_cuda_graph_runner import (
EAGLEDraftExtendCudaGraphRunner, EAGLEDraftExtendCudaGraphRunner,
) )
from sglang.srt.speculative.eagle_utils import ( from sglang.srt.speculative.eagle_info import (
EagleDraftInput, EagleDraftInput,
EagleVerifyInput, EagleVerifyInput,
EagleVerifyOutput, EagleVerifyOutput,
)
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.speculative.spec_utils import (
assign_draft_cache_locs, assign_draft_cache_locs,
fast_topk, fast_topk,
generate_token_bitmask, generate_token_bitmask,
select_top_k_tokens, select_top_k_tokens,
) )
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import ( from sglang.srt.utils import (
empty_context, empty_context,
get_available_gpu_memory, get_available_gpu_memory,
......
...@@ -2,7 +2,7 @@ from __future__ import annotations ...@@ -2,7 +2,7 @@ from __future__ import annotations
import copy import copy
import logging import logging
from typing import Optional from typing import Optional, Tuple
import torch import torch
import triton import triton
...@@ -13,6 +13,7 @@ from dataclasses import dataclass ...@@ -13,6 +13,7 @@ from dataclasses import dataclass
import torch.nn.functional as F import torch.nn.functional as F
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import apply_custom_logit_processor from sglang.srt.layers.sampler import apply_custom_logit_processor
from sglang.srt.managers.schedule_batch import ( from sglang.srt.managers.schedule_batch import (
...@@ -21,10 +22,10 @@ from sglang.srt.managers.schedule_batch import ( ...@@ -21,10 +22,10 @@ from sglang.srt.managers.schedule_batch import (
global_server_args_dict, global_server_args_dict,
) )
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.speculative.eagle_utils import ( from sglang.srt.speculative.spec_info import SpecInput, SpecInputType
from sglang.srt.speculative.spec_utils import (
TREE_SPEC_KERNEL_AVAILABLE, TREE_SPEC_KERNEL_AVAILABLE,
assign_req_to_token_pool, assign_req_to_token_pool,
create_flashinfer_kv_indices_triton,
get_src_tgt_cache_loc, get_src_tgt_cache_loc,
get_target_cache_loc, get_target_cache_loc,
) )
...@@ -42,7 +43,7 @@ elif is_hip(): ...@@ -42,7 +43,7 @@ elif is_hip():
@dataclass @dataclass
class NgramVerifyInput: class NgramVerifyInput(SpecInput):
def __init__( def __init__(
self, self,
draft_token: torch.Tensor, draft_token: torch.Tensor,
...@@ -53,6 +54,7 @@ class NgramVerifyInput: ...@@ -53,6 +54,7 @@ class NgramVerifyInput:
retrive_next_sibling: torch.Tensor, retrive_next_sibling: torch.Tensor,
draft_token_num: int, draft_token_num: int,
): ):
super().__init__(SpecInputType.NGRAM_VERIFY)
self.draft_token = draft_token self.draft_token = draft_token
self.custom_mask = tree_mask self.custom_mask = tree_mask
self.positions = positions self.positions = positions
...@@ -62,6 +64,9 @@ class NgramVerifyInput: ...@@ -62,6 +64,9 @@ class NgramVerifyInput:
self.draft_token_num = draft_token_num self.draft_token_num = draft_token_num
self.device = self.custom_mask.device self.device = self.custom_mask.device
def get_spec_adjust_token_coefficient(self) -> Tuple[int, int]:
return self.draft_token_num, self.draft_token_num
def prepare_for_verify(self, batch: ScheduleBatch, page_size: int): def prepare_for_verify(self, batch: ScheduleBatch, page_size: int):
if batch.forward_mode.is_idle(): if batch.forward_mode.is_idle():
return return
......
import logging import logging
import os from typing import List, Optional
import threading
import time
from typing import TYPE_CHECKING, List, Optional, Union
import numpy as np import numpy as np
import torch import torch
...@@ -15,7 +12,6 @@ from sglang.srt.server_args import ServerArgs ...@@ -15,7 +12,6 @@ from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.cpp_ngram.ngram_cache import NgramCache from sglang.srt.speculative.cpp_ngram.ngram_cache import NgramCache
from sglang.srt.speculative.ngram_utils import NgramVerifyInput from sglang.srt.speculative.ngram_utils import NgramVerifyInput
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import broadcast_pyobj
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
from abc import ABC, abstractmethod
from enum import IntEnum, auto from enum import IntEnum, auto
from typing import List, Tuple
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
class SpeculativeAlgorithm(IntEnum): class SpeculativeAlgorithm(IntEnum):
...@@ -35,3 +39,41 @@ class SpeculativeAlgorithm(IntEnum): ...@@ -35,3 +39,41 @@ class SpeculativeAlgorithm(IntEnum):
if name is not None: if name is not None:
name = name.upper() name = name.upper()
return name_map[name] return name_map[name]
class SpecInputType(IntEnum):
# NOTE: introduce this to distinguish the SpecInput types of multiple algorithms when asserting in attention backends.
# If all algorithms can share the same datastrucutre of draft_input and verify_input, consider simplify it
EAGLE_DRAFT = auto()
EAGLE_VERIFY = auto()
NGRAM_VERIFY = auto()
class SpecInput(ABC):
def __init__(self, spec_input_type: SpecInputType):
self.spec_input_type = spec_input_type
def is_draft_input(self) -> bool:
# FIXME: remove this function which is only used for assertion
# or use another variable name like `draft_input` to substitute `spec_info`
return self.spec_input_type == SpecInputType.EAGLE_DRAFT
def is_verify_input(self) -> bool:
return self.spec_input_type in {
SpecInputType.EAGLE_VERIFY,
SpecInputType.NGRAM_VERIFY,
}
@abstractmethod
def get_spec_adjust_token_coefficient(self) -> Tuple[int, int]:
pass
def get_spec_adjusted_global_num_tokens(
self, forward_batch: ModelWorkerBatch
) -> Tuple[List[int], List[int]]:
c1, c2 = self.get_spec_adjust_token_coefficient()
global_num_tokens = [x * c1 for x in forward_batch.global_num_tokens]
global_num_tokens_for_logprob = [
x * c2 for x in forward_batch.global_num_tokens_for_logprob
]
return global_num_tokens, global_num_tokens_for_logprob
from __future__ import annotations
import logging
import os
import time
from typing import TYPE_CHECKING, List
import torch
import triton
import triton.language as tl
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
from sglang.srt.environ import envs
from sglang.srt.managers.schedule_batch import Req
from sglang.srt.utils import is_cuda, is_hip
if is_cuda():
from sgl_kernel import fast_topk
elif is_hip():
from sgl_kernel import fast_topk
if TYPE_CHECKING:
from sglang.srt.speculative.eagle_info import EagleVerifyInput
logger = logging.getLogger(__name__)
# Simulate acceptance length for benchmarking purposes
SIMULATE_ACC_LEN = envs.SGLANG_SIMULATE_ACC_LEN.get() # turn off if < 0
SIMULATE_ACC_METHOD = envs.SGLANG_SIMULATE_ACC_METHOD.get()
TREE_TRAVERSE_TIME_THRESHOLD = 1 # TODO: set this properly
TREE_SPEC_KERNEL_AVAILABLE = "tree_speculative_sampling_target_only" in globals()
@triton.jit
def create_extend_after_decode_spec_info(
verified_id,
seq_lens,
accept_lens,
positions,
new_verified_id,
bs_upper: tl.constexpr,
):
pid = tl.program_id(axis=0)
offsets = tl.arange(0, bs_upper)
seq_length = tl.load(seq_lens + pid)
accept_length = tl.load(accept_lens + pid)
accept_len_cumsum = tl.sum(
tl.load(accept_lens + offsets, mask=offsets < pid, other=0)
)
positions_ptr = positions + accept_len_cumsum
mask = offsets < accept_length
tl.store(positions_ptr + offsets, seq_length - accept_length + offsets, mask)
accept_len_cumsum += accept_length - 1
verified_id_data = tl.load(verified_id + accept_len_cumsum)
tl.store(new_verified_id + pid, verified_id_data)
@triton.jit
def assign_req_to_token_pool(
req_pool_indices,
req_to_token,
start_offset,
end_offset,
out_cache_loc,
pool_len: tl.constexpr,
bs_upper: tl.constexpr,
):
BLOCK_SIZE: tl.constexpr = 32
pid = tl.program_id(axis=0)
kv_start = tl.load(start_offset + pid)
kv_end = tl.load(end_offset + pid)
token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
length_offset = tl.arange(0, bs_upper)
start = tl.load(start_offset + length_offset, mask=length_offset < pid, other=0)
end = tl.load(end_offset + length_offset, mask=length_offset < pid, other=0)
out_offset = tl.sum(end - start, axis=0)
out_cache_ptr = out_cache_loc + out_offset
save_offset = tl.arange(0, BLOCK_SIZE) + kv_start
load_offset = tl.arange(0, BLOCK_SIZE)
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
for _ in range(num_loop):
mask = save_offset < kv_end
data = tl.load(out_cache_ptr + load_offset, mask=mask)
tl.store(token_pool + save_offset, data, mask=mask)
save_offset += BLOCK_SIZE
load_offset += BLOCK_SIZE
@triton.jit
def assign_draft_cache_locs(
req_pool_indices,
req_to_token,
seq_lens,
extend_lens,
num_new_pages_per_topk,
out_cache_loc,
pool_len: tl.constexpr,
topk: tl.constexpr,
speculative_num_steps: tl.constexpr,
page_size: tl.constexpr,
bs_upper: tl.constexpr,
iter_upper: tl.constexpr,
):
BLOCK_SIZE: tl.constexpr = 128
pid = tl.program_id(axis=0)
if page_size == 1 or topk == 1:
copy_len = topk * speculative_num_steps
out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps
else:
bs_offset = tl.arange(0, bs_upper)
copy_len = tl.load(extend_lens + pid)
cum_copy_len = tl.sum(tl.load(extend_lens + bs_offset, mask=bs_offset < pid))
out_cache_ptr = out_cache_loc + cum_copy_len
# Part 1: Copy from out_cache_loc to req_to_token
kv_start = tl.load(seq_lens + pid)
token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
num_loop = tl.cdiv(copy_len, BLOCK_SIZE)
for i in range(num_loop):
copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
mask = copy_offset < copy_len
data = tl.load(out_cache_ptr + copy_offset, mask=mask)
tl.store(token_pool + kv_start + copy_offset, data, mask=mask)
if page_size == 1 or topk == 1:
return
# Part 2: Copy the indices for the last partial page
prefix_len = tl.load(seq_lens + pid)
last_page_len = prefix_len % page_size
offsets = tl.arange(0, page_size)
mask = offsets < last_page_len
num_new_pages_per_topk_ = tl.load(num_new_pages_per_topk + pid)
prefix_base = token_pool + prefix_len - last_page_len
for topk_id in range(topk):
value = tl.load(prefix_base + offsets, mask=mask)
tl.store(
prefix_base + topk_id * num_new_pages_per_topk_ * page_size + offsets,
value,
mask=mask,
)
# Part 3: Remove the padding in out_cache_loc
iter_offest = tl.arange(0, iter_upper)
for topk_id in range(topk):
indices = tl.load(
prefix_base
+ topk_id * num_new_pages_per_topk_ * page_size
+ last_page_len
+ iter_offest,
mask=iter_offest < speculative_num_steps,
)
tl.store(
out_cache_loc
+ pid * topk * speculative_num_steps
+ topk_id * speculative_num_steps
+ iter_offest,
indices,
mask=iter_offest < speculative_num_steps,
)
@triton.jit
def generate_draft_decode_kv_indices(
req_pool_indices,
req_to_token,
paged_kernel_lens,
kv_indices,
kv_indptr,
positions,
pool_len: tl.constexpr,
kv_indices_stride: tl.constexpr,
kv_indptr_stride: tl.constexpr,
bs_upper: tl.constexpr,
iter_upper: tl.constexpr,
num_tokens_upper: tl.constexpr,
page_size: tl.constexpr,
):
BLOCK_SIZE: tl.constexpr = 128
iters = tl.program_id(axis=0)
bid = tl.program_id(axis=1)
topk_id = tl.program_id(axis=2)
num_steps = tl.num_programs(axis=0)
num_seqs = tl.num_programs(axis=1)
topk = tl.num_programs(axis=2)
kv_indices += kv_indices_stride * iters
kv_indptr += kv_indptr_stride * iters
iters += 1
load_offset = tl.arange(0, bs_upper)
seq_lens = tl.load(paged_kernel_lens + load_offset, mask=load_offset < bid, other=0)
seq_len = tl.load(paged_kernel_lens + bid)
cum_seq_len = tl.sum(seq_lens)
# Update kv_indices
kv_offset = cum_seq_len * topk + bid * iters * topk + topk_id * (seq_len + iters)
kv_ptr = kv_indices + kv_offset
token_pool_ptr = req_to_token + tl.load(req_pool_indices + bid) * pool_len
kv_offset = tl.arange(0, BLOCK_SIZE)
num_loop = tl.cdiv(seq_len, BLOCK_SIZE)
for _ in range(num_loop):
mask = kv_offset < seq_len
data = tl.load(token_pool_ptr + kv_offset, mask=mask)
tl.store(kv_ptr + kv_offset, data, mask=mask)
kv_offset += BLOCK_SIZE
extend_offset = tl.arange(0, iter_upper)
if page_size == 1 or topk == 1:
extend_data = tl.load(
token_pool_ptr + seq_len + topk_id * num_steps + tl.arange(0, iter_upper),
mask=extend_offset < iters,
)
else:
prefix_len = seq_len
last_page_len = prefix_len % page_size
num_new_pages_per_topk = (
last_page_len + num_steps + page_size - 1
) // page_size
prefix_base = seq_len // page_size * page_size
start = (
prefix_base + topk_id * num_new_pages_per_topk * page_size + last_page_len
)
extend_data = tl.load(
token_pool_ptr + start + extend_offset,
mask=extend_offset < iters,
)
tl.store(kv_ptr + seq_len + extend_offset, extend_data, mask=extend_offset < iters)
# Update kv_indptr
bs_offset = tl.arange(0, num_tokens_upper)
zid = bid * topk + topk_id
if zid == 0:
zid = num_seqs * topk
positions = tl.load(positions + bs_offset, mask=bs_offset < zid, other=0)
base = tl.sum(positions)
tl.store(kv_indptr + zid, base + zid * iters)
@triton.jit
def align_evict_mask_to_page_size(
seq_lens,
evict_mask,
page_size: tl.constexpr,
num_draft_tokens: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
t_range = tl.arange(0, BLOCK_SIZE)
bid = tl.program_id(axis=0)
seq_len = tl.load(seq_lens + bid)
io_mask = t_range < num_draft_tokens
mask_row = tl.load(
evict_mask + bid * num_draft_tokens + t_range, mask=io_mask, other=0
)
num_trues = tl.sum(mask_row)
num_false = num_draft_tokens - num_trues
start = (seq_len + num_false - 1) // page_size * page_size - seq_len
for i in range(max(start, 0), min(start + page_size, num_draft_tokens)):
tl.store(evict_mask + bid * num_draft_tokens + i, False)
@triton.jit
def get_target_cache_loc(
tgt_cache_loc,
to_free_slots,
accept_length,
to_free_num_slots,
out_cache_loc,
num_verify_tokens: tl.constexpr,
num_verify_tokens_upper: tl.constexpr,
bs_upper: tl.constexpr,
):
bid = tl.program_id(axis=0)
offset = tl.arange(0, num_verify_tokens_upper)
bs_offset = tl.arange(0, bs_upper)
# write the first part to tgt_cache_loc
accept_len_all = tl.load(accept_length + bs_offset, mask=bs_offset < bid)
tgt_cache_loc_start = tl.sum(accept_len_all) + bid
copy_len = tl.load(accept_length + bid) + 1
out_cache_loc_row = tl.load(
out_cache_loc + bid * num_verify_tokens + offset, mask=offset < copy_len
)
tl.store(
tgt_cache_loc + tgt_cache_loc_start + offset,
out_cache_loc_row,
mask=offset < copy_len,
)
# write the second part to to_free_num_pages
to_free_num_slots_all = tl.load(to_free_num_slots + bs_offset, mask=bs_offset < bid)
to_free_num_slots_cur = tl.load(to_free_num_slots + bid)
out_cache_loc_start = num_verify_tokens - to_free_num_slots_cur
to_free_slots_start = tl.sum(to_free_num_slots_all)
copy_len = to_free_num_slots_cur
out_cache_loc_row = tl.load(
out_cache_loc + bid * num_verify_tokens + out_cache_loc_start + offset,
mask=offset < copy_len,
)
tl.store(
to_free_slots + to_free_slots_start + offset,
out_cache_loc_row,
mask=offset < copy_len,
)
@torch.compile(dynamic=True)
def get_src_tgt_cache_loc(
seq_lens: torch.Tensor,
out_cache_loc: torch.Tensor,
accept_index: torch.Tensor,
accept_length: torch.Tensor,
draft_token_num: int,
page_size: int,
):
src_cache_loc = out_cache_loc[accept_index]
tgt_cache_loc = torch.empty_like(src_cache_loc)
extended_len = seq_lens + draft_token_num
keep_len = torch.minimum(
(seq_lens + accept_length + 1 + page_size - 1) // page_size * page_size,
extended_len,
)
to_free_num_slots = extended_len - keep_len
return src_cache_loc, tgt_cache_loc, to_free_num_slots
@triton.jit
def filter_finished_cache_loc_kernel(
out_cache_loc,
tgt_cache_loc,
accept_length,
accept_length_filter,
bs_upper: tl.constexpr,
num_verify_tokens_upper: tl.constexpr,
):
bid = tl.program_id(0)
bs_offset = tl.arange(0, bs_upper)
accept_length_all = tl.load(accept_length + bs_offset, mask=bs_offset < bid)
old_start = tl.sum(accept_length_all) + bid
accept_length_filter_all = tl.load(
accept_length_filter + bs_offset, mask=bs_offset < bid
)
new_start = tl.sum(accept_length_filter_all)
copy_len = tl.load(accept_length_filter + bid)
copy_offset = tl.arange(0, num_verify_tokens_upper)
value = tl.load(
tgt_cache_loc + old_start + copy_offset, mask=copy_offset < copy_len
)
tl.store(
out_cache_loc + new_start + copy_offset, value, mask=copy_offset < copy_len
)
@torch.compile(dynamic=True)
def create_accept_length_filter(
accept_length: torch.Tensor,
unfinished_index_device: torch.Tensor,
seq_lens: torch.Tensor,
):
accept_length_filter = torch.zeros_like(accept_length)
accept_length_filter[unfinished_index_device] = (
accept_length[unfinished_index_device] + 1
)
seq_lens.add_(accept_length + 1)
return accept_length_filter
@torch.compile(dynamic=True)
def select_top_k_tokens(
i: int,
topk_p: torch.Tensor,
topk_index: torch.Tensor,
hidden_states: torch.Tensor,
scores: torch.Tensor,
topk: int,
):
if i == 0:
# The first step after extend
input_ids = topk_index.flatten()
hidden_states = hidden_states.repeat_interleave(topk, dim=0)
scores = topk_p # shape: (b, topk)
tree_info = (
topk_p.unsqueeze(1), # shape: (b, 1, topk)
topk_index, # shape: (b, topk)
torch.arange(-1, topk, dtype=torch.long, device="cuda")
.unsqueeze(0)
.repeat(topk_p.shape[0], 1), # shape: (b, topk + 1)
)
else:
# The later decode steps
expand_scores = torch.mul(
scores.unsqueeze(2), topk_p.reshape(-1, topk, topk)
) # (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk)
topk_cs_p, topk_cs_index = fast_topk(
expand_scores.flatten(start_dim=1), topk, dim=-1
) # (b, topk)
scores = topk_cs_p # shape: (b, topk)
topk_index = topk_index.reshape(-1, topk**2)
input_ids = torch.gather(topk_index, index=topk_cs_index, dim=1).flatten()
if hidden_states.shape[0] > 0:
selected_input_index = topk_cs_index.flatten() // topk + torch.arange(
0, hidden_states.shape[0], step=topk, device="cuda"
).repeat_interleave(topk)
hidden_states = hidden_states[selected_input_index, :]
tree_info = (
expand_scores, # shape: (b, topk, topk)
topk_index, # shape: (b, topk * topk)
topk_cs_index + (topk**2 * (i - 1) + topk), # shape: (b, topk)
)
return input_ids, hidden_states, scores, tree_info
def _generate_simulated_accept_index(
accept_index,
predict,
accept_length,
bs,
spec_steps,
simulate_acc_len: float = SIMULATE_ACC_LEN,
simulate_acc_method: str = SIMULATE_ACC_METHOD,
):
assert simulate_acc_len > 0.0
if simulate_acc_method == "multinomial":
simulated_values = torch.normal(
mean=simulate_acc_len,
std=1.0,
size=(1,),
device="cpu",
)
# clamp simulated values to be between 1 and self.spec_steps
simulated_values = torch.clamp(simulated_values, min=1.0, max=spec_steps + 1)
simulate_acc_len = int(simulated_values.round().item())
elif simulate_acc_method == "match-expected":
# multinomial sampling does not match the expected length
# we keep it for the sake of compatibility of existing tests
# but it's better to use "match-expected" for the cases that need to
# match the expected length, One caveat is that this will only sample
# either round down or round up of the expected length
simulate_acc_len = max(1.0, min(spec_steps + 1, simulate_acc_len))
lower = int(simulate_acc_len // 1)
upper = lower + 1 if lower < spec_steps + 1 else lower
if lower == upper:
simulate_acc_len = lower
else:
weight_upper = simulate_acc_len - lower
weight_lower = 1.0 - weight_upper
probs = torch.tensor([weight_lower, weight_upper], device="cpu")
sampled_index = torch.multinomial(probs, num_samples=1)
simulate_acc_len = lower if sampled_index == 0 else upper
else:
raise ValueError(f"Invalid simulate_acc_method: {SIMULATE_ACC_METHOD}")
accept_indx_first_col = accept_index[:, 0].view(-1, 1)
sim_accept_index = torch.full(
(bs, spec_steps + 1), -1, dtype=torch.int32, device="cuda"
)
sim_accept_index[:, :simulate_acc_len] = accept_indx_first_col + torch.arange(
simulate_acc_len, device=accept_index.device
)
accept_length.fill_(simulate_acc_len - 1)
predict.fill_(100) # some legit token id
return sim_accept_index
def traverse_tree(
retrieve_next_token: torch.Tensor,
retrieve_next_sibling: torch.Tensor,
draft_tokens: torch.Tensor,
grammar: BaseGrammarObject,
allocate_token_bitmask: torch.Tensor,
):
"""
Traverse the tree constructed by the draft model to generate the logits mask.
"""
assert (
retrieve_next_token.shape == retrieve_next_sibling.shape == draft_tokens.shape
)
allocate_token_bitmask.fill_(0)
def dfs(
curr: int,
retrieve_next_token: torch.Tensor,
retrieve_next_sibling: torch.Tensor,
parent_pos: int,
):
if curr == 0:
# the first token generated by the target model, and thus it is always
# accepted from the previous iteration
accepted = True
else:
parent_bitmask = allocate_token_bitmask[parent_pos]
curr_token_id = draft_tokens[curr]
# 32 boolean bitmask values are packed into 32-bit integers
accepted = (
parent_bitmask[curr_token_id // 32] & (1 << (curr_token_id % 32))
) != 0
if accepted:
if curr != 0:
# Accept the current token
grammar.accept_token(draft_tokens[curr])
if not grammar.is_terminated():
# Generate the bitmask for the current token
grammar.fill_vocab_mask(allocate_token_bitmask, curr)
if retrieve_next_token[curr] != -1:
# Visit the child node
dfs(
retrieve_next_token[curr],
retrieve_next_token,
retrieve_next_sibling,
curr,
)
if curr != 0:
# Rollback the current token
grammar.rollback(1)
if retrieve_next_sibling[curr] != -1:
# Visit the sibling node
dfs(
retrieve_next_sibling[curr],
retrieve_next_token,
retrieve_next_sibling,
parent_pos,
)
dfs(0, retrieve_next_token, retrieve_next_sibling, -1)
def generate_token_bitmask(
reqs: List[Req],
verify_input: EagleVerifyInput,
retrieve_next_token_cpu: torch.Tensor,
retrieve_next_sibling_cpu: torch.Tensor,
draft_tokens_cpu: torch.Tensor,
vocab_size: int,
):
"""
Generate the logit mask for structured output.
Draft model's token can be either valid or invalid with respect to the grammar.
We need to perform DFS to
1. figure out which tokens are accepted by the grammar.
2. if so, what is the corresponding logit mask.
"""
num_draft_tokens = draft_tokens_cpu.shape[-1]
allocate_token_bitmask = None
assert len(reqs) == retrieve_next_token_cpu.shape[0]
grammar = None
for i, req in enumerate(reqs):
if req.grammar is not None:
if allocate_token_bitmask is None:
allocate_token_bitmask = req.grammar.allocate_vocab_mask(
vocab_size=vocab_size,
batch_size=draft_tokens_cpu.numel(),
device="cpu",
)
grammar = req.grammar
s = time.perf_counter()
traverse_tree(
retrieve_next_token_cpu[i],
retrieve_next_sibling_cpu[i],
draft_tokens_cpu[i],
req.grammar,
allocate_token_bitmask[
i * num_draft_tokens : (i + 1) * num_draft_tokens
],
)
tree_traverse_time = time.perf_counter() - s
if tree_traverse_time > TREE_TRAVERSE_TIME_THRESHOLD:
logger.warning(
f"Bit mask generation took {tree_traverse_time} seconds with "
f"grammar: {req.grammar}"
)
verify_input.grammar = grammar
return allocate_token_bitmask
...@@ -30,7 +30,8 @@ from sglang.srt.model_executor.forward_batch_info import ( ...@@ -30,7 +30,8 @@ from sglang.srt.model_executor.forward_batch_info import (
) )
from sglang.srt.operations import execute_operations, execute_overlapped_operations from sglang.srt.operations import execute_operations, execute_overlapped_operations
from sglang.srt.operations_strategy import OperationsStrategy from sglang.srt.operations_strategy import OperationsStrategy
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.spec_info import SpecInput
from sglang.srt.utils import BumpAllocator, empty_context, get_bool_env_var, is_hip from sglang.srt.utils import BumpAllocator, empty_context, get_bool_env_var, is_hip
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -48,7 +49,7 @@ logger = logging.getLogger(__name__) ...@@ -48,7 +49,7 @@ logger = logging.getLogger(__name__)
def get_token_num_per_seq( def get_token_num_per_seq(
forward_mode: ForwardMode, forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None, spec_info: Optional[SpecInput] = None,
): ):
if forward_mode.is_target_verify(): if forward_mode.is_target_verify():
return spec_info.draft_token_num return spec_info.draft_token_num
...@@ -273,7 +274,7 @@ def compute_split_token_index( ...@@ -273,7 +274,7 @@ def compute_split_token_index(
def compute_split_indices_for_cuda_graph_replay( def compute_split_indices_for_cuda_graph_replay(
forward_mode: ForwardMode, forward_mode: ForwardMode,
cuda_graph_num_tokens: int, cuda_graph_num_tokens: int,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], spec_info: Optional[SpecInput],
): ):
forward_mode_for_tbo_split = ( forward_mode_for_tbo_split = (
forward_mode if forward_mode != ForwardMode.IDLE else ForwardMode.DECODE forward_mode if forward_mode != ForwardMode.IDLE else ForwardMode.DECODE
...@@ -333,7 +334,7 @@ class TboCudaGraphRunnerPlugin: ...@@ -333,7 +334,7 @@ class TboCudaGraphRunnerPlugin:
forward_mode: ForwardMode, forward_mode: ForwardMode,
bs: int, bs: int,
num_token_non_padded: int, num_token_non_padded: int,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], spec_info: Optional[SpecInput],
): ):
token_num_per_seq = get_token_num_per_seq( token_num_per_seq = get_token_num_per_seq(
forward_mode=forward_mode, spec_info=spec_info forward_mode=forward_mode, spec_info=spec_info
......
...@@ -7,7 +7,6 @@ or ...@@ -7,7 +7,6 @@ or
python3 test_forward_split_prefill.py python3 test_forward_split_prefill.py
""" """
import time
import unittest import unittest
import numpy as np import numpy as np
...@@ -16,7 +15,7 @@ import torch ...@@ -16,7 +15,7 @@ import torch
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment