Unverified Commit f4941906 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Move args from `global_config` to `environ` (#11332)

parent 01e59e82
"""Global configurations"""
import os
# FIXME: deprecate this file and move all usage to sglang.srt.environ or sglang.__init__.py
class GlobalConfig:
......@@ -20,27 +20,6 @@ class GlobalConfig:
# Default backend of the language
self.default_backend = None
# Runtime constants: New generation token ratio estimation
self.default_init_new_token_ratio = float(
os.environ.get("SGLANG_INIT_NEW_TOKEN_RATIO", 0.7)
)
self.default_min_new_token_ratio_factor = float(
os.environ.get("SGLANG_MIN_NEW_TOKEN_RATIO_FACTOR", 0.14)
)
self.default_new_token_ratio_decay_steps = float(
os.environ.get("SGLANG_NEW_TOKEN_RATIO_DECAY_STEPS", 600)
)
self.torch_empty_cache_interval = float(
os.environ.get(
"SGLANG_EMPTY_CACHE_INTERVAL", -1
) # in seconds. Set if you observe high memory accumulation over a long serving period.
)
# Runtime constants: others
self.retract_decode_steps = 20
self.flashinfer_workspace_size = int(
os.environ.get("FLASHINFER_WORKSPACE_SIZE", 384 * 1024 * 1024)
)
# Output tokenization configs
self.skip_special_tokens_in_output = True
self.spaces_between_special_tokens_in_out = True
......
......@@ -128,6 +128,14 @@ class Envs:
SGLANG_SIMULATE_ACC_METHOD = EnvStr("multinomial")
SGLANG_TORCH_PROFILER_DIR = EnvStr("/tmp")
# Scheduler: new token ratio hyperparameters
SGLANG_INIT_NEW_TOKEN_RATIO = EnvFloat(0.7)
SGLANG_MIN_NEW_TOKEN_RATIO_FACTOR = EnvFloat(0.14)
SGLANG_NEW_TOKEN_RATIO_DECAY_STEPS = EnvInt(600)
SGLANG_RETRACT_DECODE_STEPS = EnvInt(20)
# Scheduler: others:
SGLANG_EMPTY_CACHE_INTERVAL = EnvFloat(-1) # in seconds. Set if you observe high memory accumulation over a long serving period.
# Test: pd-disaggregation
SGLANG_TEST_PD_DISAGG_BACKEND = EnvStr("mooncake")
SGLANG_TEST_PD_DISAGG_DEVICES = EnvStr(None)
......@@ -159,6 +167,7 @@ class Envs:
# Flashinfer
SGLANG_IS_FLASHINFER_AVAILABLE = EnvBool(True)
SGLANG_ENABLE_FLASHINFER_GEMM = EnvBool(False)
SGLANG_FLASHINFER_WORKSPACE_SIZE = EnvInt(384 * 1024 * 1024)
# Triton
SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS = EnvBool(False)
......
......@@ -16,13 +16,7 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Union
import torch
if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
torch._logging.set_logs(dynamo=logging.ERROR)
torch._dynamo.config.suppress_errors = True
logger = logging.getLogger(__name__)
from sglang.global_config import global_config
from sglang.srt.environ import envs
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
from sglang.srt.layers.dp_attention import get_attention_tp_size
......@@ -41,6 +35,12 @@ if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
logger = logging.getLogger(__name__)
if envs.SGLANG_ENABLE_TORCH_COMPILE.get():
torch._logging.set_logs(dynamo=logging.ERROR)
torch._dynamo.config.suppress_errors = True
if is_flashinfer_available():
from flashinfer import (
......@@ -160,7 +160,7 @@ class FlashInferAttnBackend(AttentionBackend):
or "Qwen3ForCausalLM" in model_runner.model_config.hf_config.architectures
or "MiMoForCausalLM" in model_runner.model_config.hf_config.architectures
):
global_config.flashinfer_workspace_size = 512 * 1024 * 1024
envs.SGLANG_FLASHINFER_WORKSPACE_SIZE.set(512 * 1024 * 1024)
# When deterministic inference is enabled, tensor cores should be used for decode
# Also set split tile sizes for prefill and decode from environment variables, and disable kv split for cuda graph
......@@ -180,13 +180,13 @@ class FlashInferAttnBackend(AttentionBackend):
"SGLANG_FLASHINFER_DECODE_SPLIT_TILE_SIZE", 2048
)
self.disable_cuda_graph_kv_split = True
global_config.flashinfer_workspace_size = 2048 * 1024 * 1024
envs.SGLANG_FLASHINFER_WORKSPACE_SIZE.set(2048 * 1024 * 1024)
# Allocate buffers
global global_workspace_buffer
if global_workspace_buffer is None:
# different from flashinfer zero_init_global_workspace_buffer
global_workspace_size = global_config.flashinfer_workspace_size
global_workspace_size = envs.SGLANG_FLASHINFER_WORKSPACE_SIZE.get()
global_workspace_buffer = torch.empty(
global_workspace_size,
dtype=torch.uint8,
......
......@@ -22,7 +22,7 @@ if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
torch._logging.set_logs(dynamo=logging.ERROR)
torch._dynamo.config.suppress_errors = True
from sglang.global_config import global_config
from sglang.srt.environ import envs
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.flashinfer_backend import (
create_flashinfer_kv_indices_triton,
......@@ -204,7 +204,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
if global_workspace_buffer is None:
# different from flashinfer zero_init_global_workspace_buffer
global_workspace_buffer = torch.empty(
global_config.flashinfer_workspace_size,
envs.SGLANG_FLASHINFER_WORKSPACE_SIZE.get(),
dtype=torch.uint8,
device=model_runner.device,
)
......
......@@ -37,7 +37,6 @@ import copy
import dataclasses
import logging
import re
import threading
import time
from enum import Enum, auto
from http import HTTPStatus
......@@ -47,7 +46,6 @@ from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union
import numpy as np
import torch
from sglang.global_config import global_config
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
from sglang.srt.disaggregation.base import BaseKVSender
from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
......@@ -55,6 +53,7 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
)
from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
from sglang.srt.environ import envs
from sglang.srt.mem_cache.allocator import (
BaseTokenToKVPoolAllocator,
SWATokenToKVPoolAllocator,
......@@ -1481,7 +1480,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs)
new_estimate_ratio = (
total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs)
total_decoded_tokens
+ envs.SGLANG_RETRACT_DECODE_STEPS.get() * len(self.reqs)
) / total_max_new_tokens
new_estimate_ratio = min(1.0, new_estimate_ratio)
......@@ -1520,7 +1520,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.tree_cache.dec_lock_ref(req.last_node)
# NOTE(lsyin): we should use the newly evictable memory instantly.
num_tokens = remaing_req_count * global_config.retract_decode_steps
num_tokens = remaing_req_count * envs.SGLANG_RETRACT_DECODE_STEPS.get()
self._evict_tree_cache_if_needed(num_tokens)
req.reset_for_retract()
......
......@@ -35,7 +35,6 @@ from torch.cuda import Stream as CudaStream
from torch.cuda import StreamContext as CudaStreamContext
from torch.distributed import barrier
from sglang.global_config import global_config
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.constrained.base_grammar_backend import (
INVALID_GRAMMAR_OBJ,
......@@ -61,6 +60,7 @@ from sglang.srt.disaggregation.utils import (
prepare_abort,
)
from sglang.srt.distributed import get_pp_group, get_world_group
from sglang.srt.environ import envs
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
......@@ -556,18 +556,17 @@ class Scheduler(
server_args.schedule_conservativeness >= 0
), "Invalid schedule_conservativeness"
self.init_new_token_ratio = min(
global_config.default_init_new_token_ratio
envs.SGLANG_INIT_NEW_TOKEN_RATIO.get()
* server_args.schedule_conservativeness,
1.0,
)
self.min_new_token_ratio = min(
self.init_new_token_ratio
* global_config.default_min_new_token_ratio_factor,
self.init_new_token_ratio * envs.SGLANG_MIN_NEW_TOKEN_RATIO_FACTOR.get(),
1.0,
)
self.new_token_ratio_decay = (
self.init_new_token_ratio - self.min_new_token_ratio
) / global_config.default_new_token_ratio_decay_steps
) / envs.SGLANG_NEW_TOKEN_RATIO_DECAY_STEPS.get()
self.new_token_ratio = self.init_new_token_ratio
# Init watchdog thread
......@@ -2897,12 +2896,13 @@ class IdleSleeper:
for s in sockets:
self.poller.register(s, zmq.POLLIN)
self.empty_cache_interval = envs.SGLANG_EMPTY_CACHE_INTERVAL.get()
def maybe_sleep(self):
self.poller.poll(1000)
if (
global_config.torch_empty_cache_interval > 0
and time.time() - self.last_empty_time
> global_config.torch_empty_cache_interval
self.empty_cache_interval > 0
and time.time() - self.last_empty_time > self.empty_cache_interval
):
self.last_empty_time = time.time()
torch.cuda.empty_cache()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment