Unverified Commit 516738b0 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Depreate `global_server_args_dict` (#11528)

parent 0b6f535f
......@@ -6,9 +6,6 @@
class GlobalConfig:
"""
Store some global constants.
See also python/sglang/srt/managers/schedule_batch.py::global_server_args_dict, which stores
many global runtime arguments as well.
"""
def __init__(self):
......
......@@ -5,7 +5,7 @@ from packaging import version
from torch.cuda.memory import CUDAPluggableAllocator
from sglang.srt.distributed.parallel_state import GroupCoordinator
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.server_args import get_global_server_args
nccl_allocator_source = """
#include <nccl.h>
......@@ -32,7 +32,7 @@ _graph_pool_id = None
def is_symmetric_memory_enabled():
return global_server_args_dict["enable_symm_mem"]
return get_global_server_args().enable_symm_mem
def set_graph_pool_id(graph_pool_id):
......
......@@ -18,7 +18,7 @@ from typing import Literal, Optional
import torch
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.server_args import get_global_server_args
@dataclass
......@@ -34,7 +34,7 @@ class ExpertLocationDispatchInfo:
@classmethod
def init_new(cls, layer_id: int):
ep_dispatch_algorithm = global_server_args_dict["ep_dispatch_algorithm"]
ep_dispatch_algorithm = get_global_server_args().ep_dispatch_algorithm
expert_location_metadata = get_global_expert_location_metadata()
assert expert_location_metadata is not None
......
......@@ -24,7 +24,7 @@ from sglang.srt.eplb.expert_location import (
ExpertLocationMetadata,
get_global_expert_location_metadata,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import get_bool_env_var
logger = logging.getLogger(__name__)
......@@ -97,7 +97,7 @@ def _update_expert_weights_with_canary(
canary_tensor = (
_get_canary_value(old_expert_location_metadata, layer_id)
.clone()
.to(device=global_server_args_dict["device"], non_blocking=True)
.to(device=get_global_server_args().device, non_blocking=True)
)
routed_experts_weights_of_layer[layer_id].append(canary_tensor)
......
......@@ -5,8 +5,8 @@ from typing import TYPE_CHECKING
import torch
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.server_args import get_global_server_args
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
......@@ -42,7 +42,7 @@ class DoubleSparseAttnBackend(AttentionBackend):
# TODO: Change the hard-coded block_seq_num
self.BLOCK_SEQ = 128
if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
if get_global_server_args().triton_attention_reduce_in_fp32:
self.reduce_dtype = torch.float32
else:
self.reduce_dtype = torch.float16
......
......@@ -11,8 +11,8 @@ import triton.language as tl
from sglang.srt.configs.model_config import AttentionArch
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.radix_attention import AttentionType
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.server_args import get_global_server_args
from sglang.srt.speculative.spec_info import SpecInput
if TYPE_CHECKING:
......@@ -830,7 +830,7 @@ class FlashAttentionBackend(AttentionBackend):
):
# Do multi-head attention with chunked prefix cache
if forward_batch.attn_attend_prefix_cache:
assert not global_server_args_dict["disable_chunked_prefix_cache"]
assert not get_global_server_args().disable_chunked_prefix_cache
# MHA for chunked prefix kv cache when running model with MLA
assert forward_batch.prefix_chunk_idx is not None
assert forward_batch.prefix_chunk_cu_seq_lens is not None
......
......@@ -28,8 +28,8 @@ from sglang.srt.layers.attention.flashinfer_backend import (
create_flashinfer_kv_indices_triton,
)
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.server_args import get_global_server_args
from sglang.srt.speculative.spec_info import SpecInput
from sglang.srt.utils import (
is_flashinfer_available,
......@@ -193,9 +193,9 @@ class FlashInferMLAAttnBackend(AttentionBackend):
self.skip_prefill = skip_prefill
self.enable_chunk_kv = (
not skip_prefill
and global_server_args_dict["disaggregation_mode"] != "decode"
and not global_server_args_dict["disable_chunked_prefix_cache"]
and not global_server_args_dict["flashinfer_mla_disable_ragged"]
and get_global_server_args().disaggregation_mode != "decode"
and not get_global_server_args().disable_chunked_prefix_cache
and not get_global_server_args().flashinfer_mla_disable_ragged
)
self.page_size = model_runner.page_size
......@@ -306,7 +306,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
prefix_lens = forward_batch.extend_prefix_lens
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
use_ragged = (
not global_server_args_dict["flashinfer_mla_disable_ragged"]
not get_global_server_args().flashinfer_mla_disable_ragged
and extend_no_prefix
)
......
......@@ -23,9 +23,9 @@ from sglang.srt.layers.linear import ReplicatedLinear
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.rotary_embedding import get_rope_wrapper
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.server_args import get_global_server_args
if TYPE_CHECKING:
from sglang.srt.mem_cache.memory_pool import NSATokenToKVPool
......@@ -162,7 +162,7 @@ class Indexer(CustomOp):
base=rope_theta, # type: ignore
rope_scaling=rope_scaling,
is_neox_style=False,
device=global_server_args_dict["device"],
device=get_global_server_args().device,
)
self.block_size = block_size
self.scale_fmt = scale_fmt
......
......@@ -2,7 +2,7 @@ import torch
import triton
import triton.language as tl
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import is_cuda, is_hip
_is_cuda = is_cuda()
......@@ -11,7 +11,7 @@ if _is_cuda:
_is_hip = is_hip()
if global_server_args_dict.get("attention_reduce_in_fp32", False):
if get_global_server_args().triton_attention_reduce_in_fp32:
REDUCE_TRITON_TYPE = tl.float32
REDUCE_TORCH_TYPE = torch.float32
else:
......
......@@ -20,8 +20,8 @@ from sglang.srt.layers.attention.utils import (
create_flashmla_kv_indices_triton,
)
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import is_cuda, is_flashinfer_available
if is_flashinfer_available():
......@@ -123,9 +123,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
self.forward_prefill_metadata: Optional[TRTLLMMLAPrefillMetadata] = None
self.forward_decode_metadata: Union[TRTLLMMLADecodeMetadata, None] = None
self.disable_chunked_prefix_cache = global_server_args_dict[
"disable_chunked_prefix_cache"
]
self.disable_chunked_prefix_cache = (
get_global_server_args().disable_chunked_prefix_cache
)
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
......
......@@ -45,7 +45,7 @@ from sglang.srt.layers.linear import (
)
from sglang.srt.layers.quantization import QuantizationConfig
from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import add_prefix
ROTARY_EMBED_CLASSES = {
......@@ -468,7 +468,7 @@ class VisionAttention(nn.Module):
_passed_backend = qkv_backend
qkv_backend = self._determine_attention_backend(_passed_backend)
if (
global_server_args_dict["mm_attention_backend"] is None
get_global_server_args().mm_attention_backend is None
and _passed_backend is None
):
print_info_once(f"Multimodal attention backend not set. Use {qkv_backend}.")
......@@ -528,7 +528,7 @@ class VisionAttention(nn.Module):
- CUDA: "triton_attn"
- Non-CUDA: "sdpa"
"""
override_backend = global_server_args_dict["mm_attention_backend"]
override_backend = get_global_server_args().mm_attention_backend
if override_backend is not None:
backend = override_backend
elif passed_backend is not None:
......
......@@ -40,8 +40,9 @@ from sglang.srt.layers.moe import (
get_moe_a2a_backend,
should_use_flashinfer_cutlass_moe_fp4_allgather,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.server_args import get_global_server_args
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import (
get_bool_env_var,
is_cuda,
......@@ -168,7 +169,7 @@ class LayerScatterModes:
def enable_moe_dense_fully_dp():
return global_server_args_dict["moe_dense_tp_size"] == 1
return get_global_server_args().moe_dense_tp_size == 1
class LayerCommunicator:
......@@ -314,7 +315,9 @@ class LayerCommunicator:
def should_fuse_mlp_allreduce_with_next_layer(
self, forward_batch: ForwardBatch
) -> bool:
speculative_algo = global_server_args_dict.get("speculative_algorithm", None)
speculative_algo = SpeculativeAlgorithm.from_string(
get_global_server_args().speculative_algorithm
)
if (
is_dp_attention_enabled()
and speculative_algo is not None
......@@ -333,7 +336,7 @@ class LayerCommunicator:
static_conditions_met = (
(not self.is_last_layer)
and (self._context.tp_size > 1)
and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False)
and get_global_server_args().enable_flashinfer_allreduce_fusion
and _is_flashinfer_available
)
......@@ -531,7 +534,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
(_is_sm100_supported or _is_sm90_supported)
and _is_flashinfer_available
and hasattr(layernorm, "forward_with_allreduce_fusion")
and global_server_args_dict["enable_flashinfer_allreduce_fusion"]
and get_global_server_args().enable_flashinfer_allreduce_fusion
and hidden_states.shape[0] <= 4096
):
hidden_states, residual = layernorm.forward_with_allreduce_fusion(
......
......@@ -38,17 +38,15 @@ from sglang.srt.layers.dp_attention import (
get_dp_device,
get_dp_dtype,
get_dp_hidden_size,
get_global_dp_buffer,
get_local_attention_dp_size,
set_dp_buffer_len,
)
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode,
ForwardBatch,
ForwardMode,
)
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import dump_to_file, is_npu, use_intel_amx_backend
logger = logging.getLogger(__name__)
......@@ -230,8 +228,8 @@ class LogitsProcessor(nn.Module):
super().__init__()
self.config = config
self.logit_scale = logit_scale
self.use_attn_tp_group = global_server_args_dict["enable_dp_lm_head"]
self.use_fp32_lm_head = global_server_args_dict["enable_fp32_lm_head"]
self.use_attn_tp_group = get_global_server_args().enable_dp_lm_head
self.use_fp32_lm_head = get_global_server_args().enable_fp32_lm_head
if self.use_attn_tp_group:
self.attn_tp_size = get_attention_tp_size()
self.do_tensor_parallel_all_gather = (
......@@ -254,8 +252,8 @@ class LogitsProcessor(nn.Module):
):
self.final_logit_softcapping = None
self.debug_tensor_dump_output_folder = global_server_args_dict.get(
"debug_tensor_dump_output_folder", None
self.debug_tensor_dump_output_folder = (
get_global_server_args().debug_tensor_dump_output_folder
)
def compute_logprobs_for_multi_item_scoring(
......@@ -372,9 +370,7 @@ class LogitsProcessor(nn.Module):
logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)
# Check if multi-item scoring is enabled via server args (only for prefill-only requests)
multi_item_delimiter = global_server_args_dict.get(
"multi_item_scoring_delimiter"
)
multi_item_delimiter = get_global_server_args().multi_item_scoring_delimiter
if multi_item_delimiter is not None and logits_metadata.is_prefill_only:
return self.compute_logprobs_for_multi_item_scoring(
input_ids, hidden_states, lm_head, logits_metadata, multi_item_delimiter
......
......@@ -27,12 +27,10 @@ from sglang.srt.layers.moe.topk import TopKOutput, TopKOutputChecker
from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase,
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.fp8 import Fp8MoEMethod
from sglang.srt.layers.quantization.modelopt_quant import ModelOptNvFp4FusedMoEMethod
from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight
from sglang.srt.utils import (
cpu_has_amx_support,
......
......@@ -31,7 +31,7 @@ from sglang.srt.layers.quantization.base_config import (
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.utils import is_layer_skipped
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import (
direct_register_custom_op,
is_cuda,
......@@ -265,9 +265,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
self.with_bias = False
self.use_flashinfer = get_moe_runner_backend().is_flashinfer_mxfp4()
self.flashinfer_mxfp4_moe_precision = global_server_args_dict[
"flashinfer_mxfp4_moe_precision"
]
self.flashinfer_mxfp4_moe_precision = (
get_global_server_args().flashinfer_mxfp4_moe_precision
)
self.triton_kernel_moe_forward = None
self.triton_kernel_moe_with_bias_forward = None
......
......@@ -11,8 +11,8 @@ from sglang.srt.layers.dp_attention import (
is_dp_attention_enabled,
)
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import crash_on_warnings, get_bool_env_var, is_cuda
if is_cuda():
......@@ -33,7 +33,7 @@ RETURN_ORIGINAL_LOGPROB = get_bool_env_var("RETURN_ORIGINAL_LOGPROB")
class Sampler(nn.Module):
def __init__(self):
super().__init__()
self.use_nan_detection = global_server_args_dict["enable_nan_detection"]
self.use_nan_detection = get_global_server_args().enable_nan_detection
self.tp_sync_group = get_tp_group().device_group
if is_dp_attention_enabled():
......@@ -103,7 +103,7 @@ class Sampler(nn.Module):
del logits
if True: # Keep this redundant check to simplify some internal code sync
if global_server_args_dict["sampling_backend"] == "flashinfer":
if get_global_server_args().sampling_backend == "flashinfer":
if sampling_info.need_min_p_sampling:
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
......@@ -118,7 +118,7 @@ class Sampler(nn.Module):
filter_apply_order="joint",
check_nan=self.use_nan_detection,
)
elif global_server_args_dict["sampling_backend"] == "pytorch":
elif get_global_server_args().sampling_backend == "pytorch":
# A slower fallback implementation with torch native operations.
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
probs,
......@@ -131,7 +131,7 @@ class Sampler(nn.Module):
)
else:
raise ValueError(
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
f"Invalid sampling backend: {get_global_server_args().sampling_backend}"
)
if return_logprob:
......
......@@ -16,10 +16,10 @@ from sglang.srt.managers.schedule_batch import (
Modality,
MultimodalDataItem,
MultimodalInputs,
global_server_args_dict,
)
from sglang.srt.mem_cache.multimodal_cache import MultiModalCache
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import flatten_nested_list, is_npu, print_warning_once
from sglang.utils import logger
......@@ -428,7 +428,7 @@ def _adjust_embedding_length(
f"tokens from multimodal embeddings."
)
if num_mm_tokens_in_input_ids < num_mm_tokens_in_embedding:
chunked_prefill_size = global_server_args_dict["chunked_prefill_size"]
chunked_prefill_size = get_global_server_args().chunked_prefill_size
if chunked_prefill_size != -1:
logger.warning(
"You may want to avoid this issue by raising `chunked_prefill_size`, or disabling chunked prefill"
......
......@@ -73,7 +73,7 @@ from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs
from sglang.srt.server_args import ServerArgs, get_global_server_args
from sglang.srt.utils import flatten_nested_list
from sglang.srt.utils.common import next_power_of_2
......@@ -83,47 +83,6 @@ if TYPE_CHECKING:
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
GLOBAL_SERVER_ARGS_KEYS = [
"attention_backend",
"mm_attention_backend",
"debug_tensor_dump_inject",
"debug_tensor_dump_output_folder",
"chunked_prefill_size",
"device",
"disable_chunked_prefix_cache",
"disable_flashinfer_cutlass_moe_fp4_allgather",
"disable_radix_cache",
"enable_dp_lm_head",
"enable_fp32_lm_head",
"flashinfer_mxfp4_moe_precision",
"enable_flashinfer_allreduce_fusion",
"moe_dense_tp_size",
"ep_dispatch_algorithm",
"ep_num_redundant_experts",
"enable_nan_detection",
"flashinfer_mla_disable_ragged",
"pp_max_micro_batch_size",
"disable_shared_experts_fusion",
"sampling_backend",
"speculative_accept_threshold_single",
"speculative_accept_threshold_acc",
"speculative_attention_mode",
"torchao_config",
"triton_attention_reduce_in_fp32",
"num_reserved_decode_tokens",
"weight_loader_disable_mmap",
"enable_multimodal",
"enable_symm_mem",
"enable_custom_logit_processor",
"disaggregation_mode",
"enable_deterministic_inference",
"nsa_prefill",
"nsa_decode",
"multi_item_scoring_delimiter",
]
# Put some global args for easy access
global_server_args_dict = {k: getattr(ServerArgs, k) for k in GLOBAL_SERVER_ARGS_KEYS}
logger = logging.getLogger(__name__)
......@@ -685,12 +644,9 @@ class Req:
def is_prefill_only(self) -> bool:
"""Check if this request is prefill-only (no token generation needed)."""
# NOTE: when spec is enabled, prefill_only optimizations are disabled
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
spec_alg = global_server_args_dict["speculative_algorithm"]
return self.sampling_params.max_new_tokens == 0 and (
spec_alg is None or spec_alg == SpeculativeAlgorithm.NONE
)
spec_alg = get_global_server_args().speculative_algorithm
return self.sampling_params.max_new_tokens == 0 and spec_alg is None
def add_latency(self, stage: RequestStage):
if self.metrics_collector is None:
......
......@@ -122,7 +122,6 @@ from sglang.srt.managers.schedule_batch import (
Req,
RequestStage,
ScheduleBatch,
global_server_args_dict,
)
from sglang.srt.managers.schedule_policy import (
AddReqResult,
......@@ -151,7 +150,7 @@ from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.parser.reasoning_parser import ReasoningParser
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.server_args import PortArgs, ServerArgs, get_global_server_args
from sglang.srt.speculative.eagle_info import EagleDraftInput
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.tracing.trace import (
......@@ -448,13 +447,12 @@ class Scheduler(
self.max_req_input_len,
self.random_seed,
self.device,
worker_global_server_args_dict,
_,
_,
_,
) = self.tp_worker.get_worker_info()
if global_server_args_dict["pp_max_micro_batch_size"] is None:
global_server_args_dict["pp_max_micro_batch_size"] = max(
if get_global_server_args().pp_max_micro_batch_size is None:
get_global_server_args().pp_max_micro_batch_size = max(
self.max_running_requests // server_args.pp_size, 1
)
......@@ -466,7 +464,6 @@ class Scheduler(
self.world_group = get_world_group()
self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
global_server_args_dict.update(worker_global_server_args_dict)
set_random_seed(self.random_seed)
# Hybrid memory pool
......@@ -1942,7 +1939,7 @@ class Scheduler(
return ret
def get_num_allocatable_reqs(self, running_bs):
res = global_server_args_dict["pp_max_micro_batch_size"] - running_bs
res = get_global_server_args().pp_max_micro_batch_size - running_bs
if self.pp_size > 1:
res = min(res, self.req_to_token_pool.available_size())
return res
......@@ -2686,7 +2683,7 @@ class Scheduler(
)
def get_internal_state(self, recv_req: GetInternalStateReq):
ret = dict(global_server_args_dict)
ret = vars(get_global_server_args())
ret["last_gen_throughput"] = self.last_gen_throughput
ret["memory_usage"] = {
"weight": round(
......@@ -2742,11 +2739,11 @@ class Scheduler(
logger.info(f"{avg_spec_accept_length=}")
self.cum_spec_accept_length = self.cum_spec_accept_count = 0
for k, v in server_args_dict.items():
global_server_args_dict[k] = v
logger.info(f"Global server args updated! {global_server_args_dict=}")
setattr(get_global_server_args(), k, v)
logger.info(f"Global server args updated! {get_global_server_args()=}")
return SetInternalStateReqOutput(
updated=True,
server_args=global_server_args_dict,
server_args=vars(get_global_server_args()),
)
def handle_rpc_request(self, recv_req: RpcReqInput):
......
......@@ -33,7 +33,7 @@ from sglang.srt.managers.io_struct import (
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromTensorReqInput,
)
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
from sglang.srt.managers.scheduler import GenerationBatchResult
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
......@@ -190,7 +190,6 @@ class TpModelWorker:
self.max_req_input_len,
self.random_seed,
self.device,
global_server_args_dict,
self.model_runner.req_to_token_pool.size,
self.model_runner.req_to_token_pool.max_context_len,
self.model_runner.token_to_kv_pool.size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment