Unverified Commit f4941906 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Move args from `global_config` to `environ` (#11332)

parent 01e59e82
"""Global configurations""" """Global configurations"""
import os # FIXME: deprecate this file and move all usage to sglang.srt.environ or sglang.__init__.py
class GlobalConfig: class GlobalConfig:
...@@ -20,27 +20,6 @@ class GlobalConfig: ...@@ -20,27 +20,6 @@ class GlobalConfig:
# Default backend of the language # Default backend of the language
self.default_backend = None self.default_backend = None
# Runtime constants: New generation token ratio estimation
self.default_init_new_token_ratio = float(
os.environ.get("SGLANG_INIT_NEW_TOKEN_RATIO", 0.7)
)
self.default_min_new_token_ratio_factor = float(
os.environ.get("SGLANG_MIN_NEW_TOKEN_RATIO_FACTOR", 0.14)
)
self.default_new_token_ratio_decay_steps = float(
os.environ.get("SGLANG_NEW_TOKEN_RATIO_DECAY_STEPS", 600)
)
self.torch_empty_cache_interval = float(
os.environ.get(
"SGLANG_EMPTY_CACHE_INTERVAL", -1
) # in seconds. Set if you observe high memory accumulation over a long serving period.
)
# Runtime constants: others
self.retract_decode_steps = 20
self.flashinfer_workspace_size = int(
os.environ.get("FLASHINFER_WORKSPACE_SIZE", 384 * 1024 * 1024)
)
# Output tokenization configs # Output tokenization configs
self.skip_special_tokens_in_output = True self.skip_special_tokens_in_output = True
self.spaces_between_special_tokens_in_out = True self.spaces_between_special_tokens_in_out = True
......
...@@ -128,6 +128,14 @@ class Envs: ...@@ -128,6 +128,14 @@ class Envs:
SGLANG_SIMULATE_ACC_METHOD = EnvStr("multinomial") SGLANG_SIMULATE_ACC_METHOD = EnvStr("multinomial")
SGLANG_TORCH_PROFILER_DIR = EnvStr("/tmp") SGLANG_TORCH_PROFILER_DIR = EnvStr("/tmp")
# Scheduler: new token ratio hyperparameters
SGLANG_INIT_NEW_TOKEN_RATIO = EnvFloat(0.7)
SGLANG_MIN_NEW_TOKEN_RATIO_FACTOR = EnvFloat(0.14)
SGLANG_NEW_TOKEN_RATIO_DECAY_STEPS = EnvInt(600)
SGLANG_RETRACT_DECODE_STEPS = EnvInt(20)
# Scheduler: others:
SGLANG_EMPTY_CACHE_INTERVAL = EnvFloat(-1) # in seconds. Set if you observe high memory accumulation over a long serving period.
# Test: pd-disaggregation # Test: pd-disaggregation
SGLANG_TEST_PD_DISAGG_BACKEND = EnvStr("mooncake") SGLANG_TEST_PD_DISAGG_BACKEND = EnvStr("mooncake")
SGLANG_TEST_PD_DISAGG_DEVICES = EnvStr(None) SGLANG_TEST_PD_DISAGG_DEVICES = EnvStr(None)
...@@ -159,6 +167,7 @@ class Envs: ...@@ -159,6 +167,7 @@ class Envs:
# Flashinfer # Flashinfer
SGLANG_IS_FLASHINFER_AVAILABLE = EnvBool(True) SGLANG_IS_FLASHINFER_AVAILABLE = EnvBool(True)
SGLANG_ENABLE_FLASHINFER_GEMM = EnvBool(False) SGLANG_ENABLE_FLASHINFER_GEMM = EnvBool(False)
SGLANG_FLASHINFER_WORKSPACE_SIZE = EnvInt(384 * 1024 * 1024)
# Triton # Triton
SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS = EnvBool(False) SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS = EnvBool(False)
......
...@@ -16,13 +16,7 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Union ...@@ -16,13 +16,7 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Union
import torch import torch
if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1": from sglang.srt.environ import envs
torch._logging.set_logs(dynamo=logging.ERROR)
torch._dynamo.config.suppress_errors = True
logger = logging.getLogger(__name__)
from sglang.global_config import global_config
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.dp_attention import get_attention_tp_size
...@@ -41,6 +35,12 @@ if TYPE_CHECKING: ...@@ -41,6 +35,12 @@ if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
logger = logging.getLogger(__name__)
if envs.SGLANG_ENABLE_TORCH_COMPILE.get():
torch._logging.set_logs(dynamo=logging.ERROR)
torch._dynamo.config.suppress_errors = True
if is_flashinfer_available(): if is_flashinfer_available():
from flashinfer import ( from flashinfer import (
...@@ -160,7 +160,7 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -160,7 +160,7 @@ class FlashInferAttnBackend(AttentionBackend):
or "Qwen3ForCausalLM" in model_runner.model_config.hf_config.architectures or "Qwen3ForCausalLM" in model_runner.model_config.hf_config.architectures
or "MiMoForCausalLM" in model_runner.model_config.hf_config.architectures or "MiMoForCausalLM" in model_runner.model_config.hf_config.architectures
): ):
global_config.flashinfer_workspace_size = 512 * 1024 * 1024 envs.SGLANG_FLASHINFER_WORKSPACE_SIZE.set(512 * 1024 * 1024)
# When deterministic inference is enabled, tensor cores should be used for decode # When deterministic inference is enabled, tensor cores should be used for decode
# Also set split tile sizes for prefill and decode from environment variables, and disable kv split for cuda graph # Also set split tile sizes for prefill and decode from environment variables, and disable kv split for cuda graph
...@@ -180,13 +180,13 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -180,13 +180,13 @@ class FlashInferAttnBackend(AttentionBackend):
"SGLANG_FLASHINFER_DECODE_SPLIT_TILE_SIZE", 2048 "SGLANG_FLASHINFER_DECODE_SPLIT_TILE_SIZE", 2048
) )
self.disable_cuda_graph_kv_split = True self.disable_cuda_graph_kv_split = True
global_config.flashinfer_workspace_size = 2048 * 1024 * 1024 envs.SGLANG_FLASHINFER_WORKSPACE_SIZE.set(2048 * 1024 * 1024)
# Allocate buffers # Allocate buffers
global global_workspace_buffer global global_workspace_buffer
if global_workspace_buffer is None: if global_workspace_buffer is None:
# different from flashinfer zero_init_global_workspace_buffer # different from flashinfer zero_init_global_workspace_buffer
global_workspace_size = global_config.flashinfer_workspace_size global_workspace_size = envs.SGLANG_FLASHINFER_WORKSPACE_SIZE.get()
global_workspace_buffer = torch.empty( global_workspace_buffer = torch.empty(
global_workspace_size, global_workspace_size,
dtype=torch.uint8, dtype=torch.uint8,
......
...@@ -22,7 +22,7 @@ if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1": ...@@ -22,7 +22,7 @@ if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
torch._logging.set_logs(dynamo=logging.ERROR) torch._logging.set_logs(dynamo=logging.ERROR)
torch._dynamo.config.suppress_errors = True torch._dynamo.config.suppress_errors = True
from sglang.global_config import global_config from sglang.srt.environ import envs
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.flashinfer_backend import ( from sglang.srt.layers.attention.flashinfer_backend import (
create_flashinfer_kv_indices_triton, create_flashinfer_kv_indices_triton,
...@@ -204,7 +204,7 @@ class FlashInferMLAAttnBackend(AttentionBackend): ...@@ -204,7 +204,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
if global_workspace_buffer is None: if global_workspace_buffer is None:
# different from flashinfer zero_init_global_workspace_buffer # different from flashinfer zero_init_global_workspace_buffer
global_workspace_buffer = torch.empty( global_workspace_buffer = torch.empty(
global_config.flashinfer_workspace_size, envs.SGLANG_FLASHINFER_WORKSPACE_SIZE.get(),
dtype=torch.uint8, dtype=torch.uint8,
device=model_runner.device, device=model_runner.device,
) )
......
...@@ -37,7 +37,6 @@ import copy ...@@ -37,7 +37,6 @@ import copy
import dataclasses import dataclasses
import logging import logging
import re import re
import threading
import time import time
from enum import Enum, auto from enum import Enum, auto
from http import HTTPStatus from http import HTTPStatus
...@@ -47,7 +46,6 @@ from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union ...@@ -47,7 +46,6 @@ from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union
import numpy as np import numpy as np
import torch import torch
from sglang.global_config import global_config
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
from sglang.srt.disaggregation.base import BaseKVSender from sglang.srt.disaggregation.base import BaseKVSender
from sglang.srt.disaggregation.decode_schedule_batch_mixin import ( from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
...@@ -55,6 +53,7 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import ( ...@@ -55,6 +53,7 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
) )
from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
from sglang.srt.environ import envs
from sglang.srt.mem_cache.allocator import ( from sglang.srt.mem_cache.allocator import (
BaseTokenToKVPoolAllocator, BaseTokenToKVPoolAllocator,
SWATokenToKVPoolAllocator, SWATokenToKVPoolAllocator,
...@@ -1481,7 +1480,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1481,7 +1480,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs) total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs)
new_estimate_ratio = ( new_estimate_ratio = (
total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs) total_decoded_tokens
+ envs.SGLANG_RETRACT_DECODE_STEPS.get() * len(self.reqs)
) / total_max_new_tokens ) / total_max_new_tokens
new_estimate_ratio = min(1.0, new_estimate_ratio) new_estimate_ratio = min(1.0, new_estimate_ratio)
...@@ -1520,7 +1520,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1520,7 +1520,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.tree_cache.dec_lock_ref(req.last_node) self.tree_cache.dec_lock_ref(req.last_node)
# NOTE(lsyin): we should use the newly evictable memory instantly. # NOTE(lsyin): we should use the newly evictable memory instantly.
num_tokens = remaing_req_count * global_config.retract_decode_steps num_tokens = remaing_req_count * envs.SGLANG_RETRACT_DECODE_STEPS.get()
self._evict_tree_cache_if_needed(num_tokens) self._evict_tree_cache_if_needed(num_tokens)
req.reset_for_retract() req.reset_for_retract()
......
...@@ -35,7 +35,6 @@ from torch.cuda import Stream as CudaStream ...@@ -35,7 +35,6 @@ from torch.cuda import Stream as CudaStream
from torch.cuda import StreamContext as CudaStreamContext from torch.cuda import StreamContext as CudaStreamContext
from torch.distributed import barrier from torch.distributed import barrier
from sglang.global_config import global_config
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.constrained.base_grammar_backend import ( from sglang.srt.constrained.base_grammar_backend import (
INVALID_GRAMMAR_OBJ, INVALID_GRAMMAR_OBJ,
...@@ -61,6 +60,7 @@ from sglang.srt.disaggregation.utils import ( ...@@ -61,6 +60,7 @@ from sglang.srt.disaggregation.utils import (
prepare_abort, prepare_abort,
) )
from sglang.srt.distributed import get_pp_group, get_world_group from sglang.srt.distributed import get_pp_group, get_world_group
from sglang.srt.environ import envs
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
...@@ -556,18 +556,17 @@ class Scheduler( ...@@ -556,18 +556,17 @@ class Scheduler(
server_args.schedule_conservativeness >= 0 server_args.schedule_conservativeness >= 0
), "Invalid schedule_conservativeness" ), "Invalid schedule_conservativeness"
self.init_new_token_ratio = min( self.init_new_token_ratio = min(
global_config.default_init_new_token_ratio envs.SGLANG_INIT_NEW_TOKEN_RATIO.get()
* server_args.schedule_conservativeness, * server_args.schedule_conservativeness,
1.0, 1.0,
) )
self.min_new_token_ratio = min( self.min_new_token_ratio = min(
self.init_new_token_ratio self.init_new_token_ratio * envs.SGLANG_MIN_NEW_TOKEN_RATIO_FACTOR.get(),
* global_config.default_min_new_token_ratio_factor,
1.0, 1.0,
) )
self.new_token_ratio_decay = ( self.new_token_ratio_decay = (
self.init_new_token_ratio - self.min_new_token_ratio self.init_new_token_ratio - self.min_new_token_ratio
) / global_config.default_new_token_ratio_decay_steps ) / envs.SGLANG_NEW_TOKEN_RATIO_DECAY_STEPS.get()
self.new_token_ratio = self.init_new_token_ratio self.new_token_ratio = self.init_new_token_ratio
# Init watchdog thread # Init watchdog thread
...@@ -2897,12 +2896,13 @@ class IdleSleeper: ...@@ -2897,12 +2896,13 @@ class IdleSleeper:
for s in sockets: for s in sockets:
self.poller.register(s, zmq.POLLIN) self.poller.register(s, zmq.POLLIN)
self.empty_cache_interval = envs.SGLANG_EMPTY_CACHE_INTERVAL.get()
def maybe_sleep(self): def maybe_sleep(self):
self.poller.poll(1000) self.poller.poll(1000)
if ( if (
global_config.torch_empty_cache_interval > 0 self.empty_cache_interval > 0
and time.time() - self.last_empty_time and time.time() - self.last_empty_time > self.empty_cache_interval
> global_config.torch_empty_cache_interval
): ):
self.last_empty_time = time.time() self.last_empty_time = time.time()
torch.cuda.empty_cache() torch.cuda.empty_cache()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment