Unverified Commit ac201a0e authored by yzds's avatar yzds Committed by GitHub
Browse files

[Feature] Support Decode Context Parallel (DCP) for MLA (#23734)


Signed-off-by: default avatarhongchao <hongchao@msh.team>
Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
Co-authored-by: default avatarhongchao <hongchao@msh.team>
Co-authored-by: default avataryoukaichao <youkaichao@gmail.com>
parent 3c529fc9
......@@ -100,6 +100,15 @@ class Scheduler(SchedulerInterface):
self.block_size = self.cache_config.block_size
self.dcp_world_size = \
vllm_config.parallel_config.decode_context_parallel_size
# Note(hc): The scheduler’s block_size must be multiplied
# by dcp_world_size, since block hashes are computed on the
# original full token sequence at a granularity of
# original_block_size × dcp_world_size.
if self.dcp_world_size > 1:
self.block_size *= self.dcp_world_size
# req_id -> Request
self.requests: dict[str, Request] = {}
# Scheduling policy
......@@ -161,6 +170,7 @@ class Scheduler(SchedulerInterface):
use_eagle=self.use_eagle,
log_stats=self.log_stats,
enable_kv_cache_events=self.enable_kv_cache_events,
dcp_world_size=self.dcp_world_size,
)
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
......
......@@ -25,6 +25,7 @@ class SingleTypeKVCacheManager(ABC):
kv_cache_spec: KVCacheSpec,
block_pool: BlockPool,
kv_cache_group_id: int,
dcp_world_size: int = 1,
) -> None:
"""
Initializes the SingleTypeKVCacheManager.
......@@ -33,8 +34,10 @@ class SingleTypeKVCacheManager(ABC):
block_pool: The block pool.
kv_cache_group_id: The id of the kv cache group of this manager.
"""
self.block_size = kv_cache_spec.block_size
self.dcp_world_size = dcp_world_size
if self.dcp_world_size > 1:
self.block_size *= dcp_world_size
self.kv_cache_spec = kv_cache_spec
self.block_pool = block_pool
......@@ -196,6 +199,7 @@ class SingleTypeKVCacheManager(ABC):
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
dcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]:
"""
Get the longest cache hit prefix of the blocks that is not longer than
......@@ -253,6 +257,7 @@ class FullAttentionManager(SingleTypeKVCacheManager):
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
dcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]:
assert isinstance(
kv_cache_spec, (FullAttentionSpec, ChunkedLocalAttentionSpec)
......@@ -260,7 +265,10 @@ class FullAttentionManager(SingleTypeKVCacheManager):
"and chunked local attention groups"
computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
[] for _ in range(len(kv_cache_group_ids)))
max_num_blocks = max_length // kv_cache_spec.block_size
block_size = kv_cache_spec.block_size
if dcp_world_size > 1:
block_size *= dcp_world_size
max_num_blocks = max_length // block_size
for block_hash in itertools.islice(block_hashes, max_num_blocks):
# block_hashes is a chain of block hashes. If a block hash is not
# in the cached_block_hash_to_id, the following block hashes are
......@@ -310,9 +318,11 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
dcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]:
assert isinstance(kv_cache_spec, SlidingWindowSpec), (
"SlidingWindowManager can only be used for sliding window groups")
assert dcp_world_size == 1, "DCP not support sliding window attn now."
# The number of contiguous blocks needed for prefix cache hit.
# -1 since the input token itself is also included in the window
......@@ -408,6 +418,7 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
dcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]:
"""
For chunked local attention, we need to find the longest cache hit
......@@ -445,6 +456,7 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
"chunked local attention groups")
assert use_eagle is False, ("Hybrid KV cache is not supported for " +
"eagle + chunked local attention.")
assert dcp_world_size == 1, "DCP not support chunked local attn now."
max_num_blocks = max_length // kv_cache_spec.block_size
if max_length > 0:
local_attention_start_idx = (max_length //
......@@ -525,10 +537,12 @@ class MambaManager(SingleTypeKVCacheManager):
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
dcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]:
assert isinstance(
kv_cache_spec,
MambaSpec), ("MambaManager can only be used for mamba groups")
assert dcp_world_size == 1, "DCP not support mamba now."
# Prefix caching is not supported for mamba now. Always return empty
# list.
computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
......@@ -583,6 +597,7 @@ class CrossAttentionManager(SingleTypeKVCacheManager):
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
dcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]:
assert isinstance(kv_cache_spec, CrossAttentionSpec), (
"CrossAttentionManager can only be used for cross-attention groups"
......
......@@ -86,6 +86,12 @@ class FullAttentionSpec(AttentionSpec):
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
max_model_len = vllm_config.model_config.max_model_len
dcp_world_size = \
vllm_config.parallel_config.decode_context_parallel_size
# Note(hc): each dcp rank only need save
# (max_model_len//dcp_world_size) tokens locally.
if dcp_world_size > 1:
max_model_len = cdiv(max_model_len, dcp_world_size)
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
@classmethod
......@@ -162,6 +168,8 @@ class SlidingWindowSpec(AttentionSpec):
assert not self.use_mla, "MLA is not supported for sliding window"
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
assert vllm_config.parallel_config.decode_context_parallel_size == 1, \
"DCP not support sliding window."
max_model_len = vllm_config.model_config.max_model_len
max_num_batched_tokens = (
vllm_config.scheduler_config.max_num_batched_tokens)
......
......@@ -4,6 +4,7 @@
import numpy as np
import torch
from vllm.distributed import get_dcp_group
from vllm.logger import init_logger
from vllm.utils import cdiv
......@@ -50,6 +51,13 @@ class BlockTable:
self.slot_mapping = torch.zeros(self.max_num_batched_tokens,
dtype=torch.int64,
device=self.device)
try:
self.dcp_world_size = get_dcp_group().world_size
self.dcp_rank = get_dcp_group().rank_in_group
except AssertionError:
# DCP might not be initialized in testing
self.dcp_world_size = 1
self.dcp_rank = 0
def append_row(
self,
......@@ -89,13 +97,36 @@ class BlockTable:
# NOTE(woosuk): We can't simply use `token_indices // block_size`
# here because M (max_model_len) is not necessarily divisible by
# block_size.
block_table_indices = (req_indices * self.max_num_blocks_per_req +
positions // self.block_size)
block_numbers = self.block_table_np.ravel()[block_table_indices]
block_offsets = positions % self.block_size
np.add(block_numbers * self.block_size,
block_offsets,
out=self.slot_mapping_np[:req_indices.shape[0]])
if self.dcp_world_size > 1:
# Note(hc): The DCP implement store kvcache with a interleave
# style, the kvcache for the token whose token_idx is i is
# always stored on the GPU whose dcp_rank equals i % cp_world_size:
# Use a "virtual block" which equals to world_size * block_size
# for block_table_indices calculation.
virtual_block_size = self.block_size * self.dcp_world_size
block_table_indices = (req_indices * self.max_num_blocks_per_req +
positions // virtual_block_size)
block_numbers = self.block_table_np.ravel()[block_table_indices]
# Use virtual_block_size for mask calculation, which marks local
# tokens.
virtual_block_offsets = positions % virtual_block_size
mask = virtual_block_offsets % self.dcp_world_size == self.dcp_rank
# Calcuate local block_offsets
block_offsets = virtual_block_offsets // self.dcp_world_size
# Calcuate slot_mapping
slot_mapping = block_numbers * self.block_size + block_offsets
# Write final slots, use -1 for not-local
self.slot_mapping_np[:req_indices.shape[0]] = np.where(
mask, slot_mapping, -1)
else:
block_table_indices = (req_indices * self.max_num_blocks_per_req +
positions // self.block_size)
block_numbers = self.block_table_np.ravel()[block_table_indices]
block_offsets = positions % self.block_size
np.add(block_numbers * self.block_size,
block_offsets,
out=self.slot_mapping_np[:req_indices.shape[0]])
def commit_block_table(self, num_reqs: int) -> None:
self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs],
......@@ -128,9 +159,19 @@ class MultiGroupBlockTable:
def __init__(self, max_num_reqs: int, max_model_len: int,
max_num_batched_tokens: int, pin_memory: bool,
device: torch.device, block_sizes: list[int]) -> None:
# Note(hc): each dcp rank only store
# (max_model_len//dcp_world_size) tokens in kvcache,
# so the block_size which used for calc max_num_blocks_per_req
# must be multiplied by dcp_world_size.
try:
dcp_world_size = get_dcp_group().world_size
except AssertionError:
# DCP might not be initialized in testing
dcp_world_size = 1
self.block_tables = [
BlockTable(block_size, max_num_reqs, cdiv(max_model_len,
block_size),
BlockTable(block_size, max_num_reqs,
cdiv(max_model_len, block_size * dcp_world_size),
max_num_batched_tokens, pin_memory, device)
for block_size in block_sizes
]
......
......@@ -56,6 +56,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
GiB_bytes, LazyLoader, cdiv, check_use_alibi,
get_dtype_size, is_pin_memory_available, round_up,
supports_dynamo)
from vllm.v1.attention.backends.mla.flashmla import FlashMLABackend
from vllm.v1.attention.backends.utils import (
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
create_fast_prefill_custom_backend,
......@@ -187,6 +188,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
model_config.is_multimodal_raw_input_only_model)
self.max_model_len = model_config.max_model_len
self.dcp_world_size = self.parallel_config.decode_context_parallel_size
self.max_num_tokens = scheduler_config.max_num_batched_tokens
self.max_num_reqs = scheduler_config.max_num_seqs
......@@ -428,6 +430,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return
if self.reorder_batch_threshold is not None:
if self.dcp_world_size > 1:
assert self.reorder_batch_threshold == 1, \
"DCP not support reorder_batch_threshold > 1 now."
reorder_batch_to_split_decodes_and_prefills(
self.input_batch,
scheduler_output,
......@@ -3305,6 +3310,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
get_kv_transfer_group().set_host_xfer_buffer_ops(
copy_kv_blocks)
if self.dcp_world_size > 1:
assert self.attn_groups[0][0].backend is FlashMLABackend, (
"DCP only support flashmla now."
"For a mla backend want to enable DCP, it is mandatory that the"
"corresponding decode attn kernel return the softmax lse.")
def may_add_encoder_only_layers_to_kv_cache_config(self) -> None:
"""
Add encoder-only layers to the KV cache config.
......
......@@ -616,7 +616,9 @@ def init_worker_distributed_environment(
init_distributed_environment(parallel_config.world_size, rank,
distributed_init_method, local_rank, backend)
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
ensure_model_parallel_initialized(
parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size,
parallel_config.decode_context_parallel_size)
ensure_kv_transfer_initialized(vllm_config)
......@@ -539,8 +539,10 @@ def init_worker_distributed_environment(
init_distributed_environment(parallel_config.world_size, rank,
distributed_init_method, local_rank,
current_platform.dist_backend)
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
ensure_model_parallel_initialized(
parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size,
parallel_config.decode_context_parallel_size)
ensure_kv_transfer_initialized(vllm_config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment