Unverified Commit 8b6ce52e authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Support multi-node DP attention (#2925)


Co-authored-by: default avatardhou-xai <dhou@x.ai>
parent 58f3f2b8
......@@ -26,8 +26,8 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct
- To run tensor parallelism on multiple nodes, add `--nnodes 2`. If you have two nodes with two GPUs on each node and want to run TP=4, let `sgl-dev-0` be the hostname of the first node and `50000` be an available port, you can use the following commands. If you meet deadlock, please try to add `--disable-cuda-graph`
```
# Node 0
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --tp 4 --nccl-init sgl-dev-0:50000 --nnodes 2 --node-rank 0
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --tp 4 --dist-init-addr sgl-dev-0:50000 --nnodes 2 --node-rank 0
# Node 1
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --tp 4 --nccl-init sgl-dev-0:50000 --nnodes 2 --node-rank 1
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --tp 4 --dist-init-addr sgl-dev-0:50000 --nnodes 2 --node-rank 1
```
......@@ -11,9 +11,9 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instr
```bash
# on the first node, replace 172.16.4.52:20000 with your own node ip address and port
python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instruct --tp 16 --nccl-init-addr 172.16.4.52:20000 --nnodes 2 --node-rank 0
python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instruct --tp 16 --dist-init-addr-addr 172.16.4.52:20000 --nnodes 2 --node-rank 0
# on the second node, replace 172.18.45.52:20000 with your own node ip address and port
python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instruct --tp 16 --nccl-init-addr 172.18.45.52:20000 --nnodes 2 --node-rank 1
python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instruct --tp 16 --dist-init-addr-addr 172.18.45.52:20000 --nnodes 2 --node-rank 1
```
......@@ -18,6 +18,7 @@ import triton.language as tl
from sglang.global_config import global_config
from sglang.srt.layers.attention import AttentionBackend
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.utils import is_flashinfer_available
......@@ -62,9 +63,9 @@ class FlashInferAttnBackend(AttentionBackend):
self.decode_use_tensor_cores = should_use_tensor_core(
kv_cache_dtype=model_runner.kv_cache_dtype,
num_attention_heads=model_runner.model_config.num_attention_heads
// model_runner.tp_size,
// get_attention_tp_size(),
num_kv_heads=model_runner.model_config.get_num_kv_heads(
model_runner.tp_size
get_attention_tp_size()
),
)
self.max_context_len = model_runner.model_config.context_len
......@@ -147,7 +148,7 @@ class FlashInferAttnBackend(AttentionBackend):
self.prefill_cuda_graph_metadata = {}
def init_forward_metadata(self, forward_batch: ForwardBatch):
if forward_batch.forward_mode.is_decode():
if forward_batch.forward_mode.is_decode_or_idle():
self.indices_updater_decode.update(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
......@@ -238,7 +239,7 @@ class FlashInferAttnBackend(AttentionBackend):
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
):
if forward_mode.is_decode():
if forward_mode.is_decode_or_idle():
decode_wrappers = []
for i in range(self.num_wrappers):
decode_wrappers.append(
......@@ -307,7 +308,7 @@ class FlashInferAttnBackend(AttentionBackend):
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
):
if forward_mode.is_decode():
if forward_mode.is_decode_or_idle():
self.indices_updater_decode.update(
req_pool_indices[:bs],
seq_lens[:bs],
......@@ -453,10 +454,10 @@ class FlashInferIndicesUpdaterDecode:
def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
# Parse Constants
self.num_qo_heads = (
model_runner.model_config.num_attention_heads // model_runner.tp_size
model_runner.model_config.num_attention_heads // get_attention_tp_size()
)
self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
model_runner.tp_size
get_attention_tp_size()
)
self.head_dim = model_runner.model_config.head_dim
self.data_type = model_runner.kv_cache_dtype
......@@ -625,10 +626,10 @@ class FlashInferIndicesUpdaterPrefill:
def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
# Parse Constants
self.num_qo_heads = (
model_runner.model_config.num_attention_heads // model_runner.tp_size
model_runner.model_config.num_attention_heads // get_attention_tp_size()
)
self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
model_runner.tp_size
get_attention_tp_size()
)
self.head_dim = model_runner.model_config.head_dim
self.data_type = model_runner.kv_cache_dtype
......
......@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Optional
import torch
from sglang.srt.layers.attention import AttentionBackend
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
if TYPE_CHECKING:
......@@ -28,12 +29,9 @@ class TritonAttnBackend(AttentionBackend):
self.decode_attention_fwd = decode_attention_fwd
self.extend_attention_fwd = extend_attention_fwd
if model_runner.server_args.enable_dp_attention:
self.num_head = model_runner.model_config.num_attention_heads
else:
self.num_head = (
model_runner.model_config.num_attention_heads // model_runner.tp_size
)
self.num_head = (
model_runner.model_config.num_attention_heads // get_attention_tp_size()
)
self.num_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
......
import torch
from vllm.distributed import GroupCoordinator, get_tp_group
_ATTN_TP_GROUP = None
_ATTN_TP_RANK = None
_ATTN_TP_SIZE = None
_DP_RANK = None
_DP_SIZE = None
def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
if not enable_dp_attention:
return tp_rank, tp_size, 0
attn_tp_size = tp_size // dp_size
dp_rank = tp_rank // attn_tp_size
attn_tp_rank = tp_rank % attn_tp_size
return attn_tp_rank, attn_tp_size, dp_rank
def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size):
global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK, _DP_SIZE
_ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK = compute_dp_attention_world_info(
enable_dp_attention, tp_rank, tp_size, dp_size
)
_DP_SIZE = dp_size
tp_group = get_tp_group()
_ATTN_TP_GROUP = GroupCoordinator(
[
list(range(head, head + _ATTN_TP_SIZE))
for head in range(0, tp_size, _ATTN_TP_SIZE)
],
tp_rank,
torch.distributed.get_backend(tp_group.device_group),
False,
False,
False,
False,
False,
group_name="attention_tp",
)
def get_attention_tp_group():
assert _ATTN_TP_GROUP is not None, "dp attention not initialized!"
return _ATTN_TP_GROUP
def get_attention_tp_rank():
assert _ATTN_TP_RANK is not None, "dp attention not initialized!"
return _ATTN_TP_RANK
def get_attention_tp_size():
assert _ATTN_TP_SIZE is not None, "dp attention not initialized!"
return _ATTN_TP_SIZE
def get_attention_dp_rank():
assert _DP_RANK is not None, "dp attention not initialized!"
return _DP_RANK
def get_attention_dp_size():
assert _DP_SIZE is not None, "dp attention not initialized!"
return _DP_SIZE
......@@ -133,7 +133,7 @@ class LogitsProcessor(nn.Module):
# Get the last hidden states and last logits for the next token prediction
if (
logits_metadata.forward_mode.is_decode()
logits_metadata.forward_mode.is_decode_or_idle()
or logits_metadata.forward_mode.is_target_verify()
):
last_index = None
......
......@@ -23,6 +23,7 @@ import psutil
import setproctitle
import zmq
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
from sglang.srt.managers.io_struct import (
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
......@@ -63,9 +64,10 @@ class DataParallelController:
# Init inter-process communication
self.context = zmq.Context(1 + server_args.dp_size)
self.recv_from_tokenizer = get_zmq_socket(
self.context, zmq.PULL, port_args.scheduler_input_ipc_name
)
if server_args.node_rank == 0:
self.recv_from_tokenizer = get_zmq_socket(
self.context, zmq.PULL, port_args.scheduler_input_ipc_name
)
# Dispatch method
self.round_robin_counter = 0
......@@ -75,33 +77,47 @@ class DataParallelController:
}
self.dispatching = dispatch_lookup[self.load_balance_method]
# Start data parallel workers
base_gpu_id = 0
# Launch data parallel workers
self.scheduler_procs = []
self.workers = [None] * server_args.dp_size
if not server_args.enable_dp_attention:
dp_port_args = self.launch_dp_schedulers(server_args, port_args)
else:
dp_port_args = self.launch_dp_attention_schedulers(server_args, port_args)
# Only node rank 0 runs the real data parallel controller that dispatches the requests.
if server_args.node_rank == 0:
for dp_rank in range(server_args.dp_size):
self.workers[dp_rank] = get_zmq_socket(
self.context,
zmq.PUSH,
dp_port_args[dp_rank].scheduler_input_ipc_name,
)
def launch_dp_schedulers(self, server_args, port_args):
base_gpu_id = 0
threads = []
sockets = []
dp_port_args = []
for dp_rank in range(server_args.dp_size):
tmp_port_args = PortArgs.init_new(server_args)
tmp_port_args.tokenizer_ipc_name = port_args.tokenizer_ipc_name
tmp_port_args.detokenizer_ipc_name = port_args.detokenizer_ipc_name
dp_port_args.append(tmp_port_args)
if server_args.enable_dp_attention:
# Data parallelism resues the tensor parallelism group,
# so all dp ranks should use the same nccl port.
tmp_port_args.nccl_port = port_args.nccl_port
else:
# This port is checked free in PortArgs.init_new.
# We hold it first so that the next dp worker gets a different port
sockets.append(bind_port(tmp_port_args.nccl_port))
# This port is checked free in PortArgs.init_new.
# We hold it first so that the next dp worker gets a different port
sockets.append(bind_port(tmp_port_args.nccl_port))
# Create a thread for each worker
thread = threading.Thread(
target=self.launch_worker_func,
target=self.launch_tensor_parallel_group,
args=(server_args, tmp_port_args, base_gpu_id, dp_rank),
)
threads.append(thread)
base_gpu_id += 1 if server_args.enable_dp_attention else server_args.tp_size
base_gpu_id += server_args.tp_size
# Free all sockets before starting the threads to launch TP workers
for sock in sockets:
......@@ -113,26 +129,14 @@ class DataParallelController:
for thread in threads:
thread.join()
def launch_worker_func(
self,
server_args: ServerArgs,
port_args: PortArgs,
base_gpu_id: int,
dp_rank: int,
):
logger.info(f"Launch DP{dp_rank} starting at GPU #{base_gpu_id}.")
return dp_port_args
launch_func_ = (
self.launch_tensor_parallel_process
if server_args.enable_dp_attention
else self.launch_tensor_parallel_group
)
self.workers[dp_rank] = launch_func_(
server_args,
port_args,
base_gpu_id,
dp_rank,
)
def launch_dp_attention_schedulers(self, server_args, port_args):
self.launch_tensor_parallel_group(server_args, port_args, 0, None)
dp_port_args = []
for dp_rank in range(server_args.dp_size):
dp_port_args.append(PortArgs.init_new(server_args, dp_rank))
return dp_port_args
def launch_tensor_parallel_group(
self,
......@@ -141,8 +145,10 @@ class DataParallelController:
base_gpu_id: int,
dp_rank: int,
):
if not server_args.enable_dp_attention:
logger.info(f"Launch DP{dp_rank} starting at GPU #{base_gpu_id}.")
# Launch tensor parallel scheduler processes
scheduler_procs = []
scheduler_pipe_readers = []
tp_size_per_node = server_args.tp_size // server_args.nnodes
tp_rank_range = range(
......@@ -150,53 +156,39 @@ class DataParallelController:
tp_size_per_node * (server_args.node_rank + 1),
)
for tp_rank in tp_rank_range:
rank_port_args = port_args
if server_args.enable_dp_attention:
# dp attention has different sharding logic
_, _, dp_rank = compute_dp_attention_world_info(
server_args.enable_dp_attention,
tp_rank,
server_args.tp_size,
server_args.dp_size,
)
# compute zmq ports for this dp rank
rank_port_args = PortArgs.init_new(server_args, dp_rank)
# Data parallelism resues the tensor parallelism group,
# so all dp ranks should use the same nccl port.
rank_port_args.nccl_port = port_args.nccl_port
reader, writer = mp.Pipe(duplex=False)
gpu_id = server_args.base_gpu_id + base_gpu_id + tp_rank % tp_size_per_node
proc = mp.Process(
target=run_scheduler_process,
args=(server_args, port_args, gpu_id, tp_rank, dp_rank, writer),
args=(server_args, rank_port_args, gpu_id, tp_rank, dp_rank, writer),
)
proc.start()
scheduler_procs.append(proc)
self.scheduler_procs.append(proc)
scheduler_pipe_readers.append(reader)
send_to = get_zmq_socket(
self.context, zmq.PUSH, port_args.scheduler_input_ipc_name
)
# Wait for model to finish loading and get max token nums
# Wait for model to finish loading
scheduler_info = []
for i in range(len(scheduler_pipe_readers)):
scheduler_info.append(scheduler_pipe_readers[i].recv())
self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"]
return send_to
def launch_tensor_parallel_process(
self,
server_args: ServerArgs,
port_args: PortArgs,
base_gpu_id: int,
dp_rank: int,
):
reader, writer = mp.Pipe(duplex=False)
gpu_id = base_gpu_id
tp_rank = dp_rank
proc = mp.Process(
target=run_scheduler_process,
args=(server_args, port_args, gpu_id, tp_rank, dp_rank, writer),
)
proc.start()
send_to = get_zmq_socket(
self.context, zmq.PUSH, port_args.scheduler_input_ipc_name
)
scheduler_info = reader.recv()
self.max_total_num_tokens = scheduler_info["max_total_num_tokens"]
return send_to
def round_robin_scheduler(self, req):
self.workers[self.round_robin_counter].send_pyobj(req)
self.round_robin_counter = (self.round_robin_counter + 1) % len(self.workers)
......@@ -221,8 +213,8 @@ class DataParallelController:
):
self.dispatching(recv_req)
else:
# Send other control messages to all workers
for worker in self.workers:
# Send other control messages to first worker of tp group
for worker in self.workers[:: self.server_args.tp_size]:
worker.send_pyobj(recv_req)
......@@ -240,7 +232,13 @@ def run_data_parallel_controller_process(
pipe_writer.send(
{"status": "ready", "max_total_num_tokens": controller.max_total_num_tokens}
)
controller.event_loop()
if server_args.node_rank == 0:
controller.event_loop()
for proc in controller.scheduler_procs:
proc.join()
logger.error(
f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}"
)
except Exception:
traceback = get_exception_traceback()
logger.error(f"DataParallelController hit an exception: {traceback}")
......
......@@ -1003,6 +1003,11 @@ class ScheduleBatch:
self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
self.seq_lens_sum = 0
self.extend_num_tokens = 0
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
self,
self.model_config.vocab_size,
enable_overlap_schedule=self.enable_overlap,
)
def prepare_for_decode(self):
self.forward_mode = ForwardMode.DECODE
......@@ -1117,7 +1122,7 @@ class ScheduleBatch:
self.spec_info.merge_batch(other.spec_info)
def get_model_worker_batch(self):
if self.forward_mode.is_decode() or self.forward_mode.is_idle():
if self.forward_mode.is_decode_or_idle():
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
else:
extend_seq_lens = self.extend_lens
......
......@@ -33,6 +33,7 @@ import zmq
from sglang.global_config import global_config
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import (
AbortReq,
......@@ -135,7 +136,17 @@ class Scheduler:
# Init inter-process communication
context = zmq.Context(2)
if self.tp_rank == 0 or self.server_args.enable_dp_attention:
self.dp_size = server_args.dp_size
self.attn_tp_rank, self.attn_tp_size, self.dp_rank = (
compute_dp_attention_world_info(
server_args.enable_dp_attention,
self.tp_rank,
self.tp_size,
self.dp_size,
)
)
if self.attn_tp_rank == 0:
self.recv_from_tokenizer = get_zmq_socket(
context, zmq.PULL, port_args.scheduler_input_ipc_name
)
......@@ -244,6 +255,7 @@ class Scheduler:
_,
) = self.tp_worker.get_worker_info()
self.tp_cpu_group = self.tp_worker.get_tp_cpu_group()
self.attn_tp_cpu_group = self.tp_worker.get_attention_tp_cpu_group()
self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
global_server_args_dict.update(worker_global_server_args_dict)
set_random_seed(self.random_seed)
......@@ -447,6 +459,10 @@ class Scheduler:
self.process_input_requests(recv_reqs)
batch = self.get_next_batch_to_run()
if self.server_args.enable_dp_attention: # TODO: simplify this
batch = self.prepare_dp_attn_batch(batch)
self.cur_batch = batch
if batch:
......@@ -479,7 +495,7 @@ class Scheduler:
def recv_requests(self) -> List[Req]:
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
if self.tp_rank == 0 or self.server_args.enable_dp_attention:
if self.attn_tp_rank == 0:
recv_reqs = []
while True:
......@@ -491,7 +507,40 @@ class Scheduler:
else:
recv_reqs = None
if self.tp_size != 1 and not self.server_args.enable_dp_attention:
if self.server_args.enable_dp_attention:
if self.attn_tp_rank == 0:
work_reqs = [
req
for req in recv_reqs
if isinstance(
req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
)
]
control_reqs = [
req
for req in recv_reqs
if not isinstance(
req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
)
]
else:
work_reqs = None
control_reqs = None
if self.attn_tp_size != 1:
attn_tp_rank_0 = self.dp_rank * self.attn_tp_size
work_reqs = broadcast_pyobj(
work_reqs,
self.attn_tp_rank,
self.attn_tp_cpu_group,
src=attn_tp_rank_0,
)
if self.tp_size != 1:
control_reqs = broadcast_pyobj(
control_reqs, self.tp_rank, self.tp_cpu_group
)
recv_reqs = work_reqs + control_reqs
elif self.tp_size != 1:
recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
return recv_reqs
......@@ -887,7 +936,7 @@ class Scheduler:
self.being_chunked_req.is_being_chunked += 1
# Print stats
if self.tp_rank == 0:
if self.attn_tp_rank == 0:
self.log_prefill_stats(adder, can_run_list, running_bs, has_being_chunked)
# Create a new batch
......@@ -974,7 +1023,7 @@ class Scheduler:
self.forward_ct += 1
if self.is_generation:
if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
if batch.forward_mode.is_decode_or_idle() or batch.extend_num_tokens != 0:
if self.spec_algorithm.is_none():
model_worker_batch = batch.get_model_worker_batch()
logits_output, next_token_ids = (
......@@ -988,18 +1037,8 @@ class Scheduler:
num_accepted_tokens,
) = self.draft_worker.forward_batch_speculative_generation(batch)
self.num_generated_tokens += num_accepted_tokens
elif batch.forward_mode.is_idle():
model_worker_batch = batch.get_model_worker_batch()
self.tp_worker.forward_batch_idle(model_worker_batch)
return
else:
logits_output = None
if self.skip_tokenizer_init:
next_token_ids = torch.full(
(batch.batch_size(),), self.tokenizer.eos_token_id
)
else:
next_token_ids = torch.full((batch.batch_size(),), 0)
assert False, "batch.extend_num_tokens == 0, this is unexpected!"
batch.output_ids = next_token_ids
ret = logits_output, next_token_ids, model_worker_batch.bid
else: # embedding or reward model
......@@ -1016,6 +1055,9 @@ class Scheduler:
self.running_batch = None
elif batch.forward_mode.is_extend():
self.process_batch_result_prefill(batch, result)
elif batch.forward_mode.is_idle():
if self.enable_overlap:
self.tp_worker.resolve_batch_result(result[-1])
elif batch.forward_mode.is_dummy_first():
batch.next_batch_sampling_info.update_regex_vocab_mask()
self.current_stream.synchronize()
......@@ -1166,7 +1208,7 @@ class Scheduler:
self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
if (
self.tp_rank == 0
self.attn_tp_rank == 0
and self.forward_ct_decode % self.server_args.decode_log_interval == 0
):
self.log_decode_stats()
......@@ -1402,12 +1444,7 @@ class Scheduler:
# Check forward mode for cuda graph
if not self.server_args.disable_cuda_graph:
forward_mode_state = torch.tensor(
(
1
if local_batch.forward_mode.is_decode()
or local_batch.forward_mode.is_idle()
else 0
),
(1 if local_batch.forward_mode.is_decode_or_idle() else 0),
dtype=torch.int32,
)
torch.distributed.all_reduce(
......
......@@ -101,6 +101,7 @@ class TpModelWorker:
self.max_total_num_tokens // 2
if server_args.max_running_requests is None
else server_args.max_running_requests
// (server_args.dp_size if server_args.enable_dp_attention else 1)
),
self.model_runner.req_to_token_pool.size,
)
......@@ -142,16 +143,15 @@ class TpModelWorker:
def get_tp_cpu_group(self):
return self.model_runner.tp_group.cpu_group
def get_attention_tp_cpu_group(self):
return self.model_runner.attention_tp_group.cpu_group
def get_memory_pool(self):
return (
self.model_runner.req_to_token_pool,
self.model_runner.token_to_kv_pool,
)
def forward_batch_idle(self, model_worker_batch: ModelWorkerBatch):
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
self.model_runner.forward(forward_batch)
def forward_batch_generation(
self,
model_worker_batch: ModelWorkerBatch,
......
......@@ -92,6 +92,9 @@ class TpModelWorkerClient:
def get_tp_cpu_group(self):
return self.worker.get_tp_cpu_group()
def get_attention_tp_cpu_group(self):
return self.worker.get_attention_tp_cpu_group()
def get_memory_pool(self):
return (
self.worker.model_runner.req_to_token_pool,
......
......@@ -122,6 +122,7 @@ class CudaGraphRunner:
self.is_encoder_decoder = self.model_runner.model_config.is_encoder_decoder
self.enable_dp_attention = self.model_runner.server_args.enable_dp_attention
self.tp_size = self.model_runner.tp_size
self.dp_size = self.model_runner.server_args.dp_size
# Batch sizes to capture
self.capture_bs = self.model_runner.server_args.cuda_graph_bs
......@@ -218,7 +219,7 @@ class CudaGraphRunner:
if self.enable_dp_attention:
self.gathered_buffer = torch.zeros(
(
self.max_bs * self.tp_size,
self.max_bs * self.dp_size,
self.model_runner.model_config.hidden_size,
),
dtype=self.model_runner.dtype,
......
......@@ -35,6 +35,10 @@ from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttn
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
from sglang.srt.layers.dp_attention import (
get_attention_tp_group,
initialize_dp_attention,
)
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import Sampler
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
......@@ -235,11 +239,18 @@ class ModelRunner:
distributed_init_method=dist_init_method,
)
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
initialize_dp_attention(
enable_dp_attention=self.server_args.enable_dp_attention,
tp_rank=self.tp_rank,
tp_size=self.tp_size,
dp_size=self.server_args.dp_size,
)
min_per_gpu_memory = get_available_gpu_memory(
self.device, self.gpu_id, distributed=self.tp_size > 1
)
self.tp_group = get_tp_group()
self.attention_tp_group = get_attention_tp_group()
# Check memory for tensor parallelism
if self.tp_size > 1:
......
......@@ -855,10 +855,9 @@ class DeepseekV2ForCausalLM(nn.Module):
forward_batch: ForwardBatch,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch)
if not forward_batch.forward_mode.is_idle():
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
)
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
......
......@@ -239,15 +239,14 @@ class ServerArgs:
# Others
if self.enable_dp_attention:
assert self.tp_size % self.dp_size == 0
self.dp_size = self.tp_size
self.chunked_prefill_size = self.chunked_prefill_size // 2
self.schedule_conservativeness = self.schedule_conservativeness * 0.3
self.disable_overlap_schedule = True
logger.warning(
f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
f"The schedule conservativeness is adjusted to {self.schedule_conservativeness}. "
"Data parallel size is adjusted to be the same as tensor parallel size. "
"Overlap scheduler is disabled."
)
# Speculative Decoding
......@@ -880,8 +879,8 @@ class ServerArgs:
self.tp_size % self.nnodes == 0
), "tp_size must be divisible by number of nodes"
assert not (
self.dp_size > 1 and self.nnodes != 1
), "multi-node data parallel is not supported"
self.dp_size > 1 and self.nnodes != 1 and not self.enable_dp_attention
), "multi-node data parallel is not supported unless dp attention!"
assert (
self.max_loras_per_batch > 0
# FIXME
......@@ -919,6 +918,9 @@ def prepare_server_args(argv: List[str]) -> ServerArgs:
return server_args
ZMQ_TCP_PORT_DELTA = 233
@dataclasses.dataclass
class PortArgs:
# The ipc filename for tokenizer to receive inputs from detokenizer (zmq)
......@@ -932,7 +934,7 @@ class PortArgs:
nccl_port: int
@staticmethod
def init_new(server_args) -> "PortArgs":
def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs":
port = server_args.port + random.randint(100, 1000)
while True:
if is_port_available(port):
......@@ -942,12 +944,39 @@ class PortArgs:
else:
port -= 43
return PortArgs(
tokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
scheduler_input_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
detokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
nccl_port=port,
)
if not server_args.enable_dp_attention:
# Normal case, use IPC within a single node
return PortArgs(
tokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
scheduler_input_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
detokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
nccl_port=port,
)
else:
# DP attention. Use TCP + port to handle both single-node and multi-node.
if server_args.nnodes == 1 and server_args.dist_init_addr is None:
dist_init_addr = ("127.0.0.1", server_args.port + ZMQ_TCP_PORT_DELTA)
else:
dist_init_addr = server_args.dist_init_addr.split(":")
assert (
len(dist_init_addr) == 2
), "please provide --dist-init-addr as host:port of head node"
dist_init_host, dist_init_port = dist_init_addr
port_base = int(dist_init_port) + 1
if dp_rank is None:
scheduler_input_port = (
port_base + 2
) # TokenizerManager to DataParallelController
else:
scheduler_input_port = port_base + 2 + 1 + dp_rank
return PortArgs(
tokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base}",
scheduler_input_ipc_name=f"tcp://{dist_init_host}:{scheduler_input_port}",
detokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base + 1}",
nccl_port=port,
)
class LoRAPathAction(argparse.Action):
......
......@@ -802,11 +802,11 @@ def get_zmq_socket(context: zmq.Context, socket_type: zmq.SocketType, endpoint:
if socket_type == zmq.PUSH:
socket.setsockopt(zmq.SNDHWM, 0)
socket.setsockopt(zmq.SNDBUF, buf_size)
socket.connect(f"ipc://{endpoint}")
socket.connect(endpoint)
elif socket_type == zmq.PULL:
socket.setsockopt(zmq.RCVHWM, 0)
socket.setsockopt(zmq.RCVBUF, buf_size)
socket.bind(f"ipc://{endpoint}")
socket.bind(endpoint)
else:
raise ValueError(f"Unsupported socket type: {socket_type}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment