Unverified Commit 0463f7fb authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

Support data parallelism (static) (#480)


Co-authored-by: default avatarYing Sheng <ying.sheng@databricks.com>
Co-authored-by: default avatarLianmin Zheng <lianminzheng@gmail.com>
Co-authored-by: default avatarLiangsheng Yin <hnyls2002@gmail.com>
Co-authored-by: default avatarZhiqiang Xie <xiezhq@stanford.edu>
parent 565d7274
...@@ -26,7 +26,8 @@ class GlobalConfig: ...@@ -26,7 +26,8 @@ class GlobalConfig:
self.concate_and_append_mode = "no_adjust" self.concate_and_append_mode = "no_adjust"
# Request dependency time due to network delay # Request dependency time due to network delay
self.request_dependency_time = 0.03 self.request_dependency_delay = 0.03
self.wait_for_new_request_delay = 0.0006
# New generation token ratio estimation # New generation token ratio estimation
self.base_new_token_ratio = 0.4 self.base_new_token_ratio = 0.4
......
...@@ -5,7 +5,7 @@ from vllm.distributed import ( ...@@ -5,7 +5,7 @@ from vllm.distributed import (
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
) )
from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
class LogitsProcessor(nn.Module): class LogitsProcessor(nn.Module):
......
...@@ -5,7 +5,7 @@ from torch import nn ...@@ -5,7 +5,7 @@ from torch import nn
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
from sglang.srt.layers.extend_attention import extend_attention_fwd from sglang.srt.layers.extend_attention import extend_attention_fwd
from sglang.srt.layers.token_attention import token_attention_fwd from sglang.srt.layers.token_attention import token_attention_fwd
from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
class RadixAttention(nn.Module): class RadixAttention(nn.Module):
...@@ -20,7 +20,7 @@ class RadixAttention(nn.Module): ...@@ -20,7 +20,7 @@ class RadixAttention(nn.Module):
assert np.allclose(scaling, 1.0 / (head_dim**0.5)) assert np.allclose(scaling, 1.0 / (head_dim**0.5))
from sglang.srt.managers.router.model_runner import global_server_args_dict from sglang.srt.managers.controller.model_runner import global_server_args_dict
if global_server_args_dict.get("enable_flashinfer", False): if global_server_args_dict.get("enable_flashinfer", False):
self.prefill_forward = self.prefill_forward_flashinfer self.prefill_forward = self.prefill_forward_flashinfer
......
...@@ -5,7 +5,7 @@ import torch ...@@ -5,7 +5,7 @@ import torch
import triton import triton
import triton.language as tl import triton.language as tl
from sglang.srt.managers.router.model_runner import global_server_args_dict from sglang.srt.managers.controller.model_runner import global_server_args_dict
from sglang.srt.utils import wrap_kernel_launcher from sglang.srt.utils import wrap_kernel_launcher
if global_server_args_dict.get("attention_reduce_in_fp32", False): if global_server_args_dict.get("attention_reduce_in_fp32", False):
......
"""A data parallel worker thread."""
import asyncio
import logging
import queue
import threading
from typing import List, Callable
import uvloop
import zmq
from sglang.global_config import global_config
from sglang.srt.managers.controller.tp_worker import ModelTpClient
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.utils import get_exception_traceback
logger = logging.getLogger("srt.controller")
CHECKING_INTERVAL = 5
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
class DataParallelWorkerThread(threading.Thread):
def __init__(
self,
worker_id: int,
request_queue: queue.Queue,
detokenizer_port: int,
step_func: Callable,
):
super(DataParallelWorkerThread, self).__init__()
self.worker_id = worker_id
self.request_queue = request_queue
self.liveness = True
self.request_dependency_delay = global_config.request_dependency_delay
context = zmq.asyncio.Context()
self.send_to_detokenizer = context.socket(zmq.PUSH)
self.send_to_detokenizer.connect(f"tcp://127.0.0.1:{detokenizer_port}")
self.step = step_func
async def loop_for_forward(self):
while self.liveness:
requests = []
while not self.request_queue.empty():
requests.append(self.request_queue.get())
try:
out_pyobjs = await self.step(requests)
except Exception:
for r in requests:
self.request_queue.put(r)
logger.error(
f"Worker thread {self.worker_id}: "
f"failed to get back from Model Server\n"
f"{get_exception_traceback()}"
)
self.liveness = False
for obj in out_pyobjs:
self.send_to_detokenizer.send_pyobj(obj)
# async sleep for receiving the subsequent request and avoiding cache miss
if len(out_pyobjs) != 0:
has_finished = any([obj.finished for obj in out_pyobjs])
if has_finished:
await asyncio.sleep(self.request_dependency_delay)
await asyncio.sleep(global_config.wait_for_new_request_delay)
async def monitoring(self):
while True:
await asyncio.sleep(CHECKING_INTERVAL)
# can plug in monitoring logic here
def run(self):
logger.info(f"DataParallelWorkerThread {self.worker_id} start")
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.create_task(self.monitoring())
loop.run_until_complete(self.loop_for_forward())
def start_data_parallel_worker(
server_args: ServerArgs,
port_args: PortArgs,
model_overide_args,
gpu_ids: List[int],
worker_id: int,
):
model_tp_client = ModelTpClient(
gpu_ids,
server_args,
port_args.model_port_args[worker_id],
model_overide_args,
)
worker_thread = DataParallelWorkerThread(
worker_id=worker_id,
request_queue=queue.Queue(),
detokenizer_port=port_args.detokenizer_port,
step_func=model_tp_client.step,
)
worker_thread.start()
return worker_thread
\ No newline at end of file
"""Meta data for requests and batches"""
from dataclasses import dataclass from dataclasses import dataclass
from enum import IntEnum, auto from enum import IntEnum, auto
from typing import List from typing import List
...@@ -5,7 +6,7 @@ from typing import List ...@@ -5,7 +6,7 @@ from typing import List
import numpy as np import numpy as np
import torch import torch
from sglang.srt.managers.router.radix_cache import RadixCache from sglang.srt.managers.controller.radix_cache import RadixCache
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
......
"""
A controller that manages multiple data parallel workers.
Each data parallel worker can manage multiple tensor parallel workers.
"""
import asyncio
import logging
from concurrent.futures import ThreadPoolExecutor
from enum import Enum, auto
from typing import Dict
import zmq
import zmq.asyncio
from sglang.global_config import global_config
from sglang.srt.managers.io_struct import (
AbortReq,
FlushCacheReq,
TokenizedGenerateReqInput,
)
from sglang.srt.managers.controller.dp_worker import (
DataParallelWorkerThread,
start_data_parallel_worker,
)
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.utils import get_exception_traceback
logger = logging.getLogger("srt.controller")
class LoadBalanceMethod(Enum):
ROUND_ROBIN = auto()
SHORTEST_QUEUE = auto()
@classmethod
def from_str(cls, method: str):
method = method.upper()
try:
return cls[method]
except KeyError as exc:
raise ValueError(f"Invalid load balance method: {method}") from exc
class Controller:
def __init__(
self,
load_balance_method: str,
server_args: ServerArgs,
port_args: PortArgs,
model_overide_args,
):
self.load_balance_method = LoadBalanceMethod.from_str(load_balance_method)
self.server_args = server_args
self.port_args = port_args
if self.load_balance_method == LoadBalanceMethod.ROUND_ROBIN:
self.round_robin_counter = 0
self.dispatch_lookup = {
LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler,
LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler,
}
self.dispatching = self.dispatch_lookup[self.load_balance_method]
# Init communication
context = zmq.asyncio.Context()
self.recv_from_tokenizer = context.socket(zmq.PULL)
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.router_port}")
# Init status
self.recv_reqs = []
# Start data parallel workers
self.workers: Dict[int, DataParallelWorkerThread] = {}
tp_size = server_args.tp_size
def start_dp_worker(i):
try:
gpu_ids = list(range(i * tp_size, (i + 1) * tp_size))
worker_thread = start_data_parallel_worker(
server_args, port_args, model_overide_args, gpu_ids, i
)
self.workers[i] = worker_thread
except Exception:
logger.error(
f"Failed to start local worker {i}\n{get_exception_traceback()}"
)
with ThreadPoolExecutor(server_args.dp_size) as executor:
executor.map(start_dp_worker, range(server_args.dp_size))
def have_any_live_worker(self):
return any(worker_thread.liveness for worker_thread in self.workers.values())
def put_req_to_worker(self, worker_id, req):
self.workers[worker_id].request_queue.put(req)
async def round_robin_scheduler(self, input_requests):
available_workers = list(self.workers.keys())
for r in input_requests:
self.put_req_to_worker(available_workers[self.round_robin_counter], r)
self.round_robin_counter = (self.round_robin_counter + 1) % len(
available_workers
)
return
async def shortest_queue_scheduler(self, input_requests):
for r in input_requests:
worker = min(
self.workers, key=lambda w: self.workers[w].request_queue.qsize()
)
self.put_req_to_worker(worker, r)
return
async def remove_dead_workers(self):
for i in list(self.workers.keys()):
worker_thread = self.workers[i]
if not worker_thread.liveness:
worker_thread.join()
# move unsuccessful requests back to the queue
while not worker_thread.request_queue.empty():
self.recv_reqs.append(worker_thread.request_queue.get())
del self.workers[i]
logger.info(f"Stale worker {i} removed")
async def loop_for_forward(self):
while True:
await self.remove_dead_workers()
if self.have_any_live_worker():
next_step_input = list(self.recv_reqs)
self.recv_reqs = []
if next_step_input:
await self.dispatching(next_step_input)
#else:
# logger.error("There is no live worker.")
await asyncio.sleep(global_config.wait_for_new_request_delay)
async def loop_for_recv_requests(self):
while True:
recv_req = await self.recv_from_tokenizer.recv_pyobj()
if isinstance(recv_req, FlushCacheReq):
# TODO(lsyin): apply more specific flushCacheReq
for worker_thread in self.workers.values():
worker_thread.request_queue.put(recv_req)
elif isinstance(recv_req, TokenizedGenerateReqInput):
self.recv_reqs.append(recv_req)
elif isinstance(recv_req, AbortReq):
in_queue = False
for i, req in enumerate(self.recv_reqs):
if req.rid == recv_req.rid:
self.recv_reqs[i] = recv_req
in_queue = True
break
if not in_queue:
# Send abort req to all TP groups
for worker in list(self.workers.keys()):
self.put_req_to_worker(worker, recv_req)
else:
logger.error(f"Invalid object: {recv_req}")
def start_controller_process(
server_args: ServerArgs,
port_args: PortArgs,
pipe_writer,
model_overide_args=None,
):
logging.basicConfig(
level=getattr(logging, server_args.log_level.upper()),
format="%(message)s",
)
try:
controller = Controller(
server_args.load_balance_method, server_args, port_args, model_overide_args
)
except Exception:
pipe_writer.send(get_exception_traceback())
raise
pipe_writer.send("init ok")
loop = asyncio.get_event_loop()
asyncio.set_event_loop(loop)
loop.create_task(controller.loop_for_recv_requests())
loop.run_until_complete(controller.loop_for_forward())
"""A controller that manages a group of tensor parallel workers."""
import asyncio import asyncio
import logging import logging
...@@ -6,15 +7,15 @@ import zmq ...@@ -6,15 +7,15 @@ import zmq
import zmq.asyncio import zmq.asyncio
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.managers.router.model_rpc import ModelRpcClient from sglang.srt.managers.controller.tp_worker import ModelTpClient
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
class RouterManager: class ControllerSingle:
def __init__(self, model_client: ModelRpcClient, port_args: PortArgs): def __init__(self, model_client: ModelTpClient, port_args: PortArgs):
# Init communication # Init communication
context = zmq.asyncio.Context(2) context = zmq.asyncio.Context(2)
self.recv_from_tokenizer = context.socket(zmq.PULL) self.recv_from_tokenizer = context.socket(zmq.PULL)
...@@ -30,7 +31,7 @@ class RouterManager: ...@@ -30,7 +31,7 @@ class RouterManager:
self.recv_reqs = [] self.recv_reqs = []
# Init some configs # Init some configs
self.request_dependency_time = global_config.request_dependency_time self.request_dependency_delay = global_config.request_dependency_delay
async def loop_for_forward(self): async def loop_for_forward(self):
while True: while True:
...@@ -46,12 +47,12 @@ class RouterManager: ...@@ -46,12 +47,12 @@ class RouterManager:
if len(out_pyobjs) != 0: if len(out_pyobjs) != 0:
has_finished = any([obj.finished for obj in out_pyobjs]) has_finished = any([obj.finished for obj in out_pyobjs])
if has_finished: if has_finished:
if self.request_dependency_time > 0: if self.request_dependency_delay > 0:
slept = True slept = True
await asyncio.sleep(self.request_dependency_time) await asyncio.sleep(self.request_dependency_delay)
if not slept: if not slept:
await asyncio.sleep(0.0006) await asyncio.sleep(global_config.wait_for_new_request_delay)
async def loop_for_recv_requests(self): async def loop_for_recv_requests(self):
while True: while True:
...@@ -59,7 +60,7 @@ class RouterManager: ...@@ -59,7 +60,7 @@ class RouterManager:
self.recv_reqs.append(recv_req) self.recv_reqs.append(recv_req)
def start_router_process( def start_controller_process(
server_args: ServerArgs, port_args: PortArgs, pipe_writer, model_overide_args server_args: ServerArgs, port_args: PortArgs, pipe_writer, model_overide_args
): ):
logging.basicConfig( logging.basicConfig(
...@@ -68,8 +69,13 @@ def start_router_process( ...@@ -68,8 +69,13 @@ def start_router_process(
) )
try: try:
model_client = ModelRpcClient(server_args, port_args, model_overide_args) model_client = ModelTpClient(
router = RouterManager(model_client, port_args) list(range(server_args.tp_size)),
server_args,
port_args.model_port_args[0],
model_overide_args,
)
controller = ControllerSingle(model_client, port_args)
except Exception: except Exception:
pipe_writer.send(get_exception_traceback()) pipe_writer.send(get_exception_traceback())
raise raise
...@@ -78,5 +84,5 @@ def start_router_process( ...@@ -78,5 +84,5 @@ def start_router_process(
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
loop.create_task(router.loop_for_recv_requests()) loop.create_task(controller.loop_for_recv_requests())
loop.run_until_complete(router.loop_for_forward()) loop.run_until_complete(controller.loop_for_forward())
\ No newline at end of file
...@@ -15,13 +15,13 @@ from vllm.distributed import initialize_model_parallel ...@@ -15,13 +15,13 @@ from vllm.distributed import initialize_model_parallel
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_available_gpu_memory, is_multimodal_model from sglang.srt.utils import get_available_gpu_memory, is_multimodal_model
logger = logging.getLogger("model_runner") logger = logging.getLogger("srt.model_runner")
# for server args in model endpoints # for server args in model endpoints
global_server_args_dict = {} global_server_args_dict = {}
...@@ -215,14 +215,16 @@ class ModelRunner: ...@@ -215,14 +215,16 @@ class ModelRunner:
def __init__( def __init__(
self, self,
model_config, model_config,
mem_fraction_static, mem_fraction_static: float,
tp_rank, gpu_id: int,
tp_size, tp_rank: int,
nccl_port, tp_size: int,
nccl_port: int,
server_args: ServerArgs, server_args: ServerArgs,
): ):
self.model_config = model_config self.model_config = model_config
self.mem_fraction_static = mem_fraction_static self.mem_fraction_static = mem_fraction_static
self.gpu_id = gpu_id
self.tp_rank = tp_rank self.tp_rank = tp_rank
self.tp_size = tp_size self.tp_size = tp_size
self.nccl_port = nccl_port self.nccl_port = nccl_port
...@@ -235,9 +237,9 @@ class ModelRunner: ...@@ -235,9 +237,9 @@ class ModelRunner:
} }
# Init torch distributed # Init torch distributed
logger.info(f"[rank={self.tp_rank}] Set cuda device.") logger.info(f"[gpu_id={self.gpu_id}] Set cuda device.")
torch.cuda.set_device(self.tp_rank) torch.cuda.set_device(self.gpu_id)
logger.info(f"[rank={self.tp_rank}] Init torch begin. Avail mem={get_available_gpu_memory(self.tp_rank):.2f} GB") logger.info(f"[gpu_id={self.gpu_id}] Init nccl begin.")
torch.distributed.init_process_group( torch.distributed.init_process_group(
backend="nccl", backend="nccl",
world_size=self.tp_size, world_size=self.tp_size,
...@@ -245,22 +247,26 @@ class ModelRunner: ...@@ -245,22 +247,26 @@ class ModelRunner:
init_method=f"tcp://127.0.0.1:{self.nccl_port}", init_method=f"tcp://127.0.0.1:{self.nccl_port}",
) )
initialize_model_parallel(tensor_model_parallel_size=self.tp_size) initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
logger.info(f"[rank={self.tp_rank}] Init torch end.") total_gpu_memory = get_available_gpu_memory(
self.gpu_id, distributed=self.tp_size > 1
total_gpu_memory = get_available_gpu_memory(self.tp_rank, distributed=self.tp_size > 1) )
if self.tp_size > 1: if self.tp_size > 1:
total_local_gpu_memory = get_available_gpu_memory(self.tp_rank) total_local_gpu_memory = get_available_gpu_memory(self.gpu_id)
if total_local_gpu_memory < total_gpu_memory * 0.9: if total_local_gpu_memory < total_gpu_memory * 0.9:
raise ValueError("The memory capacity is unbalanced. Some GPUs may be occupied by other processes.") raise ValueError(
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
)
self.load_model() self.load_model()
self.init_memory_pool(total_gpu_memory) self.init_memory_pool(total_gpu_memory)
self.is_multimodal_model = is_multimodal_model(self.model_config) self.is_multimodal_model = is_multimodal_model(self.model_config)
def load_model(self): def load_model(self):
logger.info(f"[rank={self.tp_rank}] Load weight begin.") logger.info(
f"[gpu_id={self.gpu_id}] Load weight begin. "
f"Avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
)
device_config = DeviceConfig() device_config = DeviceConfig()
load_config = LoadConfig(load_format=self.server_args.load_format) load_config = LoadConfig(load_format=self.server_args.load_format)
...@@ -286,12 +292,16 @@ class ModelRunner: ...@@ -286,12 +292,16 @@ class ModelRunner:
parallel_config=None, parallel_config=None,
scheduler_config=None, scheduler_config=None,
) )
logger.info(f"[rank={self.tp_rank}] Load weight end. " logger.info(
f"Type={type(self.model).__name__}. " f"[gpu_id={self.gpu_id}] Load weight end. "
f"Avail mem={get_available_gpu_memory(self.tp_rank):.2f} GB") f"Type={type(self.model).__name__}. "
f"Avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
)
def profile_max_num_token(self, total_gpu_memory): def profile_max_num_token(self, total_gpu_memory):
available_gpu_memory = get_available_gpu_memory(self.tp_rank, distributed=self.tp_size > 1) available_gpu_memory = get_available_gpu_memory(
self.gpu_id, distributed=self.tp_size > 1
)
head_dim = self.model_config.head_dim head_dim = self.model_config.head_dim
head_num = self.model_config.num_key_value_heads // self.tp_size head_num = self.model_config.num_key_value_heads // self.tp_size
cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * 2 cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * 2
...@@ -306,7 +316,7 @@ class ModelRunner: ...@@ -306,7 +316,7 @@ class ModelRunner:
if self.max_total_num_tokens <= 0: if self.max_total_num_tokens <= 0:
raise RuntimeError( raise RuntimeError(
"Not enought memory. " "Please try to increase --mem-fraction-static." "Not enought memory. Please try to increase --mem-fraction-static."
) )
self.req_to_token_pool = ReqToTokenPool( self.req_to_token_pool = ReqToTokenPool(
...@@ -320,6 +330,10 @@ class ModelRunner: ...@@ -320,6 +330,10 @@ class ModelRunner:
head_dim=self.model_config.head_dim, head_dim=self.model_config.head_dim,
layer_num=self.model_config.num_hidden_layers, layer_num=self.model_config.num_hidden_layers,
) )
logger.info(
f"[gpu_id={self.gpu_id}] Memory pool end. "
f"Avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
)
@torch.inference_mode() @torch.inference_mode()
def forward_prefill(self, batch: Batch): def forward_prefill(self, batch: Batch):
...@@ -424,8 +438,8 @@ def import_model_classes(): ...@@ -424,8 +438,8 @@ def import_model_classes():
if hasattr(module, "EntryClass"): if hasattr(module, "EntryClass"):
entry = module.EntryClass entry = module.EntryClass
if isinstance(entry, list): # To support multiple model classes in one module if isinstance(entry, list): # To support multiple model classes in one module
for cls in entry: for tmp in entry:
model_arch_name_to_cls[cls.__name__] = cls model_arch_name_to_cls[tmp.__name__] = tmp
else: else:
model_arch_name_to_cls[entry.__name__] = entry model_arch_name_to_cls[entry.__name__] = entry
return model_arch_name_to_cls return model_arch_name_to_cls
...@@ -442,4 +456,4 @@ def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]: ...@@ -442,4 +456,4 @@ def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
# Monkey patch model loader # Monkey patch model loader
setattr(ModelRegistry, "load_model_cls", load_model_cls_srt) setattr(ModelRegistry, "load_model_cls", load_model_cls_srt)
\ No newline at end of file
...@@ -2,7 +2,7 @@ import random ...@@ -2,7 +2,7 @@ import random
from collections import defaultdict from collections import defaultdict
class Scheduler: class ScheduleHeuristic:
def __init__( def __init__(
self, self,
schedule_heuristic, schedule_heuristic,
......
import asyncio import asyncio
import logging import logging
import multiprocessing
import time import time
import warnings import warnings
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import List, Optional from typing import List
import rpyc import rpyc
import torch import torch
from rpyc.utils.classic import obtain from rpyc.utils.classic import obtain
from rpyc.utils.server import ThreadedServer
try:
from vllm.logger import _default_handler as vllm_default_logger
except ImportError:
from vllm.logger import logger as vllm_default_logger
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.constrained.fsm_cache import FSMCache from sglang.srt.constrained.fsm_cache import FSMCache
...@@ -26,38 +19,41 @@ from sglang.srt.managers.io_struct import ( ...@@ -26,38 +19,41 @@ from sglang.srt.managers.io_struct import (
FlushCacheReq, FlushCacheReq,
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
) )
from sglang.srt.managers.router.infer_batch import Batch, FinishReason, ForwardMode, Req from sglang.srt.managers.controller.infer_batch import Batch, FinishReason, ForwardMode, Req
from sglang.srt.managers.router.model_runner import ModelRunner from sglang.srt.managers.controller.model_runner import ModelRunner
from sglang.srt.managers.router.radix_cache import RadixCache from sglang.srt.managers.controller.radix_cache import RadixCache
from sglang.srt.managers.router.scheduler import Scheduler from sglang.srt.managers.controller.schedule_heuristic import ScheduleHeuristic
from sglang.srt.model_config import ModelConfig from sglang.srt.model_config import ModelConfig
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import ModelPortArgs, ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
get_int_token_logit_bias, get_int_token_logit_bias,
is_multimodal_model, is_multimodal_model,
set_random_seed, set_random_seed,
start_rpyc_process,
suppress_other_loggers,
) )
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
logger = logging.getLogger("model_rpc") logger = logging.getLogger("srt.model_tp")
vllm_default_logger.setLevel(logging.WARN)
logging.getLogger("vllm.utils").setLevel(logging.WARN)
logging.getLogger("vllm.selector").setLevel(logging.WARN)
class ModelRpcServer: class ModelTpServer:
def __init__( def __init__(
self, self,
gpu_id: int,
tp_rank: int, tp_rank: int,
server_args: ServerArgs, server_args: ServerArgs,
port_args: PortArgs, model_port_args: ModelPortArgs,
model_overide_args: Optional[dict] = None, model_overide_args,
): ):
server_args, port_args = [obtain(x) for x in [server_args, port_args]] server_args, model_port_args = obtain(server_args), obtain(model_port_args)
suppress_other_loggers()
# Copy arguments # Copy arguments
self.gpu_id = gpu_id
self.tp_rank = tp_rank self.tp_rank = tp_rank
self.tp_size = server_args.tp_size self.tp_size = server_args.tp_size
self.dp_size = server_args.dp_size
self.schedule_heuristic = server_args.schedule_heuristic self.schedule_heuristic = server_args.schedule_heuristic
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
...@@ -68,16 +64,16 @@ class ModelRpcServer: ...@@ -68,16 +64,16 @@ class ModelRpcServer:
context_length=server_args.context_length, context_length=server_args.context_length,
model_overide_args=model_overide_args, model_overide_args=model_overide_args,
) )
# For model end global settings
self.model_runner = ModelRunner( self.model_runner = ModelRunner(
model_config=self.model_config, model_config=self.model_config,
mem_fraction_static=server_args.mem_fraction_static, mem_fraction_static=server_args.mem_fraction_static,
gpu_id=gpu_id,
tp_rank=tp_rank, tp_rank=tp_rank,
tp_size=server_args.tp_size, tp_size=server_args.tp_size,
nccl_port=port_args.nccl_port, nccl_port=model_port_args.nccl_port,
server_args=server_args, server_args=server_args,
) )
if is_multimodal_model(server_args.model_path): if is_multimodal_model(server_args.model_path):
self.processor = get_processor( self.processor = get_processor(
server_args.tokenizer_path, server_args.tokenizer_path,
...@@ -95,21 +91,21 @@ class ModelRpcServer: ...@@ -95,21 +91,21 @@ class ModelRpcServer:
self.max_prefill_tokens = max( self.max_prefill_tokens = max(
self.model_config.context_len, self.model_config.context_len,
( (
self.max_total_num_tokens // 6 min(self.max_total_num_tokens // 6, 65536)
if server_args.max_prefill_tokens is None if server_args.max_prefill_tokens is None
else server_args.max_prefill_tokens else server_args.max_prefill_tokens
), ),
) )
self.max_running_requests = (self.max_total_num_tokens // 2 self.max_running_requests = (self.max_total_num_tokens // 2
if server_args.max_running_requests is None else server_args.max_running_requests) if server_args.max_running_requests is None else server_args.max_running_requests)
self.int_token_logit_bias = torch.tensor( self.int_token_logit_bias = torch.tensor(
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size) get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
) )
set_random_seed(server_args.random_seed) set_random_seed(server_args.random_seed)
# Print info # Print info
logger.info(f"[rank={self.tp_rank}] " logger.info(
f"[gpu_id={self.gpu_id}] "
f"max_total_num_tokens={self.max_total_num_tokens}, " f"max_total_num_tokens={self.max_total_num_tokens}, "
f"max_prefill_tokens={self.max_prefill_tokens}, " f"max_prefill_tokens={self.max_prefill_tokens}, "
f"context_len={self.model_config.context_len}, " f"context_len={self.model_config.context_len}, "
...@@ -124,7 +120,7 @@ class ModelRpcServer: ...@@ -124,7 +120,7 @@ class ModelRpcServer:
disable=server_args.disable_radix_cache, disable=server_args.disable_radix_cache,
) )
self.tree_cache_metrics = {"total": 0, "hit": 0} self.tree_cache_metrics = {"total": 0, "hit": 0}
self.scheduler = Scheduler( self.scheduler = ScheduleHeuristic(
self.schedule_heuristic, self.schedule_heuristic,
self.max_running_requests, self.max_running_requests,
self.max_prefill_tokens, self.max_prefill_tokens,
...@@ -170,7 +166,7 @@ class ModelRpcServer: ...@@ -170,7 +166,7 @@ class ModelRpcServer:
self.new_token_ratio_recovery = global_config.new_token_ratio_recovery self.new_token_ratio_recovery = global_config.new_token_ratio_recovery
def exposed_step(self, recv_reqs): def exposed_step(self, recv_reqs):
if self.tp_size != 1: if self.tp_size * self.dp_size != 1:
recv_reqs = obtain(recv_reqs) recv_reqs = obtain(recv_reqs)
try: try:
...@@ -188,7 +184,7 @@ class ModelRpcServer: ...@@ -188,7 +184,7 @@ class ModelRpcServer:
# Forward # Forward
self.forward_step() self.forward_step()
except Exception: except Exception:
logger.error("Exception in ModelRpcClient:\n" + get_exception_traceback()) logger.error("Exception in ModelTpClient:\n" + get_exception_traceback())
# Return results # Return results
ret = self.out_pyobjs ret = self.out_pyobjs
...@@ -224,16 +220,17 @@ class ModelRpcServer: ...@@ -224,16 +220,17 @@ class ModelRpcServer:
self.token_to_kv_pool.available_size() self.token_to_kv_pool.available_size()
+ self.tree_cache.evictable_size() + self.tree_cache.evictable_size()
) )
throuhgput = self.num_generated_tokens / ( throughput = self.num_generated_tokens / (
time.time() - self.last_stats_tic time.time() - self.last_stats_tic
) )
self.num_generated_tokens = 0 self.num_generated_tokens = 0
self.last_stats_tic = time.time() self.last_stats_tic = time.time()
logger.info( logger.info(
f"[gpu_id={self.gpu_id}] "
f"#running-req: {len(self.running_batch.reqs)}, " f"#running-req: {len(self.running_batch.reqs)}, "
f"#token: {num_used}, " f"#token: {num_used}, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, " f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"gen throughput (token/s): {throuhgput:.2f}, " f"gen throughput (token/s): {throughput:.2f}, "
f"#queue-req: {len(self.forward_queue)}" f"#queue-req: {len(self.forward_queue)}"
) )
...@@ -405,7 +402,7 @@ class ModelRpcServer: ...@@ -405,7 +402,7 @@ class ModelRpcServer:
f"#new_token: {new_batch_input_tokens}. " f"#new_token: {new_batch_input_tokens}. "
f"#remaining_req: {len(self.forward_queue) - len(can_run_list)}. " f"#remaining_req: {len(self.forward_queue) - len(can_run_list)}. "
f"#running_req: {running_req}. " f"#running_req: {running_req}. "
f"tree_cache_hit_rate: {100.0 * tree_cache_hit_rate:.2f}%." f"tree_cache_hit_rate: {100.0 * tree_cache_hit_rate:.2f}%. "
) )
# logger.debug( # logger.debug(
# f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. " # f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
...@@ -724,20 +721,30 @@ class ModelRpcServer: ...@@ -724,20 +721,30 @@ class ModelRpcServer:
break break
class ModelRpcService(rpyc.Service): class ModelTpService(rpyc.Service):
exposed_ModelRpcServer = ModelRpcServer exposed_ModelTpServer = ModelTpServer
class ModelRpcClient: class ModelTpClient:
def __init__( def __init__(
self, server_args: ServerArgs, port_args: PortArgs, model_overide_args self,
gpu_ids: List[int],
server_args: ServerArgs,
model_port_args: ModelPortArgs,
model_overide_args,
): ):
tp_size = server_args.tp_size server_args, model_port_args = obtain(server_args), obtain(model_port_args)
self.tp_size = server_args.tp_size
if tp_size == 1: if self.tp_size * server_args.dp_size == 1:
# Init model # Init model
self.model_server = ModelRpcService().exposed_ModelRpcServer( assert len(gpu_ids) == 1
0, server_args, port_args, model_overide_args self.model_server = ModelTpService().exposed_ModelTpServer(
0,
gpu_ids[0],
server_args,
model_port_args,
model_overide_args,
) )
# Wrap functions # Wrap functions
...@@ -749,19 +756,26 @@ class ModelRpcClient: ...@@ -749,19 +756,26 @@ class ModelRpcClient:
self.step = async_wrap(self.model_server.exposed_step) self.step = async_wrap(self.model_server.exposed_step)
else: else:
with ThreadPoolExecutor(tp_size) as executor: with ThreadPoolExecutor(self.tp_size) as executor:
# Launch model processes # Launch model processes
rets = executor.map(start_model_process, port_args.model_rpc_ports) rets = executor.map(
self.remote_services = [x[0] for x in rets] lambda args: start_rpyc_process(*args),
[(ModelTpService, p) for p in model_port_args.model_tp_ports],
)
self.model_services = [x[0] for x in rets]
self.procs = [x[1] for x in rets] self.procs = [x[1] for x in rets]
# Init model # Init model
def init_model(i): def init_model(i):
return self.remote_services[i].ModelRpcServer( return self.model_services[i].ModelTpServer(
i, server_args, port_args, model_overide_args gpu_ids[i],
i,
server_args,
model_port_args,
model_overide_args,
) )
self.model_servers = executor.map(init_model, range(tp_size)) self.model_servers = executor.map(init_model, range(self.tp_size))
# Wrap functions # Wrap functions
def async_wrap(func_name): def async_wrap(func_name):
...@@ -774,45 +788,4 @@ class ModelRpcClient: ...@@ -774,45 +788,4 @@ class ModelRpcClient:
return _func return _func
self.step = async_wrap("step") self.step = async_wrap("step")
\ No newline at end of file
def _init_service(port):
t = ThreadedServer(
ModelRpcService(),
port=port,
protocol_config={
"allow_public_attrs": True,
"allow_pickle": True,
"sync_request_timeout": 3600,
},
)
t.start()
def start_model_process(port):
proc = multiprocessing.Process(target=_init_service, args=(port,))
proc.start()
time.sleep(1)
repeat_count = 0
while repeat_count < 20:
try:
con = rpyc.connect(
"localhost",
port,
config={
"allow_public_attrs": True,
"allow_pickle": True,
"sync_request_timeout": 3600,
},
)
break
except ConnectionRefusedError:
time.sleep(1)
repeat_count += 1
if repeat_count == 20:
raise RuntimeError("init rpc env error!")
assert proc.is_alive()
return con.root, proc
...@@ -27,7 +27,6 @@ class GenerateReqInput: ...@@ -27,7 +27,6 @@ class GenerateReqInput:
return_text_in_logprobs: bool = False return_text_in_logprobs: bool = False
# Whether to stream output # Whether to stream output
stream: bool = False stream: bool = False
# TODO: make all parameters a Union[List[T], T] to allow for batched requests
def post_init(self): def post_init(self):
...@@ -135,4 +134,4 @@ class AbortReq: ...@@ -135,4 +134,4 @@ class AbortReq:
@dataclass @dataclass
class DetokenizeReqInput: class DetokenizeReqInput:
input_ids: List[int] input_ids: List[int]
\ No newline at end of file
...@@ -48,7 +48,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -48,7 +48,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata from sglang.srt.managers.controller.model_runner import InputMetadata
@torch.compile @torch.compile
......
...@@ -29,7 +29,7 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig ...@@ -29,7 +29,7 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata from sglang.srt.managers.controller.model_runner import InputMetadata
class DbrxRouter(nn.Module): class DbrxRouter(nn.Module):
......
...@@ -22,7 +22,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -22,7 +22,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata from sglang.srt.managers.controller.model_runner import InputMetadata
class GemmaMLP(nn.Module): class GemmaMLP(nn.Module):
......
...@@ -37,7 +37,7 @@ from vllm.utils import print_warning_once ...@@ -37,7 +37,7 @@ from vllm.utils import print_warning_once
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.fused_moe import fused_moe from sglang.srt.layers.fused_moe import fused_moe
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata from sglang.srt.managers.controller.model_runner import InputMetadata
use_fused = True use_fused = True
......
...@@ -4,9 +4,13 @@ ...@@ -4,9 +4,13 @@
from typing import Any, Dict, Optional, Tuple, Iterable from typing import Any, Dict, Optional, Tuple, Iterable
import torch import torch
import tqdm
from torch import nn from torch import nn
from transformers import LlamaConfig from transformers import LlamaConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size
)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
...@@ -24,7 +28,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -24,7 +28,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata from sglang.srt.managers.controller.model_runner import InputMetadata
class LlamaMLP(nn.Module): class LlamaMLP(nn.Module):
...@@ -284,6 +288,8 @@ class LlamaForCausalLM(nn.Module): ...@@ -284,6 +288,8 @@ class LlamaForCausalLM(nn.Module):
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
if get_tensor_model_parallel_rank() == 0:
weights = tqdm.tqdm(weights, total=int(len(params_dict) * 1.5))
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name or "projector" in name: if "rotary_emb.inv_freq" in name or "projector" in name:
continue continue
......
...@@ -10,8 +10,8 @@ from transformers.models.llava.modeling_llava import LlavaMultiModalProjector ...@@ -10,8 +10,8 @@ from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.managers.router.infer_batch import ForwardMode from sglang.srt.managers.controller.infer_batch import ForwardMode
from sglang.srt.managers.router.model_runner import InputMetadata from sglang.srt.managers.controller.model_runner import InputMetadata
from sglang.srt.mm_utils import ( from sglang.srt.mm_utils import (
get_anyres_image_grid_shape, get_anyres_image_grid_shape,
unpad_image, unpad_image,
......
...@@ -10,8 +10,8 @@ from transformers.models.llava.modeling_llava import LlavaMultiModalProjector ...@@ -10,8 +10,8 @@ from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.managers.router.infer_batch import ForwardMode from sglang.srt.managers.controller.infer_batch import ForwardMode
from sglang.srt.managers.router.model_runner import InputMetadata from sglang.srt.managers.controller.model_runner import InputMetadata
from sglang.srt.mm_utils import ( from sglang.srt.mm_utils import (
get_anyres_image_grid_shape, get_anyres_image_grid_shape,
unpad_image, unpad_image,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment