Unverified Commit 1083e7e3 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Deprecate `global_server_args_dict` (#11331)

parent 2157d12a
...@@ -39,7 +39,6 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -39,7 +39,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import ( from sglang.srt.model_loader.weight_utils import (
...@@ -47,6 +46,7 @@ from sglang.srt.model_loader.weight_utils import ( ...@@ -47,6 +46,7 @@ from sglang.srt.model_loader.weight_utils import (
sharded_weight_loader, sharded_weight_loader,
) )
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP, Qwen2MoeSparseMoeBlock from sglang.srt.models.qwen2_moe import Qwen2MoeMLP, Qwen2MoeSparseMoeBlock
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import ( from sglang.srt.utils import (
LazyValue, LazyValue,
add_prefix, add_prefix,
...@@ -905,7 +905,7 @@ class Qwen3NextForCausalLM(nn.Module): ...@@ -905,7 +905,7 @@ class Qwen3NextForCausalLM(nn.Module):
quant_config=quant_config, quant_config=quant_config,
org_num_embeddings=config.vocab_size, org_num_embeddings=config.vocab_size,
prefix=add_prefix("lm_head", prefix), prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
) )
self.lm_head = self.lm_head.float() self.lm_head = self.lm_head.float()
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
......
...@@ -21,14 +21,13 @@ from torch import nn ...@@ -21,14 +21,13 @@ from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
from sglang.srt.layers.layernorm import GemmaRMSNorm, RMSNorm from sglang.srt.layers.layernorm import GemmaRMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.qwen3_moe import Qwen3MoeModel
from sglang.srt.models.qwen3_next import Qwen3NextForCausalLM, Qwen3NextModel from sglang.srt.models.qwen3_next import Qwen3NextForCausalLM, Qwen3NextModel
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import add_prefix from sglang.srt.utils import add_prefix
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -69,7 +68,7 @@ class Qwen3NextForCausalLMMTP(Qwen3NextForCausalLM): ...@@ -69,7 +68,7 @@ class Qwen3NextForCausalLMMTP(Qwen3NextForCausalLM):
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("model.shared_head.head", prefix), prefix=add_prefix("model.shared_head.head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
......
...@@ -38,20 +38,12 @@ from sglang.srt.layers.logits_processor import LogitsProcessor ...@@ -38,20 +38,12 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.utils import get_layer_id
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.mm_utils import ( from sglang.srt.managers.mm_utils import general_mm_embed_routine
MultiModalityDataPaddingPatternMultimodalTokens, from sglang.srt.managers.schedule_batch import MultimodalDataItem
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import (
MultimodalDataItem,
MultimodalInputs,
global_server_args_dict,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen3_moe import Qwen3MoeForCausalLM, Qwen3MoeModel from sglang.srt.models.qwen3_moe import Qwen3MoeModel
from sglang.srt.models.qwen3_vl import ( from sglang.srt.models.qwen3_vl import (
Qwen3_VisionTransformer, Qwen3_VisionTransformer,
Qwen3VLForConditionalGeneration, Qwen3VLForConditionalGeneration,
......
...@@ -57,7 +57,6 @@ from sglang.srt.managers.schedule_batch import ( ...@@ -57,7 +57,6 @@ from sglang.srt.managers.schedule_batch import (
Modality, Modality,
MultimodalDataItem, MultimodalDataItem,
MultimodalInputs, MultimodalInputs,
global_server_args_dict,
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
...@@ -300,7 +299,7 @@ class Step3TextDecoderLayer(nn.Module): ...@@ -300,7 +299,7 @@ class Step3TextDecoderLayer(nn.Module):
# self.n_shared_experts = 1 # self.n_shared_experts = 1
# self.num_fused_shared_experts = ( # self.num_fused_shared_experts = (
# 0 # 0
# if global_server_args_dict["disable_shared_experts_fusion"] # if global_server_args.disable_shared_experts_fusion
# else self.n_shared_experts # else self.n_shared_experts
# ) # )
self.num_fused_shared_experts = 0 self.num_fused_shared_experts = 0
...@@ -774,7 +773,7 @@ class Step3VLForConditionalGeneration(nn.Module): ...@@ -774,7 +773,7 @@ class Step3VLForConditionalGeneration(nn.Module):
# self.n_shared_experts = 1 # self.n_shared_experts = 1
# self.num_fused_shared_experts = ( # self.num_fused_shared_experts = (
# 0 # 0
# if global_server_args_dict["disable_shared_experts_fusion"] # if global_server_args.disable_shared_experts_fusion
# else self.n_shared_experts # else self.n_shared_experts
# ) # )
self.num_fused_shared_experts = 0 self.num_fused_shared_experts = 0
......
...@@ -2,7 +2,6 @@ from __future__ import annotations ...@@ -2,7 +2,6 @@ from __future__ import annotations
import dataclasses import dataclasses
import logging import logging
import threading
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
import torch import torch
...@@ -10,6 +9,7 @@ import torch ...@@ -10,6 +9,7 @@ import torch
import sglang.srt.sampling.penaltylib as penaltylib import sglang.srt.sampling.penaltylib as penaltylib
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
from sglang.srt.sampling.sampling_params import TOP_K_ALL from sglang.srt.sampling.sampling_params import TOP_K_ALL
from sglang.srt.server_args import get_global_server_args
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.managers.schedule_batch import ScheduleBatch
...@@ -66,16 +66,10 @@ class SamplingBatchInfo: ...@@ -66,16 +66,10 @@ class SamplingBatchInfo:
# Handle logit bias # Handle logit bias
logit_bias: Optional[torch.Tensor] = None logit_bias: Optional[torch.Tensor] = None
@classmethod
def _get_global_server_args_dict(cls):
from sglang.srt.managers.schedule_batch import global_server_args_dict
return global_server_args_dict
@classmethod @classmethod
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int): def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
global_server_args_dict = cls._get_global_server_args_dict() global_server_args = get_global_server_args()
enable_deterministic = global_server_args_dict["enable_deterministic_inference"] enable_deterministic = global_server_args.enable_deterministic_inference
reqs = batch.reqs reqs = batch.reqs
device = batch.device device = batch.device
...@@ -112,10 +106,9 @@ class SamplingBatchInfo: ...@@ -112,10 +106,9 @@ class SamplingBatchInfo:
logit_bias[i, int(key)] = value logit_bias[i, int(key)] = value
# Check if any request has custom logit processor # Check if any request has custom logit processor
has_custom_logit_processor = global_server_args_dict[ has_custom_logit_processor = (
"enable_custom_logit_processor" global_server_args.enable_custom_logit_processor
] and any( # check the flag first. and any(r.custom_logit_processor for r in reqs) # check the flag first.
r.custom_logit_processor for r in reqs
) # then check the requests. ) # then check the requests.
if has_custom_logit_processor: if has_custom_logit_processor:
......
...@@ -53,6 +53,7 @@ from sglang.utils import is_in_ci ...@@ -53,6 +53,7 @@ from sglang.utils import is_in_ci
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Define constants # Define constants
LOAD_FORMAT_CHOICES = [ LOAD_FORMAT_CHOICES = [
"auto", "auto",
...@@ -3329,6 +3330,22 @@ class ServerArgs: ...@@ -3329,6 +3330,22 @@ class ServerArgs:
) )
# NOTE: This is a global variable to hold the server args for scheduler.
_global_server_args: Optional[ServerArgs] = None
def set_global_server_args_for_scheduler(server_args: ServerArgs):
global _global_server_args
_global_server_args = server_args
def get_global_server_args() -> ServerArgs:
if _global_server_args is None:
raise ValueError("Global server args is not set yet!")
return _global_server_args
def prepare_server_args(argv: List[str]) -> ServerArgs: def prepare_server_args(argv: List[str]) -> ServerArgs:
""" """
Prepare the server arguments from the command line arguments. Prepare the server arguments from the command line arguments.
...@@ -3363,8 +3380,8 @@ def prepare_server_args(argv: List[str]) -> ServerArgs: ...@@ -3363,8 +3380,8 @@ def prepare_server_args(argv: List[str]) -> ServerArgs:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser) ServerArgs.add_cli_args(parser)
raw_args = parser.parse_args(argv) raw_args = parser.parse_args(argv)
server_args = ServerArgs.from_cli_args(raw_args)
return server_args return ServerArgs.from_cli_args(raw_args)
ZMQ_TCP_PORT_DELTA = 233 ZMQ_TCP_PORT_DELTA = 233
......
...@@ -6,7 +6,6 @@ import torch ...@@ -6,7 +6,6 @@ import torch
from sglang.srt.layers.moe import get_moe_runner_backend from sglang.srt.layers.moe import get_moe_runner_backend
from sglang.srt.layers.moe.utils import is_sbo_enabled from sglang.srt.layers.moe.utils import is_sbo_enabled
from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import get_int_env_var from sglang.srt.utils import get_int_env_var
......
...@@ -11,7 +11,7 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_trito ...@@ -11,7 +11,7 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_trito
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import apply_custom_logit_processor from sglang.srt.layers.sampler import apply_custom_logit_processor
from sglang.srt.managers.overlap_utils import FutureIndices from sglang.srt.managers.overlap_utils import FutureIndices
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.common import ( from sglang.srt.mem_cache.common import (
alloc_paged_token_slots_extend, alloc_paged_token_slots_extend,
...@@ -19,6 +19,7 @@ from sglang.srt.mem_cache.common import ( ...@@ -19,6 +19,7 @@ from sglang.srt.mem_cache.common import (
get_last_loc, get_last_loc,
) )
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
from sglang.srt.server_args import get_global_server_args
from sglang.srt.speculative.eagle_info_v2 import ( from sglang.srt.speculative.eagle_info_v2 import (
EagleDraftInputV2Mixin, EagleDraftInputV2Mixin,
EagleVerifyInputV2Mixin, EagleVerifyInputV2Mixin,
...@@ -332,12 +333,8 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin): ...@@ -332,12 +333,8 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
uniform_samples_for_final_sampling=coins_for_final_sampling, uniform_samples_for_final_sampling=coins_for_final_sampling,
target_probs=target_probs, target_probs=target_probs,
draft_probs=draft_probs, draft_probs=draft_probs,
threshold_single=global_server_args_dict[ threshold_single=get_global_server_args().speculative_accept_threshold_single,
"speculative_accept_threshold_single" threshold_acc=get_global_server_args().speculative_accept_threshold_acc,
],
threshold_acc=global_server_args_dict[
"speculative_accept_threshold_acc"
],
deterministic=True, deterministic=True,
) )
......
...@@ -11,7 +11,6 @@ import triton.language as tl ...@@ -11,7 +11,6 @@ import triton.language as tl
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import ModelWorkerBatch from sglang.srt.managers.schedule_batch import ModelWorkerBatch
from sglang.srt.managers.scheduler import global_server_args_dict
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
from sglang.srt.model_executor.forward_batch_info import ( from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode, CaptureHiddenMode,
...@@ -19,6 +18,7 @@ from sglang.srt.model_executor.forward_batch_info import ( ...@@ -19,6 +18,7 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardMode, ForwardMode,
) )
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.server_args import get_global_server_args
from sglang.srt.speculative.build_eagle_tree import TreeMaskMode from sglang.srt.speculative.build_eagle_tree import TreeMaskMode
from sglang.srt.speculative.spec_utils import ( from sglang.srt.speculative.spec_utils import (
SIMULATE_ACC_LEN, SIMULATE_ACC_LEN,
...@@ -265,12 +265,8 @@ class EagleVerifyInputV2Mixin: ...@@ -265,12 +265,8 @@ class EagleVerifyInputV2Mixin:
uniform_samples_for_final_sampling=coins_for_final_sampling, uniform_samples_for_final_sampling=coins_for_final_sampling,
target_probs=target_probs, target_probs=target_probs,
draft_probs=draft_probs, draft_probs=draft_probs,
threshold_single=global_server_args_dict[ threshold_single=get_global_server_args().speculative_accept_threshold_single,
"speculative_accept_threshold_single" threshold_acc=get_global_server_args().speculative_accept_threshold_acc,
],
threshold_acc=global_server_args_dict[
"speculative_accept_threshold_acc"
],
deterministic=True, deterministic=True,
) )
......
...@@ -14,7 +14,7 @@ from sglang.srt.distributed import ( ...@@ -14,7 +14,7 @@ from sglang.srt.distributed import (
) )
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.managers.scheduler import GenerationBatchResult from sglang.srt.managers.scheduler import GenerationBatchResult
from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.mem_cache.common import ( from sglang.srt.mem_cache.common import (
...@@ -27,7 +27,7 @@ from sglang.srt.model_executor.forward_batch_info import ( ...@@ -27,7 +27,7 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch, ForwardBatch,
ForwardMode, ForwardMode,
) )
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs, get_global_server_args
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
from sglang.srt.speculative.eagle_draft_cuda_graph_runner import ( from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
EAGLEDraftCudaGraphRunner, EAGLEDraftCudaGraphRunner,
...@@ -261,7 +261,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -261,7 +261,7 @@ class EAGLEWorker(TpModelWorker):
) )
def _create_flashinfer_decode_backend(self): def _create_flashinfer_decode_backend(self):
if not global_server_args_dict["use_mla_backend"]: if not get_global_server_args().use_mla_backend:
from sglang.srt.layers.attention.flashinfer_backend import ( from sglang.srt.layers.attention.flashinfer_backend import (
FlashInferMultiStepDraftBackend, FlashInferMultiStepDraftBackend,
) )
...@@ -325,7 +325,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -325,7 +325,7 @@ class EAGLEWorker(TpModelWorker):
) )
def _create_trtllm_mla_decode_backend(self): def _create_trtllm_mla_decode_backend(self):
if not global_server_args_dict["use_mla_backend"]: if not get_global_server_args().use_mla_backend:
raise ValueError( raise ValueError(
"trtllm_mla backend requires MLA model (use_mla_backend=True)." "trtllm_mla backend requires MLA model (use_mla_backend=True)."
) )
...@@ -340,7 +340,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -340,7 +340,7 @@ class EAGLEWorker(TpModelWorker):
) )
def _create_flashinfer_prefill_backend(self): def _create_flashinfer_prefill_backend(self):
if not global_server_args_dict["use_mla_backend"]: if not get_global_server_args().use_mla_backend:
from sglang.srt.layers.attention.flashinfer_backend import ( from sglang.srt.layers.attention.flashinfer_backend import (
FlashInferAttnBackend, FlashInferAttnBackend,
) )
...@@ -376,7 +376,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -376,7 +376,7 @@ class EAGLEWorker(TpModelWorker):
return TRTLLMHAAttnBackend(self.draft_model_runner, skip_prefill=False) return TRTLLMHAAttnBackend(self.draft_model_runner, skip_prefill=False)
def _create_trtllm_mla_prefill_backend(self): def _create_trtllm_mla_prefill_backend(self):
if not global_server_args_dict["use_mla_backend"]: if not get_global_server_args().use_mla_backend:
raise ValueError( raise ValueError(
"trtllm_mla backend requires MLA model (use_mla_backend=True)." "trtllm_mla backend requires MLA model (use_mla_backend=True)."
) )
......
...@@ -7,6 +7,8 @@ from typing import Optional, Tuple ...@@ -7,6 +7,8 @@ from typing import Optional, Tuple
import torch import torch
import triton import triton
from sglang.srt.server_args import get_global_server_args
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
from dataclasses import dataclass from dataclasses import dataclass
...@@ -16,7 +18,7 @@ import torch.nn.functional as F ...@@ -16,7 +18,7 @@ import torch.nn.functional as F
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import apply_custom_logit_processor from sglang.srt.layers.sampler import apply_custom_logit_processor
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.mem_cache.common import ( from sglang.srt.mem_cache.common import (
alloc_paged_token_slots_extend, alloc_paged_token_slots_extend,
alloc_token_slots, alloc_token_slots,
...@@ -350,10 +352,8 @@ class NgramVerifyInput(SpecInput): ...@@ -350,10 +352,8 @@ class NgramVerifyInput(SpecInput):
uniform_samples_for_final_sampling=coins_for_final_sampling, uniform_samples_for_final_sampling=coins_for_final_sampling,
target_probs=target_probs, target_probs=target_probs,
draft_probs=draft_probs, draft_probs=draft_probs,
threshold_single=global_server_args_dict[ threshold_single=get_global_server_args().speculative_accept_threshold_single,
"speculative_accept_threshold_single" threshold_acc=get_global_server_args().speculative_accept_threshold_acc,
],
threshold_acc=global_server_args_dict["speculative_accept_threshold_acc"],
deterministic=True, deterministic=True,
) )
......
...@@ -22,7 +22,7 @@ from sglang.srt.layers.moe import ( ...@@ -22,7 +22,7 @@ from sglang.srt.layers.moe import (
) )
from sglang.srt.layers.moe.token_dispatcher import DeepEPDispatcher from sglang.srt.layers.moe.token_dispatcher import DeepEPDispatcher
from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.model_executor.forward_batch_info import ( from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch, ForwardBatch,
ForwardMode, ForwardMode,
...@@ -30,6 +30,7 @@ from sglang.srt.model_executor.forward_batch_info import ( ...@@ -30,6 +30,7 @@ from sglang.srt.model_executor.forward_batch_info import (
) )
from sglang.srt.operations import execute_operations, execute_overlapped_operations from sglang.srt.operations import execute_operations, execute_overlapped_operations
from sglang.srt.operations_strategy import OperationsStrategy from sglang.srt.operations_strategy import OperationsStrategy
from sglang.srt.server_args import get_global_server_args
from sglang.srt.speculative.spec_info import SpecInput from sglang.srt.speculative.spec_info import SpecInput
from sglang.srt.utils import BumpAllocator, empty_context, get_bool_env_var, is_hip from sglang.srt.utils import BumpAllocator, empty_context, get_bool_env_var, is_hip
...@@ -153,7 +154,7 @@ def _update_device_and_sum_field_from_cpu_field( ...@@ -153,7 +154,7 @@ def _update_device_and_sum_field_from_cpu_field(
cpu_value cpu_value
if isinstance(cpu_value, torch.Tensor) if isinstance(cpu_value, torch.Tensor)
else torch.tensor(cpu_value, dtype=old_device_value.dtype) else torch.tensor(cpu_value, dtype=old_device_value.dtype)
).to(device=global_server_args_dict["device"], non_blocking=True) ).to(device=get_global_server_args().device, non_blocking=True)
setattr(batch, device_field, new_device_value) setattr(batch, device_field, new_device_value)
if sum_field is not None: if sum_field is not None:
...@@ -582,7 +583,7 @@ class TboForwardBatchPreparer: ...@@ -582,7 +583,7 @@ class TboForwardBatchPreparer:
sum_field=None, sum_field=None,
) )
_, child_b.extend_start_loc = compute_position( _, child_b.extend_start_loc = compute_position(
global_server_args_dict["attention_backend"], get_global_server_args().attention_backend,
child_b.extend_prefix_lens, child_b.extend_prefix_lens,
child_b.extend_seq_lens, child_b.extend_seq_lens,
child_b.extend_num_tokens, child_b.extend_num_tokens,
...@@ -687,7 +688,7 @@ class TboForwardBatchPreparer: ...@@ -687,7 +688,7 @@ class TboForwardBatchPreparer:
# TODO improve, e.g. unify w/ `init_raw` # TODO improve, e.g. unify w/ `init_raw`
if ( if (
global_server_args_dict["moe_dense_tp_size"] == 1 get_global_server_args().moe_dense_tp_size == 1
and batch.global_dp_buffer_len is not None and batch.global_dp_buffer_len is not None
): ):
sum_len = end_token_index - start_token_index sum_len = end_token_index - start_token_index
...@@ -755,7 +756,7 @@ class TboForwardBatchPreparer: ...@@ -755,7 +756,7 @@ class TboForwardBatchPreparer:
value_a = min(tbo_split_token_index, num_token_non_padded) value_a = min(tbo_split_token_index, num_token_non_padded)
value_b = max(0, num_token_non_padded - tbo_split_token_index) value_b = max(0, num_token_non_padded - tbo_split_token_index)
return torch.tensor([value_a, value_b], dtype=torch.int32).to( return torch.tensor([value_a, value_b], dtype=torch.int32).to(
device=global_server_args_dict["device"], non_blocking=True device=get_global_server_args().device, non_blocking=True
) )
@classmethod @classmethod
......
...@@ -7,7 +7,11 @@ import torch.nn as nn ...@@ -7,7 +7,11 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.server_args import (
ServerArgs,
get_global_server_args,
set_global_server_args_for_scheduler,
)
class LMHeadStub(nn.Module): class LMHeadStub(nn.Module):
...@@ -32,8 +36,10 @@ class TestLMHeadFP32(unittest.TestCase): ...@@ -32,8 +36,10 @@ class TestLMHeadFP32(unittest.TestCase):
raise unittest.SkipTest("needs CUDA GPU") raise unittest.SkipTest("needs CUDA GPU")
def _make_logprocessor(self, vocab_size, enable_fp32): def _make_logprocessor(self, vocab_size, enable_fp32):
global_server_args_dict["enable_dp_lm_head"] = False ServerArgs.__post_init__ = lambda self: None # disable validation
global_server_args_dict["enable_fp32_lm_head"] = enable_fp32 set_global_server_args_for_scheduler(ServerArgs(model_path="dummy"))
get_global_server_args().enable_dp_lm_head = False
get_global_server_args().enable_fp32_lm_head = enable_fp32
cfg = SimpleNamespace(vocab_size=vocab_size, final_logit_softcapping=None) cfg = SimpleNamespace(vocab_size=vocab_size, final_logit_softcapping=None)
return LogitsProcessor(cfg, skip_all_gather=True, logit_scale=None) return LogitsProcessor(cfg, skip_all_gather=True, logit_scale=None)
......
...@@ -4,6 +4,7 @@ import unittest ...@@ -4,6 +4,7 @@ import unittest
import requests import requests
import torch import torch
from sglang.srt.server_args import set_global_server_args_for_scheduler
from sglang.srt.utils import kill_process_tree from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
...@@ -16,17 +17,15 @@ from sglang.test.test_utils import ( ...@@ -16,17 +17,15 @@ from sglang.test.test_utils import (
def check_quant_method(model_path: str, use_marlin_kernel: bool): def check_quant_method(model_path: str, use_marlin_kernel: bool):
from sglang.srt.configs.device_config import DeviceConfig from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import AttentionArch, ModelConfig from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.distributed import ( from sglang.srt.distributed import (
get_tp_group,
init_distributed_environment, init_distributed_environment,
initialize_model_parallel, initialize_model_parallel,
set_custom_all_reduce,
) )
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
from sglang.srt.layers.quantization.utils import get_dynamic_override from sglang.srt.layers.quantization.utils import get_dynamic_override
from sglang.srt.model_loader import get_model from sglang.srt.model_loader import get_model
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import ServerArgs
try: try:
init_distributed_environment( init_distributed_environment(
...@@ -43,6 +42,7 @@ def check_quant_method(model_path: str, use_marlin_kernel: bool): ...@@ -43,6 +42,7 @@ def check_quant_method(model_path: str, use_marlin_kernel: bool):
pass pass
server_args = ServerArgs(model_path=model_path, dtype=torch.float16) server_args = ServerArgs(model_path=model_path, dtype=torch.float16)
set_global_server_args_for_scheduler(server_args)
model_config = ModelConfig.from_server_args(server_args) model_config = ModelConfig.from_server_args(server_args)
load_config = LoadConfig() load_config = LoadConfig()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment