Commit 852a49c5 authored by maxiao's avatar maxiao
Browse files

adapt to dsv32 on dcu

parent 8f7453e3
......@@ -12,7 +12,7 @@
# limitations under the License.
# ==============================================================================
"""Inference-only GLM-4.5, GLM-4.6 model compatible with HuggingFace weights"""
"""Inference-only GLM-4.5 model compatible with HuggingFace weights"""
import logging
from typing import Any, Dict, Iterable, Optional, Tuple
......@@ -785,9 +785,9 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
or self.config.architectures[0] != architecture
or self.config.n_shared_experts != 1
):
disable_reason = "Only GLM-4.5 or GLM-4.6 on NV-platform with capability >= 80 can use shared experts fusion optimization."
disable_reason = "Only GLM-4.5 on NV-platform with capability >= 80 can use shared experts fusion optimization."
elif get_moe_expert_parallel_world_size() > 1:
disable_reason = "Deepseek and GLM-4.5 or GLM-4.6 can not use shared experts fusion optimization under expert parallelism."
disable_reason = "Deepseek and GLM-4.5 can not use shared experts fusion optimization under expert parallelism."
if disable_reason is not None:
global_server_args_dict["disable_shared_experts_fusion"] = True
......
......@@ -12,7 +12,7 @@
# limitations under the License.
# ==============================================================================
"""Inference-only GLM-4.5, GLM-4.6 NextN Speculative Decoding."""
"""Inference-only GLM-4.5 NextN Speculative Decoding."""
import logging
from typing import Iterable, Optional, Tuple
......@@ -48,7 +48,7 @@ class Glm4MoeModelNextN(nn.Module):
super().__init__()
if quant_config is not None and quant_config.get_name() == "modelopt_fp4":
logger.warning(
"Overriding Glm4MoeForCausalLMNextN quant config for modelopt_fp4 GLM-4.5 / GLM-4.6 model."
"Overriding Glm4MoeForCausalLMNextN quant config for modelopt_fp4 GLM-4.5 model."
)
quant_config = None
......
......@@ -66,10 +66,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.utils import (
create_fused_set_kv_buffer_arg,
enable_fused_set_kv_buffer,
)
from sglang.srt.utils import (
LazyValue,
add_prefix,
......@@ -197,6 +193,33 @@ class GptOssSparseMoeBlock(nn.Module):
return ans
def _enable_fused_set_kv_buffer(forward_batch: ForwardBatch):
"""Enable fused set_kv_buffer only on CUDA with bfloat16 KV cache."""
return _is_cuda and forward_batch.token_to_kv_pool.dtype == torch.bfloat16
# TODO maybe move to a model-common utils
def _create_fused_set_kv_buffer_arg(
value: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
):
layer_id = layer.layer_id
token_to_kv_pool = forward_batch.token_to_kv_pool
k_buffer = token_to_kv_pool.get_key_buffer(layer_id)
v_buffer = token_to_kv_pool.get_value_buffer(layer_id)
return FusedSetKVBufferArg(
value=value,
k_buffer=k_buffer.view(k_buffer.shape[0], -1),
v_buffer=v_buffer.view(v_buffer.shape[0], -1),
k_scale=layer.k_scale,
v_scale=layer.v_scale,
cache_loc=forward_batch.out_cache_loc,
)
class GptOssAttention(nn.Module):
def __init__(
self,
......@@ -314,12 +337,12 @@ class GptOssAttention(nn.Module):
q,
k,
fused_set_kv_buffer_arg=(
create_fused_set_kv_buffer_arg(
_create_fused_set_kv_buffer_arg(
value=v,
layer=self.attn,
forward_batch=forward_batch,
)
if enable_fused_set_kv_buffer(forward_batch)
if _enable_fused_set_kv_buffer(forward_batch)
else None
),
)
......@@ -333,7 +356,7 @@ class GptOssAttention(nn.Module):
attn_output = self.attn(
*inner_state,
sinks=self.sinks,
save_kv_cache=not enable_fused_set_kv_buffer(forward_batch),
save_kv_cache=not _enable_fused_set_kv_buffer(forward_batch),
)
output, _ = self.o_proj(attn_output)
return output
......
......@@ -291,7 +291,7 @@ class Llama4UnfoldConvolution(nn.Module):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.unfold(hidden_states)
hidden_states = hidden_states.permute(0, 2, 1).contiguous()
hidden_states = hidden_states.permute(0, 2, 1)
hidden_states, _ = self.linear(hidden_states)
return hidden_states
......@@ -446,20 +446,9 @@ class Llama4ForConditionalGeneration(nn.Module):
)
if self.has_vision:
# TODO: make this more general
ignore_quant_layers = getattr(config, "quantization_config", {}).get(
"ignore", {}
)
if (
"model.layers.vision_model*" in ignore_quant_layers
and "model.layers.multi_modal_projector*" in ignore_quant_layers
):
vision_quant_config = None
else:
vision_quant_config = quant_config
self.vision_model = Llama4VisionModel(
config.vision_config,
quant_config=vision_quant_config,
quant_config=quant_config,
prefix=add_prefix("vision_model", prefix),
)
......@@ -571,7 +560,7 @@ class Llama4ForConditionalGeneration(nn.Module):
forward_batch=forward_batch,
language_model=self.language_model,
data_embedding_funcs={
Modality.IMAGE: image_embedding_func,
Modality.IMAGE: self.get_image_feature,
},
positions=positions,
)
......
......@@ -454,6 +454,9 @@ class Qwen2ForCausalLM(nn.Module):
# For EAGLE3 support
self.capture_aux_hidden_states = False
# For EAGLE3 support
self.capture_aux_hidden_states = False
def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embedding(input_ids)
......
......@@ -265,7 +265,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
self.fullatt_block_indexes = vision_config.fullatt_block_indexes
self.window_size = vision_config.window_size
self.patch_size = vision_config.patch_size
mlp_hidden_size: int = ((vision_config.intermediate_size + 7) // 8) * 8
mlp_hidden_size: int = vision_config.intermediate_size
self.patch_embed = Qwen2_5_VisionPatchEmbed(
patch_size=patch_size,
temporal_patch_size=temporal_patch_size,
......
# Adapted from qwen2.py
import logging
from functools import partial
from typing import Any, Dict, Iterable, List, Optional, Tuple
import torch
......@@ -29,19 +30,12 @@ from sglang.srt.model_loader.weight_utils import (
)
from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP
from sglang.srt.models.qwen2 import Qwen2Model
from sglang.srt.utils import (
add_prefix,
get_cmo_stream,
is_cuda,
is_npu,
wait_cmo_stream,
)
from sglang.srt.utils import add_prefix, is_cuda
Qwen3Config = None
logger = logging.getLogger(__name__)
_is_cuda = is_cuda()
_is_npu = is_npu()
class Qwen3Attention(nn.Module):
......@@ -241,18 +235,9 @@ class Qwen3DecoderLayer(nn.Module):
# Fully Connected
hidden_states, residual = self.layer_communicator.prepare_mlp(
hidden_states,
residual,
forward_batch,
cache=(
[self.mlp.gate_up_proj.weight, self.mlp.down_proj.weight]
if _is_npu
else None
),
hidden_states, residual, forward_batch
)
hidden_states = self.mlp(hidden_states)
if _is_npu and get_cmo_stream():
wait_cmo_stream()
hidden_states, residual = self.layer_communicator.postprocess_layer(
hidden_states, residual, forward_batch
)
......
......@@ -60,10 +60,6 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTe
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
from sglang.srt.models.qwen2_moe import Qwen2MoeModel
from sglang.srt.models.utils import (
create_fused_set_kv_buffer_arg,
enable_fused_set_kv_buffer,
)
from sglang.srt.utils import (
add_prefix,
is_cuda,
......@@ -416,20 +412,7 @@ class Qwen3MoeAttention(nn.Module):
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self._apply_qk_norm(q, k)
q, k = self.rotary_emb(
positions,
q,
k,
fused_set_kv_buffer_arg=(
create_fused_set_kv_buffer_arg(
value=v,
layer=self.attn,
forward_batch=forward_batch,
)
if enable_fused_set_kv_buffer(forward_batch)
else None
),
)
q, k = self.rotary_emb(positions, q, k)
inner_state = q, k, v, forward_batch
return None, forward_batch, inner_state
......@@ -437,10 +420,7 @@ class Qwen3MoeAttention(nn.Module):
hidden_states, forward_batch, inner_state = intermediate_state
if inner_state is None:
return hidden_states
attn_output = self.attn(
*inner_state,
save_kv_cache=not enable_fused_set_kv_buffer(forward_batch),
)
attn_output = self.attn(*inner_state)
output, _ = self.o_proj(attn_output)
return output
......
# Copyright 2025 Qwen Team
# Copyright 2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Inference-only Qwen3-VL model compatible with HuggingFace weights."""
import logging
from functools import lru_cache, partial
from typing import Callable, Iterable, List, Literal, Optional, Tuple, TypedDict, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers.activations import ACT2FN
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VisionRotaryEmbedding,
)
from sglang.srt.configs.qwen3_vl import Qwen3VLConfig, Qwen3VLVisionConfig
from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.layers.attention.vision import VisionAttention
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternMultimodalTokens,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2_vl import Qwen2VLVideoInputs
from sglang.srt.models.qwen3 import Qwen3Model
from sglang.srt.utils import add_prefix
logger = logging.getLogger(__name__)
# === Vision Encoder === #
class Qwen3_VisionMLP(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: int,
bias: bool = True,
hidden_act="silu",
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.linear_fc1 = ColumnParallelLinear(
in_features,
hidden_features,
bias=bias,
quant_config=quant_config,
prefix=add_prefix("linear_fc1", prefix),
)
self.linear_fc2 = RowParallelLinear(
hidden_features,
in_features,
bias=bias,
quant_config=quant_config,
prefix=add_prefix("linear_fc2", prefix),
)
self.act = ACT2FN[hidden_act]
def forward(self, x: torch.Tensor):
x_fc1, _ = self.linear_fc1(x)
mlp_output, _ = self.linear_fc2(self.act(x_fc1))
return mlp_output
class Qwen3VLVisionPatchEmbed(nn.Module):
def __init__(self, config) -> None:
super().__init__()
self.patch_size = config.patch_size
self.temporal_patch_size = config.temporal_patch_size
self.in_channels = config.in_channels
self.embed_dim = config.hidden_size
kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
self.proj = nn.Conv3d(
self.in_channels,
self.embed_dim,
kernel_size=kernel_size,
stride=kernel_size,
bias=True,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
target_dtype = self.proj.weight.dtype
hidden_states = hidden_states.view(
-1,
self.in_channels,
self.temporal_patch_size,
self.patch_size,
self.patch_size,
)
hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(
-1, self.embed_dim
)
return hidden_states
class Qwen3_VisionBlock(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
intermediate_dim: int,
hidden_act="silu",
norm_layer: Optional[Callable[[int], nn.Module]] = None,
attn_implementation: Optional[str] = "sdpa",
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
if norm_layer is None:
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.norm1 = norm_layer(dim)
self.norm2 = norm_layer(dim)
if attn_implementation == "sdpa":
softmax_in_single_precision = False
qkv_backend = "sdpa"
flatten_batch = True
elif attn_implementation == "flash_attention_2":
softmax_in_single_precision = False
qkv_backend = "triton_attn"
flatten_batch = True
elif attn_implementation == "eager":
softmax_in_single_precision = True
qkv_backend = "sdpa"
flatten_batch = True
elif attn_implementation == "flash_attention_3":
softmax_in_single_precision = False
qkv_backend = "fa3"
flatten_batch = True
self.attn = VisionAttention(
embed_dim=dim,
num_heads=num_heads,
projection_size=dim,
use_qkv_parallel=True,
rotary_embed="normal",
proj_bias=True,
qkv_backend=qkv_backend,
softmax_in_single_precision=softmax_in_single_precision,
flatten_batch=flatten_batch,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
)
self.mlp = Qwen3_VisionMLP(
dim,
intermediate_dim,
hidden_act=hidden_act,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
def forward(
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
position_embeddings: torch.Tensor,
) -> torch.Tensor:
hidden_states = self.norm1(x)
hidden_states = rearrange(hidden_states, "s b ... -> b s ...")
attn = self.attn(
hidden_states,
cu_seqlens=cu_seqlens,
position_embeddings=position_embeddings,
)
attn = rearrange(attn, "b s ... -> s b ...")
x = x + attn
norm2 = self.norm2(x)
mlp = self.mlp(norm2)
x = x + mlp
return x
class Qwen3_VisionPatchMerger(nn.Module):
def __init__(
self,
dim: int,
context_dim: int,
norm_layer: Optional[Callable[[int], nn.Module]] = None,
spatial_merge_size: int = 2,
use_postshuffle_norm: bool = False,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = context_dim * (spatial_merge_size**2)
self.use_postshuffle_norm = use_postshuffle_norm
if norm_layer is None:
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.norm = norm_layer(
self.hidden_size if use_postshuffle_norm else context_dim
)
self.linear_fc1 = ColumnParallelLinear(
self.hidden_size,
self.hidden_size,
bias=True,
quant_config=quant_config,
prefix=add_prefix("linear_fc1", prefix),
)
self.act_fn = nn.GELU()
self.linear_fc2 = RowParallelLinear(
self.hidden_size,
dim,
bias=True,
quant_config=quant_config,
prefix=add_prefix("linear_fc2", prefix),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.use_postshuffle_norm:
x = self.norm(x.view(-1, self.hidden_size))
else:
x = self.norm(x).view(-1, self.hidden_size)
x_parallel, _ = self.linear_fc1(x)
x_parallel = self.act_fn(x_parallel)
out, _ = self.linear_fc2(x_parallel)
return out
class Qwen3_VisionTransformer(nn.Module):
def __init__(
self,
vision_config: Qwen3VLVisionConfig,
norm_eps: float = 1e-6,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = vision_config.hidden_size
self.num_heads = vision_config.num_heads
self.num_position_embeddings = vision_config.num_position_embeddings
self.patch_size = vision_config.patch_size
self.spatial_merge_size = vision_config.spatial_merge_size
self.spatial_merge_unit = self.spatial_merge_size**2
self.temporal_patch_size = vision_config.temporal_patch_size
self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes
self.patch_embed = Qwen3VLVisionPatchEmbed(config=vision_config)
self.pos_embed = nn.Embedding(self.num_position_embeddings, self.hidden_size)
norm_layer = partial(nn.LayerNorm, eps=norm_eps)
head_dim = self.hidden_size // self.num_heads
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
self.blocks = nn.ModuleList(
[
Qwen3_VisionBlock(
dim=self.hidden_size,
num_heads=self.num_heads,
intermediate_dim=vision_config.intermediate_size,
hidden_act=vision_config.hidden_act,
norm_layer=norm_layer,
attn_implementation="flash_attention_3",
quant_config=quant_config,
prefix=add_prefix(f"blocks.{layer_idx}", prefix),
)
for layer_idx in range(vision_config.depth)
]
)
self.merger = Qwen3_VisionPatchMerger(
dim=vision_config.out_hidden_size,
context_dim=self.hidden_size,
norm_layer=norm_layer,
spatial_merge_size=self.spatial_merge_size,
quant_config=quant_config,
prefix=add_prefix("merger", prefix),
)
self.deepstack_merger_list = nn.ModuleList(
[
Qwen3_VisionPatchMerger(
dim=vision_config.out_hidden_size,
context_dim=self.hidden_size,
spatial_merge_size=self.spatial_merge_size,
use_postshuffle_norm=True,
norm_layer=norm_layer,
quant_config=quant_config,
prefix=add_prefix(f"deepstack_merger_list.{layer_idx}", prefix),
)
for layer_idx in range(len(self.deepstack_visual_indexes))
]
)
@property
def dtype(self) -> torch.dtype:
return self.patch_embed.proj.weight.dtype
@property
def device(self) -> torch.device:
return self.patch_embed.proj.weight.device
def rot_pos_emb(self, grid_thw):
pos_ids = []
for t, h, w in grid_thw:
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
hpos_ids = hpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
)
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
hpos_ids = hpos_ids.flatten()
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
wpos_ids = wpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
)
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
wpos_ids = wpos_ids.flatten()
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
pos_ids = torch.cat(pos_ids, dim=0)
max_grid_size = grid_thw[:, 1:].max()
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb
def fast_pos_embed_interpolate(self, grid_thw):
num_grid_per_side = int(self.num_position_embeddings**0.5)
idx_list = [[] for _ in range(4)]
weight_list = [[] for _ in range(4)]
# TODO: use torch instand of np
for t, h, w in grid_thw:
h_idxs = np.linspace(0, num_grid_per_side - 1, h)
w_idxs = np.linspace(0, num_grid_per_side - 1, w)
h_idxs_floor = h_idxs.astype(int)
w_idxs_floor = w_idxs.astype(int)
h_idxs_ceil = (h_idxs.astype(int) + 1).clip(max=num_grid_per_side - 1)
w_idxs_ceil = (w_idxs.astype(int) + 1).clip(max=num_grid_per_side - 1)
dh = h_idxs - h_idxs_floor
dw = w_idxs - w_idxs_floor
idx_list[0].extend(
((h_idxs_floor * num_grid_per_side)[None].T + w_idxs_floor[None])
.flatten()
.tolist()
* t
)
idx_list[1].extend(
((h_idxs_floor * num_grid_per_side)[None].T + w_idxs_ceil[None])
.flatten()
.tolist()
* t
)
idx_list[2].extend(
((h_idxs_ceil * num_grid_per_side)[None].T + w_idxs_floor[None])
.flatten()
.tolist()
* t
)
idx_list[3].extend(
((h_idxs_ceil * num_grid_per_side)[None].T + w_idxs_ceil[None])
.flatten()
.tolist()
* t
)
weight_list[0].extend(
((1 - dh)[None].T * (1 - dw)[None]).flatten().tolist() * t
)
weight_list[1].extend(((1 - dh)[None].T * dw[None]).flatten().tolist() * t)
weight_list[2].extend((dh[None].T * (1 - dw)[None]).flatten().tolist() * t)
weight_list[3].extend((dh[None].T * dw[None]).flatten().tolist() * t)
device = self.pos_embed.weight.device
dtype = self.pos_embed.weight.dtype
p0 = (
self.pos_embed(torch.tensor(idx_list[0], dtype=torch.long, device=device))
* torch.tensor(weight_list[0], dtype=dtype, device=device)[:, None]
)
p1 = (
self.pos_embed(torch.tensor(idx_list[1], dtype=torch.long, device=device))
* torch.tensor(weight_list[1], dtype=dtype, device=device)[:, None]
)
p2 = (
self.pos_embed(torch.tensor(idx_list[2], dtype=torch.long, device=device))
* torch.tensor(weight_list[2], dtype=dtype, device=device)[:, None]
)
p3 = (
self.pos_embed(torch.tensor(idx_list[3], dtype=torch.long, device=device))
* torch.tensor(weight_list[3], dtype=dtype, device=device)[:, None]
)
patch_pos_embeds = p0 + p1 + p2 + p3
patch_pos_embeds = patch_pos_embeds.split([t * h * w for t, h, w in grid_thw])
patch_pos_embeds_permute = []
m_size = self.spatial_merge_size
for pos_embed, (t, h, w) in zip(patch_pos_embeds, grid_thw):
pos_embed = (
pos_embed.view(t, h // m_size, m_size, w // m_size, m_size, -1)
.permute(0, 1, 3, 2, 4, 5)
.flatten(0, 4)
)
patch_pos_embeds_permute.append(pos_embed)
patch_pos_embeds = torch.cat(patch_pos_embeds_permute)
return patch_pos_embeds
def forward(
self,
x: torch.Tensor,
grid_thw: torch.Tensor,
) -> torch.Tensor:
x = x.to(device=self.device, dtype=self.dtype)
x = self.patch_embed(x)
pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
x = x + pos_embeds
rotary_pos_emb = self.rot_pos_emb(grid_thw)
seq_len, _ = x.size()
rotary_pos_emb = rotary_pos_emb.to(x.device)
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
position_embeddings = (emb.cos(), emb.sin())
# compute cu_seqlens
cu_seqlens = torch.cat(
[
torch.tensor([0], device=grid_thw.device),
(grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]).cumsum(dim=0),
]
)
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
# max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
x = x.unsqueeze(1)
deepstack_feature_lists = []
num_deepstack_captured = 0
for layer_num, blk in enumerate(self.blocks):
x = blk(x, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings)
if layer_num in self.deepstack_visual_indexes:
deepstack_feature = self.deepstack_merger_list[num_deepstack_captured](
x
)
deepstack_feature_lists.append(deepstack_feature)
num_deepstack_captured += 1
x = self.merger(x)
hidden_states = torch.cat(
[x] + deepstack_feature_lists, dim=1
) # [seq_len, hidden_size * (1 + depth_of_deepstack)]
return hidden_states
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("attn.qkv.", "attn.q.", "q"),
("attn.qkv.", "attn.k.", "k"),
("attn.qkv.", "attn.v.", "v"),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: set[str] = set()
for name, loaded_weight in weights:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
cached_get_processor = lru_cache(get_processor)
class Qwen3LLMModel(Qwen3Model):
def __init__(
self,
*,
config: Qwen3VLConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__(config=config, quant_config=quant_config, prefix=prefix)
if not self.pp_group.is_first_rank:
assert self.start_layer >= len(
config.vision_config.deepstack_visual_indexes
), "start_layer should be greater than or equal to len(deepstack_visual_indexes)"
self.hidden_size = config.hidden_size
self.deepstack_embed_to_decoder_layer = range(
len(config.vision_config.deepstack_visual_indexes)
)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
input_deepstack_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, PPProxyTensors]:
if self.pp_group.is_first_rank:
if input_embeds is None:
hidden_states = self.embed_tokens(input_ids)
else:
hidden_states = input_embeds
residual = None
else:
assert pp_proxy_tensors is not None
hidden_states = pp_proxy_tensors["hidden_states"]
residual = pp_proxy_tensors["residual"]
aux_hidden_states = []
for layer_idx, layer in enumerate(
self.layers[self.start_layer : self.end_layer]
):
layer_idx = layer_idx + self.start_layer
if layer_idx in self.layers_to_capture:
aux_hidden_states.append(
hidden_states + residual if residual is not None else hidden_states
)
hidden_states, residual = layer(
positions,
hidden_states,
forward_batch,
residual,
)
# process deepstack
if (
input_deepstack_embeds is not None
and layer_idx in self.deepstack_embed_to_decoder_layer
):
sep = self.hidden_size * layer_idx
hidden_states = (
hidden_states
+ input_deepstack_embeds[:, sep : sep + self.hidden_size]
)
if not self.pp_group.is_last_rank:
return PPProxyTensors(
{
"hidden_states": hidden_states,
"residual": residual,
}
)
else:
if hidden_states.shape[0] != 0:
if residual is None:
hidden_states = self.norm(hidden_states)
else:
hidden_states, _ = self.norm(hidden_states, residual)
if len(aux_hidden_states) == 0:
return hidden_states
return hidden_states, aux_hidden_states
class Qwen3VLForConditionalGeneration(nn.Module):
def __init__(
self,
config: Qwen3VLConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.visual = Qwen3_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
# NOTE: Qwen3-VL vision encoder currently supports BitsAndBytes 4-bit quantization.
# Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported.
quant_config=quant_config,
prefix=add_prefix("visual", prefix),
)
self.model = Qwen3LLMModel(
config=config,
quant_config=quant_config,
prefix=add_prefix("model", prefix),
)
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
)
self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
# like {8:0, 16:1, 24:2}, which stands for the captured deepstack features on
# 8, 16, 24 layer will be merged to 0, 1, 2 layer of decoder output hidden_states
# deepstack
self.deepstack_visual_indexes = self.visual.deepstack_visual_indexes
self.num_deepstack_embeddings = len(self.deepstack_visual_indexes)
@property
def use_deepstack(self) -> bool:
return hasattr(self, "deepstack_visual_indexes")
def separate_deepstack_embeds(self, embedding):
assert (
embedding.shape[-1] % (1 + self.num_deepstack_embeddings) == 0
), f"hidden_state of {embedding.shape} should be divisible by ({1 + self.num_deepstack_embeddings})"
separate_index = self.config.hidden_size
input_embeds = embedding[:, :separate_index]
input_deepstack_embeds = embedding[:, separate_index:]
return input_embeds, input_deepstack_embeds
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
return pattern.pad_input_tokens(input_ids, mm_inputs)
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
# in qwen-vl, last dim is the same
pixel_values = torch.cat([item.feature for item in items], dim=0).type(
self.visual.dtype
)
image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0)
assert pixel_values.dim() == 2, pixel_values.dim()
assert image_grid_thw.dim() == 2, image_grid_thw.dim()
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
return image_embeds
def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
# in qwen-vl, last dim is the same
pixel_values = torch.cat([item.feature for item in items], dim=0).type(
self.visual.dtype
)
video_grid_thw = torch.concat([item.video_grid_thw for item in items], dim=0)
assert pixel_values.dim() == 2, pixel_values.dim()
assert video_grid_thw.dim() == 2, video_grid_thw.dim()
video_embeds = self.visual(pixel_values, grid_thw=video_grid_thw)
return video_embeds
def get_input_embeddings(self):
return self.model.embed_tokens
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
get_embedding: bool = False,
):
"""Run forward pass for Qwen3-VL.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
positions: Flattened (concatenated) position ids corresponding to a
batch.
**NOTE**: If mrope is enabled (default setting for Qwen2-VL
opensource models), the shape will be `(3, seq_len)`,
otherwise it will be `(seq_len,).
(Use input_metadata.mrope_positions to replace it)
"""
if self.is_mrope_enabled:
positions = forward_batch.mrope_positions
if not (
forward_batch.forward_mode.is_decode()
or not forward_batch.contains_image_inputs()
):
if self.is_mrope_enabled:
assert positions.ndim == 2 and positions.size(0) == 3, (
"multimodal section rotary embedding requires "
f"(3, seq_len) positions, but got {positions.size()}"
)
hidden_states = general_mm_embed_routine(
input_ids=input_ids,
forward_batch=forward_batch,
language_model=self.model,
multimodal_model=self,
positions=positions,
use_deepstack=self.use_deepstack,
)
if not get_embedding:
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
)
else:
return self.pooler(hidden_states, forward_batch)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
("gate_up_proj", "up_proj", 1),
("gate_up_proj", "gate_proj", 0),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if "language_model" in name:
name = name.replace(r"model.language_model.", r"model.")
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
if "visual" in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
if "visual" in name:
# adapt to VisionAttention
name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
name = name.replace(r"model.visual.", r"visual.")
try:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
except KeyError:
print(params_dict.keys())
raise
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
EntryClass = Qwen3VLForConditionalGeneration
# Copyright 2025 Qwen Team
# Copyright 2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Inference-only Qwen3-VL model compatible with HuggingFace weights."""
import logging
from functools import lru_cache, partial
from typing import Callable, Iterable, List, Literal, Optional, Tuple, TypedDict, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers import BatchFeature
from transformers.activations import ACT2FN
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VisionRotaryEmbedding,
)
from sglang.srt.configs.qwen3_vl import Qwen3VLMoeConfig, Qwen3VLMoeVisionConfig
from sglang.srt.distributed import (
get_moe_expert_parallel_world_size,
get_pp_group,
get_tensor_model_parallel_rank,
)
from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.utils import get_layer_id
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternMultimodalTokens,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import (
MultimodalDataItem,
MultimodalInputs,
global_server_args_dict,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen3_moe import Qwen3MoeForCausalLM, Qwen3MoeModel
from sglang.srt.models.qwen3_vl import (
Qwen3_VisionTransformer,
Qwen3VLForConditionalGeneration,
)
from sglang.srt.utils import add_prefix
logger = logging.getLogger(__name__)
cached_get_processor = lru_cache(get_processor)
class Qwen3MoeLLMModel(Qwen3MoeModel):
def __init__(
self,
*,
config: Qwen3VLMoeConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__(config=config, quant_config=quant_config, prefix=prefix)
self.hidden_size = config.hidden_size
def get_input_embeddings(self) -> nn.Embedding:
return self.embed_tokens
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
# in qwen-vl, last dim is the same
pixel_values = torch.cat([item.feature for item in items], dim=0).type(
self.visual.dtype
)
image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0)
assert pixel_values.dim() == 2, pixel_values.dim()
assert image_grid_thw.dim() == 2, image_grid_thw.dim()
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
return image_embeds
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
input_deepstack_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, PPProxyTensors]:
if self.pp_group.is_first_rank:
if input_embeds is None:
hidden_states = self.embed_tokens(input_ids)
else:
hidden_states = input_embeds
residual = None
else:
assert pp_proxy_tensors is not None
hidden_states = pp_proxy_tensors["hidden_states"]
residual = pp_proxy_tensors["residual"]
aux_hidden_states = []
for layer_idx, layer in enumerate(
self.layers[self.start_layer : self.end_layer]
):
layer_idx = layer_idx + self.start_layer
if layer_idx in self.layers_to_capture:
aux_hidden_states.append(
hidden_states + residual if residual is not None else hidden_states
)
hidden_states, residual = layer(
positions,
hidden_states,
forward_batch,
residual,
)
# process deepstack
if input_deepstack_embeds is not None and layer_idx in range(3):
sep = self.hidden_size * layer_idx
hidden_states = (
hidden_states
+ input_deepstack_embeds[:, sep : sep + self.hidden_size]
)
if not self.pp_group.is_last_rank:
return PPProxyTensors(
{
"hidden_states": hidden_states,
"residual": residual,
}
)
else:
if hidden_states.shape[0] != 0:
if residual is None:
hidden_states = self.norm(hidden_states)
else:
hidden_states, _ = self.norm(hidden_states, residual)
if len(aux_hidden_states) == 0:
return hidden_states
return hidden_states, aux_hidden_states
class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
def __init__(
self,
*,
config: Qwen3VLMoeConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super(Qwen3VLForConditionalGeneration, self).__init__()
self.config = config
self.visual = Qwen3_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
# NOTE: Qwen3-VL vision encoder currently supports BitsAndBytes 4-bit quantization.
# Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported.
quant_config=quant_config,
prefix=add_prefix("visual", prefix),
)
self.model = Qwen3MoeLLMModel(
config=config,
quant_config=quant_config,
prefix=add_prefix("model", prefix),
)
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
)
self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
# deepstack
self.deepstack_visual_indexes = self.visual.deepstack_visual_indexes
self.num_deepstack_embeddings = len(self.deepstack_visual_indexes)
@property
def use_deepstack(self) -> bool:
return hasattr(self, "deepstack_visual_indexes")
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
get_embedding: bool = False,
):
"""Run forward pass for Qwen3-VL.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
positions: Flattened (concatenated) position ids corresponding to a
batch.
**NOTE**: If mrope is enabled (default setting for Qwen2-VL
opensource models), the shape will be `(3, seq_len)`,
otherwise it will be `(seq_len,).
(Use input_metadata.mrope_positions to replace it)
"""
if self.is_mrope_enabled:
positions = forward_batch.mrope_positions
if not (
forward_batch.forward_mode.is_decode()
or not forward_batch.contains_image_inputs()
):
if self.is_mrope_enabled:
assert positions.ndim == 2 and positions.size(0) == 3, (
"multimodal section rotary embedding requires "
f"(3, seq_len) positions, but got {positions.size()}"
)
hidden_states = general_mm_embed_routine(
input_ids=input_ids,
forward_batch=forward_batch,
language_model=self.model,
multimodal_model=self,
positions=positions,
use_deepstack=self.use_deepstack,
)
if not get_embedding:
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
)
else:
return self.pooler(hidden_states, forward_batch)
def load_fused_expert_weights(
self,
name: str,
params_dict: dict,
loaded_weight: torch.Tensor,
shard_id: str,
num_experts: int,
):
param = params_dict[name]
# weight_loader = typing.cast(Callable[..., bool], param.weight_loader)
weight_loader = param.weight_loader
ep_rank = get_tensor_model_parallel_rank()
ep_size = get_moe_expert_parallel_world_size()
if ep_size == 1:
for expert_id in range(num_experts):
curr_expert_weight = loaded_weight[expert_id]
weight_loader(
param,
curr_expert_weight,
name,
shard_id,
expert_id,
)
else:
experts_per_ep = num_experts // ep_size
start_expert = ep_rank * experts_per_ep
end_expert = (
(ep_rank + 1) * experts_per_ep
if ep_rank != ep_size - 1
else num_experts
)
for idx, expert_id in enumerate(range(start_expert, end_expert)):
curr_expert_weight = loaded_weight[expert_id]
weight_loader(
param,
curr_expert_weight,
name,
shard_id,
idx,
)
return True
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
("gate_up_proj", "up_proj", 1),
("gate_up_proj", "gate_proj", 0),
]
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts,
)
# Skip loading extra parameters for GPTQ/modelopt models.
ignore_suffixes = (
".bias",
"_bias",
".k_scale",
"_k_scale",
".v_scale",
"_v_scale",
".weight_scale",
"_weight_scale",
".input_scale",
"_input_scale",
)
is_fused_expert = False
fused_expert_params_mapping = [
("experts.w13_weight", "experts.gate_up_proj", 0, "w1"),
("experts.w2_weight", "experts.down_proj", 0, "w2"),
]
num_experts = self.config.num_experts
# Cache params_dict to avoid repeated expensive traversal of model parameters
if not hasattr(self, "_cached_params_dict"):
self._cached_params_dict = dict(self.named_parameters())
params_dict = self._cached_params_dict
for name, loaded_weight in weights:
if "language_model" in name:
name = name.replace(r"model.language_model.", r"model.")
for param_name, weight_name, shard_id in stacked_params_mapping:
if "experts.gate_up_proj" in name or "experts.down_proj" in name:
is_fused_expert = True
expert_params_mapping = fused_expert_params_mapping
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
if "visual" in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if "mlp.experts" in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra parameters for GPTQ/modelopt models.
if name.endswith(ignore_suffixes) and name not in params_dict:
continue
# [TODO] Skip layers that are on other devices (check if sglang has a similar function)
# if is_pp_missing_parameter(name, self):
# continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Track if this is an expert weight to enable early skipping
is_expert_weight = False
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
if "visual" in name:
continue
# Anyway, this is an expert weight and should not be
# attempted to load as other weights later
is_expert_weight = True
name_mapped = name.replace(weight_name, param_name)
if is_fused_expert:
loaded_weight = loaded_weight.transpose(-1, -2) # no bias
if "experts.gate_up_proj" in name:
loaded_weight = loaded_weight.chunk(2, dim=-2)
self.load_fused_expert_weights(
name_mapped,
params_dict,
loaded_weight[0],
"w1",
num_experts,
)
self.load_fused_expert_weights(
name_mapped,
params_dict,
loaded_weight[1],
"w3",
num_experts,
)
else:
self.load_fused_expert_weights(
name_mapped,
params_dict,
loaded_weight,
shard_id,
num_experts,
)
else:
# Skip loading extra parameters for GPTQ/modelopt models.
if (
name_mapped.endswith(ignore_suffixes)
and name_mapped not in params_dict
):
continue
param = params_dict[name_mapped]
# We should ask the weight loader to return success or
# not here since otherwise we may skip experts with
# # other available replicas.
weight_loader = param.weight_loader
weight_loader(
param,
loaded_weight,
name_mapped,
shard_id=shard_id,
expert_id=expert_id,
)
name = name_mapped
break
else:
if is_expert_weight:
# This is an expert weight but not mapped to this rank, skip all remaining processing
continue
if "visual" in name:
# adapt to VisionAttention
name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
name = name.replace(r"model.visual.", r"visual.")
# Skip loading extra parameters for GPTQ/modelopt models.
if name.endswith(ignore_suffixes) and name not in params_dict:
continue
if name in params_dict.keys():
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
else:
logger.warning(f"Parameter {name} not found in params_dict")
# TODO mimic deepseek
# Lazy initialization of expert weights cache to avoid slowing down load_weights
# if not hasattr(self, "routed_experts_weights_of_layer"):
# self.routed_experts_weights_of_layer = {
# layer_id: self.model.layers[layer_id].mlp.get_moe_weights()
# for layer_id in range(self.start_layer, self.end_layer)
# if isinstance(self.model.layers[layer_id].mlp, Qwen3MoeSparseMoeBlock)
# }
EntryClass = Qwen3VLMoeForConditionalGeneration
......@@ -66,8 +66,8 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix
tp_size: Optional[int] = None
tp_rank: Optional[int] = None
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
def gate_up_proj_weight_loader(
......@@ -341,13 +341,6 @@ class LlamaModel(nn.Module):
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
global tp_size, tp_rank
if tp_size is None:
tp_size = get_tensor_model_parallel_world_size()
if tp_rank is None:
tp_rank = get_tensor_model_parallel_rank()
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
......
# Copyright 2023-2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import torch
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import is_cuda
_is_cuda = is_cuda()
if _is_cuda:
from sgl_kernel import FusedSetKVBufferArg
def enable_fused_set_kv_buffer(forward_batch: ForwardBatch):
"""Enable fused set_kv_buffer only on CUDA with bfloat16 KV cache."""
return _is_cuda and forward_batch.token_to_kv_pool.dtype == torch.bfloat16
def create_fused_set_kv_buffer_arg(
value: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
):
layer_id = layer.layer_id
token_to_kv_pool = forward_batch.token_to_kv_pool
k_buffer = token_to_kv_pool.get_key_buffer(layer_id)
v_buffer = token_to_kv_pool.get_value_buffer(layer_id)
return FusedSetKVBufferArg(
value=value,
k_buffer=k_buffer.view(k_buffer.shape[0], -1),
v_buffer=v_buffer.view(v_buffer.shape[0], -1),
k_scale=layer.k_scale,
v_scale=layer.v_scale,
cache_loc=forward_batch.out_cache_loc,
)
......@@ -234,14 +234,7 @@ class BaseMultimodalProcessor(ABC):
and isinstance(processor.image_processor, BaseImageProcessorFast)
and not self.server_args.disable_fast_image_processor
):
if not _is_npu:
kwargs["device"] = "cuda"
elif processor.__class__.__name__ not in {
"Qwen2_5_VLProcessor",
"Qwen3VLProcessor",
}:
# Note: for qwen-vl, processor has some reshape issue because of dims restriction on Ascend.
kwargs["device"] = "npu"
kwargs["device"] = "cuda" if not _is_npu else "npu"
result = processor.__call__(
text=[input_text],
padding=True,
......
......@@ -12,8 +12,6 @@ from torchvision.transforms import InterpolationMode
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
from sglang.srt.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
from sglang.srt.models.qwen2_vl import Qwen2VLForConditionalGeneration
from sglang.srt.models.qwen3_vl import Qwen3VLForConditionalGeneration
from sglang.srt.models.qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration
from sglang.srt.multimodal.processors.base_processor import (
BaseMultimodalProcessor as SGLangBaseProcessor,
)
......@@ -211,12 +209,7 @@ async def preprocess_video(
# Compatible with Qwen2VL and Qwen2_5VL
class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
models = [
Qwen2VLForConditionalGeneration,
Qwen2_5_VLForConditionalGeneration,
Qwen3VLForConditionalGeneration,
Qwen3VLMoeForConditionalGeneration,
]
models = [Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration]
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
......
......@@ -17,18 +17,10 @@ import torch
from packaging import version
from torch.multiprocessing import reductions
from sglang.srt.utils import is_npu
_is_npu = is_npu()
def monkey_patch_torch_reductions():
"""Monkey patching before Torch https://github.com/pytorch/pytorch/pull/149248 is fixed"""
# Currently, NPU does not support UUID. This has been temporarily commented out, with support expected in the fourth quarter.
if _is_npu:
return
if hasattr(reductions, "_reduce_tensor_original"):
return
......
......@@ -19,6 +19,7 @@ from sglang.srt.utils import get_bool_env_var
_SAMPLING_EPS = 1e-6
TOP_K_ALL = 1 << 30
DEFAULT_SAMPLING_SEED = 42
class SamplingParams:
......@@ -55,7 +56,7 @@ class SamplingParams:
custom_params: Optional[Dict[str, Any]] = None,
stream_interval: Optional[int] = None,
logit_bias: Optional[Dict[str, float]] = None,
sampling_seed: int = 42,
sampling_seed: Optional[int] = None,
) -> None:
self.max_new_tokens = max_new_tokens
self.stop_strs = stop
......@@ -83,6 +84,13 @@ class SamplingParams:
self.custom_params = custom_params
self.stream_interval = stream_interval
self.logit_bias = logit_bias
# Used for deterministic sampling
if (
get_bool_env_var("SGLANG_ENABLE_DETERMINISTIC_INFERENCE")
and sampling_seed is None
):
# If deterministic inference is enabled and sampling_seed is not set, use the default seed
sampling_seed = DEFAULT_SAMPLING_SEED
self.sampling_seed = sampling_seed
# Process some special cases
......
......@@ -19,6 +19,8 @@ import json
import logging
import os
import random
import socket
import sys
import tempfile
from typing import List, Literal, Optional, Union
......@@ -51,6 +53,7 @@ from sglang.utils import is_in_ci
logger = logging.getLogger(__name__)
# Define constants
LOAD_FORMAT_CHOICES = [
"auto",
......@@ -91,6 +94,7 @@ ATTENTION_BACKEND_CHOICES = [
"triton",
"torch_native",
"flex_attention",
"nsa",
# NVIDIA specific
"cutlass_mla",
"fa3",
......@@ -100,6 +104,7 @@ ATTENTION_BACKEND_CHOICES = [
"trtllm_mla",
"trtllm_mha",
"dual_chunk_flash_attn",
"hybrid_linear_attn",
# AMD specific
"aiter",
"wave",
......@@ -116,6 +121,11 @@ GRAMMAR_BACKEND_CHOICES = ["xgrammar", "outlines", "llguidance", "none"]
DETERMINISTIC_ATTENTION_BACKEND_CHOICES = ["flashinfer", "fa3", "triton"]
NSA_CHOICES = ["flashmla_prefill", "flashmla_decode", "fa3", "tilelang"]
NSA_DEFAULT_PREFILL = "flashmla_prefill"
NSA_DEFAULT_DECODE = "fa3"
# Allow external code to add more choices
def add_load_format_choices(choices):
......@@ -167,7 +177,6 @@ class ServerArgs:
quantization: Optional[str] = None
quantization_param_path: Optional[str] = None
kv_cache_dtype: str = "auto"
enable_fp32_lm_head: bool = False
# Memory and scheduling
mem_fraction_static: Optional[float] = None
......@@ -212,8 +221,8 @@ class ServerArgs:
show_time_cost: bool = False
enable_metrics: bool = False
enable_metrics_for_all_schedulers: bool = False
tokenizer_metrics_custom_labels_header: str = "x-custom-labels"
tokenizer_metrics_allowed_custom_labels: Optional[List[str]] = None
tokenizer_metrics_custom_labels_header: str = "x-customer-labels"
tokenizer_metrics_allowed_customer_labels: Optional[List[str]] = None
bucket_time_to_first_token: Optional[List[float]] = None
bucket_inter_token_latency: Optional[List[float]] = None
bucket_e2e_request_latency: Optional[List[float]] = None
......@@ -286,14 +295,14 @@ class ServerArgs:
speculative_accept_threshold_acc: float = 1.0
speculative_token_map: Optional[str] = None
speculative_attention_mode: str = "prefill"
# For ngram only
speculative_ngram_min_match_window_size: int = 1
speculative_ngram_max_match_window_size: int = 12
speculative_ngram_min_bfs_breadth: int = 1
speculative_ngram_max_bfs_breadth: int = 10
speculative_ngram_match_type: Literal["BFS", "PROB"] = "BFS"
speculative_ngram_branch_length: int = 18
speculative_ngram_capacity: int = 10 * 1000 * 1000
# For lookahead only
speculative_lookahead_min_match_window_size: int = 1
speculative_lookahead_max_match_window_size: int = 12
speculative_lookahead_min_bfs_breadth: int = 1
speculative_lookahead_max_bfs_breadth: int = 10
speculative_lookahead_match_type: Literal["BFS", "PROB"] = "BFS"
speculative_lookahead_branch_length: int = 18
speculative_lookahead_capacity: int = 10 * 1000 * 1000
# Expert parallelism
ep_size: int = 1
......@@ -325,10 +334,6 @@ class ServerArgs:
deepep_config: Optional[str] = None
moe_dense_tp_size: Optional[int] = None
# Mamba cache
max_mamba_cache_size: Optional[int] = None
mamba_ssm_dtype: str = "float32"
# Hierarchical cache
enable_hierarchical_cache: bool = False
hicache_ratio: float = 2.0
......@@ -399,7 +404,7 @@ class ServerArgs:
enable_return_hidden_states: bool = False
scheduler_recv_interval: int = 1
numa_node: Optional[List[int]] = None
enable_deterministic_inference: bool = False
max_prefill_bs: Optional[int] = None
# Dynamic batch tokenizer
enable_dynamic_batch_tokenizer: bool = False
......@@ -420,14 +425,16 @@ class ServerArgs:
disaggregation_decode_dp: Optional[int] = None
disaggregation_prefill_pp: Optional[int] = 1
disaggregation_ib_device: Optional[str] = None
disaggregation_decode_enable_offload_kvcache: bool = False
num_reserved_decode_tokens: int = 512 # used for decode kv cache offload in PD
# FIXME: hack to reduce ITL when decode bs is small
disaggregation_decode_polling_interval: int = 1
# For model weight update and weight loading
# For model weight update
custom_weight_loader: Optional[List[str]] = None
weight_loader_disable_mmap: bool = False
# Remote instance weight loading
remote_instance_weight_loader_seed_instance_ip: Optional[str] = None
remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None
remote_instance_weight_loader_send_weights_group_ports: Optional[List[int]] = None
......@@ -436,80 +443,62 @@ class ServerArgs:
enable_pdmux: bool = False
sm_group_num: int = 3
def __post_init__(self):
"""
Orchestrates the handling of various server arguments, ensuring proper configuration and validation.
"""
# Handle deprecated arguments.
self._handle_deprecated_args()
# Set missing default values.
self._handle_missing_default_values()
# Get GPU memory capacity, which is a common dependency for several configuration steps.
gpu_mem = get_device_memory_capacity(self.device)
# Handle memory-related, chunked prefill, and CUDA graph batch size configurations.
self._handle_gpu_memory_settings(gpu_mem)
# Handle device-specific backends.
self._handle_hpu_backends()
self._handle_cpu_backends()
# Apply model-specific adjustments.
self._handle_model_specific_adjustments()
# Set kernel backends.
self._handle_sampling_backend()
self._handle_attention_backend_compatibility()
self._handle_page_size()
self._handle_amd_specifics()
self._handle_grammar_backend()
# Handle data parallelism.
self._handle_data_parallelism()
# Handle MoE configurations.
self._handle_moe_kernel_config()
self._handle_deepep_moe()
self._handle_eplb_and_dispatch()
self._handle_expert_distribution_metrics()
# Handle pipeline parallelism.
self._handle_pipeline_parallelism()
# Handle Hicache settings.
self._handle_hicache()
# Handle speculative decoding logic.
self._handle_speculative_decoding()
# Handle model loading format.
self._handle_load_format()
# Handle PD disaggregation.
self._handle_disaggregation()
# Validate tokenizer settings.
self._handle_tokenizer_batching()
# Propagate environment variables.
self._handle_environment_variables()
# Validate cache settings.
self._handle_cache_compatibility()
# Mamba cache
max_mamba_cache_size: Optional[int] = None
mamba_ssm_dtype: str = "float32"
# Validate metrics labels.
self._handle_metrics_labels()
# For deterministic inference
enable_deterministic_inference: bool = False
# Handle deterministic inference.
self._handle_deterministic_inference()
# NSA attention backend
nsa_prefill: str = NSA_DEFAULT_PREFILL
nsa_decode: str = NSA_DEFAULT_DECODE
# Handle any other necessary validations.
self._handle_other_validations()
# Deprecated arguments
enable_ep_moe: bool = False
enable_deepep_moe: bool = False
enable_flashinfer_cutlass_moe: bool = False
enable_flashinfer_cutedsl_moe: bool = False
enable_flashinfer_trtllm_moe: bool = False
enable_triton_kernel_moe: bool = False
enable_flashinfer_mxfp4_moe: bool = False
def _handle_deprecated_args(self):
pass
if self.enable_ep_moe:
self.ep_size = self.tp_size
print_deprecated_warning(
"NOTE: --enable-ep-moe is deprecated. Please set `--ep-size` to the same value as `--tp-size` instead."
)
if self.enable_deepep_moe:
self.moe_a2a_backend = "deepep"
print_deprecated_warning(
"NOTE: --enable-deepep-moe is deprecated. Please set `--moe-a2a-backend` to 'deepep' instead."
)
if self.enable_triton_kernel_moe:
self.moe_runner_backend = "triton_kernel"
print_deprecated_warning(
"NOTE: --enable-triton-kernel-moe is deprecated. Please set `--moe-runner-backend` to 'triton_kernel' instead."
)
if self.enable_flashinfer_cutedsl_moe:
self.moe_runner_backend = "flashinfer_cutedsl"
print_deprecated_warning(
"NOTE: --enable-flashinfer-cutedsl-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_cutedsl' instead."
)
if self.enable_flashinfer_cutlass_moe:
self.moe_runner_backend = "flashinfer_cutlass"
print_deprecated_warning(
"NOTE: --enable-flashinfer-cutlass-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_cutlass' instead."
)
if self.enable_flashinfer_trtllm_moe:
self.moe_runner_backend = "flashinfer_trtllm"
print_deprecated_warning(
"NOTE: --enable-flashinfer-trtllm-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_trtllm' instead."
)
if self.enable_flashinfer_mxfp4_moe:
self.moe_runner_backend = "flashinfer_mxfp4"
print_deprecated_warning(
"NOTE: --enable-flashinfer-mxfp4-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_mxfp4' instead."
)
def _handle_missing_default_values(self):
if self.tokenizer_path is None:
......@@ -521,174 +510,85 @@ class ServerArgs:
if self.random_seed is None:
self.random_seed = random.randint(0, 1 << 30)
def _handle_gpu_memory_settings(self, gpu_mem):
"""
Configure GPU memory-dependent settings including
chunked_prefill_size, cuda_graph_max_bs, and mem_fraction_static.
Here are our heuristics:
- Set chunked_prefill_size and cuda_graph_max_bs based on the GPU memory capacity.
This is because GPUs with more memory are generally more powerful, we need to use a larger
chunked_prefill_size and a larger cuda_graph_max_bs to fully utilize the GPU.
- Then set mem_fraction_static based on chunked_prefill_size and cuda_graph_max_bs.
GPU memory capacity = model weights + KV cache pool + activations + cuda graph buffers
The argument mem_fraction_static is defined as (model weights + KV cache pool) / GPU memory capacity,
or equivalently, mem_fraction_static = (GPU memory capacity - activations - cuda graph buffers) / GPU memory capacity.
In order to compute mem_fraction_static, we need to estimate the size of activations and cuda graph buffers.
The activation memory is proportional to the chunked_prefill_size.
The cuda graph memory is proportional to the cuda_graph_max_bs.
We use reserved_mem = chunked_prefill_size * 1.5 + cuda_graph_max_bs * 2 to estimate the size of activations and cuda graph buffers in GB.
and set mem_fraction_static = (GPU memory capacity - reserved_mem) / GPU memory capacity.
The coefficient 1.5 is a heuristic value, in the future, we can do better estimation by looking at the model types, hidden sizes or even do a dummy run.
"""
def _handle_mem_fraction_static(self, gpu_mem):
if self.mem_fraction_static is None:
if gpu_mem is not None:
# GPU memory capacity = model weights + KV cache pool + activations + cuda graph buffers
# mem_fraction_static = (model weights + KV cache pool) / GPU memory capacity.
# We want mem_fraction_static to be as large as possible but still has enough room
# for activations and cuda graph buffers. We use the following heuristic to
# compute the needed size for activations and cuda graph buffers:
# - The size of the activation depends on the chunked_prefill_size and model size.
# - The size of cuda graph buffers depends on the cuda graph capture range and model size.
# For GPUs with more memory, we use a larger chunked_prefill_size and
# capture more cuda graphs, so they need to reserve more memory.
parallel_size = self.tp_size * self.pp_size
if gpu_mem < 20 * 1024:
# T4, 4080
# (chunked_prefill_size 2k, cuda_graph_max_bs 8)
if self.chunked_prefill_size is None:
self.chunked_prefill_size = 2048
if self.cuda_graph_max_bs is None:
self.cuda_graph_max_bs = 8
# T4, 4080. (chunked_prefill_size 2k, cuda_graph_max_bs 8)
reserved_mem = (2.8 + parallel_size / 10) * 1024
elif gpu_mem < 35 * 1024:
# A10, 4090, 5090
# (chunked_prefill_size 2k, cuda_graph_max_bs 16 if tp < 4 else 80)
if self.chunked_prefill_size is None:
self.chunked_prefill_size = 2048
if self.cuda_graph_max_bs is None:
# Based on detailed statistics, when serving TP1/TP2 models on lower-end GPUs with HBM < 35GB, you can either disable cuda graph or set `cuda_graph_max_bs` to a very small value to reduce the memory overhead of creating cuda graphs, with almost no impact on performance.
# However, when serving models with TP4 or TP8, we need to enable cuda graph to maintain high performance. In this case, we can set `cuda_graph_max_bs` to 80 (half of the default value 160) to reduce the memory overhead of creating cuda graphs. Looking at the logs
# from TP4 serving of qwen2-72b, a value of 80 is sufficient and can reduce the memory overhead of creating cuda graphs on lower-end GPUs compared to the original 160, avoiding OOM issues.
if self.tp_size < 4:
self.cuda_graph_max_bs = 16
else:
self.cuda_graph_max_bs = 80
elif gpu_mem < 60 * 1024:
# A100 (40GB), L40,
# (chunked_prefill_size 4k, cuda_graph_max_bs 32 if tp < 4 else 160)
if self.chunked_prefill_size is None:
self.chunked_prefill_size = 4096
if self.cuda_graph_max_bs is None:
if self.tp_size < 4:
self.cuda_graph_max_bs = 32
else:
self.cuda_graph_max_bs = 160
# A10, L40, 4090, 5090. (chunked_prefill_size 2k, cuda_graph_max_bs 8)
reserved_mem = (2.8 + parallel_size / 10) * 1024
elif gpu_mem < 90 * 1024:
# H100, A100
# (chunked_prefill_size 8k, cuda_graph_max_bs 256 if tp < 4 else 512)
if self.chunked_prefill_size is None:
self.chunked_prefill_size = 8192
if self.cuda_graph_max_bs is None:
if self.tp_size < 4:
self.cuda_graph_max_bs = 256
else:
self.cuda_graph_max_bs = 512
# H100, A100. (chunked_prefill_size 8k, cuda_graph_max_bs 160)
reserved_mem = (9.5 + parallel_size / 2) * 1024
elif gpu_mem < 100 * 1024:
# H20. (chunked_prefill_size 8k, cuda_graph_max_bs 256)
reserved_mem = (12 + parallel_size / 2) * 1024
elif gpu_mem < 160 * 1024:
# H20, H200
# (chunked_prefill_size 8k, cuda_graph_max_bs 256 if tp < 4 else 512)
if self.chunked_prefill_size is None:
self.chunked_prefill_size = 8192
if self.cuda_graph_max_bs is None:
if self.tp_size < 4:
self.cuda_graph_max_bs = 256
else:
self.cuda_graph_max_bs = 512
else:
# B200, MI300
# (chunked_prefill_size 16k, cuda_graph_max_bs 512)
if self.chunked_prefill_size is None:
self.chunked_prefill_size = 16384
if self.cuda_graph_max_bs is None:
self.cuda_graph_max_bs = 512
else:
# Fallback defaults when gpu_mem is None
if self.chunked_prefill_size is None:
self.chunked_prefill_size = 4096
if self.cuda_graph_max_bs is None:
self.cuda_graph_max_bs = 160
# Set cuda graph batch sizes
if self.cuda_graph_bs is None:
self.cuda_graph_bs = self._generate_cuda_graph_batch_sizes()
else:
self.cuda_graph_max_bs = max(self.cuda_graph_bs)
if self.mem_fraction_static is None:
# Constant meta data (e.g., from attention backend)
reserved_mem = 512
# For activation during large prefill
if self.chunked_prefill_size > 0:
reserved_mem += max(self.chunked_prefill_size, 2048) * 1.5
# H200. (chunked_prefill_size 8k, cuda_graph_max_bs 256)
reserved_mem = (12 + parallel_size / 2) * 1024
else:
reserved_mem += max(self.max_prefill_tokens, 2048) * 1.5
# For cuda graphs
reserved_mem += self.cuda_graph_max_bs * 2
# Some adjustments for large parallel size
reserved_mem += self.tp_size * self.pp_size / 8 * 1024
if self.enable_dp_attention:
# DP attention needs more padding for some operations
reserved_mem += self.cuda_graph_max_bs * self.dp_size * 3
# DP attention uses much more memory for large cuda graph max bs,
# likely due to some inefficiencies in torch allocator or our implementation.
# So we need to reserve more memory.
if self.cuda_graph_max_bs > 300:
reserved_mem += self.cuda_graph_max_bs * self.dp_size * 1.5
if gpu_mem > 60 * 1024:
reserved_mem = max(reserved_mem, 10 * 1024)
# B200, MI300. (chunked_prefill_size 16k, cuda_graph_max_bs 512)
reserved_mem = 32 * 1024
# draft model and larger cuda graph buffers
if self.speculative_algorithm is not None:
if self.speculative_algorithm == "STANDALONE":
# standalonedraft model and cuda graphs
# Standalone speculative decoding needs more memory than other speculative
# decoding algorithms since the draft model is typically larger.
reserved_mem += 6 * 1024
elif self.speculative_algorithm != "NGRAM":
# eagle draft models and cuda graphs
elif self.speculative_algorithm != "LOOKAHEAD":
reserved_mem += 2 * 1024
if self.enable_dp_attention:
reserved_mem += 4 * 1024
self.mem_fraction_static = round((gpu_mem - reserved_mem) / gpu_mem, 3)
else:
self.mem_fraction_static = 0.88
# Lazy init to avoid circular import
# Multimodal models need more memory for the image processor
# Lazy init to avoid circular import.
from sglang.srt.configs.model_config import ModelConfig
model_config = ModelConfig.from_server_args(self)
if model_config.is_multimodal:
self.adjust_mem_fraction_for_vlm(model_config)
def _generate_cuda_graph_batch_sizes(self):
"""
Generate the list of batch sizes for CUDA graph capture based on cuda_graph_max_bs.
This integrates the logic from cuda_graph_runner.py.
"""
# Handle disable_cuda_graph_padding as the first condition for both spec and non-spec
if self.disable_cuda_graph_padding:
capture_bs = list(range(1, self.cuda_graph_max_bs + 1))
elif self.speculative_algorithm is None:
# Normal case: [1, 2, 4, 8, 12] + list(range(16, 257, 8)) + list(range(272, 512, 16)) + list(range(512, cuda_graph_max_bs + 1))
capture_bs = (
[1, 2, 4, 8, 12]
+ list(range(16, 257, 8))
+ list(range(272, 512, 16))
+ list(range(512, self.cuda_graph_max_bs + 1))
)
def _handle_chunked_prefill_size(self, gpu_mem):
if self.chunked_prefill_size is None:
if gpu_mem is not None:
# A10, L40, 4090
if gpu_mem < 35 * 1024:
self.chunked_prefill_size = 2048
# H100, H200, A100, H20
elif gpu_mem < 160 * 1024:
self.chunked_prefill_size = 8192
# B200, MI300
else:
# Spec decoding case: list(range(1, 9, 1)) + list(range(10, 33, 2)) + list(range(40, 64, 4)) + list(range(72, 257, 8))
capture_bs = (
list(range(1, 9, 1))
+ list(range(10, 33, 2))
+ list(range(40, 64, 4))
+ list(range(72, 257, 8))
+ list(range(272, self.cuda_graph_max_bs + 1, 16))
)
capture_bs = [bs for bs in capture_bs if bs <= self.cuda_graph_max_bs]
self.chunked_prefill_size = 16384
else:
self.chunked_prefill_size = 4096
return capture_bs
def _handle_cuda_graph_max_bs(self, gpu_mem):
# Based on detailed statistics, when serving TP1/TP2 models on lower-end GPUs with HBM<25G, you can either disable cuda graph or set `cuda_graph_max_bs` to a very small value to reduce the memory overhead of creating cuda graphs, with almost no impact on performance. However, when serving models with TP4 or TP8, we need to enable cuda graph to maintain high performance. In this case, we can set `cuda_graph_max_bs` to 80 (half of the default value 160) to reduce the memory overhead of creating cuda graphs. Looking at the logs from TP4 serving of qwen2-72b, a value of 80 is sufficient and can reduce the memory overhead of creating cuda graphs on lower-end GPUs compared to the original 160, avoiding OOM issues.
if self.cuda_graph_max_bs is None:
if gpu_mem is not None and gpu_mem < 35 * 1024:
if self.tp_size < 4:
self.cuda_graph_max_bs = 8
else:
self.cuda_graph_max_bs = 80
def _handle_hpu_backends(self):
if self.device == "hpu":
......@@ -701,84 +601,6 @@ class ServerArgs:
self.attention_backend = "intel_amx"
self.sampling_backend = "pytorch"
def _handle_model_specific_adjustments(self):
if parse_connector_type(self.model_path) == ConnectorType.INSTANCE:
return
hf_config = self.get_hf_config()
model_arch = hf_config.architectures[0]
if model_arch in ["GptOssForCausalLM"]:
if self.attention_backend is None:
if is_cuda() and is_sm100_supported():
self.attention_backend = "trtllm_mha"
elif is_cuda() and is_sm90_supported():
self.attention_backend = "fa3"
else:
self.attention_backend = "triton"
supported_backends = ["triton", "trtllm_mha", "fa3"]
logger.info(
f"Use {self.attention_backend} as attention backend for GptOssForCausalLM"
)
assert (
self.attention_backend in supported_backends
), f"GptOssForCausalLM requires one of {supported_backends} attention backend, but got '{self.attention_backend}'"
if is_sm100_supported():
if not self.enable_dp_attention:
self.enable_flashinfer_allreduce_fusion = True
logger.info(
"Enable FlashInfer AllReduce Fusion on sm100 for GptOssForCausalLM"
)
quantization_config = getattr(hf_config, "quantization_config", None)
is_mxfp4_quant_format = (
quantization_config is not None
and quantization_config.get("quant_method") == "mxfp4"
)
if is_sm100_supported() and is_mxfp4_quant_format:
self.moe_runner_backend = "flashinfer_mxfp4"
logger.warning(
"Detected SM100 and MXFP4 quantization format for GPT-OSS model, enabling FlashInfer MXFP4 MOE kernel."
)
else:
if self.moe_runner_backend == "triton_kernel":
assert (
self.ep_size == 1
), "Triton kernel MoE is only supported when ep_size == 1"
if (
self.moe_runner_backend == "auto"
and self.ep_size == 1
and is_triton_kernels_available()
):
self.moe_runner_backend = "triton_kernel"
logger.warning(
"Detected GPT-OSS model, enabling triton_kernels MOE kernel."
)
self.disable_hybrid_swa_memory = True
if is_mxfp4_quant_format:
# use bf16 for mxfp4 triton kernels
self.dtype = "bfloat16"
elif "Llama4" in model_arch and self.device != "cpu":
assert self.attention_backend in {
"fa3",
"aiter",
"triton",
}, "fa3, aiter, or triton is required for Llama4 model"
elif model_arch in [
"Gemma2ForCausalLM",
"Gemma3ForCausalLM",
"Gemma3ForConditionalGeneration",
"Gemma3nForCausalLM",
"Gemma3nForConditionalGeneration",
]:
# FIXME: https://github.com/sgl-project/sglang/pull/7367 is not compatible with gemma2 model.
# It failed at this test: https://github.com/sgl-project/sglang/actions/runs/16255155597/job/45890331952#step:4:736
logger.warning(
f"Disable hybrid SWA memory for {model_arch} as it is not yet supported."
)
self.disable_hybrid_swa_memory = True
def _handle_sampling_backend(self):
if self.sampling_backend is None:
self.sampling_backend = (
......@@ -801,7 +623,7 @@ class ServerArgs:
self.speculative_algorithm is None
), "Speculative decoding is currently not supported with Flex Attention backend"
if is_npu() and self.attention_backend in ["ascend"]:
if is_npu() and self.attention_backend in ["ascend", "hybrid_linear_attn"]:
logger.warning(
"At this moment Ascend attention backend only supports a page_size of 128, change page_size to 128."
)
......@@ -964,15 +786,8 @@ class ServerArgs:
def _handle_hicache(self):
if self.hicache_storage_backend == "mooncake":
if self.hicache_mem_layout == "layer_first":
if self.hicache_io_backend == "direct":
self.hicache_mem_layout = "page_first_direct"
elif self.hicache_io_backend == "kernel":
self.hicache_io_backend = "kernel"
self.hicache_mem_layout = "page_first"
logger.warning(
f"Mooncake storage backend does not support layer_first layout, "
f"switching to {self.hicache_mem_layout} layout for {self.hicache_io_backend} io backend"
)
if self.hicache_mem_layout == "page_first_direct":
if self.hicache_io_backend != "direct":
......@@ -1007,6 +822,7 @@ class ServerArgs:
model_arch = self.get_hf_config().architectures[0]
if model_arch in [
"DeepseekV32ForCausalLM",
"DeepseekV3ForCausalLM",
"Glm4MoeForCausalLM",
"BailingMoeForCausalLM",
......@@ -1058,23 +874,23 @@ class ServerArgs:
"speculative_eagle_topk > 1 with page_size > 1 is unstable and produces incorrect results for paged attention backends. This combination is only supported for the 'flashinfer' backend."
)
if self.speculative_algorithm == "NGRAM":
if self.speculative_algorithm == "LOOKAHEAD":
if not self.device.startswith("cuda"):
raise ValueError(
"Ngram speculative decoding only supports CUDA device."
"Lookahead speculative decoding only supports CUDA device."
)
if self.max_running_requests is None:
self.max_running_requests = 48
self.disable_overlap_schedule = True
self.enable_mixed_chunk = False
self.speculative_eagle_topk = self.speculative_ngram_max_bfs_breadth
self.speculative_eagle_topk = self.speculative_lookahead_max_bfs_breadth
if self.speculative_num_draft_tokens is None:
self.speculative_num_draft_tokens = (
self.speculative_ngram_max_match_window_size
self.speculative_lookahead_max_match_window_size
)
logger.warning(
"The overlap scheduler and mixed chunked prefill are disabled because of "
"using ngram speculative decoding."
"using lookahead speculative decoding."
)
if (
......@@ -1086,9 +902,9 @@ class ServerArgs:
"speculative_eagle_topk > 1 with page_size > 1 is unstable and produces incorrect results for paged attention backends. This combination is only supported for the 'flashinfer' backend."
)
if self.enable_dp_attention:
# TODO: support dp attention for ngram speculative decoding
# TODO: support dp attention for lookahead speculative decoding
raise ValueError(
"Currently ngram speculative decoding does not support dp attention."
"Currently lookahead speculative decoding does not support dp attention."
)
def _handle_load_format(self):
......@@ -1166,55 +982,120 @@ class ServerArgs:
"and cannot be used at the same time. Please use only one of them."
)
if (
self.disaggregation_decode_enable_offload_kvcache
and self.disaggregation_mode != "decode"
):
raise ValueError(
"The argument disaggregation-decode-enable-offload-kvcache is only supported for decode side."
)
def _handle_metrics_labels(self):
if (
not self.tokenizer_metrics_custom_labels_header
and self.tokenizer_metrics_allowed_custom_labels
and self.tokenizer_metrics_allowed_customer_labels
):
raise ValueError(
"Please set --tokenizer-metrics-custom-labels-header when setting --tokenizer-metrics-allowed-custom-labels."
"Please set --tokenizer-metrics-custom-labels-header when setting --tokenizer-metrics-allowed-customer-labels."
)
def _handle_deterministic_inference(self):
if self.enable_deterministic_inference:
# Check sampling backend
self.sampling_backend = "pytorch"
logger.warning(
"Sampling backend is set to pytorch for deterministic inference."
)
import importlib
# Check attention backend
if self.attention_backend not in DETERMINISTIC_ATTENTION_BACKEND_CHOICES:
if not importlib.util.find_spec("batch_invariant_ops"):
raise ValueError(
f"Currently only {DETERMINISTIC_ATTENTION_BACKEND_CHOICES} attention backends are supported for deterministic inference."
"batch_invariant_ops is not installed. Please install it from https://github.com/thinking-machines-lab/batch_invariant_ops/."
)
# Check some settings
self.sampling_backend = "pytorch"
logger.warning(
"Sampling backend is set to pytorch for deterministic inference."
)
# Currently, only FA3 supports radix cache. Support for other backends is in progress
if self.attention_backend != "fa3":
self.disable_radix_cache = True
logger.warning(
f"Currently radix cache is not compatible with {self.attention_backend} attention backend for deterministic inference. It will be supported in the future."
"Currently radix cache is disabled for deterministic inference. It will be supported in the future."
)
# Check TP size
if self.tp_size > 1:
os.environ["NCCL_ALGO"] = "allreduce:tree"
self.disable_custom_all_reduce = True
logger.warning(
"NCCL_ALGO is set to 'allreduce:tree' and custom all reduce is disabled for deterministic inference when TP size > 1."
if self.attention_backend not in DETERMINISTIC_ATTENTION_BACKEND_CHOICES:
raise ValueError(
f"Currently only {DETERMINISTIC_ATTENTION_BACKEND_CHOICES} attention backends are supported for deterministic inference."
)
def _handle_other_validations(self):
pass
def __post_init__(self):
"""
Orchestrates the handling of various server arguments, ensuring proper configuration and validation.
"""
# Step 1: Handle deprecated arguments.
self._handle_deprecated_args()
# Step 2: Set missing default values.
self._handle_missing_default_values()
# Get GPU memory capacity, which is a common dependency for several configuration steps.
gpu_mem = get_device_memory_capacity(self.device)
# Step 3: Handle memory-related configurations.
self._handle_mem_fraction_static(gpu_mem)
self._handle_chunked_prefill_size(gpu_mem)
# Step 4: Handle CUDA graph settings.
self._handle_cuda_graph_max_bs(gpu_mem)
# Step 5: Handle device-specific backends.
self._handle_hpu_backends()
self._handle_cpu_backends()
# Step 6: Apply model-specific adjustments.
if parse_connector_type(self.model_path) != ConnectorType.INSTANCE:
self.model_specific_adjustments()
# Step 7: Set kernel backends.
self._handle_sampling_backend()
self._handle_attention_backend_compatibility()
self._handle_page_size()
self._handle_amd_specifics()
self._handle_grammar_backend()
# Step 8: Handle data parallelism.
self._handle_data_parallelism()
# Step 9: Handle MoE configurations.
self._handle_moe_kernel_config()
self._handle_deepep_moe()
self._handle_eplb_and_dispatch()
self._handle_expert_distribution_metrics()
# Step 10: Handle pipeline parallelism.
self._handle_pipeline_parallelism()
# Step 11: Handle Hicache settings.
self._handle_hicache()
# Step 12: Handle speculative decoding logic.
self._handle_speculative_decoding()
# Step 13: Handle model loading format.
self._handle_load_format()
# Step 14: Handle PD disaggregation.
self._handle_disaggregation()
# Step 15: Validate tokenizer settings.
self._handle_tokenizer_batching()
# Step 16: Propagate environment variables.
self._handle_environment_variables()
# Step 17: Validate cache settings.
self._handle_cache_compatibility()
# Step 18: Validate metrics labels.
self._handle_metrics_labels()
# Step 19: Handle deterministic inference.
self._handle_deterministic_inference()
# Step 20: Handle any other necessary validations.
self._handle_other_validations()
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
# Model and tokenizer
......@@ -1225,6 +1106,24 @@ class ServerArgs:
help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
required=True,
)
parser.add_argument(
"--remote-instance-weight-loader-seed-instance-ip",
type=str,
default=ServerArgs.remote_instance_weight_loader_seed_instance_ip,
help="The ip of the seed instance for loading weights from remote instance.",
)
parser.add_argument(
"--remote-instance-weight-loader-seed-instance-service-port",
type=int,
default=ServerArgs.remote_instance_weight_loader_seed_instance_service_port,
help="The service port of the seed instance for loading weights from remote instance.",
)
parser.add_argument(
"--remote-instance-weight-loader-send-weights-group-ports",
type=json_list_type,
default=ServerArgs.remote_instance_weight_loader_send_weights_group_ports,
help="The communication group ports for loading weights from remote instance.",
)
parser.add_argument(
"--tokenizer-path",
type=str,
......@@ -1393,11 +1292,6 @@ class ServerArgs:
choices=["auto", "fp8_e5m2", "fp8_e4m3"],
help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" and "fp8_e4m3" is supported for CUDA 11.8+.',
)
parser.add_argument(
"--enable-fp32-lm-head",
action="store_true",
help="If set, the LM head outputs (logits) are in FP32.",
)
# Memory and scheduling
parser.add_argument(
......@@ -1637,16 +1531,16 @@ class ServerArgs:
"--tokenizer-metrics-custom-labels-header",
type=str,
default=ServerArgs.tokenizer_metrics_custom_labels_header,
help="Specify the HTTP header for passing custom labels for tokenizer metrics.",
help="Specify the HTTP header for passing customer labels for tokenizer metrics.",
)
parser.add_argument(
"--tokenizer-metrics-allowed-custom-labels",
"--tokenizer-metrics-allowed-customer-labels",
type=str,
nargs="+",
default=ServerArgs.tokenizer_metrics_allowed_custom_labels,
help="The custom labels allowed for tokenizer metrics. The labels are specified via a dict in "
default=ServerArgs.tokenizer_metrics_allowed_customer_labels,
help="The customer labels allowed for tokenizer metrics. The labels are specified via a dict in "
"'--tokenizer-metrics-custom-labels-header' field in HTTP requests, e.g., {'label1': 'value1', 'label2': "
"'value2'} is allowed if '--tokenizer-metrics-allowed-custom-labels label1 label2' is set.",
"'value2'} is allowed if '--tokenizer-metrics-allowed-labels label1 label2' is set.",
)
parser.add_argument(
"--bucket-time-to-first-token",
......@@ -1678,8 +1572,8 @@ class ServerArgs:
bucket_rule = (
"Supports 3 rule types: 'default' uses predefined buckets; 'tse <middle> <base> <count>' "
"generates two sides exponential distributed buckets (e.g., 'tse 1000 2 8' generates buckets "
"[984.0, 992.0, 996.0, 998.0, 1000.0, 1002.0, 1004.0, 1008.0, 1016.0]).); 'custom <value1> "
"<value2> ...' uses custom bucket values (e.g., 'custom 10 50 100 500')."
"[984.0, 992.0, 996.0, 998.0, 1000.0, 1002.0, 1004.0, 1008.0, 1016.0]).); 'customer <value1> "
"<value2> ...' uses custom bucket values (e.g., 'customer 10 50 100 500')."
)
parser.add_argument(
"--prompt-tokens-buckets",
......@@ -1951,7 +1845,7 @@ class ServerArgs:
parser.add_argument(
"--mm-attention-backend",
type=str,
choices=["sdpa", "fa3", "triton_attn", "ascend_attn"],
choices=["sdpa", "fa3", "triton_attn"],
default=ServerArgs.mm_attention_backend,
help="Set multimodal attention backend.",
)
......@@ -1960,7 +1854,7 @@ class ServerArgs:
parser.add_argument(
"--speculative-algorithm",
type=str,
choices=["EAGLE", "EAGLE3", "NEXTN", "STANDALONE", "NGRAM"],
choices=["EAGLE", "EAGLE3", "NEXTN", "STANDALONE", "LOOKAHEAD"],
help="Speculative algorithm.",
)
parser.add_argument(
......@@ -2020,49 +1914,49 @@ class ServerArgs:
help="Attention backend for speculative decoding operations (both target verify and draft extend). Can be one of 'prefill' (default) or 'decode'.",
default=ServerArgs.speculative_attention_mode,
)
# Ngram speculative decoding
# Lookahead speculative decoding
parser.add_argument(
"--speculative-ngram-min-match-window-size",
"--speculative-lookahead-min-match-window-size",
type=int,
default=ServerArgs.speculative_ngram_min_match_window_size,
help="The minimum window size for pattern matching in ngram speculative decoding.",
default=ServerArgs.speculative_lookahead_min_match_window_size,
help="The minimum window size for pattern matching in lookahead speculative decoding.",
)
parser.add_argument(
"--speculative-ngram-max-match-window-size",
"--speculative-lookahead-max-match-window-size",
type=int,
default=ServerArgs.speculative_ngram_max_match_window_size,
help="The maximum window size for pattern matching in ngram speculative decoding.",
default=ServerArgs.speculative_lookahead_max_match_window_size,
help="The maximum window size for pattern matching in lookahead speculative decoding.",
)
parser.add_argument(
"--speculative-ngram-min-bfs-breadth",
"--speculative-lookahead-min-bfs-breadth",
type=int,
default=ServerArgs.speculative_ngram_min_bfs_breadth,
help="The minimum breadth for BFS (Breadth-First Search) in ngram speculative decoding.",
default=ServerArgs.speculative_lookahead_min_bfs_breadth,
help="The minimum breadth for BFS (Breadth-First Search) in lookahead speculative decoding.",
)
parser.add_argument(
"--speculative-ngram-max-bfs-breadth",
"--speculative-lookahead-max-bfs-breadth",
type=int,
default=ServerArgs.speculative_ngram_max_bfs_breadth,
help="The maximum breadth for BFS (Breadth-First Search) in ngram speculative decoding.",
default=ServerArgs.speculative_lookahead_max_bfs_breadth,
help="The maximum breadth for BFS (Breadth-First Search) in lookahead speculative decoding.",
)
parser.add_argument(
"--speculative-ngram-match-type",
"--speculative-lookahead-match-type",
type=str,
choices=["BFS", "PROB"],
default=ServerArgs.speculative_ngram_match_type,
default=ServerArgs.speculative_lookahead_match_type,
help="The match type for cache tree.",
)
parser.add_argument(
"--speculative-ngram-branch-length",
"--speculative-lookahead-branch-length",
type=int,
default=ServerArgs.speculative_ngram_branch_length,
help="The branch length for ngram speculative decoding.",
default=ServerArgs.speculative_lookahead_branch_length,
help="The branch length for lookahead speculative decoding.",
)
parser.add_argument(
"--speculative-ngram-capacity",
"--speculative-lookahead-capacity",
type=int,
default=ServerArgs.speculative_ngram_capacity,
help="The cache capacity for ngram speculative decoding.",
default=ServerArgs.speculative_lookahead_capacity,
help="The cache capacity for lookahead speculative decoding.",
)
# Expert parallelism
......@@ -2256,12 +2150,9 @@ class ServerArgs:
parser.add_argument(
"--hicache-storage-backend",
type=str,
choices=["file", "mooncake", "hf3fs", "nixl", "aibrix", "dynamic"],
choices=["file", "mooncake", "hf3fs", "nixl"],
default=ServerArgs.hicache_storage_backend,
help="The storage backend for hierarchical KV cache. "
"Built-in backends: file, mooncake, hf3fs, nixl, aibrix. "
"For dynamic backend, use --hicache-storage-backend-extra-config to specify: "
"backend_name (custom name), module_path (Python module path), class_name (backend class name).",
help="The storage backend for hierarchical KV cache.",
)
parser.add_argument(
"--hicache-storage-prefetch-policy",
......@@ -2571,6 +2462,12 @@ class ServerArgs:
nargs="+",
help="Sets the numa node for the subprocesses. i-th element corresponds to i-th subprocess.",
)
parser.add_argument(
"--max-prefill-bs",
type=int,
default=ServerArgs.max_prefill_bs,
help="The maximum batch size for prefill requests.",
)
# Debug tensor dumps
parser.add_argument(
......@@ -2661,11 +2558,6 @@ class ServerArgs:
"or multiple comma-separated devices (e.g., --disaggregation-ib-device mlx5_0,mlx5_1). "
"Default is None, which triggers automatic device detection when mooncake backend is enabled.",
)
parser.add_argument(
"--disaggregation-decode-enable-offload-kvcache",
action="store_true",
help="Enable async KV cache offloading on decode server (PD mode).",
)
parser.add_argument(
"--num-reserved-decode-tokens",
type=int,
......@@ -2692,24 +2584,6 @@ class ServerArgs:
action="store_true",
help="Disable mmap while loading weight using safetensors.",
)
parser.add_argument(
"--remote-instance-weight-loader-seed-instance-ip",
type=str,
default=ServerArgs.remote_instance_weight_loader_seed_instance_ip,
help="The ip of the seed instance for loading weights from remote instance.",
)
parser.add_argument(
"--remote-instance-weight-loader-seed-instance-service-port",
type=int,
default=ServerArgs.remote_instance_weight_loader_seed_instance_service_port,
help="The service port of the seed instance for loading weights from remote instance.",
)
parser.add_argument(
"--remote-instance-weight-loader-send-weights-group-ports",
type=json_list_type,
default=ServerArgs.remote_instance_weight_loader_send_weights_group_ports,
help="The communication group ports for loading weights from remote instance.",
)
# For PD-Multiplexing
parser.add_argument(
......@@ -2732,48 +2606,56 @@ class ServerArgs:
help="Enable deterministic inference mode with batch invariant ops.",
)
# For NSA models
parser.add_argument(
"--nsa-prefill",
default=NSA_DEFAULT_PREFILL,
type=str,
choices=NSA_CHOICES,
)
parser.add_argument(
"--nsa-decode",
default=NSA_DEFAULT_DECODE,
type=str,
choices=NSA_CHOICES,
)
# Deprecated arguments
parser.add_argument(
"--enable-ep-moe",
action=DeprecatedAction,
help="NOTE: --enable-ep-moe is deprecated. Please set `--ep-size` to the same value as `--tp-size` instead.",
action="store_true",
help="(Deprecated) Enabling expert parallelism for moe. The ep size is equal to the tp size.",
)
parser.add_argument(
"--enable-deepep-moe",
action=DeprecatedAction,
help="NOTE: --enable-deepep-moe is deprecated. Please set `--moe-a2a-backend` to 'deepep' instead.",
action="store_true",
help="(Deprecated) Enabling DeepEP MoE implementation for EP MoE.",
)
parser.add_argument(
"--enable-flashinfer-cutlass-moe",
action=DeprecatedAction,
help="NOTE: --enable-flashinfer-cutlass-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_cutlass' instead.",
action="store_true",
help="(Deprecated) Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP",
)
parser.add_argument(
"--enable-flashinfer-cutedsl-moe",
action=DeprecatedAction,
help="NOTE: --enable-flashinfer-cutedsl-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_cutedsl' instead.",
action="store_true",
help="(Deprecated) Enable FlashInfer CuteDSL MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP",
)
parser.add_argument(
"--enable-flashinfer-trtllm-moe",
action=DeprecatedAction,
help="NOTE: --enable-flashinfer-trtllm-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_trtllm' instead.",
action="store_true",
help="(Deprecated) Enable FlashInfer TRTLLM MoE backend on Blackwell. Supports BlockScale FP8 MoE-EP",
)
parser.add_argument(
"--enable-triton-kernel-moe",
action=DeprecatedAction,
help="NOTE: --enable-triton-kernel-moe is deprecated. Please set `--moe-runner-backend` to 'triton_kernel' instead.",
action="store_true",
help="(Deprecated) Use triton moe grouped gemm kernel.",
)
parser.add_argument(
"--enable-flashinfer-mxfp4-moe",
action=DeprecatedAction,
help="NOTE: --enable-flashinfer-mxfp4-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_mxfp4' instead.",
)
# Configuration file support
parser.add_argument(
"--config",
type=str,
help="Read CLI options from a config file. Must be a YAML file with configuration options.",
action="store_true",
help="(Deprecated) Enable FlashInfer MXFP4 MoE backend for modelopt_fp4 quant on Blackwell.",
)
@classmethod
......@@ -2967,8 +2849,8 @@ class ServerArgs:
assert rule in [
"tse",
"default",
"custom",
], f"Unsupported {arg_name} rule type: '{rule}'. Must be one of: 'tse', 'default', 'custom'"
"customer",
], f"Unsupported {arg_name} rule type: '{rule}'. Must be one of: 'tse', 'default', 'customer'"
if rule == "tse":
assert (
......@@ -2991,20 +2873,116 @@ class ServerArgs:
len(buckets_rule) == 1
), f"{arg_name} default rule should only have one parameter: ['default'], got {len(buckets_rule)}"
elif rule == "custom":
elif rule == "customer":
assert (
len(buckets_rule) >= 2
), f"{arg_name} custom rule requires at least one bucket value: ['custom', value1, ...]"
), f"{arg_name} customer rule requires at least one bucket value: ['customer', value1, ...]"
try:
bucket_values = [float(x) for x in buckets_rule[1:]]
except ValueError:
assert False, f"{arg_name} custom rule bucket values must be numeric"
assert False, f"{arg_name} customer rule bucket values must be numeric"
assert len(set(bucket_values)) == len(
bucket_values
), f"{arg_name} custom rule bucket values should not contain duplicates"
), f"{arg_name} customer rule bucket values should not contain duplicates"
assert all(
val >= 0 for val in bucket_values
), f"{arg_name} custom rule bucket values should be non-negative"
), f"{arg_name} customer rule bucket values should be non-negative"
def model_specific_adjustments(self):
from sglang.srt.configs.model_config import is_deepseek_nsa
hf_config = self.get_hf_config()
model_arch = hf_config.architectures[0]
if model_arch in ["GptOssForCausalLM"]:
if self.attention_backend is None:
if is_cuda() and is_sm100_supported():
self.attention_backend = "trtllm_mha"
elif is_cuda() and is_sm90_supported():
self.attention_backend = "fa3"
else:
self.attention_backend = "triton"
supported_backends = ["triton", "trtllm_mha", "fa3"]
logger.info(
f"Use {self.attention_backend} as attention backend for GptOssForCausalLM"
)
assert (
self.attention_backend in supported_backends
), f"GptOssForCausalLM requires one of {supported_backends} attention backend, but got '{self.attention_backend}'"
if is_sm100_supported():
if not self.enable_dp_attention:
self.enable_flashinfer_allreduce_fusion = True
logger.info(
"Enable FlashInfer AllReduce Fusion on sm100 for GptOssForCausalLM"
)
quantization_config = getattr(hf_config, "quantization_config", None)
is_mxfp4_quant_format = (
quantization_config is not None
and quantization_config.get("quant_method") == "mxfp4"
)
if is_sm100_supported() and is_mxfp4_quant_format:
self.moe_runner_backend = "flashinfer_mxfp4"
logger.warning(
"Detected SM100 and MXFP4 quantization format for GPT-OSS model, enabling FlashInfer MXFP4 MOE kernel."
)
else:
if self.moe_runner_backend == "triton_kernel":
assert (
self.ep_size == 1
), "Triton kernel MoE is only supported when ep_size == 1"
if (
self.moe_runner_backend == "auto"
and self.ep_size == 1
and is_triton_kernels_available()
):
self.moe_runner_backend = "triton_kernel"
logger.warning(
"Detected GPT-OSS model, enabling triton_kernels MOE kernel."
)
self.disable_hybrid_swa_memory = True
if is_mxfp4_quant_format:
# use bf16 for mxfp4 triton kernels
self.dtype = "bfloat16"
elif "Llama4" in model_arch and self.device != "cpu":
assert self.attention_backend in {
"fa3",
"aiter",
"triton",
}, "fa3, aiter, or triton is required for Llama4 model"
elif model_arch in [
"Gemma2ForCausalLM",
"Gemma3ForCausalLM",
"Gemma3ForConditionalGeneration",
"Gemma3nForCausalLM",
"Gemma3nForConditionalGeneration",
]:
# FIXME: https://github.com/sgl-project/sglang/pull/7367 is not compatible with gemma2 model.
# It failed at this test: https://github.com/sgl-project/sglang/actions/runs/16255155597/job/45890331952#step:4:736
logger.warning(
f"Disable hybrid SWA memory for {model_arch} as it is not yet supported."
)
self.disable_hybrid_swa_memory = True
elif is_deepseek_nsa(hf_config):
if (
self.attention_backend is None
and self.prefill_attention_backend is None
and self.decode_attention_backend is None
):
self.attention_backend = "nsa"
logger.warning("Set nsa attention backend for DeepSeek NSA.")
if not is_npu():
self.enable_dp_attention = True
self.dp_size = self.tp_size
logger.warning("DP attention is enabled for DeepSeek NSA.")
self.page_size = 64
logger.warning("Setting page size to 64 for DeepSeek NSA.")
self.max_prefill_bs = 1
logger.warning("Setting maximum prefill batch size to 1 for DeepSeek NSA.")
def adjust_mem_fraction_for_vlm(self, model_config):
vision_config = getattr(model_config.hf_config, "vision_config", None)
......@@ -3056,26 +3034,6 @@ def prepare_server_args(argv: List[str]) -> ServerArgs:
Returns:
The server arguments.
"""
# Import here to avoid circular imports
from sglang.srt.server_args_config_parser import ConfigArgumentMerger
# Check for config file and merge arguments if present
if "--config" in argv:
# Extract boolean actions from the parser to handle them correctly
parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
# Get boolean action destinations
boolean_actions = []
for action in parser._actions:
if hasattr(action, "dest") and hasattr(action, "action"):
if action.action in ["store_true", "store_false"]:
boolean_actions.append(action.dest)
# Merge config file arguments with CLI arguments
config_merger = ConfigArgumentMerger(boolean_actions=boolean_actions)
argv = config_merger.merge_config_with_args(argv)
parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
raw_args = parser.parse_args(argv)
......@@ -3217,6 +3175,7 @@ def auto_choose_speculative_params(self: ServerArgs):
# The default value for llama
return (5, 4, 8)
elif arch in [
"DeepseekV32ForCausalLM",
"DeepseekV3ForCausalLM",
"DeepseekV2ForCausalLM",
"GptOssForCausalLM",
......
"""
Configuration argument parser for command-line applications.
Handles merging of YAML configuration files with command-line arguments.
"""
import logging
from pathlib import Path
from typing import Any, Dict, List, Union
import yaml
logger = logging.getLogger(__name__)
class ConfigArgumentMerger:
"""Handles merging of configuration file arguments with command-line arguments."""
def __init__(self, boolean_actions: List[str] = None):
"""Initialize with list of boolean action destinations."""
self.boolean_actions = boolean_actions or []
def merge_config_with_args(self, cli_args: List[str]) -> List[str]:
"""
Merge configuration file arguments with command-line arguments.
Configuration arguments are inserted after the subcommand to maintain
proper precedence: CLI > Config > Defaults
Args:
cli_args: List of command-line arguments
Returns:
Merged argument list with config values inserted
Raises:
ValueError: If multiple config files specified or no config file provided
"""
config_file_path = self._extract_config_file_path(cli_args)
if not config_file_path:
return cli_args
config_args = self._parse_yaml_config(config_file_path)
return self._insert_config_args(cli_args, config_args, config_file_path)
def _extract_config_file_path(self, args: List[str]) -> str:
"""Extract the config file path from arguments."""
config_indices = [i for i, arg in enumerate(args) if arg == "--config"]
if len(config_indices) > 1:
raise ValueError("Multiple config files specified! Only one allowed.")
if not config_indices:
return None
config_index = config_indices[0]
if config_index == len(args) - 1:
raise ValueError("No config file specified after --config flag!")
return args[config_index + 1]
def _insert_config_args(
self, cli_args: List[str], config_args: List[str], config_file_path: str
) -> List[str]:
"""Insert configuration arguments into the CLI argument list."""
config_index = cli_args.index("--config")
# Split arguments around config file
before_config = cli_args[:config_index]
after_config = cli_args[config_index + 2 :] # Skip --config and file path
# Simple merge: config args + CLI args
return config_args + before_config + after_config
def _parse_yaml_config(self, file_path: str) -> List[str]:
"""
Parse YAML configuration file and convert to argument list.
Args:
file_path: Path to the YAML configuration file
Returns:
List of arguments in format ['--key', 'value', ...]
Raises:
ValueError: If file is not YAML or cannot be read
"""
self._validate_yaml_file(file_path)
try:
with open(file_path, "r") as file:
config_data = yaml.safe_load(file)
except Exception as e:
logger.error(f"Failed to read config file {file_path}: {e}")
raise
# Handle empty files or None content
if config_data is None:
config_data = {}
if not isinstance(config_data, dict):
raise ValueError("Config file must contain a dictionary at root level")
return self._convert_config_to_args(config_data)
def _validate_yaml_file(self, file_path: str) -> None:
"""Validate that the file is a YAML file."""
path = Path(file_path)
if path.suffix.lower() not in [".yaml", ".yml"]:
raise ValueError(f"Config file must be YAML format, got: {path.suffix}")
if not path.exists():
raise ValueError(f"Config file not found: {file_path}")
def _convert_config_to_args(self, config: Dict[str, Any]) -> List[str]:
"""Convert configuration dictionary to argument list."""
args = []
for key, value in config.items():
if isinstance(value, bool):
self._add_boolean_arg(args, key, value)
elif isinstance(value, list):
self._add_list_arg(args, key, value)
else:
self._add_scalar_arg(args, key, value)
return args
def _add_boolean_arg(self, args: List[str], key: str, value: bool) -> None:
"""Add boolean argument to the list."""
if key in self.boolean_actions:
# For boolean actions, always add the flag and value
args.extend([f"--{key}", str(value).lower()])
else:
# For regular booleans, only add flag if True
if value:
args.append(f"--{key}")
def _add_list_arg(self, args: List[str], key: str, value: List[Any]) -> None:
"""Add list argument to the list."""
if value: # Only add if list is not empty
args.append(f"--{key}")
args.extend(str(item) for item in value)
def _add_scalar_arg(self, args: List[str], key: str, value: Any) -> None:
"""Add scalar argument to the list."""
args.extend([f"--{key}", str(value)])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment