Unverified Commit 6a2941f4 authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

Improve tensor parallel performance (#625)


Co-authored-by: default avatarMingyi <wisclmy0611@gmail.com>
parent 5ac8b806
......@@ -377,6 +377,14 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --mem-fraction-static 0.7
```
- See [hyperparameter_tuning.md](docs/hyperparameter_tuning.md) on tuning hyperparameters for better performance.
- Add `--nnodes 2` to run tensor parallelism on multiple nodes. If you have two nodes with two GPUs on each node and want to run TP=4, let `sgl-dev-1` be the hostname of the first node and `50000` be an available port.
```
# Node 0
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --tp 4 --nccl-init sgl-dev-1:50000 --nnodes 2 --node-rank 0
# Node 1
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --tp 4 --nccl-init sgl-dev-1:50000 --nnodes 2 --node-rank 1
```
### Supported Models
- Llama
......
......@@ -96,8 +96,11 @@ def run_one_batch_size(bs):
ret = response.json()
print(ret)
input_len = args.input_len if args.input_len else 1
output_len = max_new_tokens
output_throughput = bs * max_new_tokens / latency
overall_throughput = bs * (args.input_len + max_new_tokens) / latency
overall_throughput = bs * (input_len + output_len) / latency
print(f"latency: {latency:.2f} s")
print(f"decode throughput: {output_throughput:.2f} token/s")
print(f"overall throughput: {overall_throughput:.2f} token/s")
......
......@@ -312,6 +312,9 @@ def main(args: argparse.Namespace):
np.sum([output_len for _, output_len, _ in REQUEST_LATENCY]) / benchmark_time
)
#latencies = [round(latency, 2) for _, _, latency in REQUEST_LATENCY]
#print(latencies)
print(f"Total time: {benchmark_time:.2f} s")
print(f"Request throughput: {args.num_prompts / benchmark_time:.2f} requests/s")
print(f"Decoding throughput: {decoding_throughput:.2f} token/s")
......
......@@ -2,11 +2,10 @@
- `backend`: Various backends for the language interpreter.
- `lang`: The frontend language.
- `srt`: The runtime for running local models.
- `srt`: The serving engine for running local models. (SRT = SGLang Runtime).
- `test`: Test utilities.
- `api.py`: Public API.
- `bench_latency.py`: Benchmark utilities.
- `global_config.py`: The global configs and constants.
- `launch_server.py`: The entry point of launching local server.
- `utils.py`: Common utilities.
......@@ -42,6 +42,8 @@ class LoadBalanceMethod(Enum):
class Controller:
"""A controller that manages multiple data parallel workers."""
def __init__(
self,
load_balance_method: str,
......@@ -183,9 +185,11 @@ def start_controller_process(
except Exception:
pipe_writer.send(get_exception_traceback())
raise
pipe_writer.send("init ok")
loop = asyncio.get_event_loop()
loop = asyncio.new_event_loop()
loop.set_default_executor(ThreadPoolExecutor(max_workers=256))
asyncio.set_event_loop(loop)
loop.create_task(controller.loop_for_recv_requests())
loop.run_until_complete(controller.loop_for_forward())
"""A controller that manages a group of tensor parallel workers."""
import asyncio
import multiprocessing
import logging
from concurrent.futures import ThreadPoolExecutor
import os
import pickle
import uvloop
import torch
import torch.distributed as dist
import zmq
import zmq.asyncio
from sglang.global_config import global_config
from sglang.srt.managers.controller.tp_worker import ModelTpClient
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.managers.controller.tp_worker import ModelTpServer
from sglang.srt.server_args import PortArgs, ServerArgs, ModelPortArgs
from sglang.srt.utils import kill_parent_process
from sglang.utils import get_exception_traceback
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
logger = logging.getLogger("srt.controller")
def run_tp_server(
gpu_id: int,
tp_rank: int,
server_args: ServerArgs,
model_port_args: ModelPortArgs,
model_overide_args: dict,
):
"""Run a tp server."""
try:
model_server = ModelTpServer(
gpu_id,
tp_rank,
server_args,
model_port_args,
model_overide_args,
)
tp_cpu_group = model_server.model_runner.tp_group.cpu_group
while True:
recv_reqs = broadcast_recv_input(None, tp_rank, tp_cpu_group)
model_server.exposed_step(recv_reqs)
except Exception:
logger.error("Exception in run_tp_server:\n" + get_exception_traceback())
raise
def launch_tp_servers(gpu_ids, tp_rank_range, server_args,
model_port_args, model_overide_args):
"""Launch multiple tp servers."""
procs = []
for i in tp_rank_range:
proc = multiprocessing.Process(target=run_tp_server, args=(
gpu_ids[i], i, server_args, model_port_args, model_overide_args
))
proc.start()
procs.append(proc)
return procs
def broadcast_recv_input(data, rank, dist_group):
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
if rank == 0:
if len(data) == 0:
tensor_size = torch.tensor([0], dtype=torch.long)
dist.broadcast(tensor_size, src=0, group=dist_group)
else:
serialized_data = pickle.dumps(data)
size = len(serialized_data)
tensor_data = torch.ByteTensor(list(serialized_data))
tensor_size = torch.tensor([size], dtype=torch.long)
dist.broadcast(tensor_size, src=0, group=dist_group)
dist.broadcast(tensor_data, src=0, group=dist_group)
else:
tensor_size = torch.tensor([0], dtype=torch.long)
dist.broadcast(tensor_size, src=0, group=dist_group)
size = tensor_size.item()
if size == 0:
return []
tensor_data = torch.empty(size, dtype=torch.uint8)
dist.broadcast(tensor_data, src=0, group=dist_group)
serialized_data = bytes(tensor_data.tolist())
data = pickle.loads(serialized_data)
return data
class ControllerSingle:
def __init__(self, model_client: ModelTpClient, port_args: PortArgs):
"""A controller that manages a group of tensor parallel workers."""
def __init__(self, server_args: ServerArgs, port_args: PortArgs, model_overide_args: dict):
# Parse args
self.server_args = server_args
# Init communication
context = zmq.asyncio.Context(2)
context = zmq.Context(2)
self.recv_from_tokenizer = context.socket(zmq.PULL)
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.router_port}")
......@@ -31,44 +107,52 @@ class ControllerSingle:
f"tcp://127.0.0.1:{port_args.detokenizer_port}"
)
# Init status
self.model_client = model_client
self.recv_reqs = []
# Init some configs
self.request_dependency_delay = global_config.request_dependency_delay
# Init model server
tp_size_local = server_args.tp_size // server_args.nnodes
gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
# Launch other tp ranks
if tp_size_local > 1:
tp_rank_range = range(1, tp_size_local)
self.tp_procs = launch_tp_servers(
gpu_ids, tp_rank_range, server_args,
port_args.model_port_args[0], model_overide_args)
# Launch tp rank 0
self.tp_server = ModelTpServer(
gpu_ids[0],
0,
server_args,
port_args.model_port_args[0],
model_overide_args,
)
self.tp_cpu_group = self.tp_server.model_runner.tp_group.cpu_group
async def loop_for_forward(self):
def loop_for_forward(self):
while True:
next_step_input = list(self.recv_reqs)
self.recv_reqs = []
out_pyobjs = await self.model_client.step(next_step_input)
recv_reqs = self.recv_requests()
if self.server_args.tp_size > 1:
broadcast_recv_input(recv_reqs, 0, self.tp_cpu_group)
out_pyobjs = self.tp_server.exposed_step(recv_reqs)
for obj in out_pyobjs:
self.send_to_detokenizer.send_pyobj(obj)
# async sleep for receiving the subsequent request and avoiding cache miss
slept = False
if len(out_pyobjs) != 0:
has_finished = any(
[obj.finished_reason is not None for obj in out_pyobjs]
)
if has_finished:
if self.request_dependency_delay > 0:
slept = True
await asyncio.sleep(self.request_dependency_delay)
if not slept:
await asyncio.sleep(global_config.wait_for_new_request_delay)
async def loop_for_recv_requests(self):
def recv_requests(self):
recv_reqs = []
while True:
recv_req = await self.recv_from_tokenizer.recv_pyobj()
self.recv_reqs.append(recv_req)
try:
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
recv_reqs.append(recv_req)
except zmq.ZMQError:
break
return recv_reqs
def start_controller_process(
server_args: ServerArgs, port_args: PortArgs, pipe_writer, model_overide_args
server_args: ServerArgs, port_args: PortArgs, pipe_writer, model_overide_args: dict
):
logging.basicConfig(
level=getattr(logging, server_args.log_level.upper()),
......@@ -76,27 +160,18 @@ def start_controller_process(
)
try:
tp_size_local = server_args.tp_size // server_args.nnodes
model_client = ModelTpClient(
[i for _ in range(server_args.nnodes) for i in range(tp_size_local)],
server_args,
port_args.model_port_args[0],
model_overide_args,
)
controller = ControllerSingle(model_client, port_args)
controller = ControllerSingle(server_args, port_args, model_overide_args)
except Exception:
pipe_writer.send(get_exception_traceback())
raise
pipe_writer.send("init ok")
loop = asyncio.new_event_loop()
loop.set_default_executor(ThreadPoolExecutor(max_workers=256))
asyncio.set_event_loop(loop)
loop.create_task(controller.loop_for_recv_requests())
try:
loop.run_until_complete(controller.loop_for_forward())
controller.loop_for_forward()
except Exception:
logger.error("Exception in ControllerSingle:\n" + get_exception_traceback())
finally:
for t in controller.tp_procs:
os.kill(t.pid, 9)
kill_parent_process()
......@@ -11,7 +11,7 @@ import torch
import torch.nn as nn
from vllm.config import DeviceConfig, LoadConfig
from vllm.config import ModelConfig as VllmModelConfig
from vllm.distributed import init_distributed_environment, initialize_model_parallel
from vllm.distributed import init_distributed_environment, initialize_model_parallel, get_tp_group
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import ModelRegistry
......@@ -75,6 +75,7 @@ class ModelRunner:
distributed_init_method=nccl_init_method,
)
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
self.tp_group = get_tp_group()
total_gpu_memory = get_available_gpu_memory(
self.gpu_id, distributed=self.tp_size > 1
)
......
......@@ -53,7 +53,7 @@ class ModelTpServer:
tp_rank: int,
server_args: ServerArgs,
model_port_args: ModelPortArgs,
model_overide_args,
model_overide_args: dict,
):
server_args, model_port_args = obtain(server_args), obtain(model_port_args)
suppress_other_loggers()
......@@ -178,7 +178,7 @@ class ModelTpServer:
self.new_token_ratio_recovery = global_config.new_token_ratio_recovery
def exposed_step(self, recv_reqs):
if self.tp_size * self.dp_size != 1:
if not isinstance(recv_reqs, list):
recv_reqs = obtain(recv_reqs)
try:
......@@ -206,11 +206,11 @@ class ModelTpServer:
@torch.inference_mode()
def forward_step(self):
new_batch = self.get_new_fill_batch()
new_batch = self.get_new_prefill_batch()
if new_batch is not None:
# Run a new fill batch
self.forward_fill_batch(new_batch)
# Run a new prefill batch
self.forward_prefill_batch(new_batch)
self.cache_filled_batch(new_batch)
if not new_batch.is_empty():
......@@ -219,7 +219,7 @@ class ModelTpServer:
else:
self.running_batch.merge(new_batch)
else:
# Run decode batch
# Run a decode batch
if self.running_batch is not None:
# Run a few decode batches continuously for reducing overhead
for _ in range(global_config.num_continue_decode_steps):
......@@ -312,7 +312,7 @@ class ModelTpServer:
)
self.forward_queue.append(req)
def get_new_fill_batch(self) -> Optional[Batch]:
def get_new_prefill_batch(self) -> Optional[Batch]:
running_bs = (
len(self.running_batch.reqs) if self.running_batch is not None else 0
)
......@@ -436,7 +436,7 @@ class ModelTpServer:
self.forward_queue = [x for x in self.forward_queue if x not in can_run_list]
return new_batch
def forward_fill_batch(self, batch: Batch):
def forward_prefill_batch(self, batch: Batch):
# Build batch tensors
batch.prepare_for_extend(
self.model_config.vocab_size, self.int_token_logit_bias
......@@ -746,8 +746,8 @@ class ModelTpClient:
# Init model
assert len(gpu_ids) == 1
self.model_server = ModelTpService().exposed_ModelTpServer(
0,
gpu_ids[0],
0,
server_args,
model_port_args,
model_overide_args,
......
......@@ -33,9 +33,9 @@ from sglang.srt.managers.controller.manager_multi import (
start_controller_process as start_controller_process_multi,
)
from sglang.srt.managers.controller.manager_single import (
launch_tp_servers,
start_controller_process as start_controller_process_single,
)
from sglang.srt.managers.controller.tp_worker import ModelTpService
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.tokenizer_manager import TokenizerManager
......@@ -53,7 +53,6 @@ from sglang.srt.utils import (
enable_show_time_cost,
receive_addrs,
send_addrs_to_rank_0,
start_rpyc_service_process,
)
from sglang.utils import get_exception_traceback
......@@ -192,21 +191,17 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
model_port_args=model_port_args,
)
# TODO multi-node dp is not supported
assert not (server_args.dp_size > 1 and server_args.node_rank is not None)
# Handle multi-node tp
if server_args.nnodes > 1:
assert server_args.dp_size == 1, "Multi-node dp is not supported."
if server_args.node_rank != 0:
send_addrs_to_rank_0(model_port_args[0], server_args)
else:
receive_addrs(model_port_args[0], server_args)
for i in range(tp_size_local):
start_rpyc_service_process(
ModelTpService, model_port_args[0].model_tp_ports[i]
)
if server_args.node_rank != 0:
logger.info(
f"[node_rank={server_args.node_rank}]: Listen for connections..."
)
tp_size_local = server_args.tp_size // server_args.nnodes
gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
tp_rank_range = list(range(server_args.node_rank * tp_size_local,
(server_args.node_rank + 1) * tp_size_local))
procs = launch_tp_servers(gpu_ids, tp_rank_range, server_args,
port_args.model_port_args[0], model_overide_args)
while True:
pass
......
......@@ -67,10 +67,12 @@ class ServerArgs:
if self.tokenizer_path is None:
self.tokenizer_path = self.model_path
if self.mem_fraction_static is None:
if self.tp_size >= 8:
if self.tp_size >= 16:
self.mem_fraction_static = 0.74
elif self.tp_size >= 8:
self.mem_fraction_static = 0.78
elif self.tp_size >= 4:
self.mem_fraction_static = 0.80
self.mem_fraction_static = 0.82
elif self.tp_size >= 2:
self.mem_fraction_static = 0.85
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment