Unverified Commit 11383cec authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

[PP] Add pipeline parallelism (#5724)

parent e97e57e6
......@@ -154,6 +154,8 @@ def load_model(server_args, port_args, tp_rank):
gpu_id=tp_rank,
tp_rank=tp_rank,
tp_size=server_args.tp_size,
pp_rank=0,
pp_size=1,
nccl_port=port_args.nccl_port,
server_args=server_args,
)
......
......@@ -126,7 +126,6 @@ class Engine(EngineBase):
server_args=server_args,
port_args=port_args,
)
self.server_args = server_args
self.tokenizer_manager = tokenizer_manager
self.scheduler_info = scheduler_info
......@@ -301,7 +300,6 @@ class Engine(EngineBase):
internal_states = loop.run_until_complete(
self.tokenizer_manager.get_internal_state()
)
return {
**dataclasses.asdict(self.tokenizer_manager.server_args),
**self.scheduler_info,
......@@ -520,25 +518,44 @@ def _launch_subprocesses(
)
scheduler_pipe_readers = []
tp_size_per_node = server_args.tp_size // server_args.nnodes
nnodes_per_tp_group = max(server_args.nnodes // server_args.pp_size, 1)
tp_size_per_node = server_args.tp_size // nnodes_per_tp_group
tp_rank_range = range(
tp_size_per_node * server_args.node_rank,
tp_size_per_node * (server_args.node_rank + 1),
tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group),
tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group + 1),
)
for tp_rank in tp_rank_range:
reader, writer = mp.Pipe(duplex=False)
gpu_id = (
server_args.base_gpu_id
+ (tp_rank % tp_size_per_node) * server_args.gpu_id_step
)
proc = mp.Process(
target=run_scheduler_process,
args=(server_args, port_args, gpu_id, tp_rank, None, writer),
)
with memory_saver_adapter.configure_subprocess():
proc.start()
scheduler_procs.append(proc)
scheduler_pipe_readers.append(reader)
pp_size_per_node = max(server_args.pp_size // server_args.nnodes, 1)
pp_rank_range = range(
pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group),
pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group + 1),
)
for pp_rank in pp_rank_range:
for tp_rank in tp_rank_range:
reader, writer = mp.Pipe(duplex=False)
gpu_id = (
server_args.base_gpu_id
+ ((pp_rank % pp_size_per_node) * tp_size_per_node)
+ (tp_rank % tp_size_per_node) * server_args.gpu_id_step
)
proc = mp.Process(
target=run_scheduler_process,
args=(
server_args,
port_args,
gpu_id,
tp_rank,
pp_rank,
None,
writer,
),
)
with memory_saver_adapter.configure_subprocess():
proc.start()
scheduler_procs.append(proc)
scheduler_pipe_readers.append(reader)
else:
# Launch the data parallel controller
reader, writer = mp.Pipe(duplex=False)
......
......@@ -43,6 +43,7 @@ def initialize_dp_attention(
tp_rank: int,
tp_size: int,
dp_size: int,
pp_size: int,
):
global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK, _DP_SIZE
......@@ -53,17 +54,19 @@ def initialize_dp_attention(
)
if enable_dp_attention:
local_rank = tp_rank % (tp_size // dp_size)
_DP_SIZE = dp_size
else:
local_rank = tp_rank
_DP_SIZE = 1
tp_group = get_tp_group()
_ATTN_TP_GROUP = GroupCoordinator(
[
list(range(head, head + _ATTN_TP_SIZE))
for head in range(0, tp_size, _ATTN_TP_SIZE)
for head in range(0, pp_size * tp_size, _ATTN_TP_SIZE)
],
tp_group.local_rank,
local_rank,
torch.distributed.get_backend(tp_group.device_group),
SYNC_TOKEN_IDS_ACROSS_TP,
False,
......
import logging
import re
import torch
logger = logging.getLogger(__name__)
def get_layer_id(weight_name):
# example weight name: model.layers.10.self_attn.qkv_proj.weight
match = re.search(r"layers\.(\d+)\.", weight_name)
if match:
return int(match.group(1))
return None
class PPMissingLayer(torch.nn.Identity):
# Adapted from
# https://github.com/vllm-project/vllm/blob/18ed3132d2bfe1df9a74729457b69243955221e8/vllm/model_executor/models/utils.py#L468C1-L486C1
"""
A placeholder layer for missing layers in a pipeline parallel model.
"""
def __init__(self, *args, **kwargs):
super().__init__()
self.return_tuple = kwargs.get("return_tuple", False)
def forward(self, *args, **kwargs):
"""
Return the first arg from args or the first value from kwargs.
Wraps the input in a tuple if `self.return_tuple` is True.
"""
input = args[0] if args else next(iter(kwargs.values()))
return (input,) if self.return_tuple else input
......@@ -181,44 +181,62 @@ class DataParallelController:
enable=server_args.enable_memory_saver
)
# Launch tensor parallel scheduler processes
scheduler_pipe_readers = []
tp_size_per_node = server_args.tp_size // server_args.nnodes
nnodes_per_tp_group = max(server_args.nnodes // server_args.pp_size, 1)
tp_size_per_node = server_args.tp_size // nnodes_per_tp_group
tp_rank_range = range(
tp_size_per_node * server_args.node_rank,
tp_size_per_node * (server_args.node_rank + 1),
tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group),
tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group + 1),
)
pp_size_per_node = max(server_args.pp_size // server_args.nnodes, 1)
pp_rank_range = range(
pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group),
pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group + 1),
)
for tp_rank in tp_rank_range:
rank_port_args = port_args
if server_args.enable_dp_attention:
# dp attention has different sharding logic
_, _, dp_rank = compute_dp_attention_world_info(
server_args.enable_dp_attention,
tp_rank,
server_args.tp_size,
server_args.dp_size,
for pp_rank in pp_rank_range:
for tp_rank in tp_rank_range:
rank_port_args = port_args
if server_args.enable_dp_attention:
# dp attention has different sharding logic
_, _, dp_rank = compute_dp_attention_world_info(
server_args.enable_dp_attention,
tp_rank,
server_args.tp_size,
server_args.dp_size,
)
# compute zmq ports for this dp rank
rank_port_args = PortArgs.init_new(server_args, dp_rank)
# Data parallelism resues the tensor parallelism group,
# so all dp ranks should use the same nccl port.
rank_port_args.nccl_port = port_args.nccl_port
reader, writer = mp.Pipe(duplex=False)
gpu_id = (
server_args.base_gpu_id
+ base_gpu_id
+ ((pp_rank % pp_size_per_node) * tp_size_per_node)
+ (tp_rank % tp_size_per_node) * server_args.gpu_id_step
)
# compute zmq ports for this dp rank
rank_port_args = PortArgs.init_new(server_args, dp_rank)
# Data parallelism resues the tensor parallelism group,
# so all dp ranks should use the same nccl port.
rank_port_args.nccl_port = port_args.nccl_port
reader, writer = mp.Pipe(duplex=False)
gpu_id = (
server_args.base_gpu_id
+ base_gpu_id
+ (tp_rank % tp_size_per_node) * server_args.gpu_id_step
)
proc = mp.Process(
target=run_scheduler_process,
args=(server_args, rank_port_args, gpu_id, tp_rank, dp_rank, writer),
)
with memory_saver_adapter.configure_subprocess():
proc.start()
self.scheduler_procs.append(proc)
scheduler_pipe_readers.append(reader)
proc = mp.Process(
target=run_scheduler_process,
args=(
server_args,
rank_port_args,
gpu_id,
tp_rank,
pp_rank,
dp_rank,
writer,
),
)
with memory_saver_adapter.configure_subprocess():
proc.start()
self.scheduler_procs.append(proc)
scheduler_pipe_readers.append(reader)
# Wait for model to finish loading
scheduler_info = []
......
......@@ -66,23 +66,24 @@ INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
# Put some global args for easy access
global_server_args_dict = {
"attention_backend": ServerArgs.attention_backend,
"sampling_backend": ServerArgs.sampling_backend,
"triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
"torchao_config": ServerArgs.torchao_config,
"enable_nan_detection": ServerArgs.enable_nan_detection,
"enable_dp_attention": ServerArgs.enable_dp_attention,
"enable_ep_moe": ServerArgs.enable_ep_moe,
"enable_deepep_moe": ServerArgs.enable_deepep_moe,
"chunked_prefill_size": ServerArgs.chunked_prefill_size,
"deepep_mode": ServerArgs.deepep_mode,
"device": ServerArgs.device,
"speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
"disable_chunked_prefix_cache": ServerArgs.disable_chunked_prefix_cache,
"disable_radix_cache": ServerArgs.disable_radix_cache,
"enable_deepep_moe": ServerArgs.enable_deepep_moe,
"enable_dp_attention": ServerArgs.enable_dp_attention,
"enable_ep_moe": ServerArgs.enable_ep_moe,
"enable_nan_detection": ServerArgs.enable_nan_detection,
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
"max_micro_batch_size": ServerArgs.max_micro_batch_size,
"moe_dense_tp_size": ServerArgs.moe_dense_tp_size,
"chunked_prefill_size": ServerArgs.chunked_prefill_size,
"n_share_experts_fusion": ServerArgs.n_share_experts_fusion,
"disable_chunked_prefix_cache": ServerArgs.disable_chunked_prefix_cache,
"sampling_backend": ServerArgs.sampling_backend,
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
"speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
"torchao_config": ServerArgs.torchao_config,
"triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
}
logger = logging.getLogger(__name__)
......@@ -728,6 +729,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Events
launch_done: Optional[threading.Event] = None
# For chunked prefill in PP
chunked_req: Optional[Req] = None
# Sampling info
sampling_info: SamplingBatchInfo = None
next_batch_sampling_info: SamplingBatchInfo = None
......@@ -761,7 +765,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# For extend and mixed chunekd prefill
prefix_lens: List[int] = None
extend_lens: List[int] = None
extend_num_tokens: int = None
extend_num_tokens: Optional[int] = None
decoding_reqs: List[Req] = None
extend_logprob_start_lens: List[int] = None
# It comes empty list if logprob is not required.
......@@ -803,6 +807,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
enable_overlap: bool,
spec_algorithm: SpeculativeAlgorithm,
enable_custom_logit_processor: bool,
chunked_req: Optional[Req] = None,
):
return_logprob = any(req.return_logprob for req in reqs)
......@@ -820,6 +825,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
spec_algorithm=spec_algorithm,
enable_custom_logit_processor=enable_custom_logit_processor,
return_hidden_states=any(req.return_hidden_states for req in reqs),
chunked_req=chunked_req,
)
def batch_size(self):
......@@ -1236,7 +1242,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
def retract_decode(self, server_args: ServerArgs):
"""Retract the decoding requests when there is not enough memory."""
sorted_indices = [i for i in range(len(self.reqs))]
sorted_indices = list(range(len(self.reqs)))
# TODO(lsyin): improve retraction policy for radix cache
# For spec decoding, filter_batch API can only filter
......@@ -1413,15 +1419,19 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
def filter_batch(
self,
chunked_req_to_exclude: Optional[Req] = None,
chunked_req_to_exclude: Optional[Union[Req, List[Req]]] = None,
keep_indices: Optional[List[int]] = None,
):
if keep_indices is None:
if isinstance(chunked_req_to_exclude, Req):
chunked_req_to_exclude = [chunked_req_to_exclude]
elif chunked_req_to_exclude is None:
chunked_req_to_exclude = []
keep_indices = [
i
for i in range(len(self.reqs))
if not self.reqs[i].finished()
and self.reqs[i] is not chunked_req_to_exclude
and not self.reqs[i] in chunked_req_to_exclude
]
if keep_indices is None or len(keep_indices) == 0:
......
This diff is collapsed.
......@@ -278,7 +278,7 @@ class SchedulerOutputProcessorMixin:
self.attn_tp_rank == 0
and self.forward_ct_decode % self.server_args.decode_log_interval == 0
):
self.log_decode_stats()
self.log_decode_stats(running_batch=batch)
def add_input_logprob_return_values(
self: Scheduler,
......
......@@ -15,11 +15,12 @@
import logging
import threading
from typing import Optional, Tuple
from typing import Optional, Tuple, Union
import torch
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.distributed import get_pp_group, get_tp_group, get_world_group
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import (
......@@ -31,7 +32,7 @@ from sglang.srt.managers.io_struct import (
)
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed
......@@ -47,6 +48,7 @@ class TpModelWorker:
server_args: ServerArgs,
gpu_id: int,
tp_rank: int,
pp_rank: int,
dp_rank: Optional[int],
nccl_port: int,
is_draft_worker: bool = False,
......@@ -54,7 +56,9 @@ class TpModelWorker:
token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
):
# Parse args
self.tp_size = server_args.tp_size
self.tp_rank = tp_rank
self.pp_rank = pp_rank
# Init model and tokenizer
self.model_config = ModelConfig(
......@@ -73,12 +77,15 @@ class TpModelWorker:
quantization=server_args.quantization,
is_draft_model=is_draft_worker,
)
self.model_runner = ModelRunner(
model_config=self.model_config,
mem_fraction_static=server_args.mem_fraction_static,
gpu_id=gpu_id,
tp_rank=tp_rank,
tp_size=server_args.tp_size,
pp_rank=pp_rank,
pp_size=server_args.pp_size,
nccl_port=nccl_port,
server_args=server_args,
is_draft_worker=is_draft_worker,
......@@ -105,6 +112,10 @@ class TpModelWorker:
)
self.device = self.model_runner.device
# Init nccl groups
self.pp_group = get_pp_group()
self.world_group = get_world_group()
# Profile number of tokens
self.max_total_num_tokens = self.model_runner.max_total_num_tokens
self.max_prefill_tokens = server_args.max_prefill_tokens
......@@ -130,8 +141,9 @@ class TpModelWorker:
# Sync random seed across TP workers
self.random_seed = broadcast_pyobj(
[server_args.random_seed],
self.tp_rank,
self.model_runner.tp_group.cpu_group,
self.tp_size * self.pp_rank + tp_rank,
self.world_group.cpu_group,
src=self.world_group.ranks[0],
)[0]
set_random_seed(self.random_seed)
......@@ -156,11 +168,14 @@ class TpModelWorker:
def get_pad_input_ids_func(self):
return getattr(self.model_runner.model, "pad_input_ids", None)
def get_tp_cpu_group(self):
return self.model_runner.tp_group.cpu_group
def get_tp_group(self):
return self.model_runner.tp_group
def get_attention_tp_group(self):
return self.model_runner.attention_tp_group
def get_attention_tp_cpu_group(self):
return self.model_runner.attention_tp_group.cpu_group
return getattr(self.model_runner.attention_tp_group, "cpu_group", None)
def get_memory_pool(self):
return (
......@@ -172,19 +187,38 @@ class TpModelWorker:
self,
model_worker_batch: ModelWorkerBatch,
skip_sample: bool = False,
) -> Tuple[LogitsProcessorOutput, Optional[torch.Tensor]]:
) -> Tuple[Union[LogitsProcessorOutput, torch.Tensor], Optional[torch.Tensor]]:
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
logits_output = self.model_runner.forward(forward_batch)
if model_worker_batch.launch_done is not None:
model_worker_batch.launch_done.set()
pp_proxy_tensors = None
if not self.pp_group.is_first_rank:
pp_proxy_tensors = PPProxyTensors(
self.pp_group.recv_tensor_dict(
all_gather_group=self.get_attention_tp_group()
)
)
if self.pp_group.is_last_rank:
logits_output = self.model_runner.forward(
forward_batch, pp_proxy_tensors=pp_proxy_tensors
)
if model_worker_batch.launch_done is not None:
model_worker_batch.launch_done.set()
if skip_sample:
next_token_ids = None
else:
next_token_ids = self.model_runner.sample(logits_output, model_worker_batch)
if skip_sample:
next_token_ids = None
else:
next_token_ids = self.model_runner.sample(
logits_output, model_worker_batch
)
return logits_output, next_token_ids
return logits_output, next_token_ids
else:
pp_proxy_tensors = self.model_runner.forward(
forward_batch,
pp_proxy_tensors=pp_proxy_tensors,
)
return pp_proxy_tensors.tensors, None
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
......
......@@ -56,11 +56,14 @@ class TpModelWorkerClient:
server_args: ServerArgs,
gpu_id: int,
tp_rank: int,
pp_rank: int,
dp_rank: Optional[int],
nccl_port: int,
):
# Load the model
self.worker = TpModelWorker(server_args, gpu_id, tp_rank, dp_rank, nccl_port)
self.worker = TpModelWorker(
server_args, gpu_id, tp_rank, pp_rank, dp_rank, nccl_port
)
self.max_running_requests = self.worker.max_running_requests
self.device = self.worker.device
self.gpu_id = gpu_id
......@@ -91,8 +94,11 @@ class TpModelWorkerClient:
def get_pad_input_ids_func(self):
return self.worker.get_pad_input_ids_func()
def get_tp_cpu_group(self):
return self.worker.get_tp_cpu_group()
def get_tp_group(self):
return self.worker.get_tp_group()
def get_attention_tp_group(self):
return self.worker.get_attention_tp_group()
def get_attention_tp_cpu_group(self):
return self.worker.get_attention_tp_cpu_group()
......
......@@ -214,6 +214,8 @@ class MHATokenToKVPool(KVCache):
layer_num: int,
device: str,
enable_memory_saver: bool,
start_layer: Optional[int] = None,
end_layer: Optional[int] = None,
):
self.size = size
self.page_size = page_size
......@@ -232,6 +234,8 @@ class MHATokenToKVPool(KVCache):
self.head_dim = head_dim
self.layer_num = layer_num
self._create_buffers()
self.start_layer = start_layer or 0
self.end_layer = end_layer or layer_num - 1
self.layer_transfer_counter = None
self.capture_mode = False
......@@ -281,6 +285,8 @@ class MHATokenToKVPool(KVCache):
# for disagg
def get_contiguous_buf_infos(self):
# layer_num x [seq_len, head_num, head_dim]
# layer_num x [page_num, page_size, head_num, head_dim]
kv_data_ptrs = [
self.get_key_buffer(i).data_ptr() for i in range(self.layer_num)
] + [self.get_value_buffer(i).data_ptr() for i in range(self.layer_num)]
......@@ -320,24 +326,24 @@ class MHATokenToKVPool(KVCache):
# transfer prepared data from host to device
flat_data = flat_data.to(device=self.device, non_blocking=False)
k_data, v_data = flat_data[0], flat_data[1]
self.k_buffer[layer_id][indices] = k_data
self.v_buffer[layer_id][indices] = v_data
self.k_buffer[layer_id - self.start_layer][indices] = k_data
self.v_buffer[layer_id - self.start_layer][indices] = v_data
def get_key_buffer(self, layer_id: int):
if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id)
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
if self.store_dtype != self.dtype:
return self.k_buffer[layer_id].view(self.dtype)
return self.k_buffer[layer_id]
return self.k_buffer[layer_id - self.start_layer].view(self.dtype)
return self.k_buffer[layer_id - self.start_layer]
def get_value_buffer(self, layer_id: int):
if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id)
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
if self.store_dtype != self.dtype:
return self.v_buffer[layer_id].view(self.dtype)
return self.v_buffer[layer_id]
return self.v_buffer[layer_id - self.start_layer].view(self.dtype)
return self.v_buffer[layer_id - self.start_layer]
def get_kv_buffer(self, layer_id: int):
return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
......@@ -369,12 +375,12 @@ class MHATokenToKVPool(KVCache):
current_stream = self.device_module.current_stream()
self.alt_stream.wait_stream(current_stream)
with self.device_module.stream(self.alt_stream):
self.k_buffer[layer_id][loc] = cache_k
self.v_buffer[layer_id][loc] = cache_v
self.k_buffer[layer_id - self.start_layer][loc] = cache_k
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
current_stream.wait_stream(self.alt_stream)
else:
self.k_buffer[layer_id][loc] = cache_k
self.v_buffer[layer_id][loc] = cache_v
self.k_buffer[layer_id - self.start_layer][loc] = cache_k
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
@torch.compile
......@@ -484,6 +490,8 @@ class MLATokenToKVPool(KVCache):
layer_num: int,
device: str,
enable_memory_saver: bool,
start_layer: Optional[int] = None,
end_layer: Optional[int] = None,
):
self.size = size
self.page_size = page_size
......@@ -497,6 +505,8 @@ class MLATokenToKVPool(KVCache):
self.kv_lora_rank = kv_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.layer_num = layer_num
self.start_layer = start_layer or 0
self.end_layer = end_layer or layer_num - 1
memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=enable_memory_saver
......@@ -540,19 +550,21 @@ class MLATokenToKVPool(KVCache):
def get_key_buffer(self, layer_id: int):
if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id)
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
if self.store_dtype != self.dtype:
return self.kv_buffer[layer_id].view(self.dtype)
return self.kv_buffer[layer_id]
return self.kv_buffer[layer_id - self.start_layer].view(self.dtype)
return self.kv_buffer[layer_id - self.start_layer]
def get_value_buffer(self, layer_id: int):
if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id)
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
if self.store_dtype != self.dtype:
return self.kv_buffer[layer_id][..., : self.kv_lora_rank].view(self.dtype)
return self.kv_buffer[layer_id][..., : self.kv_lora_rank]
return self.kv_buffer[layer_id - self.start_layer][
..., : self.kv_lora_rank
].view(self.dtype)
return self.kv_buffer[layer_id - self.start_layer][..., : self.kv_lora_rank]
def get_kv_buffer(self, layer_id: int):
return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
......@@ -568,9 +580,11 @@ class MLATokenToKVPool(KVCache):
if cache_k.dtype != self.dtype:
cache_k = cache_k.to(self.dtype)
if self.store_dtype != self.dtype:
self.kv_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
self.kv_buffer[layer_id - self.start_layer][loc] = cache_k.view(
self.store_dtype
)
else:
self.kv_buffer[layer_id][loc] = cache_k
self.kv_buffer[layer_id - self.start_layer][loc] = cache_k
def set_mla_kv_buffer(
self,
......@@ -605,7 +619,7 @@ class MLATokenToKVPool(KVCache):
def transfer_per_layer(self, indices, flat_data, layer_id):
# transfer prepared data from host to device
flat_data = flat_data.to(device=self.device, non_blocking=False)
self.kv_buffer[layer_id][indices] = flat_data
self.kv_buffer[layer_id - self.start_layer][indices] = flat_data
class DoubleSparseTokenToKVPool(KVCache):
......@@ -620,6 +634,8 @@ class DoubleSparseTokenToKVPool(KVCache):
device: str,
heavy_channel_num: int,
enable_memory_saver: bool,
start_layer: Optional[int] = None,
end_layer: Optional[int] = None,
):
self.size = size
self.page_size = page_size
......@@ -657,17 +673,23 @@ class DoubleSparseTokenToKVPool(KVCache):
for _ in range(layer_num)
]
self.start_layer = start_layer or 0
self.end_layer = end_layer or layer_num - 1
def get_key_buffer(self, layer_id: int):
return self.k_buffer[layer_id]
return self.k_buffer[layer_id - self.start_layer]
def get_value_buffer(self, layer_id: int):
return self.v_buffer[layer_id]
return self.v_buffer[layer_id - self.start_layer]
def get_label_buffer(self, layer_id: int):
return self.label_buffer[layer_id]
return self.label_buffer[layer_id - self.start_layer]
def get_kv_buffer(self, layer_id: int):
return self.k_buffer[layer_id], self.v_buffer[layer_id]
return (
self.k_buffer[layer_id - self.start_layer],
self.v_buffer[layer_id - self.start_layer],
)
def set_kv_buffer(
self,
......@@ -679,9 +701,9 @@ class DoubleSparseTokenToKVPool(KVCache):
):
# NOTE(Andy): ignore the dtype check
layer_id = layer.layer_id
self.k_buffer[layer_id][loc] = cache_k
self.v_buffer[layer_id][loc] = cache_v
self.label_buffer[layer_id][loc] = cache_label
self.k_buffer[layer_id - self.start_layer][loc] = cache_k
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
self.label_buffer[layer_id - self.start_layer][loc] = cache_label
def get_flat_data(self, indices):
pass
......@@ -930,7 +952,7 @@ class MHATokenToKVPoolHost(HostKVCache):
return self.kv_buffer[:, :, indices]
def get_flat_data_by_layer(self, indices, layer_id):
return self.kv_buffer[:, layer_id, indices]
return self.kv_buffer[:, layer_id - self.start_layer, indices]
def assign_flat_data(self, indices, flat_data):
self.kv_buffer[:, :, indices] = flat_data
......@@ -955,12 +977,20 @@ class MHATokenToKVPoolHost(HostKVCache):
for i in range(len(device_indices_cpu)):
h_index = host_indices[i * self.page_size]
d_index = device_indices_cpu[i]
device_pool.k_buffer[layer_id][d_index : d_index + self.page_size].copy_(
self.kv_buffer[0, layer_id, h_index : h_index + self.page_size],
device_pool.k_buffer[layer_id - self.start_layer][
d_index : d_index + self.page_size
].copy_(
self.kv_buffer[
0, layer_id - self.start_layer, h_index : h_index + self.page_size
],
non_blocking=True,
)
device_pool.v_buffer[layer_id][d_index : d_index + self.page_size].copy_(
self.kv_buffer[1, layer_id, h_index : h_index + self.page_size],
device_pool.v_buffer[layer_id - self.start_layer][
d_index : d_index + self.page_size
].copy_(
self.kv_buffer[
1, layer_id - self.start_layer, h_index : h_index + self.page_size
],
non_blocking=True,
)
......@@ -1015,7 +1045,7 @@ class MLATokenToKVPoolHost(HostKVCache):
return self.kv_buffer[:, indices]
def get_flat_data_by_layer(self, indices, layer_id):
return self.kv_buffer[layer_id, indices]
return self.kv_buffer[layer_id - self.start_layer, indices]
def assign_flat_data(self, indices, flat_data):
self.kv_buffer[:, indices] = flat_data
......@@ -1036,7 +1066,11 @@ class MLATokenToKVPoolHost(HostKVCache):
for i in range(len(device_indices_cpu)):
h_index = host_indices[i * self.page_size]
d_index = device_indices_cpu[i]
device_pool.kv_buffer[layer_id][d_index : d_index + self.page_size].copy_(
self.kv_buffer[layer_id, h_index : h_index + self.page_size],
device_pool.kv_buffer[layer_id - self.start_layer][
d_index : d_index + self.page_size
].copy_(
self.kv_buffer[
layer_id - self.start_layer, h_index : h_index + self.page_size
],
non_blocking=True,
)
......@@ -16,6 +16,7 @@
from __future__ import annotations
import bisect
import inspect
import os
from contextlib import contextmanager
from typing import TYPE_CHECKING, Callable
......@@ -33,12 +34,14 @@ from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode,
ForwardBatch,
ForwardMode,
PPProxyTensors,
)
from sglang.srt.patch_torch import monkey_patch_torch_compile
from sglang.srt.utils import (
get_available_gpu_memory,
get_device_memory_capacity,
is_hip,
rank0_log,
)
if TYPE_CHECKING:
......@@ -188,10 +191,11 @@ class CudaGraphRunner:
self.speculative_algorithm = model_runner.server_args.speculative_algorithm
self.tp_size = model_runner.server_args.tp_size
self.dp_size = model_runner.server_args.dp_size
self.pp_size = model_runner.server_args.pp_size
# Batch sizes to capture
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
rank0_log(f"Capture cuda graph bs {self.capture_bs}")
self.capture_forward_mode = ForwardMode.DECODE
self.capture_hidden_mode = CaptureHiddenMode.NULL
self.num_tokens_per_bs = 1
......@@ -234,6 +238,19 @@ class CudaGraphRunner:
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64)
# pipeline parallelism
if self.pp_size > 1:
self.pp_proxy_tensors = {
"hidden_states": torch.zeros(
(self.max_bs, self.model_runner.model_config.hidden_size),
dtype=torch.bfloat16,
),
"residual": torch.zeros(
(self.max_bs, self.model_runner.model_config.hidden_size),
dtype=torch.bfloat16,
),
}
# Speculative_inference
if (
model_runner.spec_algorithm.is_eagle3()
......@@ -384,6 +401,12 @@ class CudaGraphRunner:
encoder_lens = None
mrope_positions = self.mrope_positions[:, :bs]
# pipeline parallelism
if self.pp_size > 1:
pp_proxy_tensors = PPProxyTensors(
{k: v[:num_tokens] for k, v in self.pp_proxy_tensors.items()}
)
if self.enable_dp_attention or self.enable_sp_layernorm:
self.global_num_tokens_gpu.copy_(
torch.tensor(
......@@ -456,8 +479,20 @@ class CudaGraphRunner:
# Clean intermediate result cache for DP attention
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
logits_output = forward(input_ids, forward_batch.positions, forward_batch)
return logits_output.next_token_logits, logits_output.hidden_states
kwargs = {}
if (
self.pp_size > 1
and "pp_proxy_tensors" in inspect.signature(forward).parameters
):
kwargs["pp_proxy_tensors"] = pp_proxy_tensors
logits_output_or_pp_proxy_tensors = forward(
input_ids,
forward_batch.positions,
forward_batch,
**kwargs,
)
return logits_output_or_pp_proxy_tensors
for _ in range(2):
torch.cuda.synchronize()
......@@ -490,7 +525,11 @@ class CudaGraphRunner:
self.capture_hidden_mode = hidden_mode_from_spec_info
self.capture()
def replay_prepare(self, forward_batch: ForwardBatch):
def replay_prepare(
self,
forward_batch: ForwardBatch,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
):
self.recapture_if_needed(forward_batch)
raw_bs = forward_batch.batch_size
......@@ -519,6 +558,11 @@ class CudaGraphRunner:
self.seq_lens_cpu.fill_(1)
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
if pp_proxy_tensors:
for key in self.pp_proxy_tensors.keys():
dim = pp_proxy_tensors[key].shape[0]
self.pp_proxy_tensors[key][:dim].copy_(pp_proxy_tensors[key])
if self.is_encoder_decoder:
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
if forward_batch.mrope_positions is not None:
......@@ -547,10 +591,13 @@ class CudaGraphRunner:
self.bs = bs
def replay(
self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False
) -> LogitsProcessorOutput:
self,
forward_batch: ForwardBatch,
skip_attn_backend_init: bool = False,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
) -> Union[LogitsProcessorOutput, PPProxyTensors]:
if not skip_attn_backend_init:
self.replay_prepare(forward_batch)
self.replay_prepare(forward_batch, pp_proxy_tensors)
else:
# In speculative decoding, these two fields are still needed.
self.input_ids[: self.raw_num_token].copy_(forward_batch.input_ids)
......@@ -558,17 +605,19 @@ class CudaGraphRunner:
# Replay
self.graphs[self.bs].replay()
next_token_logits, hidden_states = self.output_buffers[self.bs]
logits_output = LogitsProcessorOutput(
next_token_logits=next_token_logits[: self.raw_num_token],
hidden_states=(
hidden_states[: self.raw_num_token]
if hidden_states is not None
else None
),
)
return logits_output
output = self.output_buffers[self.bs]
if isinstance(output, LogitsProcessorOutput):
return LogitsProcessorOutput(
next_token_logits=output.next_token_logits[: self.raw_num_token],
hidden_states=(
output.hidden_states[: self.raw_num_token]
if output.hidden_states is not None
else None
),
)
else:
assert isinstance(output, PPProxyTensors)
return PPProxyTensors({k: v[: self.bs] for k, v in output.tensors.items()})
def get_spec_info(self, num_tokens: int):
spec_info = None
......
......@@ -31,7 +31,7 @@ from __future__ import annotations
from dataclasses import dataclass
from enum import IntEnum, auto
from typing import TYPE_CHECKING, List, Optional, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Union
import torch
import triton
......@@ -585,6 +585,36 @@ class ForwardBatch:
self.prepare_chunked_kv_indices(device)
class PPProxyTensors:
# adapted from https://github.com/vllm-project/vllm/blob/d14e98d924724b284dc5eaf8070d935e214e50c0/vllm/sequence.py#L1103
tensors: Dict[str, torch.Tensor]
def __init__(self, tensors):
# manually define this function, so that
# Dynamo knows `IntermediateTensors()` comes from this file.
# Otherwise, dataclass will generate this function by evaluating
# a string, and we will lose the information about the source file.
self.tensors = tensors
def __getitem__(self, key: Union[str, slice]):
if isinstance(key, str):
return self.tensors[key]
elif isinstance(key, slice):
return self.__class__({k: v[key] for k, v in self.tensors.items()})
def __setitem__(self, key: str, value: torch.Tensor):
self.tensors[key] = value
def __len__(self):
return len(self.tensors)
def __eq__(self, other: object):
return isinstance(other, self.__class__) and self
def __repr__(self) -> str:
return f"PPProxyTensors(tensors={self.tensors})"
def compute_position_triton(
extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum
):
......
......@@ -13,8 +13,10 @@
# ==============================================================================
"""ModelRunner runs the forward passes of the models."""
import collections
import datetime
import gc
import inspect
import json
import logging
import os
......@@ -59,7 +61,7 @@ from sglang.srt.mem_cache.memory_pool import (
)
from sglang.srt.mem_cache.paged_allocator import PagedTokenToKVPoolAllocator
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader import get_model
from sglang.srt.model_loader.loader import (
DefaultModelLoader,
......@@ -111,6 +113,8 @@ class ModelRunner:
gpu_id: int,
tp_rank: int,
tp_size: int,
pp_rank: int,
pp_size: int,
nccl_port: int,
server_args: ServerArgs,
is_draft_worker: bool = False,
......@@ -124,6 +128,8 @@ class ModelRunner:
self.gpu_id = gpu_id
self.tp_rank = tp_rank
self.tp_size = tp_size
self.pp_rank = pp_rank
self.pp_size = pp_size
self.dist_port = nccl_port
self.server_args = server_args
self.is_draft_worker = is_draft_worker
......@@ -149,24 +155,24 @@ class ModelRunner:
global_server_args_dict.update(
{
"attention_backend": server_args.attention_backend,
"sampling_backend": server_args.sampling_backend,
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
"torchao_config": server_args.torchao_config,
"debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
"debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
"deepep_mode": server_args.deepep_mode,
"device": server_args.device,
"disable_chunked_prefix_cache": server_args.disable_chunked_prefix_cache,
"disable_radix_cache": server_args.disable_radix_cache,
"enable_nan_detection": server_args.enable_nan_detection,
"enable_dp_attention": server_args.enable_dp_attention,
"enable_ep_moe": server_args.enable_ep_moe,
"enable_deepep_moe": server_args.enable_deepep_moe,
"deepep_mode": server_args.deepep_mode,
"device": server_args.device,
"speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
"disable_radix_cache": server_args.disable_radix_cache,
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
"moe_dense_tp_size": server_args.moe_dense_tp_size,
"debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
"debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
"n_share_experts_fusion": server_args.n_share_experts_fusion,
"disable_chunked_prefix_cache": server_args.disable_chunked_prefix_cache,
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
"torchao_config": server_args.torchao_config,
"sampling_backend": server_args.sampling_backend,
"speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
"use_mla_backend": self.use_mla_backend,
}
)
......@@ -184,6 +190,11 @@ class ModelRunner:
# If it is a draft model, tp_group can be different
self.initialize(min_per_gpu_memory)
# temporary cached values
self.support_pp = (
"pp_proxy_tensors" in inspect.signature(self.model.forward).parameters
)
def initialize(self, min_per_gpu_memory: float):
server_args = self.server_args
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
......@@ -194,6 +205,12 @@ class ModelRunner:
self.sampler = Sampler()
self.load_model()
self.start_layer = getattr(self.model, "start_layer", 0)
self.end_layer = getattr(
self.model, "end_layer", self.model_config.num_hidden_layers
)
self.num_effective_layers = self.end_layer - self.start_layer
# Apply torchao quantization
torchao_applied = getattr(self.model, "torchao_applied", False)
# In layered loading, torchao may have been applied
......@@ -360,18 +377,22 @@ class ModelRunner:
# Only initialize the distributed environment on the target model worker.
init_distributed_environment(
backend=backend,
world_size=self.tp_size,
rank=self.tp_rank,
world_size=self.tp_size * self.pp_size,
rank=self.tp_size * self.pp_rank + self.tp_rank,
local_rank=self.gpu_id,
distributed_init_method=dist_init_method,
timeout=self.server_args.dist_timeout,
)
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
initialize_model_parallel(
tensor_model_parallel_size=self.tp_size,
pipeline_model_parallel_size=self.pp_size,
)
initialize_dp_attention(
enable_dp_attention=self.server_args.enable_dp_attention,
tp_rank=self.tp_rank,
tp_size=self.tp_size,
dp_size=self.server_args.dp_size,
pp_size=self.server_args.pp_size,
)
min_per_gpu_memory = get_available_gpu_memory(
......@@ -698,6 +719,8 @@ class ModelRunner:
if not self.is_draft_worker
else self.model_config.hf_config.num_nextn_predict_layers
)
# FIXME: pipeline parallelism is not compatible with mla backend
assert self.pp_size == 1
cell_size = (
(self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
* num_layers
......@@ -707,7 +730,7 @@ class ModelRunner:
cell_size = (
self.model_config.get_num_kv_heads(get_attention_tp_size())
* self.model_config.head_dim
* self.model_config.num_hidden_layers
* self.num_effective_layers
* 2
* torch._utils._element_size(self.kv_cache_dtype)
)
......@@ -819,9 +842,11 @@ class ModelRunner:
self.model_config.num_hidden_layers
if not self.is_draft_worker
else self.model_config.hf_config.num_nextn_predict_layers
),
), # PP is not compatible with mla backend
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
start_layer=self.start_layer,
end_layer=self.end_layer,
)
elif self.server_args.enable_double_sparsity:
self.token_to_kv_pool = DoubleSparseTokenToKVPool(
......@@ -830,10 +855,12 @@ class ModelRunner:
dtype=self.kv_cache_dtype,
head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
head_dim=self.model_config.head_dim,
layer_num=self.model_config.num_hidden_layers,
layer_num=self.num_effective_layers,
device=self.device,
heavy_channel_num=self.server_args.ds_heavy_channel_num,
enable_memory_saver=self.server_args.enable_memory_saver,
start_layer=self.start_layer,
end_layer=self.end_layer,
)
else:
self.token_to_kv_pool = MHATokenToKVPool(
......@@ -842,9 +869,11 @@ class ModelRunner:
dtype=self.kv_cache_dtype,
head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
head_dim=self.model_config.head_dim,
layer_num=self.model_config.num_hidden_layers,
layer_num=self.num_effective_layers,
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
start_layer=self.start_layer,
end_layer=self.end_layer,
)
if self.token_to_kv_pool_allocator is None:
......@@ -957,7 +986,7 @@ class ModelRunner:
with open(self.server_args.ds_channel_config_path, "r") as f:
channel_config = json.load(f)
for i in range(self.model_config.num_hidden_layers):
for i in range(self.start_layer, self.end_layer):
key = "model.layers." + str(i) + ".self_attn" + selected_channel
self.sorted_channels.append(
torch.tensor(channel_config[key])[
......@@ -997,64 +1026,82 @@ class ModelRunner:
device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
tensor_parallel(self.model, device_mesh)
def forward_decode(self, forward_batch: ForwardBatch):
def forward_decode(
self, forward_batch: ForwardBatch, pp_proxy_tensors=None
) -> LogitsProcessorOutput:
self.attn_backend.init_forward_metadata(forward_batch)
# FIXME: add pp_proxy_tensors arg to all models
kwargs = {}
if self.support_pp:
kwargs["pp_proxy_tensors"] = pp_proxy_tensors
return self.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch
forward_batch.input_ids, forward_batch.positions, forward_batch, **kwargs
)
def forward_extend(
self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False
):
self,
forward_batch: ForwardBatch,
skip_attn_backend_init: bool = False,
pp_proxy_tensors=None,
) -> LogitsProcessorOutput:
if not skip_attn_backend_init:
self.attn_backend.init_forward_metadata(forward_batch)
if self.is_generation:
if forward_batch.input_embeds is None:
return self.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch
)
else:
return self.model.forward(
forward_batch.input_ids,
forward_batch.positions,
forward_batch,
input_embeds=forward_batch.input_embeds.bfloat16(),
)
else:
# Only embedding models have get_embedding parameter
return self.model.forward(
forward_batch.input_ids,
forward_batch.positions,
forward_batch,
get_embedding=True,
)
kwargs = {}
if self.support_pp:
kwargs["pp_proxy_tensors"] = pp_proxy_tensors
if forward_batch.input_embeds is not None:
kwargs["input_embeds"] = forward_batch.input_embeds.bfloat16()
if not self.is_generation:
kwargs["get_embedding"] = True
return self.model.forward(
forward_batch.input_ids,
forward_batch.positions,
forward_batch,
**kwargs,
)
def forward_idle(self, forward_batch: ForwardBatch):
def forward_idle(
self, forward_batch: ForwardBatch, pp_proxy_tensors=None
) -> LogitsProcessorOutput:
kwargs = {}
if self.support_pp:
kwargs["pp_proxy_tensors"] = pp_proxy_tensors
return self.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch
forward_batch.input_ids,
forward_batch.positions,
forward_batch,
**kwargs,
)
def forward(
self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False
) -> LogitsProcessorOutput:
if (
self,
forward_batch: ForwardBatch,
skip_attn_backend_init: bool = False,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
) -> Union[LogitsProcessorOutput, PPProxyTensors]:
can_run_cuda_graph = bool(
forward_batch.forward_mode.is_cuda_graph()
and self.cuda_graph_runner
and self.cuda_graph_runner.can_run(forward_batch)
):
)
if can_run_cuda_graph:
return self.cuda_graph_runner.replay(
forward_batch, skip_attn_backend_init=skip_attn_backend_init
forward_batch,
skip_attn_backend_init=skip_attn_backend_init,
pp_proxy_tensors=pp_proxy_tensors,
)
if forward_batch.forward_mode.is_decode():
return self.forward_decode(forward_batch)
return self.forward_decode(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
elif forward_batch.forward_mode.is_extend():
return self.forward_extend(
forward_batch, skip_attn_backend_init=skip_attn_backend_init
forward_batch,
skip_attn_backend_init=skip_attn_backend_init,
pp_proxy_tensors=pp_proxy_tensors,
)
elif forward_batch.forward_mode.is_idle():
return self.forward_idle(forward_batch)
return self.forward_idle(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
else:
raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
......
......@@ -17,13 +17,14 @@
"""Inference-only LLaMA model compatible with HuggingFace weights."""
import logging
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
import torch
from torch import nn
from transformers import LlamaConfig
from sglang.srt.distributed import (
get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
......@@ -39,11 +40,12 @@ from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import (
default_weight_loader,
kv_cache_scales_loader,
......@@ -275,21 +277,31 @@ class LlamaModel(nn.Module):
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("embed_tokens", prefix),
)
self.layers = make_layers(
self.pp_group = get_pp_group()
if self.pp_group.is_first_rank:
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("embed_tokens", prefix),
)
else:
self.embed_tokens = PPMissingLayer()
self.layers, self.start_layer, self.end_layer = make_layers(
config.num_hidden_layers,
lambda idx, prefix: LlamaDecoderLayer(
config=config, layer_id=idx, quant_config=quant_config, prefix=prefix
config=config, quant_config=quant_config, layer_id=idx, prefix=prefix
),
pp_rank=self.pp_group.rank_in_group,
pp_size=self.pp_group.world_size,
prefix="model.layers",
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
if self.pp_group.is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer(return_tuple=True)
self.layers_to_capture = []
def forward(
......@@ -298,14 +310,23 @@ class LlamaModel(nn.Module):
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
if input_embeds is None:
hidden_states = self.embed_tokens(input_ids)
pp_proxy_tensors: Optional[PPProxyTensors] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]], PPProxyTensors]:
if self.pp_group.is_first_rank:
if input_embeds is None:
hidden_states = self.embed_tokens(input_ids)
else:
hidden_states = input_embeds
residual = None
else:
hidden_states = input_embeds
residual = None
assert pp_proxy_tensors is not None
# FIXME(@ying): reduce the number of proxy tensors by not fusing layer norms
hidden_states = pp_proxy_tensors["hidden_states"]
residual = pp_proxy_tensors["residual"]
deferred_norm = None
aux_hidden_states = []
for i in range(len(self.layers)):
for i in range(self.start_layer, self.end_layer):
if i in self.layers_to_capture:
aux_hidden_states.append(hidden_states + residual)
layer = self.layers[i]
......@@ -315,7 +336,16 @@ class LlamaModel(nn.Module):
forward_batch,
residual,
)
hidden_states, _ = self.norm(hidden_states, residual)
if not self.pp_group.is_last_rank:
return PPProxyTensors(
{
"hidden_states": hidden_states,
"residual": residual,
}
)
else:
hidden_states, _ = self.norm(hidden_states, residual)
if len(aux_hidden_states) == 0:
return hidden_states
......@@ -376,6 +406,7 @@ class LlamaForCausalLM(nn.Module):
prefix: str = "",
) -> None:
super().__init__()
self.pp_group = get_pp_group()
self.config = config
self.quant_config = quant_config
self.model = self._init_model(config, quant_config, add_prefix("model", prefix))
......@@ -419,23 +450,41 @@ class LlamaForCausalLM(nn.Module):
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
get_embedding: bool = False,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
) -> LogitsProcessorOutput:
hidden_states = self.model(
input_ids,
positions,
forward_batch,
input_embeds,
pp_proxy_tensors=pp_proxy_tensors,
)
aux_hidden_states = None
if self.capture_aux_hidden_states:
hidden_states, aux_hidden_states = self.model(
input_ids, positions, forward_batch, input_embeds
)
hidden_states, aux_hidden_states = hidden_states
if self.pp_group.is_last_rank:
if not get_embedding:
return self.logits_processor(
input_ids,
hidden_states,
self.lm_head,
forward_batch,
aux_hidden_states,
)
else:
return self.pooler(hidden_states, forward_batch)
else:
hidden_states = self.model(
input_ids, positions, forward_batch, input_embeds
)
return hidden_states
if not get_embedding:
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states
)
else:
return self.pooler(hidden_states, forward_batch)
@property
def start_layer(self):
return self.model.start_layer
@property
def end_layer(self):
return self.model.end_layer
def get_input_embeddings(self) -> nn.Embedding:
return self.model.embed_tokens
......@@ -491,6 +540,16 @@ class LlamaForCausalLM(nn.Module):
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
layer_id = get_layer_id(name)
if (
layer_id is not None
and hasattr(self.model, "start_layer")
and (
layer_id < self.model.start_layer
or layer_id >= self.model.end_layer
)
):
continue
if "rotary_emb.inv_freq" in name or "projector" in name:
continue
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
......@@ -637,6 +696,9 @@ class LlamaForCausalLM(nn.Module):
self.model.load_kv_cache_scales(quantization_param_path)
def set_eagle3_layers_to_capture(self):
if not self.pp_group.is_last_rank:
return
self.capture_aux_hidden_states = True
num_layers = self.config.num_hidden_layers
self.model.layers_to_capture = [2, num_layers // 2, num_layers - 3]
......
......@@ -46,7 +46,7 @@ from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.models.llama import LlamaForCausalLM, LlamaMLP
from sglang.srt.utils import add_prefix, fast_topk, get_compiler_backend, make_layers
......@@ -431,6 +431,7 @@ class Llama4Model(nn.Module):
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
if input_embeds is None:
hidden_states = self.embed_tokens(input_ids)
......
......@@ -25,13 +25,14 @@ import torch
from torch import nn
from transformers import LlamaConfig
from sglang.srt.distributed import get_pp_group
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.models.llama import LlamaDecoderLayer, LlamaForCausalLM
......@@ -86,6 +87,7 @@ class LlamaModel(nn.Module):
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
) -> torch.Tensor:
if input_embeds is None:
hidden_states = self.embed_tokens(input_ids)
......@@ -118,6 +120,7 @@ class LlamaForCausalLMEagle(LlamaForCausalLM):
nn.Module.__init__(self)
self.config = config
self.quant_config = quant_config
self.pp_group = get_pp_group()
self.model = LlamaModel(
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
)
......
......@@ -25,6 +25,7 @@ import torch
from torch import nn
from transformers import LlamaConfig
from sglang.srt.distributed import get_pp_group
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import QKVParallelLinear, RowParallelLinear
from sglang.srt.layers.logits_processor import LogitsProcessor
......@@ -33,7 +34,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.models.llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM
......@@ -118,6 +119,7 @@ class LlamaModel(nn.Module):
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
) -> torch.Tensor:
if input_embeds is None:
embeds = self.embed_tokens(input_ids)
......@@ -155,6 +157,7 @@ class LlamaForCausalLMEagle3(LlamaForCausalLM):
nn.Module.__init__(self)
self.config = config
self.quant_config = quant_config
self.pp_group = get_pp_group()
if self.config.num_hidden_layers != 1:
raise ValueError("EAGLE3 currently only supports 1 layer")
......
......@@ -78,6 +78,8 @@ class ServerArgs:
# Other runtime options
tp_size: int = 1
pp_size: int = 1
max_micro_batch_size: Optional[int] = None
stream_interval: int = 1
stream_output: bool = False
random_seed: Optional[int] = None
......@@ -222,14 +224,18 @@ class ServerArgs:
# Set mem fraction static, which depends on the tensor parallelism size
if self.mem_fraction_static is None:
if self.tp_size >= 16:
self.mem_fraction_static = 0.79
elif self.tp_size >= 8:
self.mem_fraction_static = 0.81
elif self.tp_size >= 4:
self.mem_fraction_static = 0.85
elif self.tp_size >= 2:
self.mem_fraction_static = 0.87
parallel_size = self.tp_size * self.pp_size
if gpu_mem <= 81920:
if parallel_size >= 16:
self.mem_fraction_static = 0.79
elif parallel_size >= 8:
self.mem_fraction_static = 0.81
elif parallel_size >= 4:
self.mem_fraction_static = 0.85
elif parallel_size >= 2:
self.mem_fraction_static = 0.87
else:
self.mem_fraction_static = 0.88
else:
self.mem_fraction_static = 0.88
if gpu_mem > 96 * 1024:
......@@ -244,6 +250,8 @@ class ServerArgs:
if self.chunked_prefill_size is None:
if gpu_mem is not None and gpu_mem < 25_000:
self.chunked_prefill_size = 2048
elif self.disaggregation_mode != "null":
self.chunked_prefill_size = 16384
else:
self.chunked_prefill_size = 8192
assert self.chunked_prefill_size % self.page_size == 0
......@@ -643,6 +651,19 @@ class ServerArgs:
default=ServerArgs.tp_size,
help="The tensor parallelism size.",
)
parser.add_argument(
"--pipeline-parallel-size",
"--pp-size",
type=int,
default=ServerArgs.pp_size,
help="The pipeline parallelism size.",
)
parser.add_argument(
"--max-micro-batch-size",
type=int,
default=ServerArgs.max_micro_batch_size,
help="The maximum micro batch size in pipeline parallelism.",
)
parser.add_argument(
"--stream-interval",
type=int,
......@@ -1232,6 +1253,7 @@ class ServerArgs:
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
args.tp_size = args.tensor_parallel_size
args.pp_size = args.pipeline_parallel_size
args.dp_size = args.data_parallel_size
args.ep_size = args.expert_parallel_size
attrs = [attr.name for attr in dataclasses.fields(cls)]
......@@ -1245,8 +1267,19 @@ class ServerArgs:
def check_server_args(self):
assert (
self.tp_size % self.nnodes == 0
), "tp_size must be divisible by number of nodes"
self.tp_size * self.pp_size
) % self.nnodes == 0, "tp_size must be divisible by number of nodes"
# FIXME pp constraints
if self.pp_size > 1:
logger.warning(f"Turn off overlap scheule for pipeline parallelism.")
self.disable_overlap_schedule = True
assert (
self.disable_overlap_schedule
and self.speculative_algorithm is None
and not self.enable_mixed_chunk
), "Pipeline parallelism is not compatible with overlap schedule, speculative decoding, mixed chunked prefill."
assert not (
self.dp_size > 1 and self.nnodes != 1 and not self.enable_dp_attention
), "multi-node data parallel is not supported unless dp attention!"
......
......@@ -106,11 +106,12 @@ class EAGLEWorker(TpModelWorker):
# Init draft worker
with empty_context():
super().__init__(
server_args=server_args,
gpu_id=gpu_id,
tp_rank=tp_rank,
server_args=server_args,
nccl_port=nccl_port,
pp_rank=0, # FIXME
dp_rank=dp_rank,
nccl_port=nccl_port,
is_draft_worker=True,
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment