Unverified Commit 114bbc86 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Use ipc instead of tcp in zmq (#1566)

parent 32eb6e96
...@@ -223,7 +223,6 @@ if __name__ == "__main__": ...@@ -223,7 +223,6 @@ if __name__ == "__main__":
model_path=args.model_path, # "liuhaotian/llava-v1.6-vicuna-7b", model_path=args.model_path, # "liuhaotian/llava-v1.6-vicuna-7b",
tokenizer_path=tokenizer_path, tokenizer_path=tokenizer_path,
port=cur_port, port=cur_port,
additional_ports=[cur_port + 1, cur_port + 2, cur_port + 3, cur_port + 4],
json_model_override_args=json.dumps(model_override_args), json_model_override_args=json.dumps(model_override_args),
tp_size=1, tp_size=1,
) )
......
...@@ -66,9 +66,8 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch ...@@ -66,9 +66,8 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server import _set_envs_and_config from sglang.srt.server import _set_envs_and_config
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
allocate_init_ports,
configure_logger, configure_logger,
kill_child_process, kill_child_process,
suppress_other_loggers, suppress_other_loggers,
...@@ -127,11 +126,7 @@ def load_model(server_args, tp_rank): ...@@ -127,11 +126,7 @@ def load_model(server_args, tp_rank):
suppress_other_loggers() suppress_other_loggers()
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
server_args.port, server_args.additional_ports = allocate_init_ports( port_args = PortArgs.init_new(server_args)
server_args.port,
server_args.additional_ports,
server_args.dp_size,
)
model_config = ModelConfig( model_config = ModelConfig(
server_args.model_path, server_args.model_path,
server_args.trust_remote_code, server_args.trust_remote_code,
...@@ -143,7 +138,7 @@ def load_model(server_args, tp_rank): ...@@ -143,7 +138,7 @@ def load_model(server_args, tp_rank):
gpu_id=tp_rank, gpu_id=tp_rank,
tp_rank=tp_rank, tp_rank=tp_rank,
tp_size=server_args.tp_size, tp_size=server_args.tp_size,
nccl_port=server_args.additional_ports[-1], nccl_port=port_args.nccl_ports[0],
server_args=server_args, server_args=server_args,
) )
rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}") rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}")
......
...@@ -59,10 +59,10 @@ class DetokenizerManager: ...@@ -59,10 +59,10 @@ class DetokenizerManager:
# Init inter-process communication # Init inter-process communication
context = zmq.Context(2) context = zmq.Context(2)
self.recv_from_scheduler = context.socket(zmq.PULL) self.recv_from_scheduler = context.socket(zmq.PULL)
self.recv_from_scheduler.bind(f"tcp://127.0.0.1:{port_args.detokenizer_port}") self.recv_from_scheduler.bind(f"ipc://{port_args.detokenizer_ipc_name}")
self.send_to_tokenizer = context.socket(zmq.PUSH) self.send_to_tokenizer = context.socket(zmq.PUSH)
self.send_to_tokenizer.connect(f"tcp://127.0.0.1:{port_args.tokenizer_port}") self.send_to_tokenizer.connect(f"ipc://{port_args.tokenizer_ipc_name}")
if server_args.skip_tokenizer_init: if server_args.skip_tokenizer_init:
self.tokenizer = None self.tokenizer = None
......
...@@ -96,14 +96,10 @@ class Scheduler: ...@@ -96,14 +96,10 @@ class Scheduler:
if self.tp_rank == 0: if self.tp_rank == 0:
self.recv_from_tokenizer = context.socket(zmq.PULL) self.recv_from_tokenizer = context.socket(zmq.PULL)
self.recv_from_tokenizer.bind( self.recv_from_tokenizer.bind(f"ipc://{port_args.scheduler_input_ipc_name}")
f"tcp://127.0.0.1:{port_args.scheduler_input_port}"
)
self.send_to_detokenizer = context.socket(zmq.PUSH) self.send_to_detokenizer = context.socket(zmq.PUSH)
self.send_to_detokenizer.connect( self.send_to_detokenizer.connect(f"ipc://{port_args.detokenizer_ipc_name}")
f"tcp://127.0.0.1:{port_args.detokenizer_port}"
)
else: else:
self.recv_from_tokenizer = self.send_to_detokenizer = None self.recv_from_tokenizer = self.send_to_detokenizer = None
......
...@@ -84,12 +84,10 @@ class TokenizerManager: ...@@ -84,12 +84,10 @@ class TokenizerManager:
# Init inter-process communication # Init inter-process communication
context = zmq.asyncio.Context(2) context = zmq.asyncio.Context(2)
self.recv_from_detokenizer = context.socket(zmq.PULL) self.recv_from_detokenizer = context.socket(zmq.PULL)
self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}") self.recv_from_detokenizer.bind(f"ipc://{port_args.tokenizer_ipc_name}")
self.send_to_scheduler = context.socket(zmq.PUSH) self.send_to_scheduler = context.socket(zmq.PUSH)
self.send_to_scheduler.connect( self.send_to_scheduler.connect(f"ipc://{port_args.scheduler_input_ipc_name}")
f"tcp://127.0.0.1:{port_args.scheduler_input_port}"
)
# Read model args # Read model args
self.model_path = server_args.model_path self.model_path = server_args.model_path
......
...@@ -16,7 +16,6 @@ limitations under the License. ...@@ -16,7 +16,6 @@ limitations under the License.
"""Memory pool.""" """Memory pool."""
import logging import logging
from abc import ABC, abstractmethod
from typing import List, Tuple, Union from typing import List, Tuple, Union
import numpy as np import numpy as np
...@@ -62,9 +61,11 @@ class BaseTokenToKVPool: ...@@ -62,9 +61,11 @@ class BaseTokenToKVPool:
self, self,
size: int, size: int,
dtype: torch.dtype, dtype: torch.dtype,
device: str,
): ):
self.size = size self.size = size
self.dtype = dtype self.dtype = dtype
self.device = device
if dtype == torch.float8_e5m2: if dtype == torch.float8_e5m2:
# NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2 # NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2
self.store_dtype = torch.uint8 self.store_dtype = torch.uint8
...@@ -84,7 +85,7 @@ class BaseTokenToKVPool: ...@@ -84,7 +85,7 @@ class BaseTokenToKVPool:
select_index = self.free_slots[:need_size] select_index = self.free_slots[:need_size]
self.free_slots = self.free_slots[need_size:] self.free_slots = self.free_slots[need_size:]
return torch.tensor(select_index, dtype=torch.int32, device="cuda") return torch.tensor(select_index, dtype=torch.int32, device=self.device)
def free(self, free_index: torch.Tensor): def free(self, free_index: torch.Tensor):
self.free_slots = np.concatenate((self.free_slots, free_index.cpu().numpy())) self.free_slots = np.concatenate((self.free_slots, free_index.cpu().numpy()))
...@@ -123,7 +124,7 @@ class MHATokenToKVPool(BaseTokenToKVPool): ...@@ -123,7 +124,7 @@ class MHATokenToKVPool(BaseTokenToKVPool):
layer_num: int, layer_num: int,
device: str, device: str,
): ):
super().__init__(size, dtype) super().__init__(size, dtype, device)
# [size, head_num, head_dim] for each layer # [size, head_num, head_dim] for each layer
# The padded slot 0 is used for writing dummy outputs from padded tokens. # The padded slot 0 is used for writing dummy outputs from padded tokens.
...@@ -187,7 +188,7 @@ class MLATokenToKVPool(BaseTokenToKVPool): ...@@ -187,7 +188,7 @@ class MLATokenToKVPool(BaseTokenToKVPool):
layer_num: int, layer_num: int,
device: str, device: str,
): ):
super().__init__(size, dtype) super().__init__(size, dtype, device)
self.kv_lora_rank = kv_lora_rank self.kv_lora_rank = kv_lora_rank
# The padded slot 0 is used for writing dummy outputs from padded tokens. # The padded slot 0 is used for writing dummy outputs from padded tokens.
......
...@@ -24,6 +24,7 @@ import json ...@@ -24,6 +24,7 @@ import json
import logging import logging
import multiprocessing as mp import multiprocessing as mp
import os import os
import random
import threading import threading
import time import time
from http import HTTPStatus from http import HTTPStatus
...@@ -68,9 +69,9 @@ from sglang.srt.openai_api.protocol import ModelCard, ModelList ...@@ -68,9 +69,9 @@ from sglang.srt.openai_api.protocol import ModelCard, ModelList
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
add_api_key_middleware, add_api_key_middleware,
allocate_init_ports,
assert_pkg_version, assert_pkg_version,
configure_logger, configure_logger,
is_port_available,
kill_child_process, kill_child_process,
maybe_set_triton_cache_manager, maybe_set_triton_cache_manager,
prepare_model_and_tokenizer, prepare_model_and_tokenizer,
...@@ -302,18 +303,7 @@ def launch_server( ...@@ -302,18 +303,7 @@ def launch_server(
_set_envs_and_config(server_args) _set_envs_and_config(server_args)
# Allocate ports for inter-process communications # Allocate ports for inter-process communications
server_args.port, server_args.additional_ports = allocate_init_ports( port_args = PortArgs.init_new(server_args)
server_args.port,
server_args.additional_ports,
server_args.dp_size,
)
ports = server_args.additional_ports
port_args = PortArgs(
tokenizer_port=ports[0],
scheduler_input_port=ports[1],
detokenizer_port=ports[2],
nccl_ports=ports[3:],
)
logger.info(f"{server_args=}") logger.info(f"{server_args=}")
# If using model from www.modelscope.cn, first download the model. # If using model from www.modelscope.cn, first download the model.
...@@ -499,17 +489,16 @@ class Runtime: ...@@ -499,17 +489,16 @@ class Runtime:
self.server_args = ServerArgs(*args, log_level=log_level, **kwargs) self.server_args = ServerArgs(*args, log_level=log_level, **kwargs)
# Pre-allocate ports # Pre-allocate ports
self.server_args.port, self.server_args.additional_ports = allocate_init_ports( for port in range(10000, 40000):
self.server_args.port, if is_port_available(port):
self.server_args.additional_ports, break
self.server_args.dp_size, port += 1
) self.server_args.port = port
self.url = self.server_args.url() self.url = self.server_args.url()
self.generate_url = ( self.generate_url = self.url + "/generate"
f"http://{self.server_args.host}:{self.server_args.port}/generate"
)
# NOTE: We store pid instead of proc to fix some issues during __delete__
self.pid = None self.pid = None
pipe_reader, pipe_writer = mp.Pipe(duplex=False) pipe_reader, pipe_writer = mp.Pipe(duplex=False)
......
...@@ -19,9 +19,10 @@ import argparse ...@@ -19,9 +19,10 @@ import argparse
import dataclasses import dataclasses
import logging import logging
import random import random
from typing import List, Optional, Union import tempfile
from typing import List, Optional
from sglang.srt.utils import is_hip, is_ipv6 from sglang.srt.utils import is_hip, is_ipv6, is_port_available
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -46,7 +47,6 @@ class ServerArgs: ...@@ -46,7 +47,6 @@ class ServerArgs:
# Port # Port
host: str = "127.0.0.1" host: str = "127.0.0.1"
port: int = 30000 port: int = 30000
additional_ports: Optional[Union[List[int], int]] = None
# Memory and scheduling # Memory and scheduling
mem_fraction_static: Optional[float] = None mem_fraction_static: Optional[float] = None
...@@ -134,11 +134,6 @@ class ServerArgs: ...@@ -134,11 +134,6 @@ class ServerArgs:
else: else:
self.mem_fraction_static = 0.88 self.mem_fraction_static = 0.88
if isinstance(self.additional_ports, int):
self.additional_ports = [self.additional_ports]
elif self.additional_ports is None:
self.additional_ports = []
if self.random_seed is None: if self.random_seed is None:
self.random_seed = random.randint(0, 1 << 30) self.random_seed = random.randint(0, 1 << 30)
...@@ -199,13 +194,6 @@ class ServerArgs: ...@@ -199,13 +194,6 @@ class ServerArgs:
parser.add_argument( parser.add_argument(
"--port", type=int, default=ServerArgs.port, help="The port of the server." "--port", type=int, default=ServerArgs.port, help="The port of the server."
) )
parser.add_argument(
"--additional-ports",
type=int,
nargs="*",
default=[],
help="The additional ports specified for the server.",
)
parser.add_argument( parser.add_argument(
"--tokenizer-mode", "--tokenizer-mode",
type=str, type=str,
...@@ -625,16 +613,31 @@ def prepare_server_args(argv: List[str]) -> ServerArgs: ...@@ -625,16 +613,31 @@ def prepare_server_args(argv: List[str]) -> ServerArgs:
@dataclasses.dataclass @dataclasses.dataclass
class PortArgs: class PortArgs:
# The port for tokenizer to receive inputs from detokenizer (zmq) # The ipc filename for tokenizer to receive inputs from detokenizer (zmq)
tokenizer_port: int tokenizer_ipc_name: str
# The port for scheduler (rank 0) to receive inputs from tokenizer (zmq) # The ipc filename for scheduler (rank 0) to receive inputs from tokenizer (zmq)
scheduler_input_port: int scheduler_input_ipc_name: str
# The port for detokenizer to receive inputs from scheduler (zmq) # The ipc filename for detokenizer to receive inputs from scheduler (zmq)
detokenizer_port: int detokenizer_ipc_name: str
# The port for nccl initialization for multiple TP groups (torch.dist) # The port for nccl initialization for multiple TP groups (torch.dist)
nccl_ports: List[int] nccl_ports: List[int]
@classmethod
def init_new(self, server_args):
port = server_args.port + 1
while True:
if is_port_available(port):
break
port += 1
return PortArgs(
tokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
scheduler_input_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
detokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
nccl_ports=[port],
)
class LoRAPathAction(argparse.Action): class LoRAPathAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None): def __call__(self, parser, namespace, values, option_string=None):
......
...@@ -177,35 +177,6 @@ def is_port_available(port): ...@@ -177,35 +177,6 @@ def is_port_available(port):
return False return False
def allocate_init_ports(
port: Optional[int] = None,
additional_ports: Optional[List[int]] = None,
dp_size: int = 1,
):
"""Allocate ports for all connections."""
if additional_ports:
ret_ports = [port] + additional_ports
else:
ret_ports = [port]
ret_ports = list(set(x for x in ret_ports if is_port_available(x)))
cur_port = ret_ports[-1] + 1 if len(ret_ports) > 0 else 10000
# HTTP + Tokenizer + Controller + Detokenizer + dp_size * 1 (nccl)
num_ports_needed = 4 + dp_size
while len(ret_ports) < num_ports_needed:
if cur_port not in ret_ports and is_port_available(cur_port):
ret_ports.append(cur_port)
cur_port += 1
if port is not None and ret_ports[0] != port:
logger.warning(
f"WARNING: Port {port} is not available. Use port {ret_ports[0]} instead."
)
return ret_ports[0], ret_ports[1:num_ports_needed]
def is_multimodal_model(model_architectures): def is_multimodal_model(model_architectures):
if ( if (
"LlavaLlamaForCausalLM" in model_architectures "LlavaLlamaForCausalLM" in model_architectures
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment