"vscode:/vscode.git/clone" did not exist on "5de43476632863dd2d3540e7b1d2e18c2fc14aec"
Unverified Commit d774acad authored by Mingyi's avatar Mingyi Committed by GitHub
Browse files

Remove the dependency of rpyc (#646)

parent d93388da
......@@ -21,7 +21,7 @@ dependencies = [
[project.optional-dependencies]
srt = ["aiohttp", "fastapi", "hf_transfer", "huggingface_hub", "interegular", "packaging", "pillow",
"psutil", "pydantic", "rpyc", "torch", "uvicorn", "uvloop", "zmq", "vllm==0.5.1", "outlines>=0.0.44"]
"psutil", "pydantic", "torch", "uvicorn", "uvloop", "zmq", "vllm==0.5.1", "outlines>=0.0.44"]
openai = ["openai>=1.0", "tiktoken"]
anthropic = ["anthropic>=0.20.0"]
litellm = ["litellm>=1.0.0"]
......
......@@ -11,4 +11,4 @@ if __name__ == "__main__":
args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args)
launch_server(server_args, None)
launch_server(server_args)
"""Launch the inference server for Llava-video model."""
import argparse
import multiprocessing as mp
from sglang.srt.server import ServerArgs, launch_server
......@@ -27,6 +26,4 @@ if __name__ == "__main__":
server_args = ServerArgs.from_cli_args(args)
pipe_reader, pipe_writer = mp.Pipe(duplex=False)
launch_server(server_args, pipe_writer, model_overide_args)
launch_server(server_args, model_overide_args, None)
"""A data parallel worker thread."""
import asyncio
import logging
import queue
import threading
from typing import Callable, List
import uvloop
import zmq
from sglang.global_config import global_config
from sglang.srt.managers.controller.tp_worker import ModelTpClient
from sglang.srt.managers.io_struct import BatchTokenIDOut
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import kill_parent_process
from sglang.utils import get_exception_traceback
logger = logging.getLogger("srt.controller")
CHECKING_INTERVAL = 5
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
class DataParallelWorkerThread(threading.Thread):
def __init__(
self,
worker_id: int,
request_queue: queue.Queue,
detokenizer_port: int,
step_func: Callable,
):
super(DataParallelWorkerThread, self).__init__()
self.worker_id = worker_id
self.request_queue = request_queue
self.liveness = True
self.request_dependency_delay = global_config.request_dependency_delay
context = zmq.asyncio.Context()
self.send_to_detokenizer = context.socket(zmq.PUSH)
self.send_to_detokenizer.connect(f"tcp://127.0.0.1:{detokenizer_port}")
self.step = step_func
async def loop_for_forward(self):
while self.liveness:
requests = []
while not self.request_queue.empty():
requests.append(self.request_queue.get())
out_pyobjs: List[BatchTokenIDOut] = []
try:
out_pyobjs = await self.step(requests)
except Exception:
for r in requests:
self.request_queue.put(r)
logger.error(
f"Worker thread {self.worker_id}: "
f"failed to get back from Model Server\n"
f"{get_exception_traceback()}"
)
self.liveness = False
# Crash the whole server when there are any errors.
# TODO(lianmin): make this an option.
kill_parent_process()
return
for obj in out_pyobjs:
self.send_to_detokenizer.send_pyobj(obj)
# async sleep for receiving the subsequent request and avoiding cache miss
if len(out_pyobjs) != 0:
has_finished = any(
[obj.finished_reason is not None for obj in out_pyobjs]
)
if has_finished:
await asyncio.sleep(self.request_dependency_delay)
await asyncio.sleep(global_config.wait_for_new_request_delay)
async def monitoring(self):
while True:
await asyncio.sleep(CHECKING_INTERVAL)
# can plug in monitoring logic here
def run(self):
logger.info(f"DataParallelWorkerThread {self.worker_id} start")
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.create_task(self.monitoring())
loop.run_until_complete(self.loop_for_forward())
def start_data_parallel_worker(
server_args: ServerArgs,
port_args: PortArgs,
model_overide_args,
gpu_ids: List[int],
worker_id: int,
):
model_tp_client = ModelTpClient(
gpu_ids,
server_args,
port_args.model_port_args[worker_id],
model_overide_args,
)
worker_thread = DataParallelWorkerThread(
worker_id=worker_id,
request_queue=queue.Queue(),
detokenizer_port=port_args.detokenizer_port,
step_func=model_tp_client.step,
)
worker_thread.start()
return worker_thread
......@@ -3,19 +3,17 @@ A controller that manages multiple data parallel workers.
Each data parallel worker can manage multiple tensor parallel workers.
"""
import asyncio
import dataclasses
import logging
from concurrent.futures import ThreadPoolExecutor
import multiprocessing
import os
from enum import Enum, auto
from typing import Dict
import numpy as np
import zmq
import zmq.asyncio
from sglang.global_config import global_config
from sglang.srt.managers.controller.dp_worker import (
DataParallelWorkerThread,
start_data_parallel_worker,
from sglang.srt.managers.controller.manager_single import (
start_controller_process as start_controller_process_single,
)
from sglang.srt.managers.io_struct import (
AbortReq,
......@@ -23,12 +21,14 @@ from sglang.srt.managers.io_struct import (
TokenizedGenerateReqInput,
)
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import kill_parent_process
from sglang.utils import get_exception_traceback
logger = logging.getLogger("srt.controller")
class LoadBalanceMethod(Enum):
"""Load balance method."""
ROUND_ROBIN = auto()
SHORTEST_QUEUE = auto()
......@@ -41,155 +41,155 @@ class LoadBalanceMethod(Enum):
raise ValueError(f"Invalid load balance method: {method}") from exc
class Controller:
@dataclasses.dataclass
class WorkerHandle:
"""Store the handle of a data parallel worker."""
proc: multiprocessing.Process
queue: multiprocessing.Queue
class ControllerMulti:
"""A controller that manages multiple data parallel workers."""
def __init__(
self,
load_balance_method: str,
server_args: ServerArgs,
port_args: PortArgs,
model_overide_args,
):
self.load_balance_method = LoadBalanceMethod.from_str(load_balance_method)
# Parse args
self.server_args = server_args
self.port_args = port_args
self.model_overide_args = model_overide_args
self.load_balance_method = LoadBalanceMethod.from_str(
server_args.load_balance_method)
if self.load_balance_method == LoadBalanceMethod.ROUND_ROBIN:
self.round_robin_counter = 0
# Init communication
context = zmq.Context()
self.recv_from_tokenizer = context.socket(zmq.PULL)
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.controller_port}")
self.dispatch_lookup = {
# Dispatch method
self.round_robin_counter = 0
dispatch_lookup = {
LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler,
LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler,
}
self.dispatching = self.dispatch_lookup[self.load_balance_method]
# Init communication
context = zmq.asyncio.Context()
self.recv_from_tokenizer = context.socket(zmq.PULL)
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.router_port}")
# Init status
self.recv_reqs = []
self.dispatching = dispatch_lookup[self.load_balance_method]
# Start data parallel workers
self.workers: Dict[int, DataParallelWorkerThread] = {}
tp_size = server_args.tp_size
def start_dp_worker(i):
try:
gpu_ids = list(range(i * tp_size, (i + 1) * tp_size))
worker_thread = start_data_parallel_worker(
server_args, port_args, model_overide_args, gpu_ids, i
)
self.workers[i] = worker_thread
except Exception:
logger.error(
f"Failed to start local worker {i}\n{get_exception_traceback()}"
)
self.workers = []
for i in range(server_args.dp_size):
start_dp_worker(i)
# Parallel launch is slower, probably due to the disk bandwidth limitations.
# with ThreadPoolExecutor(server_args.dp_size) as executor:
# executor.map(start_dp_worker, range(server_args.dp_size))
def have_any_live_worker(self):
return any(worker_thread.liveness for worker_thread in self.workers.values())
self.start_dp_worker(i)
def start_dp_worker(self, dp_worker_id: int):
tp_size = self.server_args.tp_size
pipe_controller_reader, pipe_controller_writer = multiprocessing.Pipe(duplex=False)
gpu_ids = list(range(dp_worker_id * tp_size, (dp_worker_id + 1) * tp_size))
queue = multiprocessing.Queue()
proc = multiprocessing.Process(
target=start_controller_process_single,
args=(
self.server_args,
self.port_args,
pipe_controller_writer,
self.model_overide_args,
True,
gpu_ids,
dp_worker_id,
queue,
)
)
proc.start()
def put_req_to_worker(self, worker_id, req):
self.workers[worker_id].request_queue.put(req)
controller_init_state = pipe_controller_reader.recv()
if controller_init_state != "init ok":
raise RuntimeError(
f"Initialization failed. controller_init_state: {controller_init_state}"
)
self.workers.append(WorkerHandle(
proc=proc,
queue=queue,
))
async def round_robin_scheduler(self, input_requests):
available_workers = list(self.workers.keys())
def round_robin_scheduler(self, input_requests):
for r in input_requests:
self.put_req_to_worker(available_workers[self.round_robin_counter], r)
self.workers[self.round_robin_counter].queue.put(r)
self.round_robin_counter = (self.round_robin_counter + 1) % len(
available_workers
self.workers
)
return
async def shortest_queue_scheduler(self, input_requests):
def shortest_queue_scheduler(self, input_requests):
for r in input_requests:
worker = min(
self.workers, key=lambda w: self.workers[w].request_queue.qsize()
)
self.put_req_to_worker(worker, r)
return
async def remove_dead_workers(self):
for i in list(self.workers.keys()):
worker_thread = self.workers[i]
if not worker_thread.liveness:
worker_thread.join()
# move unsuccessful requests back to the queue
while not worker_thread.request_queue.empty():
self.recv_reqs.append(worker_thread.request_queue.get())
del self.workers[i]
logger.info(f"Stale worker {i} removed")
async def loop_for_forward(self):
while True:
await self.remove_dead_workers()
queue_sizes = [worker.queue.qsize() for worker in self.workers]
wid = np.argmin(queue_sizes)
self.workers[wid].queue.put(r)
if self.have_any_live_worker():
next_step_input = list(self.recv_reqs)
self.recv_reqs = []
if next_step_input:
await self.dispatching(next_step_input)
# else:
# logger.error("There is no live worker.")
def loop_for_forward(self):
while True:
recv_reqs = self.recv_requests()
self.dispatching(recv_reqs)
await asyncio.sleep(global_config.wait_for_new_request_delay)
def recv_requests(self):
recv_reqs = []
async def loop_for_recv_requests(self):
while True:
recv_req = await self.recv_from_tokenizer.recv_pyobj()
try:
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
except zmq.ZMQError:
break
if isinstance(recv_req, FlushCacheReq):
# TODO(lsyin): apply more specific flushCacheReq
for worker_thread in self.workers.values():
worker_thread.request_queue.put(recv_req)
elif isinstance(recv_req, TokenizedGenerateReqInput):
self.recv_reqs.append(recv_req)
for worker in self.workers:
worker.queue.put(recv_req)
elif isinstance(recv_req, AbortReq):
in_queue = False
for i, req in enumerate(self.recv_reqs):
for i, req in enumerate(recv_reqs):
if req.rid == recv_req.rid:
self.recv_reqs[i] = recv_req
recv_reqs[i] = recv_req
in_queue = True
break
if not in_queue:
# Send abort req to all TP groups
for worker in list(self.workers.keys()):
self.put_req_to_worker(worker, recv_req)
for worker in self.workers:
worker.queue.put(recv_req)
elif isinstance(recv_req, TokenizedGenerateReqInput):
recv_reqs.append(recv_req)
else:
logger.error(f"Invalid object: {recv_req}")
return recv_reqs
def start_controller_process(
server_args: ServerArgs,
port_args: PortArgs,
pipe_writer,
model_overide_args=None,
model_overide_args: dict,
):
"""Start a controller process."""
logging.basicConfig(
level=getattr(logging, server_args.log_level.upper()),
format="%(message)s",
)
try:
controller = Controller(
server_args.load_balance_method, server_args, port_args, model_overide_args
)
controller = ControllerMulti(server_args, port_args, model_overide_args)
except Exception:
pipe_writer.send(get_exception_traceback())
raise
pipe_writer.send("init ok")
loop = asyncio.new_event_loop()
loop.set_default_executor(ThreadPoolExecutor(max_workers=256))
pipe_writer.send("init ok")
asyncio.set_event_loop(loop)
loop.create_task(controller.loop_for_recv_requests())
loop.run_until_complete(controller.loop_for_forward())
try:
controller.loop_for_forward()
except Exception:
logger.error("Exception in ControllerMulti:\n" + get_exception_traceback())
finally:
for w in controller.workers:
os.kill(w.proc.pid, 9)
kill_parent_process()
......@@ -3,126 +3,61 @@
import logging
import multiprocessing
import os
import pickle
from typing import List
import torch
import torch.distributed as dist
import zmq
import zmq.asyncio
from sglang.srt.managers.controller.tp_worker import ModelTpServer
from sglang.srt.server_args import ModelPortArgs, PortArgs, ServerArgs
from sglang.srt.managers.controller.tp_worker import (
broadcast_recv_input, launch_tp_servers, ModelTpServer
)
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import kill_parent_process
from sglang.utils import get_exception_traceback
logger = logging.getLogger("srt.controller")
def run_tp_server(
gpu_id: int,
tp_rank: int,
server_args: ServerArgs,
model_port_args: ModelPortArgs,
model_overide_args: dict,
):
"""Run a tp server."""
try:
model_server = ModelTpServer(
gpu_id,
tp_rank,
server_args,
model_port_args,
model_overide_args,
)
tp_cpu_group = model_server.model_runner.tp_group.cpu_group
while True:
recv_reqs = broadcast_recv_input(None, tp_rank, tp_cpu_group)
model_server.exposed_step(recv_reqs)
except Exception:
logger.error("Exception in run_tp_server:\n" + get_exception_traceback())
raise
def launch_tp_servers(
gpu_ids, tp_rank_range, server_args, model_port_args, model_overide_args
):
"""Launch multiple tp servers."""
procs = []
for i in tp_rank_range:
proc = multiprocessing.Process(
target=run_tp_server,
args=(gpu_ids[i], i, server_args, model_port_args, model_overide_args),
)
proc.start()
procs.append(proc)
return procs
def broadcast_recv_input(data, rank, dist_group):
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
if rank == 0:
if len(data) == 0:
tensor_size = torch.tensor([0], dtype=torch.long)
dist.broadcast(tensor_size, src=0, group=dist_group)
else:
serialized_data = pickle.dumps(data)
size = len(serialized_data)
tensor_data = torch.ByteTensor(list(serialized_data))
tensor_size = torch.tensor([size], dtype=torch.long)
dist.broadcast(tensor_size, src=0, group=dist_group)
dist.broadcast(tensor_data, src=0, group=dist_group)
else:
tensor_size = torch.tensor([0], dtype=torch.long)
dist.broadcast(tensor_size, src=0, group=dist_group)
size = tensor_size.item()
if size == 0:
return []
tensor_data = torch.empty(size, dtype=torch.uint8)
dist.broadcast(tensor_data, src=0, group=dist_group)
serialized_data = bytes(tensor_data.tolist())
data = pickle.loads(serialized_data)
return data
class ControllerSingle:
"""A controller that manages a group of tensor parallel workers."""
def __init__(
self, server_args: ServerArgs, port_args: PortArgs, model_overide_args: dict
self,
server_args: ServerArgs,
port_args: PortArgs,
model_overide_args: dict,
gpu_ids: List[int],
is_data_parallel_worker: bool,
dp_worker_id: int,
mp_queue: multiprocessing.Queue,
):
# Parse args
self.server_args = server_args
self.tp_procs = []
self.tp_size = server_args.tp_size
self.is_dp_worker = is_data_parallel_worker
self.dp_worker_id = dp_worker_id
self.mp_queue = mp_queue
# Init communication
context = zmq.Context(2)
self.recv_from_tokenizer = context.socket(zmq.PULL)
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.router_port}")
if not self.is_dp_worker:
self.recv_from_tokenizer = context.socket(zmq.PULL)
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.controller_port}")
self.send_to_detokenizer = context.socket(zmq.PUSH)
self.send_to_detokenizer.connect(
f"tcp://127.0.0.1:{port_args.detokenizer_port}"
)
# Init model server
tp_size_local = server_args.tp_size // server_args.nnodes
gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
# Launch other tp ranks
tp_size_local = server_args.tp_size // server_args.nnodes
self.tp_procs = []
if tp_size_local > 1:
tp_rank_range = range(1, tp_size_local)
self.tp_procs = launch_tp_servers(
gpu_ids,
tp_rank_range,
server_args,
port_args.model_port_args[0],
port_args.nccl_ports[dp_worker_id],
model_overide_args,
)
......@@ -131,16 +66,19 @@ class ControllerSingle:
gpu_ids[0],
0,
server_args,
port_args.model_port_args[0],
port_args.nccl_ports[dp_worker_id],
model_overide_args,
)
self.tp_cpu_group = self.tp_server.model_runner.tp_group.cpu_group
def loop_for_forward(self):
while True:
recv_reqs = self.recv_requests()
if not self.is_dp_worker:
recv_reqs = self.recv_requests_from_zmq()
else:
recv_reqs = self.recv_requests_from_mp_queue()
if self.server_args.tp_size > 1:
if self.tp_size > 1:
broadcast_recv_input(recv_reqs, 0, self.tp_cpu_group)
out_pyobjs = self.tp_server.exposed_step(recv_reqs)
......@@ -148,27 +86,51 @@ class ControllerSingle:
for obj in out_pyobjs:
self.send_to_detokenizer.send_pyobj(obj)
def recv_requests(self):
def recv_requests_from_zmq(self):
recv_reqs = []
while True:
try:
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
recv_reqs.append(recv_req)
except zmq.ZMQError:
break
recv_reqs.append(recv_req)
return recv_reqs
def recv_requests_from_mp_queue(self):
recv_reqs = []
while not self.mp_queue.empty():
recv_reqs.append(self.mp_queue.get())
return recv_reqs
def start_controller_process(
server_args: ServerArgs, port_args: PortArgs, pipe_writer, model_overide_args: dict
server_args: ServerArgs,
port_args: PortArgs,
pipe_writer: multiprocessing.connection.Connection,
model_overide_args: dict,
is_data_parallel_worker: bool = False,
gpu_ids: List[int] = None,
dp_worker_id: int = None,
queue: multiprocessing.connection.Connection = None,
):
"""Start a controller process."""
logging.basicConfig(
level=getattr(logging, server_args.log_level.upper()),
format="%(message)s",
)
if not is_data_parallel_worker:
tp_size_local = server_args.tp_size // server_args.nnodes
gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
dp_worker_id = 0
queue = None
try:
controller = ControllerSingle(server_args, port_args, model_overide_args)
controller = ControllerSingle(server_args, port_args, model_overide_args,
gpu_ids, is_data_parallel_worker,
dp_worker_id, queue)
except Exception:
pipe_writer.send(get_exception_traceback())
raise
......
"""A tensor parallel worker."""
import asyncio
import logging
import multiprocessing
import pickle
import time
import warnings
from concurrent.futures import ThreadPoolExecutor
from typing import List, Optional
import rpyc
import torch
from rpyc.utils.classic import obtain
import torch.distributed as dist
from sglang.global_config import global_config
from sglang.srt.constrained.fsm_cache import FSMCache
......@@ -32,13 +31,11 @@ from sglang.srt.managers.io_struct import (
TokenizedGenerateReqInput,
)
from sglang.srt.model_config import ModelConfig
from sglang.srt.server_args import ModelPortArgs, ServerArgs
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
connect_rpyc_service,
get_int_token_logit_bias,
is_multimodal_model,
set_random_seed,
start_rpyc_service_process,
suppress_other_loggers,
)
from sglang.utils import get_exception_traceback
......@@ -52,10 +49,9 @@ class ModelTpServer:
gpu_id: int,
tp_rank: int,
server_args: ServerArgs,
model_port_args: ModelPortArgs,
nccl_port: int,
model_overide_args: dict,
):
server_args, model_port_args = obtain(server_args), obtain(model_port_args)
suppress_other_loggers()
# Copy arguments
......@@ -79,7 +75,7 @@ class ModelTpServer:
gpu_id=gpu_id,
tp_rank=tp_rank,
tp_size=server_args.tp_size,
nccl_port=model_port_args.nccl_port,
nccl_port=nccl_port,
server_args=server_args,
)
......@@ -178,9 +174,6 @@ class ModelTpServer:
self.new_token_ratio_recovery = global_config.new_token_ratio_recovery
def exposed_step(self, recv_reqs):
if not isinstance(recv_reqs, list):
recv_reqs = obtain(recv_reqs)
try:
# Recv requests
for recv_req in recv_reqs:
......@@ -425,12 +418,6 @@ class ModelTpServer:
f"#running-req: {running_bs}, "
f"#queue-req: {len(self.forward_queue) - len(can_run_list)}"
)
# logger.debug(
# f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
# f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
# f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. "
# f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
# )
# Return the new batch
new_batch = Batch.init_new(
......@@ -733,87 +720,74 @@ class ModelTpServer:
break
class ModelTpService(rpyc.Service):
exposed_ModelTpServer = ModelTpServer
class ModelTpClient:
def __init__(
self,
gpu_ids: List[int],
server_args: ServerArgs,
model_port_args: ModelPortArgs,
model_overide_args,
):
server_args, model_port_args = obtain(server_args), obtain(model_port_args)
self.tp_size = server_args.tp_size
def run_tp_server(
gpu_id: int,
tp_rank: int,
server_args: ServerArgs,
nccl_port: int,
model_overide_args: dict,
):
"""Run a tensor parallel server."""
try:
model_server = ModelTpServer(
gpu_id,
tp_rank,
server_args,
nccl_port,
model_overide_args,
)
tp_cpu_group = model_server.model_runner.tp_group.cpu_group
while True:
recv_reqs = broadcast_recv_input(None, tp_rank, tp_cpu_group)
model_server.exposed_step(recv_reqs)
except Exception:
logger.error("Exception in run_tp_server:\n" + get_exception_traceback())
raise
def launch_tp_servers(
gpu_ids, tp_rank_range, server_args, nccl_port, model_overide_args
):
"""Launch multiple tensor parallel servers."""
procs = []
for i in tp_rank_range:
proc = multiprocessing.Process(
target=run_tp_server,
args=(gpu_ids[i], i, server_args, nccl_port, model_overide_args),
)
proc.start()
procs.append(proc)
if self.tp_size * server_args.dp_size == 1:
# Init model
assert len(gpu_ids) == 1
self.model_server = ModelTpService().exposed_ModelTpServer(
gpu_ids[0],
0,
server_args,
model_port_args,
model_overide_args,
)
return procs
# Wrap functions
def async_wrap(f):
async def _func(*args, **kwargs):
return f(*args, **kwargs)
return _func
def broadcast_recv_input(data, rank, dist_group):
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
self.step = async_wrap(self.model_server.exposed_step)
if rank == 0:
if len(data) == 0:
tensor_size = torch.tensor([0], dtype=torch.long)
dist.broadcast(tensor_size, src=0, group=dist_group)
else:
with ThreadPoolExecutor(self.tp_size) as executor:
# Launch model processes
if server_args.nnodes == 1:
self.procs = list(
executor.map(
lambda args: start_rpyc_service_process(*args),
[
(ModelTpService, p)
for p in model_port_args.model_tp_ports
],
)
)
addrs = [("localhost", p) for p in model_port_args.model_tp_ports]
else:
addrs = [
(ip, port)
for ip, port in zip(
model_port_args.model_tp_ips, model_port_args.model_tp_ports
)
]
self.model_services = list(
executor.map(lambda args: connect_rpyc_service(*args), addrs)
)
# Init model
def init_model(i):
return self.model_services[i].ModelTpServer(
gpu_ids[i],
i,
server_args,
model_port_args,
model_overide_args,
)
self.model_servers = list(executor.map(init_model, range(self.tp_size)))
# Wrap functions
def async_wrap(func_name):
fs = [rpyc.async_(getattr(m, func_name)) for m in self.model_servers]
async def _func(*args, **kwargs):
tasks = [f(*args, **kwargs) for f in fs]
await asyncio.gather(*[asyncio.to_thread(t.wait) for t in tasks])
return obtain(tasks[0].value)
return _func
self.step = async_wrap("step")
serialized_data = pickle.dumps(data)
size = len(serialized_data)
tensor_data = torch.ByteTensor(list(serialized_data))
tensor_size = torch.tensor([size], dtype=torch.long)
dist.broadcast(tensor_size, src=0, group=dist_group)
dist.broadcast(tensor_data, src=0, group=dist_group)
else:
tensor_size = torch.tensor([0], dtype=torch.long)
dist.broadcast(tensor_size, src=0, group=dist_group)
size = tensor_size.item()
if size == 0:
return []
tensor_data = torch.empty(size, dtype=torch.uint8)
dist.broadcast(tensor_data, src=0, group=dist_group)
serialized_data = bytes(tensor_data.tolist())
data = pickle.loads(serialized_data)
return data
......@@ -61,7 +61,7 @@ class TokenizerManager:
self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
self.send_to_router = context.socket(zmq.PUSH)
self.send_to_router.connect(f"tcp://127.0.0.1:{port_args.router_port}")
self.send_to_router.connect(f"tcp://127.0.0.1:{port_args.controller_port}")
self.model_path = server_args.model_path
self.hf_config = get_config(
......
......@@ -44,15 +44,13 @@ from sglang.srt.openai_api_adapter import (
v1_chat_completions,
v1_completions,
)
from sglang.srt.server_args import ModelPortArgs, PortArgs, ServerArgs
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import (
API_KEY_HEADER_NAME,
APIKeyValidatorMiddleware,
allocate_init_ports,
assert_pkg_version,
enable_show_time_cost,
receive_addrs,
send_addrs_to_rank_0,
)
from sglang.utils import get_exception_traceback
......@@ -98,6 +96,7 @@ async def flush_cache():
async def generate_request(obj: GenerateReqInput, request: Request):
"""Handle a generate request."""
if obj.stream:
async def stream_results():
......@@ -146,7 +145,10 @@ def _set_global_server_args(server_args: ServerArgs):
}
def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_args=None):
def launch_server(server_args: ServerArgs,
model_overide_args: Optional[dict] = None,
pipe_finish_writer: Optional[mp.connection.Connection] = None):
"""Launch an HTTP server."""
global tokenizer_manager
logging.basicConfig(
......@@ -173,39 +175,23 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
if server_args.chat_template:
# TODO: replace this with huggingface transformers template
load_chat_template_for_openai_api(server_args.chat_template)
_set_global_server_args(server_args)
# Allocate ports
assert server_args.tp_size % server_args.nnodes == 0
tp_size_local = server_args.tp_size // server_args.nnodes
server_args.port, server_args.additional_ports = allocate_init_ports(
server_args.port,
server_args.additional_ports,
tp_size_local,
server_args.dp_size,
)
ports = server_args.additional_ports
model_port_args = []
for i in range(server_args.dp_size):
model_port_args.append(
ModelPortArgs(
nccl_port=ports[3 + i * (tp_size_local + 1)],
model_tp_ips=[None] * tp_size_local,
model_tp_ports=ports[
3 + i * (tp_size_local + 1) + 1 : 3 + (i + 1) * (tp_size_local + 1)
],
)
)
port_args = PortArgs(
tokenizer_port=ports[0],
router_port=ports[1],
controller_port=ports[1],
detokenizer_port=ports[2],
model_port_args=model_port_args,
nccl_ports=ports[3:],
)
# Handle multi-node tp
# Handle multi-node tensor parallelism
if server_args.nnodes > 1:
assert server_args.dp_size == 1, "Multi-node dp is not supported."
......@@ -224,7 +210,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
gpu_ids,
tp_rank_range,
server_args,
port_args.model_port_args[0],
ports[3],
model_overide_args,
)
while True:
......@@ -232,18 +218,18 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
# Launch processes
tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args)
pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False)
pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False)
pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
if server_args.dp_size == 1:
start_process = start_controller_process_single
else:
start_process = start_controller_process_multi
proc_router = mp.Process(
proc_controller = mp.Process(
target=start_process,
args=(server_args, port_args, pipe_router_writer, model_overide_args),
args=(server_args, port_args, pipe_controller_writer, model_overide_args),
)
proc_router.start()
proc_controller.start()
proc_detoken = mp.Process(
target=start_detokenizer_process,
args=(
......@@ -255,68 +241,27 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
proc_detoken.start()
# Wait for the model to finish loading
router_init_state = pipe_router_reader.recv()
controller_init_state = pipe_controller_reader.recv()
detoken_init_state = pipe_detoken_reader.recv()
if router_init_state != "init ok" or detoken_init_state != "init ok":
proc_router.kill()
if controller_init_state != "init ok" or detoken_init_state != "init ok":
proc_controller.kill()
proc_detoken.kill()
print(
f"Initialization failed. router_init_state: {router_init_state}", flush=True
f"Initialization failed. controller_init_state: {controller_init_state}", flush=True
)
print(
f"Initialization failed. detoken_init_state: {detoken_init_state}",
flush=True,
)
sys.exit(1)
assert proc_router.is_alive() and proc_detoken.is_alive()
assert proc_controller.is_alive() and proc_detoken.is_alive()
if server_args.api_key and server_args.api_key != "":
app.add_middleware(APIKeyValidatorMiddleware, api_key=server_args.api_key)
# Send a warmup request
def _wait_and_warmup():
headers = {}
url = server_args.url()
if server_args.api_key:
headers[API_KEY_HEADER_NAME] = server_args.api_key
# Wait until the server is launched
for _ in range(120):
time.sleep(0.5)
try:
requests.get(url + "/get_model_info", timeout=5, headers=headers)
break
except requests.exceptions.RequestException:
pass
# Send a warmup request
try:
for _ in range(server_args.dp_size):
res = requests.post(
url + "/generate",
json={
"text": "The capital city of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 8,
},
},
headers=headers,
timeout=600,
)
assert res.status_code == 200
except Exception as e:
if pipe_finish_writer is not None:
pipe_finish_writer.send(get_exception_traceback())
print(f"Initialization failed. warmup error: {e}", flush=True)
raise e
logger.info("The server is fired up and ready to roll!")
if pipe_finish_writer is not None:
pipe_finish_writer.send("init ok")
t = threading.Thread(target=_wait_and_warmup)
t = threading.Thread(target=_wait_and_warmup, args=(server_args, pipe_finish_writer))
t.start()
# Listen for requests
......@@ -333,6 +278,48 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
t.join()
def _wait_and_warmup(server_args, pipe_finish_writer):
headers = {}
url = server_args.url()
if server_args.api_key:
headers[API_KEY_HEADER_NAME] = server_args.api_key
# Wait until the server is launched
for _ in range(120):
time.sleep(0.5)
try:
requests.get(url + "/get_model_info", timeout=5, headers=headers)
break
except requests.exceptions.RequestException:
pass
# Send a warmup request
try:
for _ in range(server_args.dp_size):
res = requests.post(
url + "/generate",
json={
"text": "The capital city of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 8,
},
},
headers=headers,
timeout=600,
)
assert res.status_code == 200
except Exception as e:
if pipe_finish_writer is not None:
pipe_finish_writer.send(get_exception_traceback())
print(f"Initialization failed. warmup error: {e}", flush=True)
raise e
logger.info("The server is fired up and ready to roll!")
if pipe_finish_writer is not None:
pipe_finish_writer.send("init ok")
class Runtime:
"""
A wrapper for the server.
......@@ -354,7 +341,6 @@ class Runtime:
self.server_args.port, self.server_args.additional_ports = allocate_init_ports(
self.server_args.port,
self.server_args.additional_ports,
self.server_args.tp_size,
self.server_args.dp_size,
)
......@@ -367,7 +353,7 @@ class Runtime:
pipe_reader, pipe_writer = mp.Pipe(duplex=False)
proc = mp.Process(
target=launch_server,
args=(self.server_args, pipe_writer, model_overide_args),
args=(self.server_args, model_overide_args, pipe_writer),
)
proc.start()
pipe_writer.close()
......
......@@ -337,16 +337,9 @@ class ServerArgs:
)
@dataclasses.dataclass
class ModelPortArgs:
nccl_port: int
model_tp_ips: List[str]
model_tp_ports: List[int]
@dataclasses.dataclass
class PortArgs:
tokenizer_port: int
router_port: int
controller_port: int
detokenizer_port: int
model_port_args: List[ModelPortArgs]
nccl_ports: List[int]
......@@ -3,7 +3,6 @@
import base64
import fcntl
import logging
import multiprocessing
import os
import random
import socket
......@@ -16,12 +15,10 @@ from typing import List, Optional
import numpy as np
import psutil
import requests
import rpyc
import torch
import triton
from fastapi.responses import JSONResponse
from packaging import version as pkg_version
from rpyc.utils.server import ThreadedServer
from starlette.middleware.base import BaseHTTPMiddleware
logger = logging.getLogger(__name__)
......@@ -148,7 +145,6 @@ def is_port_available(port):
def allocate_init_ports(
port: Optional[int] = None,
additional_ports: Optional[List[int]] = None,
tp_size: int = 1,
dp_size: int = 1,
):
"""Allocate ports for all connections."""
......@@ -160,8 +156,8 @@ def allocate_init_ports(
ret_ports = list(set(x for x in ret_ports if is_port_available(x)))
cur_port = ret_ports[-1] + 1 if len(ret_ports) > 0 else 10000
# HTTP + Tokenizer + Controller + Detokenizer + dp_size * (nccl + tp_size)
num_ports_needed = 4 + dp_size * (1 + tp_size)
# HTTP + Tokenizer + Controller + Detokenizer + dp_size * 1 (nccl)
num_ports_needed = 4 + dp_size
while len(ret_ports) < num_ports_needed:
if cur_port not in ret_ports and is_port_available(cur_port):
ret_ports.append(cur_port)
......@@ -371,49 +367,6 @@ def load_image(image_file):
return image, image_size
def connect_rpyc_service(host, port):
repeat_count = 0
while repeat_count < 20:
try:
con = rpyc.connect(
host,
port,
config={
"allow_public_attrs": True,
"allow_pickle": True,
"sync_request_timeout": 3600,
},
)
break
except ConnectionRefusedError as e:
time.sleep(1)
repeat_count += 1
if repeat_count == 20:
raise RuntimeError(f"Connect rpyc error: {e}")
return con.root
def start_rpyc_service(service: rpyc.Service, port: int):
t = ThreadedServer(
service=service,
port=port,
protocol_config={
"allow_public_attrs": True,
"allow_pickle": True,
"sync_request_timeout": 3600,
},
)
t.logger.setLevel(logging.WARN)
t.start()
def start_rpyc_service_process(service: rpyc.Service, port: int):
proc = multiprocessing.Process(target=start_rpyc_service, args=(service, port))
proc.start()
return proc
def suppress_other_loggers():
from vllm.logger import logger as vllm_default_logger
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment