Unverified Commit 04868543 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Improve process creation (#1534)

parent fd9ad817
......@@ -249,11 +249,12 @@ jobs:
python3 test_mla.py
python3 test_mla_fp8.py
- name: Evaluate Data Parallelism Accuracy (TP=2)
timeout-minutes: 10
run: |
cd test/srt
python3 test_data_parallelism.py
# Temporarily disabled
#- name: Evaluate Data Parallelism Accuracy (TP=2)
# timeout-minutes: 10
# run: |
# cd test/srt
# python3 test_data_parallelism.py
finish:
needs: [
......
......@@ -228,7 +228,7 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct
- To enable fp8 weight quantization, add `--quantization fp8` on a fp16 checkpoint or directly load a fp8 checkpoint without specifying any arguments.
- To enable fp8 kv cache quantization, add `--kv-cache-dtype fp8_e5m2`.
- If the model does not have a chat template in the Hugging Face tokenizer, you can specify a [custom chat template](docs/en/custom_chat_template.md).
- To run tensor parallelism on multiple nodes, add `--nnodes 2`. If you have two nodes with two GPUs on each node and want to run TP=4, let `sgl-dev-0` be the hostname of the first node and `50000` be an available port.
- To run tensor parallelism on multiple nodes, add `--nnodes 2`. If you have two nodes with two GPUs on each node and want to run TP=4, let `sgl-dev-0` be the hostname of the first node and `50000` be an available port, you can use the following commands. If you meet deadlock, please try to add `--disable-cuda-graph`
```
# Node 0
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --tp 4 --nccl-init sgl-dev-0:50000 --nnodes 2 --node-rank 0
......
......@@ -84,7 +84,7 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct
- To enable fp8 weight quantization, add `--quantization fp8` on a fp16 checkpoint or directly load a fp8 checkpoint without specifying any arguments.
- To enable fp8 kv cache quantization, add `--kv-cache-dtype fp8_e5m2`.
- If the model does not have a chat template in the Hugging Face tokenizer, you can specify a [custom chat template](https://sglang.readthedocs.io/en/latest/custom_chat_template.html).
- To run tensor parallelism on multiple nodes, add `--nnodes 2`. If you have two nodes with two GPUs on each node and want to run TP=4, let `sgl-dev-0` be the hostname of the first node and `50000` be an available port.
- To run tensor parallelism on multiple nodes, add `--nnodes 2`. If you have two nodes with two GPUs on each node and want to run TP=4, let `sgl-dev-0` be the hostname of the first node and `50000` be an available port, you can use the following commands. If you meet deadlock, please try to add `--disable-cuda-graph`
```
# Node 0
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --tp 4 --nccl-init sgl-dev-0:50000 --nnodes 2 --node-rank 0
......
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""
A controller that manages multiple data parallel workers.
Each data parallel worker can manage multiple tensor parallel workers.
"""
import dataclasses
import logging
import multiprocessing
from enum import Enum, auto
import numpy as np
import zmq
from sglang.srt.managers.controller_single import (
start_controller_process as start_controller_process_single,
)
from sglang.srt.managers.io_struct import (
AbortReq,
FlushCacheReq,
TokenizedGenerateReqInput,
)
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import configure_logger, kill_parent_process
from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__)
class LoadBalanceMethod(Enum):
"""Load balance method."""
ROUND_ROBIN = auto()
SHORTEST_QUEUE = auto()
@classmethod
def from_str(cls, method: str):
method = method.upper()
try:
return cls[method]
except KeyError as exc:
raise ValueError(f"Invalid load balance method: {method}") from exc
@dataclasses.dataclass
class WorkerHandle:
"""Store the handle of a data parallel worker."""
proc: multiprocessing.Process
queue: multiprocessing.Queue
class ControllerMulti:
"""A controller that manages multiple data parallel workers."""
def __init__(
self,
server_args: ServerArgs,
port_args: PortArgs,
):
# Parse args
self.server_args = server_args
self.port_args = port_args
self.load_balance_method = LoadBalanceMethod.from_str(
server_args.load_balance_method
)
# Init communication
context = zmq.Context()
self.recv_from_tokenizer = context.socket(zmq.PULL)
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.controller_port}")
# Dispatch method
self.round_robin_counter = 0
dispatch_lookup = {
LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler,
LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler,
}
self.dispatching = dispatch_lookup[self.load_balance_method]
# Start data parallel workers
self.workers = []
for i in range(server_args.dp_size):
self.start_dp_worker(i)
def start_dp_worker(self, dp_worker_id: int):
tp_size = self.server_args.tp_size
pipe_controller_reader, pipe_controller_writer = multiprocessing.Pipe(
duplex=False
)
gpu_ids = list(range(dp_worker_id * tp_size, (dp_worker_id + 1) * tp_size))
queue = multiprocessing.Queue()
proc = multiprocessing.Process(
target=start_controller_process_single,
args=(
self.server_args,
self.port_args,
pipe_controller_writer,
True,
gpu_ids,
dp_worker_id,
queue,
),
)
proc.start()
controller_init_state = pipe_controller_reader.recv()
if controller_init_state != "init ok":
raise RuntimeError(
f"Initialization failed. controller_init_state: {controller_init_state}"
)
self.workers.append(
WorkerHandle(
proc=proc,
queue=queue,
)
)
def round_robin_scheduler(self, input_requests):
for r in input_requests:
self.workers[self.round_robin_counter].queue.put(r)
self.round_robin_counter = (self.round_robin_counter + 1) % len(
self.workers
)
def shortest_queue_scheduler(self, input_requests):
for r in input_requests:
queue_sizes = [worker.queue.qsize() for worker in self.workers]
wid = np.argmin(queue_sizes)
self.workers[wid].queue.put(r)
def loop_for_forward(self):
while True:
recv_reqs = self.recv_requests()
self.dispatching(recv_reqs)
def recv_requests(self):
recv_reqs = []
while True:
try:
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
except zmq.ZMQError:
break
if isinstance(recv_req, FlushCacheReq):
# TODO(lsyin): apply more specific flushCacheReq
for worker in self.workers:
worker.queue.put(recv_req)
elif isinstance(recv_req, AbortReq):
in_queue = False
for i, req in enumerate(recv_reqs):
if req.rid == recv_req.rid:
recv_reqs[i] = recv_req
in_queue = True
break
if not in_queue:
# Send abort req to all TP groups
for worker in self.workers:
worker.queue.put(recv_req)
elif isinstance(recv_req, TokenizedGenerateReqInput):
recv_reqs.append(recv_req)
else:
logger.error(f"Invalid object: {recv_req}")
return recv_reqs
def start_controller_process(
server_args: ServerArgs,
port_args: PortArgs,
pipe_writer,
):
"""Start a controller process."""
configure_logger(server_args)
try:
controller = ControllerMulti(server_args, port_args)
except Exception:
pipe_writer.send(get_exception_traceback())
raise
pipe_writer.send("init ok")
try:
controller.loop_for_forward()
except Exception:
logger.error("Exception in ControllerMulti:\n" + get_exception_traceback())
finally:
kill_parent_process()
......@@ -16,6 +16,8 @@ limitations under the License.
"""DetokenizerManager is a process that detokenizes the token ids."""
import dataclasses
import logging
from collections import OrderedDict
from typing import List
import zmq
......@@ -29,8 +31,11 @@ from sglang.srt.managers.io_struct import (
)
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import configure_logger, kill_parent_process
from sglang.utils import find_printable_text, get_exception_traceback
logger = logging.getLogger(__name__)
@dataclasses.dataclass
class DecodeStatus:
......@@ -53,8 +58,8 @@ class DetokenizerManager:
):
# Init inter-process communication
context = zmq.Context(2)
self.recv_from_router = context.socket(zmq.PULL)
self.recv_from_router.bind(f"tcp://127.0.0.1:{port_args.detokenizer_port}")
self.recv_from_scheduler = context.socket(zmq.PULL)
self.recv_from_scheduler.bind(f"tcp://127.0.0.1:{port_args.detokenizer_port}")
self.send_to_tokenizer = context.socket(zmq.PUSH)
self.send_to_tokenizer.connect(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
......@@ -68,13 +73,13 @@ class DetokenizerManager:
trust_remote_code=server_args.trust_remote_code,
)
self.decode_status = {}
self.decode_status = LimitedCapacityDict()
def handle_loop(self):
def event_loop(self):
"""The event loop that handles requests"""
while True:
recv_obj = self.recv_from_router.recv_pyobj()
recv_obj = self.recv_from_scheduler.recv_pyobj()
if isinstance(recv_obj, BatchEmbeddingOut):
# If it is embedding model, no detokenization is needed.
......@@ -165,15 +170,29 @@ class DetokenizerManager:
)
def start_detokenizer_process(
class LimitedCapacityDict(OrderedDict):
def __init__(self, capacity=1 << 15, *args, **kwargs):
super().__init__(*args, **kwargs)
self.capacity = capacity
def __setitem__(self, key, value):
if len(self) >= self.capacity:
# Remove the oldest element (first item in the dict)
self.popitem(last=False)
# Set the new item
super().__setitem__(key, value)
def run_detokenizer_process(
server_args: ServerArgs,
port_args: PortArgs,
pipe_writer,
):
configure_logger(server_args)
try:
manager = DetokenizerManager(server_args, port_args)
manager.event_loop()
except Exception:
pipe_writer.send(get_exception_traceback())
raise
pipe_writer.send("init ok")
manager.handle_loop()
msg = get_exception_traceback()
logger.error(msg)
kill_parent_process()
......@@ -13,96 +13,75 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
"""A controller that manages a group of tensor parallel workers."""
"""A scheduler that manages a tensor parallel GPU worker."""
import logging
import multiprocessing
from typing import List
import zmq
from sglang.srt.managers.tp_worker import (
ModelTpServer,
broadcast_recv_input,
launch_tp_servers,
)
from sglang.srt.managers.tp_worker import ModelTpServer
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import configure_logger, kill_parent_process
from sglang.srt.utils import broadcast_pyobj, configure_logger, kill_parent_process
from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__)
class ControllerSingle:
"""A controller that manages a group of tensor parallel workers."""
class Scheduler:
"""A scheduler that manages a tensor parallel GPU worker."""
def __init__(
self,
server_args: ServerArgs,
port_args: PortArgs,
gpu_ids: List[int],
is_data_parallel_worker: bool,
dp_worker_id: int,
mp_queue: multiprocessing.Queue,
gpu_id: int,
tp_rank: int,
):
# Parse args
self.tp_rank = tp_rank
self.tp_size = server_args.tp_size
self.is_dp_worker = is_data_parallel_worker
self.dp_worker_id = dp_worker_id
self.mp_queue = mp_queue
# Init inter-process communication
context = zmq.Context(2)
if not self.is_dp_worker:
if self.tp_rank == 0:
self.recv_from_tokenizer = context.socket(zmq.PULL)
self.recv_from_tokenizer.bind(
f"tcp://127.0.0.1:{port_args.controller_port}"
)
self.send_to_detokenizer = context.socket(zmq.PUSH)
self.send_to_detokenizer.connect(
f"tcp://127.0.0.1:{port_args.detokenizer_port}"
)
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.scheduler_port}")
# Launch other tp ranks
tp_size_local = server_args.tp_size // server_args.nnodes
self.tp_procs = []
if tp_size_local > 1:
tp_rank_range = range(1, tp_size_local)
self.tp_procs = launch_tp_servers(
gpu_ids,
tp_rank_range,
server_args,
port_args.nccl_ports[dp_worker_id],
self.send_to_detokenizer = context.socket(zmq.PUSH)
self.send_to_detokenizer.connect(
f"tcp://127.0.0.1:{port_args.detokenizer_port}"
)
else:
self.send_to_detokenizer = None
# Launch tp rank 0
# Launch a tp server
self.tp_server = ModelTpServer(
gpu_ids[0],
0,
server_args,
port_args.nccl_ports[dp_worker_id],
gpu_id=gpu_id,
tp_rank=tp_rank,
server_args=server_args,
nccl_port=port_args.nccl_ports[0],
)
self.tp_cpu_group = self.tp_server.model_runner.tp_group.cpu_group
def loop_for_forward(self):
def event_loop(self):
while True:
if not self.is_dp_worker:
if self.tp_rank == 0:
recv_reqs = self.recv_requests_from_zmq()
else:
recv_reqs = self.recv_requests_from_mp_queue()
if self.tp_size > 1:
broadcast_recv_input(recv_reqs, 0, self.tp_cpu_group)
recv_reqs = None
recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
out_pyobjs = self.tp_server.exposed_step(recv_reqs)
for obj in out_pyobjs:
self.send_to_detokenizer.send_pyobj(obj)
if self.tp_rank == 0:
for obj in out_pyobjs:
self.send_to_detokenizer.send_pyobj(obj)
def recv_requests_from_zmq(self):
recv_reqs = []
while True:
try:
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
......@@ -112,53 +91,21 @@ class ControllerSingle:
return recv_reqs
def recv_requests_from_mp_queue(self):
recv_reqs = []
while not self.mp_queue.empty():
recv_reqs.append(self.mp_queue.get())
return recv_reqs
def start_controller_process(
def run_scheduler_process(
server_args: ServerArgs,
port_args: PortArgs,
gpu_id: int,
tp_rank: int,
pipe_writer: multiprocessing.connection.Connection,
is_data_parallel_worker: bool = False,
gpu_ids: List[int] = None,
dp_worker_id: int = None,
queue: multiprocessing.connection.Connection = None,
):
"""Start a controller process."""
if is_data_parallel_worker:
logger_prefix = f" DP{dp_worker_id} TP0"
else:
logger_prefix = " TP0"
configure_logger(server_args, prefix=logger_prefix)
if not is_data_parallel_worker:
tp_size_local = server_args.tp_size // server_args.nnodes
gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
dp_worker_id = 0
queue = None
try:
controller = ControllerSingle(
server_args,
port_args,
gpu_ids,
is_data_parallel_worker,
dp_worker_id,
queue,
)
except Exception:
pipe_writer.send(get_exception_traceback())
raise
pipe_writer.send("init ok")
configure_logger(server_args, prefix=f" TP{tp_rank}")
try:
controller.loop_for_forward()
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank)
pipe_writer.send("ready")
scheduler.event_loop()
except Exception:
logger.error("Exception in ControllerSingle:\n" + get_exception_traceback())
finally:
msg = get_exception_traceback()
logger.error(msg)
kill_parent_process()
......@@ -88,8 +88,8 @@ class TokenizerManager:
self.recv_from_detokenizer = context.socket(zmq.PULL)
self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
self.send_to_controller = context.socket(zmq.PUSH)
self.send_to_controller.connect(f"tcp://127.0.0.1:{port_args.controller_port}")
self.send_to_scheduler = context.socket(zmq.PUSH)
self.send_to_scheduler.connect(f"tcp://127.0.0.1:{port_args.scheduler_port}")
# Read model args
self.model_path = server_args.model_path
......@@ -285,7 +285,7 @@ class TokenizerManager:
input_ids,
sampling_params,
)
self.send_to_controller.send_pyobj(tokenized_obj)
self.send_to_scheduler.send_pyobj(tokenized_obj)
# Recv results
event = asyncio.Event()
......@@ -397,7 +397,7 @@ class TokenizerManager:
input_ids,
sampling_params,
)
self.send_to_controller.send_pyobj(tokenized_obj)
self.send_to_scheduler.send_pyobj(tokenized_obj)
event = asyncio.Event()
state = ReqState([], False, event)
......@@ -530,14 +530,14 @@ class TokenizerManager:
def flush_cache(self):
req = FlushCacheReq()
self.send_to_controller.send_pyobj(req)
self.send_to_scheduler.send_pyobj(req)
def abort_request(self, rid: str):
if rid not in self.rid_to_state:
return
del self.rid_to_state[rid]
req = AbortReq(rid)
self.send_to_controller.send_pyobj(req)
self.send_to_scheduler.send_pyobj(req)
async def update_weights(
self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
......@@ -554,7 +554,7 @@ class TokenizerManager:
# wait for the previous generation requests to finish
while len(self.rid_to_state) > 0:
await asyncio.sleep(0)
self.send_to_controller.send_pyobj(obj)
self.send_to_scheduler.send_pyobj(obj)
self.model_update_result = asyncio.Future()
result = await self.model_update_result
if result.success:
......@@ -665,6 +665,7 @@ class TokenizerManager:
def detokenize_logprob_tokens(
self, token_logprobs: List[Tuple[float, int]], decode_to_text: bool
):
# TODO(lianmin): This should run on DetokenizerManager
if not decode_to_text:
return [(logprob, token_id, None) for logprob, token_id in token_logprobs]
......
......@@ -17,16 +17,12 @@ limitations under the License.
import json
import logging
import multiprocessing
import os
import pickle
import time
import warnings
from typing import Any, List, Optional, Union
from typing import List, Optional, Union
import torch
import torch.distributed
import torch.distributed as dist
from sglang.global_config import global_config
from sglang.srt.configs.model_config import ModelConfig
......@@ -58,7 +54,7 @@ from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
configure_logger,
broadcast_pyobj,
is_multimodal_model,
set_random_seed,
suppress_other_loggers,
......@@ -140,7 +136,7 @@ class ModelTpServer:
)
# Sync random seed across TP workers
server_args.random_seed = broadcast_recv_input(
server_args.random_seed = broadcast_pyobj(
[server_args.random_seed],
self.tp_rank,
self.model_runner.tp_group.cpu_group,
......@@ -935,82 +931,3 @@ class ModelTpServer:
else:
logger.error(message)
return success, message
def run_tp_server(
gpu_id: int,
tp_rank: int,
server_args: ServerArgs,
nccl_port: int,
):
"""Run a tensor parallel model server."""
configure_logger(server_args, prefix=f" TP{tp_rank}")
try:
model_server = ModelTpServer(
gpu_id,
tp_rank,
server_args,
nccl_port,
)
tp_cpu_group = model_server.model_runner.tp_group.cpu_group
while True:
recv_reqs = broadcast_recv_input(None, tp_rank, tp_cpu_group)
model_server.exposed_step(recv_reqs)
except Exception:
logger.error("Exception in run_tp_server:\n" + get_exception_traceback())
raise
def launch_tp_servers(
gpu_ids: List[int],
tp_rank_range: List[int],
server_args: ServerArgs,
nccl_port: int,
):
"""Launch multiple tensor parallel servers."""
procs = []
for i in tp_rank_range:
proc = multiprocessing.Process(
target=run_tp_server,
args=(gpu_ids[i], i, server_args, nccl_port),
)
proc.start()
procs.append(proc)
return procs
def broadcast_recv_input(
data: Any, rank: int, dist_group: torch.distributed.ProcessGroup
):
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
if rank == 0:
if len(data) == 0:
tensor_size = torch.tensor([0], dtype=torch.long)
dist.broadcast(tensor_size, src=0, group=dist_group)
else:
serialized_data = pickle.dumps(data)
size = len(serialized_data)
tensor_data = torch.ByteTensor(list(serialized_data))
tensor_size = torch.tensor([size], dtype=torch.long)
dist.broadcast(tensor_size, src=0, group=dist_group)
dist.broadcast(tensor_data, src=0, group=dist_group)
return data
else:
tensor_size = torch.tensor([0], dtype=torch.long)
dist.broadcast(tensor_size, src=0, group=dist_group)
size = tensor_size.item()
if size == 0:
return []
tensor_data = torch.empty(size, dtype=torch.uint8)
dist.broadcast(tensor_data, src=0, group=dist_group)
serialized_data = bytes(tensor_data.tolist())
data = pickle.loads(serialized_data)
return data
......@@ -135,8 +135,8 @@ class ModelRunner:
if not self.server_args.enable_p2p_check:
monkey_patch_vllm_p2p_access_check(self.gpu_id)
if self.server_args.nccl_init_addr:
nccl_init_method = f"tcp://{self.server_args.nccl_init_addr}"
if self.server_args.dist_init_addr:
nccl_init_method = f"tcp://{self.server_args.dist_init_addr}"
else:
nccl_init_method = f"tcp://127.0.0.1:{self.nccl_port}"
set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
......
......@@ -43,20 +43,14 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.srt.constrained import disable_cache
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.controller_multi import (
start_controller_process as start_controller_process_multi,
)
from sglang.srt.managers.controller_single import launch_tp_servers
from sglang.srt.managers.controller_single import (
start_controller_process as start_controller_process_single,
)
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
from sglang.srt.managers.io_struct import (
EmbeddingReqInput,
GenerateReqInput,
RewardReqInput,
UpdateWeightReqInput,
)
from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.openai_api.adapter import (
load_chat_template_for_openai_api,
......@@ -82,8 +76,7 @@ from sglang.srt.utils import (
is_hip,
kill_child_process,
maybe_set_triton_cache_manager,
prepare_model,
prepare_tokenizer,
prepare_model_and_tokenizer,
set_ulimit,
)
from sglang.utils import get_exception_traceback
......@@ -303,8 +296,8 @@ def launch_server(
"""Launch an HTTP server."""
global tokenizer_manager
# Configure global environment
configure_logger(server_args)
server_args.check_server_args()
_set_envs_and_config(server_args)
......@@ -317,81 +310,60 @@ def launch_server(
ports = server_args.additional_ports
port_args = PortArgs(
tokenizer_port=ports[0],
controller_port=ports[1],
scheduler_port=ports[1],
detokenizer_port=ports[2],
nccl_ports=ports[3:],
)
logger.info(f"{server_args=}")
# Use model from www.modelscope.cn, first download the model.
server_args.model_path = prepare_model(server_args.model_path)
server_args.tokenizer_path = prepare_tokenizer(server_args.tokenizer_path)
# Launch processes for multi-node tensor parallelism
if server_args.nnodes > 1 and server_args.node_rank != 0:
tp_size_local = server_args.tp_size // server_args.nnodes
gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
tp_rank_range = list(
range(
server_args.node_rank * tp_size_local,
(server_args.node_rank + 1) * tp_size_local,
)
)
procs = launch_tp_servers(
gpu_ids,
tp_rank_range,
server_args,
ports[3],
)
try:
for p in procs:
p.join()
finally:
kill_child_process(os.getpid(), including_parent=False)
return
# Launch processes
pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False)
# If using model from www.modelscope.cn, first download the model.
server_args.model_path, server_args.tokenizer_path = prepare_model_and_tokenizer(
server_args.model_path, server_args.tokenizer_path
)
if server_args.dp_size == 1:
start_controller_process = start_controller_process_single
else:
start_controller_process = start_controller_process_multi
proc_controller = mp.Process(
target=start_controller_process,
args=(server_args, port_args, pipe_controller_writer),
# Launch tensor parallel scheduler processes
scheduler_procs = []
scheduler_pipe_readers = []
tp_size_per_node = server_args.tp_size // server_args.nnodes
tp_rank_range = range(
tp_size_per_node * server_args.node_rank,
tp_size_per_node * (server_args.node_rank + 1),
)
proc_controller.start()
for tp_rank in tp_rank_range:
reader, writer = mp.Pipe(duplex=False)
gpu_id = tp_rank % tp_size_per_node
proc = mp.Process(
target=run_scheduler_process,
args=(server_args, port_args, gpu_id, tp_rank, writer),
)
proc.start()
scheduler_procs.append(proc)
scheduler_pipe_readers.append(reader)
if server_args.node_rank >= 1:
# For other nodes, they do not need to run tokenizer or detokenizer,
# so they can just wait here.
while True:
pass
pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
proc_detoken = mp.Process(
target=start_detokenizer_process,
# Launch detokenizer process
detoken_proc = mp.Process(
target=run_detokenizer_process,
args=(
server_args,
port_args,
pipe_detoken_writer,
),
)
proc_detoken.start()
detoken_proc.start()
# Launch tokenizer process
tokenizer_manager = TokenizerManager(server_args, port_args)
if server_args.chat_template:
load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)
# Wait for the model to finish loading
controller_init_state = pipe_controller_reader.recv()
detoken_init_state = pipe_detoken_reader.recv()
if controller_init_state != "init ok" or detoken_init_state != "init ok":
proc_controller.kill()
proc_detoken.kill()
raise RuntimeError(
"Initialization failed. "
f"controller_init_state: {controller_init_state}, "
f"detoken_init_state: {detoken_init_state}"
)
assert proc_controller.is_alive() and proc_detoken.is_alive()
# Wait for model to finish loading
for i in range(len(scheduler_pipe_readers)):
scheduler_pipe_readers[i].recv()
# Add api key authorization
if server_args.api_key:
......@@ -404,7 +376,7 @@ def launch_server(
t.start()
try:
# Listen for requests
# Listen for HTTP requests
uvicorn.run(
app,
host=server_args.host,
......@@ -451,9 +423,7 @@ def _set_envs_and_config(server_args: ServerArgs):
"at https://docs.flashinfer.ai/installation.html.",
)
if is_hip():
# to figure out a better method of not using fork later
mp.set_start_method("spawn", force=True)
mp.set_start_method("spawn", force=True)
def _wait_and_warmup(server_args, pipe_finish_writer, pid):
......@@ -517,7 +487,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
logger.info("The server is fired up and ready to roll!")
if pipe_finish_writer is not None:
pipe_finish_writer.send("init ok")
pipe_finish_writer.send("ready")
class Runtime:
......@@ -564,7 +534,7 @@ class Runtime:
except EOFError:
init_state = ""
if init_state != "init ok":
if init_state != "ready":
self.shutdown()
raise RuntimeError(
"Initialization failed. Please see the error messages above."
......
......@@ -78,9 +78,9 @@ class ServerArgs:
load_balance_method: str = "round_robin"
# Distributed args
nccl_init_addr: Optional[str] = None
dist_init_addr: Optional[str] = None
nnodes: int = 1
node_rank: Optional[int] = None
node_rank: int = 0
# Model override args in JSON
json_model_override_args: str = "{}"
......@@ -426,14 +426,17 @@ class ServerArgs:
# Multi-node distributed serving args
parser.add_argument(
"--nccl-init-addr",
"--dist-init-addr",
"--nccl-init-addr", # For backward compatbility. This will be removed in the future.
type=str,
help="The nccl init address of multi-node server.",
help="The host address for initializing distributed backend (e.g., `192.168.0.2:25000`).",
)
parser.add_argument(
"--nnodes", type=int, default=ServerArgs.nnodes, help="The number of nodes."
)
parser.add_argument("--node-rank", type=int, help="The node rank.")
parser.add_argument(
"--node-rank", type=int, default=ServerArgs.node_rank, help="The node rank."
)
# Model override args
parser.add_argument(
......@@ -583,6 +586,11 @@ class ServerArgs:
and (self.lora_paths is None or self.disable_radix_cache)
), "compatibility of lora and cuda graph and radix attention is in progress"
assert self.dp_size == 1, (
"The support for data parallelism is temporarily disabled during refactor. "
"Please use sglang<=0.3.2 or wait for later updates."
)
def prepare_server_args(argv: List[str]) -> ServerArgs:
"""
......@@ -604,9 +612,13 @@ def prepare_server_args(argv: List[str]) -> ServerArgs:
@dataclasses.dataclass
class PortArgs:
# The port for tokenizer to receive inputs from detokenizer (zmq)
tokenizer_port: int
controller_port: int
# The port for scheduler to receive inputs from tokenizer (zmq)
scheduler_port: int
# The port for detokenizer to receive inputs from scheduler (zmq)
detokenizer_port: int
# The port for nccl initialization for multiple TP groups (torch.dist)
nccl_ports: List[int]
......
......@@ -16,13 +16,12 @@ limitations under the License.
"""Common utilities."""
import base64
import fcntl
import logging
import os
import pickle
import random
import resource
import socket
import struct
import time
from importlib.metadata import PackageNotFoundError, version
from io import BytesIO
......@@ -36,7 +35,6 @@ import torch.distributed as dist
from fastapi.responses import JSONResponse
from packaging import version as pkg_version
from torch import nn
from torch.nn.parameter import Parameter
from triton.runtime.cache import (
FileCacheManager,
default_cache_dir,
......@@ -539,89 +537,6 @@ class CustomCacheManager(FileCacheManager):
raise RuntimeError("Could not create or locate cache dir")
def get_ip_address(ifname):
"""
Get the IP address of a network interface.
:param ifname: Name of the network interface (e.g., 'eth0')
:return: IP address of the network interface
"""
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
ip_address = fcntl.ioctl(
s.fileno(),
0x8915, # SIOCGIFADDR
struct.pack("256s", bytes(ifname[:15], "utf-8")),
)[20:24]
return socket.inet_ntoa(ip_address)
def send_addrs_to_rank_0(model_port_args, server_args):
assert server_args.node_rank != 0 and server_args.dp_size == 1
ifname = os.environ.get(
"SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0")
)
ip_addr = get_ip_address(ifname)
num_tp_ports = server_args.tp_size // server_args.nnodes
model_port_args.model_tp_ips[:num_tp_ports] = [ip_addr] * num_tp_ports
ip_addr = [int(x) for x in ip_addr.split(".")]
addrs_tensor = torch.tensor(
ip_addr + model_port_args.model_tp_ports, dtype=torch.int
)
init_method = f"tcp://{server_args.nccl_init_addr}"
dist.init_process_group(
backend="gloo",
init_method=init_method,
rank=server_args.node_rank,
world_size=server_args.nnodes,
)
dist.send(addrs_tensor, dst=0)
print(
f"Node {server_args.node_rank} sent: ip_address {ip_addr} and ports {model_port_args.model_tp_ports}"
)
dist.barrier()
dist.destroy_process_group()
def receive_addrs(model_port_args, server_args):
assert server_args.node_rank == 0 and server_args.dp_size == 1
ifname = os.environ.get(
"SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0")
)
ip_addr = get_ip_address(ifname)
num_tp_ports = server_args.tp_size // server_args.nnodes
model_port_args.model_tp_ips[:num_tp_ports] = [ip_addr] * num_tp_ports
init_method = f"tcp://{server_args.nccl_init_addr}"
dist.init_process_group(
backend="gloo",
init_method=init_method,
rank=server_args.node_rank,
world_size=server_args.nnodes,
)
for src_rank in range(1, server_args.nnodes):
tensor = torch.zeros(4 + num_tp_ports, dtype=torch.int)
dist.recv(tensor, src=src_rank)
ip = ".".join([str(x) for x in tensor[:4].tolist()])
ports = tensor[4:].tolist()
model_port_args.model_tp_ips[
num_tp_ports * src_rank : num_tp_ports * (src_rank + 1)
] = [ip] * num_tp_ports
model_port_args.model_tp_ports[
num_tp_ports * src_rank : num_tp_ports * (src_rank + 1)
] = ports
print(f"Node 0 received from rank {src_rank}: {tensor.tolist()}")
dist.barrier()
dist.destroy_process_group()
def set_ulimit(target_soft_limit=65535):
resource_type = resource.RLIMIT_NOFILE
current_soft, current_hard = resource.getrlimit(resource_type)
......@@ -645,24 +560,16 @@ def add_api_key_middleware(app, api_key: str):
return await call_next(request)
def prepare_model(model_path: str):
def prepare_model_and_tokenizer(model_path: str, tokenizer_path: str):
if "SGLANG_USE_MODELSCOPE" in os.environ:
if not os.path.exists(model_path):
from modelscope import snapshot_download
return snapshot_download(model_path)
return model_path
def prepare_tokenizer(tokenizer_path: str):
if "SGLANG_USE_MODELSCOPE" in os.environ:
if not os.path.exists(tokenizer_path):
from modelscope import snapshot_download
return snapshot_download(
model_path = snapshot_download(model_path)
tokenizer_path = snapshot_download(
tokenizer_path, ignore_patterns=["*.bin", "*.safetensors"]
)
return tokenizer_path
return model_path, tokenizer_path
def configure_logger(server_args, prefix: str = ""):
......@@ -704,3 +611,37 @@ def set_weight_attrs(
for key, value in weight_attrs.items():
assert not hasattr(weight, key), f"Overwriting existing tensor attribute: {key}"
setattr(weight, key, value)
def broadcast_pyobj(
data: List[Any], rank: int, dist_group: torch.distributed.ProcessGroup
):
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
if rank == 0:
if len(data) == 0:
tensor_size = torch.tensor([0], dtype=torch.long)
dist.broadcast(tensor_size, src=0, group=dist_group)
else:
serialized_data = pickle.dumps(data)
size = len(serialized_data)
tensor_data = torch.ByteTensor(list(serialized_data))
tensor_size = torch.tensor([size], dtype=torch.long)
dist.broadcast(tensor_size, src=0, group=dist_group)
dist.broadcast(tensor_data, src=0, group=dist_group)
return data
else:
tensor_size = torch.tensor([0], dtype=torch.long)
dist.broadcast(tensor_size, src=0, group=dist_group)
size = tensor_size.item()
if size == 0:
return []
tensor_data = torch.empty(size, dtype=torch.uint8)
dist.broadcast(tensor_data, src=0, group=dist_group)
serialized_data = bytes(tensor_data.tolist())
data = pickle.loads(serialized_data)
return data
......@@ -4,7 +4,7 @@ import subprocess
import unittest
from unittest import mock
from sglang.srt.utils import prepare_model, prepare_tokenizer
from sglang.srt.utils import prepare_model_and_tokenizer
class TestDownloadFromModelScope(unittest.TestCase):
......@@ -21,25 +21,17 @@ class TestDownloadFromModelScope(unittest.TestCase):
def tearDownClass(cls):
pass
def test_prepare_model(self):
def test_prepare_model_and_tokenizer(self):
from modelscope.utils.file_utils import get_model_cache_root
model_cache_root = get_model_cache_root()
if os.path.exists(model_cache_root):
shutil.rmtree(model_cache_root)
with mock.patch.dict(os.environ, self.with_modelscope_environ, clear=True):
model_path = prepare_model(self.model)
model_path, tokenizer_path = prepare_model_and_tokenizer(
self.model, self.model
)
assert os.path.exists(os.path.join(model_path, "pytorch_model.bin"))
def test_prepare_tokenizer(self):
from modelscope.utils.file_utils import get_model_cache_root
model_cache_root = get_model_cache_root()
if os.path.exists(model_cache_root):
shutil.rmtree(model_cache_root)
with mock.patch.dict(os.environ, self.with_modelscope_environ, clear=True):
tokenizer_path = prepare_tokenizer(self.model)
assert not os.path.exists(os.path.join(tokenizer_path, "pytorch_model.bin"))
assert os.path.exists(os.path.join(tokenizer_path, "config.json"))
......
......@@ -26,7 +26,7 @@ class TestTritonAttnBackend(unittest.TestCase):
)
if is_in_ci():
assert output_throughput > 155, f"{output_throughput=}"
assert output_throughput > 154, f"{output_throughput=}"
def test_mmlu(self):
model = DEFAULT_MODEL_NAME_FOR_TEST
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment