Unverified Commit 1083e7e3 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Deprecate `global_server_args_dict` (#11331)

parent 2157d12a
...@@ -11,7 +11,7 @@ from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator ...@@ -11,7 +11,7 @@ from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import support_triton from sglang.srt.utils import support_triton
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -19,10 +19,6 @@ if TYPE_CHECKING: ...@@ -19,10 +19,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
GLOBAL_SERVER_ARGS_KEYS = ["attention_backend"]
global_server_args_dict = {k: getattr(ServerArgs, k) for k in GLOBAL_SERVER_ARGS_KEYS}
@triton.jit @triton.jit
def write_req_to_token_pool_triton( def write_req_to_token_pool_triton(
...@@ -88,7 +84,7 @@ def write_cache_indices( ...@@ -88,7 +84,7 @@ def write_cache_indices(
prefix_tensors: list[torch.Tensor], prefix_tensors: list[torch.Tensor],
req_to_token_pool: ReqToTokenPool, req_to_token_pool: ReqToTokenPool,
): ):
if support_triton(global_server_args_dict.get("attention_backend")): if support_triton(get_global_server_args().attention_backend):
prefix_pointers = torch.tensor( prefix_pointers = torch.tensor(
[t.data_ptr() for t in prefix_tensors], [t.data_ptr() for t in prefix_tensors],
device=req_to_token_pool.device, device=req_to_token_pool.device,
...@@ -129,8 +125,8 @@ def get_last_loc( ...@@ -129,8 +125,8 @@ def get_last_loc(
prefix_lens_tensor: torch.Tensor, prefix_lens_tensor: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
if ( if (
global_server_args_dict["attention_backend"] != "ascend" get_global_server_args().attention_backend != "ascend"
and global_server_args_dict["attention_backend"] != "torch_native" and get_global_server_args().attention_backend != "torch_native"
): ):
impl = get_last_loc_triton impl = get_last_loc_triton
else: else:
......
...@@ -83,10 +83,6 @@ from sglang.srt.layers.sampler import Sampler ...@@ -83,10 +83,6 @@ from sglang.srt.layers.sampler import Sampler
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
from sglang.srt.lora.lora_manager import LoRAManager from sglang.srt.lora.lora_manager import LoRAManager
from sglang.srt.lora.lora_registry import LoRARef from sglang.srt.lora.lora_registry import LoRARef
from sglang.srt.managers.schedule_batch import (
GLOBAL_SERVER_ARGS_KEYS,
global_server_args_dict,
)
from sglang.srt.mem_cache.allocator import ( from sglang.srt.mem_cache.allocator import (
BaseTokenToKVPoolAllocator, BaseTokenToKVPoolAllocator,
PagedTokenToKVPoolAllocator, PagedTokenToKVPoolAllocator,
...@@ -125,7 +121,11 @@ from sglang.srt.model_loader.remote_instance_weight_loader_utils import ( ...@@ -125,7 +121,11 @@ from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
from sglang.srt.model_loader.utils import set_default_torch_dtype from sglang.srt.model_loader.utils import set_default_torch_dtype
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import (
ServerArgs,
get_global_server_args,
set_global_server_args_for_scheduler,
)
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import ( from sglang.srt.utils import (
MultiprocessingSerializer, MultiprocessingSerializer,
...@@ -275,15 +275,12 @@ class ModelRunner: ...@@ -275,15 +275,12 @@ class ModelRunner:
# Model-specific adjustment # Model-specific adjustment
self.model_specific_adjustment() self.model_specific_adjustment()
# Global vars # Set the global server_args in the scheduler process
global_server_args_dict.update( set_global_server_args_for_scheduler(server_args)
{k: getattr(server_args, k) for k in GLOBAL_SERVER_ARGS_KEYS} global_server_args = get_global_server_args()
| {
# TODO it is indeed not a "server args" # FIXME: hacky set `use_mla_backend`
"use_mla_backend": self.use_mla_backend, global_server_args.use_mla_backend = self.use_mla_backend
"speculative_algorithm": self.spec_algorithm,
}
)
# Init OpenMP threads binding for CPU # Init OpenMP threads binding for CPU
if self.device == "cpu": if self.device == "cpu":
...@@ -432,7 +429,7 @@ class ModelRunner: ...@@ -432,7 +429,7 @@ class ModelRunner:
# In layered loading, torchao may have been applied # In layered loading, torchao may have been applied
if not torchao_applied: if not torchao_applied:
apply_torchao_config_to_model( apply_torchao_config_to_model(
self.model, global_server_args_dict["torchao_config"] self.model, get_global_server_args().torchao_config
) )
# Apply torch TP if the model supports it # Apply torch TP if the model supports it
...@@ -1838,12 +1835,10 @@ class ModelRunner: ...@@ -1838,12 +1835,10 @@ class ModelRunner:
self.server_args.attention_backend self.server_args.attention_backend
) )
global_server_args_dict.update( (
{ get_global_server_args().prefill_attention_backend,
"decode_attention_backend": self.decode_attention_backend_str, get_global_server_args().decode_attention_backend,
"prefill_attention_backend": self.prefill_attention_backend_str, ) = (self.prefill_attention_backend_str, self.decode_attention_backend_str)
}
)
return attn_backend return attn_backend
def _get_attention_backend_from_str(self, backend_str: str): def _get_attention_backend_from_str(self, backend_str: str):
......
...@@ -4,7 +4,6 @@ from __future__ import annotations ...@@ -4,7 +4,6 @@ from __future__ import annotations
# ruff: noqa: SIM117 # ruff: noqa: SIM117
import collections import collections
import concurrent
import dataclasses import dataclasses
import fnmatch import fnmatch
import glob import glob
...@@ -12,12 +11,10 @@ import json ...@@ -12,12 +11,10 @@ import json
import logging import logging
import math import math
import os import os
import re
import socket import socket
import threading import threading
import time import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager, suppress from contextlib import contextmanager, suppress
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
...@@ -33,10 +30,10 @@ from typing import ( ...@@ -33,10 +30,10 @@ from typing import (
import huggingface_hub import huggingface_hub
import numpy as np import numpy as np
import requests
import safetensors.torch
import torch import torch
from sglang.srt.server_args import get_global_server_args
# Try to import accelerate (optional dependency) # Try to import accelerate (optional dependency)
try: try:
from accelerate import infer_auto_device_map, init_empty_weights from accelerate import infer_auto_device_map, init_empty_weights
...@@ -81,8 +78,6 @@ DEFAULT_GPU_MEMORY_FRACTION_FOR_CALIBRATION = ( ...@@ -81,8 +78,6 @@ DEFAULT_GPU_MEMORY_FRACTION_FOR_CALIBRATION = (
0.8 # Reserve 20% GPU memory headroom for ModelOpt calibration 0.8 # Reserve 20% GPU memory headroom for ModelOpt calibration
) )
from sglang.srt.model_loader.weight_utils import ( from sglang.srt.model_loader.weight_utils import (
_BAR_FORMAT,
default_weight_loader,
download_safetensors_index_file_from_hf, download_safetensors_index_file_from_hf,
download_weights_from_hf, download_weights_from_hf,
filter_duplicate_safetensors_files, filter_duplicate_safetensors_files,
...@@ -445,10 +440,8 @@ class DefaultModelLoader(BaseModelLoader): ...@@ -445,10 +440,8 @@ class DefaultModelLoader(BaseModelLoader):
hf_weights_files, hf_weights_files,
) )
elif use_safetensors: elif use_safetensors:
from sglang.srt.managers.schedule_batch import global_server_args_dict weight_loader_disable_mmap = (
get_global_server_args().weight_loader_disable_mmap
weight_loader_disable_mmap = global_server_args_dict.get(
"weight_loader_disable_mmap"
) )
if extra_config.get("enable_multithread_load"): if extra_config.get("enable_multithread_load"):
...@@ -616,9 +609,9 @@ class LayeredModelLoader(DefaultModelLoader): ...@@ -616,9 +609,9 @@ class LayeredModelLoader(DefaultModelLoader):
device_config: DeviceConfig, device_config: DeviceConfig,
) -> nn.Module: ) -> nn.Module:
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.server_args import get_global_server_args
torchao_config = global_server_args_dict.get("torchao_config") torchao_config = get_global_server_args().torchao_config
target_device = torch.device(device_config.device) target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype): with set_default_torch_dtype(model_config.dtype):
......
...@@ -46,15 +46,14 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -46,15 +46,14 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import ( from sglang.srt.model_loader.weight_utils import (
default_weight_loader, default_weight_loader,
kv_cache_scales_loader, kv_cache_scales_loader,
maybe_remap_kv_scale_name, maybe_remap_kv_scale_name,
) )
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import add_prefix, make_layers from sglang.srt.utils import add_prefix, make_layers
from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -447,7 +446,7 @@ class ApertusForCausalLM(nn.Module): ...@@ -447,7 +446,7 @@ class ApertusForCausalLM(nn.Module):
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("lm_head", prefix), prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
......
...@@ -42,13 +42,13 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -42,13 +42,13 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import ( from sglang.srt.model_loader.weight_utils import (
default_weight_loader, default_weight_loader,
kv_cache_scales_loader, kv_cache_scales_loader,
maybe_remap_kv_scale_name, maybe_remap_kv_scale_name,
) )
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import add_prefix, make_layers from sglang.srt.utils import add_prefix, make_layers
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -407,7 +407,7 @@ class ArceeForCausalLM(nn.Module): ...@@ -407,7 +407,7 @@ class ArceeForCausalLM(nn.Module):
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("lm_head", prefix), prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" SGLang BailingMoE model.""" """SGLang BailingMoE model."""
import logging import logging
from typing import Any, Dict, Iterable, Optional, Tuple, Union from typing import Any, Dict, Iterable, Optional, Tuple, Union
...@@ -68,7 +68,6 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -68,7 +68,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
...@@ -76,6 +75,7 @@ from sglang.srt.models.utils import ( ...@@ -76,6 +75,7 @@ from sglang.srt.models.utils import (
create_fused_set_kv_buffer_arg, create_fused_set_kv_buffer_arg,
enable_fused_set_kv_buffer, enable_fused_set_kv_buffer,
) )
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import add_prefix, is_cuda, is_non_idle_and_non_empty, make_layers from sglang.srt.utils import add_prefix, is_cuda, is_non_idle_and_non_empty, make_layers
LoraConfig = None LoraConfig = None
...@@ -204,8 +204,8 @@ class BailingMoESparseMoeBlock(nn.Module): ...@@ -204,8 +204,8 @@ class BailingMoESparseMoeBlock(nn.Module):
else: else:
self.router_dtype = torch.bfloat16 self.router_dtype = torch.bfloat16
# TODO global_server_args_dict["ep_num_redundant_experts"] is used for eplb, not supported now # TODO global_server_args.ep_num_redundant_experts is used for eplb, not supported now
assert global_server_args_dict["ep_num_redundant_experts"] == 0 assert get_global_server_args().ep_num_redundant_experts == 0
# check group topk # check group topk
self.num_expert_group = getattr(config, "n_group", 0) self.num_expert_group = getattr(config, "n_group", 0)
self.topk_group = getattr(config, "topk_group", 0) self.topk_group = getattr(config, "topk_group", 0)
...@@ -220,7 +220,7 @@ class BailingMoESparseMoeBlock(nn.Module): ...@@ -220,7 +220,7 @@ class BailingMoESparseMoeBlock(nn.Module):
self.use_grouped_topk = False self.use_grouped_topk = False
self.num_experts = ( self.num_experts = (
config.num_experts + global_server_args_dict["ep_num_redundant_experts"] config.num_experts + get_global_server_args().ep_num_redundant_experts
) )
self.gate = BailingMoEGate( self.gate = BailingMoEGate(
...@@ -824,7 +824,7 @@ class BailingMoEForCausalLM(nn.Module): ...@@ -824,7 +824,7 @@ class BailingMoEForCausalLM(nn.Module):
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("lm_head", prefix), prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" SGLang BailingMoENextN model.""" """SGLang BailingMoENextN model."""
import logging import logging
from typing import Iterable, Optional, Tuple from typing import Iterable, Optional, Tuple
...@@ -29,15 +29,14 @@ from sglang.srt.distributed import get_tensor_model_parallel_world_size ...@@ -29,15 +29,14 @@ from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.dp_attention import is_dp_attention_enabled from sglang.srt.layers.dp_attention import is_dp_attention_enabled
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.bailing_moe import BailingMoEBlock, BailingMoEForCausalLM from sglang.srt.models.bailing_moe import BailingMoEBlock, BailingMoEForCausalLM
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import add_prefix from sglang.srt.utils import add_prefix
LoraConfig = None LoraConfig = None
...@@ -145,7 +144,7 @@ class BailingMoeForCausalLMNextN(BailingMoEForCausalLM): ...@@ -145,7 +144,7 @@ class BailingMoeForCausalLMNextN(BailingMoEForCausalLM):
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("model.shared_head.head", prefix), prefix=add_prefix("model.shared_head.head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
......
...@@ -30,9 +30,9 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -30,9 +30,9 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import BumpAllocator, add_prefix, is_cuda from sglang.srt.utils import BumpAllocator, add_prefix, is_cuda
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -152,7 +152,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM): ...@@ -152,7 +152,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("model.shared_head.head", prefix), prefix=add_prefix("model.shared_head.head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
......
...@@ -35,7 +35,6 @@ from sglang.srt.configs.model_config import ( ...@@ -35,7 +35,6 @@ from sglang.srt.configs.model_config import (
get_nsa_index_topk, get_nsa_index_topk,
is_deepseek_nsa, is_deepseek_nsa,
) )
from sglang.srt.debug_utils.dumper import dumper
from sglang.srt.distributed import ( from sglang.srt.distributed import (
get_moe_expert_parallel_world_size, get_moe_expert_parallel_world_size,
get_pp_group, get_pp_group,
...@@ -108,10 +107,11 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -108,10 +107,11 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.server_args import get_global_server_args
from sglang.srt.single_batch_overlap import SboFlags from sglang.srt.single_batch_overlap import SboFlags
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.two_batch_overlap import ( from sglang.srt.two_batch_overlap import (
MaybeTboDeepEPDispatcher, MaybeTboDeepEPDispatcher,
model_forward_maybe_tbo, model_forward_maybe_tbo,
...@@ -520,7 +520,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -520,7 +520,7 @@ class DeepseekV2MoE(nn.Module):
self.n_shared_experts = config.n_shared_experts self.n_shared_experts = config.n_shared_experts
self.num_fused_shared_experts = ( self.num_fused_shared_experts = (
0 0
if global_server_args_dict["disable_shared_experts_fusion"] if get_global_server_args().disable_shared_experts_fusion
else config.n_shared_experts else config.n_shared_experts
) )
self.config = config self.config = config
...@@ -549,7 +549,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -549,7 +549,7 @@ class DeepseekV2MoE(nn.Module):
self.experts = get_moe_impl_class(quant_config)( self.experts = get_moe_impl_class(quant_config)(
num_experts=config.n_routed_experts num_experts=config.n_routed_experts
+ self.num_fused_shared_experts + self.num_fused_shared_experts
+ global_server_args_dict["ep_num_redundant_experts"], + get_global_server_args().ep_num_redundant_experts,
num_fused_shared_experts=self.num_fused_shared_experts, num_fused_shared_experts=self.num_fused_shared_experts,
top_k=config.num_experts_per_tok + self.num_fused_shared_experts, top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
...@@ -627,7 +627,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -627,7 +627,7 @@ class DeepseekV2MoE(nn.Module):
self.ep_size = get_moe_expert_parallel_world_size() self.ep_size = get_moe_expert_parallel_world_size()
self.num_experts = ( self.num_experts = (
config.n_routed_experts config.n_routed_experts
+ global_server_args_dict["ep_num_redundant_experts"] + get_global_server_args().ep_num_redundant_experts
) )
self.renormalize = config.norm_topk_prob self.renormalize = config.norm_topk_prob
self.topk_group = config.topk_group self.topk_group = config.topk_group
...@@ -1125,7 +1125,7 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1125,7 +1125,7 @@ class DeepseekV2AttentionMLA(nn.Module):
base=rope_theta, base=rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
is_neox_style=False, is_neox_style=False,
device=global_server_args_dict["device"], device=get_global_server_args().device,
) )
if rope_scaling: if rope_scaling:
...@@ -1169,12 +1169,12 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1169,12 +1169,12 @@ class DeepseekV2AttentionMLA(nn.Module):
self.w_scale_v = None self.w_scale_v = None
self.use_deep_gemm_bmm = False self.use_deep_gemm_bmm = False
self.flashinfer_mla_disable_ragged = global_server_args_dict[ self.flashinfer_mla_disable_ragged = (
"flashinfer_mla_disable_ragged" get_global_server_args().flashinfer_mla_disable_ragged,
] )
self.disable_chunked_prefix_cache = global_server_args_dict[ self.disable_chunked_prefix_cache = (
"disable_chunked_prefix_cache" get_global_server_args().disable_chunked_prefix_cache
] )
self.current_attention_backend = ( self.current_attention_backend = (
None # Attention backend used by current forward batch None # Attention backend used by current forward batch
...@@ -1253,18 +1253,18 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1253,18 +1253,18 @@ class DeepseekV2AttentionMLA(nn.Module):
) -> AttnForwardMethod: ) -> AttnForwardMethod:
# Determine attention backend used by current forward batch # Determine attention backend used by current forward batch
if forward_batch.forward_mode.is_decode_or_idle(): if forward_batch.forward_mode.is_decode_or_idle():
attention_backend = global_server_args_dict["decode_attention_backend"] attention_backend = get_global_server_args().decode_attention_backend
elif ( elif (
forward_batch.forward_mode.is_target_verify() forward_batch.forward_mode.is_target_verify()
or forward_batch.forward_mode.is_draft_extend() or forward_batch.forward_mode.is_draft_extend()
): ):
# Use the specified backend for speculative operations (both verify and draft extend) # Use the specified backend for speculative operations (both verify and draft extend)
if global_server_args_dict["speculative_attention_mode"] == "decode": if get_global_server_args().speculative_attention_mode == "decode":
attention_backend = global_server_args_dict["decode_attention_backend"] attention_backend = get_global_server_args().decode_attention_backend
else: # default to prefill else: # default to prefill
attention_backend = global_server_args_dict["prefill_attention_backend"] attention_backend = get_global_server_args().prefill_attention_backend
else: else:
attention_backend = global_server_args_dict["prefill_attention_backend"] attention_backend = get_global_server_args().prefill_attention_backend
self.current_attention_backend = attention_backend self.current_attention_backend = attention_backend
handler = AttentionBackendRegistry.get_handler(attention_backend) handler = AttentionBackendRegistry.get_handler(attention_backend)
...@@ -2365,7 +2365,9 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -2365,7 +2365,9 @@ class DeepseekV2DecoderLayer(nn.Module):
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None) rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192) max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.speculative_algorithm = global_server_args_dict["speculative_algorithm"] self.speculative_algorithm = SpeculativeAlgorithm.from_string(
get_global_server_args().speculative_algorithm
)
self.layer_id = layer_id self.layer_id = layer_id
self.is_nextn = is_nextn self.is_nextn = is_nextn
self.self_attn = DeepseekV2AttentionMLA( self.self_attn = DeepseekV2AttentionMLA(
...@@ -2817,7 +2819,7 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -2817,7 +2819,7 @@ class DeepseekV2ForCausalLM(nn.Module):
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("lm_head", prefix), prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
...@@ -2837,7 +2839,7 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -2837,7 +2839,7 @@ class DeepseekV2ForCausalLM(nn.Module):
self, architecture: str = "DeepseekV3ForCausalLM" self, architecture: str = "DeepseekV3ForCausalLM"
): ):
self.num_fused_shared_experts = 0 self.num_fused_shared_experts = 0
if global_server_args_dict["disable_shared_experts_fusion"]: if get_global_server_args().disable_shared_experts_fusion:
return return
# Only Deepseek V3/R1 can use shared experts fusion optimization now. # Only Deepseek V3/R1 can use shared experts fusion optimization now.
...@@ -2856,7 +2858,7 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -2856,7 +2858,7 @@ class DeepseekV2ForCausalLM(nn.Module):
disable_reason = "Deepseek V3/R1 W4AFP8 model uses different quant method for routed experts and shared experts." disable_reason = "Deepseek V3/R1 W4AFP8 model uses different quant method for routed experts and shared experts."
if disable_reason is not None: if disable_reason is not None:
global_server_args_dict["disable_shared_experts_fusion"] = True get_global_server_args().disable_shared_experts_fusion = True
self.num_fused_shared_experts = 0 self.num_fused_shared_experts = 0
log_info_on_rank0( log_info_on_rank0(
logger, logger,
......
...@@ -33,9 +33,9 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -33,9 +33,9 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import add_prefix, is_cuda, make_layers from sglang.srt.utils import add_prefix, is_cuda, make_layers
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -483,7 +483,7 @@ class FalconH1ForCausalLM(nn.Module): ...@@ -483,7 +483,7 @@ class FalconH1ForCausalLM(nn.Module):
quant_config=quant_config, quant_config=quant_config,
org_num_embeddings=config.vocab_size, org_num_embeddings=config.vocab_size,
prefix=add_prefix("lm_head", prefix), prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
) )
self.lm_head = self.lm_head.float() self.lm_head = self.lm_head.float()
self.lm_head_multiplier = config.lm_head_multiplier self.lm_head_multiplier = config.lm_head_multiplier
......
...@@ -56,18 +56,13 @@ from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class ...@@ -56,18 +56,13 @@ from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
is_fp8_fnuz,
per_tensor_quant_mla_fp8,
per_token_group_quant_mla_deep_gemm_masked_fp8,
)
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
...@@ -77,6 +72,7 @@ from sglang.srt.models.deepseek_v2 import ( ...@@ -77,6 +72,7 @@ from sglang.srt.models.deepseek_v2 import (
DeepseekV2Model, DeepseekV2Model,
DeepseekV2MoE, DeepseekV2MoE,
) )
from sglang.srt.server_args import get_global_server_args
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
from sglang.srt.utils import ( from sglang.srt.utils import (
BumpAllocator, BumpAllocator,
...@@ -395,7 +391,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): ...@@ -395,7 +391,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
self.n_shared_experts = config.n_shared_experts self.n_shared_experts = config.n_shared_experts
self.num_fused_shared_experts = ( self.num_fused_shared_experts = (
0 0
if global_server_args_dict["disable_shared_experts_fusion"] if get_global_server_args().disable_shared_experts_fusion
else config.n_shared_experts else config.n_shared_experts
) )
self.config = config self.config = config
...@@ -432,7 +428,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): ...@@ -432,7 +428,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
self.experts = get_moe_impl_class(quant_config)( self.experts = get_moe_impl_class(quant_config)(
num_experts=config.n_routed_experts num_experts=config.n_routed_experts
+ self.num_fused_shared_experts + self.num_fused_shared_experts
+ global_server_args_dict["ep_num_redundant_experts"], + get_global_server_args().ep_num_redundant_experts,
num_fused_shared_experts=self.num_fused_shared_experts, num_fused_shared_experts=self.num_fused_shared_experts,
top_k=config.num_experts_per_tok + self.num_fused_shared_experts, top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
...@@ -476,7 +472,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): ...@@ -476,7 +472,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
self.ep_size = get_moe_expert_parallel_world_size() self.ep_size = get_moe_expert_parallel_world_size()
self.num_experts = ( self.num_experts = (
config.n_routed_experts config.n_routed_experts
+ global_server_args_dict["ep_num_redundant_experts"] + get_global_server_args().ep_num_redundant_experts
) )
self.renormalize = config.norm_topk_prob self.renormalize = config.norm_topk_prob
self.topk_group = config.topk_group self.topk_group = config.topk_group
...@@ -758,7 +754,7 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM): ...@@ -758,7 +754,7 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("lm_head", prefix), prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
...@@ -774,7 +770,7 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM): ...@@ -774,7 +770,7 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
self, architecture: str = "Glm4MoeForCausalLM" self, architecture: str = "Glm4MoeForCausalLM"
): ):
self.num_fused_shared_experts = 0 self.num_fused_shared_experts = 0
if global_server_args_dict["disable_shared_experts_fusion"]: if get_global_server_args().disable_shared_experts_fusion:
return return
# Only Deepseek V3/R1 can use shared experts fusion optimization now. # Only Deepseek V3/R1 can use shared experts fusion optimization now.
...@@ -790,7 +786,7 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM): ...@@ -790,7 +786,7 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
disable_reason = "Deepseek and GLM-4.5 or GLM-4.6 can not use shared experts fusion optimization under expert parallelism." disable_reason = "Deepseek and GLM-4.5 or GLM-4.6 can not use shared experts fusion optimization under expert parallelism."
if disable_reason is not None: if disable_reason is not None:
global_server_args_dict["disable_shared_experts_fusion"] = True get_global_server_args().disable_shared_experts_fusion = True
self.num_fused_shared_experts = 0 self.num_fused_shared_experts = 0
log_info_on_rank0( log_info_on_rank0(
logger, logger,
......
...@@ -30,9 +30,9 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -30,9 +30,9 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.glm4_moe import Glm4MoeDecoderLayer, Glm4MoeForCausalLM from sglang.srt.models.glm4_moe import Glm4MoeDecoderLayer, Glm4MoeForCausalLM
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import BumpAllocator, add_prefix from sglang.srt.utils import BumpAllocator, add_prefix
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -145,7 +145,7 @@ class Glm4MoeForCausalLMNextN(Glm4MoeForCausalLM): ...@@ -145,7 +145,7 @@ class Glm4MoeForCausalLMNextN(Glm4MoeForCausalLM):
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("model.shared_head.head", prefix), prefix=add_prefix("model.shared_head.head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
......
...@@ -16,10 +16,10 @@ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE ...@@ -16,10 +16,10 @@ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.glm4_moe import Glm4MoeModel from sglang.srt.models.glm4_moe import Glm4MoeModel
from sglang.srt.models.glm4v import Glm4vForConditionalGeneration, Glm4vVisionModel from sglang.srt.models.glm4v import Glm4vForConditionalGeneration, Glm4vVisionModel
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import add_prefix, is_cuda, log_info_on_rank0 from sglang.srt.utils import add_prefix, is_cuda, log_info_on_rank0
from sglang.srt.utils.hf_transformers_utils import get_processor from sglang.srt.utils.hf_transformers_utils import get_processor
...@@ -47,7 +47,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration): ...@@ -47,7 +47,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
self.determine_num_fused_shared_experts("Glm4MoeForCausalLM") self.determine_num_fused_shared_experts("Glm4MoeForCausalLM")
self.num_fused_shared_experts = ( self.num_fused_shared_experts = (
0 0
if global_server_args_dict["disable_shared_experts_fusion"] if get_global_server_args().disable_shared_experts_fusion
else config.n_shared_experts else config.n_shared_experts
) )
...@@ -68,7 +68,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration): ...@@ -68,7 +68,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("lm_head", prefix), prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
...@@ -81,7 +81,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration): ...@@ -81,7 +81,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
self, architecture: str = "Glm4MoeForCausalLM" self, architecture: str = "Glm4MoeForCausalLM"
): ):
self.num_fused_shared_experts = 0 self.num_fused_shared_experts = 0
if global_server_args_dict["disable_shared_experts_fusion"]: if get_global_server_args().disable_shared_experts_fusion:
return return
# Only Deepseek V3/R1 can use shared experts fusion optimization now. # Only Deepseek V3/R1 can use shared experts fusion optimization now.
...@@ -97,7 +97,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration): ...@@ -97,7 +97,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
disable_reason = "Deepseek and GLM-4.5 can not use shared experts fusion optimization under expert parallelism." disable_reason = "Deepseek and GLM-4.5 can not use shared experts fusion optimization under expert parallelism."
if disable_reason is not None: if disable_reason is not None:
global_server_args_dict["disable_shared_experts_fusion"] = True get_global_server_args().disable_shared_experts_fusion = True
self.num_fused_shared_experts = 0 self.num_fused_shared_experts = 0
log_info_on_rank0( log_info_on_rank0(
logger, logger,
......
...@@ -63,13 +63,13 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -63,13 +63,13 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.utils import ( from sglang.srt.models.utils import (
create_fused_set_kv_buffer_arg, create_fused_set_kv_buffer_arg,
enable_fused_set_kv_buffer, enable_fused_set_kv_buffer,
) )
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import ( from sglang.srt.utils import (
LazyValue, LazyValue,
add_prefix, add_prefix,
...@@ -138,7 +138,7 @@ class GptOssSparseMoeBlock(nn.Module): ...@@ -138,7 +138,7 @@ class GptOssSparseMoeBlock(nn.Module):
} }
self.experts = experts_type( self.experts = experts_type(
num_experts=config.num_local_experts num_experts=config.num_local_experts
+ global_server_args_dict["ep_num_redundant_experts"], + get_global_server_args().ep_num_redundant_experts,
top_k=config.num_experts_per_tok, top_k=config.num_experts_per_tok,
layer_id=layer_id, layer_id=layer_id,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
...@@ -259,7 +259,7 @@ class GptOssAttention(nn.Module): ...@@ -259,7 +259,7 @@ class GptOssAttention(nn.Module):
# Choose dtype of sinks based on attention backend: trtllm_mha requires float32, # Choose dtype of sinks based on attention backend: trtllm_mha requires float32,
# others can use bfloat16 # others can use bfloat16
attn_backend = global_server_args_dict.get("attention_backend") attn_backend = get_global_server_args().attention_backend
sinks_dtype = torch.float32 if attn_backend == "trtllm_mha" else torch.bfloat16 sinks_dtype = torch.float32 if attn_backend == "trtllm_mha" else torch.bfloat16
self.sinks = nn.Parameter( self.sinks = nn.Parameter(
torch.empty(self.num_heads, dtype=sinks_dtype), requires_grad=False torch.empty(self.num_heads, dtype=sinks_dtype), requires_grad=False
...@@ -591,7 +591,7 @@ class GptOssForCausalLM(nn.Module): ...@@ -591,7 +591,7 @@ class GptOssForCausalLM(nn.Module):
config.hidden_size, config.hidden_size,
# quant_config=quant_config, # quant_config=quant_config,
prefix=add_prefix("lm_head", prefix), prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.capture_aux_hidden_states = False self.capture_aux_hidden_states = False
......
...@@ -28,7 +28,6 @@ from torch import nn ...@@ -28,7 +28,6 @@ from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from sglang.srt.distributed import ( from sglang.srt.distributed import (
get_moe_expert_parallel_world_size,
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
...@@ -36,7 +35,6 @@ from sglang.srt.distributed import ( ...@@ -36,7 +35,6 @@ from sglang.srt.distributed import (
) )
from sglang.srt.layers.activation import GeluAndMul from sglang.srt.layers.activation import GeluAndMul
from sglang.srt.layers.elementwise import ( from sglang.srt.layers.elementwise import (
experts_combine_triton,
fused_dual_residual_rmsnorm, fused_dual_residual_rmsnorm,
fused_rmsnorm, fused_rmsnorm,
gelu_and_mul_triton, gelu_and_mul_triton,
...@@ -64,10 +62,10 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -64,10 +62,10 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.loader import DefaultModelLoader from sglang.srt.model_loader.loader import DefaultModelLoader
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import add_prefix, dispose_tensor, dump_to_file from sglang.srt.utils import add_prefix, dispose_tensor, dump_to_file
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -869,10 +867,10 @@ class Grok1ForCausalLM(nn.Module): ...@@ -869,10 +867,10 @@ class Grok1ForCausalLM(nn.Module):
# Dump tensors for debugging # Dump tensors for debugging
global debug_tensor_dump_output_folder, debug_tensor_dump_inject global debug_tensor_dump_output_folder, debug_tensor_dump_inject
debug_tensor_dump_output_folder = global_server_args_dict[ debug_tensor_dump_output_folder = (
"debug_tensor_dump_output_folder" get_global_server_args().debug_tensor_dump_output_folder
] )
debug_tensor_dump_inject = global_server_args_dict["debug_tensor_dump_inject"] debug_tensor_dump_inject = get_global_server_args().debug_tensor_dump_inject
warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", category=FutureWarning)
if get_tensor_model_parallel_rank() == 0: if get_tensor_model_parallel_rank() == 0:
......
...@@ -45,13 +45,13 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -45,13 +45,13 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import ( from sglang.srt.model_loader.weight_utils import (
default_weight_loader, default_weight_loader,
kv_cache_scales_loader, kv_cache_scales_loader,
maybe_remap_kv_scale_name, maybe_remap_kv_scale_name,
) )
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import add_prefix, make_layers from sglang.srt.utils import add_prefix, make_layers
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
...@@ -433,7 +433,7 @@ class LlamaForCausalLM(nn.Module): ...@@ -433,7 +433,7 @@ class LlamaForCausalLM(nn.Module):
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("lm_head", prefix), prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
......
...@@ -32,14 +32,10 @@ ...@@ -32,14 +32,10 @@
import concurrent.futures import concurrent.futures
import logging import logging
import os from typing import Iterable, Optional, Tuple
from enum import IntEnum, auto
from typing import Any, Dict, Iterable, Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F
from torch import nn from torch import nn
from tqdm import tqdm
from sglang.srt.configs import LongcatFlashConfig from sglang.srt.configs import LongcatFlashConfig
from sglang.srt.distributed import ( from sglang.srt.distributed import (
...@@ -85,10 +81,10 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -85,10 +81,10 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.deepseek_v2 import DeepseekV2AttentionMLA from sglang.srt.models.deepseek_v2 import DeepseekV2AttentionMLA
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import ( from sglang.srt.utils import (
BumpAllocator, BumpAllocator,
LazyValue, LazyValue,
...@@ -595,7 +591,7 @@ class LongcatFlashForCausalLM(nn.Module): ...@@ -595,7 +591,7 @@ class LongcatFlashForCausalLM(nn.Module):
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("lm_head", prefix), prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
......
...@@ -31,9 +31,9 @@ from sglang.srt.managers.schedule_batch import ( ...@@ -31,9 +31,9 @@ from sglang.srt.managers.schedule_batch import (
Modality, Modality,
MultimodalDataItem, MultimodalDataItem,
MultimodalInputs, MultimodalInputs,
global_server_args_dict,
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import is_cpu from sglang.srt.utils import is_cpu
_is_cpu = is_cpu() _is_cpu = is_cpu()
...@@ -448,7 +448,7 @@ class Llama4ForConditionalGeneration(nn.Module): ...@@ -448,7 +448,7 @@ class Llama4ForConditionalGeneration(nn.Module):
) )
self.has_vision = ( self.has_vision = (
self.has_vision_weights and global_server_args_dict["enable_multimodal"] self.has_vision_weights and get_global_server_args().enable_multimodal
) )
if self.has_vision: if self.has_vision:
......
...@@ -64,10 +64,10 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -64,10 +64,10 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.server_args import get_global_server_args
from sglang.srt.two_batch_overlap import model_forward_maybe_tbo from sglang.srt.two_batch_overlap import model_forward_maybe_tbo
from sglang.srt.utils import add_prefix, is_cuda, make_layers from sglang.srt.utils import add_prefix, is_cuda, make_layers
...@@ -156,7 +156,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module): ...@@ -156,7 +156,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
layer_id=self.layer_id, layer_id=self.layer_id,
top_k=config.num_experts_per_tok, top_k=config.num_experts_per_tok,
num_experts=config.num_experts num_experts=config.num_experts
+ global_server_args_dict["ep_num_redundant_experts"], + get_global_server_args().ep_num_redundant_experts,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size, intermediate_size=config.moe_intermediate_size,
quant_config=quant_config, quant_config=quant_config,
...@@ -192,7 +192,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module): ...@@ -192,7 +192,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
# TODO: we will support tp < ep in the future # TODO: we will support tp < ep in the future
self.ep_size = get_moe_expert_parallel_world_size() self.ep_size = get_moe_expert_parallel_world_size()
self.num_experts = ( self.num_experts = (
config.num_experts + global_server_args_dict["ep_num_redundant_experts"] config.num_experts + get_global_server_args().ep_num_redundant_experts
) )
self.top_k = config.num_experts_per_tok self.top_k = config.num_experts_per_tok
...@@ -643,7 +643,7 @@ class Qwen2MoeForCausalLM(nn.Module): ...@@ -643,7 +643,7 @@ class Qwen2MoeForCausalLM(nn.Module):
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("lm_head", prefix), prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
# For EAGLE3 support # For EAGLE3 support
......
...@@ -54,7 +54,6 @@ from sglang.srt.layers.radix_attention import RadixAttention ...@@ -54,7 +54,6 @@ from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding, get_rope from sglang.srt.layers.rotary_embedding import MRotaryEmbedding, get_rope
from sglang.srt.layers.utils import get_layer_id from sglang.srt.layers.utils import get_layer_id
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
...@@ -64,6 +63,7 @@ from sglang.srt.models.utils import ( ...@@ -64,6 +63,7 @@ from sglang.srt.models.utils import (
create_fused_set_kv_buffer_arg, create_fused_set_kv_buffer_arg,
enable_fused_set_kv_buffer, enable_fused_set_kv_buffer,
) )
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import ( from sglang.srt.utils import (
add_prefix, add_prefix,
is_cuda, is_cuda,
...@@ -104,7 +104,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ...@@ -104,7 +104,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
self.experts = get_moe_impl_class(quant_config)( self.experts = get_moe_impl_class(quant_config)(
num_experts=config.num_experts num_experts=config.num_experts
+ global_server_args_dict["ep_num_redundant_experts"], + get_global_server_args().ep_num_redundant_experts,
top_k=config.num_experts_per_tok, top_k=config.num_experts_per_tok,
layer_id=layer_id, layer_id=layer_id,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
...@@ -125,7 +125,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ...@@ -125,7 +125,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
# TODO: we will support tp < ep in the future # TODO: we will support tp < ep in the future
self.ep_size = get_moe_expert_parallel_world_size() self.ep_size = get_moe_expert_parallel_world_size()
self.num_experts = ( self.num_experts = (
config.num_experts + global_server_args_dict["ep_num_redundant_experts"] config.num_experts + get_global_server_args().ep_num_redundant_experts
) )
self.top_k = config.num_experts_per_tok self.top_k = config.num_experts_per_tok
...@@ -693,7 +693,7 @@ class Qwen3MoeForCausalLM(nn.Module): ...@@ -693,7 +693,7 @@ class Qwen3MoeForCausalLM(nn.Module):
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("lm_head", prefix), prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.capture_aux_hidden_states = False self.capture_aux_hidden_states = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment