Unverified Commit ac201a0e authored by yzds's avatar yzds Committed by GitHub
Browse files

[Feature] Support Decode Context Parallel (DCP) for MLA (#23734)


Signed-off-by: default avatarhongchao <hongchao@msh.team>
Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
Co-authored-by: default avatarhongchao <hongchao@msh.team>
Co-authored-by: default avataryoukaichao <youkaichao@gmail.com>
parent 3c529fc9
...@@ -100,6 +100,15 @@ class Scheduler(SchedulerInterface): ...@@ -100,6 +100,15 @@ class Scheduler(SchedulerInterface):
self.block_size = self.cache_config.block_size self.block_size = self.cache_config.block_size
self.dcp_world_size = \
vllm_config.parallel_config.decode_context_parallel_size
# Note(hc): The scheduler’s block_size must be multiplied
# by dcp_world_size, since block hashes are computed on the
# original full token sequence at a granularity of
# original_block_size × dcp_world_size.
if self.dcp_world_size > 1:
self.block_size *= self.dcp_world_size
# req_id -> Request # req_id -> Request
self.requests: dict[str, Request] = {} self.requests: dict[str, Request] = {}
# Scheduling policy # Scheduling policy
...@@ -161,6 +170,7 @@ class Scheduler(SchedulerInterface): ...@@ -161,6 +170,7 @@ class Scheduler(SchedulerInterface):
use_eagle=self.use_eagle, use_eagle=self.use_eagle,
log_stats=self.log_stats, log_stats=self.log_stats,
enable_kv_cache_events=self.enable_kv_cache_events, enable_kv_cache_events=self.enable_kv_cache_events,
dcp_world_size=self.dcp_world_size,
) )
self.use_pp = self.parallel_config.pipeline_parallel_size > 1 self.use_pp = self.parallel_config.pipeline_parallel_size > 1
......
...@@ -25,6 +25,7 @@ class SingleTypeKVCacheManager(ABC): ...@@ -25,6 +25,7 @@ class SingleTypeKVCacheManager(ABC):
kv_cache_spec: KVCacheSpec, kv_cache_spec: KVCacheSpec,
block_pool: BlockPool, block_pool: BlockPool,
kv_cache_group_id: int, kv_cache_group_id: int,
dcp_world_size: int = 1,
) -> None: ) -> None:
""" """
Initializes the SingleTypeKVCacheManager. Initializes the SingleTypeKVCacheManager.
...@@ -33,8 +34,10 @@ class SingleTypeKVCacheManager(ABC): ...@@ -33,8 +34,10 @@ class SingleTypeKVCacheManager(ABC):
block_pool: The block pool. block_pool: The block pool.
kv_cache_group_id: The id of the kv cache group of this manager. kv_cache_group_id: The id of the kv cache group of this manager.
""" """
self.block_size = kv_cache_spec.block_size self.block_size = kv_cache_spec.block_size
self.dcp_world_size = dcp_world_size
if self.dcp_world_size > 1:
self.block_size *= dcp_world_size
self.kv_cache_spec = kv_cache_spec self.kv_cache_spec = kv_cache_spec
self.block_pool = block_pool self.block_pool = block_pool
...@@ -196,6 +199,7 @@ class SingleTypeKVCacheManager(ABC): ...@@ -196,6 +199,7 @@ class SingleTypeKVCacheManager(ABC):
block_pool: BlockPool, block_pool: BlockPool,
kv_cache_spec: KVCacheSpec, kv_cache_spec: KVCacheSpec,
use_eagle: bool, use_eagle: bool,
dcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]: ) -> tuple[list[KVCacheBlock], ...]:
""" """
Get the longest cache hit prefix of the blocks that is not longer than Get the longest cache hit prefix of the blocks that is not longer than
...@@ -253,6 +257,7 @@ class FullAttentionManager(SingleTypeKVCacheManager): ...@@ -253,6 +257,7 @@ class FullAttentionManager(SingleTypeKVCacheManager):
block_pool: BlockPool, block_pool: BlockPool,
kv_cache_spec: KVCacheSpec, kv_cache_spec: KVCacheSpec,
use_eagle: bool, use_eagle: bool,
dcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]: ) -> tuple[list[KVCacheBlock], ...]:
assert isinstance( assert isinstance(
kv_cache_spec, (FullAttentionSpec, ChunkedLocalAttentionSpec) kv_cache_spec, (FullAttentionSpec, ChunkedLocalAttentionSpec)
...@@ -260,7 +265,10 @@ class FullAttentionManager(SingleTypeKVCacheManager): ...@@ -260,7 +265,10 @@ class FullAttentionManager(SingleTypeKVCacheManager):
"and chunked local attention groups" "and chunked local attention groups"
computed_blocks: tuple[list[KVCacheBlock], ...] = tuple( computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
[] for _ in range(len(kv_cache_group_ids))) [] for _ in range(len(kv_cache_group_ids)))
max_num_blocks = max_length // kv_cache_spec.block_size block_size = kv_cache_spec.block_size
if dcp_world_size > 1:
block_size *= dcp_world_size
max_num_blocks = max_length // block_size
for block_hash in itertools.islice(block_hashes, max_num_blocks): for block_hash in itertools.islice(block_hashes, max_num_blocks):
# block_hashes is a chain of block hashes. If a block hash is not # block_hashes is a chain of block hashes. If a block hash is not
# in the cached_block_hash_to_id, the following block hashes are # in the cached_block_hash_to_id, the following block hashes are
...@@ -310,9 +318,11 @@ class SlidingWindowManager(SingleTypeKVCacheManager): ...@@ -310,9 +318,11 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
block_pool: BlockPool, block_pool: BlockPool,
kv_cache_spec: KVCacheSpec, kv_cache_spec: KVCacheSpec,
use_eagle: bool, use_eagle: bool,
dcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]: ) -> tuple[list[KVCacheBlock], ...]:
assert isinstance(kv_cache_spec, SlidingWindowSpec), ( assert isinstance(kv_cache_spec, SlidingWindowSpec), (
"SlidingWindowManager can only be used for sliding window groups") "SlidingWindowManager can only be used for sliding window groups")
assert dcp_world_size == 1, "DCP not support sliding window attn now."
# The number of contiguous blocks needed for prefix cache hit. # The number of contiguous blocks needed for prefix cache hit.
# -1 since the input token itself is also included in the window # -1 since the input token itself is also included in the window
...@@ -408,6 +418,7 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager): ...@@ -408,6 +418,7 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
block_pool: BlockPool, block_pool: BlockPool,
kv_cache_spec: KVCacheSpec, kv_cache_spec: KVCacheSpec,
use_eagle: bool, use_eagle: bool,
dcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]: ) -> tuple[list[KVCacheBlock], ...]:
""" """
For chunked local attention, we need to find the longest cache hit For chunked local attention, we need to find the longest cache hit
...@@ -445,6 +456,7 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager): ...@@ -445,6 +456,7 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
"chunked local attention groups") "chunked local attention groups")
assert use_eagle is False, ("Hybrid KV cache is not supported for " + assert use_eagle is False, ("Hybrid KV cache is not supported for " +
"eagle + chunked local attention.") "eagle + chunked local attention.")
assert dcp_world_size == 1, "DCP not support chunked local attn now."
max_num_blocks = max_length // kv_cache_spec.block_size max_num_blocks = max_length // kv_cache_spec.block_size
if max_length > 0: if max_length > 0:
local_attention_start_idx = (max_length // local_attention_start_idx = (max_length //
...@@ -525,10 +537,12 @@ class MambaManager(SingleTypeKVCacheManager): ...@@ -525,10 +537,12 @@ class MambaManager(SingleTypeKVCacheManager):
block_pool: BlockPool, block_pool: BlockPool,
kv_cache_spec: KVCacheSpec, kv_cache_spec: KVCacheSpec,
use_eagle: bool, use_eagle: bool,
dcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]: ) -> tuple[list[KVCacheBlock], ...]:
assert isinstance( assert isinstance(
kv_cache_spec, kv_cache_spec,
MambaSpec), ("MambaManager can only be used for mamba groups") MambaSpec), ("MambaManager can only be used for mamba groups")
assert dcp_world_size == 1, "DCP not support mamba now."
# Prefix caching is not supported for mamba now. Always return empty # Prefix caching is not supported for mamba now. Always return empty
# list. # list.
computed_blocks: tuple[list[KVCacheBlock], ...] = tuple( computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
...@@ -583,6 +597,7 @@ class CrossAttentionManager(SingleTypeKVCacheManager): ...@@ -583,6 +597,7 @@ class CrossAttentionManager(SingleTypeKVCacheManager):
block_pool: BlockPool, block_pool: BlockPool,
kv_cache_spec: KVCacheSpec, kv_cache_spec: KVCacheSpec,
use_eagle: bool, use_eagle: bool,
dcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]: ) -> tuple[list[KVCacheBlock], ...]:
assert isinstance(kv_cache_spec, CrossAttentionSpec), ( assert isinstance(kv_cache_spec, CrossAttentionSpec), (
"CrossAttentionManager can only be used for cross-attention groups" "CrossAttentionManager can only be used for cross-attention groups"
......
...@@ -86,6 +86,12 @@ class FullAttentionSpec(AttentionSpec): ...@@ -86,6 +86,12 @@ class FullAttentionSpec(AttentionSpec):
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
max_model_len = vllm_config.model_config.max_model_len max_model_len = vllm_config.model_config.max_model_len
dcp_world_size = \
vllm_config.parallel_config.decode_context_parallel_size
# Note(hc): each dcp rank only need save
# (max_model_len//dcp_world_size) tokens locally.
if dcp_world_size > 1:
max_model_len = cdiv(max_model_len, dcp_world_size)
return cdiv(max_model_len, self.block_size) * self.page_size_bytes return cdiv(max_model_len, self.block_size) * self.page_size_bytes
@classmethod @classmethod
...@@ -162,6 +168,8 @@ class SlidingWindowSpec(AttentionSpec): ...@@ -162,6 +168,8 @@ class SlidingWindowSpec(AttentionSpec):
assert not self.use_mla, "MLA is not supported for sliding window" assert not self.use_mla, "MLA is not supported for sliding window"
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
assert vllm_config.parallel_config.decode_context_parallel_size == 1, \
"DCP not support sliding window."
max_model_len = vllm_config.model_config.max_model_len max_model_len = vllm_config.model_config.max_model_len
max_num_batched_tokens = ( max_num_batched_tokens = (
vllm_config.scheduler_config.max_num_batched_tokens) vllm_config.scheduler_config.max_num_batched_tokens)
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import numpy as np import numpy as np
import torch import torch
from vllm.distributed import get_dcp_group
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import cdiv from vllm.utils import cdiv
...@@ -50,6 +51,13 @@ class BlockTable: ...@@ -50,6 +51,13 @@ class BlockTable:
self.slot_mapping = torch.zeros(self.max_num_batched_tokens, self.slot_mapping = torch.zeros(self.max_num_batched_tokens,
dtype=torch.int64, dtype=torch.int64,
device=self.device) device=self.device)
try:
self.dcp_world_size = get_dcp_group().world_size
self.dcp_rank = get_dcp_group().rank_in_group
except AssertionError:
# DCP might not be initialized in testing
self.dcp_world_size = 1
self.dcp_rank = 0
def append_row( def append_row(
self, self,
...@@ -89,13 +97,36 @@ class BlockTable: ...@@ -89,13 +97,36 @@ class BlockTable:
# NOTE(woosuk): We can't simply use `token_indices // block_size` # NOTE(woosuk): We can't simply use `token_indices // block_size`
# here because M (max_model_len) is not necessarily divisible by # here because M (max_model_len) is not necessarily divisible by
# block_size. # block_size.
block_table_indices = (req_indices * self.max_num_blocks_per_req + if self.dcp_world_size > 1:
positions // self.block_size) # Note(hc): The DCP implement store kvcache with a interleave
block_numbers = self.block_table_np.ravel()[block_table_indices] # style, the kvcache for the token whose token_idx is i is
block_offsets = positions % self.block_size # always stored on the GPU whose dcp_rank equals i % cp_world_size:
np.add(block_numbers * self.block_size,
block_offsets, # Use a "virtual block" which equals to world_size * block_size
out=self.slot_mapping_np[:req_indices.shape[0]]) # for block_table_indices calculation.
virtual_block_size = self.block_size * self.dcp_world_size
block_table_indices = (req_indices * self.max_num_blocks_per_req +
positions // virtual_block_size)
block_numbers = self.block_table_np.ravel()[block_table_indices]
# Use virtual_block_size for mask calculation, which marks local
# tokens.
virtual_block_offsets = positions % virtual_block_size
mask = virtual_block_offsets % self.dcp_world_size == self.dcp_rank
# Calcuate local block_offsets
block_offsets = virtual_block_offsets // self.dcp_world_size
# Calcuate slot_mapping
slot_mapping = block_numbers * self.block_size + block_offsets
# Write final slots, use -1 for not-local
self.slot_mapping_np[:req_indices.shape[0]] = np.where(
mask, slot_mapping, -1)
else:
block_table_indices = (req_indices * self.max_num_blocks_per_req +
positions // self.block_size)
block_numbers = self.block_table_np.ravel()[block_table_indices]
block_offsets = positions % self.block_size
np.add(block_numbers * self.block_size,
block_offsets,
out=self.slot_mapping_np[:req_indices.shape[0]])
def commit_block_table(self, num_reqs: int) -> None: def commit_block_table(self, num_reqs: int) -> None:
self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs], self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs],
...@@ -128,9 +159,19 @@ class MultiGroupBlockTable: ...@@ -128,9 +159,19 @@ class MultiGroupBlockTable:
def __init__(self, max_num_reqs: int, max_model_len: int, def __init__(self, max_num_reqs: int, max_model_len: int,
max_num_batched_tokens: int, pin_memory: bool, max_num_batched_tokens: int, pin_memory: bool,
device: torch.device, block_sizes: list[int]) -> None: device: torch.device, block_sizes: list[int]) -> None:
# Note(hc): each dcp rank only store
# (max_model_len//dcp_world_size) tokens in kvcache,
# so the block_size which used for calc max_num_blocks_per_req
# must be multiplied by dcp_world_size.
try:
dcp_world_size = get_dcp_group().world_size
except AssertionError:
# DCP might not be initialized in testing
dcp_world_size = 1
self.block_tables = [ self.block_tables = [
BlockTable(block_size, max_num_reqs, cdiv(max_model_len, BlockTable(block_size, max_num_reqs,
block_size), cdiv(max_model_len, block_size * dcp_world_size),
max_num_batched_tokens, pin_memory, device) max_num_batched_tokens, pin_memory, device)
for block_size in block_sizes for block_size in block_sizes
] ]
......
...@@ -56,6 +56,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, ...@@ -56,6 +56,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
GiB_bytes, LazyLoader, cdiv, check_use_alibi, GiB_bytes, LazyLoader, cdiv, check_use_alibi,
get_dtype_size, is_pin_memory_available, round_up, get_dtype_size, is_pin_memory_available, round_up,
supports_dynamo) supports_dynamo)
from vllm.v1.attention.backends.mla.flashmla import FlashMLABackend
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
create_fast_prefill_custom_backend, create_fast_prefill_custom_backend,
...@@ -187,6 +188,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -187,6 +188,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
model_config.is_multimodal_raw_input_only_model) model_config.is_multimodal_raw_input_only_model)
self.max_model_len = model_config.max_model_len self.max_model_len = model_config.max_model_len
self.dcp_world_size = self.parallel_config.decode_context_parallel_size
self.max_num_tokens = scheduler_config.max_num_batched_tokens self.max_num_tokens = scheduler_config.max_num_batched_tokens
self.max_num_reqs = scheduler_config.max_num_seqs self.max_num_reqs = scheduler_config.max_num_seqs
...@@ -428,6 +430,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -428,6 +430,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return return
if self.reorder_batch_threshold is not None: if self.reorder_batch_threshold is not None:
if self.dcp_world_size > 1:
assert self.reorder_batch_threshold == 1, \
"DCP not support reorder_batch_threshold > 1 now."
reorder_batch_to_split_decodes_and_prefills( reorder_batch_to_split_decodes_and_prefills(
self.input_batch, self.input_batch,
scheduler_output, scheduler_output,
...@@ -3305,6 +3310,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -3305,6 +3310,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
get_kv_transfer_group().set_host_xfer_buffer_ops( get_kv_transfer_group().set_host_xfer_buffer_ops(
copy_kv_blocks) copy_kv_blocks)
if self.dcp_world_size > 1:
assert self.attn_groups[0][0].backend is FlashMLABackend, (
"DCP only support flashmla now."
"For a mla backend want to enable DCP, it is mandatory that the"
"corresponding decode attn kernel return the softmax lse.")
def may_add_encoder_only_layers_to_kv_cache_config(self) -> None: def may_add_encoder_only_layers_to_kv_cache_config(self) -> None:
""" """
Add encoder-only layers to the KV cache config. Add encoder-only layers to the KV cache config.
......
...@@ -616,7 +616,9 @@ def init_worker_distributed_environment( ...@@ -616,7 +616,9 @@ def init_worker_distributed_environment(
init_distributed_environment(parallel_config.world_size, rank, init_distributed_environment(parallel_config.world_size, rank,
distributed_init_method, local_rank, backend) distributed_init_method, local_rank, backend)
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, ensure_model_parallel_initialized(
parallel_config.pipeline_parallel_size) parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size,
parallel_config.decode_context_parallel_size)
ensure_kv_transfer_initialized(vllm_config) ensure_kv_transfer_initialized(vllm_config)
...@@ -539,8 +539,10 @@ def init_worker_distributed_environment( ...@@ -539,8 +539,10 @@ def init_worker_distributed_environment(
init_distributed_environment(parallel_config.world_size, rank, init_distributed_environment(parallel_config.world_size, rank,
distributed_init_method, local_rank, distributed_init_method, local_rank,
current_platform.dist_backend) current_platform.dist_backend)
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, ensure_model_parallel_initialized(
parallel_config.pipeline_parallel_size) parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size,
parallel_config.decode_context_parallel_size)
ensure_kv_transfer_initialized(vllm_config) ensure_kv_transfer_initialized(vllm_config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment