Unverified Commit 1083e7e3 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Deprecate `global_server_args_dict` (#11331)

parent 2157d12a
......@@ -39,7 +39,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import (
......@@ -47,6 +46,7 @@ from sglang.srt.model_loader.weight_utils import (
sharded_weight_loader,
)
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP, Qwen2MoeSparseMoeBlock
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import (
LazyValue,
add_prefix,
......@@ -905,7 +905,7 @@ class Qwen3NextForCausalLM(nn.Module):
quant_config=quant_config,
org_num_embeddings=config.vocab_size,
prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
)
self.lm_head = self.lm_head.float()
self.logits_processor = LogitsProcessor(config)
......
......@@ -21,14 +21,13 @@ from torch import nn
from transformers import PretrainedConfig
from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
from sglang.srt.layers.layernorm import GemmaRMSNorm, RMSNorm
from sglang.srt.layers.layernorm import GemmaRMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.qwen3_moe import Qwen3MoeModel
from sglang.srt.models.qwen3_next import Qwen3NextForCausalLM, Qwen3NextModel
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import add_prefix
logger = logging.getLogger(__name__)
......@@ -69,7 +68,7 @@ class Qwen3NextForCausalLMMTP(Qwen3NextForCausalLM):
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("model.shared_head.head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
)
self.logits_processor = LogitsProcessor(config)
......
......@@ -38,20 +38,12 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.utils import get_layer_id
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternMultimodalTokens,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import (
MultimodalDataItem,
MultimodalInputs,
global_server_args_dict,
)
from sglang.srt.managers.mm_utils import general_mm_embed_routine
from sglang.srt.managers.schedule_batch import MultimodalDataItem
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen3_moe import Qwen3MoeForCausalLM, Qwen3MoeModel
from sglang.srt.models.qwen3_moe import Qwen3MoeModel
from sglang.srt.models.qwen3_vl import (
Qwen3_VisionTransformer,
Qwen3VLForConditionalGeneration,
......
......@@ -57,7 +57,6 @@ from sglang.srt.managers.schedule_batch import (
Modality,
MultimodalDataItem,
MultimodalInputs,
global_server_args_dict,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
......@@ -300,7 +299,7 @@ class Step3TextDecoderLayer(nn.Module):
# self.n_shared_experts = 1
# self.num_fused_shared_experts = (
# 0
# if global_server_args_dict["disable_shared_experts_fusion"]
# if global_server_args.disable_shared_experts_fusion
# else self.n_shared_experts
# )
self.num_fused_shared_experts = 0
......@@ -774,7 +773,7 @@ class Step3VLForConditionalGeneration(nn.Module):
# self.n_shared_experts = 1
# self.num_fused_shared_experts = (
# 0
# if global_server_args_dict["disable_shared_experts_fusion"]
# if global_server_args.disable_shared_experts_fusion
# else self.n_shared_experts
# )
self.num_fused_shared_experts = 0
......
......@@ -2,7 +2,6 @@ from __future__ import annotations
import dataclasses
import logging
import threading
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
import torch
......@@ -10,6 +9,7 @@ import torch
import sglang.srt.sampling.penaltylib as penaltylib
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
from sglang.srt.sampling.sampling_params import TOP_K_ALL
from sglang.srt.server_args import get_global_server_args
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import ScheduleBatch
......@@ -66,16 +66,10 @@ class SamplingBatchInfo:
# Handle logit bias
logit_bias: Optional[torch.Tensor] = None
@classmethod
def _get_global_server_args_dict(cls):
from sglang.srt.managers.schedule_batch import global_server_args_dict
return global_server_args_dict
@classmethod
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
global_server_args_dict = cls._get_global_server_args_dict()
enable_deterministic = global_server_args_dict["enable_deterministic_inference"]
global_server_args = get_global_server_args()
enable_deterministic = global_server_args.enable_deterministic_inference
reqs = batch.reqs
device = batch.device
......@@ -112,10 +106,9 @@ class SamplingBatchInfo:
logit_bias[i, int(key)] = value
# Check if any request has custom logit processor
has_custom_logit_processor = global_server_args_dict[
"enable_custom_logit_processor"
] and any( # check the flag first.
r.custom_logit_processor for r in reqs
has_custom_logit_processor = (
global_server_args.enable_custom_logit_processor
and any(r.custom_logit_processor for r in reqs) # check the flag first.
) # then check the requests.
if has_custom_logit_processor:
......
......@@ -53,6 +53,7 @@ from sglang.utils import is_in_ci
logger = logging.getLogger(__name__)
# Define constants
LOAD_FORMAT_CHOICES = [
"auto",
......@@ -3329,6 +3330,22 @@ class ServerArgs:
)
# NOTE: This is a global variable to hold the server args for scheduler.
_global_server_args: Optional[ServerArgs] = None
def set_global_server_args_for_scheduler(server_args: ServerArgs):
global _global_server_args
_global_server_args = server_args
def get_global_server_args() -> ServerArgs:
if _global_server_args is None:
raise ValueError("Global server args is not set yet!")
return _global_server_args
def prepare_server_args(argv: List[str]) -> ServerArgs:
"""
Prepare the server arguments from the command line arguments.
......@@ -3363,8 +3380,8 @@ def prepare_server_args(argv: List[str]) -> ServerArgs:
parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
raw_args = parser.parse_args(argv)
server_args = ServerArgs.from_cli_args(raw_args)
return server_args
return ServerArgs.from_cli_args(raw_args)
ZMQ_TCP_PORT_DELTA = 233
......
......@@ -6,7 +6,6 @@ import torch
from sglang.srt.layers.moe import get_moe_runner_backend
from sglang.srt.layers.moe.utils import is_sbo_enabled
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import get_int_env_var
......
......@@ -11,7 +11,7 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_trito
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import apply_custom_logit_processor
from sglang.srt.managers.overlap_utils import FutureIndices
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.common import (
alloc_paged_token_slots_extend,
......@@ -19,6 +19,7 @@ from sglang.srt.mem_cache.common import (
get_last_loc,
)
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
from sglang.srt.server_args import get_global_server_args
from sglang.srt.speculative.eagle_info_v2 import (
EagleDraftInputV2Mixin,
EagleVerifyInputV2Mixin,
......@@ -332,12 +333,8 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
uniform_samples_for_final_sampling=coins_for_final_sampling,
target_probs=target_probs,
draft_probs=draft_probs,
threshold_single=global_server_args_dict[
"speculative_accept_threshold_single"
],
threshold_acc=global_server_args_dict[
"speculative_accept_threshold_acc"
],
threshold_single=get_global_server_args().speculative_accept_threshold_single,
threshold_acc=get_global_server_args().speculative_accept_threshold_acc,
deterministic=True,
)
......
......@@ -11,7 +11,6 @@ import triton.language as tl
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
from sglang.srt.managers.scheduler import global_server_args_dict
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode,
......@@ -19,6 +18,7 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardMode,
)
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.server_args import get_global_server_args
from sglang.srt.speculative.build_eagle_tree import TreeMaskMode
from sglang.srt.speculative.spec_utils import (
SIMULATE_ACC_LEN,
......@@ -265,12 +265,8 @@ class EagleVerifyInputV2Mixin:
uniform_samples_for_final_sampling=coins_for_final_sampling,
target_probs=target_probs,
draft_probs=draft_probs,
threshold_single=global_server_args_dict[
"speculative_accept_threshold_single"
],
threshold_acc=global_server_args_dict[
"speculative_accept_threshold_acc"
],
threshold_single=get_global_server_args().speculative_accept_threshold_single,
threshold_acc=get_global_server_args().speculative_accept_threshold_acc,
deterministic=True,
)
......
......@@ -14,7 +14,7 @@ from sglang.srt.distributed import (
)
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.managers.scheduler import GenerationBatchResult
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.mem_cache.common import (
......@@ -27,7 +27,7 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch,
ForwardMode,
)
from sglang.srt.server_args import ServerArgs
from sglang.srt.server_args import ServerArgs, get_global_server_args
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
EAGLEDraftCudaGraphRunner,
......@@ -261,7 +261,7 @@ class EAGLEWorker(TpModelWorker):
)
def _create_flashinfer_decode_backend(self):
if not global_server_args_dict["use_mla_backend"]:
if not get_global_server_args().use_mla_backend:
from sglang.srt.layers.attention.flashinfer_backend import (
FlashInferMultiStepDraftBackend,
)
......@@ -325,7 +325,7 @@ class EAGLEWorker(TpModelWorker):
)
def _create_trtllm_mla_decode_backend(self):
if not global_server_args_dict["use_mla_backend"]:
if not get_global_server_args().use_mla_backend:
raise ValueError(
"trtllm_mla backend requires MLA model (use_mla_backend=True)."
)
......@@ -340,7 +340,7 @@ class EAGLEWorker(TpModelWorker):
)
def _create_flashinfer_prefill_backend(self):
if not global_server_args_dict["use_mla_backend"]:
if not get_global_server_args().use_mla_backend:
from sglang.srt.layers.attention.flashinfer_backend import (
FlashInferAttnBackend,
)
......@@ -376,7 +376,7 @@ class EAGLEWorker(TpModelWorker):
return TRTLLMHAAttnBackend(self.draft_model_runner, skip_prefill=False)
def _create_trtllm_mla_prefill_backend(self):
if not global_server_args_dict["use_mla_backend"]:
if not get_global_server_args().use_mla_backend:
raise ValueError(
"trtllm_mla backend requires MLA model (use_mla_backend=True)."
)
......
......@@ -7,6 +7,8 @@ from typing import Optional, Tuple
import torch
import triton
from sglang.srt.server_args import get_global_server_args
logger = logging.getLogger(__name__)
from dataclasses import dataclass
......@@ -16,7 +18,7 @@ import torch.nn.functional as F
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import apply_custom_logit_processor
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.mem_cache.common import (
alloc_paged_token_slots_extend,
alloc_token_slots,
......@@ -350,10 +352,8 @@ class NgramVerifyInput(SpecInput):
uniform_samples_for_final_sampling=coins_for_final_sampling,
target_probs=target_probs,
draft_probs=draft_probs,
threshold_single=global_server_args_dict[
"speculative_accept_threshold_single"
],
threshold_acc=global_server_args_dict["speculative_accept_threshold_acc"],
threshold_single=get_global_server_args().speculative_accept_threshold_single,
threshold_acc=get_global_server_args().speculative_accept_threshold_acc,
deterministic=True,
)
......
......@@ -22,7 +22,7 @@ from sglang.srt.layers.moe import (
)
from sglang.srt.layers.moe.token_dispatcher import DeepEPDispatcher
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch,
ForwardMode,
......@@ -30,6 +30,7 @@ from sglang.srt.model_executor.forward_batch_info import (
)
from sglang.srt.operations import execute_operations, execute_overlapped_operations
from sglang.srt.operations_strategy import OperationsStrategy
from sglang.srt.server_args import get_global_server_args
from sglang.srt.speculative.spec_info import SpecInput
from sglang.srt.utils import BumpAllocator, empty_context, get_bool_env_var, is_hip
......@@ -153,7 +154,7 @@ def _update_device_and_sum_field_from_cpu_field(
cpu_value
if isinstance(cpu_value, torch.Tensor)
else torch.tensor(cpu_value, dtype=old_device_value.dtype)
).to(device=global_server_args_dict["device"], non_blocking=True)
).to(device=get_global_server_args().device, non_blocking=True)
setattr(batch, device_field, new_device_value)
if sum_field is not None:
......@@ -582,7 +583,7 @@ class TboForwardBatchPreparer:
sum_field=None,
)
_, child_b.extend_start_loc = compute_position(
global_server_args_dict["attention_backend"],
get_global_server_args().attention_backend,
child_b.extend_prefix_lens,
child_b.extend_seq_lens,
child_b.extend_num_tokens,
......@@ -687,7 +688,7 @@ class TboForwardBatchPreparer:
# TODO improve, e.g. unify w/ `init_raw`
if (
global_server_args_dict["moe_dense_tp_size"] == 1
get_global_server_args().moe_dense_tp_size == 1
and batch.global_dp_buffer_len is not None
):
sum_len = end_token_index - start_token_index
......@@ -755,7 +756,7 @@ class TboForwardBatchPreparer:
value_a = min(tbo_split_token_index, num_token_non_padded)
value_b = max(0, num_token_non_padded - tbo_split_token_index)
return torch.tensor([value_a, value_b], dtype=torch.int32).to(
device=global_server_args_dict["device"], non_blocking=True
device=get_global_server_args().device, non_blocking=True
)
@classmethod
......
......@@ -7,7 +7,11 @@ import torch.nn as nn
import torch.nn.functional as F
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.server_args import (
ServerArgs,
get_global_server_args,
set_global_server_args_for_scheduler,
)
class LMHeadStub(nn.Module):
......@@ -32,8 +36,10 @@ class TestLMHeadFP32(unittest.TestCase):
raise unittest.SkipTest("needs CUDA GPU")
def _make_logprocessor(self, vocab_size, enable_fp32):
global_server_args_dict["enable_dp_lm_head"] = False
global_server_args_dict["enable_fp32_lm_head"] = enable_fp32
ServerArgs.__post_init__ = lambda self: None # disable validation
set_global_server_args_for_scheduler(ServerArgs(model_path="dummy"))
get_global_server_args().enable_dp_lm_head = False
get_global_server_args().enable_fp32_lm_head = enable_fp32
cfg = SimpleNamespace(vocab_size=vocab_size, final_logit_softcapping=None)
return LogitsProcessor(cfg, skip_all_gather=True, logit_scale=None)
......
......@@ -4,6 +4,7 @@ import unittest
import requests
import torch
from sglang.srt.server_args import set_global_server_args_for_scheduler
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
......@@ -16,17 +17,15 @@ from sglang.test.test_utils import (
def check_quant_method(model_path: str, use_marlin_kernel: bool):
from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.distributed import (
get_tp_group,
init_distributed_environment,
initialize_model_parallel,
set_custom_all_reduce,
)
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
from sglang.srt.layers.quantization.utils import get_dynamic_override
from sglang.srt.model_loader import get_model
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.server_args import ServerArgs
try:
init_distributed_environment(
......@@ -43,6 +42,7 @@ def check_quant_method(model_path: str, use_marlin_kernel: bool):
pass
server_args = ServerArgs(model_path=model_path, dtype=torch.float16)
set_global_server_args_for_scheduler(server_args)
model_config = ModelConfig.from_server_args(server_args)
load_config = LoadConfig()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment