Unverified Commit c05956e5 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Simplify port allocation (#447)

parent d75dc20f
...@@ -7,7 +7,7 @@ import zmq.asyncio ...@@ -7,7 +7,7 @@ import zmq.asyncio
from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import get_exception_traceback from sglang.utils import get_exception_traceback
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
......
...@@ -8,7 +8,7 @@ import zmq.asyncio ...@@ -8,7 +8,7 @@ import zmq.asyncio
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.managers.router.model_rpc import ModelRpcClient from sglang.srt.managers.router.model_rpc import ModelRpcClient
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import get_exception_traceback from sglang.utils import get_exception_traceback
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
......
...@@ -31,11 +31,12 @@ from sglang.srt.managers.router.scheduler import Scheduler ...@@ -31,11 +31,12 @@ from sglang.srt.managers.router.scheduler import Scheduler
from sglang.srt.model_config import ModelConfig from sglang.srt.model_config import ModelConfig
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
get_exception_traceback,
get_int_token_logit_bias, get_int_token_logit_bias,
is_multimodal_model, is_multimodal_model,
set_random_seed, set_random_seed,
) )
from sglang.utils import get_exception_traceback
logger = logging.getLogger("model_rpc") logger = logging.getLogger("model_rpc")
vllm_default_logger.setLevel(logging.WARN) vllm_default_logger.setLevel(logging.WARN)
......
...@@ -20,7 +20,6 @@ from sglang.srt.hf_transformers_utils import ( ...@@ -20,7 +20,6 @@ from sglang.srt.hf_transformers_utils import (
) )
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
BatchStrOut, BatchStrOut,
DetokenizeReqInput,
FlushCacheReq, FlushCacheReq,
GenerateReqInput, GenerateReqInput,
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
...@@ -28,7 +27,8 @@ from sglang.srt.managers.io_struct import ( ...@@ -28,7 +27,8 @@ from sglang.srt.managers.io_struct import (
from sglang.srt.mm_utils import expand2square, process_anyres_image from sglang.srt.mm_utils import expand2square, process_anyres_image
from sglang.srt.sampling_params import SamplingParams from sglang.srt.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import get_exception_traceback, is_multimodal_model, load_image from sglang.srt.utils import is_multimodal_model, load_image
from sglang.utils import get_exception_traceback
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
......
...@@ -41,8 +41,9 @@ from sglang.srt.utils import ( ...@@ -41,8 +41,9 @@ from sglang.srt.utils import (
allocate_init_ports, allocate_init_ports,
assert_pkg_version, assert_pkg_version,
enable_show_time_cost, enable_show_time_cost,
get_exception_traceback,
) )
from sglang.utils import get_exception_traceback
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
......
"""Common utilities.""" """Common utilities."""
import base64 import base64
import logging
import os import os
import random import random
import socket import socket
...@@ -18,7 +19,9 @@ from packaging import version as pkg_version ...@@ -18,7 +19,9 @@ from packaging import version as pkg_version
from pydantic import BaseModel from pydantic import BaseModel
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__)
show_time_cost = False show_time_cost = False
time_infos = {} time_infos = {}
...@@ -124,31 +127,12 @@ def set_random_seed(seed: int) -> None: ...@@ -124,31 +127,12 @@ def set_random_seed(seed: int) -> None:
torch.cuda.manual_seed_all(seed) torch.cuda.manual_seed_all(seed)
def alloc_usable_network_port(num, used_list=()): def is_port_available(port):
port_list = []
for port in range(10000, 65536):
if port in used_list:
continue
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
try:
s.bind(("", port))
s.listen(1) # Attempt to listen on the port
port_list.append(port)
except socket.error:
pass # If any error occurs, this port is not usable
if len(port_list) == num:
return port_list
return None
def check_port(port):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
try: try:
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.bind(("", port)) s.bind(("", port))
s.listen(1)
return True return True
except socket.error: except socket.error:
return False return False
...@@ -159,31 +143,23 @@ def allocate_init_ports( ...@@ -159,31 +143,23 @@ def allocate_init_ports(
additional_ports: Optional[List[int]] = None, additional_ports: Optional[List[int]] = None,
tp_size: int = 1, tp_size: int = 1,
): ):
port = 30000 if port is None else port if additional_ports:
additional_ports = [] if additional_ports is None else additional_ports ret_ports = [port] + additional_ports
additional_ports = ( else:
[additional_ports] if isinstance(additional_ports, int) else additional_ports ret_ports = [port]
)
# first check on server port ret_ports = list(set(x for x in ret_ports if is_port_available(x)))
if not check_port(port): cur_port = ret_ports[-1] + 1 if len(ret_ports) > 0 else 10000
new_port = alloc_usable_network_port(1, used_list=[port])[0]
print(f"WARNING: Port {port} is not available. Use {new_port} instead.") while len(ret_ports) < 5 + tp_size:
port = new_port if cur_port not in ret_ports and is_port_available(cur_port):
ret_ports.append(cur_port)
# then we check on additional ports cur_port += 1
additional_unique_ports = set(additional_ports) - {port}
# filter out ports that are already in use if port and ret_ports[0] != port:
can_use_ports = [port for port in additional_unique_ports if check_port(port)] logger.warn(f"WARNING: Port {port} is not available. Use port {ret_ports[0]} instead.")
num_specified_ports = len(can_use_ports)
if num_specified_ports < 4 + tp_size:
addtional_can_use_ports = alloc_usable_network_port(
num=4 + tp_size - num_specified_ports, used_list=can_use_ports + [port]
)
can_use_ports.extend(addtional_can_use_ports)
additional_ports = can_use_ports[: 4 + tp_size] return ret_ports[0], ret_ports[1:]
return port, additional_ports
def get_int_token_logit_bias(tokenizer, vocab_size): def get_int_token_logit_bias(tokenizer, vocab_size):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment