Unverified Commit 11383cec authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

[PP] Add pipeline parallelism (#5724)

parent e97e57e6
...@@ -154,6 +154,8 @@ def load_model(server_args, port_args, tp_rank): ...@@ -154,6 +154,8 @@ def load_model(server_args, port_args, tp_rank):
gpu_id=tp_rank, gpu_id=tp_rank,
tp_rank=tp_rank, tp_rank=tp_rank,
tp_size=server_args.tp_size, tp_size=server_args.tp_size,
pp_rank=0,
pp_size=1,
nccl_port=port_args.nccl_port, nccl_port=port_args.nccl_port,
server_args=server_args, server_args=server_args,
) )
......
...@@ -126,7 +126,6 @@ class Engine(EngineBase): ...@@ -126,7 +126,6 @@ class Engine(EngineBase):
server_args=server_args, server_args=server_args,
port_args=port_args, port_args=port_args,
) )
self.server_args = server_args self.server_args = server_args
self.tokenizer_manager = tokenizer_manager self.tokenizer_manager = tokenizer_manager
self.scheduler_info = scheduler_info self.scheduler_info = scheduler_info
...@@ -301,7 +300,6 @@ class Engine(EngineBase): ...@@ -301,7 +300,6 @@ class Engine(EngineBase):
internal_states = loop.run_until_complete( internal_states = loop.run_until_complete(
self.tokenizer_manager.get_internal_state() self.tokenizer_manager.get_internal_state()
) )
return { return {
**dataclasses.asdict(self.tokenizer_manager.server_args), **dataclasses.asdict(self.tokenizer_manager.server_args),
**self.scheduler_info, **self.scheduler_info,
...@@ -520,25 +518,44 @@ def _launch_subprocesses( ...@@ -520,25 +518,44 @@ def _launch_subprocesses(
) )
scheduler_pipe_readers = [] scheduler_pipe_readers = []
tp_size_per_node = server_args.tp_size // server_args.nnodes
nnodes_per_tp_group = max(server_args.nnodes // server_args.pp_size, 1)
tp_size_per_node = server_args.tp_size // nnodes_per_tp_group
tp_rank_range = range( tp_rank_range = range(
tp_size_per_node * server_args.node_rank, tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group),
tp_size_per_node * (server_args.node_rank + 1), tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group + 1),
) )
for tp_rank in tp_rank_range:
reader, writer = mp.Pipe(duplex=False) pp_size_per_node = max(server_args.pp_size // server_args.nnodes, 1)
gpu_id = ( pp_rank_range = range(
server_args.base_gpu_id pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group),
+ (tp_rank % tp_size_per_node) * server_args.gpu_id_step pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group + 1),
) )
proc = mp.Process(
target=run_scheduler_process, for pp_rank in pp_rank_range:
args=(server_args, port_args, gpu_id, tp_rank, None, writer), for tp_rank in tp_rank_range:
) reader, writer = mp.Pipe(duplex=False)
with memory_saver_adapter.configure_subprocess(): gpu_id = (
proc.start() server_args.base_gpu_id
scheduler_procs.append(proc) + ((pp_rank % pp_size_per_node) * tp_size_per_node)
scheduler_pipe_readers.append(reader) + (tp_rank % tp_size_per_node) * server_args.gpu_id_step
)
proc = mp.Process(
target=run_scheduler_process,
args=(
server_args,
port_args,
gpu_id,
tp_rank,
pp_rank,
None,
writer,
),
)
with memory_saver_adapter.configure_subprocess():
proc.start()
scheduler_procs.append(proc)
scheduler_pipe_readers.append(reader)
else: else:
# Launch the data parallel controller # Launch the data parallel controller
reader, writer = mp.Pipe(duplex=False) reader, writer = mp.Pipe(duplex=False)
......
...@@ -43,6 +43,7 @@ def initialize_dp_attention( ...@@ -43,6 +43,7 @@ def initialize_dp_attention(
tp_rank: int, tp_rank: int,
tp_size: int, tp_size: int,
dp_size: int, dp_size: int,
pp_size: int,
): ):
global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK, _DP_SIZE global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK, _DP_SIZE
...@@ -53,17 +54,19 @@ def initialize_dp_attention( ...@@ -53,17 +54,19 @@ def initialize_dp_attention(
) )
if enable_dp_attention: if enable_dp_attention:
local_rank = tp_rank % (tp_size // dp_size)
_DP_SIZE = dp_size _DP_SIZE = dp_size
else: else:
local_rank = tp_rank
_DP_SIZE = 1 _DP_SIZE = 1
tp_group = get_tp_group() tp_group = get_tp_group()
_ATTN_TP_GROUP = GroupCoordinator( _ATTN_TP_GROUP = GroupCoordinator(
[ [
list(range(head, head + _ATTN_TP_SIZE)) list(range(head, head + _ATTN_TP_SIZE))
for head in range(0, tp_size, _ATTN_TP_SIZE) for head in range(0, pp_size * tp_size, _ATTN_TP_SIZE)
], ],
tp_group.local_rank, local_rank,
torch.distributed.get_backend(tp_group.device_group), torch.distributed.get_backend(tp_group.device_group),
SYNC_TOKEN_IDS_ACROSS_TP, SYNC_TOKEN_IDS_ACROSS_TP,
False, False,
......
import logging
import re
import torch
logger = logging.getLogger(__name__)
def get_layer_id(weight_name):
# example weight name: model.layers.10.self_attn.qkv_proj.weight
match = re.search(r"layers\.(\d+)\.", weight_name)
if match:
return int(match.group(1))
return None
class PPMissingLayer(torch.nn.Identity):
# Adapted from
# https://github.com/vllm-project/vllm/blob/18ed3132d2bfe1df9a74729457b69243955221e8/vllm/model_executor/models/utils.py#L468C1-L486C1
"""
A placeholder layer for missing layers in a pipeline parallel model.
"""
def __init__(self, *args, **kwargs):
super().__init__()
self.return_tuple = kwargs.get("return_tuple", False)
def forward(self, *args, **kwargs):
"""
Return the first arg from args or the first value from kwargs.
Wraps the input in a tuple if `self.return_tuple` is True.
"""
input = args[0] if args else next(iter(kwargs.values()))
return (input,) if self.return_tuple else input
...@@ -181,44 +181,62 @@ class DataParallelController: ...@@ -181,44 +181,62 @@ class DataParallelController:
enable=server_args.enable_memory_saver enable=server_args.enable_memory_saver
) )
# Launch tensor parallel scheduler processes
scheduler_pipe_readers = [] scheduler_pipe_readers = []
tp_size_per_node = server_args.tp_size // server_args.nnodes
nnodes_per_tp_group = max(server_args.nnodes // server_args.pp_size, 1)
tp_size_per_node = server_args.tp_size // nnodes_per_tp_group
tp_rank_range = range( tp_rank_range = range(
tp_size_per_node * server_args.node_rank, tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group),
tp_size_per_node * (server_args.node_rank + 1), tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group + 1),
)
pp_size_per_node = max(server_args.pp_size // server_args.nnodes, 1)
pp_rank_range = range(
pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group),
pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group + 1),
) )
for tp_rank in tp_rank_range:
rank_port_args = port_args for pp_rank in pp_rank_range:
for tp_rank in tp_rank_range:
if server_args.enable_dp_attention: rank_port_args = port_args
# dp attention has different sharding logic
_, _, dp_rank = compute_dp_attention_world_info( if server_args.enable_dp_attention:
server_args.enable_dp_attention, # dp attention has different sharding logic
tp_rank, _, _, dp_rank = compute_dp_attention_world_info(
server_args.tp_size, server_args.enable_dp_attention,
server_args.dp_size, tp_rank,
server_args.tp_size,
server_args.dp_size,
)
# compute zmq ports for this dp rank
rank_port_args = PortArgs.init_new(server_args, dp_rank)
# Data parallelism resues the tensor parallelism group,
# so all dp ranks should use the same nccl port.
rank_port_args.nccl_port = port_args.nccl_port
reader, writer = mp.Pipe(duplex=False)
gpu_id = (
server_args.base_gpu_id
+ base_gpu_id
+ ((pp_rank % pp_size_per_node) * tp_size_per_node)
+ (tp_rank % tp_size_per_node) * server_args.gpu_id_step
) )
# compute zmq ports for this dp rank proc = mp.Process(
rank_port_args = PortArgs.init_new(server_args, dp_rank) target=run_scheduler_process,
# Data parallelism resues the tensor parallelism group, args=(
# so all dp ranks should use the same nccl port. server_args,
rank_port_args.nccl_port = port_args.nccl_port rank_port_args,
gpu_id,
reader, writer = mp.Pipe(duplex=False) tp_rank,
gpu_id = ( pp_rank,
server_args.base_gpu_id dp_rank,
+ base_gpu_id writer,
+ (tp_rank % tp_size_per_node) * server_args.gpu_id_step ),
) )
proc = mp.Process( with memory_saver_adapter.configure_subprocess():
target=run_scheduler_process, proc.start()
args=(server_args, rank_port_args, gpu_id, tp_rank, dp_rank, writer), self.scheduler_procs.append(proc)
) scheduler_pipe_readers.append(reader)
with memory_saver_adapter.configure_subprocess():
proc.start()
self.scheduler_procs.append(proc)
scheduler_pipe_readers.append(reader)
# Wait for model to finish loading # Wait for model to finish loading
scheduler_info = [] scheduler_info = []
......
...@@ -66,23 +66,24 @@ INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 ...@@ -66,23 +66,24 @@ INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
# Put some global args for easy access # Put some global args for easy access
global_server_args_dict = { global_server_args_dict = {
"attention_backend": ServerArgs.attention_backend, "attention_backend": ServerArgs.attention_backend,
"sampling_backend": ServerArgs.sampling_backend, "chunked_prefill_size": ServerArgs.chunked_prefill_size,
"triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
"torchao_config": ServerArgs.torchao_config,
"enable_nan_detection": ServerArgs.enable_nan_detection,
"enable_dp_attention": ServerArgs.enable_dp_attention,
"enable_ep_moe": ServerArgs.enable_ep_moe,
"enable_deepep_moe": ServerArgs.enable_deepep_moe,
"deepep_mode": ServerArgs.deepep_mode, "deepep_mode": ServerArgs.deepep_mode,
"device": ServerArgs.device, "device": ServerArgs.device,
"speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single, "disable_chunked_prefix_cache": ServerArgs.disable_chunked_prefix_cache,
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
"disable_radix_cache": ServerArgs.disable_radix_cache, "disable_radix_cache": ServerArgs.disable_radix_cache,
"enable_deepep_moe": ServerArgs.enable_deepep_moe,
"enable_dp_attention": ServerArgs.enable_dp_attention,
"enable_ep_moe": ServerArgs.enable_ep_moe,
"enable_nan_detection": ServerArgs.enable_nan_detection,
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged, "flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
"max_micro_batch_size": ServerArgs.max_micro_batch_size,
"moe_dense_tp_size": ServerArgs.moe_dense_tp_size, "moe_dense_tp_size": ServerArgs.moe_dense_tp_size,
"chunked_prefill_size": ServerArgs.chunked_prefill_size,
"n_share_experts_fusion": ServerArgs.n_share_experts_fusion, "n_share_experts_fusion": ServerArgs.n_share_experts_fusion,
"disable_chunked_prefix_cache": ServerArgs.disable_chunked_prefix_cache, "sampling_backend": ServerArgs.sampling_backend,
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
"speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
"torchao_config": ServerArgs.torchao_config,
"triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
} }
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -728,6 +729,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -728,6 +729,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Events # Events
launch_done: Optional[threading.Event] = None launch_done: Optional[threading.Event] = None
# For chunked prefill in PP
chunked_req: Optional[Req] = None
# Sampling info # Sampling info
sampling_info: SamplingBatchInfo = None sampling_info: SamplingBatchInfo = None
next_batch_sampling_info: SamplingBatchInfo = None next_batch_sampling_info: SamplingBatchInfo = None
...@@ -761,7 +765,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -761,7 +765,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# For extend and mixed chunekd prefill # For extend and mixed chunekd prefill
prefix_lens: List[int] = None prefix_lens: List[int] = None
extend_lens: List[int] = None extend_lens: List[int] = None
extend_num_tokens: int = None extend_num_tokens: Optional[int] = None
decoding_reqs: List[Req] = None decoding_reqs: List[Req] = None
extend_logprob_start_lens: List[int] = None extend_logprob_start_lens: List[int] = None
# It comes empty list if logprob is not required. # It comes empty list if logprob is not required.
...@@ -803,6 +807,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -803,6 +807,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
enable_overlap: bool, enable_overlap: bool,
spec_algorithm: SpeculativeAlgorithm, spec_algorithm: SpeculativeAlgorithm,
enable_custom_logit_processor: bool, enable_custom_logit_processor: bool,
chunked_req: Optional[Req] = None,
): ):
return_logprob = any(req.return_logprob for req in reqs) return_logprob = any(req.return_logprob for req in reqs)
...@@ -820,6 +825,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -820,6 +825,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
spec_algorithm=spec_algorithm, spec_algorithm=spec_algorithm,
enable_custom_logit_processor=enable_custom_logit_processor, enable_custom_logit_processor=enable_custom_logit_processor,
return_hidden_states=any(req.return_hidden_states for req in reqs), return_hidden_states=any(req.return_hidden_states for req in reqs),
chunked_req=chunked_req,
) )
def batch_size(self): def batch_size(self):
...@@ -1236,7 +1242,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1236,7 +1242,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
def retract_decode(self, server_args: ServerArgs): def retract_decode(self, server_args: ServerArgs):
"""Retract the decoding requests when there is not enough memory.""" """Retract the decoding requests when there is not enough memory."""
sorted_indices = [i for i in range(len(self.reqs))] sorted_indices = list(range(len(self.reqs)))
# TODO(lsyin): improve retraction policy for radix cache # TODO(lsyin): improve retraction policy for radix cache
# For spec decoding, filter_batch API can only filter # For spec decoding, filter_batch API can only filter
...@@ -1413,15 +1419,19 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1413,15 +1419,19 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
def filter_batch( def filter_batch(
self, self,
chunked_req_to_exclude: Optional[Req] = None, chunked_req_to_exclude: Optional[Union[Req, List[Req]]] = None,
keep_indices: Optional[List[int]] = None, keep_indices: Optional[List[int]] = None,
): ):
if keep_indices is None: if keep_indices is None:
if isinstance(chunked_req_to_exclude, Req):
chunked_req_to_exclude = [chunked_req_to_exclude]
elif chunked_req_to_exclude is None:
chunked_req_to_exclude = []
keep_indices = [ keep_indices = [
i i
for i in range(len(self.reqs)) for i in range(len(self.reqs))
if not self.reqs[i].finished() if not self.reqs[i].finished()
and self.reqs[i] is not chunked_req_to_exclude and not self.reqs[i] in chunked_req_to_exclude
] ]
if keep_indices is None or len(keep_indices) == 0: if keep_indices is None or len(keep_indices) == 0:
......
This diff is collapsed.
...@@ -278,7 +278,7 @@ class SchedulerOutputProcessorMixin: ...@@ -278,7 +278,7 @@ class SchedulerOutputProcessorMixin:
self.attn_tp_rank == 0 self.attn_tp_rank == 0
and self.forward_ct_decode % self.server_args.decode_log_interval == 0 and self.forward_ct_decode % self.server_args.decode_log_interval == 0
): ):
self.log_decode_stats() self.log_decode_stats(running_batch=batch)
def add_input_logprob_return_values( def add_input_logprob_return_values(
self: Scheduler, self: Scheduler,
......
...@@ -15,11 +15,12 @@ ...@@ -15,11 +15,12 @@
import logging import logging
import threading import threading
from typing import Optional, Tuple from typing import Optional, Tuple, Union
import torch import torch
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.distributed import get_pp_group, get_tp_group, get_world_group
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
...@@ -31,7 +32,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -31,7 +32,7 @@ from sglang.srt.managers.io_struct import (
) )
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed
...@@ -47,6 +48,7 @@ class TpModelWorker: ...@@ -47,6 +48,7 @@ class TpModelWorker:
server_args: ServerArgs, server_args: ServerArgs,
gpu_id: int, gpu_id: int,
tp_rank: int, tp_rank: int,
pp_rank: int,
dp_rank: Optional[int], dp_rank: Optional[int],
nccl_port: int, nccl_port: int,
is_draft_worker: bool = False, is_draft_worker: bool = False,
...@@ -54,7 +56,9 @@ class TpModelWorker: ...@@ -54,7 +56,9 @@ class TpModelWorker:
token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None, token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
): ):
# Parse args # Parse args
self.tp_size = server_args.tp_size
self.tp_rank = tp_rank self.tp_rank = tp_rank
self.pp_rank = pp_rank
# Init model and tokenizer # Init model and tokenizer
self.model_config = ModelConfig( self.model_config = ModelConfig(
...@@ -73,12 +77,15 @@ class TpModelWorker: ...@@ -73,12 +77,15 @@ class TpModelWorker:
quantization=server_args.quantization, quantization=server_args.quantization,
is_draft_model=is_draft_worker, is_draft_model=is_draft_worker,
) )
self.model_runner = ModelRunner( self.model_runner = ModelRunner(
model_config=self.model_config, model_config=self.model_config,
mem_fraction_static=server_args.mem_fraction_static, mem_fraction_static=server_args.mem_fraction_static,
gpu_id=gpu_id, gpu_id=gpu_id,
tp_rank=tp_rank, tp_rank=tp_rank,
tp_size=server_args.tp_size, tp_size=server_args.tp_size,
pp_rank=pp_rank,
pp_size=server_args.pp_size,
nccl_port=nccl_port, nccl_port=nccl_port,
server_args=server_args, server_args=server_args,
is_draft_worker=is_draft_worker, is_draft_worker=is_draft_worker,
...@@ -105,6 +112,10 @@ class TpModelWorker: ...@@ -105,6 +112,10 @@ class TpModelWorker:
) )
self.device = self.model_runner.device self.device = self.model_runner.device
# Init nccl groups
self.pp_group = get_pp_group()
self.world_group = get_world_group()
# Profile number of tokens # Profile number of tokens
self.max_total_num_tokens = self.model_runner.max_total_num_tokens self.max_total_num_tokens = self.model_runner.max_total_num_tokens
self.max_prefill_tokens = server_args.max_prefill_tokens self.max_prefill_tokens = server_args.max_prefill_tokens
...@@ -130,8 +141,9 @@ class TpModelWorker: ...@@ -130,8 +141,9 @@ class TpModelWorker:
# Sync random seed across TP workers # Sync random seed across TP workers
self.random_seed = broadcast_pyobj( self.random_seed = broadcast_pyobj(
[server_args.random_seed], [server_args.random_seed],
self.tp_rank, self.tp_size * self.pp_rank + tp_rank,
self.model_runner.tp_group.cpu_group, self.world_group.cpu_group,
src=self.world_group.ranks[0],
)[0] )[0]
set_random_seed(self.random_seed) set_random_seed(self.random_seed)
...@@ -156,11 +168,14 @@ class TpModelWorker: ...@@ -156,11 +168,14 @@ class TpModelWorker:
def get_pad_input_ids_func(self): def get_pad_input_ids_func(self):
return getattr(self.model_runner.model, "pad_input_ids", None) return getattr(self.model_runner.model, "pad_input_ids", None)
def get_tp_cpu_group(self): def get_tp_group(self):
return self.model_runner.tp_group.cpu_group return self.model_runner.tp_group
def get_attention_tp_group(self):
return self.model_runner.attention_tp_group
def get_attention_tp_cpu_group(self): def get_attention_tp_cpu_group(self):
return self.model_runner.attention_tp_group.cpu_group return getattr(self.model_runner.attention_tp_group, "cpu_group", None)
def get_memory_pool(self): def get_memory_pool(self):
return ( return (
...@@ -172,19 +187,38 @@ class TpModelWorker: ...@@ -172,19 +187,38 @@ class TpModelWorker:
self, self,
model_worker_batch: ModelWorkerBatch, model_worker_batch: ModelWorkerBatch,
skip_sample: bool = False, skip_sample: bool = False,
) -> Tuple[LogitsProcessorOutput, Optional[torch.Tensor]]: ) -> Tuple[Union[LogitsProcessorOutput, torch.Tensor], Optional[torch.Tensor]]:
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
logits_output = self.model_runner.forward(forward_batch)
if model_worker_batch.launch_done is not None: pp_proxy_tensors = None
model_worker_batch.launch_done.set() if not self.pp_group.is_first_rank:
pp_proxy_tensors = PPProxyTensors(
self.pp_group.recv_tensor_dict(
all_gather_group=self.get_attention_tp_group()
)
)
if self.pp_group.is_last_rank:
logits_output = self.model_runner.forward(
forward_batch, pp_proxy_tensors=pp_proxy_tensors
)
if model_worker_batch.launch_done is not None:
model_worker_batch.launch_done.set()
if skip_sample: if skip_sample:
next_token_ids = None next_token_ids = None
else: else:
next_token_ids = self.model_runner.sample(logits_output, model_worker_batch) next_token_ids = self.model_runner.sample(
logits_output, model_worker_batch
)
return logits_output, next_token_ids return logits_output, next_token_ids
else:
pp_proxy_tensors = self.model_runner.forward(
forward_batch,
pp_proxy_tensors=pp_proxy_tensors,
)
return pp_proxy_tensors.tensors, None
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch): def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
......
...@@ -56,11 +56,14 @@ class TpModelWorkerClient: ...@@ -56,11 +56,14 @@ class TpModelWorkerClient:
server_args: ServerArgs, server_args: ServerArgs,
gpu_id: int, gpu_id: int,
tp_rank: int, tp_rank: int,
pp_rank: int,
dp_rank: Optional[int], dp_rank: Optional[int],
nccl_port: int, nccl_port: int,
): ):
# Load the model # Load the model
self.worker = TpModelWorker(server_args, gpu_id, tp_rank, dp_rank, nccl_port) self.worker = TpModelWorker(
server_args, gpu_id, tp_rank, pp_rank, dp_rank, nccl_port
)
self.max_running_requests = self.worker.max_running_requests self.max_running_requests = self.worker.max_running_requests
self.device = self.worker.device self.device = self.worker.device
self.gpu_id = gpu_id self.gpu_id = gpu_id
...@@ -91,8 +94,11 @@ class TpModelWorkerClient: ...@@ -91,8 +94,11 @@ class TpModelWorkerClient:
def get_pad_input_ids_func(self): def get_pad_input_ids_func(self):
return self.worker.get_pad_input_ids_func() return self.worker.get_pad_input_ids_func()
def get_tp_cpu_group(self): def get_tp_group(self):
return self.worker.get_tp_cpu_group() return self.worker.get_tp_group()
def get_attention_tp_group(self):
return self.worker.get_attention_tp_group()
def get_attention_tp_cpu_group(self): def get_attention_tp_cpu_group(self):
return self.worker.get_attention_tp_cpu_group() return self.worker.get_attention_tp_cpu_group()
......
...@@ -214,6 +214,8 @@ class MHATokenToKVPool(KVCache): ...@@ -214,6 +214,8 @@ class MHATokenToKVPool(KVCache):
layer_num: int, layer_num: int,
device: str, device: str,
enable_memory_saver: bool, enable_memory_saver: bool,
start_layer: Optional[int] = None,
end_layer: Optional[int] = None,
): ):
self.size = size self.size = size
self.page_size = page_size self.page_size = page_size
...@@ -232,6 +234,8 @@ class MHATokenToKVPool(KVCache): ...@@ -232,6 +234,8 @@ class MHATokenToKVPool(KVCache):
self.head_dim = head_dim self.head_dim = head_dim
self.layer_num = layer_num self.layer_num = layer_num
self._create_buffers() self._create_buffers()
self.start_layer = start_layer or 0
self.end_layer = end_layer or layer_num - 1
self.layer_transfer_counter = None self.layer_transfer_counter = None
self.capture_mode = False self.capture_mode = False
...@@ -281,6 +285,8 @@ class MHATokenToKVPool(KVCache): ...@@ -281,6 +285,8 @@ class MHATokenToKVPool(KVCache):
# for disagg # for disagg
def get_contiguous_buf_infos(self): def get_contiguous_buf_infos(self):
# layer_num x [seq_len, head_num, head_dim]
# layer_num x [page_num, page_size, head_num, head_dim]
kv_data_ptrs = [ kv_data_ptrs = [
self.get_key_buffer(i).data_ptr() for i in range(self.layer_num) self.get_key_buffer(i).data_ptr() for i in range(self.layer_num)
] + [self.get_value_buffer(i).data_ptr() for i in range(self.layer_num)] ] + [self.get_value_buffer(i).data_ptr() for i in range(self.layer_num)]
...@@ -320,24 +326,24 @@ class MHATokenToKVPool(KVCache): ...@@ -320,24 +326,24 @@ class MHATokenToKVPool(KVCache):
# transfer prepared data from host to device # transfer prepared data from host to device
flat_data = flat_data.to(device=self.device, non_blocking=False) flat_data = flat_data.to(device=self.device, non_blocking=False)
k_data, v_data = flat_data[0], flat_data[1] k_data, v_data = flat_data[0], flat_data[1]
self.k_buffer[layer_id][indices] = k_data self.k_buffer[layer_id - self.start_layer][indices] = k_data
self.v_buffer[layer_id][indices] = v_data self.v_buffer[layer_id - self.start_layer][indices] = v_data
def get_key_buffer(self, layer_id: int): def get_key_buffer(self, layer_id: int):
if self.layer_transfer_counter is not None: if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id) self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
if self.store_dtype != self.dtype: if self.store_dtype != self.dtype:
return self.k_buffer[layer_id].view(self.dtype) return self.k_buffer[layer_id - self.start_layer].view(self.dtype)
return self.k_buffer[layer_id] return self.k_buffer[layer_id - self.start_layer]
def get_value_buffer(self, layer_id: int): def get_value_buffer(self, layer_id: int):
if self.layer_transfer_counter is not None: if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id) self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
if self.store_dtype != self.dtype: if self.store_dtype != self.dtype:
return self.v_buffer[layer_id].view(self.dtype) return self.v_buffer[layer_id - self.start_layer].view(self.dtype)
return self.v_buffer[layer_id] return self.v_buffer[layer_id - self.start_layer]
def get_kv_buffer(self, layer_id: int): def get_kv_buffer(self, layer_id: int):
return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id) return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
...@@ -369,12 +375,12 @@ class MHATokenToKVPool(KVCache): ...@@ -369,12 +375,12 @@ class MHATokenToKVPool(KVCache):
current_stream = self.device_module.current_stream() current_stream = self.device_module.current_stream()
self.alt_stream.wait_stream(current_stream) self.alt_stream.wait_stream(current_stream)
with self.device_module.stream(self.alt_stream): with self.device_module.stream(self.alt_stream):
self.k_buffer[layer_id][loc] = cache_k self.k_buffer[layer_id - self.start_layer][loc] = cache_k
self.v_buffer[layer_id][loc] = cache_v self.v_buffer[layer_id - self.start_layer][loc] = cache_v
current_stream.wait_stream(self.alt_stream) current_stream.wait_stream(self.alt_stream)
else: else:
self.k_buffer[layer_id][loc] = cache_k self.k_buffer[layer_id - self.start_layer][loc] = cache_k
self.v_buffer[layer_id][loc] = cache_v self.v_buffer[layer_id - self.start_layer][loc] = cache_v
@torch.compile @torch.compile
...@@ -484,6 +490,8 @@ class MLATokenToKVPool(KVCache): ...@@ -484,6 +490,8 @@ class MLATokenToKVPool(KVCache):
layer_num: int, layer_num: int,
device: str, device: str,
enable_memory_saver: bool, enable_memory_saver: bool,
start_layer: Optional[int] = None,
end_layer: Optional[int] = None,
): ):
self.size = size self.size = size
self.page_size = page_size self.page_size = page_size
...@@ -497,6 +505,8 @@ class MLATokenToKVPool(KVCache): ...@@ -497,6 +505,8 @@ class MLATokenToKVPool(KVCache):
self.kv_lora_rank = kv_lora_rank self.kv_lora_rank = kv_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim self.qk_rope_head_dim = qk_rope_head_dim
self.layer_num = layer_num self.layer_num = layer_num
self.start_layer = start_layer or 0
self.end_layer = end_layer or layer_num - 1
memory_saver_adapter = TorchMemorySaverAdapter.create( memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=enable_memory_saver enable=enable_memory_saver
...@@ -540,19 +550,21 @@ class MLATokenToKVPool(KVCache): ...@@ -540,19 +550,21 @@ class MLATokenToKVPool(KVCache):
def get_key_buffer(self, layer_id: int): def get_key_buffer(self, layer_id: int):
if self.layer_transfer_counter is not None: if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id) self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
if self.store_dtype != self.dtype: if self.store_dtype != self.dtype:
return self.kv_buffer[layer_id].view(self.dtype) return self.kv_buffer[layer_id - self.start_layer].view(self.dtype)
return self.kv_buffer[layer_id] return self.kv_buffer[layer_id - self.start_layer]
def get_value_buffer(self, layer_id: int): def get_value_buffer(self, layer_id: int):
if self.layer_transfer_counter is not None: if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id) self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
if self.store_dtype != self.dtype: if self.store_dtype != self.dtype:
return self.kv_buffer[layer_id][..., : self.kv_lora_rank].view(self.dtype) return self.kv_buffer[layer_id - self.start_layer][
return self.kv_buffer[layer_id][..., : self.kv_lora_rank] ..., : self.kv_lora_rank
].view(self.dtype)
return self.kv_buffer[layer_id - self.start_layer][..., : self.kv_lora_rank]
def get_kv_buffer(self, layer_id: int): def get_kv_buffer(self, layer_id: int):
return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id) return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
...@@ -568,9 +580,11 @@ class MLATokenToKVPool(KVCache): ...@@ -568,9 +580,11 @@ class MLATokenToKVPool(KVCache):
if cache_k.dtype != self.dtype: if cache_k.dtype != self.dtype:
cache_k = cache_k.to(self.dtype) cache_k = cache_k.to(self.dtype)
if self.store_dtype != self.dtype: if self.store_dtype != self.dtype:
self.kv_buffer[layer_id][loc] = cache_k.view(self.store_dtype) self.kv_buffer[layer_id - self.start_layer][loc] = cache_k.view(
self.store_dtype
)
else: else:
self.kv_buffer[layer_id][loc] = cache_k self.kv_buffer[layer_id - self.start_layer][loc] = cache_k
def set_mla_kv_buffer( def set_mla_kv_buffer(
self, self,
...@@ -605,7 +619,7 @@ class MLATokenToKVPool(KVCache): ...@@ -605,7 +619,7 @@ class MLATokenToKVPool(KVCache):
def transfer_per_layer(self, indices, flat_data, layer_id): def transfer_per_layer(self, indices, flat_data, layer_id):
# transfer prepared data from host to device # transfer prepared data from host to device
flat_data = flat_data.to(device=self.device, non_blocking=False) flat_data = flat_data.to(device=self.device, non_blocking=False)
self.kv_buffer[layer_id][indices] = flat_data self.kv_buffer[layer_id - self.start_layer][indices] = flat_data
class DoubleSparseTokenToKVPool(KVCache): class DoubleSparseTokenToKVPool(KVCache):
...@@ -620,6 +634,8 @@ class DoubleSparseTokenToKVPool(KVCache): ...@@ -620,6 +634,8 @@ class DoubleSparseTokenToKVPool(KVCache):
device: str, device: str,
heavy_channel_num: int, heavy_channel_num: int,
enable_memory_saver: bool, enable_memory_saver: bool,
start_layer: Optional[int] = None,
end_layer: Optional[int] = None,
): ):
self.size = size self.size = size
self.page_size = page_size self.page_size = page_size
...@@ -657,17 +673,23 @@ class DoubleSparseTokenToKVPool(KVCache): ...@@ -657,17 +673,23 @@ class DoubleSparseTokenToKVPool(KVCache):
for _ in range(layer_num) for _ in range(layer_num)
] ]
self.start_layer = start_layer or 0
self.end_layer = end_layer or layer_num - 1
def get_key_buffer(self, layer_id: int): def get_key_buffer(self, layer_id: int):
return self.k_buffer[layer_id] return self.k_buffer[layer_id - self.start_layer]
def get_value_buffer(self, layer_id: int): def get_value_buffer(self, layer_id: int):
return self.v_buffer[layer_id] return self.v_buffer[layer_id - self.start_layer]
def get_label_buffer(self, layer_id: int): def get_label_buffer(self, layer_id: int):
return self.label_buffer[layer_id] return self.label_buffer[layer_id - self.start_layer]
def get_kv_buffer(self, layer_id: int): def get_kv_buffer(self, layer_id: int):
return self.k_buffer[layer_id], self.v_buffer[layer_id] return (
self.k_buffer[layer_id - self.start_layer],
self.v_buffer[layer_id - self.start_layer],
)
def set_kv_buffer( def set_kv_buffer(
self, self,
...@@ -679,9 +701,9 @@ class DoubleSparseTokenToKVPool(KVCache): ...@@ -679,9 +701,9 @@ class DoubleSparseTokenToKVPool(KVCache):
): ):
# NOTE(Andy): ignore the dtype check # NOTE(Andy): ignore the dtype check
layer_id = layer.layer_id layer_id = layer.layer_id
self.k_buffer[layer_id][loc] = cache_k self.k_buffer[layer_id - self.start_layer][loc] = cache_k
self.v_buffer[layer_id][loc] = cache_v self.v_buffer[layer_id - self.start_layer][loc] = cache_v
self.label_buffer[layer_id][loc] = cache_label self.label_buffer[layer_id - self.start_layer][loc] = cache_label
def get_flat_data(self, indices): def get_flat_data(self, indices):
pass pass
...@@ -930,7 +952,7 @@ class MHATokenToKVPoolHost(HostKVCache): ...@@ -930,7 +952,7 @@ class MHATokenToKVPoolHost(HostKVCache):
return self.kv_buffer[:, :, indices] return self.kv_buffer[:, :, indices]
def get_flat_data_by_layer(self, indices, layer_id): def get_flat_data_by_layer(self, indices, layer_id):
return self.kv_buffer[:, layer_id, indices] return self.kv_buffer[:, layer_id - self.start_layer, indices]
def assign_flat_data(self, indices, flat_data): def assign_flat_data(self, indices, flat_data):
self.kv_buffer[:, :, indices] = flat_data self.kv_buffer[:, :, indices] = flat_data
...@@ -955,12 +977,20 @@ class MHATokenToKVPoolHost(HostKVCache): ...@@ -955,12 +977,20 @@ class MHATokenToKVPoolHost(HostKVCache):
for i in range(len(device_indices_cpu)): for i in range(len(device_indices_cpu)):
h_index = host_indices[i * self.page_size] h_index = host_indices[i * self.page_size]
d_index = device_indices_cpu[i] d_index = device_indices_cpu[i]
device_pool.k_buffer[layer_id][d_index : d_index + self.page_size].copy_( device_pool.k_buffer[layer_id - self.start_layer][
self.kv_buffer[0, layer_id, h_index : h_index + self.page_size], d_index : d_index + self.page_size
].copy_(
self.kv_buffer[
0, layer_id - self.start_layer, h_index : h_index + self.page_size
],
non_blocking=True, non_blocking=True,
) )
device_pool.v_buffer[layer_id][d_index : d_index + self.page_size].copy_( device_pool.v_buffer[layer_id - self.start_layer][
self.kv_buffer[1, layer_id, h_index : h_index + self.page_size], d_index : d_index + self.page_size
].copy_(
self.kv_buffer[
1, layer_id - self.start_layer, h_index : h_index + self.page_size
],
non_blocking=True, non_blocking=True,
) )
...@@ -1015,7 +1045,7 @@ class MLATokenToKVPoolHost(HostKVCache): ...@@ -1015,7 +1045,7 @@ class MLATokenToKVPoolHost(HostKVCache):
return self.kv_buffer[:, indices] return self.kv_buffer[:, indices]
def get_flat_data_by_layer(self, indices, layer_id): def get_flat_data_by_layer(self, indices, layer_id):
return self.kv_buffer[layer_id, indices] return self.kv_buffer[layer_id - self.start_layer, indices]
def assign_flat_data(self, indices, flat_data): def assign_flat_data(self, indices, flat_data):
self.kv_buffer[:, indices] = flat_data self.kv_buffer[:, indices] = flat_data
...@@ -1036,7 +1066,11 @@ class MLATokenToKVPoolHost(HostKVCache): ...@@ -1036,7 +1066,11 @@ class MLATokenToKVPoolHost(HostKVCache):
for i in range(len(device_indices_cpu)): for i in range(len(device_indices_cpu)):
h_index = host_indices[i * self.page_size] h_index = host_indices[i * self.page_size]
d_index = device_indices_cpu[i] d_index = device_indices_cpu[i]
device_pool.kv_buffer[layer_id][d_index : d_index + self.page_size].copy_( device_pool.kv_buffer[layer_id - self.start_layer][
self.kv_buffer[layer_id, h_index : h_index + self.page_size], d_index : d_index + self.page_size
].copy_(
self.kv_buffer[
layer_id - self.start_layer, h_index : h_index + self.page_size
],
non_blocking=True, non_blocking=True,
) )
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
from __future__ import annotations from __future__ import annotations
import bisect import bisect
import inspect
import os import os
from contextlib import contextmanager from contextlib import contextmanager
from typing import TYPE_CHECKING, Callable from typing import TYPE_CHECKING, Callable
...@@ -33,12 +34,14 @@ from sglang.srt.model_executor.forward_batch_info import ( ...@@ -33,12 +34,14 @@ from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode, CaptureHiddenMode,
ForwardBatch, ForwardBatch,
ForwardMode, ForwardMode,
PPProxyTensors,
) )
from sglang.srt.patch_torch import monkey_patch_torch_compile from sglang.srt.patch_torch import monkey_patch_torch_compile
from sglang.srt.utils import ( from sglang.srt.utils import (
get_available_gpu_memory, get_available_gpu_memory,
get_device_memory_capacity, get_device_memory_capacity,
is_hip, is_hip,
rank0_log,
) )
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -188,10 +191,11 @@ class CudaGraphRunner: ...@@ -188,10 +191,11 @@ class CudaGraphRunner:
self.speculative_algorithm = model_runner.server_args.speculative_algorithm self.speculative_algorithm = model_runner.server_args.speculative_algorithm
self.tp_size = model_runner.server_args.tp_size self.tp_size = model_runner.server_args.tp_size
self.dp_size = model_runner.server_args.dp_size self.dp_size = model_runner.server_args.dp_size
self.pp_size = model_runner.server_args.pp_size
# Batch sizes to capture # Batch sizes to capture
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner) self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
rank0_log(f"Capture cuda graph bs {self.capture_bs}")
self.capture_forward_mode = ForwardMode.DECODE self.capture_forward_mode = ForwardMode.DECODE
self.capture_hidden_mode = CaptureHiddenMode.NULL self.capture_hidden_mode = CaptureHiddenMode.NULL
self.num_tokens_per_bs = 1 self.num_tokens_per_bs = 1
...@@ -234,6 +238,19 @@ class CudaGraphRunner: ...@@ -234,6 +238,19 @@ class CudaGraphRunner:
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64) self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64) self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64)
# pipeline parallelism
if self.pp_size > 1:
self.pp_proxy_tensors = {
"hidden_states": torch.zeros(
(self.max_bs, self.model_runner.model_config.hidden_size),
dtype=torch.bfloat16,
),
"residual": torch.zeros(
(self.max_bs, self.model_runner.model_config.hidden_size),
dtype=torch.bfloat16,
),
}
# Speculative_inference # Speculative_inference
if ( if (
model_runner.spec_algorithm.is_eagle3() model_runner.spec_algorithm.is_eagle3()
...@@ -384,6 +401,12 @@ class CudaGraphRunner: ...@@ -384,6 +401,12 @@ class CudaGraphRunner:
encoder_lens = None encoder_lens = None
mrope_positions = self.mrope_positions[:, :bs] mrope_positions = self.mrope_positions[:, :bs]
# pipeline parallelism
if self.pp_size > 1:
pp_proxy_tensors = PPProxyTensors(
{k: v[:num_tokens] for k, v in self.pp_proxy_tensors.items()}
)
if self.enable_dp_attention or self.enable_sp_layernorm: if self.enable_dp_attention or self.enable_sp_layernorm:
self.global_num_tokens_gpu.copy_( self.global_num_tokens_gpu.copy_(
torch.tensor( torch.tensor(
...@@ -456,8 +479,20 @@ class CudaGraphRunner: ...@@ -456,8 +479,20 @@ class CudaGraphRunner:
# Clean intermediate result cache for DP attention # Clean intermediate result cache for DP attention
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
logits_output = forward(input_ids, forward_batch.positions, forward_batch) kwargs = {}
return logits_output.next_token_logits, logits_output.hidden_states if (
self.pp_size > 1
and "pp_proxy_tensors" in inspect.signature(forward).parameters
):
kwargs["pp_proxy_tensors"] = pp_proxy_tensors
logits_output_or_pp_proxy_tensors = forward(
input_ids,
forward_batch.positions,
forward_batch,
**kwargs,
)
return logits_output_or_pp_proxy_tensors
for _ in range(2): for _ in range(2):
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -490,7 +525,11 @@ class CudaGraphRunner: ...@@ -490,7 +525,11 @@ class CudaGraphRunner:
self.capture_hidden_mode = hidden_mode_from_spec_info self.capture_hidden_mode = hidden_mode_from_spec_info
self.capture() self.capture()
def replay_prepare(self, forward_batch: ForwardBatch): def replay_prepare(
self,
forward_batch: ForwardBatch,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
):
self.recapture_if_needed(forward_batch) self.recapture_if_needed(forward_batch)
raw_bs = forward_batch.batch_size raw_bs = forward_batch.batch_size
...@@ -519,6 +558,11 @@ class CudaGraphRunner: ...@@ -519,6 +558,11 @@ class CudaGraphRunner:
self.seq_lens_cpu.fill_(1) self.seq_lens_cpu.fill_(1)
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu) self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
if pp_proxy_tensors:
for key in self.pp_proxy_tensors.keys():
dim = pp_proxy_tensors[key].shape[0]
self.pp_proxy_tensors[key][:dim].copy_(pp_proxy_tensors[key])
if self.is_encoder_decoder: if self.is_encoder_decoder:
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens) self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
if forward_batch.mrope_positions is not None: if forward_batch.mrope_positions is not None:
...@@ -547,10 +591,13 @@ class CudaGraphRunner: ...@@ -547,10 +591,13 @@ class CudaGraphRunner:
self.bs = bs self.bs = bs
def replay( def replay(
self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False self,
) -> LogitsProcessorOutput: forward_batch: ForwardBatch,
skip_attn_backend_init: bool = False,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
) -> Union[LogitsProcessorOutput, PPProxyTensors]:
if not skip_attn_backend_init: if not skip_attn_backend_init:
self.replay_prepare(forward_batch) self.replay_prepare(forward_batch, pp_proxy_tensors)
else: else:
# In speculative decoding, these two fields are still needed. # In speculative decoding, these two fields are still needed.
self.input_ids[: self.raw_num_token].copy_(forward_batch.input_ids) self.input_ids[: self.raw_num_token].copy_(forward_batch.input_ids)
...@@ -558,17 +605,19 @@ class CudaGraphRunner: ...@@ -558,17 +605,19 @@ class CudaGraphRunner:
# Replay # Replay
self.graphs[self.bs].replay() self.graphs[self.bs].replay()
next_token_logits, hidden_states = self.output_buffers[self.bs] output = self.output_buffers[self.bs]
if isinstance(output, LogitsProcessorOutput):
logits_output = LogitsProcessorOutput( return LogitsProcessorOutput(
next_token_logits=next_token_logits[: self.raw_num_token], next_token_logits=output.next_token_logits[: self.raw_num_token],
hidden_states=( hidden_states=(
hidden_states[: self.raw_num_token] output.hidden_states[: self.raw_num_token]
if hidden_states is not None if output.hidden_states is not None
else None else None
), ),
) )
return logits_output else:
assert isinstance(output, PPProxyTensors)
return PPProxyTensors({k: v[: self.bs] for k, v in output.tensors.items()})
def get_spec_info(self, num_tokens: int): def get_spec_info(self, num_tokens: int):
spec_info = None spec_info = None
......
...@@ -31,7 +31,7 @@ from __future__ import annotations ...@@ -31,7 +31,7 @@ from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from enum import IntEnum, auto from enum import IntEnum, auto
from typing import TYPE_CHECKING, List, Optional, Union from typing import TYPE_CHECKING, Dict, List, Optional, Union
import torch import torch
import triton import triton
...@@ -585,6 +585,36 @@ class ForwardBatch: ...@@ -585,6 +585,36 @@ class ForwardBatch:
self.prepare_chunked_kv_indices(device) self.prepare_chunked_kv_indices(device)
class PPProxyTensors:
# adapted from https://github.com/vllm-project/vllm/blob/d14e98d924724b284dc5eaf8070d935e214e50c0/vllm/sequence.py#L1103
tensors: Dict[str, torch.Tensor]
def __init__(self, tensors):
# manually define this function, so that
# Dynamo knows `IntermediateTensors()` comes from this file.
# Otherwise, dataclass will generate this function by evaluating
# a string, and we will lose the information about the source file.
self.tensors = tensors
def __getitem__(self, key: Union[str, slice]):
if isinstance(key, str):
return self.tensors[key]
elif isinstance(key, slice):
return self.__class__({k: v[key] for k, v in self.tensors.items()})
def __setitem__(self, key: str, value: torch.Tensor):
self.tensors[key] = value
def __len__(self):
return len(self.tensors)
def __eq__(self, other: object):
return isinstance(other, self.__class__) and self
def __repr__(self) -> str:
return f"PPProxyTensors(tensors={self.tensors})"
def compute_position_triton( def compute_position_triton(
extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum
): ):
......
...@@ -13,8 +13,10 @@ ...@@ -13,8 +13,10 @@
# ============================================================================== # ==============================================================================
"""ModelRunner runs the forward passes of the models.""" """ModelRunner runs the forward passes of the models."""
import collections
import datetime import datetime
import gc import gc
import inspect
import json import json
import logging import logging
import os import os
...@@ -59,7 +61,7 @@ from sglang.srt.mem_cache.memory_pool import ( ...@@ -59,7 +61,7 @@ from sglang.srt.mem_cache.memory_pool import (
) )
from sglang.srt.mem_cache.paged_allocator import PagedTokenToKVPoolAllocator from sglang.srt.mem_cache.paged_allocator import PagedTokenToKVPoolAllocator
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader import get_model from sglang.srt.model_loader import get_model
from sglang.srt.model_loader.loader import ( from sglang.srt.model_loader.loader import (
DefaultModelLoader, DefaultModelLoader,
...@@ -111,6 +113,8 @@ class ModelRunner: ...@@ -111,6 +113,8 @@ class ModelRunner:
gpu_id: int, gpu_id: int,
tp_rank: int, tp_rank: int,
tp_size: int, tp_size: int,
pp_rank: int,
pp_size: int,
nccl_port: int, nccl_port: int,
server_args: ServerArgs, server_args: ServerArgs,
is_draft_worker: bool = False, is_draft_worker: bool = False,
...@@ -124,6 +128,8 @@ class ModelRunner: ...@@ -124,6 +128,8 @@ class ModelRunner:
self.gpu_id = gpu_id self.gpu_id = gpu_id
self.tp_rank = tp_rank self.tp_rank = tp_rank
self.tp_size = tp_size self.tp_size = tp_size
self.pp_rank = pp_rank
self.pp_size = pp_size
self.dist_port = nccl_port self.dist_port = nccl_port
self.server_args = server_args self.server_args = server_args
self.is_draft_worker = is_draft_worker self.is_draft_worker = is_draft_worker
...@@ -149,24 +155,24 @@ class ModelRunner: ...@@ -149,24 +155,24 @@ class ModelRunner:
global_server_args_dict.update( global_server_args_dict.update(
{ {
"attention_backend": server_args.attention_backend, "attention_backend": server_args.attention_backend,
"sampling_backend": server_args.sampling_backend, "debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32, "debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
"torchao_config": server_args.torchao_config, "deepep_mode": server_args.deepep_mode,
"device": server_args.device,
"disable_chunked_prefix_cache": server_args.disable_chunked_prefix_cache,
"disable_radix_cache": server_args.disable_radix_cache,
"enable_nan_detection": server_args.enable_nan_detection, "enable_nan_detection": server_args.enable_nan_detection,
"enable_dp_attention": server_args.enable_dp_attention, "enable_dp_attention": server_args.enable_dp_attention,
"enable_ep_moe": server_args.enable_ep_moe, "enable_ep_moe": server_args.enable_ep_moe,
"enable_deepep_moe": server_args.enable_deepep_moe, "enable_deepep_moe": server_args.enable_deepep_moe,
"deepep_mode": server_args.deepep_mode,
"device": server_args.device,
"speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
"disable_radix_cache": server_args.disable_radix_cache,
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged, "flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
"moe_dense_tp_size": server_args.moe_dense_tp_size, "moe_dense_tp_size": server_args.moe_dense_tp_size,
"debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
"debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
"n_share_experts_fusion": server_args.n_share_experts_fusion, "n_share_experts_fusion": server_args.n_share_experts_fusion,
"disable_chunked_prefix_cache": server_args.disable_chunked_prefix_cache, "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
"torchao_config": server_args.torchao_config,
"sampling_backend": server_args.sampling_backend,
"speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
"use_mla_backend": self.use_mla_backend, "use_mla_backend": self.use_mla_backend,
} }
) )
...@@ -184,6 +190,11 @@ class ModelRunner: ...@@ -184,6 +190,11 @@ class ModelRunner:
# If it is a draft model, tp_group can be different # If it is a draft model, tp_group can be different
self.initialize(min_per_gpu_memory) self.initialize(min_per_gpu_memory)
# temporary cached values
self.support_pp = (
"pp_proxy_tensors" in inspect.signature(self.model.forward).parameters
)
def initialize(self, min_per_gpu_memory: float): def initialize(self, min_per_gpu_memory: float):
server_args = self.server_args server_args = self.server_args
self.memory_saver_adapter = TorchMemorySaverAdapter.create( self.memory_saver_adapter = TorchMemorySaverAdapter.create(
...@@ -194,6 +205,12 @@ class ModelRunner: ...@@ -194,6 +205,12 @@ class ModelRunner:
self.sampler = Sampler() self.sampler = Sampler()
self.load_model() self.load_model()
self.start_layer = getattr(self.model, "start_layer", 0)
self.end_layer = getattr(
self.model, "end_layer", self.model_config.num_hidden_layers
)
self.num_effective_layers = self.end_layer - self.start_layer
# Apply torchao quantization # Apply torchao quantization
torchao_applied = getattr(self.model, "torchao_applied", False) torchao_applied = getattr(self.model, "torchao_applied", False)
# In layered loading, torchao may have been applied # In layered loading, torchao may have been applied
...@@ -360,18 +377,22 @@ class ModelRunner: ...@@ -360,18 +377,22 @@ class ModelRunner:
# Only initialize the distributed environment on the target model worker. # Only initialize the distributed environment on the target model worker.
init_distributed_environment( init_distributed_environment(
backend=backend, backend=backend,
world_size=self.tp_size, world_size=self.tp_size * self.pp_size,
rank=self.tp_rank, rank=self.tp_size * self.pp_rank + self.tp_rank,
local_rank=self.gpu_id, local_rank=self.gpu_id,
distributed_init_method=dist_init_method, distributed_init_method=dist_init_method,
timeout=self.server_args.dist_timeout, timeout=self.server_args.dist_timeout,
) )
initialize_model_parallel(tensor_model_parallel_size=self.tp_size) initialize_model_parallel(
tensor_model_parallel_size=self.tp_size,
pipeline_model_parallel_size=self.pp_size,
)
initialize_dp_attention( initialize_dp_attention(
enable_dp_attention=self.server_args.enable_dp_attention, enable_dp_attention=self.server_args.enable_dp_attention,
tp_rank=self.tp_rank, tp_rank=self.tp_rank,
tp_size=self.tp_size, tp_size=self.tp_size,
dp_size=self.server_args.dp_size, dp_size=self.server_args.dp_size,
pp_size=self.server_args.pp_size,
) )
min_per_gpu_memory = get_available_gpu_memory( min_per_gpu_memory = get_available_gpu_memory(
...@@ -698,6 +719,8 @@ class ModelRunner: ...@@ -698,6 +719,8 @@ class ModelRunner:
if not self.is_draft_worker if not self.is_draft_worker
else self.model_config.hf_config.num_nextn_predict_layers else self.model_config.hf_config.num_nextn_predict_layers
) )
# FIXME: pipeline parallelism is not compatible with mla backend
assert self.pp_size == 1
cell_size = ( cell_size = (
(self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim) (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
* num_layers * num_layers
...@@ -707,7 +730,7 @@ class ModelRunner: ...@@ -707,7 +730,7 @@ class ModelRunner:
cell_size = ( cell_size = (
self.model_config.get_num_kv_heads(get_attention_tp_size()) self.model_config.get_num_kv_heads(get_attention_tp_size())
* self.model_config.head_dim * self.model_config.head_dim
* self.model_config.num_hidden_layers * self.num_effective_layers
* 2 * 2
* torch._utils._element_size(self.kv_cache_dtype) * torch._utils._element_size(self.kv_cache_dtype)
) )
...@@ -819,9 +842,11 @@ class ModelRunner: ...@@ -819,9 +842,11 @@ class ModelRunner:
self.model_config.num_hidden_layers self.model_config.num_hidden_layers
if not self.is_draft_worker if not self.is_draft_worker
else self.model_config.hf_config.num_nextn_predict_layers else self.model_config.hf_config.num_nextn_predict_layers
), ), # PP is not compatible with mla backend
device=self.device, device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver, enable_memory_saver=self.server_args.enable_memory_saver,
start_layer=self.start_layer,
end_layer=self.end_layer,
) )
elif self.server_args.enable_double_sparsity: elif self.server_args.enable_double_sparsity:
self.token_to_kv_pool = DoubleSparseTokenToKVPool( self.token_to_kv_pool = DoubleSparseTokenToKVPool(
...@@ -830,10 +855,12 @@ class ModelRunner: ...@@ -830,10 +855,12 @@ class ModelRunner:
dtype=self.kv_cache_dtype, dtype=self.kv_cache_dtype,
head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()), head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
head_dim=self.model_config.head_dim, head_dim=self.model_config.head_dim,
layer_num=self.model_config.num_hidden_layers, layer_num=self.num_effective_layers,
device=self.device, device=self.device,
heavy_channel_num=self.server_args.ds_heavy_channel_num, heavy_channel_num=self.server_args.ds_heavy_channel_num,
enable_memory_saver=self.server_args.enable_memory_saver, enable_memory_saver=self.server_args.enable_memory_saver,
start_layer=self.start_layer,
end_layer=self.end_layer,
) )
else: else:
self.token_to_kv_pool = MHATokenToKVPool( self.token_to_kv_pool = MHATokenToKVPool(
...@@ -842,9 +869,11 @@ class ModelRunner: ...@@ -842,9 +869,11 @@ class ModelRunner:
dtype=self.kv_cache_dtype, dtype=self.kv_cache_dtype,
head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()), head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
head_dim=self.model_config.head_dim, head_dim=self.model_config.head_dim,
layer_num=self.model_config.num_hidden_layers, layer_num=self.num_effective_layers,
device=self.device, device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver, enable_memory_saver=self.server_args.enable_memory_saver,
start_layer=self.start_layer,
end_layer=self.end_layer,
) )
if self.token_to_kv_pool_allocator is None: if self.token_to_kv_pool_allocator is None:
...@@ -957,7 +986,7 @@ class ModelRunner: ...@@ -957,7 +986,7 @@ class ModelRunner:
with open(self.server_args.ds_channel_config_path, "r") as f: with open(self.server_args.ds_channel_config_path, "r") as f:
channel_config = json.load(f) channel_config = json.load(f)
for i in range(self.model_config.num_hidden_layers): for i in range(self.start_layer, self.end_layer):
key = "model.layers." + str(i) + ".self_attn" + selected_channel key = "model.layers." + str(i) + ".self_attn" + selected_channel
self.sorted_channels.append( self.sorted_channels.append(
torch.tensor(channel_config[key])[ torch.tensor(channel_config[key])[
...@@ -997,64 +1026,82 @@ class ModelRunner: ...@@ -997,64 +1026,82 @@ class ModelRunner:
device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,)) device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
tensor_parallel(self.model, device_mesh) tensor_parallel(self.model, device_mesh)
def forward_decode(self, forward_batch: ForwardBatch): def forward_decode(
self, forward_batch: ForwardBatch, pp_proxy_tensors=None
) -> LogitsProcessorOutput:
self.attn_backend.init_forward_metadata(forward_batch) self.attn_backend.init_forward_metadata(forward_batch)
# FIXME: add pp_proxy_tensors arg to all models
kwargs = {}
if self.support_pp:
kwargs["pp_proxy_tensors"] = pp_proxy_tensors
return self.model.forward( return self.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch forward_batch.input_ids, forward_batch.positions, forward_batch, **kwargs
) )
def forward_extend( def forward_extend(
self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False self,
): forward_batch: ForwardBatch,
skip_attn_backend_init: bool = False,
pp_proxy_tensors=None,
) -> LogitsProcessorOutput:
if not skip_attn_backend_init: if not skip_attn_backend_init:
self.attn_backend.init_forward_metadata(forward_batch) self.attn_backend.init_forward_metadata(forward_batch)
if self.is_generation: kwargs = {}
if forward_batch.input_embeds is None: if self.support_pp:
return self.model.forward( kwargs["pp_proxy_tensors"] = pp_proxy_tensors
forward_batch.input_ids, forward_batch.positions, forward_batch if forward_batch.input_embeds is not None:
) kwargs["input_embeds"] = forward_batch.input_embeds.bfloat16()
else: if not self.is_generation:
return self.model.forward( kwargs["get_embedding"] = True
forward_batch.input_ids, return self.model.forward(
forward_batch.positions, forward_batch.input_ids,
forward_batch, forward_batch.positions,
input_embeds=forward_batch.input_embeds.bfloat16(), forward_batch,
) **kwargs,
else: )
# Only embedding models have get_embedding parameter
return self.model.forward(
forward_batch.input_ids,
forward_batch.positions,
forward_batch,
get_embedding=True,
)
def forward_idle(self, forward_batch: ForwardBatch): def forward_idle(
self, forward_batch: ForwardBatch, pp_proxy_tensors=None
) -> LogitsProcessorOutput:
kwargs = {}
if self.support_pp:
kwargs["pp_proxy_tensors"] = pp_proxy_tensors
return self.model.forward( return self.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch forward_batch.input_ids,
forward_batch.positions,
forward_batch,
**kwargs,
) )
def forward( def forward(
self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False self,
) -> LogitsProcessorOutput: forward_batch: ForwardBatch,
if ( skip_attn_backend_init: bool = False,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
) -> Union[LogitsProcessorOutput, PPProxyTensors]:
can_run_cuda_graph = bool(
forward_batch.forward_mode.is_cuda_graph() forward_batch.forward_mode.is_cuda_graph()
and self.cuda_graph_runner and self.cuda_graph_runner
and self.cuda_graph_runner.can_run(forward_batch) and self.cuda_graph_runner.can_run(forward_batch)
): )
if can_run_cuda_graph:
return self.cuda_graph_runner.replay( return self.cuda_graph_runner.replay(
forward_batch, skip_attn_backend_init=skip_attn_backend_init forward_batch,
skip_attn_backend_init=skip_attn_backend_init,
pp_proxy_tensors=pp_proxy_tensors,
) )
if forward_batch.forward_mode.is_decode(): if forward_batch.forward_mode.is_decode():
return self.forward_decode(forward_batch) return self.forward_decode(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
elif forward_batch.forward_mode.is_extend(): elif forward_batch.forward_mode.is_extend():
return self.forward_extend( return self.forward_extend(
forward_batch, skip_attn_backend_init=skip_attn_backend_init forward_batch,
skip_attn_backend_init=skip_attn_backend_init,
pp_proxy_tensors=pp_proxy_tensors,
) )
elif forward_batch.forward_mode.is_idle(): elif forward_batch.forward_mode.is_idle():
return self.forward_idle(forward_batch) return self.forward_idle(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
else: else:
raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}") raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
......
...@@ -17,13 +17,14 @@ ...@@ -17,13 +17,14 @@
"""Inference-only LLaMA model compatible with HuggingFace weights.""" """Inference-only LLaMA model compatible with HuggingFace weights."""
import logging import logging
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from transformers import LlamaConfig from transformers import LlamaConfig
from sglang.srt.distributed import ( from sglang.srt.distributed import (
get_pp_group,
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
) )
...@@ -39,11 +40,12 @@ from sglang.srt.layers.pooler import Pooler, PoolingType ...@@ -39,11 +40,12 @@ from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import ( from sglang.srt.model_loader.weight_utils import (
default_weight_loader, default_weight_loader,
kv_cache_scales_loader, kv_cache_scales_loader,
...@@ -275,21 +277,31 @@ class LlamaModel(nn.Module): ...@@ -275,21 +277,31 @@ class LlamaModel(nn.Module):
self.config = config self.config = config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding( self.pp_group = get_pp_group()
config.vocab_size, if self.pp_group.is_first_rank:
config.hidden_size, self.embed_tokens = VocabParallelEmbedding(
quant_config=quant_config, config.vocab_size,
prefix=add_prefix("embed_tokens", prefix), config.hidden_size,
) quant_config=quant_config,
self.layers = make_layers( prefix=add_prefix("embed_tokens", prefix),
)
else:
self.embed_tokens = PPMissingLayer()
self.layers, self.start_layer, self.end_layer = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda idx, prefix: LlamaDecoderLayer( lambda idx, prefix: LlamaDecoderLayer(
config=config, layer_id=idx, quant_config=quant_config, prefix=prefix config=config, quant_config=quant_config, layer_id=idx, prefix=prefix
), ),
pp_rank=self.pp_group.rank_in_group,
pp_size=self.pp_group.world_size,
prefix="model.layers", prefix="model.layers",
) )
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) if self.pp_group.is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer(return_tuple=True)
self.layers_to_capture = [] self.layers_to_capture = []
def forward( def forward(
...@@ -298,14 +310,23 @@ class LlamaModel(nn.Module): ...@@ -298,14 +310,23 @@ class LlamaModel(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: pp_proxy_tensors: Optional[PPProxyTensors] = None,
if input_embeds is None: ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]], PPProxyTensors]:
hidden_states = self.embed_tokens(input_ids) if self.pp_group.is_first_rank:
if input_embeds is None:
hidden_states = self.embed_tokens(input_ids)
else:
hidden_states = input_embeds
residual = None
else: else:
hidden_states = input_embeds assert pp_proxy_tensors is not None
residual = None # FIXME(@ying): reduce the number of proxy tensors by not fusing layer norms
hidden_states = pp_proxy_tensors["hidden_states"]
residual = pp_proxy_tensors["residual"]
deferred_norm = None
aux_hidden_states = [] aux_hidden_states = []
for i in range(len(self.layers)): for i in range(self.start_layer, self.end_layer):
if i in self.layers_to_capture: if i in self.layers_to_capture:
aux_hidden_states.append(hidden_states + residual) aux_hidden_states.append(hidden_states + residual)
layer = self.layers[i] layer = self.layers[i]
...@@ -315,7 +336,16 @@ class LlamaModel(nn.Module): ...@@ -315,7 +336,16 @@ class LlamaModel(nn.Module):
forward_batch, forward_batch,
residual, residual,
) )
hidden_states, _ = self.norm(hidden_states, residual)
if not self.pp_group.is_last_rank:
return PPProxyTensors(
{
"hidden_states": hidden_states,
"residual": residual,
}
)
else:
hidden_states, _ = self.norm(hidden_states, residual)
if len(aux_hidden_states) == 0: if len(aux_hidden_states) == 0:
return hidden_states return hidden_states
...@@ -376,6 +406,7 @@ class LlamaForCausalLM(nn.Module): ...@@ -376,6 +406,7 @@ class LlamaForCausalLM(nn.Module):
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.pp_group = get_pp_group()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = self._init_model(config, quant_config, add_prefix("model", prefix)) self.model = self._init_model(config, quant_config, add_prefix("model", prefix))
...@@ -419,23 +450,41 @@ class LlamaForCausalLM(nn.Module): ...@@ -419,23 +450,41 @@ class LlamaForCausalLM(nn.Module):
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
get_embedding: bool = False, get_embedding: bool = False,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
) -> LogitsProcessorOutput: ) -> LogitsProcessorOutput:
hidden_states = self.model(
input_ids,
positions,
forward_batch,
input_embeds,
pp_proxy_tensors=pp_proxy_tensors,
)
aux_hidden_states = None aux_hidden_states = None
if self.capture_aux_hidden_states: if self.capture_aux_hidden_states:
hidden_states, aux_hidden_states = self.model( hidden_states, aux_hidden_states = hidden_states
input_ids, positions, forward_batch, input_embeds
) if self.pp_group.is_last_rank:
if not get_embedding:
return self.logits_processor(
input_ids,
hidden_states,
self.lm_head,
forward_batch,
aux_hidden_states,
)
else:
return self.pooler(hidden_states, forward_batch)
else: else:
hidden_states = self.model( return hidden_states
input_ids, positions, forward_batch, input_embeds
)
if not get_embedding: @property
return self.logits_processor( def start_layer(self):
input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states return self.model.start_layer
)
else: @property
return self.pooler(hidden_states, forward_batch) def end_layer(self):
return self.model.end_layer
def get_input_embeddings(self) -> nn.Embedding: def get_input_embeddings(self) -> nn.Embedding:
return self.model.embed_tokens return self.model.embed_tokens
...@@ -491,6 +540,16 @@ class LlamaForCausalLM(nn.Module): ...@@ -491,6 +540,16 @@ class LlamaForCausalLM(nn.Module):
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
for name, loaded_weight in weights: for name, loaded_weight in weights:
layer_id = get_layer_id(name)
if (
layer_id is not None
and hasattr(self.model, "start_layer")
and (
layer_id < self.model.start_layer
or layer_id >= self.model.end_layer
)
):
continue
if "rotary_emb.inv_freq" in name or "projector" in name: if "rotary_emb.inv_freq" in name or "projector" in name:
continue continue
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
...@@ -637,6 +696,9 @@ class LlamaForCausalLM(nn.Module): ...@@ -637,6 +696,9 @@ class LlamaForCausalLM(nn.Module):
self.model.load_kv_cache_scales(quantization_param_path) self.model.load_kv_cache_scales(quantization_param_path)
def set_eagle3_layers_to_capture(self): def set_eagle3_layers_to_capture(self):
if not self.pp_group.is_last_rank:
return
self.capture_aux_hidden_states = True self.capture_aux_hidden_states = True
num_layers = self.config.num_hidden_layers num_layers = self.config.num_hidden_layers
self.model.layers_to_capture = [2, num_layers // 2, num_layers - 3] self.model.layers_to_capture = [2, num_layers // 2, num_layers - 3]
......
...@@ -46,7 +46,7 @@ from sglang.srt.layers.radix_attention import RadixAttention ...@@ -46,7 +46,7 @@ from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.models.llama import LlamaForCausalLM, LlamaMLP from sglang.srt.models.llama import LlamaForCausalLM, LlamaMLP
from sglang.srt.utils import add_prefix, fast_topk, get_compiler_backend, make_layers from sglang.srt.utils import add_prefix, fast_topk, get_compiler_backend, make_layers
...@@ -431,6 +431,7 @@ class Llama4Model(nn.Module): ...@@ -431,6 +431,7 @@ class Llama4Model(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
if input_embeds is None: if input_embeds is None:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
......
...@@ -25,13 +25,14 @@ import torch ...@@ -25,13 +25,14 @@ import torch
from torch import nn from torch import nn
from transformers import LlamaConfig from transformers import LlamaConfig
from sglang.srt.distributed import get_pp_group
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.models.llama import LlamaDecoderLayer, LlamaForCausalLM from sglang.srt.models.llama import LlamaDecoderLayer, LlamaForCausalLM
...@@ -86,6 +87,7 @@ class LlamaModel(nn.Module): ...@@ -86,6 +87,7 @@ class LlamaModel(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if input_embeds is None: if input_embeds is None:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
...@@ -118,6 +120,7 @@ class LlamaForCausalLMEagle(LlamaForCausalLM): ...@@ -118,6 +120,7 @@ class LlamaForCausalLMEagle(LlamaForCausalLM):
nn.Module.__init__(self) nn.Module.__init__(self)
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.pp_group = get_pp_group()
self.model = LlamaModel( self.model = LlamaModel(
config, quant_config=quant_config, prefix=add_prefix("model", prefix) config, quant_config=quant_config, prefix=add_prefix("model", prefix)
) )
......
...@@ -25,6 +25,7 @@ import torch ...@@ -25,6 +25,7 @@ import torch
from torch import nn from torch import nn
from transformers import LlamaConfig from transformers import LlamaConfig
from sglang.srt.distributed import get_pp_group
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import QKVParallelLinear, RowParallelLinear from sglang.srt.layers.linear import QKVParallelLinear, RowParallelLinear
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
...@@ -33,7 +34,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -33,7 +34,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.models.llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM from sglang.srt.models.llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM
...@@ -118,6 +119,7 @@ class LlamaModel(nn.Module): ...@@ -118,6 +119,7 @@ class LlamaModel(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if input_embeds is None: if input_embeds is None:
embeds = self.embed_tokens(input_ids) embeds = self.embed_tokens(input_ids)
...@@ -155,6 +157,7 @@ class LlamaForCausalLMEagle3(LlamaForCausalLM): ...@@ -155,6 +157,7 @@ class LlamaForCausalLMEagle3(LlamaForCausalLM):
nn.Module.__init__(self) nn.Module.__init__(self)
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.pp_group = get_pp_group()
if self.config.num_hidden_layers != 1: if self.config.num_hidden_layers != 1:
raise ValueError("EAGLE3 currently only supports 1 layer") raise ValueError("EAGLE3 currently only supports 1 layer")
......
...@@ -78,6 +78,8 @@ class ServerArgs: ...@@ -78,6 +78,8 @@ class ServerArgs:
# Other runtime options # Other runtime options
tp_size: int = 1 tp_size: int = 1
pp_size: int = 1
max_micro_batch_size: Optional[int] = None
stream_interval: int = 1 stream_interval: int = 1
stream_output: bool = False stream_output: bool = False
random_seed: Optional[int] = None random_seed: Optional[int] = None
...@@ -222,14 +224,18 @@ class ServerArgs: ...@@ -222,14 +224,18 @@ class ServerArgs:
# Set mem fraction static, which depends on the tensor parallelism size # Set mem fraction static, which depends on the tensor parallelism size
if self.mem_fraction_static is None: if self.mem_fraction_static is None:
if self.tp_size >= 16: parallel_size = self.tp_size * self.pp_size
self.mem_fraction_static = 0.79 if gpu_mem <= 81920:
elif self.tp_size >= 8: if parallel_size >= 16:
self.mem_fraction_static = 0.81 self.mem_fraction_static = 0.79
elif self.tp_size >= 4: elif parallel_size >= 8:
self.mem_fraction_static = 0.85 self.mem_fraction_static = 0.81
elif self.tp_size >= 2: elif parallel_size >= 4:
self.mem_fraction_static = 0.87 self.mem_fraction_static = 0.85
elif parallel_size >= 2:
self.mem_fraction_static = 0.87
else:
self.mem_fraction_static = 0.88
else: else:
self.mem_fraction_static = 0.88 self.mem_fraction_static = 0.88
if gpu_mem > 96 * 1024: if gpu_mem > 96 * 1024:
...@@ -244,6 +250,8 @@ class ServerArgs: ...@@ -244,6 +250,8 @@ class ServerArgs:
if self.chunked_prefill_size is None: if self.chunked_prefill_size is None:
if gpu_mem is not None and gpu_mem < 25_000: if gpu_mem is not None and gpu_mem < 25_000:
self.chunked_prefill_size = 2048 self.chunked_prefill_size = 2048
elif self.disaggregation_mode != "null":
self.chunked_prefill_size = 16384
else: else:
self.chunked_prefill_size = 8192 self.chunked_prefill_size = 8192
assert self.chunked_prefill_size % self.page_size == 0 assert self.chunked_prefill_size % self.page_size == 0
...@@ -643,6 +651,19 @@ class ServerArgs: ...@@ -643,6 +651,19 @@ class ServerArgs:
default=ServerArgs.tp_size, default=ServerArgs.tp_size,
help="The tensor parallelism size.", help="The tensor parallelism size.",
) )
parser.add_argument(
"--pipeline-parallel-size",
"--pp-size",
type=int,
default=ServerArgs.pp_size,
help="The pipeline parallelism size.",
)
parser.add_argument(
"--max-micro-batch-size",
type=int,
default=ServerArgs.max_micro_batch_size,
help="The maximum micro batch size in pipeline parallelism.",
)
parser.add_argument( parser.add_argument(
"--stream-interval", "--stream-interval",
type=int, type=int,
...@@ -1232,6 +1253,7 @@ class ServerArgs: ...@@ -1232,6 +1253,7 @@ class ServerArgs:
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace): def from_cli_args(cls, args: argparse.Namespace):
args.tp_size = args.tensor_parallel_size args.tp_size = args.tensor_parallel_size
args.pp_size = args.pipeline_parallel_size
args.dp_size = args.data_parallel_size args.dp_size = args.data_parallel_size
args.ep_size = args.expert_parallel_size args.ep_size = args.expert_parallel_size
attrs = [attr.name for attr in dataclasses.fields(cls)] attrs = [attr.name for attr in dataclasses.fields(cls)]
...@@ -1245,8 +1267,19 @@ class ServerArgs: ...@@ -1245,8 +1267,19 @@ class ServerArgs:
def check_server_args(self): def check_server_args(self):
assert ( assert (
self.tp_size % self.nnodes == 0 self.tp_size * self.pp_size
), "tp_size must be divisible by number of nodes" ) % self.nnodes == 0, "tp_size must be divisible by number of nodes"
# FIXME pp constraints
if self.pp_size > 1:
logger.warning(f"Turn off overlap scheule for pipeline parallelism.")
self.disable_overlap_schedule = True
assert (
self.disable_overlap_schedule
and self.speculative_algorithm is None
and not self.enable_mixed_chunk
), "Pipeline parallelism is not compatible with overlap schedule, speculative decoding, mixed chunked prefill."
assert not ( assert not (
self.dp_size > 1 and self.nnodes != 1 and not self.enable_dp_attention self.dp_size > 1 and self.nnodes != 1 and not self.enable_dp_attention
), "multi-node data parallel is not supported unless dp attention!" ), "multi-node data parallel is not supported unless dp attention!"
......
...@@ -106,11 +106,12 @@ class EAGLEWorker(TpModelWorker): ...@@ -106,11 +106,12 @@ class EAGLEWorker(TpModelWorker):
# Init draft worker # Init draft worker
with empty_context(): with empty_context():
super().__init__( super().__init__(
server_args=server_args,
gpu_id=gpu_id, gpu_id=gpu_id,
tp_rank=tp_rank, tp_rank=tp_rank,
server_args=server_args, pp_rank=0, # FIXME
nccl_port=nccl_port,
dp_rank=dp_rank, dp_rank=dp_rank,
nccl_port=nccl_port,
is_draft_worker=True, is_draft_worker=True,
req_to_token_pool=self.req_to_token_pool, req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment