Unverified Commit 1bdd0102 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

Revert "Deprecate `global_server_args_dict`" (#11520)

parent 6cd29694
......@@ -11,7 +11,7 @@ from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
from sglang.srt.server_args import get_global_server_args
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import support_triton
if TYPE_CHECKING:
......@@ -19,6 +19,10 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
GLOBAL_SERVER_ARGS_KEYS = ["attention_backend"]
global_server_args_dict = {k: getattr(ServerArgs, k) for k in GLOBAL_SERVER_ARGS_KEYS}
@triton.jit
def write_req_to_token_pool_triton(
......@@ -84,7 +88,7 @@ def write_cache_indices(
prefix_tensors: list[torch.Tensor],
req_to_token_pool: ReqToTokenPool,
):
if support_triton(get_global_server_args().attention_backend):
if support_triton(global_server_args_dict.get("attention_backend")):
prefix_pointers = torch.tensor(
[t.data_ptr() for t in prefix_tensors],
device=req_to_token_pool.device,
......@@ -125,8 +129,8 @@ def get_last_loc(
prefix_lens_tensor: torch.Tensor,
) -> torch.Tensor:
if (
get_global_server_args().attention_backend != "ascend"
and get_global_server_args().attention_backend != "torch_native"
global_server_args_dict["attention_backend"] != "ascend"
and global_server_args_dict["attention_backend"] != "torch_native"
):
impl = get_last_loc_triton
else:
......
......@@ -83,6 +83,10 @@ from sglang.srt.layers.sampler import Sampler
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
from sglang.srt.lora.lora_manager import LoRAManager
from sglang.srt.lora.lora_registry import LoRARef
from sglang.srt.managers.schedule_batch import (
GLOBAL_SERVER_ARGS_KEYS,
global_server_args_dict,
)
from sglang.srt.mem_cache.allocator import (
BaseTokenToKVPoolAllocator,
PagedTokenToKVPoolAllocator,
......@@ -121,11 +125,7 @@ from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
from sglang.srt.model_loader.utils import set_default_torch_dtype
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.server_args import (
ServerArgs,
get_global_server_args,
set_global_server_args_for_scheduler,
)
from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import (
MultiprocessingSerializer,
......@@ -275,12 +275,15 @@ class ModelRunner:
# Model-specific adjustment
self.model_specific_adjustment()
# Set the global server_args in the scheduler process
set_global_server_args_for_scheduler(server_args)
global_server_args = get_global_server_args()
# FIXME: hacky set `use_mla_backend`
global_server_args.use_mla_backend = self.use_mla_backend
# Global vars
global_server_args_dict.update(
{k: getattr(server_args, k) for k in GLOBAL_SERVER_ARGS_KEYS}
| {
# TODO it is indeed not a "server args"
"use_mla_backend": self.use_mla_backend,
"speculative_algorithm": self.spec_algorithm,
}
)
# Init OpenMP threads binding for CPU
if self.device == "cpu":
......@@ -429,7 +432,7 @@ class ModelRunner:
# In layered loading, torchao may have been applied
if not torchao_applied:
apply_torchao_config_to_model(
self.model, get_global_server_args().torchao_config
self.model, global_server_args_dict["torchao_config"]
)
# Apply torch TP if the model supports it
......@@ -1835,10 +1838,12 @@ class ModelRunner:
self.server_args.attention_backend
)
(
get_global_server_args().prefill_attention_backend,
get_global_server_args().decode_attention_backend,
) = (self.prefill_attention_backend_str, self.decode_attention_backend_str)
global_server_args_dict.update(
{
"decode_attention_backend": self.decode_attention_backend_str,
"prefill_attention_backend": self.prefill_attention_backend_str,
}
)
return attn_backend
def _get_attention_backend_from_str(self, backend_str: str):
......
......@@ -4,6 +4,7 @@ from __future__ import annotations
# ruff: noqa: SIM117
import collections
import concurrent
import dataclasses
import fnmatch
import glob
......@@ -11,10 +12,12 @@ import json
import logging
import math
import os
import re
import socket
import threading
import time
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager, suppress
from typing import (
TYPE_CHECKING,
......@@ -30,10 +33,10 @@ from typing import (
import huggingface_hub
import numpy as np
import requests
import safetensors.torch
import torch
from sglang.srt.server_args import get_global_server_args
# Try to import accelerate (optional dependency)
try:
from accelerate import infer_auto_device_map, init_empty_weights
......@@ -78,6 +81,8 @@ DEFAULT_GPU_MEMORY_FRACTION_FOR_CALIBRATION = (
0.8 # Reserve 20% GPU memory headroom for ModelOpt calibration
)
from sglang.srt.model_loader.weight_utils import (
_BAR_FORMAT,
default_weight_loader,
download_safetensors_index_file_from_hf,
download_weights_from_hf,
filter_duplicate_safetensors_files,
......@@ -440,8 +445,10 @@ class DefaultModelLoader(BaseModelLoader):
hf_weights_files,
)
elif use_safetensors:
weight_loader_disable_mmap = (
get_global_server_args().weight_loader_disable_mmap
from sglang.srt.managers.schedule_batch import global_server_args_dict
weight_loader_disable_mmap = global_server_args_dict.get(
"weight_loader_disable_mmap"
)
if extra_config.get("enable_multithread_load"):
......@@ -609,9 +616,9 @@ class LayeredModelLoader(DefaultModelLoader):
device_config: DeviceConfig,
) -> nn.Module:
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
from sglang.srt.server_args import get_global_server_args
from sglang.srt.managers.schedule_batch import global_server_args_dict
torchao_config = get_global_server_args().torchao_config
torchao_config = global_server_args_dict.get("torchao_config")
target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype):
......
......@@ -46,14 +46,15 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import (
default_weight_loader,
kv_cache_scales_loader,
maybe_remap_kv_scale_name,
)
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import add_prefix, make_layers
from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__)
......@@ -446,7 +447,7 @@ class ApertusForCausalLM(nn.Module):
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
)
self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
......
......@@ -42,13 +42,13 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import (
default_weight_loader,
kv_cache_scales_loader,
maybe_remap_kv_scale_name,
)
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import add_prefix, make_layers
logger = logging.getLogger(__name__)
......@@ -407,7 +407,7 @@ class ArceeForCausalLM(nn.Module):
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
)
self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
......
......@@ -17,7 +17,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""SGLang BailingMoE model."""
""" SGLang BailingMoE model."""
import logging
from typing import Any, Dict, Iterable, Optional, Tuple, Union
......@@ -68,6 +68,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader
......@@ -75,7 +76,6 @@ from sglang.srt.models.utils import (
create_fused_set_kv_buffer_arg,
enable_fused_set_kv_buffer,
)
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import add_prefix, is_cuda, is_non_idle_and_non_empty, make_layers
LoraConfig = None
......@@ -204,8 +204,8 @@ class BailingMoESparseMoeBlock(nn.Module):
else:
self.router_dtype = torch.bfloat16
# TODO global_server_args.ep_num_redundant_experts is used for eplb, not supported now
assert get_global_server_args().ep_num_redundant_experts == 0
# TODO global_server_args_dict["ep_num_redundant_experts"] is used for eplb, not supported now
assert global_server_args_dict["ep_num_redundant_experts"] == 0
# check group topk
self.num_expert_group = getattr(config, "n_group", 0)
self.topk_group = getattr(config, "topk_group", 0)
......@@ -220,7 +220,7 @@ class BailingMoESparseMoeBlock(nn.Module):
self.use_grouped_topk = False
self.num_experts = (
config.num_experts + get_global_server_args().ep_num_redundant_experts
config.num_experts + global_server_args_dict["ep_num_redundant_experts"]
)
self.gate = BailingMoEGate(
......@@ -824,7 +824,7 @@ class BailingMoEForCausalLM(nn.Module):
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
)
self.logits_processor = LogitsProcessor(config)
......
......@@ -17,7 +17,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""SGLang BailingMoENextN model."""
""" SGLang BailingMoENextN model."""
import logging
from typing import Iterable, Optional, Tuple
......@@ -29,14 +29,15 @@ from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.dp_attention import is_dp_attention_enabled
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.bailing_moe import BailingMoEBlock, BailingMoEForCausalLM
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import add_prefix
LoraConfig = None
......@@ -144,7 +145,7 @@ class BailingMoeForCausalLMNextN(BailingMoEForCausalLM):
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("model.shared_head.head", prefix),
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
)
self.logits_processor = LogitsProcessor(config)
......
......@@ -30,9 +30,9 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import BumpAllocator, add_prefix, is_cuda
logger = logging.getLogger(__name__)
......@@ -152,7 +152,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("model.shared_head.head", prefix),
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
)
self.logits_processor = LogitsProcessor(config)
......
......@@ -35,6 +35,7 @@ from sglang.srt.configs.model_config import (
get_nsa_index_topk,
is_deepseek_nsa,
)
from sglang.srt.debug_utils.dumper import dumper
from sglang.srt.distributed import (
get_moe_expert_parallel_world_size,
get_pp_group,
......@@ -107,11 +108,10 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.server_args import get_global_server_args
from sglang.srt.single_batch_overlap import SboFlags
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.two_batch_overlap import (
MaybeTboDeepEPDispatcher,
model_forward_maybe_tbo,
......@@ -520,7 +520,7 @@ class DeepseekV2MoE(nn.Module):
self.n_shared_experts = config.n_shared_experts
self.num_fused_shared_experts = (
0
if get_global_server_args().disable_shared_experts_fusion
if global_server_args_dict["disable_shared_experts_fusion"]
else config.n_shared_experts
)
self.config = config
......@@ -549,7 +549,7 @@ class DeepseekV2MoE(nn.Module):
self.experts = get_moe_impl_class(quant_config)(
num_experts=config.n_routed_experts
+ self.num_fused_shared_experts
+ get_global_server_args().ep_num_redundant_experts,
+ global_server_args_dict["ep_num_redundant_experts"],
num_fused_shared_experts=self.num_fused_shared_experts,
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
hidden_size=config.hidden_size,
......@@ -627,7 +627,7 @@ class DeepseekV2MoE(nn.Module):
self.ep_size = get_moe_expert_parallel_world_size()
self.num_experts = (
config.n_routed_experts
+ get_global_server_args().ep_num_redundant_experts
+ global_server_args_dict["ep_num_redundant_experts"]
)
self.renormalize = config.norm_topk_prob
self.topk_group = config.topk_group
......@@ -1125,7 +1125,7 @@ class DeepseekV2AttentionMLA(nn.Module):
base=rope_theta,
rope_scaling=rope_scaling,
is_neox_style=False,
device=get_global_server_args().device,
device=global_server_args_dict["device"],
)
if rope_scaling:
......@@ -1169,12 +1169,12 @@ class DeepseekV2AttentionMLA(nn.Module):
self.w_scale_v = None
self.use_deep_gemm_bmm = False
self.flashinfer_mla_disable_ragged = (
get_global_server_args().flashinfer_mla_disable_ragged,
)
self.disable_chunked_prefix_cache = (
get_global_server_args().disable_chunked_prefix_cache
)
self.flashinfer_mla_disable_ragged = global_server_args_dict[
"flashinfer_mla_disable_ragged"
]
self.disable_chunked_prefix_cache = global_server_args_dict[
"disable_chunked_prefix_cache"
]
self.current_attention_backend = (
None # Attention backend used by current forward batch
......@@ -1253,18 +1253,18 @@ class DeepseekV2AttentionMLA(nn.Module):
) -> AttnForwardMethod:
# Determine attention backend used by current forward batch
if forward_batch.forward_mode.is_decode_or_idle():
attention_backend = get_global_server_args().decode_attention_backend
attention_backend = global_server_args_dict["decode_attention_backend"]
elif (
forward_batch.forward_mode.is_target_verify()
or forward_batch.forward_mode.is_draft_extend()
):
# Use the specified backend for speculative operations (both verify and draft extend)
if get_global_server_args().speculative_attention_mode == "decode":
attention_backend = get_global_server_args().decode_attention_backend
if global_server_args_dict["speculative_attention_mode"] == "decode":
attention_backend = global_server_args_dict["decode_attention_backend"]
else: # default to prefill
attention_backend = get_global_server_args().prefill_attention_backend
attention_backend = global_server_args_dict["prefill_attention_backend"]
else:
attention_backend = get_global_server_args().prefill_attention_backend
attention_backend = global_server_args_dict["prefill_attention_backend"]
self.current_attention_backend = attention_backend
handler = AttentionBackendRegistry.get_handler(attention_backend)
......@@ -2365,9 +2365,7 @@ class DeepseekV2DecoderLayer(nn.Module):
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.speculative_algorithm = SpeculativeAlgorithm.from_string(
get_global_server_args().speculative_algorithm
)
self.speculative_algorithm = global_server_args_dict["speculative_algorithm"]
self.layer_id = layer_id
self.is_nextn = is_nextn
self.self_attn = DeepseekV2AttentionMLA(
......@@ -2819,7 +2817,7 @@ class DeepseekV2ForCausalLM(nn.Module):
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
)
self.logits_processor = LogitsProcessor(config)
......@@ -2839,7 +2837,7 @@ class DeepseekV2ForCausalLM(nn.Module):
self, architecture: str = "DeepseekV3ForCausalLM"
):
self.num_fused_shared_experts = 0
if get_global_server_args().disable_shared_experts_fusion:
if global_server_args_dict["disable_shared_experts_fusion"]:
return
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
......@@ -2858,7 +2856,7 @@ class DeepseekV2ForCausalLM(nn.Module):
disable_reason = "Deepseek V3/R1 W4AFP8 model uses different quant method for routed experts and shared experts."
if disable_reason is not None:
get_global_server_args().disable_shared_experts_fusion = True
global_server_args_dict["disable_shared_experts_fusion"] = True
self.num_fused_shared_experts = 0
log_info_on_rank0(
logger,
......
......@@ -33,9 +33,9 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import add_prefix, is_cuda, make_layers
logger = logging.getLogger(__name__)
......@@ -483,7 +483,7 @@ class FalconH1ForCausalLM(nn.Module):
quant_config=quant_config,
org_num_embeddings=config.vocab_size,
prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
)
self.lm_head = self.lm_head.float()
self.lm_head_multiplier = config.lm_head_multiplier
......
......@@ -56,13 +56,18 @@ from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
from sglang.srt.layers.quantization.fp8_kernel import (
is_fp8_fnuz,
per_tensor_quant_mla_fp8,
per_token_group_quant_mla_deep_gemm_masked_fp8,
)
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
......@@ -72,7 +77,6 @@ from sglang.srt.models.deepseek_v2 import (
DeepseekV2Model,
DeepseekV2MoE,
)
from sglang.srt.server_args import get_global_server_args
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
from sglang.srt.utils import (
BumpAllocator,
......@@ -391,7 +395,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
self.n_shared_experts = config.n_shared_experts
self.num_fused_shared_experts = (
0
if get_global_server_args().disable_shared_experts_fusion
if global_server_args_dict["disable_shared_experts_fusion"]
else config.n_shared_experts
)
self.config = config
......@@ -428,7 +432,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
self.experts = get_moe_impl_class(quant_config)(
num_experts=config.n_routed_experts
+ self.num_fused_shared_experts
+ get_global_server_args().ep_num_redundant_experts,
+ global_server_args_dict["ep_num_redundant_experts"],
num_fused_shared_experts=self.num_fused_shared_experts,
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
hidden_size=config.hidden_size,
......@@ -472,7 +476,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
self.ep_size = get_moe_expert_parallel_world_size()
self.num_experts = (
config.n_routed_experts
+ get_global_server_args().ep_num_redundant_experts
+ global_server_args_dict["ep_num_redundant_experts"]
)
self.renormalize = config.norm_topk_prob
self.topk_group = config.topk_group
......@@ -754,7 +758,7 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
)
self.logits_processor = LogitsProcessor(config)
......@@ -770,7 +774,7 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
self, architecture: str = "Glm4MoeForCausalLM"
):
self.num_fused_shared_experts = 0
if get_global_server_args().disable_shared_experts_fusion:
if global_server_args_dict["disable_shared_experts_fusion"]:
return
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
......@@ -786,7 +790,7 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
disable_reason = "Deepseek and GLM-4.5 or GLM-4.6 can not use shared experts fusion optimization under expert parallelism."
if disable_reason is not None:
get_global_server_args().disable_shared_experts_fusion = True
global_server_args_dict["disable_shared_experts_fusion"] = True
self.num_fused_shared_experts = 0
log_info_on_rank0(
logger,
......
......@@ -30,9 +30,9 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.glm4_moe import Glm4MoeDecoderLayer, Glm4MoeForCausalLM
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import BumpAllocator, add_prefix
logger = logging.getLogger(__name__)
......@@ -145,7 +145,7 @@ class Glm4MoeForCausalLMNextN(Glm4MoeForCausalLM):
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("model.shared_head.head", prefix),
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
)
self.logits_processor = LogitsProcessor(config)
......
......@@ -16,10 +16,10 @@ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.glm4_moe import Glm4MoeModel
from sglang.srt.models.glm4v import Glm4vForConditionalGeneration, Glm4vVisionModel
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import add_prefix, is_cuda, log_info_on_rank0
from sglang.srt.utils.hf_transformers_utils import get_processor
......@@ -47,7 +47,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
self.determine_num_fused_shared_experts("Glm4MoeForCausalLM")
self.num_fused_shared_experts = (
0
if get_global_server_args().disable_shared_experts_fusion
if global_server_args_dict["disable_shared_experts_fusion"]
else config.n_shared_experts
)
......@@ -68,7 +68,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
)
self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
......@@ -81,7 +81,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
self, architecture: str = "Glm4MoeForCausalLM"
):
self.num_fused_shared_experts = 0
if get_global_server_args().disable_shared_experts_fusion:
if global_server_args_dict["disable_shared_experts_fusion"]:
return
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
......@@ -97,7 +97,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
disable_reason = "Deepseek and GLM-4.5 can not use shared experts fusion optimization under expert parallelism."
if disable_reason is not None:
get_global_server_args().disable_shared_experts_fusion = True
global_server_args_dict["disable_shared_experts_fusion"] = True
self.num_fused_shared_experts = 0
log_info_on_rank0(
logger,
......
......@@ -63,13 +63,13 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.utils import (
create_fused_set_kv_buffer_arg,
enable_fused_set_kv_buffer,
)
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import (
LazyValue,
add_prefix,
......@@ -138,7 +138,7 @@ class GptOssSparseMoeBlock(nn.Module):
}
self.experts = experts_type(
num_experts=config.num_local_experts
+ get_global_server_args().ep_num_redundant_experts,
+ global_server_args_dict["ep_num_redundant_experts"],
top_k=config.num_experts_per_tok,
layer_id=layer_id,
hidden_size=config.hidden_size,
......@@ -259,7 +259,7 @@ class GptOssAttention(nn.Module):
# Choose dtype of sinks based on attention backend: trtllm_mha requires float32,
# others can use bfloat16
attn_backend = get_global_server_args().attention_backend
attn_backend = global_server_args_dict.get("attention_backend")
sinks_dtype = torch.float32 if attn_backend == "trtllm_mha" else torch.bfloat16
self.sinks = nn.Parameter(
torch.empty(self.num_heads, dtype=sinks_dtype), requires_grad=False
......@@ -591,7 +591,7 @@ class GptOssForCausalLM(nn.Module):
config.hidden_size,
# quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
)
self.logits_processor = LogitsProcessor(config)
self.capture_aux_hidden_states = False
......
......@@ -28,6 +28,7 @@ from torch import nn
from transformers import PretrainedConfig
from sglang.srt.distributed import (
get_moe_expert_parallel_world_size,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
......@@ -35,6 +36,7 @@ from sglang.srt.distributed import (
)
from sglang.srt.layers.activation import GeluAndMul
from sglang.srt.layers.elementwise import (
experts_combine_triton,
fused_dual_residual_rmsnorm,
fused_rmsnorm,
gelu_and_mul_triton,
......@@ -62,10 +64,10 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.loader import DefaultModelLoader
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import add_prefix, dispose_tensor, dump_to_file
logger = logging.getLogger(__name__)
......@@ -864,10 +866,10 @@ class Grok1ForCausalLM(nn.Module):
# Dump tensors for debugging
global debug_tensor_dump_output_folder, debug_tensor_dump_inject
debug_tensor_dump_output_folder = (
get_global_server_args().debug_tensor_dump_output_folder
)
debug_tensor_dump_inject = get_global_server_args().debug_tensor_dump_inject
debug_tensor_dump_output_folder = global_server_args_dict[
"debug_tensor_dump_output_folder"
]
debug_tensor_dump_inject = global_server_args_dict["debug_tensor_dump_inject"]
warnings.filterwarnings("ignore", category=FutureWarning)
if get_tensor_model_parallel_rank() == 0:
......
......@@ -45,13 +45,13 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import (
default_weight_loader,
kv_cache_scales_loader,
maybe_remap_kv_scale_name,
)
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import add_prefix, make_layers
from sglang.utils import get_exception_traceback
......@@ -433,7 +433,7 @@ class LlamaForCausalLM(nn.Module):
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
)
self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
......
......@@ -32,10 +32,14 @@
import concurrent.futures
import logging
from typing import Iterable, Optional, Tuple
import os
from enum import IntEnum, auto
from typing import Any, Dict, Iterable, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import nn
from tqdm import tqdm
from sglang.srt.configs import LongcatFlashConfig
from sglang.srt.distributed import (
......@@ -81,10 +85,10 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.deepseek_v2 import DeepseekV2AttentionMLA
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import (
BumpAllocator,
LazyValue,
......@@ -591,7 +595,7 @@ class LongcatFlashForCausalLM(nn.Module):
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
)
self.logits_processor = LogitsProcessor(config)
......
......@@ -31,9 +31,9 @@ from sglang.srt.managers.schedule_batch import (
Modality,
MultimodalDataItem,
MultimodalInputs,
global_server_args_dict,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import is_cpu
_is_cpu = is_cpu()
......@@ -448,7 +448,7 @@ class Llama4ForConditionalGeneration(nn.Module):
)
self.has_vision = (
self.has_vision_weights and get_global_server_args().enable_multimodal
self.has_vision_weights and global_server_args_dict["enable_multimodal"]
)
if self.has_vision:
......
......@@ -64,10 +64,10 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.server_args import get_global_server_args
from sglang.srt.two_batch_overlap import model_forward_maybe_tbo
from sglang.srt.utils import add_prefix, is_cuda, make_layers
......@@ -156,7 +156,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
layer_id=self.layer_id,
top_k=config.num_experts_per_tok,
num_experts=config.num_experts
+ get_global_server_args().ep_num_redundant_experts,
+ global_server_args_dict["ep_num_redundant_experts"],
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
quant_config=quant_config,
......@@ -192,7 +192,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
# TODO: we will support tp < ep in the future
self.ep_size = get_moe_expert_parallel_world_size()
self.num_experts = (
config.num_experts + get_global_server_args().ep_num_redundant_experts
config.num_experts + global_server_args_dict["ep_num_redundant_experts"]
)
self.top_k = config.num_experts_per_tok
......@@ -643,7 +643,7 @@ class Qwen2MoeForCausalLM(nn.Module):
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
)
self.logits_processor = LogitsProcessor(config)
# For EAGLE3 support
......
......@@ -54,6 +54,7 @@ from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding, get_rope
from sglang.srt.layers.utils import get_layer_id
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader
......@@ -63,7 +64,6 @@ from sglang.srt.models.utils import (
create_fused_set_kv_buffer_arg,
enable_fused_set_kv_buffer,
)
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import (
add_prefix,
is_cuda,
......@@ -104,7 +104,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
self.experts = get_moe_impl_class(quant_config)(
num_experts=config.num_experts
+ get_global_server_args().ep_num_redundant_experts,
+ global_server_args_dict["ep_num_redundant_experts"],
top_k=config.num_experts_per_tok,
layer_id=layer_id,
hidden_size=config.hidden_size,
......@@ -125,7 +125,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
# TODO: we will support tp < ep in the future
self.ep_size = get_moe_expert_parallel_world_size()
self.num_experts = (
config.num_experts + get_global_server_args().ep_num_redundant_experts
config.num_experts + global_server_args_dict["ep_num_redundant_experts"]
)
self.top_k = config.num_experts_per_tok
......@@ -693,7 +693,7 @@ class Qwen3MoeForCausalLM(nn.Module):
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
)
self.logits_processor = LogitsProcessor(config)
self.capture_aux_hidden_states = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment