Unverified Commit 4a634cf6 authored by Jay Zhou's avatar Jay Zhou Committed by GitHub
Browse files

[Feature] Allow specifying all ports to use in advance (#116)

parent a49dc52b
...@@ -6,7 +6,7 @@ import os ...@@ -6,7 +6,7 @@ import os
import sys import sys
import threading import threading
import time import time
from typing import List, Optional from typing import List, Optional, Union
# Fix a Python bug # Fix a Python bug
setattr(threading, "_register_atexit", lambda *args, **kwargs: None) setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
...@@ -47,7 +47,7 @@ from sglang.srt.managers.openai_protocol import ( ...@@ -47,7 +47,7 @@ from sglang.srt.managers.openai_protocol import (
from sglang.srt.managers.router.manager import start_router_process from sglang.srt.managers.router.manager import start_router_process
from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import alloc_usable_network_port from sglang.srt.utils import alloc_usable_network_port, handle_port_init
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
...@@ -306,16 +306,17 @@ def launch_server(server_args, pipe_finish_writer): ...@@ -306,16 +306,17 @@ def launch_server(server_args, pipe_finish_writer):
global tokenizer_manager global tokenizer_manager
global chat_template_name global chat_template_name
# Allocate ports # Handle ports
can_use_ports = alloc_usable_network_port( server_args.port, server_args.additional_ports = handle_port_init(
num=4 + server_args.tp_size, used_list=(server_args.port,) server_args.port, server_args.additional_ports, server_args.tp_size
) )
port_args = PortArgs( port_args = PortArgs(
tokenizer_port=can_use_ports[0], tokenizer_port=server_args.additional_ports[0],
router_port=can_use_ports[1], router_port=server_args.additional_ports[1],
detokenizer_port=can_use_ports[2], detokenizer_port=server_args.additional_ports[2],
nccl_port=can_use_ports[3], nccl_port=server_args.additional_ports[3],
model_rpc_ports=can_use_ports[4:], model_rpc_ports=server_args.additional_ports[4:],
) )
# Load chat template if needed # Load chat template if needed
...@@ -435,14 +436,19 @@ class Runtime: ...@@ -435,14 +436,19 @@ class Runtime:
schedule_heuristic: str = "lpm", schedule_heuristic: str = "lpm",
random_seed: int = 42, random_seed: int = 42,
log_level: str = "error", log_level: str = "error",
port: Optional[int] = None,
additional_ports: Optional[Union[List[int], int]] = None,
): ):
host = "127.0.0.1" host = "127.0.0.1"
port = alloc_usable_network_port(1)[0] port, additional_ports = handle_port_init(
port, additional_ports, tp_size
)
self.server_args = ServerArgs( self.server_args = ServerArgs(
model_path=model_path, model_path=model_path,
tokenizer_path=tokenizer_path, tokenizer_path=tokenizer_path,
host=host, host=host,
port=port, port=port,
additional_ports=additional_ports,
load_format=load_format, load_format=load_format,
tokenizer_mode=tokenizer_mode, tokenizer_mode=tokenizer_mode,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
......
import argparse import argparse
import dataclasses import dataclasses
from typing import List, Optional from typing import List, Optional, Union
@dataclasses.dataclass @dataclasses.dataclass
...@@ -9,6 +9,7 @@ class ServerArgs: ...@@ -9,6 +9,7 @@ class ServerArgs:
tokenizer_path: Optional[str] = None tokenizer_path: Optional[str] = None
host: str = "127.0.0.1" host: str = "127.0.0.1"
port: int = 30000 port: int = 30000
additional_ports: Optional[Union[List[int], int]] = None
load_format: str = "auto" load_format: str = "auto"
tokenizer_mode: str = "auto" tokenizer_mode: str = "auto"
chat_template: Optional[str] = None chat_template: Optional[str] = None
...@@ -37,6 +38,10 @@ class ServerArgs: ...@@ -37,6 +38,10 @@ class ServerArgs:
self.mem_fraction_static = 0.85 self.mem_fraction_static = 0.85
else: else:
self.mem_fraction_static = 0.90 self.mem_fraction_static = 0.90
if isinstance(self.additional_ports, int):
self.additional_ports = [self.additional_ports]
elif self.additional_ports is None:
self.additional_ports = []
@staticmethod @staticmethod
def add_cli_args(parser: argparse.ArgumentParser): def add_cli_args(parser: argparse.ArgumentParser):
...@@ -54,6 +59,14 @@ class ServerArgs: ...@@ -54,6 +59,14 @@ class ServerArgs:
) )
parser.add_argument("--host", type=str, default=ServerArgs.host) parser.add_argument("--host", type=str, default=ServerArgs.host)
parser.add_argument("--port", type=int, default=ServerArgs.port) parser.add_argument("--port", type=int, default=ServerArgs.port)
# we want to be able to pass a list of ports
parser.add_argument(
"--additional-ports",
type=int,
nargs="*",
default=[],
help="Additional ports specified for launching server.",
)
parser.add_argument( parser.add_argument(
"--load-format", "--load-format",
type=str, type=str,
......
...@@ -99,6 +99,40 @@ def alloc_usable_network_port(num, used_list=()): ...@@ -99,6 +99,40 @@ def alloc_usable_network_port(num, used_list=()):
return None return None
def check_port(port):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
try:
s.bind(("", port))
return True
except socket.error:
return False
def handle_port_init(port: Optional[int] = None, additional_ports: Optional[List[int]] = None, tp_size: int = 1):
port = 30000 if port is None else port
additional_ports = [] if additional_ports is None else additional_ports
additional_ports = [additional_ports] if isinstance(additional_ports, int) else additional_ports
# first check on server port
if not check_port(port):
new_port = alloc_usable_network_port(1, used_list=[port])[0]
print(f"Port {port} is not available, using {new_port} instead.")
port = new_port
# then we check on additional ports
additional_unique_ports = set(additional_ports) - {port}
# filter out ports that are already in use
can_use_ports = [port for port in additional_unique_ports if check_port(port)]
num_specified_ports = len(can_use_ports)
if num_specified_ports < 4 + tp_size:
addtional_can_use_ports = alloc_usable_network_port(
num=4 + tp_size - num_specified_ports, used_list=can_use_ports + [port]
)
can_use_ports.extend(addtional_can_use_ports)
additional_ports = can_use_ports[:4 + tp_size]
return port, additional_ports
def get_exception_traceback(): def get_exception_traceback():
etype, value, tb = sys.exc_info() etype, value, tb = sys.exc_info()
err_str = "".join(traceback.format_exception(etype, value, tb)) err_str = "".join(traceback.format_exception(etype, value, tb))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment