Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
1083e7e3
Unverified
Commit
1083e7e3
authored
Oct 13, 2025
by
Liangsheng Yin
Committed by
GitHub
Oct 13, 2025
Browse files
Deprecate `global_server_args_dict` (#11331)
parent
2157d12a
Changes
54
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
105 additions
and
131 deletions
+105
-131
python/sglang/srt/mem_cache/common.py
python/sglang/srt/mem_cache/common.py
+4
-8
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+16
-21
python/sglang/srt/model_loader/loader.py
python/sglang/srt/model_loader/loader.py
+6
-13
python/sglang/srt/models/apertus.py
python/sglang/srt/models/apertus.py
+2
-3
python/sglang/srt/models/arcee.py
python/sglang/srt/models/arcee.py
+2
-2
python/sglang/srt/models/bailing_moe.py
python/sglang/srt/models/bailing_moe.py
+6
-6
python/sglang/srt/models/bailing_moe_nextn.py
python/sglang/srt/models/bailing_moe_nextn.py
+3
-4
python/sglang/srt/models/deepseek_nextn.py
python/sglang/srt/models/deepseek_nextn.py
+2
-2
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+23
-21
python/sglang/srt/models/falcon_h1.py
python/sglang/srt/models/falcon_h1.py
+2
-2
python/sglang/srt/models/glm4_moe.py
python/sglang/srt/models/glm4_moe.py
+8
-12
python/sglang/srt/models/glm4_moe_nextn.py
python/sglang/srt/models/glm4_moe_nextn.py
+2
-2
python/sglang/srt/models/glm4v_moe.py
python/sglang/srt/models/glm4v_moe.py
+5
-5
python/sglang/srt/models/gpt_oss.py
python/sglang/srt/models/gpt_oss.py
+4
-4
python/sglang/srt/models/grok.py
python/sglang/srt/models/grok.py
+5
-7
python/sglang/srt/models/llama.py
python/sglang/srt/models/llama.py
+2
-2
python/sglang/srt/models/longcat_flash.py
python/sglang/srt/models/longcat_flash.py
+3
-7
python/sglang/srt/models/mllama4.py
python/sglang/srt/models/mllama4.py
+2
-2
python/sglang/srt/models/qwen2_moe.py
python/sglang/srt/models/qwen2_moe.py
+4
-4
python/sglang/srt/models/qwen3_moe.py
python/sglang/srt/models/qwen3_moe.py
+4
-4
No files found.
python/sglang/srt/mem_cache/common.py
View file @
1083e7e3
...
@@ -11,7 +11,7 @@ from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
...
@@ -11,7 +11,7 @@ from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
,
SWAChunkCache
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
,
SWAChunkCache
from
sglang.srt.mem_cache.memory_pool
import
HybridReqToTokenPool
,
ReqToTokenPool
from
sglang.srt.mem_cache.memory_pool
import
HybridReqToTokenPool
,
ReqToTokenPool
from
sglang.srt.server_args
import
S
erver
A
rgs
from
sglang.srt.server_args
import
get_global_s
erver
_a
rgs
from
sglang.srt.utils
import
support_triton
from
sglang.srt.utils
import
support_triton
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -19,10 +19,6 @@ if TYPE_CHECKING:
...
@@ -19,10 +19,6 @@ if TYPE_CHECKING:
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
GLOBAL_SERVER_ARGS_KEYS
=
[
"attention_backend"
]
global_server_args_dict
=
{
k
:
getattr
(
ServerArgs
,
k
)
for
k
in
GLOBAL_SERVER_ARGS_KEYS
}
@
triton
.
jit
@
triton
.
jit
def
write_req_to_token_pool_triton
(
def
write_req_to_token_pool_triton
(
...
@@ -88,7 +84,7 @@ def write_cache_indices(
...
@@ -88,7 +84,7 @@ def write_cache_indices(
prefix_tensors
:
list
[
torch
.
Tensor
],
prefix_tensors
:
list
[
torch
.
Tensor
],
req_to_token_pool
:
ReqToTokenPool
,
req_to_token_pool
:
ReqToTokenPool
,
):
):
if
support_triton
(
global_server_args
_dict
.
get
(
"
attention_backend
"
)
):
if
support_triton
(
get_
global_server_args
().
attention_backend
):
prefix_pointers
=
torch
.
tensor
(
prefix_pointers
=
torch
.
tensor
(
[
t
.
data_ptr
()
for
t
in
prefix_tensors
],
[
t
.
data_ptr
()
for
t
in
prefix_tensors
],
device
=
req_to_token_pool
.
device
,
device
=
req_to_token_pool
.
device
,
...
@@ -129,8 +125,8 @@ def get_last_loc(
...
@@ -129,8 +125,8 @@ def get_last_loc(
prefix_lens_tensor
:
torch
.
Tensor
,
prefix_lens_tensor
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
(
if
(
global_server_args
_dict
[
"
attention_backend
"
]
!=
"ascend"
get_
global_server_args
().
attention_backend
!=
"ascend"
and
global_server_args
_dict
[
"
attention_backend
"
]
!=
"torch_native"
and
get_
global_server_args
().
attention_backend
!=
"torch_native"
):
):
impl
=
get_last_loc_triton
impl
=
get_last_loc_triton
else
:
else
:
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
1083e7e3
...
@@ -83,10 +83,6 @@ from sglang.srt.layers.sampler import Sampler
...
@@ -83,10 +83,6 @@ from sglang.srt.layers.sampler import Sampler
from
sglang.srt.layers.torchao_utils
import
apply_torchao_config_to_model
from
sglang.srt.layers.torchao_utils
import
apply_torchao_config_to_model
from
sglang.srt.lora.lora_manager
import
LoRAManager
from
sglang.srt.lora.lora_manager
import
LoRAManager
from
sglang.srt.lora.lora_registry
import
LoRARef
from
sglang.srt.lora.lora_registry
import
LoRARef
from
sglang.srt.managers.schedule_batch
import
(
GLOBAL_SERVER_ARGS_KEYS
,
global_server_args_dict
,
)
from
sglang.srt.mem_cache.allocator
import
(
from
sglang.srt.mem_cache.allocator
import
(
BaseTokenToKVPoolAllocator
,
BaseTokenToKVPoolAllocator
,
PagedTokenToKVPoolAllocator
,
PagedTokenToKVPoolAllocator
,
...
@@ -125,7 +121,11 @@ from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
...
@@ -125,7 +121,11 @@ from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
from
sglang.srt.model_loader.utils
import
set_default_torch_dtype
from
sglang.srt.model_loader.utils
import
set_default_torch_dtype
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
(
ServerArgs
,
get_global_server_args
,
set_global_server_args_for_scheduler
,
)
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
MultiprocessingSerializer
,
MultiprocessingSerializer
,
...
@@ -275,15 +275,12 @@ class ModelRunner:
...
@@ -275,15 +275,12 @@ class ModelRunner:
# Model-specific adjustment
# Model-specific adjustment
self
.
model_specific_adjustment
()
self
.
model_specific_adjustment
()
# Global vars
# Set the global server_args in the scheduler process
global_server_args_dict
.
update
(
set_global_server_args_for_scheduler
(
server_args
)
{
k
:
getattr
(
server_args
,
k
)
for
k
in
GLOBAL_SERVER_ARGS_KEYS
}
global_server_args
=
get_global_server_args
()
|
{
# TODO it is indeed not a "server args"
# FIXME: hacky set `use_mla_backend`
"use_mla_backend"
:
self
.
use_mla_backend
,
global_server_args
.
use_mla_backend
=
self
.
use_mla_backend
"speculative_algorithm"
:
self
.
spec_algorithm
,
}
)
# Init OpenMP threads binding for CPU
# Init OpenMP threads binding for CPU
if
self
.
device
==
"cpu"
:
if
self
.
device
==
"cpu"
:
...
@@ -432,7 +429,7 @@ class ModelRunner:
...
@@ -432,7 +429,7 @@ class ModelRunner:
# In layered loading, torchao may have been applied
# In layered loading, torchao may have been applied
if
not
torchao_applied
:
if
not
torchao_applied
:
apply_torchao_config_to_model
(
apply_torchao_config_to_model
(
self
.
model
,
global_server_args
_dict
[
"
torchao_config
"
]
self
.
model
,
get_
global_server_args
().
torchao_config
)
)
# Apply torch TP if the model supports it
# Apply torch TP if the model supports it
...
@@ -1838,12 +1835,10 @@ class ModelRunner:
...
@@ -1838,12 +1835,10 @@ class ModelRunner:
self
.
server_args
.
attention_backend
self
.
server_args
.
attention_backend
)
)
global_server_args_dict
.
update
(
(
{
get_global_server_args
().
prefill_attention_backend
,
"decode_attention_backend"
:
self
.
decode_attention_backend_str
,
get_global_server_args
().
decode_attention_backend
,
"prefill_attention_backend"
:
self
.
prefill_attention_backend_str
,
)
=
(
self
.
prefill_attention_backend_str
,
self
.
decode_attention_backend_str
)
}
)
return
attn_backend
return
attn_backend
def
_get_attention_backend_from_str
(
self
,
backend_str
:
str
):
def
_get_attention_backend_from_str
(
self
,
backend_str
:
str
):
...
...
python/sglang/srt/model_loader/loader.py
View file @
1083e7e3
...
@@ -4,7 +4,6 @@ from __future__ import annotations
...
@@ -4,7 +4,6 @@ from __future__ import annotations
# ruff: noqa: SIM117
# ruff: noqa: SIM117
import
collections
import
collections
import
concurrent
import
dataclasses
import
dataclasses
import
fnmatch
import
fnmatch
import
glob
import
glob
...
@@ -12,12 +11,10 @@ import json
...
@@ -12,12 +11,10 @@ import json
import
logging
import
logging
import
math
import
math
import
os
import
os
import
re
import
socket
import
socket
import
threading
import
threading
import
time
import
time
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
concurrent.futures
import
ThreadPoolExecutor
from
contextlib
import
contextmanager
,
suppress
from
contextlib
import
contextmanager
,
suppress
from
typing
import
(
from
typing
import
(
TYPE_CHECKING
,
TYPE_CHECKING
,
...
@@ -33,10 +30,10 @@ from typing import (
...
@@ -33,10 +30,10 @@ from typing import (
import
huggingface_hub
import
huggingface_hub
import
numpy
as
np
import
numpy
as
np
import
requests
import
safetensors.torch
import
torch
import
torch
from
sglang.srt.server_args
import
get_global_server_args
# Try to import accelerate (optional dependency)
# Try to import accelerate (optional dependency)
try
:
try
:
from
accelerate
import
infer_auto_device_map
,
init_empty_weights
from
accelerate
import
infer_auto_device_map
,
init_empty_weights
...
@@ -81,8 +78,6 @@ DEFAULT_GPU_MEMORY_FRACTION_FOR_CALIBRATION = (
...
@@ -81,8 +78,6 @@ DEFAULT_GPU_MEMORY_FRACTION_FOR_CALIBRATION = (
0.8
# Reserve 20% GPU memory headroom for ModelOpt calibration
0.8
# Reserve 20% GPU memory headroom for ModelOpt calibration
)
)
from
sglang.srt.model_loader.weight_utils
import
(
from
sglang.srt.model_loader.weight_utils
import
(
_BAR_FORMAT
,
default_weight_loader
,
download_safetensors_index_file_from_hf
,
download_safetensors_index_file_from_hf
,
download_weights_from_hf
,
download_weights_from_hf
,
filter_duplicate_safetensors_files
,
filter_duplicate_safetensors_files
,
...
@@ -445,10 +440,8 @@ class DefaultModelLoader(BaseModelLoader):
...
@@ -445,10 +440,8 @@ class DefaultModelLoader(BaseModelLoader):
hf_weights_files
,
hf_weights_files
,
)
)
elif
use_safetensors
:
elif
use_safetensors
:
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
weight_loader_disable_mmap
=
(
get_global_server_args
().
weight_loader_disable_mmap
weight_loader_disable_mmap
=
global_server_args_dict
.
get
(
"weight_loader_disable_mmap"
)
)
if
extra_config
.
get
(
"enable_multithread_load"
):
if
extra_config
.
get
(
"enable_multithread_load"
):
...
@@ -616,9 +609,9 @@ class LayeredModelLoader(DefaultModelLoader):
...
@@ -616,9 +609,9 @@ class LayeredModelLoader(DefaultModelLoader):
device_config
:
DeviceConfig
,
device_config
:
DeviceConfig
,
)
->
nn
.
Module
:
)
->
nn
.
Module
:
from
sglang.srt.layers.torchao_utils
import
apply_torchao_config_to_model
from
sglang.srt.layers.torchao_utils
import
apply_torchao_config_to_model
from
sglang.srt.
managers.schedule_batch
import
global_server_args
_dict
from
sglang.srt.
server_args
import
get_
global_server_args
torchao_config
=
global_server_args
_dict
.
get
(
"
torchao_config
"
)
torchao_config
=
get_
global_server_args
().
torchao_config
target_device
=
torch
.
device
(
device_config
.
device
)
target_device
=
torch
.
device
(
device_config
.
device
)
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
set_default_torch_dtype
(
model_config
.
dtype
):
...
...
python/sglang/srt/models/apertus.py
View file @
1083e7e3
...
@@ -46,15 +46,14 @@ from sglang.srt.layers.vocab_parallel_embedding import (
...
@@ -46,15 +46,14 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_loader.weight_utils
import
(
from
sglang.srt.model_loader.weight_utils
import
(
default_weight_loader
,
default_weight_loader
,
kv_cache_scales_loader
,
kv_cache_scales_loader
,
maybe_remap_kv_scale_name
,
maybe_remap_kv_scale_name
,
)
)
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.utils
import
add_prefix
,
make_layers
from
sglang.srt.utils
import
add_prefix
,
make_layers
from
sglang.utils
import
get_exception_traceback
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -447,7 +446,7 @@ class ApertusForCausalLM(nn.Module):
...
@@ -447,7 +446,7 @@ class ApertusForCausalLM(nn.Module):
config
.
hidden_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"lm_head"
,
prefix
),
prefix
=
add_prefix
(
"lm_head"
,
prefix
),
use_attn_tp_group
=
global_server_args
_dict
[
"
enable_dp_lm_head
"
]
,
use_attn_tp_group
=
get_
global_server_args
().
enable_dp_lm_head
,
)
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
...
...
python/sglang/srt/models/arcee.py
View file @
1083e7e3
...
@@ -42,13 +42,13 @@ from sglang.srt.layers.vocab_parallel_embedding import (
...
@@ -42,13 +42,13 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_loader.weight_utils
import
(
from
sglang.srt.model_loader.weight_utils
import
(
default_weight_loader
,
default_weight_loader
,
kv_cache_scales_loader
,
kv_cache_scales_loader
,
maybe_remap_kv_scale_name
,
maybe_remap_kv_scale_name
,
)
)
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.utils
import
add_prefix
,
make_layers
from
sglang.srt.utils
import
add_prefix
,
make_layers
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -407,7 +407,7 @@ class ArceeForCausalLM(nn.Module):
...
@@ -407,7 +407,7 @@ class ArceeForCausalLM(nn.Module):
config
.
hidden_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"lm_head"
,
prefix
),
prefix
=
add_prefix
(
"lm_head"
,
prefix
),
use_attn_tp_group
=
global_server_args
_dict
[
"
enable_dp_lm_head
"
]
,
use_attn_tp_group
=
get_
global_server_args
().
enable_dp_lm_head
,
)
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
...
...
python/sglang/srt/models/bailing_moe.py
View file @
1083e7e3
...
@@ -17,7 +17,7 @@
...
@@ -17,7 +17,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""
SGLang BailingMoE model."""
"""SGLang BailingMoE model."""
import
logging
import
logging
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Tuple
,
Union
...
@@ -68,7 +68,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
...
@@ -68,7 +68,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.cuda_graph_runner
import
get_is_capture_mode
from
sglang.srt.model_executor.cuda_graph_runner
import
get_is_capture_mode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
...
@@ -76,6 +75,7 @@ from sglang.srt.models.utils import (
...
@@ -76,6 +75,7 @@ from sglang.srt.models.utils import (
create_fused_set_kv_buffer_arg
,
create_fused_set_kv_buffer_arg
,
enable_fused_set_kv_buffer
,
enable_fused_set_kv_buffer
,
)
)
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.utils
import
add_prefix
,
is_cuda
,
is_non_idle_and_non_empty
,
make_layers
from
sglang.srt.utils
import
add_prefix
,
is_cuda
,
is_non_idle_and_non_empty
,
make_layers
LoraConfig
=
None
LoraConfig
=
None
...
@@ -204,8 +204,8 @@ class BailingMoESparseMoeBlock(nn.Module):
...
@@ -204,8 +204,8 @@ class BailingMoESparseMoeBlock(nn.Module):
else
:
else
:
self
.
router_dtype
=
torch
.
bfloat16
self
.
router_dtype
=
torch
.
bfloat16
# TODO global_server_args
_dict["
ep_num_redundant_experts
"]
is used for eplb, not supported now
# TODO global_server_args
.
ep_num_redundant_experts is used for eplb, not supported now
assert
global_server_args
_dict
[
"
ep_num_redundant_experts
"
]
==
0
assert
get_
global_server_args
().
ep_num_redundant_experts
==
0
# check group topk
# check group topk
self
.
num_expert_group
=
getattr
(
config
,
"n_group"
,
0
)
self
.
num_expert_group
=
getattr
(
config
,
"n_group"
,
0
)
self
.
topk_group
=
getattr
(
config
,
"topk_group"
,
0
)
self
.
topk_group
=
getattr
(
config
,
"topk_group"
,
0
)
...
@@ -220,7 +220,7 @@ class BailingMoESparseMoeBlock(nn.Module):
...
@@ -220,7 +220,7 @@ class BailingMoESparseMoeBlock(nn.Module):
self
.
use_grouped_topk
=
False
self
.
use_grouped_topk
=
False
self
.
num_experts
=
(
self
.
num_experts
=
(
config
.
num_experts
+
global_server_args
_dict
[
"
ep_num_redundant_experts
"
]
config
.
num_experts
+
get_
global_server_args
().
ep_num_redundant_experts
)
)
self
.
gate
=
BailingMoEGate
(
self
.
gate
=
BailingMoEGate
(
...
@@ -824,7 +824,7 @@ class BailingMoEForCausalLM(nn.Module):
...
@@ -824,7 +824,7 @@ class BailingMoEForCausalLM(nn.Module):
config
.
hidden_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"lm_head"
,
prefix
),
prefix
=
add_prefix
(
"lm_head"
,
prefix
),
use_attn_tp_group
=
global_server_args
_dict
[
"
enable_dp_lm_head
"
]
,
use_attn_tp_group
=
get_
global_server_args
().
enable_dp_lm_head
,
)
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
...
...
python/sglang/srt/models/bailing_moe_nextn.py
View file @
1083e7e3
...
@@ -17,7 +17,7 @@
...
@@ -17,7 +17,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""
SGLang BailingMoENextN model."""
"""SGLang BailingMoENextN model."""
import
logging
import
logging
from
typing
import
Iterable
,
Optional
,
Tuple
from
typing
import
Iterable
,
Optional
,
Tuple
...
@@ -29,15 +29,14 @@ from sglang.srt.distributed import get_tensor_model_parallel_world_size
...
@@ -29,15 +29,14 @@ from sglang.srt.distributed import get_tensor_model_parallel_world_size
from
sglang.srt.layers.dp_attention
import
is_dp_attention_enabled
from
sglang.srt.layers.dp_attention
import
is_dp_attention_enabled
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.moe.topk
import
select_experts
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.vocab_parallel_embedding
import
(
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.models.bailing_moe
import
BailingMoEBlock
,
BailingMoEForCausalLM
from
sglang.srt.models.bailing_moe
import
BailingMoEBlock
,
BailingMoEForCausalLM
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.utils
import
add_prefix
from
sglang.srt.utils
import
add_prefix
LoraConfig
=
None
LoraConfig
=
None
...
@@ -145,7 +144,7 @@ class BailingMoeForCausalLMNextN(BailingMoEForCausalLM):
...
@@ -145,7 +144,7 @@ class BailingMoeForCausalLMNextN(BailingMoEForCausalLM):
config
.
hidden_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"model.shared_head.head"
,
prefix
),
prefix
=
add_prefix
(
"model.shared_head.head"
,
prefix
),
use_attn_tp_group
=
global_server_args
_dict
[
"
enable_dp_lm_head
"
]
,
use_attn_tp_group
=
get_
global_server_args
().
enable_dp_lm_head
,
)
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
...
...
python/sglang/srt/models/deepseek_nextn.py
View file @
1083e7e3
...
@@ -30,9 +30,9 @@ from sglang.srt.layers.vocab_parallel_embedding import (
...
@@ -30,9 +30,9 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.models.deepseek_v2
import
DeepseekV2DecoderLayer
,
DeepseekV3ForCausalLM
from
sglang.srt.models.deepseek_v2
import
DeepseekV2DecoderLayer
,
DeepseekV3ForCausalLM
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.utils
import
BumpAllocator
,
add_prefix
,
is_cuda
from
sglang.srt.utils
import
BumpAllocator
,
add_prefix
,
is_cuda
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -152,7 +152,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
...
@@ -152,7 +152,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
config
.
hidden_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"model.shared_head.head"
,
prefix
),
prefix
=
add_prefix
(
"model.shared_head.head"
,
prefix
),
use_attn_tp_group
=
global_server_args
_dict
[
"
enable_dp_lm_head
"
]
,
use_attn_tp_group
=
get_
global_server_args
().
enable_dp_lm_head
,
)
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
1083e7e3
...
@@ -35,7 +35,6 @@ from sglang.srt.configs.model_config import (
...
@@ -35,7 +35,6 @@ from sglang.srt.configs.model_config import (
get_nsa_index_topk
,
get_nsa_index_topk
,
is_deepseek_nsa
,
is_deepseek_nsa
,
)
)
from
sglang.srt.debug_utils.dumper
import
dumper
from
sglang.srt.distributed
import
(
from
sglang.srt.distributed
import
(
get_moe_expert_parallel_world_size
,
get_moe_expert_parallel_world_size
,
get_pp_group
,
get_pp_group
,
...
@@ -108,10 +107,11 @@ from sglang.srt.layers.vocab_parallel_embedding import (
...
@@ -108,10 +107,11 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.single_batch_overlap
import
SboFlags
from
sglang.srt.single_batch_overlap
import
SboFlags
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.srt.two_batch_overlap
import
(
from
sglang.srt.two_batch_overlap
import
(
MaybeTboDeepEPDispatcher
,
MaybeTboDeepEPDispatcher
,
model_forward_maybe_tbo
,
model_forward_maybe_tbo
,
...
@@ -520,7 +520,7 @@ class DeepseekV2MoE(nn.Module):
...
@@ -520,7 +520,7 @@ class DeepseekV2MoE(nn.Module):
self
.
n_shared_experts
=
config
.
n_shared_experts
self
.
n_shared_experts
=
config
.
n_shared_experts
self
.
num_fused_shared_experts
=
(
self
.
num_fused_shared_experts
=
(
0
0
if
global_server_args
_dict
[
"
disable_shared_experts_fusion
"
]
if
get_
global_server_args
().
disable_shared_experts_fusion
else
config
.
n_shared_experts
else
config
.
n_shared_experts
)
)
self
.
config
=
config
self
.
config
=
config
...
@@ -549,7 +549,7 @@ class DeepseekV2MoE(nn.Module):
...
@@ -549,7 +549,7 @@ class DeepseekV2MoE(nn.Module):
self
.
experts
=
get_moe_impl_class
(
quant_config
)(
self
.
experts
=
get_moe_impl_class
(
quant_config
)(
num_experts
=
config
.
n_routed_experts
num_experts
=
config
.
n_routed_experts
+
self
.
num_fused_shared_experts
+
self
.
num_fused_shared_experts
+
global_server_args
_dict
[
"
ep_num_redundant_experts
"
]
,
+
get_
global_server_args
().
ep_num_redundant_experts
,
num_fused_shared_experts
=
self
.
num_fused_shared_experts
,
num_fused_shared_experts
=
self
.
num_fused_shared_experts
,
top_k
=
config
.
num_experts_per_tok
+
self
.
num_fused_shared_experts
,
top_k
=
config
.
num_experts_per_tok
+
self
.
num_fused_shared_experts
,
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
...
@@ -627,7 +627,7 @@ class DeepseekV2MoE(nn.Module):
...
@@ -627,7 +627,7 @@ class DeepseekV2MoE(nn.Module):
self
.
ep_size
=
get_moe_expert_parallel_world_size
()
self
.
ep_size
=
get_moe_expert_parallel_world_size
()
self
.
num_experts
=
(
self
.
num_experts
=
(
config
.
n_routed_experts
config
.
n_routed_experts
+
global_server_args
_dict
[
"
ep_num_redundant_experts
"
]
+
get_
global_server_args
().
ep_num_redundant_experts
)
)
self
.
renormalize
=
config
.
norm_topk_prob
self
.
renormalize
=
config
.
norm_topk_prob
self
.
topk_group
=
config
.
topk_group
self
.
topk_group
=
config
.
topk_group
...
@@ -1125,7 +1125,7 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -1125,7 +1125,7 @@ class DeepseekV2AttentionMLA(nn.Module):
base
=
rope_theta
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
rope_scaling
=
rope_scaling
,
is_neox_style
=
False
,
is_neox_style
=
False
,
device
=
global_server_args
_dict
[
"
device
"
]
,
device
=
get_
global_server_args
().
device
,
)
)
if
rope_scaling
:
if
rope_scaling
:
...
@@ -1169,12 +1169,12 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -1169,12 +1169,12 @@ class DeepseekV2AttentionMLA(nn.Module):
self
.
w_scale_v
=
None
self
.
w_scale_v
=
None
self
.
use_deep_gemm_bmm
=
False
self
.
use_deep_gemm_bmm
=
False
self
.
flashinfer_mla_disable_ragged
=
global_server_args_dict
[
self
.
flashinfer_mla_disable_ragged
=
(
"
flashinfer_mla_disable_ragged
"
get_global_server_args
().
flashinfer_mla_disable_ragged
,
]
)
self
.
disable_chunked_prefix_cache
=
global_server_args_dict
[
self
.
disable_chunked_prefix_cache
=
(
"
disable_chunked_prefix_cache
"
get_global_server_args
().
disable_chunked_prefix_cache
]
)
self
.
current_attention_backend
=
(
self
.
current_attention_backend
=
(
None
# Attention backend used by current forward batch
None
# Attention backend used by current forward batch
...
@@ -1253,18 +1253,18 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -1253,18 +1253,18 @@ class DeepseekV2AttentionMLA(nn.Module):
)
->
AttnForwardMethod
:
)
->
AttnForwardMethod
:
# Determine attention backend used by current forward batch
# Determine attention backend used by current forward batch
if
forward_batch
.
forward_mode
.
is_decode_or_idle
():
if
forward_batch
.
forward_mode
.
is_decode_or_idle
():
attention_backend
=
global_server_args
_dict
[
"
decode_attention_backend
"
]
attention_backend
=
get_
global_server_args
().
decode_attention_backend
elif
(
elif
(
forward_batch
.
forward_mode
.
is_target_verify
()
forward_batch
.
forward_mode
.
is_target_verify
()
or
forward_batch
.
forward_mode
.
is_draft_extend
()
or
forward_batch
.
forward_mode
.
is_draft_extend
()
):
):
# Use the specified backend for speculative operations (both verify and draft extend)
# Use the specified backend for speculative operations (both verify and draft extend)
if
global_server_args
_dict
[
"
speculative_attention_mode
"
]
==
"decode"
:
if
get_
global_server_args
().
speculative_attention_mode
==
"decode"
:
attention_backend
=
global_server_args
_dict
[
"
decode_attention_backend
"
]
attention_backend
=
get_
global_server_args
().
decode_attention_backend
else
:
# default to prefill
else
:
# default to prefill
attention_backend
=
global_server_args
_dict
[
"
prefill_attention_backend
"
]
attention_backend
=
get_
global_server_args
().
prefill_attention_backend
else
:
else
:
attention_backend
=
global_server_args
_dict
[
"
prefill_attention_backend
"
]
attention_backend
=
get_
global_server_args
().
prefill_attention_backend
self
.
current_attention_backend
=
attention_backend
self
.
current_attention_backend
=
attention_backend
handler
=
AttentionBackendRegistry
.
get_handler
(
attention_backend
)
handler
=
AttentionBackendRegistry
.
get_handler
(
attention_backend
)
...
@@ -2365,7 +2365,9 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -2365,7 +2365,9 @@ class DeepseekV2DecoderLayer(nn.Module):
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
8192
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
8192
)
self
.
speculative_algorithm
=
global_server_args_dict
[
"speculative_algorithm"
]
self
.
speculative_algorithm
=
SpeculativeAlgorithm
.
from_string
(
get_global_server_args
().
speculative_algorithm
)
self
.
layer_id
=
layer_id
self
.
layer_id
=
layer_id
self
.
is_nextn
=
is_nextn
self
.
is_nextn
=
is_nextn
self
.
self_attn
=
DeepseekV2AttentionMLA
(
self
.
self_attn
=
DeepseekV2AttentionMLA
(
...
@@ -2817,7 +2819,7 @@ class DeepseekV2ForCausalLM(nn.Module):
...
@@ -2817,7 +2819,7 @@ class DeepseekV2ForCausalLM(nn.Module):
config
.
hidden_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"lm_head"
,
prefix
),
prefix
=
add_prefix
(
"lm_head"
,
prefix
),
use_attn_tp_group
=
global_server_args
_dict
[
"
enable_dp_lm_head
"
]
,
use_attn_tp_group
=
get_
global_server_args
().
enable_dp_lm_head
,
)
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
...
@@ -2837,7 +2839,7 @@ class DeepseekV2ForCausalLM(nn.Module):
...
@@ -2837,7 +2839,7 @@ class DeepseekV2ForCausalLM(nn.Module):
self
,
architecture
:
str
=
"DeepseekV3ForCausalLM"
self
,
architecture
:
str
=
"DeepseekV3ForCausalLM"
):
):
self
.
num_fused_shared_experts
=
0
self
.
num_fused_shared_experts
=
0
if
global_server_args
_dict
[
"
disable_shared_experts_fusion
"
]
:
if
get_
global_server_args
().
disable_shared_experts_fusion
:
return
return
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
...
@@ -2856,7 +2858,7 @@ class DeepseekV2ForCausalLM(nn.Module):
...
@@ -2856,7 +2858,7 @@ class DeepseekV2ForCausalLM(nn.Module):
disable_reason
=
"Deepseek V3/R1 W4AFP8 model uses different quant method for routed experts and shared experts."
disable_reason
=
"Deepseek V3/R1 W4AFP8 model uses different quant method for routed experts and shared experts."
if
disable_reason
is
not
None
:
if
disable_reason
is
not
None
:
global_server_args
_dict
[
"
disable_shared_experts_fusion
"
]
=
True
get_
global_server_args
().
disable_shared_experts_fusion
=
True
self
.
num_fused_shared_experts
=
0
self
.
num_fused_shared_experts
=
0
log_info_on_rank0
(
log_info_on_rank0
(
logger
,
logger
,
...
...
python/sglang/srt/models/falcon_h1.py
View file @
1083e7e3
...
@@ -33,9 +33,9 @@ from sglang.srt.layers.vocab_parallel_embedding import (
...
@@ -33,9 +33,9 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.utils
import
add_prefix
,
is_cuda
,
make_layers
from
sglang.srt.utils
import
add_prefix
,
is_cuda
,
make_layers
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -483,7 +483,7 @@ class FalconH1ForCausalLM(nn.Module):
...
@@ -483,7 +483,7 @@ class FalconH1ForCausalLM(nn.Module):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
org_num_embeddings
=
config
.
vocab_size
,
org_num_embeddings
=
config
.
vocab_size
,
prefix
=
add_prefix
(
"lm_head"
,
prefix
),
prefix
=
add_prefix
(
"lm_head"
,
prefix
),
use_attn_tp_group
=
global_server_args
_dict
[
"
enable_dp_lm_head
"
]
,
use_attn_tp_group
=
get_
global_server_args
().
enable_dp_lm_head
,
)
)
self
.
lm_head
=
self
.
lm_head
.
float
()
self
.
lm_head
=
self
.
lm_head
.
float
()
self
.
lm_head_multiplier
=
config
.
lm_head_multiplier
self
.
lm_head_multiplier
=
config
.
lm_head_multiplier
...
...
python/sglang/srt/models/glm4_moe.py
View file @
1083e7e3
...
@@ -56,18 +56,13 @@ from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
...
@@ -56,18 +56,13 @@ from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FusedMoE
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FusedMoE
from
sglang.srt.layers.moe.topk
import
TopK
from
sglang.srt.layers.moe.topk
import
TopK
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.fp8_kernel
import
(
from
sglang.srt.layers.quantization.fp8_kernel
import
is_fp8_fnuz
is_fp8_fnuz
,
per_tensor_quant_mla_fp8
,
per_token_group_quant_mla_deep_gemm_masked_fp8
,
)
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.vocab_parallel_embedding
import
(
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.cuda_graph_runner
import
get_is_capture_mode
from
sglang.srt.model_executor.cuda_graph_runner
import
get_is_capture_mode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
...
@@ -77,6 +72,7 @@ from sglang.srt.models.deepseek_v2 import (
...
@@ -77,6 +72,7 @@ from sglang.srt.models.deepseek_v2 import (
DeepseekV2Model
,
DeepseekV2Model
,
DeepseekV2MoE
,
DeepseekV2MoE
,
)
)
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.two_batch_overlap
import
MaybeTboDeepEPDispatcher
from
sglang.srt.two_batch_overlap
import
MaybeTboDeepEPDispatcher
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
BumpAllocator
,
BumpAllocator
,
...
@@ -395,7 +391,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
...
@@ -395,7 +391,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
self
.
n_shared_experts
=
config
.
n_shared_experts
self
.
n_shared_experts
=
config
.
n_shared_experts
self
.
num_fused_shared_experts
=
(
self
.
num_fused_shared_experts
=
(
0
0
if
global_server_args
_dict
[
"
disable_shared_experts_fusion
"
]
if
get_
global_server_args
().
disable_shared_experts_fusion
else
config
.
n_shared_experts
else
config
.
n_shared_experts
)
)
self
.
config
=
config
self
.
config
=
config
...
@@ -432,7 +428,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
...
@@ -432,7 +428,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
self
.
experts
=
get_moe_impl_class
(
quant_config
)(
self
.
experts
=
get_moe_impl_class
(
quant_config
)(
num_experts
=
config
.
n_routed_experts
num_experts
=
config
.
n_routed_experts
+
self
.
num_fused_shared_experts
+
self
.
num_fused_shared_experts
+
global_server_args
_dict
[
"
ep_num_redundant_experts
"
]
,
+
get_
global_server_args
().
ep_num_redundant_experts
,
num_fused_shared_experts
=
self
.
num_fused_shared_experts
,
num_fused_shared_experts
=
self
.
num_fused_shared_experts
,
top_k
=
config
.
num_experts_per_tok
+
self
.
num_fused_shared_experts
,
top_k
=
config
.
num_experts_per_tok
+
self
.
num_fused_shared_experts
,
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
...
@@ -476,7 +472,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
...
@@ -476,7 +472,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
self
.
ep_size
=
get_moe_expert_parallel_world_size
()
self
.
ep_size
=
get_moe_expert_parallel_world_size
()
self
.
num_experts
=
(
self
.
num_experts
=
(
config
.
n_routed_experts
config
.
n_routed_experts
+
global_server_args
_dict
[
"
ep_num_redundant_experts
"
]
+
get_
global_server_args
().
ep_num_redundant_experts
)
)
self
.
renormalize
=
config
.
norm_topk_prob
self
.
renormalize
=
config
.
norm_topk_prob
self
.
topk_group
=
config
.
topk_group
self
.
topk_group
=
config
.
topk_group
...
@@ -758,7 +754,7 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
...
@@ -758,7 +754,7 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
config
.
hidden_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"lm_head"
,
prefix
),
prefix
=
add_prefix
(
"lm_head"
,
prefix
),
use_attn_tp_group
=
global_server_args
_dict
[
"
enable_dp_lm_head
"
]
,
use_attn_tp_group
=
get_
global_server_args
().
enable_dp_lm_head
,
)
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
...
@@ -774,7 +770,7 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
...
@@ -774,7 +770,7 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
self
,
architecture
:
str
=
"Glm4MoeForCausalLM"
self
,
architecture
:
str
=
"Glm4MoeForCausalLM"
):
):
self
.
num_fused_shared_experts
=
0
self
.
num_fused_shared_experts
=
0
if
global_server_args
_dict
[
"
disable_shared_experts_fusion
"
]
:
if
get_
global_server_args
().
disable_shared_experts_fusion
:
return
return
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
...
@@ -790,7 +786,7 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
...
@@ -790,7 +786,7 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
disable_reason
=
"Deepseek and GLM-4.5 or GLM-4.6 can not use shared experts fusion optimization under expert parallelism."
disable_reason
=
"Deepseek and GLM-4.5 or GLM-4.6 can not use shared experts fusion optimization under expert parallelism."
if
disable_reason
is
not
None
:
if
disable_reason
is
not
None
:
global_server_args
_dict
[
"
disable_shared_experts_fusion
"
]
=
True
get_
global_server_args
().
disable_shared_experts_fusion
=
True
self
.
num_fused_shared_experts
=
0
self
.
num_fused_shared_experts
=
0
log_info_on_rank0
(
log_info_on_rank0
(
logger
,
logger
,
...
...
python/sglang/srt/models/glm4_moe_nextn.py
View file @
1083e7e3
...
@@ -30,9 +30,9 @@ from sglang.srt.layers.vocab_parallel_embedding import (
...
@@ -30,9 +30,9 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.models.glm4_moe
import
Glm4MoeDecoderLayer
,
Glm4MoeForCausalLM
from
sglang.srt.models.glm4_moe
import
Glm4MoeDecoderLayer
,
Glm4MoeForCausalLM
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.utils
import
BumpAllocator
,
add_prefix
from
sglang.srt.utils
import
BumpAllocator
,
add_prefix
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -145,7 +145,7 @@ class Glm4MoeForCausalLMNextN(Glm4MoeForCausalLM):
...
@@ -145,7 +145,7 @@ class Glm4MoeForCausalLMNextN(Glm4MoeForCausalLM):
config
.
hidden_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"model.shared_head.head"
,
prefix
),
prefix
=
add_prefix
(
"model.shared_head.head"
,
prefix
),
use_attn_tp_group
=
global_server_args
_dict
[
"
enable_dp_lm_head
"
]
,
use_attn_tp_group
=
get_
global_server_args
().
enable_dp_lm_head
,
)
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
...
...
python/sglang/srt/models/glm4v_moe.py
View file @
1083e7e3
...
@@ -16,10 +16,10 @@ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
...
@@ -16,10 +16,10 @@ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from
sglang.srt.layers.pooler
import
Pooler
,
PoolingType
from
sglang.srt.layers.pooler
import
Pooler
,
PoolingType
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.glm4_moe
import
Glm4MoeModel
from
sglang.srt.models.glm4_moe
import
Glm4MoeModel
from
sglang.srt.models.glm4v
import
Glm4vForConditionalGeneration
,
Glm4vVisionModel
from
sglang.srt.models.glm4v
import
Glm4vForConditionalGeneration
,
Glm4vVisionModel
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.utils
import
add_prefix
,
is_cuda
,
log_info_on_rank0
from
sglang.srt.utils
import
add_prefix
,
is_cuda
,
log_info_on_rank0
from
sglang.srt.utils.hf_transformers_utils
import
get_processor
from
sglang.srt.utils.hf_transformers_utils
import
get_processor
...
@@ -47,7 +47,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
...
@@ -47,7 +47,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
self
.
determine_num_fused_shared_experts
(
"Glm4MoeForCausalLM"
)
self
.
determine_num_fused_shared_experts
(
"Glm4MoeForCausalLM"
)
self
.
num_fused_shared_experts
=
(
self
.
num_fused_shared_experts
=
(
0
0
if
global_server_args
_dict
[
"
disable_shared_experts_fusion
"
]
if
get_
global_server_args
().
disable_shared_experts_fusion
else
config
.
n_shared_experts
else
config
.
n_shared_experts
)
)
...
@@ -68,7 +68,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
...
@@ -68,7 +68,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
config
.
hidden_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"lm_head"
,
prefix
),
prefix
=
add_prefix
(
"lm_head"
,
prefix
),
use_attn_tp_group
=
global_server_args
_dict
[
"
enable_dp_lm_head
"
]
,
use_attn_tp_group
=
get_
global_server_args
().
enable_dp_lm_head
,
)
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
...
@@ -81,7 +81,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
...
@@ -81,7 +81,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
self
,
architecture
:
str
=
"Glm4MoeForCausalLM"
self
,
architecture
:
str
=
"Glm4MoeForCausalLM"
):
):
self
.
num_fused_shared_experts
=
0
self
.
num_fused_shared_experts
=
0
if
global_server_args
_dict
[
"
disable_shared_experts_fusion
"
]
:
if
get_
global_server_args
().
disable_shared_experts_fusion
:
return
return
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
...
@@ -97,7 +97,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
...
@@ -97,7 +97,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
disable_reason
=
"Deepseek and GLM-4.5 can not use shared experts fusion optimization under expert parallelism."
disable_reason
=
"Deepseek and GLM-4.5 can not use shared experts fusion optimization under expert parallelism."
if
disable_reason
is
not
None
:
if
disable_reason
is
not
None
:
global_server_args
_dict
[
"
disable_shared_experts_fusion
"
]
=
True
get_
global_server_args
().
disable_shared_experts_fusion
=
True
self
.
num_fused_shared_experts
=
0
self
.
num_fused_shared_experts
=
0
log_info_on_rank0
(
log_info_on_rank0
(
logger
,
logger
,
...
...
python/sglang/srt/models/gpt_oss.py
View file @
1083e7e3
...
@@ -63,13 +63,13 @@ from sglang.srt.layers.vocab_parallel_embedding import (
...
@@ -63,13 +63,13 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.utils
import
(
from
sglang.srt.models.utils
import
(
create_fused_set_kv_buffer_arg
,
create_fused_set_kv_buffer_arg
,
enable_fused_set_kv_buffer
,
enable_fused_set_kv_buffer
,
)
)
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
LazyValue
,
LazyValue
,
add_prefix
,
add_prefix
,
...
@@ -138,7 +138,7 @@ class GptOssSparseMoeBlock(nn.Module):
...
@@ -138,7 +138,7 @@ class GptOssSparseMoeBlock(nn.Module):
}
}
self
.
experts
=
experts_type
(
self
.
experts
=
experts_type
(
num_experts
=
config
.
num_local_experts
num_experts
=
config
.
num_local_experts
+
global_server_args
_dict
[
"
ep_num_redundant_experts
"
]
,
+
get_
global_server_args
().
ep_num_redundant_experts
,
top_k
=
config
.
num_experts_per_tok
,
top_k
=
config
.
num_experts_per_tok
,
layer_id
=
layer_id
,
layer_id
=
layer_id
,
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
...
@@ -259,7 +259,7 @@ class GptOssAttention(nn.Module):
...
@@ -259,7 +259,7 @@ class GptOssAttention(nn.Module):
# Choose dtype of sinks based on attention backend: trtllm_mha requires float32,
# Choose dtype of sinks based on attention backend: trtllm_mha requires float32,
# others can use bfloat16
# others can use bfloat16
attn_backend
=
global_server_args
_dict
.
get
(
"
attention_backend
"
)
attn_backend
=
get_
global_server_args
().
attention_backend
sinks_dtype
=
torch
.
float32
if
attn_backend
==
"trtllm_mha"
else
torch
.
bfloat16
sinks_dtype
=
torch
.
float32
if
attn_backend
==
"trtllm_mha"
else
torch
.
bfloat16
self
.
sinks
=
nn
.
Parameter
(
self
.
sinks
=
nn
.
Parameter
(
torch
.
empty
(
self
.
num_heads
,
dtype
=
sinks_dtype
),
requires_grad
=
False
torch
.
empty
(
self
.
num_heads
,
dtype
=
sinks_dtype
),
requires_grad
=
False
...
@@ -591,7 +591,7 @@ class GptOssForCausalLM(nn.Module):
...
@@ -591,7 +591,7 @@ class GptOssForCausalLM(nn.Module):
config
.
hidden_size
,
config
.
hidden_size
,
# quant_config=quant_config,
# quant_config=quant_config,
prefix
=
add_prefix
(
"lm_head"
,
prefix
),
prefix
=
add_prefix
(
"lm_head"
,
prefix
),
use_attn_tp_group
=
global_server_args
_dict
[
"
enable_dp_lm_head
"
]
,
use_attn_tp_group
=
get_
global_server_args
().
enable_dp_lm_head
,
)
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
capture_aux_hidden_states
=
False
self
.
capture_aux_hidden_states
=
False
...
...
python/sglang/srt/models/grok.py
View file @
1083e7e3
...
@@ -28,7 +28,6 @@ from torch import nn
...
@@ -28,7 +28,6 @@ from torch import nn
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
sglang.srt.distributed
import
(
from
sglang.srt.distributed
import
(
get_moe_expert_parallel_world_size
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_gather
,
...
@@ -36,7 +35,6 @@ from sglang.srt.distributed import (
...
@@ -36,7 +35,6 @@ from sglang.srt.distributed import (
)
)
from
sglang.srt.layers.activation
import
GeluAndMul
from
sglang.srt.layers.activation
import
GeluAndMul
from
sglang.srt.layers.elementwise
import
(
from
sglang.srt.layers.elementwise
import
(
experts_combine_triton
,
fused_dual_residual_rmsnorm
,
fused_dual_residual_rmsnorm
,
fused_rmsnorm
,
fused_rmsnorm
,
gelu_and_mul_triton
,
gelu_and_mul_triton
,
...
@@ -64,10 +62,10 @@ from sglang.srt.layers.vocab_parallel_embedding import (
...
@@ -64,10 +62,10 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.loader
import
DefaultModelLoader
from
sglang.srt.model_loader.loader
import
DefaultModelLoader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.utils
import
add_prefix
,
dispose_tensor
,
dump_to_file
from
sglang.srt.utils
import
add_prefix
,
dispose_tensor
,
dump_to_file
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -869,10 +867,10 @@ class Grok1ForCausalLM(nn.Module):
...
@@ -869,10 +867,10 @@ class Grok1ForCausalLM(nn.Module):
# Dump tensors for debugging
# Dump tensors for debugging
global
debug_tensor_dump_output_folder
,
debug_tensor_dump_inject
global
debug_tensor_dump_output_folder
,
debug_tensor_dump_inject
debug_tensor_dump_output_folder
=
global_server_args_dict
[
debug_tensor_dump_output_folder
=
(
"
debug_tensor_dump_output_folder
"
get_global_server_args
().
debug_tensor_dump_output_folder
]
)
debug_tensor_dump_inject
=
global_server_args
_dict
[
"
debug_tensor_dump_inject
"
]
debug_tensor_dump_inject
=
get_
global_server_args
().
debug_tensor_dump_inject
warnings
.
filterwarnings
(
"ignore"
,
category
=
FutureWarning
)
warnings
.
filterwarnings
(
"ignore"
,
category
=
FutureWarning
)
if
get_tensor_model_parallel_rank
()
==
0
:
if
get_tensor_model_parallel_rank
()
==
0
:
...
...
python/sglang/srt/models/llama.py
View file @
1083e7e3
...
@@ -45,13 +45,13 @@ from sglang.srt.layers.vocab_parallel_embedding import (
...
@@ -45,13 +45,13 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_loader.weight_utils
import
(
from
sglang.srt.model_loader.weight_utils
import
(
default_weight_loader
,
default_weight_loader
,
kv_cache_scales_loader
,
kv_cache_scales_loader
,
maybe_remap_kv_scale_name
,
maybe_remap_kv_scale_name
,
)
)
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.utils
import
add_prefix
,
make_layers
from
sglang.srt.utils
import
add_prefix
,
make_layers
from
sglang.utils
import
get_exception_traceback
from
sglang.utils
import
get_exception_traceback
...
@@ -433,7 +433,7 @@ class LlamaForCausalLM(nn.Module):
...
@@ -433,7 +433,7 @@ class LlamaForCausalLM(nn.Module):
config
.
hidden_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"lm_head"
,
prefix
),
prefix
=
add_prefix
(
"lm_head"
,
prefix
),
use_attn_tp_group
=
global_server_args
_dict
[
"
enable_dp_lm_head
"
]
,
use_attn_tp_group
=
get_
global_server_args
().
enable_dp_lm_head
,
)
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
...
...
python/sglang/srt/models/longcat_flash.py
View file @
1083e7e3
...
@@ -32,14 +32,10 @@
...
@@ -32,14 +32,10 @@
import
concurrent.futures
import
concurrent.futures
import
logging
import
logging
import
os
from
typing
import
Iterable
,
Optional
,
Tuple
from
enum
import
IntEnum
,
auto
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Tuple
,
Union
import
torch
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
torch
import
nn
from
tqdm
import
tqdm
from
sglang.srt.configs
import
LongcatFlashConfig
from
sglang.srt.configs
import
LongcatFlashConfig
from
sglang.srt.distributed
import
(
from
sglang.srt.distributed
import
(
...
@@ -85,10 +81,10 @@ from sglang.srt.layers.vocab_parallel_embedding import (
...
@@ -85,10 +81,10 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.deepseek_v2
import
DeepseekV2AttentionMLA
from
sglang.srt.models.deepseek_v2
import
DeepseekV2AttentionMLA
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
BumpAllocator
,
BumpAllocator
,
LazyValue
,
LazyValue
,
...
@@ -595,7 +591,7 @@ class LongcatFlashForCausalLM(nn.Module):
...
@@ -595,7 +591,7 @@ class LongcatFlashForCausalLM(nn.Module):
config
.
hidden_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"lm_head"
,
prefix
),
prefix
=
add_prefix
(
"lm_head"
,
prefix
),
use_attn_tp_group
=
global_server_args
_dict
[
"
enable_dp_lm_head
"
]
,
use_attn_tp_group
=
get_
global_server_args
().
enable_dp_lm_head
,
)
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
...
...
python/sglang/srt/models/mllama4.py
View file @
1083e7e3
...
@@ -31,9 +31,9 @@ from sglang.srt.managers.schedule_batch import (
...
@@ -31,9 +31,9 @@ from sglang.srt.managers.schedule_batch import (
Modality
,
Modality
,
MultimodalDataItem
,
MultimodalDataItem
,
MultimodalInputs
,
MultimodalInputs
,
global_server_args_dict
,
)
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.utils
import
is_cpu
from
sglang.srt.utils
import
is_cpu
_is_cpu
=
is_cpu
()
_is_cpu
=
is_cpu
()
...
@@ -448,7 +448,7 @@ class Llama4ForConditionalGeneration(nn.Module):
...
@@ -448,7 +448,7 @@ class Llama4ForConditionalGeneration(nn.Module):
)
)
self
.
has_vision
=
(
self
.
has_vision
=
(
self
.
has_vision_weights
and
global_server_args
_dict
[
"
enable_multimodal
"
]
self
.
has_vision_weights
and
get_
global_server_args
().
enable_multimodal
)
)
if
self
.
has_vision
:
if
self
.
has_vision
:
...
...
python/sglang/srt/models/qwen2_moe.py
View file @
1083e7e3
...
@@ -64,10 +64,10 @@ from sglang.srt.layers.vocab_parallel_embedding import (
...
@@ -64,10 +64,10 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.cuda_graph_runner
import
get_is_capture_mode
from
sglang.srt.model_executor.cuda_graph_runner
import
get_is_capture_mode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.two_batch_overlap
import
model_forward_maybe_tbo
from
sglang.srt.two_batch_overlap
import
model_forward_maybe_tbo
from
sglang.srt.utils
import
add_prefix
,
is_cuda
,
make_layers
from
sglang.srt.utils
import
add_prefix
,
is_cuda
,
make_layers
...
@@ -156,7 +156,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
...
@@ -156,7 +156,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
layer_id
=
self
.
layer_id
,
layer_id
=
self
.
layer_id
,
top_k
=
config
.
num_experts_per_tok
,
top_k
=
config
.
num_experts_per_tok
,
num_experts
=
config
.
num_experts
num_experts
=
config
.
num_experts
+
global_server_args
_dict
[
"
ep_num_redundant_experts
"
]
,
+
get_
global_server_args
().
ep_num_redundant_experts
,
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
...
@@ -192,7 +192,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
...
@@ -192,7 +192,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
# TODO: we will support tp < ep in the future
# TODO: we will support tp < ep in the future
self
.
ep_size
=
get_moe_expert_parallel_world_size
()
self
.
ep_size
=
get_moe_expert_parallel_world_size
()
self
.
num_experts
=
(
self
.
num_experts
=
(
config
.
num_experts
+
global_server_args
_dict
[
"
ep_num_redundant_experts
"
]
config
.
num_experts
+
get_
global_server_args
().
ep_num_redundant_experts
)
)
self
.
top_k
=
config
.
num_experts_per_tok
self
.
top_k
=
config
.
num_experts_per_tok
...
@@ -643,7 +643,7 @@ class Qwen2MoeForCausalLM(nn.Module):
...
@@ -643,7 +643,7 @@ class Qwen2MoeForCausalLM(nn.Module):
config
.
hidden_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"lm_head"
,
prefix
),
prefix
=
add_prefix
(
"lm_head"
,
prefix
),
use_attn_tp_group
=
global_server_args
_dict
[
"
enable_dp_lm_head
"
]
,
use_attn_tp_group
=
get_
global_server_args
().
enable_dp_lm_head
,
)
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
# For EAGLE3 support
# For EAGLE3 support
...
...
python/sglang/srt/models/qwen3_moe.py
View file @
1083e7e3
...
@@ -54,7 +54,6 @@ from sglang.srt.layers.radix_attention import RadixAttention
...
@@ -54,7 +54,6 @@ from sglang.srt.layers.radix_attention import RadixAttention
from
sglang.srt.layers.rotary_embedding
import
MRotaryEmbedding
,
get_rope
from
sglang.srt.layers.rotary_embedding
import
MRotaryEmbedding
,
get_rope
from
sglang.srt.layers.utils
import
get_layer_id
from
sglang.srt.layers.utils
import
get_layer_id
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.cuda_graph_runner
import
get_is_capture_mode
from
sglang.srt.model_executor.cuda_graph_runner
import
get_is_capture_mode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
...
@@ -64,6 +63,7 @@ from sglang.srt.models.utils import (
...
@@ -64,6 +63,7 @@ from sglang.srt.models.utils import (
create_fused_set_kv_buffer_arg
,
create_fused_set_kv_buffer_arg
,
enable_fused_set_kv_buffer
,
enable_fused_set_kv_buffer
,
)
)
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
add_prefix
,
add_prefix
,
is_cuda
,
is_cuda
,
...
@@ -104,7 +104,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
...
@@ -104,7 +104,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
self
.
experts
=
get_moe_impl_class
(
quant_config
)(
self
.
experts
=
get_moe_impl_class
(
quant_config
)(
num_experts
=
config
.
num_experts
num_experts
=
config
.
num_experts
+
global_server_args
_dict
[
"
ep_num_redundant_experts
"
]
,
+
get_
global_server_args
().
ep_num_redundant_experts
,
top_k
=
config
.
num_experts_per_tok
,
top_k
=
config
.
num_experts_per_tok
,
layer_id
=
layer_id
,
layer_id
=
layer_id
,
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
...
@@ -125,7 +125,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
...
@@ -125,7 +125,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
# TODO: we will support tp < ep in the future
# TODO: we will support tp < ep in the future
self
.
ep_size
=
get_moe_expert_parallel_world_size
()
self
.
ep_size
=
get_moe_expert_parallel_world_size
()
self
.
num_experts
=
(
self
.
num_experts
=
(
config
.
num_experts
+
global_server_args
_dict
[
"
ep_num_redundant_experts
"
]
config
.
num_experts
+
get_
global_server_args
().
ep_num_redundant_experts
)
)
self
.
top_k
=
config
.
num_experts_per_tok
self
.
top_k
=
config
.
num_experts_per_tok
...
@@ -693,7 +693,7 @@ class Qwen3MoeForCausalLM(nn.Module):
...
@@ -693,7 +693,7 @@ class Qwen3MoeForCausalLM(nn.Module):
config
.
hidden_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"lm_head"
,
prefix
),
prefix
=
add_prefix
(
"lm_head"
,
prefix
),
use_attn_tp_group
=
global_server_args
_dict
[
"
enable_dp_lm_head
"
]
,
use_attn_tp_group
=
get_
global_server_args
().
enable_dp_lm_head
,
)
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
capture_aux_hidden_states
=
False
self
.
capture_aux_hidden_states
=
False
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment