Unverified Commit 6a2941f4 authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

Improve tensor parallel performance (#625)


Co-authored-by: default avatarMingyi <wisclmy0611@gmail.com>
parent 5ac8b806
...@@ -377,6 +377,14 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port ...@@ -377,6 +377,14 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --mem-fraction-static 0.7 python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --mem-fraction-static 0.7
``` ```
- See [hyperparameter_tuning.md](docs/hyperparameter_tuning.md) on tuning hyperparameters for better performance. - See [hyperparameter_tuning.md](docs/hyperparameter_tuning.md) on tuning hyperparameters for better performance.
- Add `--nnodes 2` to run tensor parallelism on multiple nodes. If you have two nodes with two GPUs on each node and want to run TP=4, let `sgl-dev-1` be the hostname of the first node and `50000` be an available port.
```
# Node 0
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --tp 4 --nccl-init sgl-dev-1:50000 --nnodes 2 --node-rank 0
# Node 1
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --tp 4 --nccl-init sgl-dev-1:50000 --nnodes 2 --node-rank 1
```
### Supported Models ### Supported Models
- Llama - Llama
......
...@@ -96,8 +96,11 @@ def run_one_batch_size(bs): ...@@ -96,8 +96,11 @@ def run_one_batch_size(bs):
ret = response.json() ret = response.json()
print(ret) print(ret)
input_len = args.input_len if args.input_len else 1
output_len = max_new_tokens
output_throughput = bs * max_new_tokens / latency output_throughput = bs * max_new_tokens / latency
overall_throughput = bs * (args.input_len + max_new_tokens) / latency overall_throughput = bs * (input_len + output_len) / latency
print(f"latency: {latency:.2f} s") print(f"latency: {latency:.2f} s")
print(f"decode throughput: {output_throughput:.2f} token/s") print(f"decode throughput: {output_throughput:.2f} token/s")
print(f"overall throughput: {overall_throughput:.2f} token/s") print(f"overall throughput: {overall_throughput:.2f} token/s")
......
...@@ -312,6 +312,9 @@ def main(args: argparse.Namespace): ...@@ -312,6 +312,9 @@ def main(args: argparse.Namespace):
np.sum([output_len for _, output_len, _ in REQUEST_LATENCY]) / benchmark_time np.sum([output_len for _, output_len, _ in REQUEST_LATENCY]) / benchmark_time
) )
#latencies = [round(latency, 2) for _, _, latency in REQUEST_LATENCY]
#print(latencies)
print(f"Total time: {benchmark_time:.2f} s") print(f"Total time: {benchmark_time:.2f} s")
print(f"Request throughput: {args.num_prompts / benchmark_time:.2f} requests/s") print(f"Request throughput: {args.num_prompts / benchmark_time:.2f} requests/s")
print(f"Decoding throughput: {decoding_throughput:.2f} token/s") print(f"Decoding throughput: {decoding_throughput:.2f} token/s")
......
...@@ -2,11 +2,10 @@ ...@@ -2,11 +2,10 @@
- `backend`: Various backends for the language interpreter. - `backend`: Various backends for the language interpreter.
- `lang`: The frontend language. - `lang`: The frontend language.
- `srt`: The runtime for running local models. - `srt`: The serving engine for running local models. (SRT = SGLang Runtime).
- `test`: Test utilities. - `test`: Test utilities.
- `api.py`: Public API. - `api.py`: Public API.
- `bench_latency.py`: Benchmark utilities. - `bench_latency.py`: Benchmark utilities.
- `global_config.py`: The global configs and constants. - `global_config.py`: The global configs and constants.
- `launch_server.py`: The entry point of launching local server. - `launch_server.py`: The entry point of launching local server.
- `utils.py`: Common utilities. - `utils.py`: Common utilities.
...@@ -42,6 +42,8 @@ class LoadBalanceMethod(Enum): ...@@ -42,6 +42,8 @@ class LoadBalanceMethod(Enum):
class Controller: class Controller:
"""A controller that manages multiple data parallel workers."""
def __init__( def __init__(
self, self,
load_balance_method: str, load_balance_method: str,
...@@ -183,9 +185,11 @@ def start_controller_process( ...@@ -183,9 +185,11 @@ def start_controller_process(
except Exception: except Exception:
pipe_writer.send(get_exception_traceback()) pipe_writer.send(get_exception_traceback())
raise raise
pipe_writer.send("init ok") pipe_writer.send("init ok")
loop = asyncio.get_event_loop()
loop = asyncio.new_event_loop()
loop.set_default_executor(ThreadPoolExecutor(max_workers=256))
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
loop.create_task(controller.loop_for_recv_requests()) loop.create_task(controller.loop_for_recv_requests())
loop.run_until_complete(controller.loop_for_forward()) loop.run_until_complete(controller.loop_for_forward())
"""A controller that manages a group of tensor parallel workers.""" """A controller that manages a group of tensor parallel workers."""
import asyncio import multiprocessing
import logging import logging
from concurrent.futures import ThreadPoolExecutor import os
import pickle
import uvloop import torch
import torch.distributed as dist
import zmq import zmq
import zmq.asyncio import zmq.asyncio
from sglang.global_config import global_config from sglang.srt.managers.controller.tp_worker import ModelTpServer
from sglang.srt.managers.controller.tp_worker import ModelTpClient from sglang.srt.server_args import PortArgs, ServerArgs, ModelPortArgs
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import kill_parent_process from sglang.srt.utils import kill_parent_process
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
logger = logging.getLogger("srt.controller") logger = logging.getLogger("srt.controller")
def run_tp_server(
gpu_id: int,
tp_rank: int,
server_args: ServerArgs,
model_port_args: ModelPortArgs,
model_overide_args: dict,
):
"""Run a tp server."""
try:
model_server = ModelTpServer(
gpu_id,
tp_rank,
server_args,
model_port_args,
model_overide_args,
)
tp_cpu_group = model_server.model_runner.tp_group.cpu_group
while True:
recv_reqs = broadcast_recv_input(None, tp_rank, tp_cpu_group)
model_server.exposed_step(recv_reqs)
except Exception:
logger.error("Exception in run_tp_server:\n" + get_exception_traceback())
raise
def launch_tp_servers(gpu_ids, tp_rank_range, server_args,
model_port_args, model_overide_args):
"""Launch multiple tp servers."""
procs = []
for i in tp_rank_range:
proc = multiprocessing.Process(target=run_tp_server, args=(
gpu_ids[i], i, server_args, model_port_args, model_overide_args
))
proc.start()
procs.append(proc)
return procs
def broadcast_recv_input(data, rank, dist_group):
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
if rank == 0:
if len(data) == 0:
tensor_size = torch.tensor([0], dtype=torch.long)
dist.broadcast(tensor_size, src=0, group=dist_group)
else:
serialized_data = pickle.dumps(data)
size = len(serialized_data)
tensor_data = torch.ByteTensor(list(serialized_data))
tensor_size = torch.tensor([size], dtype=torch.long)
dist.broadcast(tensor_size, src=0, group=dist_group)
dist.broadcast(tensor_data, src=0, group=dist_group)
else:
tensor_size = torch.tensor([0], dtype=torch.long)
dist.broadcast(tensor_size, src=0, group=dist_group)
size = tensor_size.item()
if size == 0:
return []
tensor_data = torch.empty(size, dtype=torch.uint8)
dist.broadcast(tensor_data, src=0, group=dist_group)
serialized_data = bytes(tensor_data.tolist())
data = pickle.loads(serialized_data)
return data
class ControllerSingle: class ControllerSingle:
def __init__(self, model_client: ModelTpClient, port_args: PortArgs): """A controller that manages a group of tensor parallel workers."""
def __init__(self, server_args: ServerArgs, port_args: PortArgs, model_overide_args: dict):
# Parse args
self.server_args = server_args
# Init communication # Init communication
context = zmq.asyncio.Context(2) context = zmq.Context(2)
self.recv_from_tokenizer = context.socket(zmq.PULL) self.recv_from_tokenizer = context.socket(zmq.PULL)
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.router_port}") self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.router_port}")
...@@ -31,44 +107,52 @@ class ControllerSingle: ...@@ -31,44 +107,52 @@ class ControllerSingle:
f"tcp://127.0.0.1:{port_args.detokenizer_port}" f"tcp://127.0.0.1:{port_args.detokenizer_port}"
) )
# Init status # Init model server
self.model_client = model_client tp_size_local = server_args.tp_size // server_args.nnodes
self.recv_reqs = [] gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
# Init some configs # Launch other tp ranks
self.request_dependency_delay = global_config.request_dependency_delay if tp_size_local > 1:
tp_rank_range = range(1, tp_size_local)
self.tp_procs = launch_tp_servers(
gpu_ids, tp_rank_range, server_args,
port_args.model_port_args[0], model_overide_args)
# Launch tp rank 0
self.tp_server = ModelTpServer(
gpu_ids[0],
0,
server_args,
port_args.model_port_args[0],
model_overide_args,
)
self.tp_cpu_group = self.tp_server.model_runner.tp_group.cpu_group
async def loop_for_forward(self): def loop_for_forward(self):
while True: while True:
next_step_input = list(self.recv_reqs) recv_reqs = self.recv_requests()
self.recv_reqs = []
out_pyobjs = await self.model_client.step(next_step_input)
for obj in out_pyobjs: if self.server_args.tp_size > 1:
self.send_to_detokenizer.send_pyobj(obj) broadcast_recv_input(recv_reqs, 0, self.tp_cpu_group)
# async sleep for receiving the subsequent request and avoiding cache miss out_pyobjs = self.tp_server.exposed_step(recv_reqs)
slept = False
if len(out_pyobjs) != 0:
has_finished = any(
[obj.finished_reason is not None for obj in out_pyobjs]
)
if has_finished:
if self.request_dependency_delay > 0:
slept = True
await asyncio.sleep(self.request_dependency_delay)
if not slept: for obj in out_pyobjs:
await asyncio.sleep(global_config.wait_for_new_request_delay) self.send_to_detokenizer.send_pyobj(obj)
async def loop_for_recv_requests(self): def recv_requests(self):
recv_reqs = []
while True: while True:
recv_req = await self.recv_from_tokenizer.recv_pyobj() try:
self.recv_reqs.append(recv_req) recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
recv_reqs.append(recv_req)
except zmq.ZMQError:
break
return recv_reqs
def start_controller_process( def start_controller_process(
server_args: ServerArgs, port_args: PortArgs, pipe_writer, model_overide_args server_args: ServerArgs, port_args: PortArgs, pipe_writer, model_overide_args: dict
): ):
logging.basicConfig( logging.basicConfig(
level=getattr(logging, server_args.log_level.upper()), level=getattr(logging, server_args.log_level.upper()),
...@@ -76,27 +160,18 @@ def start_controller_process( ...@@ -76,27 +160,18 @@ def start_controller_process(
) )
try: try:
tp_size_local = server_args.tp_size // server_args.nnodes controller = ControllerSingle(server_args, port_args, model_overide_args)
model_client = ModelTpClient(
[i for _ in range(server_args.nnodes) for i in range(tp_size_local)],
server_args,
port_args.model_port_args[0],
model_overide_args,
)
controller = ControllerSingle(model_client, port_args)
except Exception: except Exception:
pipe_writer.send(get_exception_traceback()) pipe_writer.send(get_exception_traceback())
raise raise
pipe_writer.send("init ok") pipe_writer.send("init ok")
loop = asyncio.new_event_loop()
loop.set_default_executor(ThreadPoolExecutor(max_workers=256))
asyncio.set_event_loop(loop)
loop.create_task(controller.loop_for_recv_requests())
try: try:
loop.run_until_complete(controller.loop_for_forward()) controller.loop_for_forward()
except Exception: except Exception:
logger.error("Exception in ControllerSingle:\n" + get_exception_traceback()) logger.error("Exception in ControllerSingle:\n" + get_exception_traceback())
finally: finally:
for t in controller.tp_procs:
os.kill(t.pid, 9)
kill_parent_process() kill_parent_process()
...@@ -11,7 +11,7 @@ import torch ...@@ -11,7 +11,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from vllm.config import DeviceConfig, LoadConfig from vllm.config import DeviceConfig, LoadConfig
from vllm.config import ModelConfig as VllmModelConfig from vllm.config import ModelConfig as VllmModelConfig
from vllm.distributed import init_distributed_environment, initialize_model_parallel from vllm.distributed import init_distributed_environment, initialize_model_parallel, get_tp_group
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
...@@ -75,6 +75,7 @@ class ModelRunner: ...@@ -75,6 +75,7 @@ class ModelRunner:
distributed_init_method=nccl_init_method, distributed_init_method=nccl_init_method,
) )
initialize_model_parallel(tensor_model_parallel_size=self.tp_size) initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
self.tp_group = get_tp_group()
total_gpu_memory = get_available_gpu_memory( total_gpu_memory = get_available_gpu_memory(
self.gpu_id, distributed=self.tp_size > 1 self.gpu_id, distributed=self.tp_size > 1
) )
......
...@@ -53,7 +53,7 @@ class ModelTpServer: ...@@ -53,7 +53,7 @@ class ModelTpServer:
tp_rank: int, tp_rank: int,
server_args: ServerArgs, server_args: ServerArgs,
model_port_args: ModelPortArgs, model_port_args: ModelPortArgs,
model_overide_args, model_overide_args: dict,
): ):
server_args, model_port_args = obtain(server_args), obtain(model_port_args) server_args, model_port_args = obtain(server_args), obtain(model_port_args)
suppress_other_loggers() suppress_other_loggers()
...@@ -178,7 +178,7 @@ class ModelTpServer: ...@@ -178,7 +178,7 @@ class ModelTpServer:
self.new_token_ratio_recovery = global_config.new_token_ratio_recovery self.new_token_ratio_recovery = global_config.new_token_ratio_recovery
def exposed_step(self, recv_reqs): def exposed_step(self, recv_reqs):
if self.tp_size * self.dp_size != 1: if not isinstance(recv_reqs, list):
recv_reqs = obtain(recv_reqs) recv_reqs = obtain(recv_reqs)
try: try:
...@@ -206,11 +206,11 @@ class ModelTpServer: ...@@ -206,11 +206,11 @@ class ModelTpServer:
@torch.inference_mode() @torch.inference_mode()
def forward_step(self): def forward_step(self):
new_batch = self.get_new_fill_batch() new_batch = self.get_new_prefill_batch()
if new_batch is not None: if new_batch is not None:
# Run a new fill batch # Run a new prefill batch
self.forward_fill_batch(new_batch) self.forward_prefill_batch(new_batch)
self.cache_filled_batch(new_batch) self.cache_filled_batch(new_batch)
if not new_batch.is_empty(): if not new_batch.is_empty():
...@@ -219,7 +219,7 @@ class ModelTpServer: ...@@ -219,7 +219,7 @@ class ModelTpServer:
else: else:
self.running_batch.merge(new_batch) self.running_batch.merge(new_batch)
else: else:
# Run decode batch # Run a decode batch
if self.running_batch is not None: if self.running_batch is not None:
# Run a few decode batches continuously for reducing overhead # Run a few decode batches continuously for reducing overhead
for _ in range(global_config.num_continue_decode_steps): for _ in range(global_config.num_continue_decode_steps):
...@@ -312,7 +312,7 @@ class ModelTpServer: ...@@ -312,7 +312,7 @@ class ModelTpServer:
) )
self.forward_queue.append(req) self.forward_queue.append(req)
def get_new_fill_batch(self) -> Optional[Batch]: def get_new_prefill_batch(self) -> Optional[Batch]:
running_bs = ( running_bs = (
len(self.running_batch.reqs) if self.running_batch is not None else 0 len(self.running_batch.reqs) if self.running_batch is not None else 0
) )
...@@ -436,7 +436,7 @@ class ModelTpServer: ...@@ -436,7 +436,7 @@ class ModelTpServer:
self.forward_queue = [x for x in self.forward_queue if x not in can_run_list] self.forward_queue = [x for x in self.forward_queue if x not in can_run_list]
return new_batch return new_batch
def forward_fill_batch(self, batch: Batch): def forward_prefill_batch(self, batch: Batch):
# Build batch tensors # Build batch tensors
batch.prepare_for_extend( batch.prepare_for_extend(
self.model_config.vocab_size, self.int_token_logit_bias self.model_config.vocab_size, self.int_token_logit_bias
...@@ -746,8 +746,8 @@ class ModelTpClient: ...@@ -746,8 +746,8 @@ class ModelTpClient:
# Init model # Init model
assert len(gpu_ids) == 1 assert len(gpu_ids) == 1
self.model_server = ModelTpService().exposed_ModelTpServer( self.model_server = ModelTpService().exposed_ModelTpServer(
0,
gpu_ids[0], gpu_ids[0],
0,
server_args, server_args,
model_port_args, model_port_args,
model_overide_args, model_overide_args,
......
...@@ -33,9 +33,9 @@ from sglang.srt.managers.controller.manager_multi import ( ...@@ -33,9 +33,9 @@ from sglang.srt.managers.controller.manager_multi import (
start_controller_process as start_controller_process_multi, start_controller_process as start_controller_process_multi,
) )
from sglang.srt.managers.controller.manager_single import ( from sglang.srt.managers.controller.manager_single import (
launch_tp_servers,
start_controller_process as start_controller_process_single, start_controller_process as start_controller_process_single,
) )
from sglang.srt.managers.controller.tp_worker import ModelTpService
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
from sglang.srt.managers.io_struct import GenerateReqInput from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.managers.tokenizer_manager import TokenizerManager
...@@ -53,7 +53,6 @@ from sglang.srt.utils import ( ...@@ -53,7 +53,6 @@ from sglang.srt.utils import (
enable_show_time_cost, enable_show_time_cost,
receive_addrs, receive_addrs,
send_addrs_to_rank_0, send_addrs_to_rank_0,
start_rpyc_service_process,
) )
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
...@@ -192,21 +191,17 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg ...@@ -192,21 +191,17 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
model_port_args=model_port_args, model_port_args=model_port_args,
) )
# TODO multi-node dp is not supported # Handle multi-node tp
assert not (server_args.dp_size > 1 and server_args.node_rank is not None)
if server_args.nnodes > 1: if server_args.nnodes > 1:
assert server_args.dp_size == 1, "Multi-node dp is not supported."
if server_args.node_rank != 0: if server_args.node_rank != 0:
send_addrs_to_rank_0(model_port_args[0], server_args) tp_size_local = server_args.tp_size // server_args.nnodes
else: gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
receive_addrs(model_port_args[0], server_args) tp_rank_range = list(range(server_args.node_rank * tp_size_local,
for i in range(tp_size_local): (server_args.node_rank + 1) * tp_size_local))
start_rpyc_service_process( procs = launch_tp_servers(gpu_ids, tp_rank_range, server_args,
ModelTpService, model_port_args[0].model_tp_ports[i] port_args.model_port_args[0], model_overide_args)
)
if server_args.node_rank != 0:
logger.info(
f"[node_rank={server_args.node_rank}]: Listen for connections..."
)
while True: while True:
pass pass
......
...@@ -67,10 +67,12 @@ class ServerArgs: ...@@ -67,10 +67,12 @@ class ServerArgs:
if self.tokenizer_path is None: if self.tokenizer_path is None:
self.tokenizer_path = self.model_path self.tokenizer_path = self.model_path
if self.mem_fraction_static is None: if self.mem_fraction_static is None:
if self.tp_size >= 8: if self.tp_size >= 16:
self.mem_fraction_static = 0.74
elif self.tp_size >= 8:
self.mem_fraction_static = 0.78 self.mem_fraction_static = 0.78
elif self.tp_size >= 4: elif self.tp_size >= 4:
self.mem_fraction_static = 0.80 self.mem_fraction_static = 0.82
elif self.tp_size >= 2: elif self.tp_size >= 2:
self.mem_fraction_static = 0.85 self.mem_fraction_static = 0.85
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment