Commit ac238727 authored by Lianmin Zheng's avatar Lianmin Zheng
Browse files

Support penalty in overlap mode; return logprob with chunked prefill; improve...


Support penalty in overlap mode; return logprob with chunked prefill; improve benchmark scripts (#3988)
Co-authored-by: default avatarSangBin Cho <rkooo567@gmail.com>
Co-authored-by: default avatardhou-xai <dhou@x.ai>
Co-authored-by: default avatarHanming Lu <hanming_lu@berkeley.edu>
parent 0194948f
from __future__ import annotations
import functools
from typing import TYPE_CHECKING, Union
import torch import torch
import triton
import triton.language as tl
from sglang.srt.distributed import GroupCoordinator, get_tp_group from sglang.srt.distributed import (
GroupCoordinator,
get_tensor_model_parallel_world_size,
get_tp_group,
tensor_model_parallel_all_reduce,
)
if TYPE_CHECKING:
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
_ATTN_TP_GROUP = None _ATTN_TP_GROUP = None
_ATTN_TP_RANK = None _ATTN_TP_RANK = None
...@@ -69,3 +84,129 @@ def get_attention_dp_rank(): ...@@ -69,3 +84,129 @@ def get_attention_dp_rank():
def get_attention_dp_size(): def get_attention_dp_size():
assert _DP_SIZE is not None, "dp attention not initialized!" assert _DP_SIZE is not None, "dp attention not initialized!"
return _DP_SIZE return _DP_SIZE
def get_dp_local_info(forward_batch: ForwardBatch):
dp_rank = get_attention_dp_rank()
if forward_batch.dp_local_start_pos is None:
cumtokens = torch.cumsum(forward_batch.global_num_tokens_gpu, dim=0)
if dp_rank == 0:
local_start_pos = torch.zeros_like(cumtokens[0])
else:
local_start_pos = cumtokens[dp_rank - 1]
local_num_tokens = forward_batch.global_num_tokens_gpu[dp_rank]
forward_batch.dp_local_start_pos = local_start_pos
forward_batch.dp_local_num_tokens = local_num_tokens
return forward_batch.dp_local_start_pos, forward_batch.dp_local_num_tokens
@triton.jit
def memcpy_triton_kernel(
dst_ptr,
src_ptr,
offset_ptr,
sz_ptr,
offset_src,
chunk_size, # multiplied for offset and sz
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0).to(tl.int64)
offset = tl.load(offset_ptr).to(tl.int64) * chunk_size
sz = tl.load(sz_ptr).to(tl.int64) * chunk_size
start_index = pid * BLOCK_SIZE
offs = tl.arange(0, BLOCK_SIZE)
mask = start_index + offs < sz
if offset_src:
data = tl.load(src_ptr + offset + start_index + offs, mask=mask)
tl.store(dst_ptr + start_index + offs, data, mask=mask)
else:
data = tl.load(src_ptr + start_index + offs, mask=mask)
tl.store(dst_ptr + offset + start_index + offs, data, mask=mask)
def prod(x):
return functools.reduce(lambda a, b: a * b, x, 1)
def memcpy_triton(dst, src, dim, offset, sz, offset_src):
max_size = min(src.numel(), dst.numel())
assert dim == 0, "dim != 0 unsupported"
assert src.shape[1:] == dst.shape[1:], "src and dst must have same shape"
chunk_size = prod(src.shape[1:])
BLOCK_SIZE = 8192
grid = (triton.cdiv(max_size, BLOCK_SIZE),)
memcpy_triton_kernel[grid](dst, src, offset, sz, offset_src, chunk_size, BLOCK_SIZE)
def dp_gather(
global_tokens: torch.Tensor,
local_tokens: torch.Tensor,
forward_batch: ForwardBatch,
layer_id: Union[str, int],
):
local_start_pos, local_num_tokens = get_dp_local_info(forward_batch)
global_tokens.fill_(0)
assert local_tokens.is_contiguous()
assert global_tokens.is_contiguous()
if local_tokens.shape[0] > 0 and (
layer_id != "embedding" or get_attention_tp_rank() == 0
):
assert (
global_tokens.storage().data_ptr() != local_tokens.storage().data_ptr()
), "aliasing between global_tokens and local_tokens not allowed"
memcpy_triton(
global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
)
# Input IDs are in int 32. We should use inplace_all_reduce for local case becaues of custom all reduce.
NUM_GPUS_PER_NODE = 8
if (
not local_tokens.dtype.is_floating_point
and get_tensor_model_parallel_world_size() <= NUM_GPUS_PER_NODE
):
torch.ops.sglang.inplace_all_reduce(
global_tokens, group_name=get_tp_group().unique_name
)
else:
global_tokens = tensor_model_parallel_all_reduce(global_tokens)
def dp_scatter(
local_tokens: torch.Tensor, # output
global_tokens: torch.Tensor, # input
forward_batch: ForwardBatch,
):
# local_num_tokens is not necessarily the same as local_tokens.shape[0],
# since local_tokens may be padded for cuda graph
local_start_pos, local_num_tokens = get_dp_local_info(forward_batch)
local_tokens.fill_(0)
assert local_tokens.is_contiguous()
assert global_tokens.is_contiguous()
if local_tokens.shape[0] > 0:
assert (
local_tokens.untyped_storage().data_ptr()
!= global_tokens.untyped_storage().data_ptr()
), "aliasing between local_tokens and global_tokens not allowed"
memcpy_triton(
local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True
)
def get_do_logits_dp_scatter(forward_batch: ForwardBatch):
def do_logits_dp_scatter(logits: torch.Tensor):
local_logits = torch.empty(
(forward_batch.input_ids.shape[0], *logits.shape[1:]),
dtype=logits.dtype,
device=logits.device,
)
dp_scatter(local_logits, logits, forward_batch)
return local_logits
return do_logits_dp_scatter
...@@ -69,7 +69,7 @@ class RMSNorm(CustomOp): ...@@ -69,7 +69,7 @@ class RMSNorm(CustomOp):
variance = x.pow(2).mean(dim=-1, keepdim=True) variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon) x = x * torch.rsqrt(variance + self.variance_epsilon)
x = x.to(orig_dtype) * self.weight x = (x * self.weight).to(orig_dtype)
if residual is None: if residual is None:
return x return x
else: else:
......
...@@ -426,13 +426,14 @@ class ColumnParallelLinear(LinearBase): ...@@ -426,13 +426,14 @@ class ColumnParallelLinear(LinearBase):
from sglang.srt.layers.parameter import _ColumnvLLMParameter from sglang.srt.layers.parameter import _ColumnvLLMParameter
if isinstance(param, _ColumnvLLMParameter): if isinstance(param, _ColumnvLLMParameter):
# FIXME: why would we need this special case?
param.load_column_parallel_weight( param.load_column_parallel_weight(
loaded_weight, loaded_weight,
tp_rank=self.tp_rank, tp_rank=self.tp_rank,
use_presharded_weights=self.use_presharded_weights, use_presharded_weights=self.use_presharded_weights,
) )
else: else:
# FIXME: This branch is needed to load deepseek v3 awq.
# However, we should fix this and avoid the branching here.
param.load_column_parallel_weight(loaded_weight) param.load_column_parallel_weight(loaded_weight)
def forward(self, input_): def forward(self, input_):
......
...@@ -144,6 +144,73 @@ def silu_and_mul_triton_kernel( ...@@ -144,6 +144,73 @@ def silu_and_mul_triton_kernel(
tl.store(down_input_ptr + offset, silu_mul_output, mask=mask) tl.store(down_input_ptr + offset, silu_mul_output, mask=mask)
@triton.jit
def tanh(x):
return 2 * tl.sigmoid(2 * x) - 1
@triton.jit
def gelu_and_mul_triton_kernel(
gateup_output,
down_input,
hidden_size,
reorder_topk_ids,
scales,
start_expert_id,
end_expert_id,
BLOCK_SIZE: tl.constexpr,
):
InDtype = gateup_output.dtype.element_ty
OutDtype = down_input.dtype.element_ty
half_hidden_size = hidden_size // 2
pid = tl.program_id(0)
expert_id = tl.load(reorder_topk_ids + pid)
if expert_id >= start_expert_id and expert_id <= end_expert_id:
gateup_output_ptr = gateup_output + pid * hidden_size
gate_output_ptr = gateup_output_ptr
up_output_ptr = gateup_output_ptr + half_hidden_size
down_input_ptr = down_input + pid * half_hidden_size
if scales is not None:
scale = tl.load(scales + expert_id - start_expert_id)
scale = (1 / scale).to(InDtype)
else:
scale = 1
for start_offset in tl.range(0, half_hidden_size, BLOCK_SIZE):
offset = start_offset + tl.arange(0, BLOCK_SIZE)
mask = offset < half_hidden_size
gate_output = tl.load(gate_output_ptr + offset, mask=mask).to(tl.float32)
up_output = tl.load(up_output_ptr + offset, mask=mask)
# gelu & mul & quantize
# https://pytorch.org/docs/stable/generated/torch.nn.GELU.html
# sqrt(2/pi)
kAlpha = 0.7978845608028654
gate_output = (
0.5
* gate_output
* (
1
+ tanh(
kAlpha
* (
gate_output
+ 0.044715 * gate_output * gate_output * gate_output
)
)
)
)
gate_output = gate_output.to(InDtype)
gelu_mul_output = gate_output * up_output * scale
gelu_mul_output = gelu_mul_output.to(OutDtype)
tl.store(down_input_ptr + offset, gelu_mul_output, mask=mask)
@triton.jit @triton.jit
def post_reorder_triton_kernel( def post_reorder_triton_kernel(
down_output_ptr, down_output_ptr,
......
...@@ -11,6 +11,7 @@ from sglang.srt.distributed import ( ...@@ -11,6 +11,7 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
) )
from sglang.srt.layers.moe.ep_moe.kernels import ( from sglang.srt.layers.moe.ep_moe.kernels import (
gelu_and_mul_triton_kernel,
grouped_gemm_triton, grouped_gemm_triton,
post_reorder_triton_kernel, post_reorder_triton_kernel,
pre_reorder_triton_kernel, pre_reorder_triton_kernel,
...@@ -296,6 +297,17 @@ class EPMoE(torch.nn.Module): ...@@ -296,6 +297,17 @@ class EPMoE(torch.nn.Module):
self.end_expert_id, self.end_expert_id,
BLOCK_SIZE=512, BLOCK_SIZE=512,
) )
elif self.activation == "gelu":
gelu_and_mul_triton_kernel[(gateup_output.shape[0],)](
gateup_output,
down_input,
gateup_output.shape[1],
reorder_topk_ids,
self.w2_input_scale,
self.start_expert_id,
self.end_expert_id,
BLOCK_SIZE=512,
)
else: else:
raise ValueError(f"Unsupported activation: {self.activation=}") raise ValueError(f"Unsupported activation: {self.activation=}")
......
...@@ -24,6 +24,8 @@ def fused_moe_forward_native( ...@@ -24,6 +24,8 @@ def fused_moe_forward_native(
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None, correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu", activation: str = "silu",
inplace: bool = True,
no_combine: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
topk_weights, topk_ids = select_experts( topk_weights, topk_ids = select_experts(
hidden_states=x, hidden_states=x,
......
...@@ -23,7 +23,7 @@ from sglang.srt.utils import ( ...@@ -23,7 +23,7 @@ from sglang.srt.utils import (
is_hip, is_hip,
) )
is_hip_flag = is_hip() is_hip_ = is_hip()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -487,6 +487,7 @@ def invoke_fused_moe_kernel( ...@@ -487,6 +487,7 @@ def invoke_fused_moe_kernel(
use_int8_w8a8: bool, use_int8_w8a8: bool,
use_int8_w8a16: bool, use_int8_w8a16: bool,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
no_combine: bool = False,
) -> None: ) -> None:
assert topk_weights.stride(1) == 1 assert topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1 assert sorted_token_ids.stride(0) == 1
...@@ -646,7 +647,7 @@ def get_default_config( ...@@ -646,7 +647,7 @@ def get_default_config(
"BLOCK_SIZE_K": 128, "BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32, "GROUP_SIZE_M": 32,
"num_warps": 8, "num_warps": 8,
"num_stages": 2 if is_hip_flag else 4, "num_stages": 2 if is_hip_ else 4,
} }
if M <= E: if M <= E:
config = { config = {
...@@ -655,7 +656,7 @@ def get_default_config( ...@@ -655,7 +656,7 @@ def get_default_config(
"BLOCK_SIZE_K": 128, "BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1, "GROUP_SIZE_M": 1,
"num_warps": 4, "num_warps": 4,
"num_stages": 2 if is_hip_flag else 4, "num_stages": 2 if is_hip_ else 4,
} }
else: else:
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_shape[1] # Block-wise quant: BLOCK_SIZE_K must be divisable by block_shape[1]
...@@ -665,7 +666,7 @@ def get_default_config( ...@@ -665,7 +666,7 @@ def get_default_config(
"BLOCK_SIZE_K": block_shape[1], "BLOCK_SIZE_K": block_shape[1],
"GROUP_SIZE_M": 32, "GROUP_SIZE_M": 32,
"num_warps": 4, "num_warps": 4,
"num_stages": 2 if is_hip_flag else 3, "num_stages": 2 if is_hip_ else 3,
} }
else: else:
config = { config = {
...@@ -814,6 +815,7 @@ def outplace_fused_experts( ...@@ -814,6 +815,7 @@ def outplace_fused_experts(
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
no_combine: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
return fused_experts_impl( return fused_experts_impl(
hidden_states, hidden_states,
...@@ -831,6 +833,7 @@ def outplace_fused_experts( ...@@ -831,6 +833,7 @@ def outplace_fused_experts(
a1_scale, a1_scale,
a2_scale, a2_scale,
block_shape, block_shape,
no_combine=no_combine,
) )
...@@ -849,6 +852,7 @@ def outplace_fused_experts_fake( ...@@ -849,6 +852,7 @@ def outplace_fused_experts_fake(
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
no_combine: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
return torch.empty_like(hidden_states) return torch.empty_like(hidden_states)
...@@ -877,8 +881,10 @@ def fused_experts( ...@@ -877,8 +881,10 @@ def fused_experts(
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
no_combine: bool = False,
): ):
if inplace: if inplace:
assert not no_combine, "no combine + inplace makes no sense"
torch.ops.sglang.inplace_fused_experts( torch.ops.sglang.inplace_fused_experts(
hidden_states, hidden_states,
w1, w1,
...@@ -912,6 +918,7 @@ def fused_experts( ...@@ -912,6 +918,7 @@ def fused_experts(
a1_scale, a1_scale,
a2_scale, a2_scale,
block_shape, block_shape,
no_combine=no_combine,
) )
...@@ -931,6 +938,7 @@ def fused_experts_impl( ...@@ -931,6 +938,7 @@ def fused_experts_impl(
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
no_combine: bool = False,
): ):
padded_size = padding_size padded_size = padding_size
if not use_fp8_w8a8 or not use_int8_w8a8 or block_shape is not None: if not use_fp8_w8a8 or not use_int8_w8a8 or block_shape is not None:
...@@ -987,7 +995,14 @@ def fused_experts_impl( ...@@ -987,7 +995,14 @@ def fused_experts_impl(
compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16 compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
if inplace: if no_combine:
assert not inplace
out_hidden_states = torch.empty(
(num_tokens, topk_ids.shape[1], w2.shape[1]),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
elif inplace:
out_hidden_states = hidden_states out_hidden_states = hidden_states
else: else:
out_hidden_states = torch.empty_like(hidden_states) out_hidden_states = torch.empty_like(hidden_states)
...@@ -1057,7 +1072,11 @@ def fused_experts_impl( ...@@ -1057,7 +1072,11 @@ def fused_experts_impl(
invoke_fused_moe_kernel( invoke_fused_moe_kernel(
intermediate_cache2, intermediate_cache2,
w2, w2,
intermediate_cache3, (
intermediate_cache3
if not no_combine and topk_ids.shape[1] != 1
else out_hidden_states[begin_chunk_idx:end_chunk_idx]
),
a2_scale, a2_scale,
w2_scale, w2_scale,
curr_topk_weights, curr_topk_weights,
...@@ -1075,16 +1094,16 @@ def fused_experts_impl( ...@@ -1075,16 +1094,16 @@ def fused_experts_impl(
block_shape=block_shape, block_shape=block_shape,
) )
if is_hip_flag: if no_combine:
pass
elif is_hip_:
ops.moe_sum( ops.moe_sum(
intermediate_cache3.view(*intermediate_cache3.shape), intermediate_cache3.view(*intermediate_cache3.shape),
out_hidden_states[begin_chunk_idx:end_chunk_idx], out_hidden_states[begin_chunk_idx:end_chunk_idx],
) )
else: else:
if topk_ids.shape[1] == 1: if topk_ids.shape[1] == 1:
out_hidden_states[begin_chunk_idx:end_chunk_idx].copy_( pass # we write directly into out_hidden_states
intermediate_cache3[:, 0]
)
elif topk_ids.shape[1] == 2: elif topk_ids.shape[1] == 2:
torch.add( torch.add(
intermediate_cache3[:, 0], intermediate_cache3[:, 0],
...@@ -1122,6 +1141,7 @@ def fused_moe( ...@@ -1122,6 +1141,7 @@ def fused_moe(
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
no_combine: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
This function computes a Mixture of Experts (MoE) layer using two sets of This function computes a Mixture of Experts (MoE) layer using two sets of
...@@ -1191,4 +1211,5 @@ def fused_moe( ...@@ -1191,4 +1211,5 @@ def fused_moe(
a1_scale=a1_scale, a1_scale=a1_scale,
a2_scale=a2_scale, a2_scale=a2_scale,
block_shape=block_shape, block_shape=block_shape,
no_combine=no_combine,
) )
...@@ -125,6 +125,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -125,6 +125,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None, correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu", activation: str = "silu",
inplace: bool = True,
no_combine: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
return self.forward( return self.forward(
x=x, x=x,
...@@ -138,6 +140,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -138,6 +140,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
correction_bias=correction_bias, correction_bias=correction_bias,
activation=activation, activation=activation,
inplace=inplace,
no_combine=no_combine,
) )
def forward_cuda( def forward_cuda(
...@@ -153,6 +157,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -153,6 +157,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None, correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu", activation: str = "silu",
inplace: bool = True,
no_combine: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
topk_weights, topk_ids = select_experts( topk_weights, topk_ids = select_experts(
hidden_states=x, hidden_states=x,
...@@ -171,6 +177,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -171,6 +177,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
from aiter.fused_moe import fused_experts_ck from aiter.fused_moe import fused_experts_ck
assert activation == "silu", f"{activation=} is not supported." assert activation == "silu", f"{activation=} is not supported."
assert not no_combine, "unsupported"
return fused_experts_ck( return fused_experts_ck(
hidden_states=x, hidden_states=x,
...@@ -186,8 +193,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -186,8 +193,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
w2=layer.w2_weight, w2=layer.w2_weight,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=inplace and not no_combine,
activation=activation, activation=activation,
no_combine=no_combine,
) )
def forward_cpu( def forward_cpu(
...@@ -202,6 +210,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -202,6 +210,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None, correction_bias: Optional[torch.Tensor] = None,
inplace: bool = True,
) -> torch.Tensor: ) -> torch.Tensor:
return moe_forward_native( return moe_forward_native(
layer, layer,
...@@ -241,6 +250,7 @@ class FusedMoE(torch.nn.Module): ...@@ -241,6 +250,7 @@ class FusedMoE(torch.nn.Module):
reduce_results: Whether to all all_reduce on the output of the layer reduce_results: Whether to all all_reduce on the output of the layer
renomalize: Whether to renormalize the logits in the fused_moe kernel renomalize: Whether to renormalize the logits in the fused_moe kernel
quant_config: Quantization configure. quant_config: Quantization configure.
inplace: suggestion to compute inplace (modify input activation).
""" """
def __init__( def __init__(
...@@ -262,6 +272,8 @@ class FusedMoE(torch.nn.Module): ...@@ -262,6 +272,8 @@ class FusedMoE(torch.nn.Module):
correction_bias: Optional[torch.Tensor] = None, correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu", activation: str = "silu",
use_presharded_weights: bool = False, use_presharded_weights: bool = False,
inplace: bool = True,
no_combine: bool = False,
): ):
super().__init__() super().__init__()
...@@ -285,6 +297,9 @@ class FusedMoE(torch.nn.Module): ...@@ -285,6 +297,9 @@ class FusedMoE(torch.nn.Module):
self.custom_routing_function = custom_routing_function self.custom_routing_function = custom_routing_function
self.correction_bias = correction_bias self.correction_bias = correction_bias
self.activation = activation self.activation = activation
self.use_presharded_weights = use_presharded_weights
self.inplace = inplace
self.no_combine = no_combine
if quant_config is None: if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = ( self.quant_method: Optional[QuantizeMethodBase] = (
...@@ -304,7 +319,6 @@ class FusedMoE(torch.nn.Module): ...@@ -304,7 +319,6 @@ class FusedMoE(torch.nn.Module):
params_dtype=params_dtype, params_dtype=params_dtype,
weight_loader=self.weight_loader, weight_loader=self.weight_loader,
) )
self.use_presharded_weights = use_presharded_weights
def _load_per_tensor_weight_scale( def _load_per_tensor_weight_scale(
self, self,
...@@ -598,6 +612,8 @@ class FusedMoE(torch.nn.Module): ...@@ -598,6 +612,8 @@ class FusedMoE(torch.nn.Module):
custom_routing_function=self.custom_routing_function, custom_routing_function=self.custom_routing_function,
correction_bias=self.correction_bias, correction_bias=self.correction_bias,
activation=self.activation, activation=self.activation,
inplace=self.inplace,
no_combine=self.no_combine,
) )
if self.reduce_results and self.tp_size > 1: if self.reduce_results and self.tp_size > 1:
......
...@@ -771,6 +771,8 @@ class Fp8MoEMethod: ...@@ -771,6 +771,8 @@ class Fp8MoEMethod:
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None, correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu", activation: str = "silu",
inplace: bool = True,
no_combine: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import select_experts
...@@ -793,6 +795,7 @@ class Fp8MoEMethod: ...@@ -793,6 +795,7 @@ class Fp8MoEMethod:
from aiter.fused_moe import fused_experts_ck from aiter.fused_moe import fused_experts_ck
assert activation == "silu", f"{activation=} is not supported." assert activation == "silu", f"{activation=} is not supported."
assert not no_combine, f"{no_combine=} is not supported."
return fused_experts_ck( return fused_experts_ck(
x, x,
...@@ -823,7 +826,7 @@ class Fp8MoEMethod: ...@@ -823,7 +826,7 @@ class Fp8MoEMethod:
layer.w2_weight, layer.w2_weight,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=inplace and not no_combine,
activation=activation, activation=activation,
use_fp8_w8a8=True, use_fp8_w8a8=True,
w1_scale=( w1_scale=(
...@@ -839,6 +842,7 @@ class Fp8MoEMethod: ...@@ -839,6 +842,7 @@ class Fp8MoEMethod:
a1_scale=layer.w13_input_scale, a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size, block_shape=self.quant_config.weight_block_size,
no_combine=no_combine,
) )
......
...@@ -707,7 +707,6 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): ...@@ -707,7 +707,6 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
cos = freqs.cos() * self.mscale cos = freqs.cos() * self.mscale
sin = freqs.sin() * self.mscale sin = freqs.sin() * self.mscale
cache = torch.cat((cos, sin), dim=-1) cache = torch.cat((cos, sin), dim=-1)
print("Cache shape", cache.shape)
return cache return cache
def forward( def forward(
......
import logging import logging
from typing import List from typing import List, Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -41,7 +41,21 @@ class Sampler(nn.Module): ...@@ -41,7 +41,21 @@ class Sampler(nn.Module):
sampling_info: SamplingBatchInfo, sampling_info: SamplingBatchInfo,
return_logprob: bool, return_logprob: bool,
top_logprobs_nums: List[int], top_logprobs_nums: List[int],
token_ids_logprobs: List[List[int]],
batch_next_token_ids: Optional[torch.Tensor] = None,
): ):
"""Run a sampler & compute logprobs and update logits_output accordingly.
Args:
logits_output: The logits from the model forward
sampling_info: Metadata for sampling
return_logprob: If set, store the output logprob information to
logits_output
top_logprobs_nums: Number of top lobprobs per sequence in a batch
batch_next_token_ids: next token IDs. If set, skip sampling and only
compute output logprobs It is used for speculative decoding which
performs sampling in draft workers.
"""
logits = logits_output.next_token_logits logits = logits_output.next_token_logits
# Apply the custom logit processors if registered in the sampling info. # Apply the custom logit processors if registered in the sampling info.
...@@ -58,13 +72,15 @@ class Sampler(nn.Module): ...@@ -58,13 +72,15 @@ class Sampler(nn.Module):
if sampling_info.is_all_greedy: if sampling_info.is_all_greedy:
# Use torch.argmax if all requests use greedy sampling # Use torch.argmax if all requests use greedy sampling
if batch_next_token_ids is None:
batch_next_token_ids = torch.argmax(logits, -1) batch_next_token_ids = torch.argmax(logits, -1)
if return_logprob: if return_logprob:
logprobs = torch.nn.functional.log_softmax(logits, dim=-1) logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
else: else:
# Post process logits # Post process logits
logits.div_(sampling_info.temperatures) logits.div_(sampling_info.temperatures)
probs = torch.softmax(logits, dim=-1) logits[:] = torch.softmax(logits, dim=-1)
probs = logits
del logits del logits
if global_server_args_dict["sampling_backend"] == "flashinfer": if global_server_args_dict["sampling_backend"] == "flashinfer":
...@@ -78,6 +94,7 @@ class Sampler(nn.Module): ...@@ -78,6 +94,7 @@ class Sampler(nn.Module):
top_p_normalize_probs_torch(probs, sampling_info.top_ps) top_p_normalize_probs_torch(probs, sampling_info.top_ps)
).clamp(min=torch.finfo(probs.dtype).min) ).clamp(min=torch.finfo(probs.dtype).min)
if batch_next_token_ids is None:
max_top_k_round, batch_size = 32, probs.shape[0] max_top_k_round, batch_size = 32, probs.shape[0]
uniform_samples = torch.rand( uniform_samples = torch.rand(
(max_top_k_round, batch_size), device=probs.device (max_top_k_round, batch_size), device=probs.device
...@@ -99,9 +116,12 @@ class Sampler(nn.Module): ...@@ -99,9 +116,12 @@ class Sampler(nn.Module):
if self.use_nan_detection and not torch.all(success): if self.use_nan_detection and not torch.all(success):
logger.warning("Detected errors during sampling!") logger.warning("Detected errors during sampling!")
batch_next_token_ids = torch.zeros_like(batch_next_token_ids) batch_next_token_ids = torch.zeros_like(
batch_next_token_ids
)
elif global_server_args_dict["sampling_backend"] == "pytorch": elif global_server_args_dict["sampling_backend"] == "pytorch":
if batch_next_token_ids is None:
# A slower fallback implementation with torch native operations. # A slower fallback implementation with torch native operations.
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch( batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
probs, probs,
...@@ -110,6 +130,7 @@ class Sampler(nn.Module): ...@@ -110,6 +130,7 @@ class Sampler(nn.Module):
sampling_info.min_ps, sampling_info.min_ps,
sampling_info.need_min_p_sampling, sampling_info.need_min_p_sampling,
) )
if return_logprob: if return_logprob:
# clamp to avoid -inf # clamp to avoid -inf
logprobs = torch.log( logprobs = torch.log(
...@@ -128,6 +149,12 @@ class Sampler(nn.Module): ...@@ -128,6 +149,12 @@ class Sampler(nn.Module):
logits_output.next_token_top_logprobs_idx, logits_output.next_token_top_logprobs_idx,
) = get_top_logprobs(logprobs, top_logprobs_nums) ) = get_top_logprobs(logprobs, top_logprobs_nums)
if any(x is not None for x in token_ids_logprobs):
(
logits_output.next_token_token_ids_logprobs_val,
logits_output.next_token_token_ids_logprobs_idx,
) = get_token_ids_logprobs(logprobs, token_ids_logprobs)
logits_output.next_token_logprobs = logprobs[ logits_output.next_token_logprobs = logprobs[
torch.arange(len(batch_next_token_ids), device=sampling_info.device), torch.arange(len(batch_next_token_ids), device=sampling_info.device),
batch_next_token_ids, batch_next_token_ids,
...@@ -223,6 +250,10 @@ def top_p_normalize_probs_torch( ...@@ -223,6 +250,10 @@ def top_p_normalize_probs_torch(
def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]): def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]):
assert len(top_logprobs_nums) == logprobs.shape[0], (
len(top_logprobs_nums),
logprobs.shape[0],
)
max_k = max(top_logprobs_nums) max_k = max(top_logprobs_nums)
ret = logprobs.topk(max_k, dim=1) ret = logprobs.topk(max_k, dim=1)
values = ret.values.tolist() values = ret.values.tolist()
...@@ -234,3 +265,17 @@ def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]): ...@@ -234,3 +265,17 @@ def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]):
output_top_logprobs_val.append(values[i][:k]) output_top_logprobs_val.append(values[i][:k])
output_top_logprobs_idx.append(indices[i][:k]) output_top_logprobs_idx.append(indices[i][:k])
return output_top_logprobs_val, output_top_logprobs_idx return output_top_logprobs_val, output_top_logprobs_idx
def get_token_ids_logprobs(logprobs: torch.Tensor, token_ids_logprobs: List[List[int]]):
output_token_ids_logprobs_val = []
output_token_ids_logprobs_idx = []
for i, token_ids in enumerate(token_ids_logprobs):
if token_ids is not None:
output_token_ids_logprobs_val.append(logprobs[i, token_ids].tolist())
output_token_ids_logprobs_idx.append(token_ids)
else:
output_token_ids_logprobs_val.append([])
output_token_ids_logprobs_idx.append([])
return output_token_ids_logprobs_val, output_token_ids_logprobs_idx
...@@ -457,7 +457,7 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -457,7 +457,7 @@ class VocabParallelEmbedding(torch.nn.Module):
assert loaded_weight.shape[output_dim] == ( assert loaded_weight.shape[output_dim] == (
self.org_vocab_size self.org_vocab_size
// (self.tp_size if self.use_presharded_weights else 1) // (self.tp_size if self.use_presharded_weights else 1)
) ), f"{self.org_vocab_size=} {self.use_presharded_weights=} {loaded_weight.shape[output_dim]=}"
# Copy the data. # Copy the data.
if not self.use_presharded_weights: if not self.use_presharded_weights:
......
...@@ -28,6 +28,7 @@ if __name__ == "__main__": ...@@ -28,6 +28,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--url", type=str, default="http://localhost:30000") parser.add_argument("--url", type=str, default="http://localhost:30000")
parser.add_argument("--log-requests", action="store_true") parser.add_argument("--log-requests", action="store_true")
parser.add_argument("--log-requests-level", type=int, default=2)
parser.add_argument( parser.add_argument(
"--dump-requests-folder", type=str, default="/tmp/sglang_request_dump" "--dump-requests-folder", type=str, default="/tmp/sglang_request_dump"
) )
...@@ -38,7 +39,7 @@ if __name__ == "__main__": ...@@ -38,7 +39,7 @@ if __name__ == "__main__":
args.url + "/configure_logging", args.url + "/configure_logging",
json={ json={
"log_requests": args.log_requests, "log_requests": args.log_requests,
"log_requests_level": 1, # Log full requests "log_requests_level": args.log_requests_level, # Log full requests
"dump_requests_folder": args.dump_requests_folder, "dump_requests_folder": args.dump_requests_folder,
"dump_requests_threshold": args.dump_requests_threshold, "dump_requests_threshold": args.dump_requests_threshold,
}, },
......
...@@ -198,6 +198,8 @@ class DataParallelController: ...@@ -198,6 +198,8 @@ class DataParallelController:
self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"] self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"]
self.max_req_input_len = scheduler_info[0]["max_req_input_len"] self.max_req_input_len = scheduler_info[0]["max_req_input_len"]
print(f"{scheduler_info=}")
def round_robin_scheduler(self, req): def round_robin_scheduler(self, req):
self.workers[self.round_robin_counter].send_pyobj(req) self.workers[self.round_robin_counter].send_pyobj(req)
self.round_robin_counter = (self.round_robin_counter + 1) % len(self.workers) self.round_robin_counter = (self.round_robin_counter + 1) % len(self.workers)
...@@ -220,6 +222,7 @@ class DataParallelController: ...@@ -220,6 +222,7 @@ class DataParallelController:
TokenizedEmbeddingReqInput, TokenizedEmbeddingReqInput,
), ),
): ):
logger.info("dispatching")
self.dispatching(recv_req) self.dispatching(recv_req)
else: else:
# Send other control messages to first worker of tp group # Send other control messages to first worker of tp group
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
"""DetokenizerManager is a process that detokenizes the token ids.""" """DetokenizerManager is a process that detokenizes the token ids."""
import dataclasses import dataclasses
import json
import logging import logging
import os import os
import signal import signal
...@@ -27,11 +28,16 @@ import zmq ...@@ -27,11 +28,16 @@ import zmq
from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
BatchEmbeddingOut, BatchEmbeddingOut,
BatchMultimodalDecodeReq,
BatchStrOut, BatchStrOut,
BatchTokenIDOut, BatchTokenIDOut,
) )
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import configure_logger, get_zmq_socket from sglang.srt.utils import (
configure_logger,
get_zmq_socket,
kill_itself_when_parent_died,
)
from sglang.utils import ( from sglang.utils import (
TypeBasedDispatcher, TypeBasedDispatcher,
find_printable_text, find_printable_text,
...@@ -86,14 +92,23 @@ class DetokenizerManager: ...@@ -86,14 +92,23 @@ class DetokenizerManager:
) )
self.decode_status = LimitedCapacityDict(capacity=DETOKENIZER_MAX_STATES) self.decode_status = LimitedCapacityDict(capacity=DETOKENIZER_MAX_STATES)
self.is_dummy = server_args.load_format == "dummy"
self._request_dispatcher = TypeBasedDispatcher( self._request_dispatcher = TypeBasedDispatcher(
[ [
(BatchEmbeddingOut, self.handle_batch_embedding_out), (BatchEmbeddingOut, self.handle_batch_embedding_out),
(BatchTokenIDOut, self.handle_batch_token_id_out), (BatchTokenIDOut, self.handle_batch_token_id_out),
(BatchMultimodalDecodeReq, self.handle_multimodal_decode_req),
] ]
) )
def event_loop(self):
"""The event loop that handles requests"""
while True:
recv_obj = self.recv_from_scheduler.recv_pyobj()
output = self._request_dispatcher(recv_obj)
self.send_to_tokenizer.send_pyobj(output)
def trim_matched_stop( def trim_matched_stop(
self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool
): ):
...@@ -117,14 +132,6 @@ class DetokenizerManager: ...@@ -117,14 +132,6 @@ class DetokenizerManager:
return output[:-1] return output[:-1]
return output return output
def event_loop(self):
"""The event loop that handles requests"""
while True:
recv_obj = self.recv_from_scheduler.recv_pyobj()
output = self._request_dispatcher(recv_obj)
self.send_to_tokenizer.send_pyobj(output)
def handle_batch_embedding_out(self, recv_obj: BatchEmbeddingOut): def handle_batch_embedding_out(self, recv_obj: BatchEmbeddingOut):
# If it is embedding model, no detokenization is needed. # If it is embedding model, no detokenization is needed.
return recv_obj return recv_obj
...@@ -173,7 +180,6 @@ class DetokenizerManager: ...@@ -173,7 +180,6 @@ class DetokenizerManager:
# Incremental decoding # Incremental decoding
output_strs = [] output_strs = []
finished_reqs = []
for i in range(bs): for i in range(bs):
try: try:
s = self.decode_status[recv_obj.rids[i]] s = self.decode_status[recv_obj.rids[i]]
...@@ -196,8 +202,6 @@ class DetokenizerManager: ...@@ -196,8 +202,6 @@ class DetokenizerManager:
new_text = "" new_text = ""
else: else:
new_text = find_printable_text(new_text) new_text = find_printable_text(new_text)
else:
finished_reqs.append(recv_obj.rids[i])
output_strs.append( output_strs.append(
self.trim_matched_stop( self.trim_matched_stop(
...@@ -207,7 +211,7 @@ class DetokenizerManager: ...@@ -207,7 +211,7 @@ class DetokenizerManager:
) )
) )
out = BatchStrOut( return BatchStrOut(
rids=recv_obj.rids, rids=recv_obj.rids,
finished_reasons=recv_obj.finished_reasons, finished_reasons=recv_obj.finished_reasons,
output_strs=output_strs, output_strs=output_strs,
...@@ -223,14 +227,15 @@ class DetokenizerManager: ...@@ -223,14 +227,15 @@ class DetokenizerManager:
input_top_logprobs_idx=recv_obj.input_top_logprobs_idx, input_top_logprobs_idx=recv_obj.input_top_logprobs_idx,
output_top_logprobs_val=recv_obj.output_top_logprobs_val, output_top_logprobs_val=recv_obj.output_top_logprobs_val,
output_top_logprobs_idx=recv_obj.output_top_logprobs_idx, output_top_logprobs_idx=recv_obj.output_top_logprobs_idx,
input_token_ids_logprobs_val=recv_obj.input_token_ids_logprobs_val,
input_token_ids_logprobs_idx=recv_obj.input_token_ids_logprobs_idx,
output_token_ids_logprobs_val=recv_obj.output_token_ids_logprobs_val,
output_token_ids_logprobs_idx=recv_obj.output_token_ids_logprobs_idx,
output_hidden_states=recv_obj.output_hidden_states, output_hidden_states=recv_obj.output_hidden_states,
) )
# remove decodestatus for completed requests def handle_multimodal_decode_req(self, recv_obj: BatchMultimodalDecodeReq):
for rid in finished_reqs: raise NotImplementedError()
self.decode_status.pop(rid)
return out
class LimitedCapacityDict(OrderedDict): class LimitedCapacityDict(OrderedDict):
...@@ -250,6 +255,7 @@ def run_detokenizer_process( ...@@ -250,6 +255,7 @@ def run_detokenizer_process(
server_args: ServerArgs, server_args: ServerArgs,
port_args: PortArgs, port_args: PortArgs,
): ):
kill_itself_when_parent_died()
setproctitle.setproctitle("sglang::detokenizer") setproctitle.setproctitle("sglang::detokenizer")
configure_logger(server_args) configure_logger(server_args)
parent_process = psutil.Process().parent() parent_process = psutil.Process().parent()
......
...@@ -16,10 +16,11 @@ The definition of objects transfered between different ...@@ -16,10 +16,11 @@ The definition of objects transfered between different
processes (TokenizerManager, DetokenizerManager, Controller). processes (TokenizerManager, DetokenizerManager, Controller).
""" """
import copy
import uuid import uuid
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from typing import Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
from sglang.srt.managers.schedule_batch import BaseFinishReason from sglang.srt.managers.schedule_batch import BaseFinishReason
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
...@@ -55,6 +56,8 @@ class GenerateReqInput: ...@@ -55,6 +56,8 @@ class GenerateReqInput:
logprob_start_len: Optional[Union[List[int], int]] = None logprob_start_len: Optional[Union[List[int], int]] = None
# If return logprobs, the number of top logprobs to return at each position. # If return logprobs, the number of top logprobs to return at each position.
top_logprobs_num: Optional[Union[List[int], int]] = None top_logprobs_num: Optional[Union[List[int], int]] = None
# If return logprobs, the token ids to return logprob for.
token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None
# Whether to detokenize tokens in text in the returned logprobs. # Whether to detokenize tokens in text in the returned logprobs.
return_text_in_logprobs: bool = False return_text_in_logprobs: bool = False
# Whether to stream output. # Whether to stream output.
...@@ -146,6 +149,8 @@ class GenerateReqInput: ...@@ -146,6 +149,8 @@ class GenerateReqInput:
self.logprob_start_len = -1 self.logprob_start_len = -1
if self.top_logprobs_num is None: if self.top_logprobs_num is None:
self.top_logprobs_num = 0 self.top_logprobs_num = 0
if not self.token_ids_logprob: # covers both None and []
self.token_ids_logprob = None
else: else:
if self.parallel_sample_num == 1: if self.parallel_sample_num == 1:
num = self.batch_size num = self.batch_size
...@@ -191,6 +196,17 @@ class GenerateReqInput: ...@@ -191,6 +196,17 @@ class GenerateReqInput:
else: else:
assert self.parallel_sample_num == 1 assert self.parallel_sample_num == 1
if not self.token_ids_logprob: # covers both None and []
self.token_ids_logprob = [None] * num
elif not isinstance(self.token_ids_logprob, list):
self.token_ids_logprob = [[self.token_ids_logprob] for _ in range(num)]
elif not isinstance(self.token_ids_logprob[0], list):
self.token_ids_logprob = [
copy.deepcopy(self.token_ids_logprob) for _ in range(num)
]
else:
assert self.parallel_sample_num == 1
if self.custom_logit_processor is None: if self.custom_logit_processor is None:
self.custom_logit_processor = [None] * num self.custom_logit_processor = [None] * num
elif not isinstance(self.custom_logit_processor, list): elif not isinstance(self.custom_logit_processor, list):
...@@ -198,6 +214,12 @@ class GenerateReqInput: ...@@ -198,6 +214,12 @@ class GenerateReqInput:
else: else:
assert self.parallel_sample_num == 1 assert self.parallel_sample_num == 1
# Other checks
if self.session_params is not None:
assert isinstance(self.session_params, dict) or isinstance(
self.session_params[0], dict
)
def regenerate_rid(self): def regenerate_rid(self):
self.rid = uuid.uuid4().hex self.rid = uuid.uuid4().hex
return self.rid return self.rid
...@@ -212,6 +234,7 @@ class GenerateReqInput: ...@@ -212,6 +234,7 @@ class GenerateReqInput:
return_logprob=self.return_logprob[i], return_logprob=self.return_logprob[i],
logprob_start_len=self.logprob_start_len[i], logprob_start_len=self.logprob_start_len[i],
top_logprobs_num=self.top_logprobs_num[i], top_logprobs_num=self.top_logprobs_num[i],
token_ids_logprob=self.token_ids_logprob[i],
return_text_in_logprobs=self.return_text_in_logprobs, return_text_in_logprobs=self.return_text_in_logprobs,
stream=self.stream, stream=self.stream,
log_metrics=self.log_metrics, log_metrics=self.log_metrics,
...@@ -244,6 +267,8 @@ class TokenizedGenerateReqInput: ...@@ -244,6 +267,8 @@ class TokenizedGenerateReqInput:
logprob_start_len: int logprob_start_len: int
# If return logprobs, the number of top logprobs to return at each position. # If return logprobs, the number of top logprobs to return at each position.
top_logprobs_num: int top_logprobs_num: int
# If return logprobs, the token id to return logprob for
token_ids_logprob: List[int]
# Whether to stream output # Whether to stream output
stream: bool stream: bool
...@@ -378,10 +403,21 @@ class BatchTokenIDOut: ...@@ -378,10 +403,21 @@ class BatchTokenIDOut:
input_top_logprobs_idx: List[List] input_top_logprobs_idx: List[List]
output_top_logprobs_val: List[List] output_top_logprobs_val: List[List]
output_top_logprobs_idx: List[List] output_top_logprobs_idx: List[List]
input_token_ids_logprobs_val: List[List]
input_token_ids_logprobs_idx: List[List]
output_token_ids_logprobs_val: List[List]
output_token_ids_logprobs_idx: List[List]
# Hidden states
output_hidden_states: List[List[float]] output_hidden_states: List[List[float]]
@dataclass
class BatchMultimodalDecodeReq:
# The request id
rids: List[str]
@dataclass @dataclass
class BatchStrOut: class BatchStrOut:
# The request id # The request id
...@@ -406,10 +442,21 @@ class BatchStrOut: ...@@ -406,10 +442,21 @@ class BatchStrOut:
input_top_logprobs_idx: List[List] input_top_logprobs_idx: List[List]
output_top_logprobs_val: List[List] output_top_logprobs_val: List[List]
output_top_logprobs_idx: List[List] output_top_logprobs_idx: List[List]
input_token_ids_logprobs_val: List[List]
input_token_ids_logprobs_idx: List[List]
output_token_ids_logprobs_val: List[List]
output_token_ids_logprobs_idx: List[List]
# Hidden states
output_hidden_states: List[List[float]] output_hidden_states: List[List[float]]
@dataclass
class BatchMultimodalOut:
# The request id
rids: List[str]
@dataclass @dataclass
class BatchEmbeddingOut: class BatchEmbeddingOut:
# The request id # The request id
...@@ -439,6 +486,8 @@ class UpdateWeightFromDiskReqInput: ...@@ -439,6 +486,8 @@ class UpdateWeightFromDiskReqInput:
class UpdateWeightFromDiskReqOutput: class UpdateWeightFromDiskReqOutput:
success: bool success: bool
message: str message: str
# Number of paused requests during weight sync.
num_paused_requests: Optional[int] = 0
@dataclass @dataclass
...@@ -526,11 +575,57 @@ class AbortReq: ...@@ -526,11 +575,57 @@ class AbortReq:
rid: str rid: str
class ProfileReq(Enum): @dataclass
class GetInternalStateReq:
pass
@dataclass
class GetInternalStateReqOutput:
internal_state: Dict[Any, Any]
@dataclass
class SetInternalStateReq:
server_args: Dict[str, Any]
@dataclass
class SetInternalStateReqOutput:
updated: bool
server_args: Dict[str, Any]
@dataclass
class ProfileReqInput:
# The output directory
output_dir: Optional[str] = None
# If set, it profile as many as this number of steps.
# If it is set, profiling is automatically stopped after this step, and
# the caller doesn't need to run stop_profile.
num_steps: Optional[int] = None
activities: Optional[List[str]] = None
class ProfileReqType(Enum):
START_PROFILE = 1 START_PROFILE = 1
STOP_PROFILE = 2 STOP_PROFILE = 2
@dataclass
class ProfileReq:
type: ProfileReqType
output_dir: Optional[str] = None
num_steps: Optional[int] = None
activities: Optional[List[str]] = None
@dataclass
class ProfileReqOutput:
success: bool
message: str
@dataclass @dataclass
class ConfigureLoggingReq: class ConfigureLoggingReq:
log_requests: Optional[bool] = None log_requests: Optional[bool] = None
...@@ -556,6 +651,11 @@ class OpenSessionReqOutput: ...@@ -556,6 +651,11 @@ class OpenSessionReqOutput:
success: bool success: bool
@dataclass
class HealthCheckOutput:
pass
@dataclass @dataclass
class Function: class Function:
description: Optional[str] = None description: Optional[str] = None
......
...@@ -272,7 +272,7 @@ class PrefillAdder: ...@@ -272,7 +272,7 @@ class PrefillAdder:
self.req_states = None self.req_states = None
self.can_run_list = [] self.can_run_list = []
self.new_being_chunked_req = None self.new_chunked_req = None
self.log_hit_tokens = 0 self.log_hit_tokens = 0
self.log_input_tokens = 0 self.log_input_tokens = 0
...@@ -327,7 +327,7 @@ class PrefillAdder: ...@@ -327,7 +327,7 @@ class PrefillAdder:
self.log_hit_tokens += prefix_len self.log_hit_tokens += prefix_len
self.log_input_tokens += extend_input_len self.log_input_tokens += extend_input_len
def add_being_chunked_req(self, req: Req): def add_chunked_req(self, req: Req):
truncated = req.extend_input_len > self.rem_chunk_tokens truncated = req.extend_input_len > self.rem_chunk_tokens
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens) req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len] req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
...@@ -354,7 +354,7 @@ class PrefillAdder: ...@@ -354,7 +354,7 @@ class PrefillAdder:
finally: finally:
self.tree_cache.dec_lock_ref(last_node) self.tree_cache.dec_lock_ref(last_node)
def add_one_req_ignore_eos(self, req: Req): def add_one_req_ignore_eos(self, req: Req, has_chunked_req: bool):
def add_req_state(r, insert_sort=False): def add_req_state(r, insert_sort=False):
new_token_ratio = ( new_token_ratio = (
1.0 if r.sampling_params.ignore_eos else self.new_token_ratio 1.0 if r.sampling_params.ignore_eos else self.new_token_ratio
...@@ -403,6 +403,7 @@ class PrefillAdder: ...@@ -403,6 +403,7 @@ class PrefillAdder:
self.rem_chunk_tokens is None self.rem_chunk_tokens is None
or req.extend_input_len <= self.rem_chunk_tokens or req.extend_input_len <= self.rem_chunk_tokens
): ):
# Non-chunked prefill
self.can_run_list.append(req) self.can_run_list.append(req)
self._prefill_one_req( self._prefill_one_req(
0, 0,
...@@ -418,14 +419,14 @@ class PrefillAdder: ...@@ -418,14 +419,14 @@ class PrefillAdder:
req.extend_input_len = trunc_len req.extend_input_len = trunc_len
req.fill_ids = req.fill_ids[:trunc_len] req.fill_ids = req.fill_ids[:trunc_len]
self.can_run_list.append(req) self.can_run_list.append(req)
self.new_being_chunked_req = req self.new_chunked_req = req
self._prefill_one_req(0, trunc_len, 0) self._prefill_one_req(0, trunc_len, 0)
return self.budget_state() return self.budget_state()
def add_one_req(self, req: Req): def add_one_req(self, req: Req, has_chunked_req: bool):
if req.sampling_params.ignore_eos and self.tree_cache.disable: if req.sampling_params.ignore_eos and self.tree_cache.disable:
return self.add_one_req_ignore_eos(req) return self.add_one_req_ignore_eos(req, has_chunked_req)
total_tokens = req.extend_input_len + min( total_tokens = req.extend_input_len + min(
req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION
...@@ -443,14 +444,7 @@ class PrefillAdder: ...@@ -443,14 +444,7 @@ class PrefillAdder:
if total_tokens > self.rem_total_tokens: if total_tokens > self.rem_total_tokens:
return AddReqResult.NO_TOKEN return AddReqResult.NO_TOKEN
if ( if self.rem_chunk_tokens is None or input_tokens <= self.rem_chunk_tokens:
self.rem_chunk_tokens is None
or input_tokens <= self.rem_chunk_tokens
or (
req.return_logprob
and req.logprob_start_len != len(req.origin_input_ids) - 1
)
):
# Non-chunked prefill # Non-chunked prefill
self.can_run_list.append(req) self.can_run_list.append(req)
self.tree_cache.inc_lock_ref(req.last_node) self.tree_cache.inc_lock_ref(req.last_node)
...@@ -470,8 +464,9 @@ class PrefillAdder: ...@@ -470,8 +464,9 @@ class PrefillAdder:
req.extend_input_len = trunc_len req.extend_input_len = trunc_len
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len] req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
self.can_run_list.append(req) self.can_run_list.append(req)
self.new_being_chunked_req = req self.new_chunked_req = req
self.tree_cache.inc_lock_ref(req.last_node) self.tree_cache.inc_lock_ref(req.last_node)
self._prefill_one_req(prefix_len, trunc_len, 0) self._prefill_one_req(prefix_len, trunc_len, 0)
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment