Unverified Commit d774acad authored by Mingyi's avatar Mingyi Committed by GitHub
Browse files

Remove the dependency of rpyc (#646)

parent d93388da
...@@ -21,7 +21,7 @@ dependencies = [ ...@@ -21,7 +21,7 @@ dependencies = [
[project.optional-dependencies] [project.optional-dependencies]
srt = ["aiohttp", "fastapi", "hf_transfer", "huggingface_hub", "interegular", "packaging", "pillow", srt = ["aiohttp", "fastapi", "hf_transfer", "huggingface_hub", "interegular", "packaging", "pillow",
"psutil", "pydantic", "rpyc", "torch", "uvicorn", "uvloop", "zmq", "vllm==0.5.1", "outlines>=0.0.44"] "psutil", "pydantic", "torch", "uvicorn", "uvloop", "zmq", "vllm==0.5.1", "outlines>=0.0.44"]
openai = ["openai>=1.0", "tiktoken"] openai = ["openai>=1.0", "tiktoken"]
anthropic = ["anthropic>=0.20.0"] anthropic = ["anthropic>=0.20.0"]
litellm = ["litellm>=1.0.0"] litellm = ["litellm>=1.0.0"]
......
...@@ -11,4 +11,4 @@ if __name__ == "__main__": ...@@ -11,4 +11,4 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args) server_args = ServerArgs.from_cli_args(args)
launch_server(server_args, None) launch_server(server_args)
"""Launch the inference server for Llava-video model.""" """Launch the inference server for Llava-video model."""
import argparse import argparse
import multiprocessing as mp
from sglang.srt.server import ServerArgs, launch_server from sglang.srt.server import ServerArgs, launch_server
...@@ -27,6 +26,4 @@ if __name__ == "__main__": ...@@ -27,6 +26,4 @@ if __name__ == "__main__":
server_args = ServerArgs.from_cli_args(args) server_args = ServerArgs.from_cli_args(args)
pipe_reader, pipe_writer = mp.Pipe(duplex=False) launch_server(server_args, model_overide_args, None)
launch_server(server_args, pipe_writer, model_overide_args)
"""A data parallel worker thread."""
import asyncio
import logging
import queue
import threading
from typing import Callable, List
import uvloop
import zmq
from sglang.global_config import global_config
from sglang.srt.managers.controller.tp_worker import ModelTpClient
from sglang.srt.managers.io_struct import BatchTokenIDOut
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import kill_parent_process
from sglang.utils import get_exception_traceback
logger = logging.getLogger("srt.controller")
CHECKING_INTERVAL = 5
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
class DataParallelWorkerThread(threading.Thread):
def __init__(
self,
worker_id: int,
request_queue: queue.Queue,
detokenizer_port: int,
step_func: Callable,
):
super(DataParallelWorkerThread, self).__init__()
self.worker_id = worker_id
self.request_queue = request_queue
self.liveness = True
self.request_dependency_delay = global_config.request_dependency_delay
context = zmq.asyncio.Context()
self.send_to_detokenizer = context.socket(zmq.PUSH)
self.send_to_detokenizer.connect(f"tcp://127.0.0.1:{detokenizer_port}")
self.step = step_func
async def loop_for_forward(self):
while self.liveness:
requests = []
while not self.request_queue.empty():
requests.append(self.request_queue.get())
out_pyobjs: List[BatchTokenIDOut] = []
try:
out_pyobjs = await self.step(requests)
except Exception:
for r in requests:
self.request_queue.put(r)
logger.error(
f"Worker thread {self.worker_id}: "
f"failed to get back from Model Server\n"
f"{get_exception_traceback()}"
)
self.liveness = False
# Crash the whole server when there are any errors.
# TODO(lianmin): make this an option.
kill_parent_process()
return
for obj in out_pyobjs:
self.send_to_detokenizer.send_pyobj(obj)
# async sleep for receiving the subsequent request and avoiding cache miss
if len(out_pyobjs) != 0:
has_finished = any(
[obj.finished_reason is not None for obj in out_pyobjs]
)
if has_finished:
await asyncio.sleep(self.request_dependency_delay)
await asyncio.sleep(global_config.wait_for_new_request_delay)
async def monitoring(self):
while True:
await asyncio.sleep(CHECKING_INTERVAL)
# can plug in monitoring logic here
def run(self):
logger.info(f"DataParallelWorkerThread {self.worker_id} start")
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.create_task(self.monitoring())
loop.run_until_complete(self.loop_for_forward())
def start_data_parallel_worker(
server_args: ServerArgs,
port_args: PortArgs,
model_overide_args,
gpu_ids: List[int],
worker_id: int,
):
model_tp_client = ModelTpClient(
gpu_ids,
server_args,
port_args.model_port_args[worker_id],
model_overide_args,
)
worker_thread = DataParallelWorkerThread(
worker_id=worker_id,
request_queue=queue.Queue(),
detokenizer_port=port_args.detokenizer_port,
step_func=model_tp_client.step,
)
worker_thread.start()
return worker_thread
...@@ -3,19 +3,17 @@ A controller that manages multiple data parallel workers. ...@@ -3,19 +3,17 @@ A controller that manages multiple data parallel workers.
Each data parallel worker can manage multiple tensor parallel workers. Each data parallel worker can manage multiple tensor parallel workers.
""" """
import asyncio import dataclasses
import logging import logging
from concurrent.futures import ThreadPoolExecutor import multiprocessing
import os
from enum import Enum, auto from enum import Enum, auto
from typing import Dict
import numpy as np
import zmq import zmq
import zmq.asyncio
from sglang.global_config import global_config from sglang.srt.managers.controller.manager_single import (
from sglang.srt.managers.controller.dp_worker import ( start_controller_process as start_controller_process_single,
DataParallelWorkerThread,
start_data_parallel_worker,
) )
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
AbortReq, AbortReq,
...@@ -23,12 +21,14 @@ from sglang.srt.managers.io_struct import ( ...@@ -23,12 +21,14 @@ from sglang.srt.managers.io_struct import (
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
) )
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import kill_parent_process
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
logger = logging.getLogger("srt.controller") logger = logging.getLogger("srt.controller")
class LoadBalanceMethod(Enum): class LoadBalanceMethod(Enum):
"""Load balance method."""
ROUND_ROBIN = auto() ROUND_ROBIN = auto()
SHORTEST_QUEUE = auto() SHORTEST_QUEUE = auto()
...@@ -41,155 +41,155 @@ class LoadBalanceMethod(Enum): ...@@ -41,155 +41,155 @@ class LoadBalanceMethod(Enum):
raise ValueError(f"Invalid load balance method: {method}") from exc raise ValueError(f"Invalid load balance method: {method}") from exc
class Controller: @dataclasses.dataclass
class WorkerHandle:
"""Store the handle of a data parallel worker."""
proc: multiprocessing.Process
queue: multiprocessing.Queue
class ControllerMulti:
"""A controller that manages multiple data parallel workers.""" """A controller that manages multiple data parallel workers."""
def __init__( def __init__(
self, self,
load_balance_method: str,
server_args: ServerArgs, server_args: ServerArgs,
port_args: PortArgs, port_args: PortArgs,
model_overide_args, model_overide_args,
): ):
self.load_balance_method = LoadBalanceMethod.from_str(load_balance_method) # Parse args
self.server_args = server_args self.server_args = server_args
self.port_args = port_args self.port_args = port_args
self.model_overide_args = model_overide_args
self.load_balance_method = LoadBalanceMethod.from_str(
server_args.load_balance_method)
if self.load_balance_method == LoadBalanceMethod.ROUND_ROBIN: # Init communication
self.round_robin_counter = 0 context = zmq.Context()
self.recv_from_tokenizer = context.socket(zmq.PULL)
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.controller_port}")
self.dispatch_lookup = { # Dispatch method
self.round_robin_counter = 0
dispatch_lookup = {
LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler, LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler,
LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler, LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler,
} }
self.dispatching = self.dispatch_lookup[self.load_balance_method] self.dispatching = dispatch_lookup[self.load_balance_method]
# Init communication
context = zmq.asyncio.Context()
self.recv_from_tokenizer = context.socket(zmq.PULL)
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.router_port}")
# Init status
self.recv_reqs = []
# Start data parallel workers # Start data parallel workers
self.workers: Dict[int, DataParallelWorkerThread] = {} self.workers = []
tp_size = server_args.tp_size for i in range(server_args.dp_size):
self.start_dp_worker(i)
def start_dp_worker(i):
try: def start_dp_worker(self, dp_worker_id: int):
gpu_ids = list(range(i * tp_size, (i + 1) * tp_size)) tp_size = self.server_args.tp_size
worker_thread = start_data_parallel_worker(
server_args, port_args, model_overide_args, gpu_ids, i pipe_controller_reader, pipe_controller_writer = multiprocessing.Pipe(duplex=False)
gpu_ids = list(range(dp_worker_id * tp_size, (dp_worker_id + 1) * tp_size))
queue = multiprocessing.Queue()
proc = multiprocessing.Process(
target=start_controller_process_single,
args=(
self.server_args,
self.port_args,
pipe_controller_writer,
self.model_overide_args,
True,
gpu_ids,
dp_worker_id,
queue,
) )
self.workers[i] = worker_thread
except Exception:
logger.error(
f"Failed to start local worker {i}\n{get_exception_traceback()}"
) )
proc.start()
for i in range(server_args.dp_size): controller_init_state = pipe_controller_reader.recv()
start_dp_worker(i) if controller_init_state != "init ok":
raise RuntimeError(
# Parallel launch is slower, probably due to the disk bandwidth limitations. f"Initialization failed. controller_init_state: {controller_init_state}"
# with ThreadPoolExecutor(server_args.dp_size) as executor: )
# executor.map(start_dp_worker, range(server_args.dp_size)) self.workers.append(WorkerHandle(
proc=proc,
def have_any_live_worker(self): queue=queue,
return any(worker_thread.liveness for worker_thread in self.workers.values()) ))
def put_req_to_worker(self, worker_id, req):
self.workers[worker_id].request_queue.put(req)
async def round_robin_scheduler(self, input_requests): def round_robin_scheduler(self, input_requests):
available_workers = list(self.workers.keys())
for r in input_requests: for r in input_requests:
self.put_req_to_worker(available_workers[self.round_robin_counter], r) self.workers[self.round_robin_counter].queue.put(r)
self.round_robin_counter = (self.round_robin_counter + 1) % len( self.round_robin_counter = (self.round_robin_counter + 1) % len(
available_workers self.workers
) )
return
async def shortest_queue_scheduler(self, input_requests): def shortest_queue_scheduler(self, input_requests):
for r in input_requests: for r in input_requests:
worker = min( queue_sizes = [worker.queue.qsize() for worker in self.workers]
self.workers, key=lambda w: self.workers[w].request_queue.qsize() wid = np.argmin(queue_sizes)
) self.workers[wid].queue.put(r)
self.put_req_to_worker(worker, r)
return
async def remove_dead_workers(self):
for i in list(self.workers.keys()):
worker_thread = self.workers[i]
if not worker_thread.liveness:
worker_thread.join()
# move unsuccessful requests back to the queue
while not worker_thread.request_queue.empty():
self.recv_reqs.append(worker_thread.request_queue.get())
del self.workers[i]
logger.info(f"Stale worker {i} removed")
async def loop_for_forward(self):
while True:
await self.remove_dead_workers()
if self.have_any_live_worker(): def loop_for_forward(self):
next_step_input = list(self.recv_reqs) while True:
self.recv_reqs = [] recv_reqs = self.recv_requests()
if next_step_input: self.dispatching(recv_reqs)
await self.dispatching(next_step_input)
# else:
# logger.error("There is no live worker.")
await asyncio.sleep(global_config.wait_for_new_request_delay) def recv_requests(self):
recv_reqs = []
async def loop_for_recv_requests(self):
while True: while True:
recv_req = await self.recv_from_tokenizer.recv_pyobj() try:
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
except zmq.ZMQError:
break
if isinstance(recv_req, FlushCacheReq): if isinstance(recv_req, FlushCacheReq):
# TODO(lsyin): apply more specific flushCacheReq # TODO(lsyin): apply more specific flushCacheReq
for worker_thread in self.workers.values(): for worker in self.workers:
worker_thread.request_queue.put(recv_req) worker.queue.put(recv_req)
elif isinstance(recv_req, TokenizedGenerateReqInput):
self.recv_reqs.append(recv_req)
elif isinstance(recv_req, AbortReq): elif isinstance(recv_req, AbortReq):
in_queue = False in_queue = False
for i, req in enumerate(self.recv_reqs): for i, req in enumerate(recv_reqs):
if req.rid == recv_req.rid: if req.rid == recv_req.rid:
self.recv_reqs[i] = recv_req recv_reqs[i] = recv_req
in_queue = True in_queue = True
break break
if not in_queue: if not in_queue:
# Send abort req to all TP groups # Send abort req to all TP groups
for worker in list(self.workers.keys()): for worker in self.workers:
self.put_req_to_worker(worker, recv_req) worker.queue.put(recv_req)
elif isinstance(recv_req, TokenizedGenerateReqInput):
recv_reqs.append(recv_req)
else: else:
logger.error(f"Invalid object: {recv_req}") logger.error(f"Invalid object: {recv_req}")
return recv_reqs
def start_controller_process( def start_controller_process(
server_args: ServerArgs, server_args: ServerArgs,
port_args: PortArgs, port_args: PortArgs,
pipe_writer, pipe_writer,
model_overide_args=None, model_overide_args: dict,
): ):
"""Start a controller process."""
logging.basicConfig( logging.basicConfig(
level=getattr(logging, server_args.log_level.upper()), level=getattr(logging, server_args.log_level.upper()),
format="%(message)s", format="%(message)s",
) )
try: try:
controller = Controller( controller = ControllerMulti(server_args, port_args, model_overide_args)
server_args.load_balance_method, server_args, port_args, model_overide_args
)
except Exception: except Exception:
pipe_writer.send(get_exception_traceback()) pipe_writer.send(get_exception_traceback())
raise raise
pipe_writer.send("init ok")
loop = asyncio.new_event_loop() pipe_writer.send("init ok")
loop.set_default_executor(ThreadPoolExecutor(max_workers=256))
asyncio.set_event_loop(loop) try:
loop.create_task(controller.loop_for_recv_requests()) controller.loop_for_forward()
loop.run_until_complete(controller.loop_for_forward()) except Exception:
logger.error("Exception in ControllerMulti:\n" + get_exception_traceback())
finally:
for w in controller.workers:
os.kill(w.proc.pid, 9)
kill_parent_process()
...@@ -3,126 +3,61 @@ ...@@ -3,126 +3,61 @@
import logging import logging
import multiprocessing import multiprocessing
import os import os
import pickle from typing import List
import torch
import torch.distributed as dist
import zmq import zmq
import zmq.asyncio
from sglang.srt.managers.controller.tp_worker import ModelTpServer from sglang.srt.managers.controller.tp_worker import (
from sglang.srt.server_args import ModelPortArgs, PortArgs, ServerArgs broadcast_recv_input, launch_tp_servers, ModelTpServer
)
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import kill_parent_process from sglang.srt.utils import kill_parent_process
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
logger = logging.getLogger("srt.controller") logger = logging.getLogger("srt.controller")
def run_tp_server(
gpu_id: int,
tp_rank: int,
server_args: ServerArgs,
model_port_args: ModelPortArgs,
model_overide_args: dict,
):
"""Run a tp server."""
try:
model_server = ModelTpServer(
gpu_id,
tp_rank,
server_args,
model_port_args,
model_overide_args,
)
tp_cpu_group = model_server.model_runner.tp_group.cpu_group
while True:
recv_reqs = broadcast_recv_input(None, tp_rank, tp_cpu_group)
model_server.exposed_step(recv_reqs)
except Exception:
logger.error("Exception in run_tp_server:\n" + get_exception_traceback())
raise
def launch_tp_servers(
gpu_ids, tp_rank_range, server_args, model_port_args, model_overide_args
):
"""Launch multiple tp servers."""
procs = []
for i in tp_rank_range:
proc = multiprocessing.Process(
target=run_tp_server,
args=(gpu_ids[i], i, server_args, model_port_args, model_overide_args),
)
proc.start()
procs.append(proc)
return procs
def broadcast_recv_input(data, rank, dist_group):
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
if rank == 0:
if len(data) == 0:
tensor_size = torch.tensor([0], dtype=torch.long)
dist.broadcast(tensor_size, src=0, group=dist_group)
else:
serialized_data = pickle.dumps(data)
size = len(serialized_data)
tensor_data = torch.ByteTensor(list(serialized_data))
tensor_size = torch.tensor([size], dtype=torch.long)
dist.broadcast(tensor_size, src=0, group=dist_group)
dist.broadcast(tensor_data, src=0, group=dist_group)
else:
tensor_size = torch.tensor([0], dtype=torch.long)
dist.broadcast(tensor_size, src=0, group=dist_group)
size = tensor_size.item()
if size == 0:
return []
tensor_data = torch.empty(size, dtype=torch.uint8)
dist.broadcast(tensor_data, src=0, group=dist_group)
serialized_data = bytes(tensor_data.tolist())
data = pickle.loads(serialized_data)
return data
class ControllerSingle: class ControllerSingle:
"""A controller that manages a group of tensor parallel workers.""" """A controller that manages a group of tensor parallel workers."""
def __init__( def __init__(
self, server_args: ServerArgs, port_args: PortArgs, model_overide_args: dict self,
server_args: ServerArgs,
port_args: PortArgs,
model_overide_args: dict,
gpu_ids: List[int],
is_data_parallel_worker: bool,
dp_worker_id: int,
mp_queue: multiprocessing.Queue,
): ):
# Parse args # Parse args
self.server_args = server_args self.tp_size = server_args.tp_size
self.tp_procs = [] self.is_dp_worker = is_data_parallel_worker
self.dp_worker_id = dp_worker_id
self.mp_queue = mp_queue
# Init communication # Init communication
context = zmq.Context(2) context = zmq.Context(2)
if not self.is_dp_worker:
self.recv_from_tokenizer = context.socket(zmq.PULL) self.recv_from_tokenizer = context.socket(zmq.PULL)
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.router_port}") self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.controller_port}")
self.send_to_detokenizer = context.socket(zmq.PUSH) self.send_to_detokenizer = context.socket(zmq.PUSH)
self.send_to_detokenizer.connect( self.send_to_detokenizer.connect(
f"tcp://127.0.0.1:{port_args.detokenizer_port}" f"tcp://127.0.0.1:{port_args.detokenizer_port}"
) )
# Init model server
tp_size_local = server_args.tp_size // server_args.nnodes
gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
# Launch other tp ranks # Launch other tp ranks
tp_size_local = server_args.tp_size // server_args.nnodes
self.tp_procs = []
if tp_size_local > 1: if tp_size_local > 1:
tp_rank_range = range(1, tp_size_local) tp_rank_range = range(1, tp_size_local)
self.tp_procs = launch_tp_servers( self.tp_procs = launch_tp_servers(
gpu_ids, gpu_ids,
tp_rank_range, tp_rank_range,
server_args, server_args,
port_args.model_port_args[0], port_args.nccl_ports[dp_worker_id],
model_overide_args, model_overide_args,
) )
...@@ -131,16 +66,19 @@ class ControllerSingle: ...@@ -131,16 +66,19 @@ class ControllerSingle:
gpu_ids[0], gpu_ids[0],
0, 0,
server_args, server_args,
port_args.model_port_args[0], port_args.nccl_ports[dp_worker_id],
model_overide_args, model_overide_args,
) )
self.tp_cpu_group = self.tp_server.model_runner.tp_group.cpu_group self.tp_cpu_group = self.tp_server.model_runner.tp_group.cpu_group
def loop_for_forward(self): def loop_for_forward(self):
while True: while True:
recv_reqs = self.recv_requests() if not self.is_dp_worker:
recv_reqs = self.recv_requests_from_zmq()
else:
recv_reqs = self.recv_requests_from_mp_queue()
if self.server_args.tp_size > 1: if self.tp_size > 1:
broadcast_recv_input(recv_reqs, 0, self.tp_cpu_group) broadcast_recv_input(recv_reqs, 0, self.tp_cpu_group)
out_pyobjs = self.tp_server.exposed_step(recv_reqs) out_pyobjs = self.tp_server.exposed_step(recv_reqs)
...@@ -148,27 +86,51 @@ class ControllerSingle: ...@@ -148,27 +86,51 @@ class ControllerSingle:
for obj in out_pyobjs: for obj in out_pyobjs:
self.send_to_detokenizer.send_pyobj(obj) self.send_to_detokenizer.send_pyobj(obj)
def recv_requests(self): def recv_requests_from_zmq(self):
recv_reqs = [] recv_reqs = []
while True: while True:
try: try:
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK) recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
recv_reqs.append(recv_req)
except zmq.ZMQError: except zmq.ZMQError:
break break
recv_reqs.append(recv_req)
return recv_reqs
def recv_requests_from_mp_queue(self):
recv_reqs = []
while not self.mp_queue.empty():
recv_reqs.append(self.mp_queue.get())
return recv_reqs return recv_reqs
def start_controller_process( def start_controller_process(
server_args: ServerArgs, port_args: PortArgs, pipe_writer, model_overide_args: dict server_args: ServerArgs,
port_args: PortArgs,
pipe_writer: multiprocessing.connection.Connection,
model_overide_args: dict,
is_data_parallel_worker: bool = False,
gpu_ids: List[int] = None,
dp_worker_id: int = None,
queue: multiprocessing.connection.Connection = None,
): ):
"""Start a controller process."""
logging.basicConfig( logging.basicConfig(
level=getattr(logging, server_args.log_level.upper()), level=getattr(logging, server_args.log_level.upper()),
format="%(message)s", format="%(message)s",
) )
if not is_data_parallel_worker:
tp_size_local = server_args.tp_size // server_args.nnodes
gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
dp_worker_id = 0
queue = None
try: try:
controller = ControllerSingle(server_args, port_args, model_overide_args) controller = ControllerSingle(server_args, port_args, model_overide_args,
gpu_ids, is_data_parallel_worker,
dp_worker_id, queue)
except Exception: except Exception:
pipe_writer.send(get_exception_traceback()) pipe_writer.send(get_exception_traceback())
raise raise
......
"""A tensor parallel worker.""" """A tensor parallel worker."""
import asyncio
import logging import logging
import multiprocessing
import pickle
import time import time
import warnings import warnings
from concurrent.futures import ThreadPoolExecutor
from typing import List, Optional from typing import List, Optional
import rpyc
import torch import torch
from rpyc.utils.classic import obtain import torch.distributed as dist
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.constrained.fsm_cache import FSMCache from sglang.srt.constrained.fsm_cache import FSMCache
...@@ -32,13 +31,11 @@ from sglang.srt.managers.io_struct import ( ...@@ -32,13 +31,11 @@ from sglang.srt.managers.io_struct import (
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
) )
from sglang.srt.model_config import ModelConfig from sglang.srt.model_config import ModelConfig
from sglang.srt.server_args import ModelPortArgs, ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
connect_rpyc_service,
get_int_token_logit_bias, get_int_token_logit_bias,
is_multimodal_model, is_multimodal_model,
set_random_seed, set_random_seed,
start_rpyc_service_process,
suppress_other_loggers, suppress_other_loggers,
) )
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
...@@ -52,10 +49,9 @@ class ModelTpServer: ...@@ -52,10 +49,9 @@ class ModelTpServer:
gpu_id: int, gpu_id: int,
tp_rank: int, tp_rank: int,
server_args: ServerArgs, server_args: ServerArgs,
model_port_args: ModelPortArgs, nccl_port: int,
model_overide_args: dict, model_overide_args: dict,
): ):
server_args, model_port_args = obtain(server_args), obtain(model_port_args)
suppress_other_loggers() suppress_other_loggers()
# Copy arguments # Copy arguments
...@@ -79,7 +75,7 @@ class ModelTpServer: ...@@ -79,7 +75,7 @@ class ModelTpServer:
gpu_id=gpu_id, gpu_id=gpu_id,
tp_rank=tp_rank, tp_rank=tp_rank,
tp_size=server_args.tp_size, tp_size=server_args.tp_size,
nccl_port=model_port_args.nccl_port, nccl_port=nccl_port,
server_args=server_args, server_args=server_args,
) )
...@@ -178,9 +174,6 @@ class ModelTpServer: ...@@ -178,9 +174,6 @@ class ModelTpServer:
self.new_token_ratio_recovery = global_config.new_token_ratio_recovery self.new_token_ratio_recovery = global_config.new_token_ratio_recovery
def exposed_step(self, recv_reqs): def exposed_step(self, recv_reqs):
if not isinstance(recv_reqs, list):
recv_reqs = obtain(recv_reqs)
try: try:
# Recv requests # Recv requests
for recv_req in recv_reqs: for recv_req in recv_reqs:
...@@ -425,12 +418,6 @@ class ModelTpServer: ...@@ -425,12 +418,6 @@ class ModelTpServer:
f"#running-req: {running_bs}, " f"#running-req: {running_bs}, "
f"#queue-req: {len(self.forward_queue) - len(can_run_list)}" f"#queue-req: {len(self.forward_queue) - len(can_run_list)}"
) )
# logger.debug(
# f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
# f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
# f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. "
# f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
# )
# Return the new batch # Return the new batch
new_batch = Batch.init_new( new_batch = Batch.init_new(
...@@ -733,87 +720,74 @@ class ModelTpServer: ...@@ -733,87 +720,74 @@ class ModelTpServer:
break break
class ModelTpService(rpyc.Service): def run_tp_server(
exposed_ModelTpServer = ModelTpServer gpu_id: int,
tp_rank: int,
class ModelTpClient:
def __init__(
self,
gpu_ids: List[int],
server_args: ServerArgs, server_args: ServerArgs,
model_port_args: ModelPortArgs, nccl_port: int,
model_overide_args, model_overide_args: dict,
): ):
server_args, model_port_args = obtain(server_args), obtain(model_port_args) """Run a tensor parallel server."""
self.tp_size = server_args.tp_size try:
model_server = ModelTpServer(
if self.tp_size * server_args.dp_size == 1: gpu_id,
# Init model tp_rank,
assert len(gpu_ids) == 1
self.model_server = ModelTpService().exposed_ModelTpServer(
gpu_ids[0],
0,
server_args, server_args,
model_port_args, nccl_port,
model_overide_args, model_overide_args,
) )
tp_cpu_group = model_server.model_runner.tp_group.cpu_group
# Wrap functions while True:
def async_wrap(f): recv_reqs = broadcast_recv_input(None, tp_rank, tp_cpu_group)
async def _func(*args, **kwargs): model_server.exposed_step(recv_reqs)
return f(*args, **kwargs) except Exception:
logger.error("Exception in run_tp_server:\n" + get_exception_traceback())
raise
return _func
self.step = async_wrap(self.model_server.exposed_step) def launch_tp_servers(
else: gpu_ids, tp_rank_range, server_args, nccl_port, model_overide_args
with ThreadPoolExecutor(self.tp_size) as executor: ):
# Launch model processes """Launch multiple tensor parallel servers."""
if server_args.nnodes == 1: procs = []
self.procs = list( for i in tp_rank_range:
executor.map( proc = multiprocessing.Process(
lambda args: start_rpyc_service_process(*args), target=run_tp_server,
[ args=(gpu_ids[i], i, server_args, nccl_port, model_overide_args),
(ModelTpService, p)
for p in model_port_args.model_tp_ports
],
)
) )
addrs = [("localhost", p) for p in model_port_args.model_tp_ports] proc.start()
else: procs.append(proc)
addrs = [
(ip, port)
for ip, port in zip(
model_port_args.model_tp_ips, model_port_args.model_tp_ports
)
]
self.model_services = list( return procs
executor.map(lambda args: connect_rpyc_service(*args), addrs)
)
# Init model
def init_model(i):
return self.model_services[i].ModelTpServer(
gpu_ids[i],
i,
server_args,
model_port_args,
model_overide_args,
)
self.model_servers = list(executor.map(init_model, range(self.tp_size))) def broadcast_recv_input(data, rank, dist_group):
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
# Wrap functions if rank == 0:
def async_wrap(func_name): if len(data) == 0:
fs = [rpyc.async_(getattr(m, func_name)) for m in self.model_servers] tensor_size = torch.tensor([0], dtype=torch.long)
dist.broadcast(tensor_size, src=0, group=dist_group)
else:
serialized_data = pickle.dumps(data)
size = len(serialized_data)
tensor_data = torch.ByteTensor(list(serialized_data))
tensor_size = torch.tensor([size], dtype=torch.long)
dist.broadcast(tensor_size, src=0, group=dist_group)
dist.broadcast(tensor_data, src=0, group=dist_group)
else:
tensor_size = torch.tensor([0], dtype=torch.long)
dist.broadcast(tensor_size, src=0, group=dist_group)
size = tensor_size.item()
async def _func(*args, **kwargs): if size == 0:
tasks = [f(*args, **kwargs) for f in fs] return []
await asyncio.gather(*[asyncio.to_thread(t.wait) for t in tasks])
return obtain(tasks[0].value)
return _func tensor_data = torch.empty(size, dtype=torch.uint8)
dist.broadcast(tensor_data, src=0, group=dist_group)
self.step = async_wrap("step") serialized_data = bytes(tensor_data.tolist())
data = pickle.loads(serialized_data)
return data
...@@ -61,7 +61,7 @@ class TokenizerManager: ...@@ -61,7 +61,7 @@ class TokenizerManager:
self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}") self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
self.send_to_router = context.socket(zmq.PUSH) self.send_to_router = context.socket(zmq.PUSH)
self.send_to_router.connect(f"tcp://127.0.0.1:{port_args.router_port}") self.send_to_router.connect(f"tcp://127.0.0.1:{port_args.controller_port}")
self.model_path = server_args.model_path self.model_path = server_args.model_path
self.hf_config = get_config( self.hf_config = get_config(
......
...@@ -44,15 +44,13 @@ from sglang.srt.openai_api_adapter import ( ...@@ -44,15 +44,13 @@ from sglang.srt.openai_api_adapter import (
v1_chat_completions, v1_chat_completions,
v1_completions, v1_completions,
) )
from sglang.srt.server_args import ModelPortArgs, PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
API_KEY_HEADER_NAME, API_KEY_HEADER_NAME,
APIKeyValidatorMiddleware, APIKeyValidatorMiddleware,
allocate_init_ports, allocate_init_ports,
assert_pkg_version, assert_pkg_version,
enable_show_time_cost, enable_show_time_cost,
receive_addrs,
send_addrs_to_rank_0,
) )
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
...@@ -98,6 +96,7 @@ async def flush_cache(): ...@@ -98,6 +96,7 @@ async def flush_cache():
async def generate_request(obj: GenerateReqInput, request: Request): async def generate_request(obj: GenerateReqInput, request: Request):
"""Handle a generate request."""
if obj.stream: if obj.stream:
async def stream_results(): async def stream_results():
...@@ -146,7 +145,10 @@ def _set_global_server_args(server_args: ServerArgs): ...@@ -146,7 +145,10 @@ def _set_global_server_args(server_args: ServerArgs):
} }
def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_args=None): def launch_server(server_args: ServerArgs,
model_overide_args: Optional[dict] = None,
pipe_finish_writer: Optional[mp.connection.Connection] = None):
"""Launch an HTTP server."""
global tokenizer_manager global tokenizer_manager
logging.basicConfig( logging.basicConfig(
...@@ -173,39 +175,23 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg ...@@ -173,39 +175,23 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
if server_args.chat_template: if server_args.chat_template:
# TODO: replace this with huggingface transformers template # TODO: replace this with huggingface transformers template
load_chat_template_for_openai_api(server_args.chat_template) load_chat_template_for_openai_api(server_args.chat_template)
_set_global_server_args(server_args) _set_global_server_args(server_args)
# Allocate ports # Allocate ports
assert server_args.tp_size % server_args.nnodes == 0
tp_size_local = server_args.tp_size // server_args.nnodes
server_args.port, server_args.additional_ports = allocate_init_ports( server_args.port, server_args.additional_ports = allocate_init_ports(
server_args.port, server_args.port,
server_args.additional_ports, server_args.additional_ports,
tp_size_local,
server_args.dp_size, server_args.dp_size,
) )
ports = server_args.additional_ports ports = server_args.additional_ports
model_port_args = []
for i in range(server_args.dp_size):
model_port_args.append(
ModelPortArgs(
nccl_port=ports[3 + i * (tp_size_local + 1)],
model_tp_ips=[None] * tp_size_local,
model_tp_ports=ports[
3 + i * (tp_size_local + 1) + 1 : 3 + (i + 1) * (tp_size_local + 1)
],
)
)
port_args = PortArgs( port_args = PortArgs(
tokenizer_port=ports[0], tokenizer_port=ports[0],
router_port=ports[1], controller_port=ports[1],
detokenizer_port=ports[2], detokenizer_port=ports[2],
model_port_args=model_port_args, nccl_ports=ports[3:],
) )
# Handle multi-node tp # Handle multi-node tensor parallelism
if server_args.nnodes > 1: if server_args.nnodes > 1:
assert server_args.dp_size == 1, "Multi-node dp is not supported." assert server_args.dp_size == 1, "Multi-node dp is not supported."
...@@ -224,7 +210,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg ...@@ -224,7 +210,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
gpu_ids, gpu_ids,
tp_rank_range, tp_rank_range,
server_args, server_args,
port_args.model_port_args[0], ports[3],
model_overide_args, model_overide_args,
) )
while True: while True:
...@@ -232,18 +218,18 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg ...@@ -232,18 +218,18 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
# Launch processes # Launch processes
tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args) tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args)
pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False) pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False)
pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False) pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
if server_args.dp_size == 1: if server_args.dp_size == 1:
start_process = start_controller_process_single start_process = start_controller_process_single
else: else:
start_process = start_controller_process_multi start_process = start_controller_process_multi
proc_router = mp.Process( proc_controller = mp.Process(
target=start_process, target=start_process,
args=(server_args, port_args, pipe_router_writer, model_overide_args), args=(server_args, port_args, pipe_controller_writer, model_overide_args),
) )
proc_router.start() proc_controller.start()
proc_detoken = mp.Process( proc_detoken = mp.Process(
target=start_detokenizer_process, target=start_detokenizer_process,
args=( args=(
...@@ -255,27 +241,44 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg ...@@ -255,27 +241,44 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
proc_detoken.start() proc_detoken.start()
# Wait for the model to finish loading # Wait for the model to finish loading
router_init_state = pipe_router_reader.recv() controller_init_state = pipe_controller_reader.recv()
detoken_init_state = pipe_detoken_reader.recv() detoken_init_state = pipe_detoken_reader.recv()
if router_init_state != "init ok" or detoken_init_state != "init ok": if controller_init_state != "init ok" or detoken_init_state != "init ok":
proc_router.kill() proc_controller.kill()
proc_detoken.kill() proc_detoken.kill()
print( print(
f"Initialization failed. router_init_state: {router_init_state}", flush=True f"Initialization failed. controller_init_state: {controller_init_state}", flush=True
) )
print( print(
f"Initialization failed. detoken_init_state: {detoken_init_state}", f"Initialization failed. detoken_init_state: {detoken_init_state}",
flush=True, flush=True,
) )
sys.exit(1) sys.exit(1)
assert proc_router.is_alive() and proc_detoken.is_alive() assert proc_controller.is_alive() and proc_detoken.is_alive()
if server_args.api_key and server_args.api_key != "": if server_args.api_key and server_args.api_key != "":
app.add_middleware(APIKeyValidatorMiddleware, api_key=server_args.api_key) app.add_middleware(APIKeyValidatorMiddleware, api_key=server_args.api_key)
# Send a warmup request # Send a warmup request
def _wait_and_warmup(): t = threading.Thread(target=_wait_and_warmup, args=(server_args, pipe_finish_writer))
t.start()
# Listen for requests
try:
uvicorn.run(
app,
host=server_args.host,
port=server_args.port,
log_level=server_args.log_level_http or server_args.log_level,
timeout_keep_alive=5,
loop="uvloop",
)
finally:
t.join()
def _wait_and_warmup(server_args, pipe_finish_writer):
headers = {} headers = {}
url = server_args.url() url = server_args.url()
if server_args.api_key: if server_args.api_key:
...@@ -316,22 +319,6 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg ...@@ -316,22 +319,6 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
if pipe_finish_writer is not None: if pipe_finish_writer is not None:
pipe_finish_writer.send("init ok") pipe_finish_writer.send("init ok")
t = threading.Thread(target=_wait_and_warmup)
t.start()
# Listen for requests
try:
uvicorn.run(
app,
host=server_args.host,
port=server_args.port,
log_level=server_args.log_level_http or server_args.log_level,
timeout_keep_alive=5,
loop="uvloop",
)
finally:
t.join()
class Runtime: class Runtime:
""" """
...@@ -354,7 +341,6 @@ class Runtime: ...@@ -354,7 +341,6 @@ class Runtime:
self.server_args.port, self.server_args.additional_ports = allocate_init_ports( self.server_args.port, self.server_args.additional_ports = allocate_init_ports(
self.server_args.port, self.server_args.port,
self.server_args.additional_ports, self.server_args.additional_ports,
self.server_args.tp_size,
self.server_args.dp_size, self.server_args.dp_size,
) )
...@@ -367,7 +353,7 @@ class Runtime: ...@@ -367,7 +353,7 @@ class Runtime:
pipe_reader, pipe_writer = mp.Pipe(duplex=False) pipe_reader, pipe_writer = mp.Pipe(duplex=False)
proc = mp.Process( proc = mp.Process(
target=launch_server, target=launch_server,
args=(self.server_args, pipe_writer, model_overide_args), args=(self.server_args, model_overide_args, pipe_writer),
) )
proc.start() proc.start()
pipe_writer.close() pipe_writer.close()
......
...@@ -337,16 +337,9 @@ class ServerArgs: ...@@ -337,16 +337,9 @@ class ServerArgs:
) )
@dataclasses.dataclass
class ModelPortArgs:
nccl_port: int
model_tp_ips: List[str]
model_tp_ports: List[int]
@dataclasses.dataclass @dataclasses.dataclass
class PortArgs: class PortArgs:
tokenizer_port: int tokenizer_port: int
router_port: int controller_port: int
detokenizer_port: int detokenizer_port: int
model_port_args: List[ModelPortArgs] nccl_ports: List[int]
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
import base64 import base64
import fcntl import fcntl
import logging import logging
import multiprocessing
import os import os
import random import random
import socket import socket
...@@ -16,12 +15,10 @@ from typing import List, Optional ...@@ -16,12 +15,10 @@ from typing import List, Optional
import numpy as np import numpy as np
import psutil import psutil
import requests import requests
import rpyc
import torch import torch
import triton import triton
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from packaging import version as pkg_version from packaging import version as pkg_version
from rpyc.utils.server import ThreadedServer
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -148,7 +145,6 @@ def is_port_available(port): ...@@ -148,7 +145,6 @@ def is_port_available(port):
def allocate_init_ports( def allocate_init_ports(
port: Optional[int] = None, port: Optional[int] = None,
additional_ports: Optional[List[int]] = None, additional_ports: Optional[List[int]] = None,
tp_size: int = 1,
dp_size: int = 1, dp_size: int = 1,
): ):
"""Allocate ports for all connections.""" """Allocate ports for all connections."""
...@@ -160,8 +156,8 @@ def allocate_init_ports( ...@@ -160,8 +156,8 @@ def allocate_init_ports(
ret_ports = list(set(x for x in ret_ports if is_port_available(x))) ret_ports = list(set(x for x in ret_ports if is_port_available(x)))
cur_port = ret_ports[-1] + 1 if len(ret_ports) > 0 else 10000 cur_port = ret_ports[-1] + 1 if len(ret_ports) > 0 else 10000
# HTTP + Tokenizer + Controller + Detokenizer + dp_size * (nccl + tp_size) # HTTP + Tokenizer + Controller + Detokenizer + dp_size * 1 (nccl)
num_ports_needed = 4 + dp_size * (1 + tp_size) num_ports_needed = 4 + dp_size
while len(ret_ports) < num_ports_needed: while len(ret_ports) < num_ports_needed:
if cur_port not in ret_ports and is_port_available(cur_port): if cur_port not in ret_ports and is_port_available(cur_port):
ret_ports.append(cur_port) ret_ports.append(cur_port)
...@@ -371,49 +367,6 @@ def load_image(image_file): ...@@ -371,49 +367,6 @@ def load_image(image_file):
return image, image_size return image, image_size
def connect_rpyc_service(host, port):
repeat_count = 0
while repeat_count < 20:
try:
con = rpyc.connect(
host,
port,
config={
"allow_public_attrs": True,
"allow_pickle": True,
"sync_request_timeout": 3600,
},
)
break
except ConnectionRefusedError as e:
time.sleep(1)
repeat_count += 1
if repeat_count == 20:
raise RuntimeError(f"Connect rpyc error: {e}")
return con.root
def start_rpyc_service(service: rpyc.Service, port: int):
t = ThreadedServer(
service=service,
port=port,
protocol_config={
"allow_public_attrs": True,
"allow_pickle": True,
"sync_request_timeout": 3600,
},
)
t.logger.setLevel(logging.WARN)
t.start()
def start_rpyc_service_process(service: rpyc.Service, port: int):
proc = multiprocessing.Process(target=start_rpyc_service, args=(service, port))
proc.start()
return proc
def suppress_other_loggers(): def suppress_other_loggers():
from vllm.logger import logger as vllm_default_logger from vllm.logger import logger as vllm_default_logger
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment