Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ea338676
"vscode:/vscode.git/clone" did not exist on "4556fab8e6378203f6d76c373213e535b9e4022a"
Unverified
Commit
ea338676
authored
Sep 23, 2025
by
Lianmin Zheng
Committed by
GitHub
Sep 23, 2025
Browse files
Clean up server args (#10770)
parent
b06db198
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
190 additions
and
238 deletions
+190
-238
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+190
-238
No files found.
python/sglang/srt/server_args.py
View file @
ea338676
...
@@ -19,8 +19,6 @@ import json
...
@@ -19,8 +19,6 @@ import json
import
logging
import
logging
import
os
import
os
import
random
import
random
import
socket
import
sys
import
tempfile
import
tempfile
from
typing
import
List
,
Literal
,
Optional
,
Union
from
typing
import
List
,
Literal
,
Optional
,
Union
...
@@ -328,6 +326,10 @@ class ServerArgs:
...
@@ -328,6 +326,10 @@ class ServerArgs:
deepep_config
:
Optional
[
str
]
=
None
deepep_config
:
Optional
[
str
]
=
None
moe_dense_tp_size
:
Optional
[
int
]
=
None
moe_dense_tp_size
:
Optional
[
int
]
=
None
# Mamba cache
max_mamba_cache_size
:
Optional
[
int
]
=
None
mamba_ssm_dtype
:
str
=
"float32"
# Hierarchical cache
# Hierarchical cache
enable_hierarchical_cache
:
bool
=
False
enable_hierarchical_cache
:
bool
=
False
hicache_ratio
:
float
=
2.0
hicache_ratio
:
float
=
2.0
...
@@ -398,6 +400,7 @@ class ServerArgs:
...
@@ -398,6 +400,7 @@ class ServerArgs:
enable_return_hidden_states
:
bool
=
False
enable_return_hidden_states
:
bool
=
False
scheduler_recv_interval
:
int
=
1
scheduler_recv_interval
:
int
=
1
numa_node
:
Optional
[
List
[
int
]]
=
None
numa_node
:
Optional
[
List
[
int
]]
=
None
enable_deterministic_inference
:
bool
=
False
# Dynamic batch tokenizer
# Dynamic batch tokenizer
enable_dynamic_batch_tokenizer
:
bool
=
False
enable_dynamic_batch_tokenizer
:
bool
=
False
...
@@ -419,15 +422,12 @@ class ServerArgs:
...
@@ -419,15 +422,12 @@ class ServerArgs:
disaggregation_prefill_pp
:
Optional
[
int
]
=
1
disaggregation_prefill_pp
:
Optional
[
int
]
=
1
disaggregation_ib_device
:
Optional
[
str
]
=
None
disaggregation_ib_device
:
Optional
[
str
]
=
None
num_reserved_decode_tokens
:
int
=
512
# used for decode kv cache offload in PD
num_reserved_decode_tokens
:
int
=
512
# used for decode kv cache offload in PD
# FIXME: hack to reduce ITL when decode bs is small
# FIXME: hack to reduce ITL when decode bs is small
disaggregation_decode_polling_interval
:
int
=
1
disaggregation_decode_polling_interval
:
int
=
1
# For model weight update
# For model weight update
and weight loading
custom_weight_loader
:
Optional
[
List
[
str
]]
=
None
custom_weight_loader
:
Optional
[
List
[
str
]]
=
None
weight_loader_disable_mmap
:
bool
=
False
weight_loader_disable_mmap
:
bool
=
False
# Remote instance weight loading
remote_instance_weight_loader_seed_instance_ip
:
Optional
[
str
]
=
None
remote_instance_weight_loader_seed_instance_ip
:
Optional
[
str
]
=
None
remote_instance_weight_loader_seed_instance_service_port
:
Optional
[
int
]
=
None
remote_instance_weight_loader_seed_instance_service_port
:
Optional
[
int
]
=
None
remote_instance_weight_loader_send_weights_group_ports
:
Optional
[
List
[
int
]]
=
None
remote_instance_weight_loader_send_weights_group_ports
:
Optional
[
List
[
int
]]
=
None
...
@@ -436,58 +436,84 @@ class ServerArgs:
...
@@ -436,58 +436,84 @@ class ServerArgs:
enable_pdmux
:
bool
=
False
enable_pdmux
:
bool
=
False
sm_group_num
:
int
=
3
sm_group_num
:
int
=
3
# Mamba cache
def
__post_init__
(
self
):
max_mamba_cache_size
:
Optional
[
int
]
=
None
"""
mamba_ssm_dtype
:
str
=
"float32"
Orchestrates the handling of various server arguments, ensuring proper configuration and validation.
"""
# Handle deprecated arguments.
self
.
_handle_deprecated_args
()
# For deterministic inference
# Set missing default values.
enable_deterministic_inference
:
bool
=
False
self
.
_handle_missing_default_values
()
# Deprecated arguments
# Get GPU memory capacity, which is a common dependency for several configuration steps.
enable_ep_moe
:
bool
=
False
gpu_mem
=
get_device_memory_capacity
(
self
.
device
)
enable_deepep_moe
:
bool
=
False
enable_flashinfer_cutlass_moe
:
bool
=
False
# Handle memory-related configurations.
enable_flashinfer_cutedsl_moe
:
bool
=
False
self
.
_handle_mem_fraction_static
(
gpu_mem
)
enable_flashinfer_trtllm_moe
:
bool
=
False
self
.
_handle_chunked_prefill_size
(
gpu_mem
)
enable_triton_kernel_moe
:
bool
=
False
enable_flashinfer_mxfp4_moe
:
bool
=
False
# Handle CUDA graph settings.
self
.
_handle_cuda_graph_max_bs
(
gpu_mem
)
# Handle device-specific backends.
self
.
_handle_hpu_backends
()
self
.
_handle_cpu_backends
()
# Apply model-specific adjustments.
self
.
_handle_model_specific_adjustments
()
# Set kernel backends.
self
.
_handle_sampling_backend
()
self
.
_handle_attention_backend_compatibility
()
self
.
_handle_page_size
()
self
.
_handle_amd_specifics
()
self
.
_handle_grammar_backend
()
# Handle data parallelism.
self
.
_handle_data_parallelism
()
# Handle MoE configurations.
self
.
_handle_moe_kernel_config
()
self
.
_handle_deepep_moe
()
self
.
_handle_eplb_and_dispatch
()
self
.
_handle_expert_distribution_metrics
()
# Handle pipeline parallelism.
self
.
_handle_pipeline_parallelism
()
# Handle Hicache settings.
self
.
_handle_hicache
()
# Handle speculative decoding logic.
self
.
_handle_speculative_decoding
()
# Handle model loading format.
self
.
_handle_load_format
()
# Handle PD disaggregation.
self
.
_handle_disaggregation
()
# Validate tokenizer settings.
self
.
_handle_tokenizer_batching
()
# Propagate environment variables.
self
.
_handle_environment_variables
()
# Validate cache settings.
self
.
_handle_cache_compatibility
()
# Validate metrics labels.
self
.
_handle_metrics_labels
()
# Handle deterministic inference.
self
.
_handle_deterministic_inference
()
# Handle any other necessary validations.
self
.
_handle_other_validations
()
def
_handle_deprecated_args
(
self
):
def
_handle_deprecated_args
(
self
):
if
self
.
enable_ep_moe
:
pass
self
.
ep_size
=
self
.
tp_size
print_deprecated_warning
(
"NOTE: --enable-ep-moe is deprecated. Please set `--ep-size` to the same value as `--tp-size` instead."
)
if
self
.
enable_deepep_moe
:
self
.
moe_a2a_backend
=
"deepep"
print_deprecated_warning
(
"NOTE: --enable-deepep-moe is deprecated. Please set `--moe-a2a-backend` to 'deepep' instead."
)
if
self
.
enable_triton_kernel_moe
:
self
.
moe_runner_backend
=
"triton_kernel"
print_deprecated_warning
(
"NOTE: --enable-triton-kernel-moe is deprecated. Please set `--moe-runner-backend` to 'triton_kernel' instead."
)
if
self
.
enable_flashinfer_cutedsl_moe
:
self
.
moe_runner_backend
=
"flashinfer_cutedsl"
print_deprecated_warning
(
"NOTE: --enable-flashinfer-cutedsl-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_cutedsl' instead."
)
if
self
.
enable_flashinfer_cutlass_moe
:
self
.
moe_runner_backend
=
"flashinfer_cutlass"
print_deprecated_warning
(
"NOTE: --enable-flashinfer-cutlass-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_cutlass' instead."
)
if
self
.
enable_flashinfer_trtllm_moe
:
self
.
moe_runner_backend
=
"flashinfer_trtllm"
print_deprecated_warning
(
"NOTE: --enable-flashinfer-trtllm-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_trtllm' instead."
)
if
self
.
enable_flashinfer_mxfp4_moe
:
self
.
moe_runner_backend
=
"flashinfer_mxfp4"
print_deprecated_warning
(
"NOTE: --enable-flashinfer-mxfp4-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_mxfp4' instead."
)
def
_handle_missing_default_values
(
self
):
def
_handle_missing_default_values
(
self
):
if
self
.
tokenizer_path
is
None
:
if
self
.
tokenizer_path
is
None
:
...
@@ -590,6 +616,84 @@ class ServerArgs:
...
@@ -590,6 +616,84 @@ class ServerArgs:
self
.
attention_backend
=
"intel_amx"
self
.
attention_backend
=
"intel_amx"
self
.
sampling_backend
=
"pytorch"
self
.
sampling_backend
=
"pytorch"
def
_handle_model_specific_adjustments
(
self
):
if
parse_connector_type
(
self
.
model_path
)
==
ConnectorType
.
INSTANCE
:
return
hf_config
=
self
.
get_hf_config
()
model_arch
=
hf_config
.
architectures
[
0
]
if
model_arch
in
[
"GptOssForCausalLM"
]:
if
self
.
attention_backend
is
None
:
if
is_cuda
()
and
is_sm100_supported
():
self
.
attention_backend
=
"trtllm_mha"
elif
is_cuda
()
and
is_sm90_supported
():
self
.
attention_backend
=
"fa3"
else
:
self
.
attention_backend
=
"triton"
supported_backends
=
[
"triton"
,
"trtllm_mha"
,
"fa3"
]
logger
.
info
(
f
"Use
{
self
.
attention_backend
}
as attention backend for GptOssForCausalLM"
)
assert
(
self
.
attention_backend
in
supported_backends
),
f
"GptOssForCausalLM requires one of
{
supported_backends
}
attention backend, but got '
{
self
.
attention_backend
}
'"
if
is_sm100_supported
():
if
not
self
.
enable_dp_attention
:
self
.
enable_flashinfer_allreduce_fusion
=
True
logger
.
info
(
"Enable FlashInfer AllReduce Fusion on sm100 for GptOssForCausalLM"
)
quantization_config
=
getattr
(
hf_config
,
"quantization_config"
,
None
)
is_mxfp4_quant_format
=
(
quantization_config
is
not
None
and
quantization_config
.
get
(
"quant_method"
)
==
"mxfp4"
)
if
is_sm100_supported
()
and
is_mxfp4_quant_format
:
self
.
moe_runner_backend
=
"flashinfer_mxfp4"
logger
.
warning
(
"Detected SM100 and MXFP4 quantization format for GPT-OSS model, enabling FlashInfer MXFP4 MOE kernel."
)
else
:
if
self
.
moe_runner_backend
==
"triton_kernel"
:
assert
(
self
.
ep_size
==
1
),
"Triton kernel MoE is only supported when ep_size == 1"
if
(
self
.
moe_runner_backend
==
"auto"
and
self
.
ep_size
==
1
and
is_triton_kernels_available
()
):
self
.
moe_runner_backend
=
"triton_kernel"
logger
.
warning
(
"Detected GPT-OSS model, enabling triton_kernels MOE kernel."
)
self
.
disable_hybrid_swa_memory
=
True
if
is_mxfp4_quant_format
:
# use bf16 for mxfp4 triton kernels
self
.
dtype
=
"bfloat16"
elif
"Llama4"
in
model_arch
and
self
.
device
!=
"cpu"
:
assert
self
.
attention_backend
in
{
"fa3"
,
"aiter"
,
"triton"
,
},
"fa3, aiter, or triton is required for Llama4 model"
elif
model_arch
in
[
"Gemma2ForCausalLM"
,
"Gemma3ForCausalLM"
,
"Gemma3ForConditionalGeneration"
,
"Gemma3nForCausalLM"
,
"Gemma3nForConditionalGeneration"
,
]:
# FIXME: https://github.com/sgl-project/sglang/pull/7367 is not compatible with gemma2 model.
# It failed at this test: https://github.com/sgl-project/sglang/actions/runs/16255155597/job/45890331952#step:4:736
logger
.
warning
(
f
"Disable hybrid SWA memory for
{
model_arch
}
as it is not yet supported."
)
self
.
disable_hybrid_swa_memory
=
True
def
_handle_sampling_backend
(
self
):
def
_handle_sampling_backend
(
self
):
if
self
.
sampling_backend
is
None
:
if
self
.
sampling_backend
is
None
:
self
.
sampling_backend
=
(
self
.
sampling_backend
=
(
...
@@ -1014,83 +1118,6 @@ class ServerArgs:
...
@@ -1014,83 +1118,6 @@ class ServerArgs:
def
_handle_other_validations
(
self
):
def
_handle_other_validations
(
self
):
pass
pass
def
__post_init__
(
self
):
"""
Orchestrates the handling of various server arguments, ensuring proper configuration and validation.
"""
# Step 1: Handle deprecated arguments.
self
.
_handle_deprecated_args
()
# Step 2: Set missing default values.
self
.
_handle_missing_default_values
()
# Get GPU memory capacity, which is a common dependency for several configuration steps.
gpu_mem
=
get_device_memory_capacity
(
self
.
device
)
# Step 3: Handle memory-related configurations.
self
.
_handle_mem_fraction_static
(
gpu_mem
)
self
.
_handle_chunked_prefill_size
(
gpu_mem
)
# Step 4: Handle CUDA graph settings.
self
.
_handle_cuda_graph_max_bs
(
gpu_mem
)
# Step 5: Handle device-specific backends.
self
.
_handle_hpu_backends
()
self
.
_handle_cpu_backends
()
# Step 6: Apply model-specific adjustments.
if
parse_connector_type
(
self
.
model_path
)
!=
ConnectorType
.
INSTANCE
:
self
.
model_specific_adjustments
()
# Step 7: Set kernel backends.
self
.
_handle_sampling_backend
()
self
.
_handle_attention_backend_compatibility
()
self
.
_handle_page_size
()
self
.
_handle_amd_specifics
()
self
.
_handle_grammar_backend
()
# Step 8: Handle data parallelism.
self
.
_handle_data_parallelism
()
# Step 9: Handle MoE configurations.
self
.
_handle_moe_kernel_config
()
self
.
_handle_deepep_moe
()
self
.
_handle_eplb_and_dispatch
()
self
.
_handle_expert_distribution_metrics
()
# Step 10: Handle pipeline parallelism.
self
.
_handle_pipeline_parallelism
()
# Step 11: Handle Hicache settings.
self
.
_handle_hicache
()
# Step 12: Handle speculative decoding logic.
self
.
_handle_speculative_decoding
()
# Step 13: Handle model loading format.
self
.
_handle_load_format
()
# Step 14: Handle PD disaggregation.
self
.
_handle_disaggregation
()
# Step 15: Validate tokenizer settings.
self
.
_handle_tokenizer_batching
()
# Step 16: Propagate environment variables.
self
.
_handle_environment_variables
()
# Step 17: Validate cache settings.
self
.
_handle_cache_compatibility
()
# Step 18: Validate metrics labels.
self
.
_handle_metrics_labels
()
# Step 19: Handle deterministic inference.
self
.
_handle_deterministic_inference
()
# Step 20: Handle any other necessary validations.
self
.
_handle_other_validations
()
@
staticmethod
@
staticmethod
def
add_cli_args
(
parser
:
argparse
.
ArgumentParser
):
def
add_cli_args
(
parser
:
argparse
.
ArgumentParser
):
# Model and tokenizer
# Model and tokenizer
...
@@ -1101,24 +1128,6 @@ class ServerArgs:
...
@@ -1101,24 +1128,6 @@ class ServerArgs:
help
=
"The path of the model weights. This can be a local folder or a Hugging Face repo ID."
,
help
=
"The path of the model weights. This can be a local folder or a Hugging Face repo ID."
,
required
=
True
,
required
=
True
,
)
)
parser
.
add_argument
(
"--remote-instance-weight-loader-seed-instance-ip"
,
type
=
str
,
default
=
ServerArgs
.
remote_instance_weight_loader_seed_instance_ip
,
help
=
"The ip of the seed instance for loading weights from remote instance."
,
)
parser
.
add_argument
(
"--remote-instance-weight-loader-seed-instance-service-port"
,
type
=
int
,
default
=
ServerArgs
.
remote_instance_weight_loader_seed_instance_service_port
,
help
=
"The service port of the seed instance for loading weights from remote instance."
,
)
parser
.
add_argument
(
"--remote-instance-weight-loader-send-weights-group-ports"
,
type
=
json_list_type
,
default
=
ServerArgs
.
remote_instance_weight_loader_send_weights_group_ports
,
help
=
"The communication group ports for loading weights from remote instance."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--tokenizer-path"
,
"--tokenizer-path"
,
type
=
str
,
type
=
str
,
...
@@ -2573,6 +2582,24 @@ class ServerArgs:
...
@@ -2573,6 +2582,24 @@ class ServerArgs:
action
=
"store_true"
,
action
=
"store_true"
,
help
=
"Disable mmap while loading weight using safetensors."
,
help
=
"Disable mmap while loading weight using safetensors."
,
)
)
parser
.
add_argument
(
"--remote-instance-weight-loader-seed-instance-ip"
,
type
=
str
,
default
=
ServerArgs
.
remote_instance_weight_loader_seed_instance_ip
,
help
=
"The ip of the seed instance for loading weights from remote instance."
,
)
parser
.
add_argument
(
"--remote-instance-weight-loader-seed-instance-service-port"
,
type
=
int
,
default
=
ServerArgs
.
remote_instance_weight_loader_seed_instance_service_port
,
help
=
"The service port of the seed instance for loading weights from remote instance."
,
)
parser
.
add_argument
(
"--remote-instance-weight-loader-send-weights-group-ports"
,
type
=
json_list_type
,
default
=
ServerArgs
.
remote_instance_weight_loader_send_weights_group_ports
,
help
=
"The communication group ports for loading weights from remote instance."
,
)
# For PD-Multiplexing
# For PD-Multiplexing
parser
.
add_argument
(
parser
.
add_argument
(
...
@@ -2598,38 +2625,38 @@ class ServerArgs:
...
@@ -2598,38 +2625,38 @@ class ServerArgs:
# Deprecated arguments
# Deprecated arguments
parser
.
add_argument
(
parser
.
add_argument
(
"--enable-ep-moe"
,
"--enable-ep-moe"
,
action
=
"store_true"
,
action
=
DeprecatedAction
,
help
=
"
(Deprecated) Enabling expert parallelism for moe. The
ep
size
is equal to the
tp
size."
,
help
=
"
NOTE: --enable-ep-moe is deprecated. Please set `--
ep
-
size
` to the same value as `--
tp
-
size
` instead
."
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--enable-deepep-moe"
,
"--enable-deepep-moe"
,
action
=
"store_true"
,
action
=
DeprecatedAction
,
help
=
"
(Deprecated) Enabling DeepEP MoE implementation for EP MoE
."
,
help
=
"
NOTE: --enable-deepep-moe is deprecated. Please set `--moe-a2a-backend` to 'deepep' instead
."
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--enable-flashinfer-cutlass-moe"
,
"--enable-flashinfer-cutlass-moe"
,
action
=
"store_true"
,
action
=
DeprecatedAction
,
help
=
"
(Deprecated) E
nable
F
lash
I
nfer
CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP
"
,
help
=
"
NOTE: --e
nable
-f
lash
i
nfer
-cutlass-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_cutlass' instead.
"
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--enable-flashinfer-cutedsl-moe"
,
"--enable-flashinfer-cutedsl-moe"
,
action
=
"store_true"
,
action
=
DeprecatedAction
,
help
=
"
(Deprecated) E
nable
F
lash
I
nfer
C
ute
DSL MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP
"
,
help
=
"
NOTE: --e
nable
-f
lash
i
nfer
-c
ute
dsl-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_cutedsl' instead.
"
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--enable-flashinfer-trtllm-moe"
,
"--enable-flashinfer-trtllm-moe"
,
action
=
"store_true"
,
action
=
DeprecatedAction
,
help
=
"
(Deprecated) Enable FlashInfer TRTLLM MoE backend on Blackwell. Supports BlockScale FP8 MoE-EP
"
,
help
=
"
NOTE: --enable-flashinfer-trtllm-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_trtllm' instead.
"
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--enable-triton-kernel-moe"
,
"--enable-triton-kernel-moe"
,
action
=
"store_true"
,
action
=
DeprecatedAction
,
help
=
"
(Deprecated) Use triton moe grouped gemm kernel
."
,
help
=
"
NOTE: --enable-triton-kernel-moe is deprecated. Please set `--moe-runner-backend` to 'triton_kernel' instead
."
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--enable-flashinfer-mxfp4-moe"
,
"--enable-flashinfer-mxfp4-moe"
,
action
=
"store_true"
,
action
=
DeprecatedAction
,
help
=
"
(Deprecated) E
nable
F
lash
I
nfer
MXFP4 MoE backend for modelopt_fp4 quant on Blackwell
."
,
help
=
"
NOTE: --e
nable
-f
lash
i
nfer
-mxfp4-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_mxfp4' instead
."
,
)
)
@
classmethod
@
classmethod
...
@@ -2862,81 +2889,6 @@ class ServerArgs:
...
@@ -2862,81 +2889,6 @@ class ServerArgs:
val
>=
0
for
val
in
bucket_values
val
>=
0
for
val
in
bucket_values
),
f
"
{
arg_name
}
customer rule bucket values should be non-negative"
),
f
"
{
arg_name
}
customer rule bucket values should be non-negative"
def
model_specific_adjustments
(
self
):
hf_config
=
self
.
get_hf_config
()
model_arch
=
hf_config
.
architectures
[
0
]
if
model_arch
in
[
"GptOssForCausalLM"
]:
if
self
.
attention_backend
is
None
:
if
is_cuda
()
and
is_sm100_supported
():
self
.
attention_backend
=
"trtllm_mha"
elif
is_cuda
()
and
is_sm90_supported
():
self
.
attention_backend
=
"fa3"
else
:
self
.
attention_backend
=
"triton"
supported_backends
=
[
"triton"
,
"trtllm_mha"
,
"fa3"
]
logger
.
info
(
f
"Use
{
self
.
attention_backend
}
as attention backend for GptOssForCausalLM"
)
assert
(
self
.
attention_backend
in
supported_backends
),
f
"GptOssForCausalLM requires one of
{
supported_backends
}
attention backend, but got '
{
self
.
attention_backend
}
'"
if
is_sm100_supported
():
if
not
self
.
enable_dp_attention
:
self
.
enable_flashinfer_allreduce_fusion
=
True
logger
.
info
(
"Enable FlashInfer AllReduce Fusion on sm100 for GptOssForCausalLM"
)
quantization_config
=
getattr
(
hf_config
,
"quantization_config"
,
None
)
is_mxfp4_quant_format
=
(
quantization_config
is
not
None
and
quantization_config
.
get
(
"quant_method"
)
==
"mxfp4"
)
if
is_sm100_supported
()
and
is_mxfp4_quant_format
:
self
.
moe_runner_backend
=
"flashinfer_mxfp4"
logger
.
warning
(
"Detected SM100 and MXFP4 quantization format for GPT-OSS model, enabling FlashInfer MXFP4 MOE kernel."
)
else
:
if
self
.
moe_runner_backend
==
"triton_kernel"
:
assert
(
self
.
ep_size
==
1
),
"Triton kernel MoE is only supported when ep_size == 1"
if
(
self
.
moe_runner_backend
==
"auto"
and
self
.
ep_size
==
1
and
is_triton_kernels_available
()
):
self
.
moe_runner_backend
=
"triton_kernel"
logger
.
warning
(
"Detected GPT-OSS model, enabling triton_kernels MOE kernel."
)
self
.
disable_hybrid_swa_memory
=
True
if
is_mxfp4_quant_format
:
# use bf16 for mxfp4 triton kernels
self
.
dtype
=
"bfloat16"
elif
"Llama4"
in
model_arch
and
self
.
device
!=
"cpu"
:
assert
self
.
attention_backend
in
{
"fa3"
,
"aiter"
,
"triton"
,
},
"fa3, aiter, or triton is required for Llama4 model"
elif
model_arch
in
[
"Gemma2ForCausalLM"
,
"Gemma3ForCausalLM"
,
"Gemma3ForConditionalGeneration"
,
"Gemma3nForCausalLM"
,
"Gemma3nForConditionalGeneration"
,
]:
# FIXME: https://github.com/sgl-project/sglang/pull/7367 is not compatible with gemma2 model.
# It failed at this test: https://github.com/sgl-project/sglang/actions/runs/16255155597/job/45890331952#step:4:736
logger
.
warning
(
f
"Disable hybrid SWA memory for
{
model_arch
}
as it is not yet supported."
)
self
.
disable_hybrid_swa_memory
=
True
def
adjust_mem_fraction_for_vlm
(
self
,
model_config
):
def
adjust_mem_fraction_for_vlm
(
self
,
model_config
):
vision_config
=
getattr
(
model_config
.
hf_config
,
"vision_config"
,
None
)
vision_config
=
getattr
(
model_config
.
hf_config
,
"vision_config"
,
None
)
if
vision_config
is
None
:
if
vision_config
is
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment