"cpu/arange_interleave.cpp" did not exist on "4a569c27736957e6606fd3b3f69712a808189a74"
Unverified Commit 0463f7fb authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

Support data parallelism (static) (#480)


Co-authored-by: default avatarYing Sheng <ying.sheng@databricks.com>
Co-authored-by: default avatarLianmin Zheng <lianminzheng@gmail.com>
Co-authored-by: default avatarLiangsheng Yin <hnyls2002@gmail.com>
Co-authored-by: default avatarZhiqiang Xie <xiezhq@stanford.edu>
parent 565d7274
......@@ -26,7 +26,8 @@ class GlobalConfig:
self.concate_and_append_mode = "no_adjust"
# Request dependency time due to network delay
self.request_dependency_time = 0.03
self.request_dependency_delay = 0.03
self.wait_for_new_request_delay = 0.0006
# New generation token ratio estimation
self.base_new_token_ratio = 0.4
......
......@@ -5,7 +5,7 @@ from vllm.distributed import (
tensor_model_parallel_all_gather,
)
from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata
from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
class LogitsProcessor(nn.Module):
......
......@@ -5,7 +5,7 @@ from torch import nn
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
from sglang.srt.layers.extend_attention import extend_attention_fwd
from sglang.srt.layers.token_attention import token_attention_fwd
from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata
from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
class RadixAttention(nn.Module):
......@@ -20,7 +20,7 @@ class RadixAttention(nn.Module):
assert np.allclose(scaling, 1.0 / (head_dim**0.5))
from sglang.srt.managers.router.model_runner import global_server_args_dict
from sglang.srt.managers.controller.model_runner import global_server_args_dict
if global_server_args_dict.get("enable_flashinfer", False):
self.prefill_forward = self.prefill_forward_flashinfer
......
......@@ -5,7 +5,7 @@ import torch
import triton
import triton.language as tl
from sglang.srt.managers.router.model_runner import global_server_args_dict
from sglang.srt.managers.controller.model_runner import global_server_args_dict
from sglang.srt.utils import wrap_kernel_launcher
if global_server_args_dict.get("attention_reduce_in_fp32", False):
......
"""A data parallel worker thread."""
import asyncio
import logging
import queue
import threading
from typing import List, Callable
import uvloop
import zmq
from sglang.global_config import global_config
from sglang.srt.managers.controller.tp_worker import ModelTpClient
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.utils import get_exception_traceback
logger = logging.getLogger("srt.controller")
CHECKING_INTERVAL = 5
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
class DataParallelWorkerThread(threading.Thread):
def __init__(
self,
worker_id: int,
request_queue: queue.Queue,
detokenizer_port: int,
step_func: Callable,
):
super(DataParallelWorkerThread, self).__init__()
self.worker_id = worker_id
self.request_queue = request_queue
self.liveness = True
self.request_dependency_delay = global_config.request_dependency_delay
context = zmq.asyncio.Context()
self.send_to_detokenizer = context.socket(zmq.PUSH)
self.send_to_detokenizer.connect(f"tcp://127.0.0.1:{detokenizer_port}")
self.step = step_func
async def loop_for_forward(self):
while self.liveness:
requests = []
while not self.request_queue.empty():
requests.append(self.request_queue.get())
try:
out_pyobjs = await self.step(requests)
except Exception:
for r in requests:
self.request_queue.put(r)
logger.error(
f"Worker thread {self.worker_id}: "
f"failed to get back from Model Server\n"
f"{get_exception_traceback()}"
)
self.liveness = False
for obj in out_pyobjs:
self.send_to_detokenizer.send_pyobj(obj)
# async sleep for receiving the subsequent request and avoiding cache miss
if len(out_pyobjs) != 0:
has_finished = any([obj.finished for obj in out_pyobjs])
if has_finished:
await asyncio.sleep(self.request_dependency_delay)
await asyncio.sleep(global_config.wait_for_new_request_delay)
async def monitoring(self):
while True:
await asyncio.sleep(CHECKING_INTERVAL)
# can plug in monitoring logic here
def run(self):
logger.info(f"DataParallelWorkerThread {self.worker_id} start")
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.create_task(self.monitoring())
loop.run_until_complete(self.loop_for_forward())
def start_data_parallel_worker(
server_args: ServerArgs,
port_args: PortArgs,
model_overide_args,
gpu_ids: List[int],
worker_id: int,
):
model_tp_client = ModelTpClient(
gpu_ids,
server_args,
port_args.model_port_args[worker_id],
model_overide_args,
)
worker_thread = DataParallelWorkerThread(
worker_id=worker_id,
request_queue=queue.Queue(),
detokenizer_port=port_args.detokenizer_port,
step_func=model_tp_client.step,
)
worker_thread.start()
return worker_thread
\ No newline at end of file
"""Meta data for requests and batches"""
from dataclasses import dataclass
from enum import IntEnum, auto
from typing import List
......@@ -5,7 +6,7 @@ from typing import List
import numpy as np
import torch
from sglang.srt.managers.router.radix_cache import RadixCache
from sglang.srt.managers.controller.radix_cache import RadixCache
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
......
"""
A controller that manages multiple data parallel workers.
Each data parallel worker can manage multiple tensor parallel workers.
"""
import asyncio
import logging
from concurrent.futures import ThreadPoolExecutor
from enum import Enum, auto
from typing import Dict
import zmq
import zmq.asyncio
from sglang.global_config import global_config
from sglang.srt.managers.io_struct import (
AbortReq,
FlushCacheReq,
TokenizedGenerateReqInput,
)
from sglang.srt.managers.controller.dp_worker import (
DataParallelWorkerThread,
start_data_parallel_worker,
)
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.utils import get_exception_traceback
logger = logging.getLogger("srt.controller")
class LoadBalanceMethod(Enum):
ROUND_ROBIN = auto()
SHORTEST_QUEUE = auto()
@classmethod
def from_str(cls, method: str):
method = method.upper()
try:
return cls[method]
except KeyError as exc:
raise ValueError(f"Invalid load balance method: {method}") from exc
class Controller:
def __init__(
self,
load_balance_method: str,
server_args: ServerArgs,
port_args: PortArgs,
model_overide_args,
):
self.load_balance_method = LoadBalanceMethod.from_str(load_balance_method)
self.server_args = server_args
self.port_args = port_args
if self.load_balance_method == LoadBalanceMethod.ROUND_ROBIN:
self.round_robin_counter = 0
self.dispatch_lookup = {
LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler,
LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler,
}
self.dispatching = self.dispatch_lookup[self.load_balance_method]
# Init communication
context = zmq.asyncio.Context()
self.recv_from_tokenizer = context.socket(zmq.PULL)
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.router_port}")
# Init status
self.recv_reqs = []
# Start data parallel workers
self.workers: Dict[int, DataParallelWorkerThread] = {}
tp_size = server_args.tp_size
def start_dp_worker(i):
try:
gpu_ids = list(range(i * tp_size, (i + 1) * tp_size))
worker_thread = start_data_parallel_worker(
server_args, port_args, model_overide_args, gpu_ids, i
)
self.workers[i] = worker_thread
except Exception:
logger.error(
f"Failed to start local worker {i}\n{get_exception_traceback()}"
)
with ThreadPoolExecutor(server_args.dp_size) as executor:
executor.map(start_dp_worker, range(server_args.dp_size))
def have_any_live_worker(self):
return any(worker_thread.liveness for worker_thread in self.workers.values())
def put_req_to_worker(self, worker_id, req):
self.workers[worker_id].request_queue.put(req)
async def round_robin_scheduler(self, input_requests):
available_workers = list(self.workers.keys())
for r in input_requests:
self.put_req_to_worker(available_workers[self.round_robin_counter], r)
self.round_robin_counter = (self.round_robin_counter + 1) % len(
available_workers
)
return
async def shortest_queue_scheduler(self, input_requests):
for r in input_requests:
worker = min(
self.workers, key=lambda w: self.workers[w].request_queue.qsize()
)
self.put_req_to_worker(worker, r)
return
async def remove_dead_workers(self):
for i in list(self.workers.keys()):
worker_thread = self.workers[i]
if not worker_thread.liveness:
worker_thread.join()
# move unsuccessful requests back to the queue
while not worker_thread.request_queue.empty():
self.recv_reqs.append(worker_thread.request_queue.get())
del self.workers[i]
logger.info(f"Stale worker {i} removed")
async def loop_for_forward(self):
while True:
await self.remove_dead_workers()
if self.have_any_live_worker():
next_step_input = list(self.recv_reqs)
self.recv_reqs = []
if next_step_input:
await self.dispatching(next_step_input)
#else:
# logger.error("There is no live worker.")
await asyncio.sleep(global_config.wait_for_new_request_delay)
async def loop_for_recv_requests(self):
while True:
recv_req = await self.recv_from_tokenizer.recv_pyobj()
if isinstance(recv_req, FlushCacheReq):
# TODO(lsyin): apply more specific flushCacheReq
for worker_thread in self.workers.values():
worker_thread.request_queue.put(recv_req)
elif isinstance(recv_req, TokenizedGenerateReqInput):
self.recv_reqs.append(recv_req)
elif isinstance(recv_req, AbortReq):
in_queue = False
for i, req in enumerate(self.recv_reqs):
if req.rid == recv_req.rid:
self.recv_reqs[i] = recv_req
in_queue = True
break
if not in_queue:
# Send abort req to all TP groups
for worker in list(self.workers.keys()):
self.put_req_to_worker(worker, recv_req)
else:
logger.error(f"Invalid object: {recv_req}")
def start_controller_process(
server_args: ServerArgs,
port_args: PortArgs,
pipe_writer,
model_overide_args=None,
):
logging.basicConfig(
level=getattr(logging, server_args.log_level.upper()),
format="%(message)s",
)
try:
controller = Controller(
server_args.load_balance_method, server_args, port_args, model_overide_args
)
except Exception:
pipe_writer.send(get_exception_traceback())
raise
pipe_writer.send("init ok")
loop = asyncio.get_event_loop()
asyncio.set_event_loop(loop)
loop.create_task(controller.loop_for_recv_requests())
loop.run_until_complete(controller.loop_for_forward())
"""A controller that manages a group of tensor parallel workers."""
import asyncio
import logging
......@@ -6,15 +7,15 @@ import zmq
import zmq.asyncio
from sglang.global_config import global_config
from sglang.srt.managers.router.model_rpc import ModelRpcClient
from sglang.srt.managers.controller.tp_worker import ModelTpClient
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.utils import get_exception_traceback
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
class RouterManager:
def __init__(self, model_client: ModelRpcClient, port_args: PortArgs):
class ControllerSingle:
def __init__(self, model_client: ModelTpClient, port_args: PortArgs):
# Init communication
context = zmq.asyncio.Context(2)
self.recv_from_tokenizer = context.socket(zmq.PULL)
......@@ -30,7 +31,7 @@ class RouterManager:
self.recv_reqs = []
# Init some configs
self.request_dependency_time = global_config.request_dependency_time
self.request_dependency_delay = global_config.request_dependency_delay
async def loop_for_forward(self):
while True:
......@@ -46,12 +47,12 @@ class RouterManager:
if len(out_pyobjs) != 0:
has_finished = any([obj.finished for obj in out_pyobjs])
if has_finished:
if self.request_dependency_time > 0:
if self.request_dependency_delay > 0:
slept = True
await asyncio.sleep(self.request_dependency_time)
await asyncio.sleep(self.request_dependency_delay)
if not slept:
await asyncio.sleep(0.0006)
await asyncio.sleep(global_config.wait_for_new_request_delay)
async def loop_for_recv_requests(self):
while True:
......@@ -59,7 +60,7 @@ class RouterManager:
self.recv_reqs.append(recv_req)
def start_router_process(
def start_controller_process(
server_args: ServerArgs, port_args: PortArgs, pipe_writer, model_overide_args
):
logging.basicConfig(
......@@ -68,8 +69,13 @@ def start_router_process(
)
try:
model_client = ModelRpcClient(server_args, port_args, model_overide_args)
router = RouterManager(model_client, port_args)
model_client = ModelTpClient(
list(range(server_args.tp_size)),
server_args,
port_args.model_port_args[0],
model_overide_args,
)
controller = ControllerSingle(model_client, port_args)
except Exception:
pipe_writer.send(get_exception_traceback())
raise
......@@ -78,5 +84,5 @@ def start_router_process(
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.create_task(router.loop_for_recv_requests())
loop.run_until_complete(router.loop_for_forward())
loop.create_task(controller.loop_for_recv_requests())
loop.run_until_complete(controller.loop_for_forward())
\ No newline at end of file
......@@ -15,13 +15,13 @@ from vllm.distributed import initialize_model_parallel
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import ModelRegistry
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_available_gpu_memory, is_multimodal_model
logger = logging.getLogger("model_runner")
logger = logging.getLogger("srt.model_runner")
# for server args in model endpoints
global_server_args_dict = {}
......@@ -215,14 +215,16 @@ class ModelRunner:
def __init__(
self,
model_config,
mem_fraction_static,
tp_rank,
tp_size,
nccl_port,
mem_fraction_static: float,
gpu_id: int,
tp_rank: int,
tp_size: int,
nccl_port: int,
server_args: ServerArgs,
):
self.model_config = model_config
self.mem_fraction_static = mem_fraction_static
self.gpu_id = gpu_id
self.tp_rank = tp_rank
self.tp_size = tp_size
self.nccl_port = nccl_port
......@@ -235,9 +237,9 @@ class ModelRunner:
}
# Init torch distributed
logger.info(f"[rank={self.tp_rank}] Set cuda device.")
torch.cuda.set_device(self.tp_rank)
logger.info(f"[rank={self.tp_rank}] Init torch begin. Avail mem={get_available_gpu_memory(self.tp_rank):.2f} GB")
logger.info(f"[gpu_id={self.gpu_id}] Set cuda device.")
torch.cuda.set_device(self.gpu_id)
logger.info(f"[gpu_id={self.gpu_id}] Init nccl begin.")
torch.distributed.init_process_group(
backend="nccl",
world_size=self.tp_size,
......@@ -245,22 +247,26 @@ class ModelRunner:
init_method=f"tcp://127.0.0.1:{self.nccl_port}",
)
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
logger.info(f"[rank={self.tp_rank}] Init torch end.")
total_gpu_memory = get_available_gpu_memory(self.tp_rank, distributed=self.tp_size > 1)
total_gpu_memory = get_available_gpu_memory(
self.gpu_id, distributed=self.tp_size > 1
)
if self.tp_size > 1:
total_local_gpu_memory = get_available_gpu_memory(self.tp_rank)
total_local_gpu_memory = get_available_gpu_memory(self.gpu_id)
if total_local_gpu_memory < total_gpu_memory * 0.9:
raise ValueError("The memory capacity is unbalanced. Some GPUs may be occupied by other processes.")
raise ValueError(
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
)
self.load_model()
self.init_memory_pool(total_gpu_memory)
self.is_multimodal_model = is_multimodal_model(self.model_config)
def load_model(self):
logger.info(f"[rank={self.tp_rank}] Load weight begin.")
logger.info(
f"[gpu_id={self.gpu_id}] Load weight begin. "
f"Avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
)
device_config = DeviceConfig()
load_config = LoadConfig(load_format=self.server_args.load_format)
......@@ -286,12 +292,16 @@ class ModelRunner:
parallel_config=None,
scheduler_config=None,
)
logger.info(f"[rank={self.tp_rank}] Load weight end. "
logger.info(
f"[gpu_id={self.gpu_id}] Load weight end. "
f"Type={type(self.model).__name__}. "
f"Avail mem={get_available_gpu_memory(self.tp_rank):.2f} GB")
f"Avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
)
def profile_max_num_token(self, total_gpu_memory):
available_gpu_memory = get_available_gpu_memory(self.tp_rank, distributed=self.tp_size > 1)
available_gpu_memory = get_available_gpu_memory(
self.gpu_id, distributed=self.tp_size > 1
)
head_dim = self.model_config.head_dim
head_num = self.model_config.num_key_value_heads // self.tp_size
cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * 2
......@@ -306,7 +316,7 @@ class ModelRunner:
if self.max_total_num_tokens <= 0:
raise RuntimeError(
"Not enought memory. " "Please try to increase --mem-fraction-static."
"Not enought memory. Please try to increase --mem-fraction-static."
)
self.req_to_token_pool = ReqToTokenPool(
......@@ -320,6 +330,10 @@ class ModelRunner:
head_dim=self.model_config.head_dim,
layer_num=self.model_config.num_hidden_layers,
)
logger.info(
f"[gpu_id={self.gpu_id}] Memory pool end. "
f"Avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
)
@torch.inference_mode()
def forward_prefill(self, batch: Batch):
......@@ -424,8 +438,8 @@ def import_model_classes():
if hasattr(module, "EntryClass"):
entry = module.EntryClass
if isinstance(entry, list): # To support multiple model classes in one module
for cls in entry:
model_arch_name_to_cls[cls.__name__] = cls
for tmp in entry:
model_arch_name_to_cls[tmp.__name__] = tmp
else:
model_arch_name_to_cls[entry.__name__] = entry
return model_arch_name_to_cls
......
......@@ -2,7 +2,7 @@ import random
from collections import defaultdict
class Scheduler:
class ScheduleHeuristic:
def __init__(
self,
schedule_heuristic,
......
import asyncio
import logging
import multiprocessing
import time
import warnings
from concurrent.futures import ThreadPoolExecutor
from typing import List, Optional
from typing import List
import rpyc
import torch
from rpyc.utils.classic import obtain
from rpyc.utils.server import ThreadedServer
try:
from vllm.logger import _default_handler as vllm_default_logger
except ImportError:
from vllm.logger import logger as vllm_default_logger
from sglang.global_config import global_config
from sglang.srt.constrained.fsm_cache import FSMCache
......@@ -26,38 +19,41 @@ from sglang.srt.managers.io_struct import (
FlushCacheReq,
TokenizedGenerateReqInput,
)
from sglang.srt.managers.router.infer_batch import Batch, FinishReason, ForwardMode, Req
from sglang.srt.managers.router.model_runner import ModelRunner
from sglang.srt.managers.router.radix_cache import RadixCache
from sglang.srt.managers.router.scheduler import Scheduler
from sglang.srt.managers.controller.infer_batch import Batch, FinishReason, ForwardMode, Req
from sglang.srt.managers.controller.model_runner import ModelRunner
from sglang.srt.managers.controller.radix_cache import RadixCache
from sglang.srt.managers.controller.schedule_heuristic import ScheduleHeuristic
from sglang.srt.model_config import ModelConfig
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.server_args import ModelPortArgs, ServerArgs
from sglang.srt.utils import (
get_int_token_logit_bias,
is_multimodal_model,
set_random_seed,
start_rpyc_process,
suppress_other_loggers,
)
from sglang.utils import get_exception_traceback
logger = logging.getLogger("model_rpc")
vllm_default_logger.setLevel(logging.WARN)
logging.getLogger("vllm.utils").setLevel(logging.WARN)
logging.getLogger("vllm.selector").setLevel(logging.WARN)
logger = logging.getLogger("srt.model_tp")
class ModelRpcServer:
class ModelTpServer:
def __init__(
self,
gpu_id: int,
tp_rank: int,
server_args: ServerArgs,
port_args: PortArgs,
model_overide_args: Optional[dict] = None,
model_port_args: ModelPortArgs,
model_overide_args,
):
server_args, port_args = [obtain(x) for x in [server_args, port_args]]
server_args, model_port_args = obtain(server_args), obtain(model_port_args)
suppress_other_loggers()
# Copy arguments
self.gpu_id = gpu_id
self.tp_rank = tp_rank
self.tp_size = server_args.tp_size
self.dp_size = server_args.dp_size
self.schedule_heuristic = server_args.schedule_heuristic
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
......@@ -68,16 +64,16 @@ class ModelRpcServer:
context_length=server_args.context_length,
model_overide_args=model_overide_args,
)
# For model end global settings
self.model_runner = ModelRunner(
model_config=self.model_config,
mem_fraction_static=server_args.mem_fraction_static,
gpu_id=gpu_id,
tp_rank=tp_rank,
tp_size=server_args.tp_size,
nccl_port=port_args.nccl_port,
nccl_port=model_port_args.nccl_port,
server_args=server_args,
)
if is_multimodal_model(server_args.model_path):
self.processor = get_processor(
server_args.tokenizer_path,
......@@ -95,21 +91,21 @@ class ModelRpcServer:
self.max_prefill_tokens = max(
self.model_config.context_len,
(
self.max_total_num_tokens // 6
min(self.max_total_num_tokens // 6, 65536)
if server_args.max_prefill_tokens is None
else server_args.max_prefill_tokens
),
)
self.max_running_requests = (self.max_total_num_tokens // 2
if server_args.max_running_requests is None else server_args.max_running_requests)
self.int_token_logit_bias = torch.tensor(
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
)
set_random_seed(server_args.random_seed)
# Print info
logger.info(f"[rank={self.tp_rank}] "
logger.info(
f"[gpu_id={self.gpu_id}] "
f"max_total_num_tokens={self.max_total_num_tokens}, "
f"max_prefill_tokens={self.max_prefill_tokens}, "
f"context_len={self.model_config.context_len}, "
......@@ -124,7 +120,7 @@ class ModelRpcServer:
disable=server_args.disable_radix_cache,
)
self.tree_cache_metrics = {"total": 0, "hit": 0}
self.scheduler = Scheduler(
self.scheduler = ScheduleHeuristic(
self.schedule_heuristic,
self.max_running_requests,
self.max_prefill_tokens,
......@@ -170,7 +166,7 @@ class ModelRpcServer:
self.new_token_ratio_recovery = global_config.new_token_ratio_recovery
def exposed_step(self, recv_reqs):
if self.tp_size != 1:
if self.tp_size * self.dp_size != 1:
recv_reqs = obtain(recv_reqs)
try:
......@@ -188,7 +184,7 @@ class ModelRpcServer:
# Forward
self.forward_step()
except Exception:
logger.error("Exception in ModelRpcClient:\n" + get_exception_traceback())
logger.error("Exception in ModelTpClient:\n" + get_exception_traceback())
# Return results
ret = self.out_pyobjs
......@@ -224,16 +220,17 @@ class ModelRpcServer:
self.token_to_kv_pool.available_size()
+ self.tree_cache.evictable_size()
)
throuhgput = self.num_generated_tokens / (
throughput = self.num_generated_tokens / (
time.time() - self.last_stats_tic
)
self.num_generated_tokens = 0
self.last_stats_tic = time.time()
logger.info(
f"[gpu_id={self.gpu_id}] "
f"#running-req: {len(self.running_batch.reqs)}, "
f"#token: {num_used}, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"gen throughput (token/s): {throuhgput:.2f}, "
f"gen throughput (token/s): {throughput:.2f}, "
f"#queue-req: {len(self.forward_queue)}"
)
......@@ -405,7 +402,7 @@ class ModelRpcServer:
f"#new_token: {new_batch_input_tokens}. "
f"#remaining_req: {len(self.forward_queue) - len(can_run_list)}. "
f"#running_req: {running_req}. "
f"tree_cache_hit_rate: {100.0 * tree_cache_hit_rate:.2f}%."
f"tree_cache_hit_rate: {100.0 * tree_cache_hit_rate:.2f}%. "
)
# logger.debug(
# f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
......@@ -724,20 +721,30 @@ class ModelRpcServer:
break
class ModelRpcService(rpyc.Service):
exposed_ModelRpcServer = ModelRpcServer
class ModelTpService(rpyc.Service):
exposed_ModelTpServer = ModelTpServer
class ModelRpcClient:
class ModelTpClient:
def __init__(
self, server_args: ServerArgs, port_args: PortArgs, model_overide_args
self,
gpu_ids: List[int],
server_args: ServerArgs,
model_port_args: ModelPortArgs,
model_overide_args,
):
tp_size = server_args.tp_size
server_args, model_port_args = obtain(server_args), obtain(model_port_args)
self.tp_size = server_args.tp_size
if tp_size == 1:
if self.tp_size * server_args.dp_size == 1:
# Init model
self.model_server = ModelRpcService().exposed_ModelRpcServer(
0, server_args, port_args, model_overide_args
assert len(gpu_ids) == 1
self.model_server = ModelTpService().exposed_ModelTpServer(
0,
gpu_ids[0],
server_args,
model_port_args,
model_overide_args,
)
# Wrap functions
......@@ -749,19 +756,26 @@ class ModelRpcClient:
self.step = async_wrap(self.model_server.exposed_step)
else:
with ThreadPoolExecutor(tp_size) as executor:
with ThreadPoolExecutor(self.tp_size) as executor:
# Launch model processes
rets = executor.map(start_model_process, port_args.model_rpc_ports)
self.remote_services = [x[0] for x in rets]
rets = executor.map(
lambda args: start_rpyc_process(*args),
[(ModelTpService, p) for p in model_port_args.model_tp_ports],
)
self.model_services = [x[0] for x in rets]
self.procs = [x[1] for x in rets]
# Init model
def init_model(i):
return self.remote_services[i].ModelRpcServer(
i, server_args, port_args, model_overide_args
return self.model_services[i].ModelTpServer(
gpu_ids[i],
i,
server_args,
model_port_args,
model_overide_args,
)
self.model_servers = executor.map(init_model, range(tp_size))
self.model_servers = executor.map(init_model, range(self.tp_size))
# Wrap functions
def async_wrap(func_name):
......@@ -775,44 +789,3 @@ class ModelRpcClient:
return _func
self.step = async_wrap("step")
\ No newline at end of file
def _init_service(port):
t = ThreadedServer(
ModelRpcService(),
port=port,
protocol_config={
"allow_public_attrs": True,
"allow_pickle": True,
"sync_request_timeout": 3600,
},
)
t.start()
def start_model_process(port):
proc = multiprocessing.Process(target=_init_service, args=(port,))
proc.start()
time.sleep(1)
repeat_count = 0
while repeat_count < 20:
try:
con = rpyc.connect(
"localhost",
port,
config={
"allow_public_attrs": True,
"allow_pickle": True,
"sync_request_timeout": 3600,
},
)
break
except ConnectionRefusedError:
time.sleep(1)
repeat_count += 1
if repeat_count == 20:
raise RuntimeError("init rpc env error!")
assert proc.is_alive()
return con.root, proc
......@@ -27,7 +27,6 @@ class GenerateReqInput:
return_text_in_logprobs: bool = False
# Whether to stream output
stream: bool = False
# TODO: make all parameters a Union[List[T], T] to allow for batched requests
def post_init(self):
......
......@@ -48,7 +48,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata
from sglang.srt.managers.controller.model_runner import InputMetadata
@torch.compile
......
......@@ -29,7 +29,7 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata
from sglang.srt.managers.controller.model_runner import InputMetadata
class DbrxRouter(nn.Module):
......
......@@ -22,7 +22,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata
from sglang.srt.managers.controller.model_runner import InputMetadata
class GemmaMLP(nn.Module):
......
......@@ -37,7 +37,7 @@ from vllm.utils import print_warning_once
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.fused_moe import fused_moe
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata
from sglang.srt.managers.controller.model_runner import InputMetadata
use_fused = True
......
......@@ -4,9 +4,13 @@
from typing import Any, Dict, Optional, Tuple, Iterable
import torch
import tqdm
from torch import nn
from transformers import LlamaConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size
)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
......@@ -24,7 +28,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata
from sglang.srt.managers.controller.model_runner import InputMetadata
class LlamaMLP(nn.Module):
......@@ -284,6 +288,8 @@ class LlamaForCausalLM(nn.Module):
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
if get_tensor_model_parallel_rank() == 0:
weights = tqdm.tqdm(weights, total=int(len(params_dict) * 1.5))
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name or "projector" in name:
continue
......
......@@ -10,8 +10,8 @@ from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.managers.router.infer_batch import ForwardMode
from sglang.srt.managers.router.model_runner import InputMetadata
from sglang.srt.managers.controller.infer_batch import ForwardMode
from sglang.srt.managers.controller.model_runner import InputMetadata
from sglang.srt.mm_utils import (
get_anyres_image_grid_shape,
unpad_image,
......
......@@ -10,8 +10,8 @@ from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.managers.router.infer_batch import ForwardMode
from sglang.srt.managers.router.model_runner import InputMetadata
from sglang.srt.managers.controller.infer_batch import ForwardMode
from sglang.srt.managers.controller.model_runner import InputMetadata
from sglang.srt.mm_utils import (
get_anyres_image_grid_shape,
unpad_image,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment