Unverified Commit 8b6ce52e authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Support multi-node DP attention (#2925)


Co-authored-by: default avatardhou-xai <dhou@x.ai>
parent 58f3f2b8
...@@ -26,8 +26,8 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct ...@@ -26,8 +26,8 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct
- To run tensor parallelism on multiple nodes, add `--nnodes 2`. If you have two nodes with two GPUs on each node and want to run TP=4, let `sgl-dev-0` be the hostname of the first node and `50000` be an available port, you can use the following commands. If you meet deadlock, please try to add `--disable-cuda-graph` - To run tensor parallelism on multiple nodes, add `--nnodes 2`. If you have two nodes with two GPUs on each node and want to run TP=4, let `sgl-dev-0` be the hostname of the first node and `50000` be an available port, you can use the following commands. If you meet deadlock, please try to add `--disable-cuda-graph`
``` ```
# Node 0 # Node 0
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --tp 4 --nccl-init sgl-dev-0:50000 --nnodes 2 --node-rank 0 python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --tp 4 --dist-init-addr sgl-dev-0:50000 --nnodes 2 --node-rank 0
# Node 1 # Node 1
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --tp 4 --nccl-init sgl-dev-0:50000 --nnodes 2 --node-rank 1 python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --tp 4 --dist-init-addr sgl-dev-0:50000 --nnodes 2 --node-rank 1
``` ```
...@@ -11,9 +11,9 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instr ...@@ -11,9 +11,9 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instr
```bash ```bash
# on the first node, replace 172.16.4.52:20000 with your own node ip address and port # on the first node, replace 172.16.4.52:20000 with your own node ip address and port
python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instruct --tp 16 --nccl-init-addr 172.16.4.52:20000 --nnodes 2 --node-rank 0 python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instruct --tp 16 --dist-init-addr-addr 172.16.4.52:20000 --nnodes 2 --node-rank 0
# on the second node, replace 172.18.45.52:20000 with your own node ip address and port # on the second node, replace 172.18.45.52:20000 with your own node ip address and port
python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instruct --tp 16 --nccl-init-addr 172.18.45.52:20000 --nnodes 2 --node-rank 1 python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instruct --tp 16 --dist-init-addr-addr 172.18.45.52:20000 --nnodes 2 --node-rank 1
``` ```
...@@ -18,6 +18,7 @@ import triton.language as tl ...@@ -18,6 +18,7 @@ import triton.language as tl
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.layers.attention import AttentionBackend from sglang.srt.layers.attention import AttentionBackend
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.utils import is_flashinfer_available from sglang.srt.utils import is_flashinfer_available
...@@ -62,9 +63,9 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -62,9 +63,9 @@ class FlashInferAttnBackend(AttentionBackend):
self.decode_use_tensor_cores = should_use_tensor_core( self.decode_use_tensor_cores = should_use_tensor_core(
kv_cache_dtype=model_runner.kv_cache_dtype, kv_cache_dtype=model_runner.kv_cache_dtype,
num_attention_heads=model_runner.model_config.num_attention_heads num_attention_heads=model_runner.model_config.num_attention_heads
// model_runner.tp_size, // get_attention_tp_size(),
num_kv_heads=model_runner.model_config.get_num_kv_heads( num_kv_heads=model_runner.model_config.get_num_kv_heads(
model_runner.tp_size get_attention_tp_size()
), ),
) )
self.max_context_len = model_runner.model_config.context_len self.max_context_len = model_runner.model_config.context_len
...@@ -147,7 +148,7 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -147,7 +148,7 @@ class FlashInferAttnBackend(AttentionBackend):
self.prefill_cuda_graph_metadata = {} self.prefill_cuda_graph_metadata = {}
def init_forward_metadata(self, forward_batch: ForwardBatch): def init_forward_metadata(self, forward_batch: ForwardBatch):
if forward_batch.forward_mode.is_decode(): if forward_batch.forward_mode.is_decode_or_idle():
self.indices_updater_decode.update( self.indices_updater_decode.update(
forward_batch.req_pool_indices, forward_batch.req_pool_indices,
forward_batch.seq_lens, forward_batch.seq_lens,
...@@ -238,7 +239,7 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -238,7 +239,7 @@ class FlashInferAttnBackend(AttentionBackend):
forward_mode: ForwardMode, forward_mode: ForwardMode,
spec_info: Optional[SpecInfo], spec_info: Optional[SpecInfo],
): ):
if forward_mode.is_decode(): if forward_mode.is_decode_or_idle():
decode_wrappers = [] decode_wrappers = []
for i in range(self.num_wrappers): for i in range(self.num_wrappers):
decode_wrappers.append( decode_wrappers.append(
...@@ -307,7 +308,7 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -307,7 +308,7 @@ class FlashInferAttnBackend(AttentionBackend):
forward_mode: ForwardMode, forward_mode: ForwardMode,
spec_info: Optional[SpecInfo], spec_info: Optional[SpecInfo],
): ):
if forward_mode.is_decode(): if forward_mode.is_decode_or_idle():
self.indices_updater_decode.update( self.indices_updater_decode.update(
req_pool_indices[:bs], req_pool_indices[:bs],
seq_lens[:bs], seq_lens[:bs],
...@@ -453,10 +454,10 @@ class FlashInferIndicesUpdaterDecode: ...@@ -453,10 +454,10 @@ class FlashInferIndicesUpdaterDecode:
def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend): def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
# Parse Constants # Parse Constants
self.num_qo_heads = ( self.num_qo_heads = (
model_runner.model_config.num_attention_heads // model_runner.tp_size model_runner.model_config.num_attention_heads // get_attention_tp_size()
) )
self.num_kv_heads = model_runner.model_config.get_num_kv_heads( self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
model_runner.tp_size get_attention_tp_size()
) )
self.head_dim = model_runner.model_config.head_dim self.head_dim = model_runner.model_config.head_dim
self.data_type = model_runner.kv_cache_dtype self.data_type = model_runner.kv_cache_dtype
...@@ -625,10 +626,10 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -625,10 +626,10 @@ class FlashInferIndicesUpdaterPrefill:
def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend): def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
# Parse Constants # Parse Constants
self.num_qo_heads = ( self.num_qo_heads = (
model_runner.model_config.num_attention_heads // model_runner.tp_size model_runner.model_config.num_attention_heads // get_attention_tp_size()
) )
self.num_kv_heads = model_runner.model_config.get_num_kv_heads( self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
model_runner.tp_size get_attention_tp_size()
) )
self.head_dim = model_runner.model_config.head_dim self.head_dim = model_runner.model_config.head_dim
self.data_type = model_runner.kv_cache_dtype self.data_type = model_runner.kv_cache_dtype
......
...@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Optional ...@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Optional
import torch import torch
from sglang.srt.layers.attention import AttentionBackend from sglang.srt.layers.attention import AttentionBackend
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -28,12 +29,9 @@ class TritonAttnBackend(AttentionBackend): ...@@ -28,12 +29,9 @@ class TritonAttnBackend(AttentionBackend):
self.decode_attention_fwd = decode_attention_fwd self.decode_attention_fwd = decode_attention_fwd
self.extend_attention_fwd = extend_attention_fwd self.extend_attention_fwd = extend_attention_fwd
if model_runner.server_args.enable_dp_attention: self.num_head = (
self.num_head = model_runner.model_config.num_attention_heads model_runner.model_config.num_attention_heads // get_attention_tp_size()
else: )
self.num_head = (
model_runner.model_config.num_attention_heads // model_runner.tp_size
)
self.num_kv_splits = model_runner.server_args.triton_attention_num_kv_splits self.num_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1] self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
......
import torch
from vllm.distributed import GroupCoordinator, get_tp_group
_ATTN_TP_GROUP = None
_ATTN_TP_RANK = None
_ATTN_TP_SIZE = None
_DP_RANK = None
_DP_SIZE = None
def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
if not enable_dp_attention:
return tp_rank, tp_size, 0
attn_tp_size = tp_size // dp_size
dp_rank = tp_rank // attn_tp_size
attn_tp_rank = tp_rank % attn_tp_size
return attn_tp_rank, attn_tp_size, dp_rank
def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size):
global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK, _DP_SIZE
_ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK = compute_dp_attention_world_info(
enable_dp_attention, tp_rank, tp_size, dp_size
)
_DP_SIZE = dp_size
tp_group = get_tp_group()
_ATTN_TP_GROUP = GroupCoordinator(
[
list(range(head, head + _ATTN_TP_SIZE))
for head in range(0, tp_size, _ATTN_TP_SIZE)
],
tp_rank,
torch.distributed.get_backend(tp_group.device_group),
False,
False,
False,
False,
False,
group_name="attention_tp",
)
def get_attention_tp_group():
assert _ATTN_TP_GROUP is not None, "dp attention not initialized!"
return _ATTN_TP_GROUP
def get_attention_tp_rank():
assert _ATTN_TP_RANK is not None, "dp attention not initialized!"
return _ATTN_TP_RANK
def get_attention_tp_size():
assert _ATTN_TP_SIZE is not None, "dp attention not initialized!"
return _ATTN_TP_SIZE
def get_attention_dp_rank():
assert _DP_RANK is not None, "dp attention not initialized!"
return _DP_RANK
def get_attention_dp_size():
assert _DP_SIZE is not None, "dp attention not initialized!"
return _DP_SIZE
...@@ -133,7 +133,7 @@ class LogitsProcessor(nn.Module): ...@@ -133,7 +133,7 @@ class LogitsProcessor(nn.Module):
# Get the last hidden states and last logits for the next token prediction # Get the last hidden states and last logits for the next token prediction
if ( if (
logits_metadata.forward_mode.is_decode() logits_metadata.forward_mode.is_decode_or_idle()
or logits_metadata.forward_mode.is_target_verify() or logits_metadata.forward_mode.is_target_verify()
): ):
last_index = None last_index = None
......
...@@ -23,6 +23,7 @@ import psutil ...@@ -23,6 +23,7 @@ import psutil
import setproctitle import setproctitle
import zmq import zmq
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
TokenizedEmbeddingReqInput, TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
...@@ -63,9 +64,10 @@ class DataParallelController: ...@@ -63,9 +64,10 @@ class DataParallelController:
# Init inter-process communication # Init inter-process communication
self.context = zmq.Context(1 + server_args.dp_size) self.context = zmq.Context(1 + server_args.dp_size)
self.recv_from_tokenizer = get_zmq_socket( if server_args.node_rank == 0:
self.context, zmq.PULL, port_args.scheduler_input_ipc_name self.recv_from_tokenizer = get_zmq_socket(
) self.context, zmq.PULL, port_args.scheduler_input_ipc_name
)
# Dispatch method # Dispatch method
self.round_robin_counter = 0 self.round_robin_counter = 0
...@@ -75,33 +77,47 @@ class DataParallelController: ...@@ -75,33 +77,47 @@ class DataParallelController:
} }
self.dispatching = dispatch_lookup[self.load_balance_method] self.dispatching = dispatch_lookup[self.load_balance_method]
# Start data parallel workers # Launch data parallel workers
base_gpu_id = 0 self.scheduler_procs = []
self.workers = [None] * server_args.dp_size self.workers = [None] * server_args.dp_size
if not server_args.enable_dp_attention:
dp_port_args = self.launch_dp_schedulers(server_args, port_args)
else:
dp_port_args = self.launch_dp_attention_schedulers(server_args, port_args)
# Only node rank 0 runs the real data parallel controller that dispatches the requests.
if server_args.node_rank == 0:
for dp_rank in range(server_args.dp_size):
self.workers[dp_rank] = get_zmq_socket(
self.context,
zmq.PUSH,
dp_port_args[dp_rank].scheduler_input_ipc_name,
)
def launch_dp_schedulers(self, server_args, port_args):
base_gpu_id = 0
threads = [] threads = []
sockets = [] sockets = []
dp_port_args = []
for dp_rank in range(server_args.dp_size): for dp_rank in range(server_args.dp_size):
tmp_port_args = PortArgs.init_new(server_args) tmp_port_args = PortArgs.init_new(server_args)
tmp_port_args.tokenizer_ipc_name = port_args.tokenizer_ipc_name tmp_port_args.tokenizer_ipc_name = port_args.tokenizer_ipc_name
tmp_port_args.detokenizer_ipc_name = port_args.detokenizer_ipc_name tmp_port_args.detokenizer_ipc_name = port_args.detokenizer_ipc_name
dp_port_args.append(tmp_port_args)
if server_args.enable_dp_attention: # This port is checked free in PortArgs.init_new.
# Data parallelism resues the tensor parallelism group, # We hold it first so that the next dp worker gets a different port
# so all dp ranks should use the same nccl port. sockets.append(bind_port(tmp_port_args.nccl_port))
tmp_port_args.nccl_port = port_args.nccl_port
else:
# This port is checked free in PortArgs.init_new.
# We hold it first so that the next dp worker gets a different port
sockets.append(bind_port(tmp_port_args.nccl_port))
# Create a thread for each worker # Create a thread for each worker
thread = threading.Thread( thread = threading.Thread(
target=self.launch_worker_func, target=self.launch_tensor_parallel_group,
args=(server_args, tmp_port_args, base_gpu_id, dp_rank), args=(server_args, tmp_port_args, base_gpu_id, dp_rank),
) )
threads.append(thread) threads.append(thread)
base_gpu_id += 1 if server_args.enable_dp_attention else server_args.tp_size base_gpu_id += server_args.tp_size
# Free all sockets before starting the threads to launch TP workers # Free all sockets before starting the threads to launch TP workers
for sock in sockets: for sock in sockets:
...@@ -113,26 +129,14 @@ class DataParallelController: ...@@ -113,26 +129,14 @@ class DataParallelController:
for thread in threads: for thread in threads:
thread.join() thread.join()
def launch_worker_func( return dp_port_args
self,
server_args: ServerArgs,
port_args: PortArgs,
base_gpu_id: int,
dp_rank: int,
):
logger.info(f"Launch DP{dp_rank} starting at GPU #{base_gpu_id}.")
launch_func_ = ( def launch_dp_attention_schedulers(self, server_args, port_args):
self.launch_tensor_parallel_process self.launch_tensor_parallel_group(server_args, port_args, 0, None)
if server_args.enable_dp_attention dp_port_args = []
else self.launch_tensor_parallel_group for dp_rank in range(server_args.dp_size):
) dp_port_args.append(PortArgs.init_new(server_args, dp_rank))
self.workers[dp_rank] = launch_func_( return dp_port_args
server_args,
port_args,
base_gpu_id,
dp_rank,
)
def launch_tensor_parallel_group( def launch_tensor_parallel_group(
self, self,
...@@ -141,8 +145,10 @@ class DataParallelController: ...@@ -141,8 +145,10 @@ class DataParallelController:
base_gpu_id: int, base_gpu_id: int,
dp_rank: int, dp_rank: int,
): ):
if not server_args.enable_dp_attention:
logger.info(f"Launch DP{dp_rank} starting at GPU #{base_gpu_id}.")
# Launch tensor parallel scheduler processes # Launch tensor parallel scheduler processes
scheduler_procs = []
scheduler_pipe_readers = [] scheduler_pipe_readers = []
tp_size_per_node = server_args.tp_size // server_args.nnodes tp_size_per_node = server_args.tp_size // server_args.nnodes
tp_rank_range = range( tp_rank_range = range(
...@@ -150,53 +156,39 @@ class DataParallelController: ...@@ -150,53 +156,39 @@ class DataParallelController:
tp_size_per_node * (server_args.node_rank + 1), tp_size_per_node * (server_args.node_rank + 1),
) )
for tp_rank in tp_rank_range: for tp_rank in tp_rank_range:
rank_port_args = port_args
if server_args.enable_dp_attention:
# dp attention has different sharding logic
_, _, dp_rank = compute_dp_attention_world_info(
server_args.enable_dp_attention,
tp_rank,
server_args.tp_size,
server_args.dp_size,
)
# compute zmq ports for this dp rank
rank_port_args = PortArgs.init_new(server_args, dp_rank)
# Data parallelism resues the tensor parallelism group,
# so all dp ranks should use the same nccl port.
rank_port_args.nccl_port = port_args.nccl_port
reader, writer = mp.Pipe(duplex=False) reader, writer = mp.Pipe(duplex=False)
gpu_id = server_args.base_gpu_id + base_gpu_id + tp_rank % tp_size_per_node gpu_id = server_args.base_gpu_id + base_gpu_id + tp_rank % tp_size_per_node
proc = mp.Process( proc = mp.Process(
target=run_scheduler_process, target=run_scheduler_process,
args=(server_args, port_args, gpu_id, tp_rank, dp_rank, writer), args=(server_args, rank_port_args, gpu_id, tp_rank, dp_rank, writer),
) )
proc.start() proc.start()
scheduler_procs.append(proc) self.scheduler_procs.append(proc)
scheduler_pipe_readers.append(reader) scheduler_pipe_readers.append(reader)
send_to = get_zmq_socket( # Wait for model to finish loading
self.context, zmq.PUSH, port_args.scheduler_input_ipc_name
)
# Wait for model to finish loading and get max token nums
scheduler_info = [] scheduler_info = []
for i in range(len(scheduler_pipe_readers)): for i in range(len(scheduler_pipe_readers)):
scheduler_info.append(scheduler_pipe_readers[i].recv()) scheduler_info.append(scheduler_pipe_readers[i].recv())
self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"] self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"]
return send_to
def launch_tensor_parallel_process(
self,
server_args: ServerArgs,
port_args: PortArgs,
base_gpu_id: int,
dp_rank: int,
):
reader, writer = mp.Pipe(duplex=False)
gpu_id = base_gpu_id
tp_rank = dp_rank
proc = mp.Process(
target=run_scheduler_process,
args=(server_args, port_args, gpu_id, tp_rank, dp_rank, writer),
)
proc.start()
send_to = get_zmq_socket(
self.context, zmq.PUSH, port_args.scheduler_input_ipc_name
)
scheduler_info = reader.recv()
self.max_total_num_tokens = scheduler_info["max_total_num_tokens"]
return send_to
def round_robin_scheduler(self, req): def round_robin_scheduler(self, req):
self.workers[self.round_robin_counter].send_pyobj(req) self.workers[self.round_robin_counter].send_pyobj(req)
self.round_robin_counter = (self.round_robin_counter + 1) % len(self.workers) self.round_robin_counter = (self.round_robin_counter + 1) % len(self.workers)
...@@ -221,8 +213,8 @@ class DataParallelController: ...@@ -221,8 +213,8 @@ class DataParallelController:
): ):
self.dispatching(recv_req) self.dispatching(recv_req)
else: else:
# Send other control messages to all workers # Send other control messages to first worker of tp group
for worker in self.workers: for worker in self.workers[:: self.server_args.tp_size]:
worker.send_pyobj(recv_req) worker.send_pyobj(recv_req)
...@@ -240,7 +232,13 @@ def run_data_parallel_controller_process( ...@@ -240,7 +232,13 @@ def run_data_parallel_controller_process(
pipe_writer.send( pipe_writer.send(
{"status": "ready", "max_total_num_tokens": controller.max_total_num_tokens} {"status": "ready", "max_total_num_tokens": controller.max_total_num_tokens}
) )
controller.event_loop() if server_args.node_rank == 0:
controller.event_loop()
for proc in controller.scheduler_procs:
proc.join()
logger.error(
f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}"
)
except Exception: except Exception:
traceback = get_exception_traceback() traceback = get_exception_traceback()
logger.error(f"DataParallelController hit an exception: {traceback}") logger.error(f"DataParallelController hit an exception: {traceback}")
......
...@@ -1003,6 +1003,11 @@ class ScheduleBatch: ...@@ -1003,6 +1003,11 @@ class ScheduleBatch:
self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device) self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
self.seq_lens_sum = 0 self.seq_lens_sum = 0
self.extend_num_tokens = 0 self.extend_num_tokens = 0
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
self,
self.model_config.vocab_size,
enable_overlap_schedule=self.enable_overlap,
)
def prepare_for_decode(self): def prepare_for_decode(self):
self.forward_mode = ForwardMode.DECODE self.forward_mode = ForwardMode.DECODE
...@@ -1117,7 +1122,7 @@ class ScheduleBatch: ...@@ -1117,7 +1122,7 @@ class ScheduleBatch:
self.spec_info.merge_batch(other.spec_info) self.spec_info.merge_batch(other.spec_info)
def get_model_worker_batch(self): def get_model_worker_batch(self):
if self.forward_mode.is_decode() or self.forward_mode.is_idle(): if self.forward_mode.is_decode_or_idle():
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
else: else:
extend_seq_lens = self.extend_lens extend_seq_lens = self.extend_lens
......
...@@ -33,6 +33,7 @@ import zmq ...@@ -33,6 +33,7 @@ import zmq
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
AbortReq, AbortReq,
...@@ -135,7 +136,17 @@ class Scheduler: ...@@ -135,7 +136,17 @@ class Scheduler:
# Init inter-process communication # Init inter-process communication
context = zmq.Context(2) context = zmq.Context(2)
if self.tp_rank == 0 or self.server_args.enable_dp_attention: self.dp_size = server_args.dp_size
self.attn_tp_rank, self.attn_tp_size, self.dp_rank = (
compute_dp_attention_world_info(
server_args.enable_dp_attention,
self.tp_rank,
self.tp_size,
self.dp_size,
)
)
if self.attn_tp_rank == 0:
self.recv_from_tokenizer = get_zmq_socket( self.recv_from_tokenizer = get_zmq_socket(
context, zmq.PULL, port_args.scheduler_input_ipc_name context, zmq.PULL, port_args.scheduler_input_ipc_name
) )
...@@ -244,6 +255,7 @@ class Scheduler: ...@@ -244,6 +255,7 @@ class Scheduler:
_, _,
) = self.tp_worker.get_worker_info() ) = self.tp_worker.get_worker_info()
self.tp_cpu_group = self.tp_worker.get_tp_cpu_group() self.tp_cpu_group = self.tp_worker.get_tp_cpu_group()
self.attn_tp_cpu_group = self.tp_worker.get_attention_tp_cpu_group()
self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func() self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
global_server_args_dict.update(worker_global_server_args_dict) global_server_args_dict.update(worker_global_server_args_dict)
set_random_seed(self.random_seed) set_random_seed(self.random_seed)
...@@ -447,6 +459,10 @@ class Scheduler: ...@@ -447,6 +459,10 @@ class Scheduler:
self.process_input_requests(recv_reqs) self.process_input_requests(recv_reqs)
batch = self.get_next_batch_to_run() batch = self.get_next_batch_to_run()
if self.server_args.enable_dp_attention: # TODO: simplify this
batch = self.prepare_dp_attn_batch(batch)
self.cur_batch = batch self.cur_batch = batch
if batch: if batch:
...@@ -479,7 +495,7 @@ class Scheduler: ...@@ -479,7 +495,7 @@ class Scheduler:
def recv_requests(self) -> List[Req]: def recv_requests(self) -> List[Req]:
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks.""" """Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
if self.tp_rank == 0 or self.server_args.enable_dp_attention: if self.attn_tp_rank == 0:
recv_reqs = [] recv_reqs = []
while True: while True:
...@@ -491,7 +507,40 @@ class Scheduler: ...@@ -491,7 +507,40 @@ class Scheduler:
else: else:
recv_reqs = None recv_reqs = None
if self.tp_size != 1 and not self.server_args.enable_dp_attention: if self.server_args.enable_dp_attention:
if self.attn_tp_rank == 0:
work_reqs = [
req
for req in recv_reqs
if isinstance(
req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
)
]
control_reqs = [
req
for req in recv_reqs
if not isinstance(
req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
)
]
else:
work_reqs = None
control_reqs = None
if self.attn_tp_size != 1:
attn_tp_rank_0 = self.dp_rank * self.attn_tp_size
work_reqs = broadcast_pyobj(
work_reqs,
self.attn_tp_rank,
self.attn_tp_cpu_group,
src=attn_tp_rank_0,
)
if self.tp_size != 1:
control_reqs = broadcast_pyobj(
control_reqs, self.tp_rank, self.tp_cpu_group
)
recv_reqs = work_reqs + control_reqs
elif self.tp_size != 1:
recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group) recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
return recv_reqs return recv_reqs
...@@ -887,7 +936,7 @@ class Scheduler: ...@@ -887,7 +936,7 @@ class Scheduler:
self.being_chunked_req.is_being_chunked += 1 self.being_chunked_req.is_being_chunked += 1
# Print stats # Print stats
if self.tp_rank == 0: if self.attn_tp_rank == 0:
self.log_prefill_stats(adder, can_run_list, running_bs, has_being_chunked) self.log_prefill_stats(adder, can_run_list, running_bs, has_being_chunked)
# Create a new batch # Create a new batch
...@@ -974,7 +1023,7 @@ class Scheduler: ...@@ -974,7 +1023,7 @@ class Scheduler:
self.forward_ct += 1 self.forward_ct += 1
if self.is_generation: if self.is_generation:
if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0: if batch.forward_mode.is_decode_or_idle() or batch.extend_num_tokens != 0:
if self.spec_algorithm.is_none(): if self.spec_algorithm.is_none():
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
logits_output, next_token_ids = ( logits_output, next_token_ids = (
...@@ -988,18 +1037,8 @@ class Scheduler: ...@@ -988,18 +1037,8 @@ class Scheduler:
num_accepted_tokens, num_accepted_tokens,
) = self.draft_worker.forward_batch_speculative_generation(batch) ) = self.draft_worker.forward_batch_speculative_generation(batch)
self.num_generated_tokens += num_accepted_tokens self.num_generated_tokens += num_accepted_tokens
elif batch.forward_mode.is_idle():
model_worker_batch = batch.get_model_worker_batch()
self.tp_worker.forward_batch_idle(model_worker_batch)
return
else: else:
logits_output = None assert False, "batch.extend_num_tokens == 0, this is unexpected!"
if self.skip_tokenizer_init:
next_token_ids = torch.full(
(batch.batch_size(),), self.tokenizer.eos_token_id
)
else:
next_token_ids = torch.full((batch.batch_size(),), 0)
batch.output_ids = next_token_ids batch.output_ids = next_token_ids
ret = logits_output, next_token_ids, model_worker_batch.bid ret = logits_output, next_token_ids, model_worker_batch.bid
else: # embedding or reward model else: # embedding or reward model
...@@ -1016,6 +1055,9 @@ class Scheduler: ...@@ -1016,6 +1055,9 @@ class Scheduler:
self.running_batch = None self.running_batch = None
elif batch.forward_mode.is_extend(): elif batch.forward_mode.is_extend():
self.process_batch_result_prefill(batch, result) self.process_batch_result_prefill(batch, result)
elif batch.forward_mode.is_idle():
if self.enable_overlap:
self.tp_worker.resolve_batch_result(result[-1])
elif batch.forward_mode.is_dummy_first(): elif batch.forward_mode.is_dummy_first():
batch.next_batch_sampling_info.update_regex_vocab_mask() batch.next_batch_sampling_info.update_regex_vocab_mask()
self.current_stream.synchronize() self.current_stream.synchronize()
...@@ -1166,7 +1208,7 @@ class Scheduler: ...@@ -1166,7 +1208,7 @@ class Scheduler:
self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30) self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
if ( if (
self.tp_rank == 0 self.attn_tp_rank == 0
and self.forward_ct_decode % self.server_args.decode_log_interval == 0 and self.forward_ct_decode % self.server_args.decode_log_interval == 0
): ):
self.log_decode_stats() self.log_decode_stats()
...@@ -1402,12 +1444,7 @@ class Scheduler: ...@@ -1402,12 +1444,7 @@ class Scheduler:
# Check forward mode for cuda graph # Check forward mode for cuda graph
if not self.server_args.disable_cuda_graph: if not self.server_args.disable_cuda_graph:
forward_mode_state = torch.tensor( forward_mode_state = torch.tensor(
( (1 if local_batch.forward_mode.is_decode_or_idle() else 0),
1
if local_batch.forward_mode.is_decode()
or local_batch.forward_mode.is_idle()
else 0
),
dtype=torch.int32, dtype=torch.int32,
) )
torch.distributed.all_reduce( torch.distributed.all_reduce(
......
...@@ -101,6 +101,7 @@ class TpModelWorker: ...@@ -101,6 +101,7 @@ class TpModelWorker:
self.max_total_num_tokens // 2 self.max_total_num_tokens // 2
if server_args.max_running_requests is None if server_args.max_running_requests is None
else server_args.max_running_requests else server_args.max_running_requests
// (server_args.dp_size if server_args.enable_dp_attention else 1)
), ),
self.model_runner.req_to_token_pool.size, self.model_runner.req_to_token_pool.size,
) )
...@@ -142,16 +143,15 @@ class TpModelWorker: ...@@ -142,16 +143,15 @@ class TpModelWorker:
def get_tp_cpu_group(self): def get_tp_cpu_group(self):
return self.model_runner.tp_group.cpu_group return self.model_runner.tp_group.cpu_group
def get_attention_tp_cpu_group(self):
return self.model_runner.attention_tp_group.cpu_group
def get_memory_pool(self): def get_memory_pool(self):
return ( return (
self.model_runner.req_to_token_pool, self.model_runner.req_to_token_pool,
self.model_runner.token_to_kv_pool, self.model_runner.token_to_kv_pool,
) )
def forward_batch_idle(self, model_worker_batch: ModelWorkerBatch):
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
self.model_runner.forward(forward_batch)
def forward_batch_generation( def forward_batch_generation(
self, self,
model_worker_batch: ModelWorkerBatch, model_worker_batch: ModelWorkerBatch,
......
...@@ -92,6 +92,9 @@ class TpModelWorkerClient: ...@@ -92,6 +92,9 @@ class TpModelWorkerClient:
def get_tp_cpu_group(self): def get_tp_cpu_group(self):
return self.worker.get_tp_cpu_group() return self.worker.get_tp_cpu_group()
def get_attention_tp_cpu_group(self):
return self.worker.get_attention_tp_cpu_group()
def get_memory_pool(self): def get_memory_pool(self):
return ( return (
self.worker.model_runner.req_to_token_pool, self.worker.model_runner.req_to_token_pool,
......
...@@ -122,6 +122,7 @@ class CudaGraphRunner: ...@@ -122,6 +122,7 @@ class CudaGraphRunner:
self.is_encoder_decoder = self.model_runner.model_config.is_encoder_decoder self.is_encoder_decoder = self.model_runner.model_config.is_encoder_decoder
self.enable_dp_attention = self.model_runner.server_args.enable_dp_attention self.enable_dp_attention = self.model_runner.server_args.enable_dp_attention
self.tp_size = self.model_runner.tp_size self.tp_size = self.model_runner.tp_size
self.dp_size = self.model_runner.server_args.dp_size
# Batch sizes to capture # Batch sizes to capture
self.capture_bs = self.model_runner.server_args.cuda_graph_bs self.capture_bs = self.model_runner.server_args.cuda_graph_bs
...@@ -218,7 +219,7 @@ class CudaGraphRunner: ...@@ -218,7 +219,7 @@ class CudaGraphRunner:
if self.enable_dp_attention: if self.enable_dp_attention:
self.gathered_buffer = torch.zeros( self.gathered_buffer = torch.zeros(
( (
self.max_bs * self.tp_size, self.max_bs * self.dp_size,
self.model_runner.model_config.hidden_size, self.model_runner.model_config.hidden_size,
), ),
dtype=self.model_runner.dtype, dtype=self.model_runner.dtype,
......
...@@ -35,6 +35,10 @@ from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttn ...@@ -35,6 +35,10 @@ from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttn
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
from sglang.srt.layers.dp_attention import (
get_attention_tp_group,
initialize_dp_attention,
)
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import Sampler from sglang.srt.layers.sampler import Sampler
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
...@@ -235,11 +239,18 @@ class ModelRunner: ...@@ -235,11 +239,18 @@ class ModelRunner:
distributed_init_method=dist_init_method, distributed_init_method=dist_init_method,
) )
initialize_model_parallel(tensor_model_parallel_size=self.tp_size) initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
initialize_dp_attention(
enable_dp_attention=self.server_args.enable_dp_attention,
tp_rank=self.tp_rank,
tp_size=self.tp_size,
dp_size=self.server_args.dp_size,
)
min_per_gpu_memory = get_available_gpu_memory( min_per_gpu_memory = get_available_gpu_memory(
self.device, self.gpu_id, distributed=self.tp_size > 1 self.device, self.gpu_id, distributed=self.tp_size > 1
) )
self.tp_group = get_tp_group() self.tp_group = get_tp_group()
self.attention_tp_group = get_attention_tp_group()
# Check memory for tensor parallelism # Check memory for tensor parallelism
if self.tp_size > 1: if self.tp_size > 1:
......
...@@ -855,10 +855,9 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -855,10 +855,9 @@ class DeepseekV2ForCausalLM(nn.Module):
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch) hidden_states = self.model(input_ids, positions, forward_batch)
if not forward_batch.forward_mode.is_idle(): return self.logits_processor(
return self.logits_processor( input_ids, hidden_states, self.lm_head, forward_batch
input_ids, hidden_states, self.lm_head, forward_batch )
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
......
...@@ -239,15 +239,14 @@ class ServerArgs: ...@@ -239,15 +239,14 @@ class ServerArgs:
# Others # Others
if self.enable_dp_attention: if self.enable_dp_attention:
assert self.tp_size % self.dp_size == 0
self.dp_size = self.tp_size self.dp_size = self.tp_size
self.chunked_prefill_size = self.chunked_prefill_size // 2 self.chunked_prefill_size = self.chunked_prefill_size // 2
self.schedule_conservativeness = self.schedule_conservativeness * 0.3 self.schedule_conservativeness = self.schedule_conservativeness * 0.3
self.disable_overlap_schedule = True
logger.warning( logger.warning(
f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. " f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
f"The schedule conservativeness is adjusted to {self.schedule_conservativeness}. " f"The schedule conservativeness is adjusted to {self.schedule_conservativeness}. "
"Data parallel size is adjusted to be the same as tensor parallel size. " "Data parallel size is adjusted to be the same as tensor parallel size. "
"Overlap scheduler is disabled."
) )
# Speculative Decoding # Speculative Decoding
...@@ -880,8 +879,8 @@ class ServerArgs: ...@@ -880,8 +879,8 @@ class ServerArgs:
self.tp_size % self.nnodes == 0 self.tp_size % self.nnodes == 0
), "tp_size must be divisible by number of nodes" ), "tp_size must be divisible by number of nodes"
assert not ( assert not (
self.dp_size > 1 and self.nnodes != 1 self.dp_size > 1 and self.nnodes != 1 and not self.enable_dp_attention
), "multi-node data parallel is not supported" ), "multi-node data parallel is not supported unless dp attention!"
assert ( assert (
self.max_loras_per_batch > 0 self.max_loras_per_batch > 0
# FIXME # FIXME
...@@ -919,6 +918,9 @@ def prepare_server_args(argv: List[str]) -> ServerArgs: ...@@ -919,6 +918,9 @@ def prepare_server_args(argv: List[str]) -> ServerArgs:
return server_args return server_args
ZMQ_TCP_PORT_DELTA = 233
@dataclasses.dataclass @dataclasses.dataclass
class PortArgs: class PortArgs:
# The ipc filename for tokenizer to receive inputs from detokenizer (zmq) # The ipc filename for tokenizer to receive inputs from detokenizer (zmq)
...@@ -932,7 +934,7 @@ class PortArgs: ...@@ -932,7 +934,7 @@ class PortArgs:
nccl_port: int nccl_port: int
@staticmethod @staticmethod
def init_new(server_args) -> "PortArgs": def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs":
port = server_args.port + random.randint(100, 1000) port = server_args.port + random.randint(100, 1000)
while True: while True:
if is_port_available(port): if is_port_available(port):
...@@ -942,12 +944,39 @@ class PortArgs: ...@@ -942,12 +944,39 @@ class PortArgs:
else: else:
port -= 43 port -= 43
return PortArgs( if not server_args.enable_dp_attention:
tokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name, # Normal case, use IPC within a single node
scheduler_input_ipc_name=tempfile.NamedTemporaryFile(delete=False).name, return PortArgs(
detokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name, tokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
nccl_port=port, scheduler_input_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
) detokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
nccl_port=port,
)
else:
# DP attention. Use TCP + port to handle both single-node and multi-node.
if server_args.nnodes == 1 and server_args.dist_init_addr is None:
dist_init_addr = ("127.0.0.1", server_args.port + ZMQ_TCP_PORT_DELTA)
else:
dist_init_addr = server_args.dist_init_addr.split(":")
assert (
len(dist_init_addr) == 2
), "please provide --dist-init-addr as host:port of head node"
dist_init_host, dist_init_port = dist_init_addr
port_base = int(dist_init_port) + 1
if dp_rank is None:
scheduler_input_port = (
port_base + 2
) # TokenizerManager to DataParallelController
else:
scheduler_input_port = port_base + 2 + 1 + dp_rank
return PortArgs(
tokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base}",
scheduler_input_ipc_name=f"tcp://{dist_init_host}:{scheduler_input_port}",
detokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base + 1}",
nccl_port=port,
)
class LoRAPathAction(argparse.Action): class LoRAPathAction(argparse.Action):
......
...@@ -802,11 +802,11 @@ def get_zmq_socket(context: zmq.Context, socket_type: zmq.SocketType, endpoint: ...@@ -802,11 +802,11 @@ def get_zmq_socket(context: zmq.Context, socket_type: zmq.SocketType, endpoint:
if socket_type == zmq.PUSH: if socket_type == zmq.PUSH:
socket.setsockopt(zmq.SNDHWM, 0) socket.setsockopt(zmq.SNDHWM, 0)
socket.setsockopt(zmq.SNDBUF, buf_size) socket.setsockopt(zmq.SNDBUF, buf_size)
socket.connect(f"ipc://{endpoint}") socket.connect(endpoint)
elif socket_type == zmq.PULL: elif socket_type == zmq.PULL:
socket.setsockopt(zmq.RCVHWM, 0) socket.setsockopt(zmq.RCVHWM, 0)
socket.setsockopt(zmq.RCVBUF, buf_size) socket.setsockopt(zmq.RCVBUF, buf_size)
socket.bind(f"ipc://{endpoint}") socket.bind(endpoint)
else: else:
raise ValueError(f"Unsupported socket type: {socket_type}") raise ValueError(f"Unsupported socket type: {socket_type}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment