Unverified Commit 1a8f5f68 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Super tiny rename environment variable (#6648)

parent 32cd7070
...@@ -31,7 +31,12 @@ from sglang.srt.disaggregation.base.conn import ( ...@@ -31,7 +31,12 @@ from sglang.srt.disaggregation.base.conn import (
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_free_port, get_ip, get_local_ip_by_remote from sglang.srt.utils import (
get_free_port,
get_int_env_var,
get_ip,
get_local_ip_by_remote,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -172,13 +177,11 @@ class MooncakeKVManager(BaseKVManager): ...@@ -172,13 +177,11 @@ class MooncakeKVManager(BaseKVManager):
# Determine the number of threads to use for kv sender # Determine the number of threads to use for kv sender
cpu_count = os.cpu_count() cpu_count = os.cpu_count()
self.executor = concurrent.futures.ThreadPoolExecutor( self.executor = concurrent.futures.ThreadPoolExecutor(
int( get_int_env_var(
os.getenv( "SGLANG_DISAGGREGATION_THREAD_POOL_SIZE",
"DISAGGREGATION_THREAD_POOL_SIZE",
min(max(1, cpu_count // 8), 8), min(max(1, cpu_count // 8), 8),
) )
) )
)
elif self.disaggregation_mode == DisaggregationMode.DECODE: elif self.disaggregation_mode == DisaggregationMode.DECODE:
self.heartbeat_failures = {} self.heartbeat_failures = {}
self.session_pool = defaultdict(requests.Session) self.session_pool = defaultdict(requests.Session)
...@@ -187,11 +190,11 @@ class MooncakeKVManager(BaseKVManager): ...@@ -187,11 +190,11 @@ class MooncakeKVManager(BaseKVManager):
self.connection_lock = threading.Lock() self.connection_lock = threading.Lock()
# Heartbeat interval should be at least 2 seconds # Heartbeat interval should be at least 2 seconds
self.heartbeat_interval = max( self.heartbeat_interval = max(
float(os.getenv("DISAGGREGATION_HEARTBEAT_INTERVAL", 5.0)), 2.0 float(os.getenv("SGLANG_DISAGGREGATION_HEARTBEAT_INTERVAL", 5.0)), 2.0
) )
# Heartbeat failure should be at least 1 # Heartbeat failure should be at least 1
self.max_failures = max( self.max_failures = max(
int(os.getenv("DISAGGREGATION_HEARTBEAT_MAX_FAILURE", 2)), 1 int(os.getenv("SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE", 2)), 1
) )
self.start_decode_thread() self.start_decode_thread()
self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {} self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment