Unverified Commit 23cc66f7 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Add back data parallelism (#1635)

parent 5d09ca57
...@@ -255,12 +255,11 @@ jobs: ...@@ -255,12 +255,11 @@ jobs:
python3 test_mla.py python3 test_mla.py
python3 test_mla_fp8.py python3 test_mla_fp8.py
# Temporarily disabled - name: Evaluate Data Parallelism Accuracy (DP=2)
#- name: Evaluate Data Parallelism Accuracy (TP=2) timeout-minutes: 10
# timeout-minutes: 10 run: |
# run: | cd test/srt
# cd test/srt python3 test_data_parallelism.py
# python3 test_data_parallelism.py
finish: finish:
needs: [ needs: [
......
...@@ -139,7 +139,7 @@ def load_model(server_args, port_args, tp_rank): ...@@ -139,7 +139,7 @@ def load_model(server_args, port_args, tp_rank):
gpu_id=tp_rank, gpu_id=tp_rank,
tp_rank=tp_rank, tp_rank=tp_rank,
tp_size=server_args.tp_size, tp_size=server_args.tp_size,
nccl_port=port_args.nccl_ports[0], nccl_port=port_args.nccl_port,
server_args=server_args, server_args=server_args,
) )
rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}") rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}")
......
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""A controller that dispatches requests to multiple data parallel workers."""
import logging
import multiprocessing as mp
from enum import Enum, auto
import zmq
from sglang.srt.managers.io_struct import (
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
TokenizedRewardReqInput,
)
from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import (
configure_logger,
kill_parent_process,
suppress_other_loggers,
)
from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__)
class LoadBalanceMethod(Enum):
"""Load balance method."""
ROUND_ROBIN = auto()
SHORTEST_QUEUE = auto()
@classmethod
def from_str(cls, method: str):
method = method.upper()
try:
return cls[method]
except KeyError as exc:
raise ValueError(f"Invalid load balance method: {method}") from exc
class DataParallelController:
"""A controller that dispatches requests to multiple data parallel workers."""
def __init__(self, server_args, port_args) -> None:
# Parse args
self.server_args = server_args
self.port_args = port_args
self.load_balance_method = LoadBalanceMethod.from_str(
server_args.load_balance_method
)
# Init inter-process communication
self.context = zmq.Context(1 + server_args.dp_size)
self.recv_from_tokenizer = self.context.socket(zmq.PULL)
self.recv_from_tokenizer.bind(f"ipc://{port_args.scheduler_input_ipc_name}")
# Dispatch method
self.round_robin_counter = 0
dispatch_lookup = {
LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler,
LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler,
}
self.dispatching = dispatch_lookup[self.load_balance_method]
# Start data parallel workers
base_gpu_id = 0
self.workers = []
for dp_rank in range(server_args.dp_size):
tmp_port_args = PortArgs.init_new(server_args)
tmp_port_args.detokenizer_ipc_name = port_args.detokenizer_ipc_name
send_to = self.launch_tensor_parallel_group(
server_args,
tmp_port_args,
base_gpu_id,
dp_rank,
)
self.workers.append(send_to)
base_gpu_id += server_args.tp_size
def launch_tensor_parallel_group(
self,
server_args: ServerArgs,
port_args: PortArgs,
base_gpu_id: int,
dp_rank: int,
):
# Launch tensor parallel scheduler processes
scheduler_procs = []
scheduler_pipe_readers = []
tp_size_per_node = server_args.tp_size // server_args.nnodes
tp_rank_range = range(
tp_size_per_node * server_args.node_rank,
tp_size_per_node * (server_args.node_rank + 1),
)
for tp_rank in tp_rank_range:
reader, writer = mp.Pipe(duplex=False)
gpu_id = base_gpu_id + tp_rank % tp_size_per_node
proc = mp.Process(
target=run_scheduler_process,
args=(server_args, port_args, gpu_id, tp_rank, dp_rank, writer),
)
proc.start()
scheduler_procs.append(proc)
scheduler_pipe_readers.append(reader)
send_to = self.context.socket(zmq.PUSH)
send_to.connect(f"ipc://{port_args.scheduler_input_ipc_name}")
# Wait for model to finish loading
for i in range(len(scheduler_pipe_readers)):
scheduler_pipe_readers[i].recv()
return send_to
def round_robin_scheduler(self, req):
self.workers[self.round_robin_counter].send_pyobj(req)
self.round_robin_counter = (self.round_robin_counter + 1) % len(self.workers)
def shortest_queue_scheduler(self, input_requests):
raise NotImplementedError()
def event_loop(self):
while True:
while True:
try:
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
except zmq.ZMQError:
break
if isinstance(
recv_req,
(
TokenizedGenerateReqInput,
TokenizedEmbeddingReqInput,
TokenizedRewardReqInput,
),
):
self.dispatching(recv_req)
else:
# Send other control messages to all workers
for worker in self.workers:
worker.queue.put(recv_req)
def run_data_parallel_controller_process(
server_args: ServerArgs,
port_args: PortArgs,
pipe_writer,
):
configure_logger(server_args)
suppress_other_loggers()
try:
controller = DataParallelController(server_args, port_args)
pipe_writer.send("ready")
controller.event_loop()
except Exception:
msg = get_exception_traceback()
logger.error(msg)
kill_parent_process()
...@@ -142,7 +142,7 @@ class Scheduler: ...@@ -142,7 +142,7 @@ class Scheduler:
gpu_id=gpu_id, gpu_id=gpu_id,
tp_rank=tp_rank, tp_rank=tp_rank,
server_args=server_args, server_args=server_args,
nccl_port=port_args.nccl_ports[0], nccl_port=port_args.nccl_port,
) )
self.tp_cpu_group = self.tp_worker.model_runner.tp_group.cpu_group self.tp_cpu_group = self.tp_worker.model_runner.tp_group.cpu_group
...@@ -1042,9 +1042,14 @@ def run_scheduler_process( ...@@ -1042,9 +1042,14 @@ def run_scheduler_process(
port_args: PortArgs, port_args: PortArgs,
gpu_id: int, gpu_id: int,
tp_rank: int, tp_rank: int,
dp_rank: Optional[int],
pipe_writer, pipe_writer,
): ):
configure_logger(server_args, prefix=f" TP{tp_rank}") if dp_rank is None:
configure_logger(server_args, prefix=f" TP{tp_rank}")
else:
configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}")
suppress_other_loggers() suppress_other_loggers()
try: try:
......
...@@ -141,7 +141,7 @@ class ModelRunner: ...@@ -141,7 +141,7 @@ class ModelRunner:
self.init_attention_backend() self.init_attention_backend()
def init_torch_distributed(self): def init_torch_distributed(self):
logger.info("Init torch distributed begin.") logger.info("Init torch distributed begin.")
# Init torch distributed # Init torch distributed
if self.device == "cuda": if self.device == "cuda":
torch.cuda.set_device(self.gpu_id) torch.cuda.set_device(self.gpu_id)
......
...@@ -44,6 +44,9 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse ...@@ -44,6 +44,9 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.data_parallel_controller import (
run_data_parallel_controller_process,
)
from sglang.srt.managers.detokenizer_manager import run_detokenizer_process from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
EmbeddingReqInput, EmbeddingReqInput,
...@@ -337,30 +340,40 @@ def launch_engine( ...@@ -337,30 +340,40 @@ def launch_engine(
server_args.model_path, server_args.tokenizer_path server_args.model_path, server_args.tokenizer_path
) )
# Launch tensor parallel scheduler processes if server_args.dp_size == 1:
scheduler_procs = [] # Launch tensor parallel scheduler processes
scheduler_pipe_readers = [] scheduler_procs = []
tp_size_per_node = server_args.tp_size // server_args.nnodes scheduler_pipe_readers = []
tp_rank_range = range( tp_size_per_node = server_args.tp_size // server_args.nnodes
tp_size_per_node * server_args.node_rank, tp_rank_range = range(
tp_size_per_node * (server_args.node_rank + 1), tp_size_per_node * server_args.node_rank,
) tp_size_per_node * (server_args.node_rank + 1),
for tp_rank in tp_rank_range: )
for tp_rank in tp_rank_range:
reader, writer = mp.Pipe(duplex=False)
gpu_id = tp_rank % tp_size_per_node
proc = mp.Process(
target=run_scheduler_process,
args=(server_args, port_args, gpu_id, tp_rank, None, writer),
)
proc.start()
scheduler_procs.append(proc)
scheduler_pipe_readers.append(reader)
if server_args.node_rank >= 1:
# For other nodes, they do not need to run tokenizer or detokenizer,
# so they can just wait here.
while True:
pass
else:
# Launch the data parallel controller
reader, writer = mp.Pipe(duplex=False) reader, writer = mp.Pipe(duplex=False)
gpu_id = tp_rank % tp_size_per_node scheduler_pipe_readers = [reader]
proc = mp.Process( proc = mp.Process(
target=run_scheduler_process, target=run_data_parallel_controller_process,
args=(server_args, port_args, gpu_id, tp_rank, writer), args=(server_args, port_args, writer),
) )
proc.start() proc.start()
scheduler_procs.append(proc)
scheduler_pipe_readers.append(reader)
if server_args.node_rank >= 1:
# For other nodes, they do not need to run tokenizer or detokenizer,
# so they can just wait here.
while True:
pass
# Launch detokenizer process # Launch detokenizer process
detoken_proc = mp.Process( detoken_proc = mp.Process(
......
...@@ -574,7 +574,7 @@ class ServerArgs: ...@@ -574,7 +574,7 @@ class ServerArgs:
self.tp_size % self.nnodes == 0 self.tp_size % self.nnodes == 0
), "tp_size must be divisible by number of nodes" ), "tp_size must be divisible by number of nodes"
assert not ( assert not (
self.dp_size > 1 and self.node_rank is not None self.dp_size > 1 and self.nnodes != 1
), "multi-node data parallel is not supported" ), "multi-node data parallel is not supported"
assert ( assert (
self.max_loras_per_batch > 0 self.max_loras_per_batch > 0
...@@ -583,11 +583,6 @@ class ServerArgs: ...@@ -583,11 +583,6 @@ class ServerArgs:
and (self.lora_paths is None or self.disable_radix_cache) and (self.lora_paths is None or self.disable_radix_cache)
), "compatibility of lora and cuda graph and radix attention is in progress" ), "compatibility of lora and cuda graph and radix attention is in progress"
assert self.dp_size == 1, (
"The support for data parallelism is temporarily disabled during refactor. "
"Please use sglang<=0.3.2 or wait for later updates."
)
if isinstance(self.lora_paths, list): if isinstance(self.lora_paths, list):
lora_paths = self.lora_paths lora_paths = self.lora_paths
self.lora_paths = {} self.lora_paths = {}
...@@ -626,8 +621,8 @@ class PortArgs: ...@@ -626,8 +621,8 @@ class PortArgs:
# The ipc filename for detokenizer to receive inputs from scheduler (zmq) # The ipc filename for detokenizer to receive inputs from scheduler (zmq)
detokenizer_ipc_name: str detokenizer_ipc_name: str
# The port for nccl initialization for multiple TP groups (torch.dist) # The port for nccl initialization (torch.dist)
nccl_ports: List[int] nccl_port: int
@staticmethod @staticmethod
def init_new(server_args) -> "PortArgs": def init_new(server_args) -> "PortArgs":
...@@ -641,7 +636,7 @@ class PortArgs: ...@@ -641,7 +636,7 @@ class PortArgs:
tokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name, tokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
scheduler_input_ipc_name=tempfile.NamedTemporaryFile(delete=False).name, scheduler_input_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
detokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name, detokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
nccl_ports=[port], nccl_port=port,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment