Commit ac238727 authored by Lianmin Zheng's avatar Lianmin Zheng
Browse files

Support penalty in overlap mode; return logprob with chunked prefill; improve...


Support penalty in overlap mode; return logprob with chunked prefill; improve benchmark scripts (#3988)
Co-authored-by: default avatarSangBin Cho <rkooo567@gmail.com>
Co-authored-by: default avatardhou-xai <dhou@x.ai>
Co-authored-by: default avatarHanming Lu <hanming_lu@berkeley.edu>
parent 0194948f
from __future__ import annotations
import functools
from typing import TYPE_CHECKING, Union
import torch import torch
import triton
import triton.language as tl
from sglang.srt.distributed import (
GroupCoordinator,
get_tensor_model_parallel_world_size,
get_tp_group,
tensor_model_parallel_all_reduce,
)
from sglang.srt.distributed import GroupCoordinator, get_tp_group if TYPE_CHECKING:
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
_ATTN_TP_GROUP = None _ATTN_TP_GROUP = None
_ATTN_TP_RANK = None _ATTN_TP_RANK = None
...@@ -69,3 +84,129 @@ def get_attention_dp_rank(): ...@@ -69,3 +84,129 @@ def get_attention_dp_rank():
def get_attention_dp_size(): def get_attention_dp_size():
assert _DP_SIZE is not None, "dp attention not initialized!" assert _DP_SIZE is not None, "dp attention not initialized!"
return _DP_SIZE return _DP_SIZE
def get_dp_local_info(forward_batch: ForwardBatch):
dp_rank = get_attention_dp_rank()
if forward_batch.dp_local_start_pos is None:
cumtokens = torch.cumsum(forward_batch.global_num_tokens_gpu, dim=0)
if dp_rank == 0:
local_start_pos = torch.zeros_like(cumtokens[0])
else:
local_start_pos = cumtokens[dp_rank - 1]
local_num_tokens = forward_batch.global_num_tokens_gpu[dp_rank]
forward_batch.dp_local_start_pos = local_start_pos
forward_batch.dp_local_num_tokens = local_num_tokens
return forward_batch.dp_local_start_pos, forward_batch.dp_local_num_tokens
@triton.jit
def memcpy_triton_kernel(
dst_ptr,
src_ptr,
offset_ptr,
sz_ptr,
offset_src,
chunk_size, # multiplied for offset and sz
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0).to(tl.int64)
offset = tl.load(offset_ptr).to(tl.int64) * chunk_size
sz = tl.load(sz_ptr).to(tl.int64) * chunk_size
start_index = pid * BLOCK_SIZE
offs = tl.arange(0, BLOCK_SIZE)
mask = start_index + offs < sz
if offset_src:
data = tl.load(src_ptr + offset + start_index + offs, mask=mask)
tl.store(dst_ptr + start_index + offs, data, mask=mask)
else:
data = tl.load(src_ptr + start_index + offs, mask=mask)
tl.store(dst_ptr + offset + start_index + offs, data, mask=mask)
def prod(x):
return functools.reduce(lambda a, b: a * b, x, 1)
def memcpy_triton(dst, src, dim, offset, sz, offset_src):
max_size = min(src.numel(), dst.numel())
assert dim == 0, "dim != 0 unsupported"
assert src.shape[1:] == dst.shape[1:], "src and dst must have same shape"
chunk_size = prod(src.shape[1:])
BLOCK_SIZE = 8192
grid = (triton.cdiv(max_size, BLOCK_SIZE),)
memcpy_triton_kernel[grid](dst, src, offset, sz, offset_src, chunk_size, BLOCK_SIZE)
def dp_gather(
global_tokens: torch.Tensor,
local_tokens: torch.Tensor,
forward_batch: ForwardBatch,
layer_id: Union[str, int],
):
local_start_pos, local_num_tokens = get_dp_local_info(forward_batch)
global_tokens.fill_(0)
assert local_tokens.is_contiguous()
assert global_tokens.is_contiguous()
if local_tokens.shape[0] > 0 and (
layer_id != "embedding" or get_attention_tp_rank() == 0
):
assert (
global_tokens.storage().data_ptr() != local_tokens.storage().data_ptr()
), "aliasing between global_tokens and local_tokens not allowed"
memcpy_triton(
global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
)
# Input IDs are in int 32. We should use inplace_all_reduce for local case becaues of custom all reduce.
NUM_GPUS_PER_NODE = 8
if (
not local_tokens.dtype.is_floating_point
and get_tensor_model_parallel_world_size() <= NUM_GPUS_PER_NODE
):
torch.ops.sglang.inplace_all_reduce(
global_tokens, group_name=get_tp_group().unique_name
)
else:
global_tokens = tensor_model_parallel_all_reduce(global_tokens)
def dp_scatter(
local_tokens: torch.Tensor, # output
global_tokens: torch.Tensor, # input
forward_batch: ForwardBatch,
):
# local_num_tokens is not necessarily the same as local_tokens.shape[0],
# since local_tokens may be padded for cuda graph
local_start_pos, local_num_tokens = get_dp_local_info(forward_batch)
local_tokens.fill_(0)
assert local_tokens.is_contiguous()
assert global_tokens.is_contiguous()
if local_tokens.shape[0] > 0:
assert (
local_tokens.untyped_storage().data_ptr()
!= global_tokens.untyped_storage().data_ptr()
), "aliasing between local_tokens and global_tokens not allowed"
memcpy_triton(
local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True
)
def get_do_logits_dp_scatter(forward_batch: ForwardBatch):
def do_logits_dp_scatter(logits: torch.Tensor):
local_logits = torch.empty(
(forward_batch.input_ids.shape[0], *logits.shape[1:]),
dtype=logits.dtype,
device=logits.device,
)
dp_scatter(local_logits, logits, forward_batch)
return local_logits
return do_logits_dp_scatter
...@@ -69,7 +69,7 @@ class RMSNorm(CustomOp): ...@@ -69,7 +69,7 @@ class RMSNorm(CustomOp):
variance = x.pow(2).mean(dim=-1, keepdim=True) variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon) x = x * torch.rsqrt(variance + self.variance_epsilon)
x = x.to(orig_dtype) * self.weight x = (x * self.weight).to(orig_dtype)
if residual is None: if residual is None:
return x return x
else: else:
......
...@@ -426,13 +426,14 @@ class ColumnParallelLinear(LinearBase): ...@@ -426,13 +426,14 @@ class ColumnParallelLinear(LinearBase):
from sglang.srt.layers.parameter import _ColumnvLLMParameter from sglang.srt.layers.parameter import _ColumnvLLMParameter
if isinstance(param, _ColumnvLLMParameter): if isinstance(param, _ColumnvLLMParameter):
# FIXME: why would we need this special case?
param.load_column_parallel_weight( param.load_column_parallel_weight(
loaded_weight, loaded_weight,
tp_rank=self.tp_rank, tp_rank=self.tp_rank,
use_presharded_weights=self.use_presharded_weights, use_presharded_weights=self.use_presharded_weights,
) )
else: else:
# FIXME: This branch is needed to load deepseek v3 awq.
# However, we should fix this and avoid the branching here.
param.load_column_parallel_weight(loaded_weight) param.load_column_parallel_weight(loaded_weight)
def forward(self, input_): def forward(self, input_):
......
...@@ -26,12 +26,19 @@ from sglang.srt.distributed import ( ...@@ -26,12 +26,19 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
) )
from sglang.srt.layers.dp_attention import (
dp_gather,
dp_scatter,
get_attention_dp_rank,
get_attention_dp_size,
)
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.model_executor.forward_batch_info import ( from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode, CaptureHiddenMode,
ForwardBatch, ForwardBatch,
ForwardMode, ForwardMode,
) )
from sglang.srt.utils import dump_to_file
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -51,6 +58,9 @@ class LogitsProcessorOutput: ...@@ -51,6 +58,9 @@ class LogitsProcessorOutput:
# The logprobs and ids of the top-k tokens in output positions. shape: [#seq, k] # The logprobs and ids of the top-k tokens in output positions. shape: [#seq, k]
next_token_top_logprobs_val: Optional[List] = None next_token_top_logprobs_val: Optional[List] = None
next_token_top_logprobs_idx: Optional[List] = None next_token_top_logprobs_idx: Optional[List] = None
# The logprobs and ids of the requested token ids in output positions. shape: [#seq, n] (n is the number of requested token ids)
next_token_token_ids_logprobs_val: Optional[List] = None
next_token_token_ids_logprobs_idx: Optional[List] = None
## Part 3: Prefill-only. This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor ## Part 3: Prefill-only. This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
# The logprobs of input tokens. shape: [#token] # The logprobs of input tokens. shape: [#token]
...@@ -58,6 +68,9 @@ class LogitsProcessorOutput: ...@@ -58,6 +68,9 @@ class LogitsProcessorOutput:
# The logprobs and ids of the top-k tokens in input positions. shape: [#seq, #token, k] # The logprobs and ids of the top-k tokens in input positions. shape: [#seq, #token, k]
input_top_logprobs_val: List = None input_top_logprobs_val: List = None
input_top_logprobs_idx: List = None input_top_logprobs_idx: List = None
# The logprobs and ids of the requested token ids in input positions. shape: [#seq, n] (n is the number of requested token ids)
input_token_ids_logprobs_val: Optional[List] = None
input_token_ids_logprobs_idx: Optional[List] = None
@dataclasses.dataclass @dataclasses.dataclass
...@@ -67,43 +80,107 @@ class LogitsMetadata: ...@@ -67,43 +80,107 @@ class LogitsMetadata:
extend_return_logprob: bool = False extend_return_logprob: bool = False
extend_return_top_logprob: bool = False extend_return_top_logprob: bool = False
extend_token_ids_logprob: bool = False
extend_seq_lens: Optional[torch.Tensor] = None extend_seq_lens: Optional[torch.Tensor] = None
extend_seq_lens_cpu: Optional[List[int]] = None extend_seq_lens_cpu: Optional[List[int]] = None
extend_logprob_start_lens_cpu: Optional[List[int]] = None extend_logprob_start_lens_cpu: Optional[List[int]] = None
extend_logprob_pruned_lens_cpu: Optional[List[int]] = None extend_logprob_pruned_lens_cpu: Optional[List[int]] = None
top_logprobs_nums: Optional[List[int]] = None top_logprobs_nums: Optional[List[int]] = None
extend_input_logprob_token_ids_gpu: Optional[torch.Tensor] = None
token_ids_logprobs: Optional[List[List[int]]] = None
# logits and logprobs post processing
temp_scaled_logprobs: bool = False
temperature: torch.Tensor = None
top_p_normalized_logprobs: bool = False
top_p: torch.Tensor = None
# DP attention metadata. Not needed when DP attention is not used.
# Number of tokens in the request.
global_num_tokens_gpu: Optional[torch.Tensor] = None
# The start position of local hidden states.
dp_local_start_pos: Optional[torch.Tensor] = None
dp_local_num_tokens: Optional[torch.Tensor] = None
gathered_buffer: Optional[torch.Tensor] = None
# Buffer to gather logits from all ranks.
forward_batch_gathered_buffer: Optional[torch.Tensor] = None
# Number of tokens to sample per DP rank
global_num_tokens_for_logprob_cpu: Optional[torch.Tensor] = None
global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] = None
# for padding
padded_static_len: int = -1
@classmethod @classmethod
def from_forward_batch(cls, forward_batch: ForwardBatch): def from_forward_batch(cls, forward_batch: ForwardBatch):
if forward_batch.forward_mode.is_extend() and forward_batch.return_logprob: if (
extend_return_logprob = True forward_batch.forward_mode.is_extend()
and forward_batch.return_logprob
and not forward_batch.forward_mode.is_target_verify()
):
extend_return_top_logprob = any( extend_return_top_logprob = any(
x > 0 for x in forward_batch.top_logprobs_nums x > 0 for x in forward_batch.top_logprobs_nums
) )
extend_logprob_pruned_lens_cpu = [ extend_token_ids_logprob = any(
extend_len - start_len x is not None for x in forward_batch.token_ids_logprobs
for extend_len, start_len in zip( )
forward_batch.extend_seq_lens_cpu, extend_return_logprob = False
forward_batch.extend_logprob_start_lens_cpu, extend_logprob_pruned_lens_cpu = []
) for extend_len, start_len in zip(
] forward_batch.extend_seq_lens_cpu,
forward_batch.extend_logprob_start_lens_cpu,
):
if extend_len - start_len > 0:
extend_return_logprob = True
extend_logprob_pruned_lens_cpu.append(extend_len - start_len)
else: else:
extend_return_logprob = extend_return_top_logprob = ( extend_return_logprob = extend_return_top_logprob = (
extend_logprob_pruned_lens_cpu extend_token_ids_logprob
) = False ) = extend_logprob_pruned_lens_cpu = False
return cls( return cls(
forward_mode=forward_batch.forward_mode, forward_mode=forward_batch.forward_mode,
capture_hidden_mode=forward_batch.capture_hidden_mode, capture_hidden_mode=forward_batch.capture_hidden_mode,
extend_return_logprob=extend_return_logprob, extend_return_logprob=extend_return_logprob,
extend_return_top_logprob=extend_return_top_logprob, extend_return_top_logprob=extend_return_top_logprob,
extend_token_ids_logprob=extend_token_ids_logprob,
extend_seq_lens=forward_batch.extend_seq_lens, extend_seq_lens=forward_batch.extend_seq_lens,
extend_seq_lens_cpu=forward_batch.extend_seq_lens_cpu, extend_seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
extend_logprob_start_lens_cpu=forward_batch.extend_logprob_start_lens_cpu, extend_logprob_start_lens_cpu=forward_batch.extend_logprob_start_lens_cpu,
extend_logprob_pruned_lens_cpu=extend_logprob_pruned_lens_cpu, extend_logprob_pruned_lens_cpu=extend_logprob_pruned_lens_cpu,
top_logprobs_nums=forward_batch.top_logprobs_nums, top_logprobs_nums=forward_batch.top_logprobs_nums,
token_ids_logprobs=forward_batch.token_ids_logprobs,
extend_input_logprob_token_ids_gpu=forward_batch.extend_input_logprob_token_ids_gpu,
padded_static_len=forward_batch.padded_static_len,
)
def compute_dp_attention_metadata(self, hidden_states: torch.Tensor):
if self.global_num_tokens_for_logprob_cpu is None:
# we are capturing cuda graph
return
cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0)
dp_rank = get_attention_dp_rank()
if dp_rank == 0:
dp_local_start_pos = torch.zeros_like(
self.global_num_tokens_for_logprob_gpu[0]
)
else:
dp_local_start_pos = cumtokens[dp_rank - 1]
dp_local_num_tokens = self.global_num_tokens_for_logprob_gpu[dp_rank]
gathered_buffer = torch.zeros(
(
sum(self.global_num_tokens_for_logprob_cpu),
hidden_states.shape[1],
),
dtype=hidden_states.dtype,
device=hidden_states.device,
) )
self.dp_local_start_pos = dp_local_start_pos
self.dp_local_num_tokens = dp_local_num_tokens
self.gathered_buffer = gathered_buffer
class LogitsProcessor(nn.Module): class LogitsProcessor(nn.Module):
def __init__( def __init__(
...@@ -115,6 +192,9 @@ class LogitsProcessor(nn.Module): ...@@ -115,6 +192,9 @@ class LogitsProcessor(nn.Module):
self.do_tensor_parallel_all_gather = ( self.do_tensor_parallel_all_gather = (
not skip_all_gather and get_tensor_model_parallel_world_size() > 1 not skip_all_gather and get_tensor_model_parallel_world_size() > 1
) )
self.do_tensor_parallel_all_gather_dp_attn = (
self.do_tensor_parallel_all_gather and get_attention_dp_size() != 1
)
self.final_logit_softcapping = getattr( self.final_logit_softcapping = getattr(
self.config, "final_logit_softcapping", None self.config, "final_logit_softcapping", None
) )
...@@ -124,6 +204,12 @@ class LogitsProcessor(nn.Module): ...@@ -124,6 +204,12 @@ class LogitsProcessor(nn.Module):
): ):
self.final_logit_softcapping = None self.final_logit_softcapping = None
from sglang.srt.managers.schedule_batch import global_server_args_dict
self.debug_tensor_dump_output_folder = global_server_args_dict[
"debug_tensor_dump_output_folder"
]
def forward( def forward(
self, self,
input_ids, input_ids,
...@@ -141,30 +227,74 @@ class LogitsProcessor(nn.Module): ...@@ -141,30 +227,74 @@ class LogitsProcessor(nn.Module):
): ):
pruned_states = hidden_states pruned_states = hidden_states
sample_indices = None sample_indices = None
input_logprob_indices = None
elif ( elif (
logits_metadata.forward_mode.is_extend() logits_metadata.forward_mode.is_extend()
and not logits_metadata.extend_return_logprob and not logits_metadata.extend_return_logprob
): ):
# Prefill without input logprobs. # Prefill without input logprobs.
last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1 if logits_metadata.padded_static_len < 0:
last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
else:
# If padding_static length is 5 and extended_seq_lens is [2, 3],
# then our batch looks like [t00, t01, p, p, p, t10, t11, t12, p, p]
# and this retrieves t01 and t12, which are the valid last tokens
idx = torch.arange(
len(logits_metadata.extend_seq_lens),
device=logits_metadata.extend_seq_lens.device,
)
last_index = (
idx * logits_metadata.padded_static_len
+ logits_metadata.extend_seq_lens
- 1
)
pruned_states = hidden_states[last_index] pruned_states = hidden_states[last_index]
sample_indices = None sample_indices = None
input_logprob_indices = None
else: else:
# Slice the requested tokens to compute logprob # Input logprobs are required.
# Find 3 different indices.
# 1. pruned_states: hidden states that we want logprobs from.
# 2. sample_indices: Indices that have sampled tokens.
# 3. input_logprob_indices: Indices that have input logprob tokens.
sample_index_pt = -1 sample_index_pt = -1
sample_indices = [] sample_indices = []
pt, pruned_states, pruned_input_ids = 0, [], [] input_logprob_indices_pt = 0
for start_len, extend_len in zip( input_logprob_indices = []
pt, pruned_states = 0, []
for extend_logprob_start_len, extend_len in zip(
logits_metadata.extend_logprob_start_lens_cpu, logits_metadata.extend_logprob_start_lens_cpu,
logits_metadata.extend_seq_lens_cpu, logits_metadata.extend_seq_lens_cpu,
): ):
# It can happen in chunked prefill. We still need to sample 1 token,
# But we don't want to include it in input logprob.
if extend_len == extend_logprob_start_len:
start_len = extend_logprob_start_len - 1
else:
start_len = extend_logprob_start_len
# We always need at least 1 token to sample because that's required
# by a caller.
assert extend_len > start_len
pruned_states.append(hidden_states[pt + start_len : pt + extend_len]) pruned_states.append(hidden_states[pt + start_len : pt + extend_len])
pt += extend_len
sample_index_pt += extend_len - start_len sample_index_pt += extend_len - start_len
sample_indices.append(sample_index_pt) sample_indices.append(sample_index_pt)
pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len]) input_logprob_indices.extend(
pt += extend_len [
input_logprob_indices_pt + i
for i in range(extend_len - extend_logprob_start_len)
]
)
input_logprob_indices_pt += extend_len - start_len
pruned_states = torch.cat(pruned_states) pruned_states = torch.cat(pruned_states)
sample_indices = torch.tensor(
sample_indices, device=pruned_states.device, dtype=torch.int64
)
input_logprob_indices = torch.tensor(
input_logprob_indices, device=pruned_states.device, dtype=torch.int64
)
# Compute logits for both input and sampled tokens. # Compute logits for both input and sampled tokens.
logits = self._get_logits(pruned_states, lm_head, logits_metadata) logits = self._get_logits(pruned_states, lm_head, logits_metadata)
...@@ -172,28 +302,51 @@ class LogitsProcessor(nn.Module): ...@@ -172,28 +302,51 @@ class LogitsProcessor(nn.Module):
logits[sample_indices] if sample_indices is not None else logits logits[sample_indices] if sample_indices is not None else logits
) )
if ( if self.debug_tensor_dump_output_folder:
not logits_metadata.extend_return_logprob assert (
or logits_metadata.capture_hidden_mode.need_capture() not self.do_tensor_parallel_all_gather or get_attention_dp_size() == 1
): ), "dp attention + sharded lm_head doesn't support full logits"
full_logits = self._get_logits(hidden_states, lm_head, logits_metadata)
dump_to_file(self.debug_tensor_dump_output_folder, "logits", full_logits)
hidden_states_to_store: Optional[torch.Tensor] = None
if logits_metadata.capture_hidden_mode.need_capture():
if logits_metadata.capture_hidden_mode.is_full():
hidden_states_to_store = hidden_states
elif logits_metadata.capture_hidden_mode.is_last():
# Get the last token hidden states. If sample_indices is None,
# pruned states only contain the last tokens already.
hidden_states_to_store = (
pruned_states[sample_indices] if sample_indices else pruned_states
)
else:
assert False, "Should never reach"
if not logits_metadata.extend_return_logprob:
# Decode mode or extend mode without return_logprob. # Decode mode or extend mode without return_logprob.
return LogitsProcessorOutput( return LogitsProcessorOutput(
next_token_logits=sampled_logits, next_token_logits=sampled_logits,
hidden_states=( hidden_states=hidden_states_to_store,
hidden_states
if logits_metadata.capture_hidden_mode.is_full()
else (
pruned_states
if logits_metadata.capture_hidden_mode.is_last()
else None
)
),
) )
else: else:
input_logprobs = logits input_logprobs = logits[input_logprob_indices]
del hidden_states, logits del hidden_states, logits
# Normalize the logprob w/o temperature, top-p # Normalize the logprob w/o temperature, top-p
pruned_lens = torch.tensor(
logits_metadata.extend_logprob_pruned_lens_cpu,
device=input_logprobs.device,
)
if logits_metadata.temp_scaled_logprobs:
logits_metadata.temperature = torch.repeat_interleave(
logits_metadata.temperature.view(-1),
pruned_lens,
).view(-1, 1)
if logits_metadata.top_p_normalized_logprobs:
logits_metadata.top_p = torch.repeat_interleave(
logits_metadata.top_p,
pruned_lens,
)
input_logprobs = self.compute_temp_top_p_normalized_logprobs( input_logprobs = self.compute_temp_top_p_normalized_logprobs(
input_logprobs, logits_metadata input_logprobs, logits_metadata
) )
...@@ -207,14 +360,18 @@ class LogitsProcessor(nn.Module): ...@@ -207,14 +360,18 @@ class LogitsProcessor(nn.Module):
else: else:
input_top_logprobs_val = input_top_logprobs_idx = None input_top_logprobs_val = input_top_logprobs_idx = None
# Get the logprob of given token id
if logits_metadata.extend_token_ids_logprob:
(
input_token_ids_logprobs_val,
input_token_ids_logprobs_idx,
) = self.get_token_ids_logprobs(input_logprobs, logits_metadata)
else:
input_token_ids_logprobs_val = input_token_ids_logprobs_idx = None
input_token_logprobs = input_logprobs[ input_token_logprobs = input_logprobs[
torch.arange(input_logprobs.shape[0], device=input_logprobs.device), torch.arange(input_logprobs.shape[0], device=input_logprobs.device),
torch.cat( logits_metadata.extend_input_logprob_token_ids_gpu,
[
torch.cat(pruned_input_ids)[1:],
torch.tensor([0], device=input_logprobs.device),
]
),
] ]
return LogitsProcessorOutput( return LogitsProcessorOutput(
...@@ -222,6 +379,9 @@ class LogitsProcessor(nn.Module): ...@@ -222,6 +379,9 @@ class LogitsProcessor(nn.Module):
input_token_logprobs=input_token_logprobs, input_token_logprobs=input_token_logprobs,
input_top_logprobs_val=input_top_logprobs_val, input_top_logprobs_val=input_top_logprobs_val,
input_top_logprobs_idx=input_top_logprobs_idx, input_top_logprobs_idx=input_top_logprobs_idx,
hidden_states=hidden_states_to_store,
input_token_ids_logprobs_val=input_token_ids_logprobs_val,
input_token_ids_logprobs_idx=input_token_ids_logprobs_idx,
) )
def _get_logits( def _get_logits(
...@@ -231,10 +391,24 @@ class LogitsProcessor(nn.Module): ...@@ -231,10 +391,24 @@ class LogitsProcessor(nn.Module):
logits_metadata: LogitsMetadata, logits_metadata: LogitsMetadata,
embedding_bias: Optional[torch.Tensor] = None, embedding_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Get logits from hidden_states.""" """Get logits from hidden_states.
If sampled_logits_only is True, it means hidden_states only contain the
last position (e.g., extend without input logprobs). The caller should
guarantee the given hidden_states follow this constraint.
"""
if self.do_tensor_parallel_all_gather_dp_attn:
logits_metadata.compute_dp_attention_metadata(hidden_states)
hidden_states, local_hidden_states = (
logits_metadata.gathered_buffer,
hidden_states.clone(),
)
dp_gather(hidden_states, local_hidden_states, logits_metadata, "embedding")
if hasattr(lm_head, "weight"): if hasattr(lm_head, "weight"):
logits = torch.matmul(hidden_states, lm_head.weight.T) logits = torch.matmul(
hidden_states.to(lm_head.weight.dtype), lm_head.weight.T
)
else: else:
# GGUF models # GGUF models
logits = lm_head.linear_method.apply(lm_head, hidden_states, embedding_bias) logits = lm_head.linear_method.apply(lm_head, hidden_states, embedding_bias)
...@@ -245,6 +419,17 @@ class LogitsProcessor(nn.Module): ...@@ -245,6 +419,17 @@ class LogitsProcessor(nn.Module):
if self.do_tensor_parallel_all_gather: if self.do_tensor_parallel_all_gather:
logits = tensor_model_parallel_all_gather(logits) logits = tensor_model_parallel_all_gather(logits)
if self.do_tensor_parallel_all_gather_dp_attn:
logits, global_logits = (
torch.empty(
(local_hidden_states.shape[0], logits.shape[1]),
device=logits.device,
dtype=logits.dtype,
),
logits,
)
dp_scatter(logits, global_logits, logits_metadata)
logits = logits[:, : self.config.vocab_size].float() logits = logits[:, : self.config.vocab_size].float()
if self.final_logit_softcapping: if self.final_logit_softcapping:
...@@ -272,21 +457,66 @@ class LogitsProcessor(nn.Module): ...@@ -272,21 +457,66 @@ class LogitsProcessor(nn.Module):
continue continue
input_top_logprobs_val.append( input_top_logprobs_val.append(
[values[pt + j][:k] for j in range(pruned_len - 1)] [values[pt + j][:k] for j in range(pruned_len)]
) )
input_top_logprobs_idx.append( input_top_logprobs_idx.append(
[indices[pt + j][:k] for j in range(pruned_len - 1)] [indices[pt + j][:k] for j in range(pruned_len)]
) )
pt += pruned_len pt += pruned_len
return input_top_logprobs_val, input_top_logprobs_idx return input_top_logprobs_val, input_top_logprobs_idx
@staticmethod
def get_token_ids_logprobs(
all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata
):
input_token_ids_logprobs_val, input_token_ids_logprobs_idx = [], []
pt = 0
for token_ids, pruned_len in zip(
logits_metadata.token_ids_logprobs,
logits_metadata.extend_logprob_pruned_lens_cpu,
):
if pruned_len <= 0:
input_token_ids_logprobs_val.append([])
input_token_ids_logprobs_idx.append([])
continue
input_token_ids_logprobs_val.append(
[all_logprobs[pt + j, token_ids].tolist() for j in range(pruned_len)]
)
input_token_ids_logprobs_idx.append([token_ids for _ in range(pruned_len)])
pt += pruned_len
return input_token_ids_logprobs_val, input_token_ids_logprobs_idx
@staticmethod @staticmethod
def compute_temp_top_p_normalized_logprobs( def compute_temp_top_p_normalized_logprobs(
last_logits: torch.Tensor, logits_metadata: LogitsMetadata last_logits: torch.Tensor, logits_metadata: LogitsMetadata
) -> torch.Tensor: ) -> torch.Tensor:
# TODO: Implement the temp and top-p normalization """
return torch.nn.functional.log_softmax(last_logits, dim=-1) compute logprobs for the output token from the given logits.
Returns:
torch.Tensor: logprobs from logits
"""
# Scale logits if temperature scaling is enabled
if logits_metadata.temp_scaled_logprobs:
last_logits = last_logits / logits_metadata.temperature
# Normalize logprobs if top_p normalization is enabled
# NOTE: only normalize logprobs when top_p is set and not equal to 1.0
if (
logits_metadata.top_p_normalized_logprobs
and (logits_metadata.top_p != 1.0).any()
):
from sglang.srt.layers.sampler import top_p_normalize_probs_torch
probs = torch.softmax(last_logits, dim=-1)
del last_logits
probs = top_p_normalize_probs_torch(probs, logits_metadata.top_p)
return torch.log(probs)
else:
return torch.nn.functional.log_softmax(last_logits, dim=-1)
@triton.jit @triton.jit
......
...@@ -144,6 +144,73 @@ def silu_and_mul_triton_kernel( ...@@ -144,6 +144,73 @@ def silu_and_mul_triton_kernel(
tl.store(down_input_ptr + offset, silu_mul_output, mask=mask) tl.store(down_input_ptr + offset, silu_mul_output, mask=mask)
@triton.jit
def tanh(x):
return 2 * tl.sigmoid(2 * x) - 1
@triton.jit
def gelu_and_mul_triton_kernel(
gateup_output,
down_input,
hidden_size,
reorder_topk_ids,
scales,
start_expert_id,
end_expert_id,
BLOCK_SIZE: tl.constexpr,
):
InDtype = gateup_output.dtype.element_ty
OutDtype = down_input.dtype.element_ty
half_hidden_size = hidden_size // 2
pid = tl.program_id(0)
expert_id = tl.load(reorder_topk_ids + pid)
if expert_id >= start_expert_id and expert_id <= end_expert_id:
gateup_output_ptr = gateup_output + pid * hidden_size
gate_output_ptr = gateup_output_ptr
up_output_ptr = gateup_output_ptr + half_hidden_size
down_input_ptr = down_input + pid * half_hidden_size
if scales is not None:
scale = tl.load(scales + expert_id - start_expert_id)
scale = (1 / scale).to(InDtype)
else:
scale = 1
for start_offset in tl.range(0, half_hidden_size, BLOCK_SIZE):
offset = start_offset + tl.arange(0, BLOCK_SIZE)
mask = offset < half_hidden_size
gate_output = tl.load(gate_output_ptr + offset, mask=mask).to(tl.float32)
up_output = tl.load(up_output_ptr + offset, mask=mask)
# gelu & mul & quantize
# https://pytorch.org/docs/stable/generated/torch.nn.GELU.html
# sqrt(2/pi)
kAlpha = 0.7978845608028654
gate_output = (
0.5
* gate_output
* (
1
+ tanh(
kAlpha
* (
gate_output
+ 0.044715 * gate_output * gate_output * gate_output
)
)
)
)
gate_output = gate_output.to(InDtype)
gelu_mul_output = gate_output * up_output * scale
gelu_mul_output = gelu_mul_output.to(OutDtype)
tl.store(down_input_ptr + offset, gelu_mul_output, mask=mask)
@triton.jit @triton.jit
def post_reorder_triton_kernel( def post_reorder_triton_kernel(
down_output_ptr, down_output_ptr,
......
...@@ -11,6 +11,7 @@ from sglang.srt.distributed import ( ...@@ -11,6 +11,7 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
) )
from sglang.srt.layers.moe.ep_moe.kernels import ( from sglang.srt.layers.moe.ep_moe.kernels import (
gelu_and_mul_triton_kernel,
grouped_gemm_triton, grouped_gemm_triton,
post_reorder_triton_kernel, post_reorder_triton_kernel,
pre_reorder_triton_kernel, pre_reorder_triton_kernel,
...@@ -296,6 +297,17 @@ class EPMoE(torch.nn.Module): ...@@ -296,6 +297,17 @@ class EPMoE(torch.nn.Module):
self.end_expert_id, self.end_expert_id,
BLOCK_SIZE=512, BLOCK_SIZE=512,
) )
elif self.activation == "gelu":
gelu_and_mul_triton_kernel[(gateup_output.shape[0],)](
gateup_output,
down_input,
gateup_output.shape[1],
reorder_topk_ids,
self.w2_input_scale,
self.start_expert_id,
self.end_expert_id,
BLOCK_SIZE=512,
)
else: else:
raise ValueError(f"Unsupported activation: {self.activation=}") raise ValueError(f"Unsupported activation: {self.activation=}")
......
...@@ -24,6 +24,8 @@ def fused_moe_forward_native( ...@@ -24,6 +24,8 @@ def fused_moe_forward_native(
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None, correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu", activation: str = "silu",
inplace: bool = True,
no_combine: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
topk_weights, topk_ids = select_experts( topk_weights, topk_ids = select_experts(
hidden_states=x, hidden_states=x,
......
...@@ -23,7 +23,7 @@ from sglang.srt.utils import ( ...@@ -23,7 +23,7 @@ from sglang.srt.utils import (
is_hip, is_hip,
) )
is_hip_flag = is_hip() is_hip_ = is_hip()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -487,6 +487,7 @@ def invoke_fused_moe_kernel( ...@@ -487,6 +487,7 @@ def invoke_fused_moe_kernel(
use_int8_w8a8: bool, use_int8_w8a8: bool,
use_int8_w8a16: bool, use_int8_w8a16: bool,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
no_combine: bool = False,
) -> None: ) -> None:
assert topk_weights.stride(1) == 1 assert topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1 assert sorted_token_ids.stride(0) == 1
...@@ -646,7 +647,7 @@ def get_default_config( ...@@ -646,7 +647,7 @@ def get_default_config(
"BLOCK_SIZE_K": 128, "BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32, "GROUP_SIZE_M": 32,
"num_warps": 8, "num_warps": 8,
"num_stages": 2 if is_hip_flag else 4, "num_stages": 2 if is_hip_ else 4,
} }
if M <= E: if M <= E:
config = { config = {
...@@ -655,7 +656,7 @@ def get_default_config( ...@@ -655,7 +656,7 @@ def get_default_config(
"BLOCK_SIZE_K": 128, "BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1, "GROUP_SIZE_M": 1,
"num_warps": 4, "num_warps": 4,
"num_stages": 2 if is_hip_flag else 4, "num_stages": 2 if is_hip_ else 4,
} }
else: else:
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_shape[1] # Block-wise quant: BLOCK_SIZE_K must be divisable by block_shape[1]
...@@ -665,7 +666,7 @@ def get_default_config( ...@@ -665,7 +666,7 @@ def get_default_config(
"BLOCK_SIZE_K": block_shape[1], "BLOCK_SIZE_K": block_shape[1],
"GROUP_SIZE_M": 32, "GROUP_SIZE_M": 32,
"num_warps": 4, "num_warps": 4,
"num_stages": 2 if is_hip_flag else 3, "num_stages": 2 if is_hip_ else 3,
} }
else: else:
config = { config = {
...@@ -814,6 +815,7 @@ def outplace_fused_experts( ...@@ -814,6 +815,7 @@ def outplace_fused_experts(
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
no_combine: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
return fused_experts_impl( return fused_experts_impl(
hidden_states, hidden_states,
...@@ -831,6 +833,7 @@ def outplace_fused_experts( ...@@ -831,6 +833,7 @@ def outplace_fused_experts(
a1_scale, a1_scale,
a2_scale, a2_scale,
block_shape, block_shape,
no_combine=no_combine,
) )
...@@ -849,6 +852,7 @@ def outplace_fused_experts_fake( ...@@ -849,6 +852,7 @@ def outplace_fused_experts_fake(
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
no_combine: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
return torch.empty_like(hidden_states) return torch.empty_like(hidden_states)
...@@ -877,8 +881,10 @@ def fused_experts( ...@@ -877,8 +881,10 @@ def fused_experts(
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
no_combine: bool = False,
): ):
if inplace: if inplace:
assert not no_combine, "no combine + inplace makes no sense"
torch.ops.sglang.inplace_fused_experts( torch.ops.sglang.inplace_fused_experts(
hidden_states, hidden_states,
w1, w1,
...@@ -912,6 +918,7 @@ def fused_experts( ...@@ -912,6 +918,7 @@ def fused_experts(
a1_scale, a1_scale,
a2_scale, a2_scale,
block_shape, block_shape,
no_combine=no_combine,
) )
...@@ -931,6 +938,7 @@ def fused_experts_impl( ...@@ -931,6 +938,7 @@ def fused_experts_impl(
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
no_combine: bool = False,
): ):
padded_size = padding_size padded_size = padding_size
if not use_fp8_w8a8 or not use_int8_w8a8 or block_shape is not None: if not use_fp8_w8a8 or not use_int8_w8a8 or block_shape is not None:
...@@ -987,7 +995,14 @@ def fused_experts_impl( ...@@ -987,7 +995,14 @@ def fused_experts_impl(
compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16 compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
if inplace: if no_combine:
assert not inplace
out_hidden_states = torch.empty(
(num_tokens, topk_ids.shape[1], w2.shape[1]),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
elif inplace:
out_hidden_states = hidden_states out_hidden_states = hidden_states
else: else:
out_hidden_states = torch.empty_like(hidden_states) out_hidden_states = torch.empty_like(hidden_states)
...@@ -1057,7 +1072,11 @@ def fused_experts_impl( ...@@ -1057,7 +1072,11 @@ def fused_experts_impl(
invoke_fused_moe_kernel( invoke_fused_moe_kernel(
intermediate_cache2, intermediate_cache2,
w2, w2,
intermediate_cache3, (
intermediate_cache3
if not no_combine and topk_ids.shape[1] != 1
else out_hidden_states[begin_chunk_idx:end_chunk_idx]
),
a2_scale, a2_scale,
w2_scale, w2_scale,
curr_topk_weights, curr_topk_weights,
...@@ -1075,16 +1094,16 @@ def fused_experts_impl( ...@@ -1075,16 +1094,16 @@ def fused_experts_impl(
block_shape=block_shape, block_shape=block_shape,
) )
if is_hip_flag: if no_combine:
pass
elif is_hip_:
ops.moe_sum( ops.moe_sum(
intermediate_cache3.view(*intermediate_cache3.shape), intermediate_cache3.view(*intermediate_cache3.shape),
out_hidden_states[begin_chunk_idx:end_chunk_idx], out_hidden_states[begin_chunk_idx:end_chunk_idx],
) )
else: else:
if topk_ids.shape[1] == 1: if topk_ids.shape[1] == 1:
out_hidden_states[begin_chunk_idx:end_chunk_idx].copy_( pass # we write directly into out_hidden_states
intermediate_cache3[:, 0]
)
elif topk_ids.shape[1] == 2: elif topk_ids.shape[1] == 2:
torch.add( torch.add(
intermediate_cache3[:, 0], intermediate_cache3[:, 0],
...@@ -1122,6 +1141,7 @@ def fused_moe( ...@@ -1122,6 +1141,7 @@ def fused_moe(
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
no_combine: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
This function computes a Mixture of Experts (MoE) layer using two sets of This function computes a Mixture of Experts (MoE) layer using two sets of
...@@ -1191,4 +1211,5 @@ def fused_moe( ...@@ -1191,4 +1211,5 @@ def fused_moe(
a1_scale=a1_scale, a1_scale=a1_scale,
a2_scale=a2_scale, a2_scale=a2_scale,
block_shape=block_shape, block_shape=block_shape,
no_combine=no_combine,
) )
...@@ -125,6 +125,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -125,6 +125,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None, correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu", activation: str = "silu",
inplace: bool = True,
no_combine: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
return self.forward( return self.forward(
x=x, x=x,
...@@ -138,6 +140,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -138,6 +140,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
correction_bias=correction_bias, correction_bias=correction_bias,
activation=activation, activation=activation,
inplace=inplace,
no_combine=no_combine,
) )
def forward_cuda( def forward_cuda(
...@@ -153,6 +157,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -153,6 +157,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None, correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu", activation: str = "silu",
inplace: bool = True,
no_combine: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
topk_weights, topk_ids = select_experts( topk_weights, topk_ids = select_experts(
hidden_states=x, hidden_states=x,
...@@ -171,6 +177,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -171,6 +177,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
from aiter.fused_moe import fused_experts_ck from aiter.fused_moe import fused_experts_ck
assert activation == "silu", f"{activation=} is not supported." assert activation == "silu", f"{activation=} is not supported."
assert not no_combine, "unsupported"
return fused_experts_ck( return fused_experts_ck(
hidden_states=x, hidden_states=x,
...@@ -186,8 +193,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -186,8 +193,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
w2=layer.w2_weight, w2=layer.w2_weight,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=inplace and not no_combine,
activation=activation, activation=activation,
no_combine=no_combine,
) )
def forward_cpu( def forward_cpu(
...@@ -202,6 +210,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -202,6 +210,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None, correction_bias: Optional[torch.Tensor] = None,
inplace: bool = True,
) -> torch.Tensor: ) -> torch.Tensor:
return moe_forward_native( return moe_forward_native(
layer, layer,
...@@ -241,6 +250,7 @@ class FusedMoE(torch.nn.Module): ...@@ -241,6 +250,7 @@ class FusedMoE(torch.nn.Module):
reduce_results: Whether to all all_reduce on the output of the layer reduce_results: Whether to all all_reduce on the output of the layer
renomalize: Whether to renormalize the logits in the fused_moe kernel renomalize: Whether to renormalize the logits in the fused_moe kernel
quant_config: Quantization configure. quant_config: Quantization configure.
inplace: suggestion to compute inplace (modify input activation).
""" """
def __init__( def __init__(
...@@ -262,6 +272,8 @@ class FusedMoE(torch.nn.Module): ...@@ -262,6 +272,8 @@ class FusedMoE(torch.nn.Module):
correction_bias: Optional[torch.Tensor] = None, correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu", activation: str = "silu",
use_presharded_weights: bool = False, use_presharded_weights: bool = False,
inplace: bool = True,
no_combine: bool = False,
): ):
super().__init__() super().__init__()
...@@ -285,6 +297,9 @@ class FusedMoE(torch.nn.Module): ...@@ -285,6 +297,9 @@ class FusedMoE(torch.nn.Module):
self.custom_routing_function = custom_routing_function self.custom_routing_function = custom_routing_function
self.correction_bias = correction_bias self.correction_bias = correction_bias
self.activation = activation self.activation = activation
self.use_presharded_weights = use_presharded_weights
self.inplace = inplace
self.no_combine = no_combine
if quant_config is None: if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = ( self.quant_method: Optional[QuantizeMethodBase] = (
...@@ -304,7 +319,6 @@ class FusedMoE(torch.nn.Module): ...@@ -304,7 +319,6 @@ class FusedMoE(torch.nn.Module):
params_dtype=params_dtype, params_dtype=params_dtype,
weight_loader=self.weight_loader, weight_loader=self.weight_loader,
) )
self.use_presharded_weights = use_presharded_weights
def _load_per_tensor_weight_scale( def _load_per_tensor_weight_scale(
self, self,
...@@ -598,6 +612,8 @@ class FusedMoE(torch.nn.Module): ...@@ -598,6 +612,8 @@ class FusedMoE(torch.nn.Module):
custom_routing_function=self.custom_routing_function, custom_routing_function=self.custom_routing_function,
correction_bias=self.correction_bias, correction_bias=self.correction_bias,
activation=self.activation, activation=self.activation,
inplace=self.inplace,
no_combine=self.no_combine,
) )
if self.reduce_results and self.tp_size > 1: if self.reduce_results and self.tp_size > 1:
......
...@@ -771,6 +771,8 @@ class Fp8MoEMethod: ...@@ -771,6 +771,8 @@ class Fp8MoEMethod:
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None, correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu", activation: str = "silu",
inplace: bool = True,
no_combine: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import select_experts
...@@ -793,6 +795,7 @@ class Fp8MoEMethod: ...@@ -793,6 +795,7 @@ class Fp8MoEMethod:
from aiter.fused_moe import fused_experts_ck from aiter.fused_moe import fused_experts_ck
assert activation == "silu", f"{activation=} is not supported." assert activation == "silu", f"{activation=} is not supported."
assert not no_combine, f"{no_combine=} is not supported."
return fused_experts_ck( return fused_experts_ck(
x, x,
...@@ -823,7 +826,7 @@ class Fp8MoEMethod: ...@@ -823,7 +826,7 @@ class Fp8MoEMethod:
layer.w2_weight, layer.w2_weight,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=inplace and not no_combine,
activation=activation, activation=activation,
use_fp8_w8a8=True, use_fp8_w8a8=True,
w1_scale=( w1_scale=(
...@@ -839,6 +842,7 @@ class Fp8MoEMethod: ...@@ -839,6 +842,7 @@ class Fp8MoEMethod:
a1_scale=layer.w13_input_scale, a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size, block_shape=self.quant_config.weight_block_size,
no_combine=no_combine,
) )
......
...@@ -707,7 +707,6 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): ...@@ -707,7 +707,6 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
cos = freqs.cos() * self.mscale cos = freqs.cos() * self.mscale
sin = freqs.sin() * self.mscale sin = freqs.sin() * self.mscale
cache = torch.cat((cos, sin), dim=-1) cache = torch.cat((cos, sin), dim=-1)
print("Cache shape", cache.shape)
return cache return cache
def forward( def forward(
......
import logging import logging
from typing import List from typing import List, Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -41,7 +41,21 @@ class Sampler(nn.Module): ...@@ -41,7 +41,21 @@ class Sampler(nn.Module):
sampling_info: SamplingBatchInfo, sampling_info: SamplingBatchInfo,
return_logprob: bool, return_logprob: bool,
top_logprobs_nums: List[int], top_logprobs_nums: List[int],
token_ids_logprobs: List[List[int]],
batch_next_token_ids: Optional[torch.Tensor] = None,
): ):
"""Run a sampler & compute logprobs and update logits_output accordingly.
Args:
logits_output: The logits from the model forward
sampling_info: Metadata for sampling
return_logprob: If set, store the output logprob information to
logits_output
top_logprobs_nums: Number of top lobprobs per sequence in a batch
batch_next_token_ids: next token IDs. If set, skip sampling and only
compute output logprobs It is used for speculative decoding which
performs sampling in draft workers.
"""
logits = logits_output.next_token_logits logits = logits_output.next_token_logits
# Apply the custom logit processors if registered in the sampling info. # Apply the custom logit processors if registered in the sampling info.
...@@ -58,13 +72,15 @@ class Sampler(nn.Module): ...@@ -58,13 +72,15 @@ class Sampler(nn.Module):
if sampling_info.is_all_greedy: if sampling_info.is_all_greedy:
# Use torch.argmax if all requests use greedy sampling # Use torch.argmax if all requests use greedy sampling
batch_next_token_ids = torch.argmax(logits, -1) if batch_next_token_ids is None:
batch_next_token_ids = torch.argmax(logits, -1)
if return_logprob: if return_logprob:
logprobs = torch.nn.functional.log_softmax(logits, dim=-1) logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
else: else:
# Post process logits # Post process logits
logits.div_(sampling_info.temperatures) logits.div_(sampling_info.temperatures)
probs = torch.softmax(logits, dim=-1) logits[:] = torch.softmax(logits, dim=-1)
probs = logits
del logits del logits
if global_server_args_dict["sampling_backend"] == "flashinfer": if global_server_args_dict["sampling_backend"] == "flashinfer":
...@@ -78,38 +94,43 @@ class Sampler(nn.Module): ...@@ -78,38 +94,43 @@ class Sampler(nn.Module):
top_p_normalize_probs_torch(probs, sampling_info.top_ps) top_p_normalize_probs_torch(probs, sampling_info.top_ps)
).clamp(min=torch.finfo(probs.dtype).min) ).clamp(min=torch.finfo(probs.dtype).min)
max_top_k_round, batch_size = 32, probs.shape[0] if batch_next_token_ids is None:
uniform_samples = torch.rand( max_top_k_round, batch_size = 32, probs.shape[0]
(max_top_k_round, batch_size), device=probs.device uniform_samples = torch.rand(
) (max_top_k_round, batch_size), device=probs.device
if sampling_info.need_min_p_sampling:
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
batch_next_token_ids = min_p_sampling_from_probs(
probs, uniform_samples, sampling_info.min_ps
) )
else: if sampling_info.need_min_p_sampling:
batch_next_token_ids, success = top_k_top_p_sampling_from_probs( probs = top_k_renorm_prob(probs, sampling_info.top_ks)
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
batch_next_token_ids = min_p_sampling_from_probs(
probs, uniform_samples, sampling_info.min_ps
)
else:
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
probs,
uniform_samples,
sampling_info.top_ks,
sampling_info.top_ps,
filter_apply_order="joint",
)
if self.use_nan_detection and not torch.all(success):
logger.warning("Detected errors during sampling!")
batch_next_token_ids = torch.zeros_like(
batch_next_token_ids
)
elif global_server_args_dict["sampling_backend"] == "pytorch":
if batch_next_token_ids is None:
# A slower fallback implementation with torch native operations.
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
probs, probs,
uniform_samples,
sampling_info.top_ks, sampling_info.top_ks,
sampling_info.top_ps, sampling_info.top_ps,
filter_apply_order="joint", sampling_info.min_ps,
sampling_info.need_min_p_sampling,
) )
if self.use_nan_detection and not torch.all(success):
logger.warning("Detected errors during sampling!")
batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
elif global_server_args_dict["sampling_backend"] == "pytorch":
# A slower fallback implementation with torch native operations.
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
probs,
sampling_info.top_ks,
sampling_info.top_ps,
sampling_info.min_ps,
sampling_info.need_min_p_sampling,
)
if return_logprob: if return_logprob:
# clamp to avoid -inf # clamp to avoid -inf
logprobs = torch.log( logprobs = torch.log(
...@@ -128,6 +149,12 @@ class Sampler(nn.Module): ...@@ -128,6 +149,12 @@ class Sampler(nn.Module):
logits_output.next_token_top_logprobs_idx, logits_output.next_token_top_logprobs_idx,
) = get_top_logprobs(logprobs, top_logprobs_nums) ) = get_top_logprobs(logprobs, top_logprobs_nums)
if any(x is not None for x in token_ids_logprobs):
(
logits_output.next_token_token_ids_logprobs_val,
logits_output.next_token_token_ids_logprobs_idx,
) = get_token_ids_logprobs(logprobs, token_ids_logprobs)
logits_output.next_token_logprobs = logprobs[ logits_output.next_token_logprobs = logprobs[
torch.arange(len(batch_next_token_ids), device=sampling_info.device), torch.arange(len(batch_next_token_ids), device=sampling_info.device),
batch_next_token_ids, batch_next_token_ids,
...@@ -223,6 +250,10 @@ def top_p_normalize_probs_torch( ...@@ -223,6 +250,10 @@ def top_p_normalize_probs_torch(
def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]): def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]):
assert len(top_logprobs_nums) == logprobs.shape[0], (
len(top_logprobs_nums),
logprobs.shape[0],
)
max_k = max(top_logprobs_nums) max_k = max(top_logprobs_nums)
ret = logprobs.topk(max_k, dim=1) ret = logprobs.topk(max_k, dim=1)
values = ret.values.tolist() values = ret.values.tolist()
...@@ -234,3 +265,17 @@ def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]): ...@@ -234,3 +265,17 @@ def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]):
output_top_logprobs_val.append(values[i][:k]) output_top_logprobs_val.append(values[i][:k])
output_top_logprobs_idx.append(indices[i][:k]) output_top_logprobs_idx.append(indices[i][:k])
return output_top_logprobs_val, output_top_logprobs_idx return output_top_logprobs_val, output_top_logprobs_idx
def get_token_ids_logprobs(logprobs: torch.Tensor, token_ids_logprobs: List[List[int]]):
output_token_ids_logprobs_val = []
output_token_ids_logprobs_idx = []
for i, token_ids in enumerate(token_ids_logprobs):
if token_ids is not None:
output_token_ids_logprobs_val.append(logprobs[i, token_ids].tolist())
output_token_ids_logprobs_idx.append(token_ids)
else:
output_token_ids_logprobs_val.append([])
output_token_ids_logprobs_idx.append([])
return output_token_ids_logprobs_val, output_token_ids_logprobs_idx
...@@ -457,7 +457,7 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -457,7 +457,7 @@ class VocabParallelEmbedding(torch.nn.Module):
assert loaded_weight.shape[output_dim] == ( assert loaded_weight.shape[output_dim] == (
self.org_vocab_size self.org_vocab_size
// (self.tp_size if self.use_presharded_weights else 1) // (self.tp_size if self.use_presharded_weights else 1)
) ), f"{self.org_vocab_size=} {self.use_presharded_weights=} {loaded_weight.shape[output_dim]=}"
# Copy the data. # Copy the data.
if not self.use_presharded_weights: if not self.use_presharded_weights:
......
...@@ -28,6 +28,7 @@ if __name__ == "__main__": ...@@ -28,6 +28,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--url", type=str, default="http://localhost:30000") parser.add_argument("--url", type=str, default="http://localhost:30000")
parser.add_argument("--log-requests", action="store_true") parser.add_argument("--log-requests", action="store_true")
parser.add_argument("--log-requests-level", type=int, default=2)
parser.add_argument( parser.add_argument(
"--dump-requests-folder", type=str, default="/tmp/sglang_request_dump" "--dump-requests-folder", type=str, default="/tmp/sglang_request_dump"
) )
...@@ -38,7 +39,7 @@ if __name__ == "__main__": ...@@ -38,7 +39,7 @@ if __name__ == "__main__":
args.url + "/configure_logging", args.url + "/configure_logging",
json={ json={
"log_requests": args.log_requests, "log_requests": args.log_requests,
"log_requests_level": 1, # Log full requests "log_requests_level": args.log_requests_level, # Log full requests
"dump_requests_folder": args.dump_requests_folder, "dump_requests_folder": args.dump_requests_folder,
"dump_requests_threshold": args.dump_requests_threshold, "dump_requests_threshold": args.dump_requests_threshold,
}, },
......
...@@ -198,6 +198,8 @@ class DataParallelController: ...@@ -198,6 +198,8 @@ class DataParallelController:
self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"] self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"]
self.max_req_input_len = scheduler_info[0]["max_req_input_len"] self.max_req_input_len = scheduler_info[0]["max_req_input_len"]
print(f"{scheduler_info=}")
def round_robin_scheduler(self, req): def round_robin_scheduler(self, req):
self.workers[self.round_robin_counter].send_pyobj(req) self.workers[self.round_robin_counter].send_pyobj(req)
self.round_robin_counter = (self.round_robin_counter + 1) % len(self.workers) self.round_robin_counter = (self.round_robin_counter + 1) % len(self.workers)
...@@ -220,6 +222,7 @@ class DataParallelController: ...@@ -220,6 +222,7 @@ class DataParallelController:
TokenizedEmbeddingReqInput, TokenizedEmbeddingReqInput,
), ),
): ):
logger.info("dispatching")
self.dispatching(recv_req) self.dispatching(recv_req)
else: else:
# Send other control messages to first worker of tp group # Send other control messages to first worker of tp group
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
"""DetokenizerManager is a process that detokenizes the token ids.""" """DetokenizerManager is a process that detokenizes the token ids."""
import dataclasses import dataclasses
import json
import logging import logging
import os import os
import signal import signal
...@@ -27,11 +28,16 @@ import zmq ...@@ -27,11 +28,16 @@ import zmq
from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
BatchEmbeddingOut, BatchEmbeddingOut,
BatchMultimodalDecodeReq,
BatchStrOut, BatchStrOut,
BatchTokenIDOut, BatchTokenIDOut,
) )
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import configure_logger, get_zmq_socket from sglang.srt.utils import (
configure_logger,
get_zmq_socket,
kill_itself_when_parent_died,
)
from sglang.utils import ( from sglang.utils import (
TypeBasedDispatcher, TypeBasedDispatcher,
find_printable_text, find_printable_text,
...@@ -86,14 +92,23 @@ class DetokenizerManager: ...@@ -86,14 +92,23 @@ class DetokenizerManager:
) )
self.decode_status = LimitedCapacityDict(capacity=DETOKENIZER_MAX_STATES) self.decode_status = LimitedCapacityDict(capacity=DETOKENIZER_MAX_STATES)
self.is_dummy = server_args.load_format == "dummy"
self._request_dispatcher = TypeBasedDispatcher( self._request_dispatcher = TypeBasedDispatcher(
[ [
(BatchEmbeddingOut, self.handle_batch_embedding_out), (BatchEmbeddingOut, self.handle_batch_embedding_out),
(BatchTokenIDOut, self.handle_batch_token_id_out), (BatchTokenIDOut, self.handle_batch_token_id_out),
(BatchMultimodalDecodeReq, self.handle_multimodal_decode_req),
] ]
) )
def event_loop(self):
"""The event loop that handles requests"""
while True:
recv_obj = self.recv_from_scheduler.recv_pyobj()
output = self._request_dispatcher(recv_obj)
self.send_to_tokenizer.send_pyobj(output)
def trim_matched_stop( def trim_matched_stop(
self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool
): ):
...@@ -117,14 +132,6 @@ class DetokenizerManager: ...@@ -117,14 +132,6 @@ class DetokenizerManager:
return output[:-1] return output[:-1]
return output return output
def event_loop(self):
"""The event loop that handles requests"""
while True:
recv_obj = self.recv_from_scheduler.recv_pyobj()
output = self._request_dispatcher(recv_obj)
self.send_to_tokenizer.send_pyobj(output)
def handle_batch_embedding_out(self, recv_obj: BatchEmbeddingOut): def handle_batch_embedding_out(self, recv_obj: BatchEmbeddingOut):
# If it is embedding model, no detokenization is needed. # If it is embedding model, no detokenization is needed.
return recv_obj return recv_obj
...@@ -173,7 +180,6 @@ class DetokenizerManager: ...@@ -173,7 +180,6 @@ class DetokenizerManager:
# Incremental decoding # Incremental decoding
output_strs = [] output_strs = []
finished_reqs = []
for i in range(bs): for i in range(bs):
try: try:
s = self.decode_status[recv_obj.rids[i]] s = self.decode_status[recv_obj.rids[i]]
...@@ -196,8 +202,6 @@ class DetokenizerManager: ...@@ -196,8 +202,6 @@ class DetokenizerManager:
new_text = "" new_text = ""
else: else:
new_text = find_printable_text(new_text) new_text = find_printable_text(new_text)
else:
finished_reqs.append(recv_obj.rids[i])
output_strs.append( output_strs.append(
self.trim_matched_stop( self.trim_matched_stop(
...@@ -207,7 +211,7 @@ class DetokenizerManager: ...@@ -207,7 +211,7 @@ class DetokenizerManager:
) )
) )
out = BatchStrOut( return BatchStrOut(
rids=recv_obj.rids, rids=recv_obj.rids,
finished_reasons=recv_obj.finished_reasons, finished_reasons=recv_obj.finished_reasons,
output_strs=output_strs, output_strs=output_strs,
...@@ -223,14 +227,15 @@ class DetokenizerManager: ...@@ -223,14 +227,15 @@ class DetokenizerManager:
input_top_logprobs_idx=recv_obj.input_top_logprobs_idx, input_top_logprobs_idx=recv_obj.input_top_logprobs_idx,
output_top_logprobs_val=recv_obj.output_top_logprobs_val, output_top_logprobs_val=recv_obj.output_top_logprobs_val,
output_top_logprobs_idx=recv_obj.output_top_logprobs_idx, output_top_logprobs_idx=recv_obj.output_top_logprobs_idx,
input_token_ids_logprobs_val=recv_obj.input_token_ids_logprobs_val,
input_token_ids_logprobs_idx=recv_obj.input_token_ids_logprobs_idx,
output_token_ids_logprobs_val=recv_obj.output_token_ids_logprobs_val,
output_token_ids_logprobs_idx=recv_obj.output_token_ids_logprobs_idx,
output_hidden_states=recv_obj.output_hidden_states, output_hidden_states=recv_obj.output_hidden_states,
) )
# remove decodestatus for completed requests def handle_multimodal_decode_req(self, recv_obj: BatchMultimodalDecodeReq):
for rid in finished_reqs: raise NotImplementedError()
self.decode_status.pop(rid)
return out
class LimitedCapacityDict(OrderedDict): class LimitedCapacityDict(OrderedDict):
...@@ -250,6 +255,7 @@ def run_detokenizer_process( ...@@ -250,6 +255,7 @@ def run_detokenizer_process(
server_args: ServerArgs, server_args: ServerArgs,
port_args: PortArgs, port_args: PortArgs,
): ):
kill_itself_when_parent_died()
setproctitle.setproctitle("sglang::detokenizer") setproctitle.setproctitle("sglang::detokenizer")
configure_logger(server_args) configure_logger(server_args)
parent_process = psutil.Process().parent() parent_process = psutil.Process().parent()
......
...@@ -16,10 +16,11 @@ The definition of objects transfered between different ...@@ -16,10 +16,11 @@ The definition of objects transfered between different
processes (TokenizerManager, DetokenizerManager, Controller). processes (TokenizerManager, DetokenizerManager, Controller).
""" """
import copy
import uuid import uuid
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from typing import Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
from sglang.srt.managers.schedule_batch import BaseFinishReason from sglang.srt.managers.schedule_batch import BaseFinishReason
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
...@@ -55,6 +56,8 @@ class GenerateReqInput: ...@@ -55,6 +56,8 @@ class GenerateReqInput:
logprob_start_len: Optional[Union[List[int], int]] = None logprob_start_len: Optional[Union[List[int], int]] = None
# If return logprobs, the number of top logprobs to return at each position. # If return logprobs, the number of top logprobs to return at each position.
top_logprobs_num: Optional[Union[List[int], int]] = None top_logprobs_num: Optional[Union[List[int], int]] = None
# If return logprobs, the token ids to return logprob for.
token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None
# Whether to detokenize tokens in text in the returned logprobs. # Whether to detokenize tokens in text in the returned logprobs.
return_text_in_logprobs: bool = False return_text_in_logprobs: bool = False
# Whether to stream output. # Whether to stream output.
...@@ -146,6 +149,8 @@ class GenerateReqInput: ...@@ -146,6 +149,8 @@ class GenerateReqInput:
self.logprob_start_len = -1 self.logprob_start_len = -1
if self.top_logprobs_num is None: if self.top_logprobs_num is None:
self.top_logprobs_num = 0 self.top_logprobs_num = 0
if not self.token_ids_logprob: # covers both None and []
self.token_ids_logprob = None
else: else:
if self.parallel_sample_num == 1: if self.parallel_sample_num == 1:
num = self.batch_size num = self.batch_size
...@@ -191,6 +196,17 @@ class GenerateReqInput: ...@@ -191,6 +196,17 @@ class GenerateReqInput:
else: else:
assert self.parallel_sample_num == 1 assert self.parallel_sample_num == 1
if not self.token_ids_logprob: # covers both None and []
self.token_ids_logprob = [None] * num
elif not isinstance(self.token_ids_logprob, list):
self.token_ids_logprob = [[self.token_ids_logprob] for _ in range(num)]
elif not isinstance(self.token_ids_logprob[0], list):
self.token_ids_logprob = [
copy.deepcopy(self.token_ids_logprob) for _ in range(num)
]
else:
assert self.parallel_sample_num == 1
if self.custom_logit_processor is None: if self.custom_logit_processor is None:
self.custom_logit_processor = [None] * num self.custom_logit_processor = [None] * num
elif not isinstance(self.custom_logit_processor, list): elif not isinstance(self.custom_logit_processor, list):
...@@ -198,6 +214,12 @@ class GenerateReqInput: ...@@ -198,6 +214,12 @@ class GenerateReqInput:
else: else:
assert self.parallel_sample_num == 1 assert self.parallel_sample_num == 1
# Other checks
if self.session_params is not None:
assert isinstance(self.session_params, dict) or isinstance(
self.session_params[0], dict
)
def regenerate_rid(self): def regenerate_rid(self):
self.rid = uuid.uuid4().hex self.rid = uuid.uuid4().hex
return self.rid return self.rid
...@@ -212,6 +234,7 @@ class GenerateReqInput: ...@@ -212,6 +234,7 @@ class GenerateReqInput:
return_logprob=self.return_logprob[i], return_logprob=self.return_logprob[i],
logprob_start_len=self.logprob_start_len[i], logprob_start_len=self.logprob_start_len[i],
top_logprobs_num=self.top_logprobs_num[i], top_logprobs_num=self.top_logprobs_num[i],
token_ids_logprob=self.token_ids_logprob[i],
return_text_in_logprobs=self.return_text_in_logprobs, return_text_in_logprobs=self.return_text_in_logprobs,
stream=self.stream, stream=self.stream,
log_metrics=self.log_metrics, log_metrics=self.log_metrics,
...@@ -244,6 +267,8 @@ class TokenizedGenerateReqInput: ...@@ -244,6 +267,8 @@ class TokenizedGenerateReqInput:
logprob_start_len: int logprob_start_len: int
# If return logprobs, the number of top logprobs to return at each position. # If return logprobs, the number of top logprobs to return at each position.
top_logprobs_num: int top_logprobs_num: int
# If return logprobs, the token id to return logprob for
token_ids_logprob: List[int]
# Whether to stream output # Whether to stream output
stream: bool stream: bool
...@@ -378,10 +403,21 @@ class BatchTokenIDOut: ...@@ -378,10 +403,21 @@ class BatchTokenIDOut:
input_top_logprobs_idx: List[List] input_top_logprobs_idx: List[List]
output_top_logprobs_val: List[List] output_top_logprobs_val: List[List]
output_top_logprobs_idx: List[List] output_top_logprobs_idx: List[List]
input_token_ids_logprobs_val: List[List]
input_token_ids_logprobs_idx: List[List]
output_token_ids_logprobs_val: List[List]
output_token_ids_logprobs_idx: List[List]
# Hidden states
output_hidden_states: List[List[float]] output_hidden_states: List[List[float]]
@dataclass
class BatchMultimodalDecodeReq:
# The request id
rids: List[str]
@dataclass @dataclass
class BatchStrOut: class BatchStrOut:
# The request id # The request id
...@@ -406,10 +442,21 @@ class BatchStrOut: ...@@ -406,10 +442,21 @@ class BatchStrOut:
input_top_logprobs_idx: List[List] input_top_logprobs_idx: List[List]
output_top_logprobs_val: List[List] output_top_logprobs_val: List[List]
output_top_logprobs_idx: List[List] output_top_logprobs_idx: List[List]
input_token_ids_logprobs_val: List[List]
input_token_ids_logprobs_idx: List[List]
output_token_ids_logprobs_val: List[List]
output_token_ids_logprobs_idx: List[List]
# Hidden states
output_hidden_states: List[List[float]] output_hidden_states: List[List[float]]
@dataclass
class BatchMultimodalOut:
# The request id
rids: List[str]
@dataclass @dataclass
class BatchEmbeddingOut: class BatchEmbeddingOut:
# The request id # The request id
...@@ -439,6 +486,8 @@ class UpdateWeightFromDiskReqInput: ...@@ -439,6 +486,8 @@ class UpdateWeightFromDiskReqInput:
class UpdateWeightFromDiskReqOutput: class UpdateWeightFromDiskReqOutput:
success: bool success: bool
message: str message: str
# Number of paused requests during weight sync.
num_paused_requests: Optional[int] = 0
@dataclass @dataclass
...@@ -526,11 +575,57 @@ class AbortReq: ...@@ -526,11 +575,57 @@ class AbortReq:
rid: str rid: str
class ProfileReq(Enum): @dataclass
class GetInternalStateReq:
pass
@dataclass
class GetInternalStateReqOutput:
internal_state: Dict[Any, Any]
@dataclass
class SetInternalStateReq:
server_args: Dict[str, Any]
@dataclass
class SetInternalStateReqOutput:
updated: bool
server_args: Dict[str, Any]
@dataclass
class ProfileReqInput:
# The output directory
output_dir: Optional[str] = None
# If set, it profile as many as this number of steps.
# If it is set, profiling is automatically stopped after this step, and
# the caller doesn't need to run stop_profile.
num_steps: Optional[int] = None
activities: Optional[List[str]] = None
class ProfileReqType(Enum):
START_PROFILE = 1 START_PROFILE = 1
STOP_PROFILE = 2 STOP_PROFILE = 2
@dataclass
class ProfileReq:
type: ProfileReqType
output_dir: Optional[str] = None
num_steps: Optional[int] = None
activities: Optional[List[str]] = None
@dataclass
class ProfileReqOutput:
success: bool
message: str
@dataclass @dataclass
class ConfigureLoggingReq: class ConfigureLoggingReq:
log_requests: Optional[bool] = None log_requests: Optional[bool] = None
...@@ -556,6 +651,11 @@ class OpenSessionReqOutput: ...@@ -556,6 +651,11 @@ class OpenSessionReqOutput:
success: bool success: bool
@dataclass
class HealthCheckOutput:
pass
@dataclass @dataclass
class Function: class Function:
description: Optional[str] = None description: Optional[str] = None
......
...@@ -29,6 +29,7 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch ...@@ -29,6 +29,7 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
It contains low-level tensor data. Most of the data consists of GPU tensors. It contains low-level tensor data. Most of the data consists of GPU tensors.
""" """
import copy
import dataclasses import dataclasses
import logging import logging
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
...@@ -50,7 +51,10 @@ from sglang.srt.sampling.sampling_params import SamplingParams ...@@ -50,7 +51,10 @@ from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
...@@ -65,6 +69,8 @@ global_server_args_dict = { ...@@ -65,6 +69,8 @@ global_server_args_dict = {
"enable_dp_attention": ServerArgs.enable_dp_attention, "enable_dp_attention": ServerArgs.enable_dp_attention,
"enable_ep_moe": ServerArgs.enable_ep_moe, "enable_ep_moe": ServerArgs.enable_ep_moe,
"device": ServerArgs.device, "device": ServerArgs.device,
"speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
"enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla, "enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla,
"disable_radix_cache": ServerArgs.disable_radix_cache, "disable_radix_cache": ServerArgs.disable_radix_cache,
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged, "flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
...@@ -230,6 +236,7 @@ class Req: ...@@ -230,6 +236,7 @@ class Req:
sampling_params: SamplingParams, sampling_params: SamplingParams,
return_logprob: bool = False, return_logprob: bool = False,
top_logprobs_num: int = 0, top_logprobs_num: int = 0,
token_ids_logprob: List[int] = None,
stream: bool = False, stream: bool = False,
origin_input_ids_unpadded: Optional[Tuple[int]] = None, origin_input_ids_unpadded: Optional[Tuple[int]] = None,
lora_path: Optional[str] = None, lora_path: Optional[str] = None,
...@@ -256,17 +263,24 @@ class Req: ...@@ -256,17 +263,24 @@ class Req:
self.input_embeds = input_embeds self.input_embeds = input_embeds
# Sampling info # Sampling info
if isinstance(sampling_params.custom_params, dict):
sampling_params = copy.copy(sampling_params)
sampling_params.custom_params = sampling_params.custom_params | {
"__req__": self
}
self.sampling_params = sampling_params self.sampling_params = sampling_params
self.custom_logit_processor = custom_logit_processor self.custom_logit_processor = custom_logit_processor
self.return_hidden_states = return_hidden_states self.return_hidden_states = return_hidden_states
# Memory pool info # Memory pool info
self.req_pool_idx = None self.req_pool_idx: Optional[int] = None
# Check finish # Check finish
self.tokenizer = None self.tokenizer = None
self.finished_reason = None self.finished_reason = None
# If we want to abort the request in the middle of the event loop, set this to true
# Note: We should never set finished_reason in the middle, the req will get filtered and never respond
self.to_abort = False self.to_abort = False
self.stream = stream self.stream = stream
self.eos_token_ids = eos_token_ids self.eos_token_ids = eos_token_ids
...@@ -289,38 +303,56 @@ class Req: ...@@ -289,38 +303,56 @@ class Req:
self.image_inputs: Optional[ImageInputs] = None self.image_inputs: Optional[ImageInputs] = None
# Prefix info # Prefix info
# The indices to kv cache for the shared prefix.
self.prefix_indices = [] self.prefix_indices = []
# Tokens to run prefill. input_tokens - shared_prefix_tokens. # Number of tokens to run prefill.
# Updated if chunked.
self.extend_input_len = 0 self.extend_input_len = 0
# The relative logprob_start_len in an extend batch
self.extend_logprob_start_len = 0
self.last_node = None self.last_node = None
# Chunked prefill # Whether or not if it is chunked. It increments whenever
self.is_being_chunked = 0 # it is chunked, and decrement whenever chunked request is
# processed.
self.is_chunked = 0
# For retraction # For retraction
self.is_retracted = False self.is_retracted = False
# Logprobs (arguments) # Logprobs (arguments)
self.return_logprob = return_logprob self.return_logprob = return_logprob
# Start index to compute logprob from.
self.logprob_start_len = 0 self.logprob_start_len = 0
self.top_logprobs_num = top_logprobs_num self.top_logprobs_num = top_logprobs_num
self.token_ids_logprob = token_ids_logprob
# Logprobs (return values) # Logprobs (return values)
self.input_token_logprobs_val: Optional[List[float]] = None self.input_token_logprobs_val: Optional[List[float]] = None
self.input_token_logprobs_idx: Optional[List[int]] = None self.input_token_logprobs_idx: Optional[List[int]] = None
self.input_top_logprobs_val: Optional[List[float]] = None self.input_top_logprobs_val: Optional[List[float]] = None
self.input_top_logprobs_idx: Optional[List[int]] = None self.input_top_logprobs_idx: Optional[List[int]] = None
self.input_token_ids_logprobs_val: Optional[List[float]] = None
self.input_token_ids_logprobs_idx: Optional[List[int]] = None
# Temporary holder to store input_token_logprobs.
self.input_token_logprobs: Optional[List[Tuple[int]]] = None
self.temp_input_top_logprobs_val: Optional[List[torch.Tensor]] = None
self.temp_input_top_logprobs_idx: Optional[List[int]] = None
self.temp_input_token_ids_logprobs_val: Optional[List[float]] = None
self.temp_input_token_ids_logprobs_idx: Optional[List[int]] = None
if return_logprob: if return_logprob:
self.output_token_logprobs_val = [] self.output_token_logprobs_val = []
self.output_token_logprobs_idx = [] self.output_token_logprobs_idx = []
self.output_top_logprobs_val = [] self.output_top_logprobs_val = []
self.output_top_logprobs_idx = [] self.output_top_logprobs_idx = []
self.output_token_ids_logprobs_val = []
self.output_token_ids_logprobs_idx = []
else: else:
self.output_token_logprobs_val = self.output_token_logprobs_idx = ( self.output_token_logprobs_val = self.output_token_logprobs_idx = (
self.output_top_logprobs_val self.output_top_logprobs_val
) = self.output_top_logprobs_idx = None ) = self.output_top_logprobs_idx = self.output_token_ids_logprobs_val = (
self.output_token_ids_logprobs_idx
) = None
self.hidden_states = [] self.hidden_states = []
# Logprobs (internal values) # Logprobs (internal values)
...@@ -345,6 +377,13 @@ class Req: ...@@ -345,6 +377,13 @@ class Req:
self.spec_verify_ct = 0 self.spec_verify_ct = 0
self.lora_path = lora_path self.lora_path = lora_path
# This carries the error message for `.to_abort` and will be attached to the finished_reason at the end of the event loop
self.to_abort_message: str = "Unknown error"
@property
def seqlen(self):
return len(self.origin_input_ids) + len(self.output_ids)
def extend_image_inputs(self, image_inputs): def extend_image_inputs(self, image_inputs):
if self.image_inputs is None: if self.image_inputs is None:
self.image_inputs = image_inputs self.image_inputs = image_inputs
...@@ -422,7 +461,9 @@ class Req: ...@@ -422,7 +461,9 @@ class Req:
return return
if self.to_abort: if self.to_abort:
self.finished_reason = FINISH_ABORT() self.finished_reason = FINISH_ABORT(
message=self.to_abort_message,
)
return return
if len(self.output_ids) >= self.sampling_params.max_new_tokens: if len(self.output_ids) >= self.sampling_params.max_new_tokens:
...@@ -517,6 +558,8 @@ class Req: ...@@ -517,6 +558,8 @@ class Req:
self.output_token_logprobs_idx = self.output_token_logprobs_idx[:k] self.output_token_logprobs_idx = self.output_token_logprobs_idx[:k]
self.output_top_logprobs_val = self.output_top_logprobs_val[:k] self.output_top_logprobs_val = self.output_top_logprobs_val[:k]
self.output_top_logprobs_idx = self.output_top_logprobs_idx[:k] self.output_top_logprobs_idx = self.output_top_logprobs_idx[:k]
self.output_token_ids_logprobs_val = self.output_token_ids_logprobs_val[:k]
self.output_token_ids_logprobs_idx = self.output_token_ids_logprobs_idx[:k]
self.logprob_start_len = prompt_tokens + k self.logprob_start_len = prompt_tokens + k
self.last_update_decode_tokens = len(self.output_ids) - k self.last_update_decode_tokens = len(self.output_ids) - k
...@@ -527,16 +570,19 @@ class Req: ...@@ -527,16 +570,19 @@ class Req:
self.last_node = None self.last_node = None
self.extend_input_len = 0 self.extend_input_len = 0
self.is_retracted = True self.is_retracted = True
self.input_token_logprobs = None
self.temp_input_top_logprobs_val = None
self.temp_input_top_logprobs_idx = None
self.extend_logprob_start_len = 0
self.is_chunked = 0
self.req_pool_idx = None
# For incremental logprobs
# TODO: Fix the `logprob_start_len`
self.last_update_decode_tokens = 0 self.last_update_decode_tokens = 0
self.logprob_start_len = 10**9
def __repr__(self): def __repr__(self):
return ( return (
f"rid(n={self.rid}, " f"Req(rid={self.rid}, "
f"input_ids={self.origin_input_ids}, output_ids={self.output_ids}" f"input_ids={self.origin_input_ids}, output_ids={self.output_ids})"
) )
...@@ -576,11 +622,13 @@ class ScheduleBatch: ...@@ -576,11 +622,13 @@ class ScheduleBatch:
# For DP attention # For DP attention
global_num_tokens: Optional[List[int]] = None global_num_tokens: Optional[List[int]] = None
global_num_tokens_for_logprob: Optional[List[int]] = None
can_run_dp_cuda_graph: bool = False can_run_dp_cuda_graph: bool = False
# For processing logprobs # For processing logprobs
return_logprob: bool = False return_logprob: bool = False
top_logprobs_nums: Optional[List[int]] = None top_logprobs_nums: Optional[List[int]] = None
token_ids_logprobs: Optional[List[List[int]]] = None
# For extend and mixed chunekd prefill # For extend and mixed chunekd prefill
prefix_lens: List[int] = None prefix_lens: List[int] = None
...@@ -588,6 +636,8 @@ class ScheduleBatch: ...@@ -588,6 +636,8 @@ class ScheduleBatch:
extend_num_tokens: int = None extend_num_tokens: int = None
decoding_reqs: List[Req] = None decoding_reqs: List[Req] = None
extend_logprob_start_lens: List[int] = None extend_logprob_start_lens: List[int] = None
# It comes empty list if logprob is not required.
extend_input_logprob_token_ids: Optional[torch.Tensor] = None
# For encoder-decoder # For encoder-decoder
encoder_cached: Optional[List[bool]] = None encoder_cached: Optional[List[bool]] = None
...@@ -606,7 +656,7 @@ class ScheduleBatch: ...@@ -606,7 +656,7 @@ class ScheduleBatch:
# Speculative decoding # Speculative decoding
spec_algorithm: SpeculativeAlgorithm = None spec_algorithm: SpeculativeAlgorithm = None
spec_info: Optional[SpecInfo] = None spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None
# Enable custom logit processor # Enable custom logit processor
enable_custom_logit_processor: bool = False enable_custom_logit_processor: bool = False
...@@ -653,8 +703,10 @@ class ScheduleBatch: ...@@ -653,8 +703,10 @@ class ScheduleBatch:
req_pool_indices = self.req_to_token_pool.alloc(num_reqs) req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
if req_pool_indices is None: if req_pool_indices is None:
raise RuntimeError( raise RuntimeError(
"Out of memory. " "alloc_req_slots runs out of memory. "
"Please set a smaller number for `--max-running-requests`." "Please set a smaller number for `--max-running-requests`. "
f"{self.req_to_token_pool.available_size()=}, "
f"{num_reqs=}, "
) )
return req_pool_indices return req_pool_indices
...@@ -765,6 +817,7 @@ class ScheduleBatch: ...@@ -765,6 +817,7 @@ class ScheduleBatch:
out_cache_loc = self.alloc_token_slots(extend_num_tokens) out_cache_loc = self.alloc_token_slots(extend_num_tokens)
input_embeds = [] input_embeds = []
extend_input_logprob_token_ids = []
pt = 0 pt = 0
for i, req in enumerate(reqs): for i, req in enumerate(reqs):
...@@ -783,22 +836,64 @@ class ScheduleBatch: ...@@ -783,22 +836,64 @@ class ScheduleBatch:
# If req.input_embeds is already a list, append its content directly # If req.input_embeds is already a list, append its content directly
input_embeds.extend(req.input_embeds) # Use extend to avoid nesting input_embeds.extend(req.input_embeds) # Use extend to avoid nesting
if req.return_logprob:
# Compute the relative logprob_start_len in an extend batch
if req.logprob_start_len >= pre_len:
extend_logprob_start_len = min(
req.logprob_start_len - pre_len, req.extend_input_len - 1
)
else:
raise RuntimeError(
f"This should never happen. {req.logprob_start_len=}, {pre_len=}"
)
req.extend_logprob_start_len = extend_logprob_start_len
req.cached_tokens += pre_len - req.already_computed req.cached_tokens += pre_len - req.already_computed
req.already_computed = seq_len req.already_computed = seq_len
req.is_retracted = False req.is_retracted = False
pre_lens.append(pre_len) pre_lens.append(pre_len)
# Compute the relative logprob_start_len in an extend batch
if req.logprob_start_len >= pre_len:
req.extend_logprob_start_len = min(
req.logprob_start_len - pre_len,
req.extend_input_len,
req.seqlen - 1,
)
else:
req.extend_logprob_start_len = 0
if self.return_logprob:
# Find input logprob token ids.
# First, find a global index within origin_input_ids and slide it by 1
# to compute input logprobs. It is because you need the next token
# to compute input logprobs. E.g., (chunk size 2)
#
# input_logprobs = [1, 2, 3, 4]
# fill_ids = [1, 2]
# extend_input_logprob_token_id = [2, 3]
#
# Note that it can also overflow. In this case, we pad it with 0.
# input_logprobs = [1, 2, 3, 4]
# fill_ids = [3, 4]
# extend_input_logprob_token_id = [4, 0]
global_start_idx, global_end_idx = (
len(req.prefix_indices),
len(req.fill_ids),
)
# Apply logprob_start_len
if global_start_idx < req.logprob_start_len:
global_start_idx = req.logprob_start_len
logprob_token_ids = req.origin_input_ids[
global_start_idx + 1 : global_end_idx + 1
]
extend_input_logprob_token_ids.extend(logprob_token_ids)
# We will need req.extend_input_len - req.extend_logprob_start_len number of
# tokens, and logprob_token_ids is for input logprob, so pad the rest of them by 0.
extend_input_logprob_token_ids.extend(
[0]
* (
req.extend_input_len
- req.extend_logprob_start_len
- len(logprob_token_ids)
)
)
if self.return_logprob:
extend_input_logprob_token_ids = torch.tensor(
extend_input_logprob_token_ids
)
else:
extend_input_logprob_token_ids = None
# Set fields # Set fields
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to( self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
...@@ -821,10 +916,12 @@ class ScheduleBatch: ...@@ -821,10 +916,12 @@ class ScheduleBatch:
self.seq_lens_sum = sum(seq_lens) self.seq_lens_sum = sum(seq_lens)
if self.return_logprob: if self.return_logprob:
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs] self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
self.token_ids_logprobs = [r.token_ids_logprob for r in reqs]
self.extend_num_tokens = extend_num_tokens self.extend_num_tokens = extend_num_tokens
self.prefix_lens = [len(r.prefix_indices) for r in reqs] self.prefix_lens = [len(r.prefix_indices) for r in reqs]
self.extend_lens = [r.extend_input_len for r in reqs] self.extend_lens = [r.extend_input_len for r in reqs]
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs] self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
# Write to req_to_token_pool # Write to req_to_token_pool
pre_lens = torch.tensor(pre_lens, dtype=torch.int32).to( pre_lens = torch.tensor(pre_lens, dtype=torch.int32).to(
...@@ -860,7 +957,6 @@ class ScheduleBatch: ...@@ -860,7 +957,6 @@ class ScheduleBatch:
self.sampling_info = SamplingBatchInfo.from_schedule_batch( self.sampling_info = SamplingBatchInfo.from_schedule_batch(
self, self,
self.model_config.vocab_size, self.model_config.vocab_size,
enable_overlap_schedule=self.enable_overlap,
) )
def mix_with_running(self, running_batch: "ScheduleBatch"): def mix_with_running(self, running_batch: "ScheduleBatch"):
...@@ -905,25 +1001,43 @@ class ScheduleBatch: ...@@ -905,25 +1001,43 @@ class ScheduleBatch:
return False return False
def retract_decode(self): def retract_decode(self, server_args: ServerArgs):
"""Retract the decoding requests when there is not enough memory.""" """Retract the decoding requests when there is not enough memory."""
sorted_indices = [i for i in range(len(self.reqs))] sorted_indices = [i for i in range(len(self.reqs))]
# TODO(lsyin): improve retraction policy for radix cache # TODO(lsyin): improve retraction policy for radix cache
sorted_indices.sort( # For spec decoding, filter_batch API can only filter
key=lambda i: ( # requests from the back, so we can only retract from the back.
len(self.reqs[i].output_ids), # TODO(sang): Clean up finish path and support better retract
-len(self.reqs[i].origin_input_ids), # policy.
), if not server_args.speculative_algorithm:
reverse=True, sorted_indices.sort(
) key=lambda i: (
len(self.reqs[i].output_ids),
-len(self.reqs[i].origin_input_ids),
),
reverse=True,
)
def get_required_tokens(num_reqs: int):
headroom_for_spec_decode = 0
if server_args.speculative_algorithm:
headroom_for_spec_decode += (
num_reqs
* server_args.speculative_eagle_topk
* server_args.speculative_num_steps
+ num_reqs * server_args.speculative_num_draft_tokens
)
return (
num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode
)
retracted_reqs = [] retracted_reqs = []
seq_lens_cpu = self.seq_lens.cpu().numpy() seq_lens_cpu = self.seq_lens.cpu().numpy()
first_iter = True first_iter = True
while ( while (
self.token_to_kv_pool.available_size() self.token_to_kv_pool.available_size()
< len(sorted_indices) * global_config.retract_decode_steps < get_required_tokens(len(sorted_indices))
or first_iter or first_iter
): ):
if len(sorted_indices) == 1: if len(sorted_indices) == 1:
...@@ -1048,17 +1162,40 @@ class ScheduleBatch: ...@@ -1048,17 +1162,40 @@ class ScheduleBatch:
self.sampling_info = SamplingBatchInfo.from_schedule_batch( self.sampling_info = SamplingBatchInfo.from_schedule_batch(
self, self,
self.model_config.vocab_size, self.model_config.vocab_size,
enable_overlap_schedule=self.enable_overlap,
) )
def prepare_for_decode(self): def prepare_for_decode(self):
self.forward_mode = ForwardMode.DECODE self.forward_mode = ForwardMode.DECODE
if self.spec_algorithm.is_eagle(): if self.spec_algorithm.is_eagle():
# if spec decoding is used, the decode batch is prepared inside
# `forward_batch_speculative_generation` after running draft models.
return return
if self.sampling_info.penalizer_orchestrator.is_required:
if self.enable_overlap:
# TODO: this can be slow, optimize this.
delayed_output_ids = torch.tensor(
[
(
req.output_ids[-1]
if len(req.output_ids)
else req.origin_input_ids[-1]
)
for req in self.reqs
],
dtype=torch.int64,
device=self.device,
)
self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
delayed_output_ids
)
else:
self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
self.output_ids.to(torch.int64)
)
self.input_ids = self.output_ids self.input_ids = self.output_ids
self.output_ids = None self.output_ids = None
self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(self.input_ids)
# Alloc mem # Alloc mem
bs = len(self.reqs) bs = len(self.reqs)
...@@ -1086,14 +1223,15 @@ class ScheduleBatch: ...@@ -1086,14 +1223,15 @@ class ScheduleBatch:
def filter_batch( def filter_batch(
self, self,
being_chunked_req: Optional[Req] = None, chunked_req_to_exclude: Optional[Req] = None,
keep_indices: Optional[List[int]] = None, keep_indices: Optional[List[int]] = None,
): ):
if keep_indices is None: if keep_indices is None:
keep_indices = [ keep_indices = [
i i
for i in range(len(self.reqs)) for i in range(len(self.reqs))
if not self.reqs[i].finished() and self.reqs[i] is not being_chunked_req if not self.reqs[i].finished()
and self.reqs[i] is not chunked_req_to_exclude
] ]
if keep_indices is None or len(keep_indices) == 0: if keep_indices is None or len(keep_indices) == 0:
...@@ -1105,31 +1243,34 @@ class ScheduleBatch: ...@@ -1105,31 +1243,34 @@ class ScheduleBatch:
# No need to filter # No need to filter
return return
keep_indices_device = torch.tensor(keep_indices, dtype=torch.int64).to(
self.device, non_blocking=True
)
if self.model_config.is_encoder_decoder: if self.model_config.is_encoder_decoder:
self.encoder_lens = self.encoder_lens[keep_indices] self.encoder_lens = self.encoder_lens[keep_indices_device]
self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices] self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]
self.reqs = [self.reqs[i] for i in keep_indices] self.reqs = [self.reqs[i] for i in keep_indices]
new_indices = torch.tensor(keep_indices, dtype=torch.int64).to( self.req_pool_indices = self.req_pool_indices[keep_indices_device]
self.device, non_blocking=True self.seq_lens = self.seq_lens[keep_indices_device]
)
self.req_pool_indices = self.req_pool_indices[new_indices]
self.seq_lens = self.seq_lens[new_indices]
self.out_cache_loc = None self.out_cache_loc = None
self.seq_lens_sum = self.seq_lens.sum().item() self.seq_lens_sum = self.seq_lens.sum().item()
self.output_ids = self.output_ids[new_indices] self.output_ids = self.output_ids[keep_indices_device]
self.return_logprob = any(req.return_logprob for req in self.reqs) self.return_logprob = any(req.return_logprob for req in self.reqs)
if self.return_logprob: if self.return_logprob:
self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices] self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices]
self.token_ids_logprobs = [self.token_ids_logprobs[i] for i in keep_indices]
else: else:
self.top_logprobs_nums = None self.top_logprobs_nums = None
self.token_ids_logprobs = None
self.has_stream = any(req.stream for req in self.reqs) self.has_stream = any(req.stream for req in self.reqs)
self.has_grammar = any(req.grammar for req in self.reqs) self.has_grammar = any(req.grammar for req in self.reqs)
self.sampling_info.filter_batch(keep_indices, new_indices) self.sampling_info.filter_batch(keep_indices, keep_indices_device)
if self.spec_info: if self.spec_info:
self.spec_info.filter_batch(new_indices) self.spec_info.filter_batch(keep_indices_device)
def merge_batch(self, other: "ScheduleBatch"): def merge_batch(self, other: "ScheduleBatch"):
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
...@@ -1152,10 +1293,13 @@ class ScheduleBatch: ...@@ -1152,10 +1293,13 @@ class ScheduleBatch:
self.output_ids = torch.concat([self.output_ids, other.output_ids]) self.output_ids = torch.concat([self.output_ids, other.output_ids])
if self.return_logprob and other.return_logprob: if self.return_logprob and other.return_logprob:
self.top_logprobs_nums.extend(other.top_logprobs_nums) self.top_logprobs_nums.extend(other.top_logprobs_nums)
self.token_ids_logprobs.extend(other.token_ids_logprobs)
elif self.return_logprob: elif self.return_logprob:
self.top_logprobs_nums.extend([0] * len(other.reqs)) self.top_logprobs_nums.extend([0] * len(other.reqs))
self.token_ids_logprobs.extend([None] * len(other.reqs))
elif other.return_logprob: elif other.return_logprob:
self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
self.token_ids_logprobs = [None] * len(self.reqs) + other.token_ids_logprobs
self.reqs.extend(other.reqs) self.reqs.extend(other.reqs)
self.return_logprob |= other.return_logprob self.return_logprob |= other.return_logprob
...@@ -1192,7 +1336,9 @@ class ScheduleBatch: ...@@ -1192,7 +1336,9 @@ class ScheduleBatch:
seq_lens_sum=self.seq_lens_sum, seq_lens_sum=self.seq_lens_sum,
return_logprob=self.return_logprob, return_logprob=self.return_logprob,
top_logprobs_nums=self.top_logprobs_nums, top_logprobs_nums=self.top_logprobs_nums,
token_ids_logprobs=self.token_ids_logprobs,
global_num_tokens=self.global_num_tokens, global_num_tokens=self.global_num_tokens,
global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph, can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
extend_num_tokens=self.extend_num_tokens, extend_num_tokens=self.extend_num_tokens,
extend_seq_lens=extend_seq_lens, extend_seq_lens=extend_seq_lens,
...@@ -1219,6 +1365,7 @@ class ScheduleBatch: ...@@ -1219,6 +1365,7 @@ class ScheduleBatch:
else CaptureHiddenMode.NULL else CaptureHiddenMode.NULL
) )
), ),
extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
) )
def copy(self): def copy(self):
...@@ -1262,9 +1409,11 @@ class ModelWorkerBatch: ...@@ -1262,9 +1409,11 @@ class ModelWorkerBatch:
# For logprob # For logprob
return_logprob: bool return_logprob: bool
top_logprobs_nums: Optional[List[int]] top_logprobs_nums: Optional[List[int]]
token_ids_logprobs: Optional[List[List[int]]]
# For DP attention # For DP attention
global_num_tokens: Optional[List[int]] global_num_tokens: Optional[List[int]]
global_num_tokens_for_logprob: Optional[List[int]]
can_run_dp_cuda_graph: bool can_run_dp_cuda_graph: bool
# For extend # For extend
...@@ -1272,6 +1421,7 @@ class ModelWorkerBatch: ...@@ -1272,6 +1421,7 @@ class ModelWorkerBatch:
extend_seq_lens: Optional[List[int]] extend_seq_lens: Optional[List[int]]
extend_prefix_lens: Optional[List[int]] extend_prefix_lens: Optional[List[int]]
extend_logprob_start_lens: Optional[List[int]] extend_logprob_start_lens: Optional[List[int]]
extend_input_logprob_token_ids: Optional[torch.Tensor]
# For multimodal # For multimodal
image_inputs: Optional[List[ImageInputs]] image_inputs: Optional[List[ImageInputs]]
...@@ -1293,7 +1443,8 @@ class ModelWorkerBatch: ...@@ -1293,7 +1443,8 @@ class ModelWorkerBatch:
# Speculative decoding # Speculative decoding
spec_algorithm: SpeculativeAlgorithm = None spec_algorithm: SpeculativeAlgorithm = None
spec_info: Optional[SpecInfo] = None spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
# If set, the output of the batch contains the hidden states of the run.
capture_hidden_mode: CaptureHiddenMode = None capture_hidden_mode: CaptureHiddenMode = None
......
...@@ -272,7 +272,7 @@ class PrefillAdder: ...@@ -272,7 +272,7 @@ class PrefillAdder:
self.req_states = None self.req_states = None
self.can_run_list = [] self.can_run_list = []
self.new_being_chunked_req = None self.new_chunked_req = None
self.log_hit_tokens = 0 self.log_hit_tokens = 0
self.log_input_tokens = 0 self.log_input_tokens = 0
...@@ -327,7 +327,7 @@ class PrefillAdder: ...@@ -327,7 +327,7 @@ class PrefillAdder:
self.log_hit_tokens += prefix_len self.log_hit_tokens += prefix_len
self.log_input_tokens += extend_input_len self.log_input_tokens += extend_input_len
def add_being_chunked_req(self, req: Req): def add_chunked_req(self, req: Req):
truncated = req.extend_input_len > self.rem_chunk_tokens truncated = req.extend_input_len > self.rem_chunk_tokens
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens) req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len] req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
...@@ -354,7 +354,7 @@ class PrefillAdder: ...@@ -354,7 +354,7 @@ class PrefillAdder:
finally: finally:
self.tree_cache.dec_lock_ref(last_node) self.tree_cache.dec_lock_ref(last_node)
def add_one_req_ignore_eos(self, req: Req): def add_one_req_ignore_eos(self, req: Req, has_chunked_req: bool):
def add_req_state(r, insert_sort=False): def add_req_state(r, insert_sort=False):
new_token_ratio = ( new_token_ratio = (
1.0 if r.sampling_params.ignore_eos else self.new_token_ratio 1.0 if r.sampling_params.ignore_eos else self.new_token_ratio
...@@ -403,6 +403,7 @@ class PrefillAdder: ...@@ -403,6 +403,7 @@ class PrefillAdder:
self.rem_chunk_tokens is None self.rem_chunk_tokens is None
or req.extend_input_len <= self.rem_chunk_tokens or req.extend_input_len <= self.rem_chunk_tokens
): ):
# Non-chunked prefill
self.can_run_list.append(req) self.can_run_list.append(req)
self._prefill_one_req( self._prefill_one_req(
0, 0,
...@@ -418,14 +419,14 @@ class PrefillAdder: ...@@ -418,14 +419,14 @@ class PrefillAdder:
req.extend_input_len = trunc_len req.extend_input_len = trunc_len
req.fill_ids = req.fill_ids[:trunc_len] req.fill_ids = req.fill_ids[:trunc_len]
self.can_run_list.append(req) self.can_run_list.append(req)
self.new_being_chunked_req = req self.new_chunked_req = req
self._prefill_one_req(0, trunc_len, 0) self._prefill_one_req(0, trunc_len, 0)
return self.budget_state() return self.budget_state()
def add_one_req(self, req: Req): def add_one_req(self, req: Req, has_chunked_req: bool):
if req.sampling_params.ignore_eos and self.tree_cache.disable: if req.sampling_params.ignore_eos and self.tree_cache.disable:
return self.add_one_req_ignore_eos(req) return self.add_one_req_ignore_eos(req, has_chunked_req)
total_tokens = req.extend_input_len + min( total_tokens = req.extend_input_len + min(
req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION
...@@ -443,14 +444,7 @@ class PrefillAdder: ...@@ -443,14 +444,7 @@ class PrefillAdder:
if total_tokens > self.rem_total_tokens: if total_tokens > self.rem_total_tokens:
return AddReqResult.NO_TOKEN return AddReqResult.NO_TOKEN
if ( if self.rem_chunk_tokens is None or input_tokens <= self.rem_chunk_tokens:
self.rem_chunk_tokens is None
or input_tokens <= self.rem_chunk_tokens
or (
req.return_logprob
and req.logprob_start_len != len(req.origin_input_ids) - 1
)
):
# Non-chunked prefill # Non-chunked prefill
self.can_run_list.append(req) self.can_run_list.append(req)
self.tree_cache.inc_lock_ref(req.last_node) self.tree_cache.inc_lock_ref(req.last_node)
...@@ -470,8 +464,9 @@ class PrefillAdder: ...@@ -470,8 +464,9 @@ class PrefillAdder:
req.extend_input_len = trunc_len req.extend_input_len = trunc_len
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len] req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
self.can_run_list.append(req) self.can_run_list.append(req)
self.new_being_chunked_req = req self.new_chunked_req = req
self.tree_cache.inc_lock_ref(req.last_node) self.tree_cache.inc_lock_ref(req.last_node)
self._prefill_one_req(prefix_len, trunc_len, 0) self._prefill_one_req(prefix_len, trunc_len, 0)
......
...@@ -17,10 +17,11 @@ import faulthandler ...@@ -17,10 +17,11 @@ import faulthandler
import logging import logging
import os import os
import signal import signal
import sys
import threading import threading
import time import time
import warnings import warnings
from collections import deque from collections import defaultdict, deque
from concurrent import futures from concurrent import futures
from dataclasses import dataclass from dataclasses import dataclass
from http import HTTPStatus from http import HTTPStatus
...@@ -41,20 +42,28 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput ...@@ -41,20 +42,28 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
AbortReq, AbortReq,
BatchEmbeddingOut, BatchEmbeddingOut,
BatchMultimodalDecodeReq,
BatchTokenIDOut, BatchTokenIDOut,
CloseSessionReqInput, CloseSessionReqInput,
FlushCacheReq, FlushCacheReq,
GetInternalStateReq,
GetInternalStateReqOutput,
GetWeightsByNameReqInput, GetWeightsByNameReqInput,
GetWeightsByNameReqOutput, GetWeightsByNameReqOutput,
HealthCheckOutput,
InitWeightsUpdateGroupReqInput, InitWeightsUpdateGroupReqInput,
InitWeightsUpdateGroupReqOutput, InitWeightsUpdateGroupReqOutput,
OpenSessionReqInput, OpenSessionReqInput,
OpenSessionReqOutput, OpenSessionReqOutput,
ProfileReq, ProfileReq,
ProfileReqOutput,
ProfileReqType,
ReleaseMemoryOccupationReqInput, ReleaseMemoryOccupationReqInput,
ReleaseMemoryOccupationReqOutput, ReleaseMemoryOccupationReqOutput,
ResumeMemoryOccupationReqInput, ResumeMemoryOccupationReqInput,
ResumeMemoryOccupationReqOutput, ResumeMemoryOccupationReqOutput,
SetInternalStateReq,
SetInternalStateReqOutput,
TokenizedEmbeddingReqInput, TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
UpdateWeightFromDiskReqInput, UpdateWeightFromDiskReqInput,
...@@ -95,6 +104,8 @@ from sglang.srt.utils import ( ...@@ -95,6 +104,8 @@ from sglang.srt.utils import (
crash_on_warnings, crash_on_warnings,
get_bool_env_var, get_bool_env_var,
get_zmq_socket, get_zmq_socket,
kill_itself_when_parent_died,
pyspy_dump_schedulers,
set_gpu_proc_affinity, set_gpu_proc_affinity,
set_random_seed, set_random_seed,
suppress_other_loggers, suppress_other_loggers,
...@@ -104,13 +115,16 @@ from sglang.utils import TypeBasedDispatcher, get_exception_traceback ...@@ -104,13 +115,16 @@ from sglang.utils import TypeBasedDispatcher, get_exception_traceback
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Test retract decode for debugging purposes # Test retract decode for debugging purposes
test_retract = get_bool_env_var("SGLANG_TEST_RETRACT") TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
@dataclass @dataclass
class GenerationBatchResult: class GenerationBatchResult:
logits_output: LogitsProcessorOutput logits_output: LogitsProcessorOutput
next_token_ids: List[int] next_token_ids: List[int]
extend_input_len_per_req: List[int]
extend_logprob_start_len_per_req: List[int]
bid: int bid: int
...@@ -142,15 +156,23 @@ class Scheduler: ...@@ -142,15 +156,23 @@ class Scheduler:
self.enable_overlap = not server_args.disable_overlap_schedule self.enable_overlap = not server_args.disable_overlap_schedule
self.skip_tokenizer_init = server_args.skip_tokenizer_init self.skip_tokenizer_init = server_args.skip_tokenizer_init
self.enable_metrics = server_args.enable_metrics self.enable_metrics = server_args.enable_metrics
self.stream_interval = server_args.stream_interval
self.spec_algorithm = SpeculativeAlgorithm.from_string( self.spec_algorithm = SpeculativeAlgorithm.from_string(
server_args.speculative_algorithm server_args.speculative_algorithm
) )
self.gpu_id = gpu_id
self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
self.decode_mem_cache_buf_multiplier = ( self.decode_mem_cache_buf_multiplier = (
self.server_args.speculative_num_draft_tokens (
self.server_args.speculative_num_draft_tokens
+ (
self.server_args.speculative_eagle_topk
* self.server_args.speculative_num_steps
)
)
if not self.spec_algorithm.is_none() if not self.spec_algorithm.is_none()
else 1 else 1
) )
self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
# Distributed rank info # Distributed rank info
self.dp_size = server_args.dp_size self.dp_size = server_args.dp_size
...@@ -246,7 +268,7 @@ class Scheduler: ...@@ -246,7 +268,7 @@ class Scheduler:
nccl_port=port_args.nccl_port, nccl_port=port_args.nccl_port,
) )
# Launch a worker for speculative decoding if needed # Launch a draft worker for speculative decoding
if self.spec_algorithm.is_eagle(): if self.spec_algorithm.is_eagle():
from sglang.srt.speculative.eagle_worker import EAGLEWorker from sglang.srt.speculative.eagle_worker import EAGLEWorker
...@@ -258,8 +280,10 @@ class Scheduler: ...@@ -258,8 +280,10 @@ class Scheduler:
target_worker=self.tp_worker, target_worker=self.tp_worker,
dp_rank=dp_rank, dp_rank=dp_rank,
) )
self.prefill_only_one_req = True
else: else:
self.draft_worker = None self.draft_worker = None
self.prefill_only_one_req = False
# Get token and memory info from the model worker # Get token and memory info from the model worker
( (
...@@ -280,6 +304,7 @@ class Scheduler: ...@@ -280,6 +304,7 @@ class Scheduler:
self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func() self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
global_server_args_dict.update(worker_global_server_args_dict) global_server_args_dict.update(worker_global_server_args_dict)
set_random_seed(self.random_seed) set_random_seed(self.random_seed)
# Print debug info # Print debug info
logger.info( logger.info(
f"max_total_num_tokens={self.max_total_num_tokens}, " f"max_total_num_tokens={self.max_total_num_tokens}, "
...@@ -301,19 +326,18 @@ class Scheduler: ...@@ -301,19 +326,18 @@ class Scheduler:
token_to_kv_pool=self.token_to_kv_pool, token_to_kv_pool=self.token_to_kv_pool,
) )
else: else:
self.tree_cache = ( if self.enable_hierarchical_cache:
HiRadixCache( self.tree_cache = HiRadixCache(
req_to_token_pool=self.req_to_token_pool, req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool=self.token_to_kv_pool, token_to_kv_pool=self.token_to_kv_pool,
) )
if self.enable_hierarchical_cache else:
else RadixCache( self.tree_cache = RadixCache(
req_to_token_pool=self.req_to_token_pool, req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool=self.token_to_kv_pool, token_to_kv_pool=self.token_to_kv_pool,
disable=server_args.disable_radix_cache, disable=server_args.disable_radix_cache,
) )
)
self.tree_cache_metrics = {"total": 0, "hit": 0}
self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache) self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
# Init running status # Init running status
...@@ -330,12 +354,23 @@ class Scheduler: ...@@ -330,12 +354,23 @@ class Scheduler:
self.num_generated_tokens = 0 self.num_generated_tokens = 0
self.spec_num_total_accepted_tokens = 0 self.spec_num_total_accepted_tokens = 0
self.spec_num_total_forward_ct = 0 self.spec_num_total_forward_ct = 0
self.cum_spec_accept_length = 0
self.cum_spec_accept_count = 0
self.last_decode_stats_tic = time.time() self.last_decode_stats_tic = time.time()
self.return_health_check_ct = 0
self.stream_interval = server_args.stream_interval self.stream_interval = server_args.stream_interval
self.current_stream = torch.get_device_module(self.device).current_stream() self.current_stream = torch.get_device_module(self.device).current_stream()
if self.device == "cpu": if self.device == "cpu":
self.current_stream.synchronize = lambda: None # No-op for CPU self.current_stream.synchronize = lambda: None # No-op for CPU
# For metrics only.
# The largest prefill length of a single request
self._largest_prefill_len: int = 0
# The largest context length (prefill + generation) of a single request
self._largest_prefill_decode_len: int = 0
self.last_gen_throughput: float = 0.0
self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
# Session info # Session info
self.sessions: Dict[str, Session] = {} self.sessions: Dict[str, Session] = {}
...@@ -343,7 +378,7 @@ class Scheduler: ...@@ -343,7 +378,7 @@ class Scheduler:
self.chunked_prefill_size = server_args.chunked_prefill_size self.chunked_prefill_size = server_args.chunked_prefill_size
if self.chunked_prefill_size <= 0: # -1 means disable if self.chunked_prefill_size <= 0: # -1 means disable
self.chunked_prefill_size = None self.chunked_prefill_size = None
self.being_chunked_req = None self.chunked_req = None
self.is_mixed_chunk = ( self.is_mixed_chunk = (
self.chunked_prefill_size is not None and server_args.enable_mixed_chunk self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
) )
...@@ -377,7 +412,7 @@ class Scheduler: ...@@ -377,7 +412,7 @@ class Scheduler:
) / global_config.default_new_token_ratio_decay_steps ) / global_config.default_new_token_ratio_decay_steps
self.new_token_ratio = self.init_new_token_ratio self.new_token_ratio = self.init_new_token_ratio
# Tells whether the current running batch is full so that we can skip # Tell whether the current running batch is full so that we can skip
# the check of whether to prefill new requests. # the check of whether to prefill new requests.
# This is an optimization to reduce the overhead of the prefill check. # This is an optimization to reduce the overhead of the prefill check.
self.batch_is_full = False self.batch_is_full = False
...@@ -388,26 +423,16 @@ class Scheduler: ...@@ -388,26 +423,16 @@ class Scheduler:
t.start() t.start()
self.parent_process = psutil.Process().parent() self.parent_process = psutil.Process().parent()
# Init memory saver
self.memory_saver_adapter = TorchMemorySaverAdapter.create( self.memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=server_args.enable_memory_saver enable=server_args.enable_memory_saver
) )
# Init profiler # Init profiler
if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "": self.torch_profiler = None
self.profiler = None self.torch_profiler_output_dir: Optional[str] = None
else: self.torch_profiler_activities: Optional[List[str]] = None
self.torch_profiler_trace_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR") self.profiler_target_forward_ct: Optional[int] = None
logger.info(
"Profiling enabled. Traces will be saved to: %s",
self.torch_profiler_trace_dir,
)
self.profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
with_stack=True,
)
# Init metrics stats # Init metrics stats
self.stats = SchedulerStats() self.stats = SchedulerStats()
...@@ -431,6 +456,8 @@ class Scheduler: ...@@ -431,6 +456,8 @@ class Scheduler:
(TokenizedEmbeddingReqInput, self.handle_embedding_request), (TokenizedEmbeddingReqInput, self.handle_embedding_request),
(FlushCacheReq, self.flush_cache_wrapped), (FlushCacheReq, self.flush_cache_wrapped),
(AbortReq, self.abort_request), (AbortReq, self.abort_request),
(OpenSessionReqInput, self.open_session),
(CloseSessionReqInput, self.close_session),
(UpdateWeightFromDiskReqInput, self.update_weights_from_disk), (UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
(InitWeightsUpdateGroupReqInput, self.init_weights_update_group), (InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
( (
...@@ -439,22 +466,16 @@ class Scheduler: ...@@ -439,22 +466,16 @@ class Scheduler:
), ),
(UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor), (UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
(GetWeightsByNameReqInput, self.get_weights_by_name), (GetWeightsByNameReqInput, self.get_weights_by_name),
(ReleaseMemoryOccupationReqInput, self.release_memory_occupation),
(ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
(ProfileReq, self.profile), (ProfileReq, self.profile),
(OpenSessionReqInput, self.open_session), (GetInternalStateReq, self.get_internal_state),
(CloseSessionReqInput, self.close_session), (SetInternalStateReq, self.set_internal_state),
(
ReleaseMemoryOccupationReqInput,
lambda _: self.release_memory_occupation(),
),
(
ResumeMemoryOccupationReqInput,
lambda _: self.resume_memory_occupation(),
),
] ]
) )
def watchdog_thread(self): def watchdog_thread(self):
"""A watch dog thread that will try to kill the server itself if one batch takes too long.""" """A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
self.watchdog_last_forward_ct = 0 self.watchdog_last_forward_ct = 0
self.watchdog_last_time = time.time() self.watchdog_last_time = time.time()
...@@ -469,7 +490,18 @@ class Scheduler: ...@@ -469,7 +490,18 @@ class Scheduler:
self.watchdog_last_forward_ct = self.forward_ct self.watchdog_last_forward_ct = self.forward_ct
self.watchdog_last_time = current self.watchdog_last_time = current
time.sleep(self.watchdog_timeout // 2) time.sleep(self.watchdog_timeout // 2)
# Wait sometimes so that the parent process can print the error.
# Print batch size and memory pool info to check whether there are de-sync issues.
logger.error(
f"{self.cur_batch.batch_size()=}, "
f"{self.cur_batch.reqs=}, "
f"{self.token_to_kv_pool.available_size()=}, "
f"{self.tree_cache.evictable_size()=}, "
)
# Wait for some time so that the parent process can print the error.
pyspy_dump_schedulers()
print(file=sys.stderr, flush=True)
print(file=sys.stdout, flush=True)
time.sleep(5) time.sleep(5)
self.parent_process.send_signal(signal.SIGQUIT) self.parent_process.send_signal(signal.SIGQUIT)
...@@ -586,6 +618,13 @@ class Scheduler: ...@@ -586,6 +618,13 @@ class Scheduler:
def process_input_requests(self, recv_reqs: List): def process_input_requests(self, recv_reqs: List):
for recv_req in recv_reqs: for recv_req in recv_reqs:
# If it is a health check generation request and there are running requests, ignore it.
if is_health_check_generate_req(recv_req) and (
self.chunked_req is not None or self.running_batch is not None
):
self.return_health_check_ct += 1
continue
output = self._request_dispatcher(recv_req) output = self._request_dispatcher(recv_req)
if output is not None: if output is not None:
self.send_to_tokenizer.send_pyobj(output) self.send_to_tokenizer.send_pyobj(output)
...@@ -600,7 +639,6 @@ class Scheduler: ...@@ -600,7 +639,6 @@ class Scheduler:
or recv_req.session_params.id is None or recv_req.session_params.id is None
or recv_req.session_params.id not in self.sessions or recv_req.session_params.id not in self.sessions
): ):
if recv_req.input_embeds is not None: if recv_req.input_embeds is not None:
# Generate fake input_ids based on the length of input_embeds # Generate fake input_ids based on the length of input_embeds
seq_length = len(recv_req.input_embeds) seq_length = len(recv_req.input_embeds)
...@@ -627,6 +665,7 @@ class Scheduler: ...@@ -627,6 +665,7 @@ class Scheduler:
recv_req.sampling_params, recv_req.sampling_params,
return_logprob=recv_req.return_logprob, return_logprob=recv_req.return_logprob,
top_logprobs_num=recv_req.top_logprobs_num, top_logprobs_num=recv_req.top_logprobs_num,
token_ids_logprob=recv_req.token_ids_logprob,
stream=recv_req.stream, stream=recv_req.stream,
lora_path=recv_req.lora_path, lora_path=recv_req.lora_path,
input_embeds=recv_req.input_embeds, input_embeds=recv_req.input_embeds,
...@@ -643,14 +682,14 @@ class Scheduler: ...@@ -643,14 +682,14 @@ class Scheduler:
req.finished_reason = FINISH_ABORT( req.finished_reason = FINISH_ABORT(
f"Invalid request: session id {recv_req.session_params.id} does not exist" f"Invalid request: session id {recv_req.session_params.id} does not exist"
) )
self.waiting_queue.append(req) self._add_request_to_queue(req)
return return
else: else:
# Create a new request from a previous session # Create a new request from a previous session
session = self.sessions[recv_req.session_params.id] session = self.sessions[recv_req.session_params.id]
req = session.create_req(recv_req, self.tokenizer) req = session.create_req(recv_req, self.tokenizer)
if isinstance(req.finished_reason, FINISH_ABORT): if isinstance(req.finished_reason, FINISH_ABORT):
self.waiting_queue.append(req) self._add_request_to_queue(req)
return return
# Handle multimodal inputs # Handle multimodal inputs
...@@ -674,7 +713,7 @@ class Scheduler: ...@@ -674,7 +713,7 @@ class Scheduler:
req.finished_reason = FINISH_ABORT( req.finished_reason = FINISH_ABORT(
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError" error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
) )
self.waiting_queue.append(req) self._add_request_to_queue(req)
return return
# Validate prompts length # Validate prompts length
...@@ -686,16 +725,26 @@ class Scheduler: ...@@ -686,16 +725,26 @@ class Scheduler:
if error_msg: if error_msg:
req.origin_input_ids = [0] req.origin_input_ids = [0]
req.sampling_params.max_new_tokens = 0 req.sampling_params.max_new_tokens = 0
self.waiting_queue.append(req) self._add_request_to_queue(req)
return return
# Copy more attributes # Copy more attributes
if recv_req.logprob_start_len == -1: if recv_req.logprob_start_len == -1 or not recv_req.return_logprob:
# By default, only return the logprobs for output tokens # By default, only return the logprobs for output tokens
req.logprob_start_len = len(req.origin_input_ids) - 1 req.logprob_start_len = len(req.origin_input_ids) - 1
else: else:
req.logprob_start_len = recv_req.logprob_start_len req.logprob_start_len = recv_req.logprob_start_len
if req.logprob_start_len >= len(req.origin_input_ids):
req.finished_reason = FINISH_ABORT(
f"logprob_start_len, ({req.logprob_start_len}) is higher than the number of input tokens ({len(req.origin_input_ids)}). Request with a lower logprob_start_len.",
HTTPStatus.BAD_REQUEST,
"BadRequestError",
)
req.logprob_start_len = len(req.origin_input_ids) - 1
self._add_request_to_queue(req)
return
req.sampling_params.max_new_tokens = min( req.sampling_params.max_new_tokens = min(
( (
req.sampling_params.max_new_tokens req.sampling_params.max_new_tokens
...@@ -731,7 +780,13 @@ class Scheduler: ...@@ -731,7 +780,13 @@ class Scheduler:
if add_to_grammar_queue: if add_to_grammar_queue:
self.grammar_queue.append(req) self.grammar_queue.append(req)
else: else:
self.waiting_queue.append(req) self._add_request_to_queue(req)
def _add_request_to_queue(self, req: Req):
self.waiting_queue.append(req)
def _extend_requests_to_queue(self, reqs: List[Req]):
self.waiting_queue.extend(reqs)
def handle_embedding_request( def handle_embedding_request(
self, self,
...@@ -752,61 +807,62 @@ class Scheduler: ...@@ -752,61 +807,62 @@ class Scheduler:
self.server_args.allow_auto_truncate, self.server_args.allow_auto_truncate,
) )
if error_msg: if error_msg:
self.waiting_queue.append(req) self._add_request_to_queue(req)
return return
# Copy more attributes # Copy more attributes
req.logprob_start_len = len(req.origin_input_ids) - 1 req.logprob_start_len = len(req.origin_input_ids) - 1
self.waiting_queue.append(req) self._add_request_to_queue(req)
def log_prefill_stats( def log_prefill_stats(
self, self,
adder: PrefillAdder, adder: PrefillAdder,
can_run_list: List[Req], can_run_list: List[Req],
running_bs: ScheduleBatch, running_bs: int,
has_being_chunked: bool,
): ):
self.tree_cache_metrics["total"] += (
adder.log_input_tokens + adder.log_hit_tokens
) / 10**9
self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9
tree_cache_hit_rate = (
self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
)
num_used = self.max_total_num_tokens - ( num_used = self.max_total_num_tokens - (
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size() self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
) )
self._largest_prefill_len = max(
self._largest_prefill_len, adder.log_input_tokens
)
logger.info( f = (
f"Prefill batch. " f"Prefill batch. "
f"#new-seq: {len(can_run_list)}, " f"#new-seq: {len(can_run_list)}, "
f"#new-token: {adder.log_input_tokens}, " f"#new-token: {adder.log_input_tokens}, "
f"#cached-token: {adder.log_hit_tokens}, " f"#cached-token: {adder.log_hit_tokens}, "
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, " f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"#running-req: {running_bs}, " f"#running-req: {running_bs}, "
f"#queue-req: {len(self.waiting_queue) + has_being_chunked}" f"#queue-req: {len(self.waiting_queue)}, "
) )
logger.info(f)
if self.enable_metrics: if self.enable_metrics:
cache_hit_rate = adder.log_hit_tokens / (
adder.log_input_tokens + adder.log_hit_tokens
)
self.stats.num_running_reqs = running_bs self.stats.num_running_reqs = running_bs
self.stats.num_used_tokens = num_used self.stats.num_used_tokens = num_used
self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2) self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
self.stats.num_queue_reqs = len(self.waiting_queue) + has_being_chunked self.stats.num_queue_reqs = len(self.waiting_queue)
self.stats.cache_hit_rate = tree_cache_hit_rate self.stats.cache_hit_rate = cache_hit_rate
self.metrics_collector.log_stats(self.stats) self.metrics_collector.log_stats(self.stats)
def log_decode_stats(self): def log_decode_stats(self):
gap_latency = time.time() - self.last_decode_stats_tic
self.last_decode_stats_tic = time.time()
self.last_gen_throughput = self.num_generated_tokens / gap_latency
self.num_generated_tokens = 0
num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
num_used = self.max_total_num_tokens - ( num_used = self.max_total_num_tokens - (
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size() self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
) )
gen_throughput = self.num_generated_tokens / (
time.time() - self.last_decode_stats_tic if RECORD_STEP_TIME:
) self.step_time_dict[num_running_reqs].append(
self.num_generated_tokens = 0 gap_latency / self.server_args.decode_log_interval
self.last_decode_stats_tic = time.time() )
num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
if self.spec_algorithm.is_none(): if self.spec_algorithm.is_none():
msg = ( msg = (
...@@ -814,14 +870,17 @@ class Scheduler: ...@@ -814,14 +870,17 @@ class Scheduler:
f"#running-req: {num_running_reqs}, " f"#running-req: {num_running_reqs}, "
f"#token: {num_used}, " f"#token: {num_used}, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, " f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"gen throughput (token/s): {gen_throughput:.2f}, " f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
f"#queue-req: {len(self.waiting_queue)}" f"largest-len: {self._largest_prefill_decode_len}, "
f"#queue-req: {len(self.waiting_queue)}, "
) )
spec_accept_length = 0 spec_accept_length = 0
else: else:
spec_accept_length = ( spec_accept_length = (
self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
) )
self.cum_spec_accept_length += self.spec_num_total_accepted_tokens
self.cum_spec_accept_count += self.spec_num_total_forward_ct
self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0 self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
msg = ( msg = (
f"Decode batch. " f"Decode batch. "
...@@ -829,8 +888,9 @@ class Scheduler: ...@@ -829,8 +888,9 @@ class Scheduler:
f"#token: {num_used}, " f"#token: {num_used}, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, " f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"accept len: {spec_accept_length:.2f}, " f"accept len: {spec_accept_length:.2f}, "
f"gen throughput (token/s): {gen_throughput:.2f}, " f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
f"#queue-req: {len(self.waiting_queue)}" f"largest-len: {self._largest_prefill_decode_len}, "
f"#queue-req: {len(self.waiting_queue)}, "
) )
logger.info(msg) logger.info(msg)
...@@ -838,7 +898,8 @@ class Scheduler: ...@@ -838,7 +898,8 @@ class Scheduler:
self.stats.num_running_reqs = num_running_reqs self.stats.num_running_reqs = num_running_reqs
self.stats.num_used_tokens = num_used self.stats.num_used_tokens = num_used
self.stats.token_usage = num_used / self.max_total_num_tokens self.stats.token_usage = num_used / self.max_total_num_tokens
self.stats.gen_throughput = gen_throughput self.stats.cache_hit_rate = 0.0
self.stats.gen_throughput = self.last_gen_throughput
self.stats.num_queue_reqs = len(self.waiting_queue) self.stats.num_queue_reqs = len(self.waiting_queue)
self.stats.spec_accept_length = spec_accept_length self.stats.spec_accept_length = spec_accept_length
self.metrics_collector.log_stats(self.stats) self.metrics_collector.log_stats(self.stats)
...@@ -872,21 +933,42 @@ class Scheduler: ...@@ -872,21 +933,42 @@ class Scheduler:
if crash_on_warnings(): if crash_on_warnings():
raise ValueError(msg) raise ValueError(msg)
if (
self.enable_metrics
and self.attn_tp_rank == 0
and time.time() > self.metrics_collector.last_log_time + 30
):
# During idle time, also collect metrics every 30 seconds.
num_used = self.max_total_num_tokens - (
self.token_to_kv_pool.available_size()
+ self.tree_cache.evictable_size()
)
num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
self.stats.num_running_reqs = num_running_reqs
self.stats.num_used_tokens = num_used
self.stats.token_usage = num_used / self.max_total_num_tokens
self.stats.gen_throughput = 0
self.stats.num_queue_reqs = len(self.waiting_queue)
self.metrics_collector.log_stats(self.stats)
def get_next_batch_to_run(self) -> Optional[ScheduleBatch]: def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
# Merge the prefill batch into the running batch # Merge the prefill batch into the running batch
if self.last_batch and self.last_batch.forward_mode.is_extend(): if self.last_batch and self.last_batch.forward_mode.is_extend():
if self.being_chunked_req: if self.chunked_req:
# Move the chunked request out of the batch # Move the chunked request out of the batch so that we can merge
self.last_batch.filter_batch(being_chunked_req=self.being_chunked_req) # only finished requests to running_batch.
self.tree_cache.cache_unfinished_req(self.being_chunked_req) self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req)
# being chunked request keeps its rid but will get a new req_pool_idx self.tree_cache.cache_unfinished_req(self.chunked_req)
self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx) # chunked request keeps its rid but will get a new req_pool_idx
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
self.batch_is_full = False self.batch_is_full = False
self.last_batch.filter_batch()
if not self.last_batch.is_empty(): if not self.last_batch.is_empty():
if self.running_batch is None: if self.running_batch is None:
self.running_batch = self.last_batch self.running_batch = self.last_batch
else: else:
# merge running_batch with prefill batch
self.running_batch.merge_batch(self.last_batch) self.running_batch.merge_batch(self.last_batch)
new_batch = self.get_new_batch_prefill() new_batch = self.get_new_batch_prefill()
...@@ -915,7 +997,7 @@ class Scheduler: ...@@ -915,7 +997,7 @@ class Scheduler:
# Handle the cases where prefill is not allowed # Handle the cases where prefill is not allowed
if ( if (
self.batch_is_full or len(self.waiting_queue) == 0 self.batch_is_full or len(self.waiting_queue) == 0
) and self.being_chunked_req is None: ) and self.chunked_req is None:
return None return None
running_bs = len(self.running_batch.reqs) if self.running_batch else 0 running_bs = len(self.running_batch.reqs) if self.running_batch else 0
...@@ -937,10 +1019,10 @@ class Scheduler: ...@@ -937,10 +1019,10 @@ class Scheduler:
running_bs if self.is_mixed_chunk else 0, running_bs if self.is_mixed_chunk else 0,
) )
has_being_chunked = self.being_chunked_req is not None is_chunked = self.chunked_req is not None
if has_being_chunked: if is_chunked:
self.being_chunked_req.init_next_round_input() self.chunked_req.init_next_round_input()
self.being_chunked_req = adder.add_being_chunked_req(self.being_chunked_req) self.chunked_req = adder.add_chunked_req(self.chunked_req)
if self.lora_paths: if self.lora_paths:
lora_set = ( lora_set = (
...@@ -994,7 +1076,7 @@ class Scheduler: ...@@ -994,7 +1076,7 @@ class Scheduler:
self.tree_cache.dec_lock_ref(self.staging_reqs[req.rid]) self.tree_cache.dec_lock_ref(self.staging_reqs[req.rid])
del self.staging_reqs[req.rid] del self.staging_reqs[req.rid]
res = adder.add_one_req(req) res = adder.add_one_req(req, self.chunked_req)
if res != AddReqResult.CONTINUE: if res != AddReqResult.CONTINUE:
if res == AddReqResult.NO_TOKEN: if res == AddReqResult.NO_TOKEN:
if self.enable_hierarchical_cache: if self.enable_hierarchical_cache:
...@@ -1006,27 +1088,27 @@ class Scheduler: ...@@ -1006,27 +1088,27 @@ class Scheduler:
else: else:
self.batch_is_full = True self.batch_is_full = True
break break
if self.server_args.prefill_only_one_req: if self.prefill_only_one_req:
break break
# Update waiting queue # Update waiting queue
can_run_list = adder.can_run_list can_run_list: List[Req] = adder.can_run_list
if len(can_run_list) == 0: if len(can_run_list) == 0:
return None return None
self.waiting_queue = [ self.waiting_queue = [
x for x in self.waiting_queue if x not in set(can_run_list) x for x in self.waiting_queue if x not in set(can_run_list)
] ]
if adder.new_being_chunked_req is not None: if adder.new_chunked_req is not None:
assert self.being_chunked_req is None assert self.chunked_req is None
self.being_chunked_req = adder.new_being_chunked_req self.chunked_req = adder.new_chunked_req
if self.being_chunked_req: if self.chunked_req:
self.being_chunked_req.is_being_chunked += 1 self.chunked_req.is_chunked += 1
# Print stats # Print stats
if self.attn_tp_rank == 0: if self.attn_tp_rank == 0:
self.log_prefill_stats(adder, can_run_list, running_bs, has_being_chunked) self.log_prefill_stats(adder, can_run_list, running_bs)
# Create a new batch # Create a new batch
new_batch = ScheduleBatch.init_new( new_batch = ScheduleBatch.init_new(
...@@ -1062,8 +1144,6 @@ class Scheduler: ...@@ -1062,8 +1144,6 @@ class Scheduler:
def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]: def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]:
"""Update the current running decoding batch.""" """Update the current running decoding batch."""
global test_retract
initial_bs = batch.batch_size() initial_bs = batch.batch_size()
batch.filter_batch() batch.filter_batch()
...@@ -1073,11 +1153,11 @@ class Scheduler: ...@@ -1073,11 +1153,11 @@ class Scheduler:
# Check if decode out of memory # Check if decode out of memory
if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or ( if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
test_retract and batch.batch_size() > 10 TEST_RETRACT and batch.batch_size() > 10
): ):
old_ratio = self.new_token_ratio old_ratio = self.new_token_ratio
retracted_reqs, new_token_ratio = batch.retract_decode() retracted_reqs, new_token_ratio = batch.retract_decode(self.server_args)
self.new_token_ratio = new_token_ratio self.new_token_ratio = new_token_ratio
if self.draft_worker: if self.draft_worker:
self.draft_worker.finish_request(retracted_reqs) self.draft_worker.finish_request(retracted_reqs)
...@@ -1087,7 +1167,7 @@ class Scheduler: ...@@ -1087,7 +1167,7 @@ class Scheduler:
f"#retracted_reqs: {len(retracted_reqs)}, " f"#retracted_reqs: {len(retracted_reqs)}, "
f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}" f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
) )
self.waiting_queue.extend(retracted_reqs) self._extend_requests_to_queue(retracted_reqs)
else: else:
self.new_token_ratio = max( self.new_token_ratio = max(
self.new_token_ratio - self.new_token_ratio_decay, self.new_token_ratio - self.new_token_ratio_decay,
...@@ -1097,7 +1177,7 @@ class Scheduler: ...@@ -1097,7 +1177,7 @@ class Scheduler:
# Check for jump-forward # Check for jump-forward
if not self.disable_jump_forward and batch.has_grammar: if not self.disable_jump_forward and batch.has_grammar:
jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func) jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
self.waiting_queue.extend(jump_forward_reqs) self._extend_requests_to_queue(jump_forward_reqs)
if batch.is_empty(): if batch.is_empty():
self.batch_is_full = False self.batch_is_full = False
return None return None
...@@ -1115,6 +1195,13 @@ class Scheduler: ...@@ -1115,6 +1195,13 @@ class Scheduler:
"""Run a batch.""" """Run a batch."""
self.forward_ct += 1 self.forward_ct += 1
# Check profiler
if (
self.profiler_target_forward_ct
and self.profiler_target_forward_ct <= self.forward_ct
):
self.stop_profile()
if self.is_generation: if self.is_generation:
if self.spec_algorithm.is_none(): if self.spec_algorithm.is_none():
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
...@@ -1135,9 +1222,23 @@ class Scheduler: ...@@ -1135,9 +1222,23 @@ class Scheduler:
self.num_generated_tokens += num_accepted_tokens self.num_generated_tokens += num_accepted_tokens
batch.output_ids = next_token_ids batch.output_ids = next_token_ids
# These 2 values are needed for processing the output, but the values can be
# modified by overlap schedule. So we have to copy them here so that
# we can use the correct values in output processing.
if batch.return_logprob:
extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
extend_logprob_start_len_per_req = [
req.extend_logprob_start_len for req in batch.reqs
]
else:
extend_input_len_per_req = None
extend_logprob_start_len_per_req = None
ret = GenerationBatchResult( ret = GenerationBatchResult(
logits_output=logits_output, logits_output=logits_output,
next_token_ids=next_token_ids, next_token_ids=next_token_ids,
extend_input_len_per_req=extend_input_len_per_req,
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
bid=model_worker_batch.bid, bid=model_worker_batch.bid,
) )
else: # embedding or reward model else: # embedding or reward model
...@@ -1171,6 +1272,13 @@ class Scheduler: ...@@ -1171,6 +1272,13 @@ class Scheduler:
self.current_stream.synchronize() self.current_stream.synchronize()
batch.next_batch_sampling_info.sampling_info_done.set() batch.next_batch_sampling_info.sampling_info_done.set()
if self.return_health_check_ct:
# Return some signal for the health check.
# This is used to prevent the health check signal being blocked by long context prefill.
# However, one minor issue is that this code path does not check the status of detokenizer manager.
self.return_health_check_ct -= 1
self.send_to_tokenizer.send_pyobj(HealthCheckOutput())
def process_batch_result_prefill( def process_batch_result_prefill(
self, self,
batch: ScheduleBatch, batch: ScheduleBatch,
...@@ -1182,10 +1290,14 @@ class Scheduler: ...@@ -1182,10 +1290,14 @@ class Scheduler:
( (
logits_output, logits_output,
next_token_ids, next_token_ids,
extend_input_len_per_req,
extend_logprob_start_len_per_req,
bid, bid,
) = ( ) = (
result.logits_output, result.logits_output,
result.next_token_ids, result.next_token_ids,
result.extend_input_len_per_req,
result.extend_logprob_start_len_per_req,
result.bid, result.bid,
) )
...@@ -1195,12 +1307,14 @@ class Scheduler: ...@@ -1195,12 +1307,14 @@ class Scheduler:
# Move next_token_ids and logprobs to cpu # Move next_token_ids and logprobs to cpu
next_token_ids = next_token_ids.tolist() next_token_ids = next_token_ids.tolist()
if batch.return_logprob: if batch.return_logprob:
logits_output.next_token_logprobs = ( if logits_output.next_token_logprobs is not None:
logits_output.next_token_logprobs.tolist() logits_output.next_token_logprobs = (
) logits_output.next_token_logprobs.tolist()
logits_output.input_token_logprobs = ( )
logits_output.input_token_logprobs.tolist() if logits_output.input_token_logprobs is not None:
) logits_output.input_token_logprobs = tuple(
logits_output.input_token_logprobs.tolist()
)
hidden_state_offset = 0 hidden_state_offset = 0
...@@ -1216,19 +1330,33 @@ class Scheduler: ...@@ -1216,19 +1330,33 @@ class Scheduler:
self.token_to_kv_pool.free(batch.out_cache_loc[j : j + 1]) self.token_to_kv_pool.free(batch.out_cache_loc[j : j + 1])
continue continue
if req.is_being_chunked <= 0: if req.is_chunked <= 0:
# req output_ids are set here
req.output_ids.append(next_token_id) req.output_ids.append(next_token_id)
req.check_finished() req.check_finished()
if req.finished(): if req.finished():
self.tree_cache.cache_finished_req(req) self.tree_cache.cache_finished_req(req)
elif not batch.decoding_reqs or req not in batch.decoding_reqs: elif not batch.decoding_reqs or req not in batch.decoding_reqs:
# This updates radix so others can match
self.tree_cache.cache_unfinished_req(req) self.tree_cache.cache_unfinished_req(req)
if req.return_logprob: if req.return_logprob:
logprob_pt += self.add_logprob_return_values( assert extend_logprob_start_len_per_req is not None
i, req, logprob_pt, next_token_ids, logits_output assert extend_input_len_per_req is not None
extend_logprob_start_len = extend_logprob_start_len_per_req[i]
extend_input_len = extend_input_len_per_req[i]
num_input_logprobs = extend_input_len - extend_logprob_start_len
self.add_logprob_return_values(
i,
req,
logprob_pt,
next_token_ids,
num_input_logprobs,
logits_output,
) )
logprob_pt += num_input_logprobs
if ( if (
req.return_hidden_states req.return_hidden_states
and logits_output.hidden_states is not None and logits_output.hidden_states is not None
...@@ -1249,12 +1377,31 @@ class Scheduler: ...@@ -1249,12 +1377,31 @@ class Scheduler:
req.grammar.finished = req.finished() req.grammar.finished = req.finished()
else: else:
# being chunked reqs' prefill is not finished # being chunked reqs' prefill is not finished
req.is_being_chunked -= 1 req.is_chunked -= 1
# There is only at most one request being currently chunked. # There is only at most one request being currently chunked.
# Because this request does not finish prefill, # Because this request does not finish prefill,
# we don't want to stream the request currently being chunked. # we don't want to stream the request currently being chunked.
skip_stream_req = req skip_stream_req = req
# Incrementally update input logprobs.
if req.return_logprob:
extend_logprob_start_len = extend_logprob_start_len_per_req[i]
extend_input_len = extend_input_len_per_req[i]
if extend_logprob_start_len < extend_input_len:
# Update input logprobs.
num_input_logprobs = (
extend_input_len - extend_logprob_start_len
)
self.add_input_logprob_return_values(
i,
req,
logits_output,
logprob_pt,
num_input_logprobs,
last_prefill_chunk=False,
)
logprob_pt += num_input_logprobs
if batch.next_batch_sampling_info: if batch.next_batch_sampling_info:
batch.next_batch_sampling_info.update_regex_vocab_mask() batch.next_batch_sampling_info.update_regex_vocab_mask()
self.current_stream.synchronize() self.current_stream.synchronize()
...@@ -1270,7 +1417,7 @@ class Scheduler: ...@@ -1270,7 +1417,7 @@ class Scheduler:
continue continue
req.embedding = embeddings[i] req.embedding = embeddings[i]
if req.is_being_chunked <= 0: if req.is_chunked <= 0:
# Dummy output token for embedding models # Dummy output token for embedding models
req.output_ids.append(0) req.output_ids.append(0)
req.check_finished() req.check_finished()
...@@ -1281,7 +1428,7 @@ class Scheduler: ...@@ -1281,7 +1428,7 @@ class Scheduler:
self.tree_cache.cache_unfinished_req(req) self.tree_cache.cache_unfinished_req(req)
else: else:
# being chunked reqs' prefill is not finished # being chunked reqs' prefill is not finished
req.is_being_chunked -= 1 req.is_chunked -= 1
self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req) self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
...@@ -1322,11 +1469,11 @@ class Scheduler: ...@@ -1322,11 +1469,11 @@ class Scheduler:
req.output_ids.append(next_token_id) req.output_ids.append(next_token_id)
req.check_finished() req.check_finished()
if req.finished(): if req.finished():
self.tree_cache.cache_finished_req(req) self.tree_cache.cache_finished_req(req)
if req.return_logprob: if req.return_logprob and batch.spec_algorithm.is_none():
# speculative worker handles logprob in speculative decoding
req.output_token_logprobs_val.append(next_token_logprobs[i]) req.output_token_logprobs_val.append(next_token_logprobs[i])
req.output_token_logprobs_idx.append(next_token_id) req.output_token_logprobs_idx.append(next_token_id)
if req.top_logprobs_num > 0: if req.top_logprobs_num > 0:
...@@ -1336,11 +1483,18 @@ class Scheduler: ...@@ -1336,11 +1483,18 @@ class Scheduler:
req.output_top_logprobs_idx.append( req.output_top_logprobs_idx.append(
logits_output.next_token_top_logprobs_idx[i] logits_output.next_token_top_logprobs_idx[i]
) )
if req.token_ids_logprob is not None:
req.output_token_ids_logprobs_val.append(
logits_output.next_token_token_ids_logprobs_val[i]
)
req.output_token_ids_logprobs_idx.append(
logits_output.next_token_token_ids_logprobs_idx[i]
)
if req.return_hidden_states and logits_output.hidden_states is not None: if req.return_hidden_states and logits_output.hidden_states is not None:
req.hidden_states.append(logits_output.hidden_states[i].cpu().clone()) req.hidden_states.append(logits_output.hidden_states[i].cpu().clone())
if req.grammar is not None: if req.grammar is not None and batch.spec_algorithm.is_none():
req.grammar.accept_token(next_token_id) req.grammar.accept_token(next_token_id)
req.grammar.finished = req.finished() req.grammar.finished = req.finished()
...@@ -1360,48 +1514,156 @@ class Scheduler: ...@@ -1360,48 +1514,156 @@ class Scheduler:
): ):
self.log_decode_stats() self.log_decode_stats()
def add_logprob_return_values( def add_input_logprob_return_values(
self, self,
i: int, i: int,
req: Req, req: Req,
pt: int,
next_token_ids: List[int],
output: LogitsProcessorOutput, output: LogitsProcessorOutput,
logprob_pt: int,
num_input_logprobs: int,
last_prefill_chunk: bool, # If True, it means prefill is finished.
): ):
"""Attach logprobs to the return values.""" """Incrementally add input logprobs to `req`.
req.output_token_logprobs_val.append(output.next_token_logprobs[i])
req.output_token_logprobs_idx.append(next_token_ids[i]) Args:
i: The request index in a batch.
req: The request. Input logprobs inside req are modified as a
consequence of the API
fill_ids: The prefill ids processed.
output: Logit processor output that's used to compute input logprobs
last_prefill_chunk: True if it is the last prefill (when chunked).
Some of input logprob operation should only happen at the last
prefill (e.g., computing input token logprobs).
"""
assert output.input_token_logprobs is not None
# It is for jump decoding that will be deprecated.
assert req.last_update_decode_tokens == 0
if req.input_token_logprobs is None:
req.input_token_logprobs = []
if req.temp_input_top_logprobs_val is None:
req.temp_input_top_logprobs_val = []
if req.temp_input_top_logprobs_idx is None:
req.temp_input_top_logprobs_idx = []
if req.temp_input_token_ids_logprobs_val is None:
req.temp_input_token_ids_logprobs_val = []
if req.temp_input_token_ids_logprobs_idx is None:
req.temp_input_token_ids_logprobs_idx = []
if req.input_token_logprobs_val is not None:
# The input logprob has been already computed. It only happens
# upon retract.
if req.top_logprobs_num > 0:
assert req.input_token_logprobs_val is not None
return
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored. # Important for the performance.
num_input_logprobs = req.extend_input_len - req.extend_logprob_start_len assert isinstance(output.input_token_logprobs, tuple)
input_token_logprobs: Tuple[int] = output.input_token_logprobs
input_token_logprobs = input_token_logprobs[
logprob_pt : logprob_pt + num_input_logprobs
]
req.input_token_logprobs.extend(input_token_logprobs)
if req.input_token_logprobs_val is None: if req.top_logprobs_num > 0:
input_token_logprobs_val = output.input_token_logprobs[ req.temp_input_top_logprobs_val.append(output.input_top_logprobs_val[i])
pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens req.temp_input_top_logprobs_idx.append(output.input_top_logprobs_idx[i])
]
input_token_logprobs_idx = req.fill_ids[ if req.token_ids_logprob is not None:
len(req.fill_ids) req.temp_input_token_ids_logprobs_val.append(
- num_input_logprobs output.input_token_ids_logprobs_val[i]
+ 1 : len(req.fill_ids) )
- req.last_update_decode_tokens req.temp_input_token_ids_logprobs_idx.append(
] output.input_token_ids_logprobs_idx[i]
)
if last_prefill_chunk:
input_token_logprobs = req.input_token_logprobs
req.input_token_logprobs = None
assert req.input_token_logprobs_val is None
assert req.input_token_logprobs_idx is None
assert req.input_top_logprobs_val is None
assert req.input_top_logprobs_idx is None
# Compute input_token_logprobs_val
# Always pad the first one with None.
req.input_token_logprobs_val = [None]
req.input_token_logprobs_val.extend(input_token_logprobs)
# The last input logprob is for sampling, so just pop it out.
req.input_token_logprobs_val.pop()
# Compute input_token_logprobs_idx
input_token_logprobs_idx = req.origin_input_ids[req.logprob_start_len :]
# Clip the padded hash values from image tokens. # Clip the padded hash values from image tokens.
# Otherwise, it will lead to detokenization errors. # Otherwise, it will lead to detokenization errors.
input_token_logprobs_idx = [ input_token_logprobs_idx = [
x if x < self.model_config.vocab_size - 1 else 0 x if x < self.model_config.vocab_size - 1 else 0
for x in input_token_logprobs_idx for x in input_token_logprobs_idx
] ]
req.input_token_logprobs_idx = input_token_logprobs_idx
if ( if req.top_logprobs_num > 0:
req.logprob_start_len == 0 req.input_top_logprobs_val = [None]
): # The first token does not have logprob, pad it. req.input_top_logprobs_idx = [None]
input_token_logprobs_val = [None] + input_token_logprobs_val
input_token_logprobs_idx = [req.fill_ids[0]] + input_token_logprobs_idx
req.input_token_logprobs_val = input_token_logprobs_val for val, idx in zip(
req.input_token_logprobs_idx = input_token_logprobs_idx req.temp_input_top_logprobs_val,
req.temp_input_top_logprobs_idx,
strict=True,
):
req.input_top_logprobs_val.extend(val)
req.input_top_logprobs_idx.extend(idx)
# Last token is a sample token.
req.input_top_logprobs_val.pop()
req.input_top_logprobs_idx.pop()
req.temp_input_top_logprobs_idx = None
req.temp_input_top_logprobs_val = None
if req.token_ids_logprob is not None:
req.input_token_ids_logprobs_val = [None]
req.input_token_ids_logprobs_idx = [None]
for val, idx in zip(
req.temp_input_token_ids_logprobs_val,
req.temp_input_token_ids_logprobs_idx,
strict=True,
):
req.input_token_ids_logprobs_val.extend(val)
req.input_token_ids_logprobs_idx.extend(idx)
# Last token is a sample token.
req.input_token_ids_logprobs_val.pop()
req.input_token_ids_logprobs_idx.pop()
req.temp_input_token_ids_logprobs_idx = None
req.temp_input_token_ids_logprobs_val = None
if req.return_logprob:
relevant_tokens_len = len(req.origin_input_ids) - req.logprob_start_len
assert len(req.input_token_logprobs_val) == relevant_tokens_len
assert len(req.input_token_logprobs_idx) == relevant_tokens_len
if req.top_logprobs_num > 0:
assert len(req.input_top_logprobs_val) == relevant_tokens_len
assert len(req.input_top_logprobs_idx) == relevant_tokens_len
if req.token_ids_logprob is not None:
assert len(req.input_token_ids_logprobs_val) == relevant_tokens_len
assert len(req.input_token_ids_logprobs_idx) == relevant_tokens_len
def add_logprob_return_values(
self,
i: int,
req: Req,
pt: int,
next_token_ids: List[int],
num_input_logprobs: int,
output: LogitsProcessorOutput,
):
"""Attach logprobs to the return values."""
req.output_token_logprobs_val.append(output.next_token_logprobs[i])
req.output_token_logprobs_idx.append(next_token_ids[i])
self.add_input_logprob_return_values(
i, req, output, pt, num_input_logprobs, last_prefill_chunk=True
)
if req.last_update_decode_tokens != 0: if req.last_update_decode_tokens != 0:
# Some decode tokens are re-computed in an extend batch # Some decode tokens are re-computed in an extend batch
req.output_token_logprobs_val.extend( req.output_token_logprobs_val.extend(
...@@ -1422,13 +1684,6 @@ class Scheduler: ...@@ -1422,13 +1684,6 @@ class Scheduler:
) )
if req.top_logprobs_num > 0: if req.top_logprobs_num > 0:
if req.input_top_logprobs_val is None:
req.input_top_logprobs_val = output.input_top_logprobs_val[i]
req.input_top_logprobs_idx = output.input_top_logprobs_idx[i]
if req.logprob_start_len == 0:
req.input_top_logprobs_val = [None] + req.input_top_logprobs_val
req.input_top_logprobs_idx = [None] + req.input_top_logprobs_idx
if req.last_update_decode_tokens != 0: if req.last_update_decode_tokens != 0:
req.output_top_logprobs_val.extend( req.output_top_logprobs_val.extend(
output.input_top_logprobs_val[i][-req.last_update_decode_tokens :] output.input_top_logprobs_val[i][-req.last_update_decode_tokens :]
...@@ -1440,6 +1695,26 @@ class Scheduler: ...@@ -1440,6 +1695,26 @@ class Scheduler:
req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i]) req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i])
req.output_top_logprobs_idx.append(output.next_token_top_logprobs_idx[i]) req.output_top_logprobs_idx.append(output.next_token_top_logprobs_idx[i])
if req.token_ids_logprob is not None:
if req.last_update_decode_tokens != 0:
req.output_token_ids_logprobs_val.extend(
output.input_token_ids_logprobs_val[i][
-req.last_update_decode_tokens :
]
)
req.output_token_ids_logprobs_idx.extend(
output.input_token_ids_logprobs_idx[i][
-req.last_update_decode_tokens :
]
)
req.output_token_ids_logprobs_val.append(
output.next_token_token_ids_logprobs_val[i]
)
req.output_token_ids_logprobs_idx.append(
output.next_token_token_ids_logprobs_idx[i]
)
return num_input_logprobs return num_input_logprobs
def stream_output( def stream_output(
...@@ -1474,24 +1749,41 @@ class Scheduler: ...@@ -1474,24 +1749,41 @@ class Scheduler:
input_top_logprobs_idx = [] input_top_logprobs_idx = []
output_top_logprobs_val = [] output_top_logprobs_val = []
output_top_logprobs_idx = [] output_top_logprobs_idx = []
input_token_ids_logprobs_val = []
input_token_ids_logprobs_idx = []
output_token_ids_logprobs_val = []
output_token_ids_logprobs_idx = []
else: else:
input_token_logprobs_val = input_token_logprobs_idx = ( input_token_logprobs_val = input_token_logprobs_idx = (
output_token_logprobs_val output_token_logprobs_val
) = output_token_logprobs_idx = input_top_logprobs_val = ( ) = output_token_logprobs_idx = input_top_logprobs_val = (
input_top_logprobs_idx input_top_logprobs_idx
) = output_top_logprobs_val = output_top_logprobs_idx = None ) = output_top_logprobs_val = output_top_logprobs_idx = (
input_token_ids_logprobs_val
) = input_token_ids_logprobs_idx = output_token_ids_logprobs_val = (
output_token_ids_logprobs_idx
) = None
for req in reqs: for req in reqs:
if req is skip_req: if req is skip_req:
continue continue
# TODO(lianmin): revisit this for overlap + retract + stream # Multimodal partial stream chunks break the detokenizer, so drop aborted requests here.
if self.model_config.is_multimodal_gen and req.to_abort:
continue
if ( if (
req.finished() req.finished()
# If stream, follow the given stream_interval # If stream, follow the given stream_interval
or (req.stream and len(req.output_ids) % self.stream_interval == 0) or (req.stream and len(req.output_ids) % self.stream_interval == 0)
# If not stream, we still want to output some tokens to get the benefit of incremental decoding. # If not stream, we still want to output some tokens to get the benefit of incremental decoding.
or (not req.stream and len(req.output_ids) % 50 == 0) # TODO(lianmin): this is wrong for speculative decoding because len(req.output_ids) does not
# always increase one-by-one.
or (
not req.stream
and len(req.output_ids) % 50 == 0
and not self.model_config.is_multimodal_gen
)
): ):
if self.draft_worker and req.finished(): if self.draft_worker and req.finished():
self.draft_worker.finish_request(req) self.draft_worker.finish_request(req)
...@@ -1529,6 +1821,18 @@ class Scheduler: ...@@ -1529,6 +1821,18 @@ class Scheduler:
input_top_logprobs_idx.append(req.input_top_logprobs_idx) input_top_logprobs_idx.append(req.input_top_logprobs_idx)
output_top_logprobs_val.append(req.output_top_logprobs_val) output_top_logprobs_val.append(req.output_top_logprobs_val)
output_top_logprobs_idx.append(req.output_top_logprobs_idx) output_top_logprobs_idx.append(req.output_top_logprobs_idx)
input_token_ids_logprobs_val.append(
req.input_token_ids_logprobs_val
)
input_token_ids_logprobs_idx.append(
req.input_token_ids_logprobs_idx
)
output_token_ids_logprobs_val.append(
req.output_token_ids_logprobs_val
)
output_token_ids_logprobs_idx.append(
req.output_token_ids_logprobs_idx
)
if req.return_hidden_states: if req.return_hidden_states:
if output_hidden_states is None: if output_hidden_states is None:
...@@ -1537,6 +1841,9 @@ class Scheduler: ...@@ -1537,6 +1841,9 @@ class Scheduler:
# Send to detokenizer # Send to detokenizer
if rids: if rids:
if self.model_config.is_multimodal_gen:
raise NotImplementedError()
self.send_to_detokenizer.send_pyobj( self.send_to_detokenizer.send_pyobj(
BatchTokenIDOut( BatchTokenIDOut(
rids, rids,
...@@ -1561,6 +1868,10 @@ class Scheduler: ...@@ -1561,6 +1868,10 @@ class Scheduler:
input_top_logprobs_idx, input_top_logprobs_idx,
output_top_logprobs_val, output_top_logprobs_val,
output_top_logprobs_idx, output_top_logprobs_idx,
input_token_ids_logprobs_val,
input_token_ids_logprobs_idx,
output_token_ids_logprobs_val,
output_token_ids_logprobs_idx,
output_hidden_states, output_hidden_states,
) )
) )
...@@ -1668,7 +1979,7 @@ class Scheduler: ...@@ -1668,7 +1979,7 @@ class Scheduler:
].grammar.result() ].grammar.result()
num_ready_reqs = num_ready_reqs_max num_ready_reqs = num_ready_reqs_max
self.waiting_queue.extend(self.grammar_queue[:num_ready_reqs]) self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
self.grammar_queue = self.grammar_queue[num_ready_reqs:] self.grammar_queue = self.grammar_queue[num_ready_reqs:]
def flush_cache_wrapped(self, recv_req: FlushCacheReq): def flush_cache_wrapped(self, recv_req: FlushCacheReq):
...@@ -1679,6 +1990,8 @@ class Scheduler: ...@@ -1679,6 +1990,8 @@ class Scheduler:
if len(self.waiting_queue) == 0 and ( if len(self.waiting_queue) == 0 and (
self.running_batch is None or len(self.running_batch.reqs) == 0 self.running_batch is None or len(self.running_batch.reqs) == 0
): ):
self.cur_batch = None
self.last_batch = None
self.tree_cache.reset() self.tree_cache.reset()
self.tree_cache_metrics = {"total": 0, "hit": 0} self.tree_cache_metrics = {"total": 0, "hit": 0}
if self.grammar_backend: if self.grammar_backend:
...@@ -1694,6 +2007,8 @@ class Scheduler: ...@@ -1694,6 +2007,8 @@ class Scheduler:
self.forward_ct_decode = 0 self.forward_ct_decode = 0
self.spec_num_total_accepted_tokens = 0 self.spec_num_total_accepted_tokens = 0
self.spec_num_total_forward_ct = 0 self.spec_num_total_forward_ct = 0
self.cum_spec_accept_length = 0
self.cum_spec_accept_count = 0
torch.cuda.empty_cache() torch.cuda.empty_cache()
logger.info("Cache flushed successfully!") logger.info("Cache flushed successfully!")
if_success = True if_success = True
...@@ -1706,6 +2021,49 @@ class Scheduler: ...@@ -1706,6 +2021,49 @@ class Scheduler:
if_success = False if_success = False
return if_success return if_success
def get_internal_state(self, recv_req: GetInternalStateReq):
ret = dict(global_server_args_dict)
ret["last_gen_throughput"] = self.last_gen_throughput
if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
ret["avg_spec_accept_length"] = (
self.cum_spec_accept_length / self.cum_spec_accept_count
)
if RECORD_STEP_TIME:
ret["step_time_dict"] = self.step_time_dict
return GetInternalStateReqOutput(
internal_state=ret,
)
def set_internal_state(self, recv_req: SetInternalStateReq):
server_args_dict = recv_req.server_args
args_allow_update = set(
[
"speculative_accept_threshold_single",
"speculative_accept_threshold_acc",
]
)
if_success = True
for k, v in server_args_dict.items():
if k not in args_allow_update:
logging.warning(f"Updating {k} is not supported.")
if_success = False
break
if if_success:
if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
avg_spec_accept_length = (
self.cum_spec_accept_length / self.cum_spec_accept_count
)
logger.info(f"{avg_spec_accept_length=}")
self.cum_spec_accept_length = self.cum_spec_accept_count = 0
for k, v in server_args_dict.items():
global_server_args_dict[k] = v
logger.info(f"Global server args updated! " f"{global_server_args_dict=}")
return SetInternalStateReqOutput(
updated=True,
server_args=global_server_args_dict,
)
def abort_request(self, recv_req: AbortReq): def abort_request(self, recv_req: AbortReq):
# Delete requests in the waiting queue # Delete requests in the waiting queue
to_del = None to_del = None
...@@ -1735,7 +2093,7 @@ class Scheduler: ...@@ -1735,7 +2093,7 @@ class Scheduler:
assert flash_cache_success, "Cache flush failed after updating weights" assert flash_cache_success, "Cache flush failed after updating weights"
else: else:
logger.error(message) logger.error(message)
return UpdateWeightFromDiskReqOutput(success, message) return UpdateWeightFromDiskReqOutput(success, message, 0)
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput): def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
"""Initialize the online model parameter update group.""" """Initialize the online model parameter update group."""
...@@ -1771,7 +2129,7 @@ class Scheduler: ...@@ -1771,7 +2129,7 @@ class Scheduler:
parameter = self.tp_worker.get_weights_by_name(recv_req) parameter = self.tp_worker.get_weights_by_name(recv_req)
return GetWeightsByNameReqOutput(parameter) return GetWeightsByNameReqOutput(parameter)
def release_memory_occupation(self): def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
self.stashed_model_static_state = _export_static_state( self.stashed_model_static_state = _export_static_state(
self.tp_worker.worker.model_runner.model self.tp_worker.worker.model_runner.model
) )
...@@ -1779,7 +2137,7 @@ class Scheduler: ...@@ -1779,7 +2137,7 @@ class Scheduler:
self.flush_cache() self.flush_cache()
return ReleaseMemoryOccupationReqOutput() return ReleaseMemoryOccupationReqOutput()
def resume_memory_occupation(self): def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
self.memory_saver_adapter.resume() self.memory_saver_adapter.resume()
_import_static_state( _import_static_state(
self.tp_worker.worker.model_runner.model, self.stashed_model_static_state self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
...@@ -1788,24 +2146,96 @@ class Scheduler: ...@@ -1788,24 +2146,96 @@ class Scheduler:
return ResumeMemoryOccupationReqOutput() return ResumeMemoryOccupationReqOutput()
def profile(self, recv_req: ProfileReq): def profile(self, recv_req: ProfileReq):
if recv_req == ProfileReq.START_PROFILE: if recv_req.type == ProfileReqType.START_PROFILE:
self.start_profile() return self.start_profile(
recv_req.output_dir, recv_req.num_steps, recv_req.activities
)
else: else:
self.stop_profile() return self.stop_profile()
def start_profile(
self,
output_dir: Optional[str],
num_steps: Optional[int],
activities: Optional[List[str]],
) -> None:
if self.torch_profiler_activities:
return ProfileReqOutput(
success=False,
message="Profiling is already in progress. Call /stop_profile first.",
)
if output_dir is None:
output_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR", "/tmp")
if activities is None:
activities = ["CPU", "GPU"]
self.torch_profiler_output_dir = output_dir
self.torch_profiler_activities = activities
logger.info(
"Profiling starts. Traces will be saved to: %s",
self.torch_profiler_output_dir,
)
activity_map = {
"CPU": torch.profiler.ProfilerActivity.CPU,
"GPU": torch.profiler.ProfilerActivity.CUDA,
}
torchprof_activities = [
activity_map[a] for a in activities if a in activity_map
]
if torchprof_activities:
self.torch_profiler = torch.profiler.profile(
activities=torchprof_activities,
with_stack=True,
)
self.torch_profiler.start()
if "MEM" in activities:
torch.cuda.memory._record_memory_history(max_entries=100000)
def start_profile(self) -> None: if num_steps:
if self.profiler is None: self.profiler_target_forward_ct = self.forward_ct + num_steps
raise RuntimeError("Profiler is not enabled.") # The caller will be notified when reaching profiler_target_forward_ct
self.profiler.start() else:
self.profiler_target_forward_ct = None
return ProfileReqOutput(success=True, message="Succeeded")
def stop_profile(self) -> None: def stop_profile(self) -> None:
if self.profiler is None: if self.torch_profiler_activities is None:
raise RuntimeError("Profiler is not enabled.") return
self.profiler.stop()
self.profiler.export_chrome_trace( logger.info("Stop profiling...")
self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz" if self.torch_profiler is not None:
self.torch_profiler.stop()
self.torch_profiler.export_chrome_trace(
os.path.join(
self.torch_profiler_output_dir,
str(time.time()) + f"-TP-{self.tp_rank}" + ".trace.json.gz",
)
)
if "MEM" in self.torch_profiler_activities:
memory_profile_path = os.path.join(
self.torch_profiler_trace_dir,
str(time.time()) + f"-TP-{self.tp_rank}-memory" + ".pickle",
)
torch.cuda.memory._dump_snapshot(memory_profile_path)
torch.cuda.memory._record_memory_history(enabled=None)
logger.info(
"Profiling done. Traces are saved to: %s",
self.torch_profiler_output_dir,
) )
logger.info("Profiler is done") self.torch_profiler = None
self.torch_profiler_output_dir = None
self.torch_profiler_activities = None
if self.profiler_target_forward_ct:
self.send_to_tokenizer.send_pyobj(
ProfileReqOutput(success=True, message="Succeeded.")
)
def open_session(self, recv_req: OpenSessionReqInput): def open_session(self, recv_req: OpenSessionReqInput):
# handle error # handle error
...@@ -1814,7 +2244,7 @@ class Scheduler: ...@@ -1814,7 +2244,7 @@ class Scheduler:
logger.warning(f"session id {session_id} already exist, cannot open.") logger.warning(f"session id {session_id} already exist, cannot open.")
return OpenSessionReqOutput(session_id, False) return OpenSessionReqOutput(session_id, False)
elif session_id is None: elif session_id is None:
logger.warning(f"session id is None, cannot open.") logger.warning("session id is None, cannot open.")
return OpenSessionReqOutput(session_id, False) return OpenSessionReqOutput(session_id, False)
else: else:
self.sessions[session_id] = Session( self.sessions[session_id] = Session(
...@@ -1831,6 +2261,10 @@ class Scheduler: ...@@ -1831,6 +2261,10 @@ class Scheduler:
del self.sessions[session_id] del self.sessions[session_id]
def is_health_check_generate_req(recv_req):
return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
def _export_static_state(model): def _export_static_state(model):
return dict( return dict(
buffers=[ buffers=[
...@@ -1853,8 +2287,11 @@ def run_scheduler_process( ...@@ -1853,8 +2287,11 @@ def run_scheduler_process(
dp_rank: Optional[int], dp_rank: Optional[int],
pipe_writer, pipe_writer,
): ):
setproctitle.setproctitle("sglang::scheduler") # Config the process
# kill_itself_when_parent_died() # This is disabled because it does not work for `--dp 2`
setproctitle.setproctitle(f"sglang::scheduler_{dp_rank}")
faulthandler.enable() faulthandler.enable()
parent_process = psutil.Process().parent()
# [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
if dp_rank is None and "SGLANG_DP_RANK" in os.environ: if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
...@@ -1862,9 +2299,10 @@ def run_scheduler_process( ...@@ -1862,9 +2299,10 @@ def run_scheduler_process(
# Configure the logger # Configure the logger
if dp_rank is None: if dp_rank is None:
configure_logger(server_args, prefix=f" TP{tp_rank}") prefix = f" TP{tp_rank}"
else: else:
configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}") prefix = f" DP{dp_rank} TP{tp_rank}"
configure_logger(server_args, prefix=prefix)
suppress_other_loggers() suppress_other_loggers()
# Set cpu affinity to this gpu process # Set cpu affinity to this gpu process
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment