Unverified Commit 2fd893b4 authored by Qiu's avatar Qiu Committed by GitHub
Browse files

[Feature] Prefill Context Parallel (PCP) basic support (#28718)


Signed-off-by: default avatarQiuChunshuo <qiuchunshuo@huawei.com>
Signed-off-by: default avatarFENP <yuanyongjie.yyj@antgroup.com>
Signed-off-by: default avatarLookAround <lixushi@huawei.com>
Signed-off-by: default avatarJingchun Gao <gaojingchun1@huawei.com>
Signed-off-by: default avatarzhenwenqi2024 <zhenwenqi_2022@qq.com>
Co-authored-by: default avatarFENP <yuanyongjie.yyj@antgroup.com>
Co-authored-by: default avatarLookAround <lixushi@huawei.com>
Co-authored-by: default avatarJingchun Gao <gaojingchun1@huawei.com>
Co-authored-by: default avatarzhenwenqi2024 <zhenwenqi_2022@qq.com>
Co-authored-by: default avatarJingchun Gao <63247409+gjc0824@users.noreply.github.com>
parent 02f5903b
...@@ -128,6 +128,7 @@ class EngineCore: ...@@ -128,6 +128,7 @@ class EngineCore:
scheduler_block_size = ( scheduler_block_size = (
vllm_config.cache_config.block_size vllm_config.cache_config.block_size
* vllm_config.parallel_config.decode_context_parallel_size * vllm_config.parallel_config.decode_context_parallel_size
* vllm_config.parallel_config.prefill_context_parallel_size
) )
self.scheduler: SchedulerInterface = Scheduler( self.scheduler: SchedulerInterface = Scheduler(
......
...@@ -35,6 +35,7 @@ from vllm.distributed.parallel_state import ( ...@@ -35,6 +35,7 @@ from vllm.distributed.parallel_state import (
get_dp_group, get_dp_group,
get_ep_group, get_ep_group,
get_inner_dp_world_group, get_inner_dp_world_group,
get_pcp_group,
get_pp_group, get_pp_group,
get_tp_group, get_tp_group,
) )
...@@ -110,12 +111,14 @@ class MultiprocExecutor(Executor): ...@@ -110,12 +111,14 @@ class MultiprocExecutor(Executor):
f"({self.parallel_config.nnodes_within_dp}). " f"({self.parallel_config.nnodes_within_dp}). "
) )
self.local_world_size = self.parallel_config.local_world_size self.local_world_size = self.parallel_config.local_world_size
tensor_parallel_size = self.parallel_config.tensor_parallel_size tp_size = self.parallel_config.tensor_parallel_size
pp_parallel_size = self.parallel_config.pipeline_parallel_size pp_size = self.parallel_config.pipeline_parallel_size
assert self.world_size == tensor_parallel_size * pp_parallel_size, ( pcp_size = self.parallel_config.prefill_context_parallel_size
assert self.world_size == tp_size * pp_size * pcp_size, (
f"world_size ({self.world_size}) must be equal to the " f"world_size ({self.world_size}) must be equal to the "
f"tensor_parallel_size ({tensor_parallel_size}) x pipeline" f"tensor_parallel_size ({tp_size}) x pipeline"
f"_parallel_size ({pp_parallel_size}). " f"_parallel_size ({pp_size}) x prefill_context"
f"_parallel_size ({pcp_size}). "
) )
# Set multiprocessing envs # Set multiprocessing envs
...@@ -424,7 +427,11 @@ class MultiprocExecutor(Executor): ...@@ -424,7 +427,11 @@ class MultiprocExecutor(Executor):
# 16-23, PP rank 2 # 16-23, PP rank 2
# 24-31, PP rank 3 # 24-31, PP rank 3
# so world_size - tp_size = 32 - 8 = 24 should be PP rank = -1 (i.e. 3) # so world_size - tp_size = 32 - 8 = 24 should be PP rank = -1 (i.e. 3)
return self.world_size - self.parallel_config.tensor_parallel_size return (
self.world_size
- self.parallel_config.tensor_parallel_size
* self.parallel_config.prefill_context_parallel_size
)
@dataclass @dataclass
...@@ -828,6 +835,8 @@ class WorkerProc: ...@@ -828,6 +835,8 @@ class WorkerProc:
dp_rank = get_dp_group().rank_in_group dp_rank = get_dp_group().rank_in_group
pp_size = get_pp_group().world_size pp_size = get_pp_group().world_size
pp_rank = get_pp_group().rank_in_group pp_rank = get_pp_group().rank_in_group
pcp_size = get_pcp_group().world_size
pcp_rank = get_pcp_group().rank_in_group
tp_size = get_tp_group().world_size tp_size = get_tp_group().world_size
tp_rank = get_tp_group().rank_in_group tp_rank = get_tp_group().rank_in_group
dcp_size = get_dcp_group().world_size dcp_size = get_dcp_group().world_size
...@@ -837,6 +846,8 @@ class WorkerProc: ...@@ -837,6 +846,8 @@ class WorkerProc:
process_name += f"_DP{dp_rank}" process_name += f"_DP{dp_rank}"
if pp_size > 1: if pp_size > 1:
process_name += f"_PP{pp_rank}" process_name += f"_PP{pp_rank}"
if pcp_size > 1:
process_name += f"_PCP{pcp_rank}"
if tp_size > 1: if tp_size > 1:
process_name += f"_TP{tp_rank}" process_name += f"_TP{tp_rank}"
if dcp_size > 1: if dcp_size > 1:
......
...@@ -95,10 +95,11 @@ class FullAttentionSpec(AttentionSpec): ...@@ -95,10 +95,11 @@ class FullAttentionSpec(AttentionSpec):
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
max_model_len = vllm_config.model_config.max_model_len max_model_len = vllm_config.model_config.max_model_len
dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size
pcp_world_size = vllm_config.parallel_config.prefill_context_parallel_size
# Note(hc): each dcp rank only need save # Note(hc): each dcp rank only need save
# (max_model_len//dcp_world_size) tokens locally. # (max_model_len//dcp_world_size) tokens locally.
if dcp_world_size > 1: if dcp_world_size * pcp_world_size > 1:
max_model_len = cdiv(max_model_len, dcp_world_size) max_model_len = cdiv(max_model_len, dcp_world_size * pcp_world_size)
return cdiv(max_model_len, self.block_size) * self.page_size_bytes return cdiv(max_model_len, self.block_size) * self.page_size_bytes
@classmethod @classmethod
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import numpy as np import numpy as np
import torch import torch
from vllm.distributed import get_dcp_group from vllm.distributed import get_dcp_group, get_pcp_group
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.v1.utils import CpuGpuBuffer from vllm.v1.utils import CpuGpuBuffer
...@@ -22,7 +22,7 @@ class BlockTable: ...@@ -22,7 +22,7 @@ class BlockTable:
pin_memory: bool, pin_memory: bool,
device: torch.device, device: torch.device,
kernel_block_size: int, kernel_block_size: int,
dcp_kv_cache_interleave_size: int, cp_kv_cache_interleave_size: int,
): ):
""" """
Args: Args:
...@@ -80,6 +80,13 @@ class BlockTable: ...@@ -80,6 +80,13 @@ class BlockTable:
else: else:
self._kernel_block_arange = None self._kernel_block_arange = None
try:
self.pcp_world_size = get_pcp_group().world_size
self.pcp_rank = get_pcp_group().rank_in_group
except AssertionError:
# DCP might not be initialized in testing
self.pcp_world_size = 1
self.pcp_rank = 0
try: try:
self.dcp_world_size = get_dcp_group().world_size self.dcp_world_size = get_dcp_group().world_size
self.dcp_rank = get_dcp_group().rank_in_group self.dcp_rank = get_dcp_group().rank_in_group
...@@ -87,7 +94,7 @@ class BlockTable: ...@@ -87,7 +94,7 @@ class BlockTable:
# DCP might not be initialized in testing # DCP might not be initialized in testing
self.dcp_world_size = 1 self.dcp_world_size = 1
self.dcp_rank = 0 self.dcp_rank = 0
self.dcp_kv_cache_interleave_size = dcp_kv_cache_interleave_size self.cp_kv_cache_interleave_size = cp_kv_cache_interleave_size
def append_row( def append_row(
self, self,
...@@ -131,14 +138,16 @@ class BlockTable: ...@@ -131,14 +138,16 @@ class BlockTable:
# NOTE(woosuk): We can't simply use `token_indices // block_size` # NOTE(woosuk): We can't simply use `token_indices // block_size`
# here because M (max_model_len) is not necessarily divisible by # here because M (max_model_len) is not necessarily divisible by
# block_size. # block_size.
if self.dcp_world_size > 1: total_cp_world_size = self.pcp_world_size * self.dcp_world_size
total_cp_rank = self.pcp_rank * self.dcp_world_size + self.dcp_rank
if total_cp_world_size > 1:
# Note(hc): The DCP implement store kvcache with an interleave # Note(hc): The DCP implement store kvcache with an interleave
# style, the kvcache for the token whose token_idx is i is # style, the kvcache for the token whose token_idx is i is
# always stored on the GPU whose dcp_rank equals i % cp_world_size: # always stored on the GPU whose dcp_rank equals i % cp_world_size:
# Use a "virtual block" which equals to world_size * block_size # Use a "virtual block" which equals to world_size * block_size
# for block_table_indices calculation. # for block_table_indices calculation.
virtual_block_size = self.block_size * self.dcp_world_size virtual_block_size = self.block_size * total_cp_world_size
block_table_indices = ( block_table_indices = (
req_indices * self.max_num_blocks_per_req req_indices * self.max_num_blocks_per_req
+ positions // virtual_block_size + positions // virtual_block_size
...@@ -150,16 +159,16 @@ class BlockTable: ...@@ -150,16 +159,16 @@ class BlockTable:
virtual_block_offsets = positions % virtual_block_size virtual_block_offsets = positions % virtual_block_size
mask = ( mask = (
virtual_block_offsets virtual_block_offsets
// self.dcp_kv_cache_interleave_size // self.cp_kv_cache_interleave_size
% self.dcp_world_size % total_cp_world_size
== self.dcp_rank == total_cp_rank
) )
# Calculate local block_offsets # Calculate local block_offsets
block_offsets = ( block_offsets = (
virtual_block_offsets virtual_block_offsets
// (self.dcp_world_size * self.dcp_kv_cache_interleave_size) // (total_cp_world_size * self.cp_kv_cache_interleave_size)
* self.dcp_kv_cache_interleave_size * self.cp_kv_cache_interleave_size
+ virtual_block_offsets % self.dcp_kv_cache_interleave_size + virtual_block_offsets % self.cp_kv_cache_interleave_size
) )
# Calculate slot_mapping # Calculate slot_mapping
slot_mapping = block_numbers * self.block_size + block_offsets slot_mapping = block_numbers * self.block_size + block_offsets
...@@ -253,7 +262,7 @@ class MultiGroupBlockTable: ...@@ -253,7 +262,7 @@ class MultiGroupBlockTable:
block_sizes: list[int], block_sizes: list[int],
kernel_block_sizes: list[int], kernel_block_sizes: list[int],
num_speculative_tokens: int = 0, num_speculative_tokens: int = 0,
dcp_kv_cache_interleave_size: int = 1, cp_kv_cache_interleave_size: int = 1,
) -> None: ) -> None:
# Note(hc): each dcp rank only store # Note(hc): each dcp rank only store
# (max_model_len//dcp_world_size) tokens in kvcache, # (max_model_len//dcp_world_size) tokens in kvcache,
...@@ -283,7 +292,7 @@ class MultiGroupBlockTable: ...@@ -283,7 +292,7 @@ class MultiGroupBlockTable:
pin_memory, pin_memory,
device, device,
kernel_block_size, kernel_block_size,
dcp_kv_cache_interleave_size, cp_kv_cache_interleave_size,
) )
for block_size, kernel_block_size in zip(block_sizes, kernel_block_sizes) for block_size, kernel_block_size in zip(block_sizes, kernel_block_sizes)
] ]
......
...@@ -87,7 +87,7 @@ class InputBatch: ...@@ -87,7 +87,7 @@ class InputBatch:
is_spec_decode: bool = False, is_spec_decode: bool = False,
is_pooling_model: bool = False, is_pooling_model: bool = False,
num_speculative_tokens: int = 0, num_speculative_tokens: int = 0,
dcp_kv_cache_interleave_size: int = 1, cp_kv_cache_interleave_size: int = 1,
): ):
self.is_pooling_model = is_pooling_model self.is_pooling_model = is_pooling_model
self.is_spec_decode = is_spec_decode self.is_spec_decode = is_spec_decode
...@@ -141,7 +141,7 @@ class InputBatch: ...@@ -141,7 +141,7 @@ class InputBatch:
block_sizes=block_sizes, block_sizes=block_sizes,
kernel_block_sizes=kernel_block_sizes, kernel_block_sizes=kernel_block_sizes,
num_speculative_tokens=num_speculative_tokens, num_speculative_tokens=num_speculative_tokens,
dcp_kv_cache_interleave_size=dcp_kv_cache_interleave_size, cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
) )
# Sampling-related. # Sampling-related.
......
...@@ -426,7 +426,7 @@ class GPUModelRunner( ...@@ -426,7 +426,7 @@ class GPUModelRunner(
# uses output token ids so we set this conservatively. # uses output token ids so we set this conservatively.
logitsprocs_need_output_token_ids=bool(custom_logitsprocs), logitsprocs_need_output_token_ids=bool(custom_logitsprocs),
is_pooling_model=self.is_pooling_model, is_pooling_model=self.is_pooling_model,
dcp_kv_cache_interleave_size=self.parallel_config.dcp_kv_cache_interleave_size, cp_kv_cache_interleave_size=self.parallel_config.cp_kv_cache_interleave_size,
) )
self.use_async_scheduling = self.scheduler_config.async_scheduling self.use_async_scheduling = self.scheduler_config.async_scheduling
...@@ -1436,7 +1436,7 @@ class GPUModelRunner( ...@@ -1436,7 +1436,7 @@ class GPUModelRunner(
self.seq_lens.cpu[:num_reqs], self.seq_lens.cpu[:num_reqs],
self.dcp_world_size, self.dcp_world_size,
self.dcp_rank, self.dcp_rank,
self.parallel_config.dcp_kv_cache_interleave_size, self.parallel_config.cp_kv_cache_interleave_size,
) )
self.dcp_local_seq_lens.copy_to_gpu(num_reqs) self.dcp_local_seq_lens.copy_to_gpu(num_reqs)
......
...@@ -26,6 +26,7 @@ from vllm.distributed.kv_transfer import ( ...@@ -26,6 +26,7 @@ from vllm.distributed.kv_transfer import (
has_kv_transfer_group, has_kv_transfer_group,
) )
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
get_pcp_group,
get_pp_group, get_pp_group,
get_tp_group, get_tp_group,
) )
...@@ -733,6 +734,7 @@ class Worker(WorkerBase): ...@@ -733,6 +734,7 @@ class Worker(WorkerBase):
module.global_num_experts = module.moe_config.num_experts module.global_num_experts = module.moe_config.num_experts
module.moe_parallel_config = FusedMoEParallelConfig.make( module.moe_parallel_config = FusedMoEParallelConfig.make(
tp_size_=get_tp_group().world_size, tp_size_=get_tp_group().world_size,
pcp_size_=get_pcp_group().world_size,
dp_size_=get_dp_group().world_size, dp_size_=get_dp_group().world_size,
vllm_parallel_config=parallel_config, vllm_parallel_config=parallel_config,
) )
...@@ -886,6 +888,7 @@ def init_worker_distributed_environment( ...@@ -886,6 +888,7 @@ def init_worker_distributed_environment(
ensure_model_parallel_initialized( ensure_model_parallel_initialized(
parallel_config.tensor_parallel_size, parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size, parallel_config.pipeline_parallel_size,
parallel_config.prefill_context_parallel_size,
parallel_config.decode_context_parallel_size, parallel_config.decode_context_parallel_size,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment