Unverified Commit ab33d2a6 authored by Wentao Ye's avatar Wentao Ye Committed by GitHub
Browse files

[Feature] Decode Context Parallel support for GPU model runner v2 (#34179)


Signed-off-by: default avataryewentao256 <zhyanwentao@126.com>
parent be3af2d2
...@@ -12,6 +12,7 @@ from vllm.v1.attention.backend import ( ...@@ -12,6 +12,7 @@ from vllm.v1.attention.backend import (
AttentionMetadataBuilder, AttentionMetadataBuilder,
CommonAttentionMetadata, CommonAttentionMetadata,
) )
from vllm.v1.attention.backends.utils import get_dcp_local_seq_lens
from vllm.v1.kv_cache_interface import ( from vllm.v1.kv_cache_interface import (
AttentionSpec, AttentionSpec,
KVCacheConfig, KVCacheConfig,
...@@ -143,6 +144,28 @@ def build_slot_mappings_by_layer( ...@@ -143,6 +144,28 @@ def build_slot_mappings_by_layer(
return slot_mappings_by_layer return slot_mappings_by_layer
def prepare_dcp_local_seq_lens(
dcp_local_seq_lens: torch.Tensor,
seq_lens: torch.Tensor,
num_reqs: int,
dcp_size: int,
dcp_rank: int,
cp_kv_cache_interleave_size: int,
) -> None:
"""Populate the persistent DCP local seq_lens buffer (CUDA graph safe)."""
if dcp_size <= 1:
return
local_seq_lens = get_dcp_local_seq_lens(
seq_lens[:num_reqs],
dcp_size=dcp_size,
dcp_rank=dcp_rank,
cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
)
dcp_local_seq_lens[:num_reqs].copy_(local_seq_lens, non_blocking=True)
dcp_local_seq_lens[num_reqs:].zero_()
def build_attn_metadata( def build_attn_metadata(
attn_metadata_builders: list[AttentionMetadataBuilder], attn_metadata_builders: list[AttentionMetadataBuilder],
num_reqs: int, num_reqs: int,
...@@ -155,9 +178,13 @@ def build_attn_metadata( ...@@ -155,9 +178,13 @@ def build_attn_metadata(
block_tables: Sequence[torch.Tensor], block_tables: Sequence[torch.Tensor],
slot_mappings: torch.Tensor, slot_mappings: torch.Tensor,
kv_cache_config: KVCacheConfig, kv_cache_config: KVCacheConfig,
dcp_local_seq_lens: torch.Tensor | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
seq_lens = seq_lens[:num_reqs] seq_lens = seq_lens[:num_reqs]
if dcp_local_seq_lens is not None:
dcp_local_seq_lens = dcp_local_seq_lens[:num_reqs]
attn_metadata: dict[str, Any] = {} attn_metadata: dict[str, Any] = {}
kv_cache_groups = kv_cache_config.kv_cache_groups kv_cache_groups = kv_cache_config.kv_cache_groups
for i, kv_cache_spec in enumerate(kv_cache_groups): for i, kv_cache_spec in enumerate(kv_cache_groups):
...@@ -175,6 +202,7 @@ def build_attn_metadata( ...@@ -175,6 +202,7 @@ def build_attn_metadata(
block_table_tensor=block_table, block_table_tensor=block_table,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
causal=True, causal=True,
dcp_local_seq_lens=dcp_local_seq_lens,
) )
attn_metadata_builder = attn_metadata_builders[i] attn_metadata_builder = attn_metadata_builders[i]
......
...@@ -4,6 +4,7 @@ from collections.abc import Iterable ...@@ -4,6 +4,7 @@ from collections.abc import Iterable
import torch import torch
from vllm.distributed import get_dcp_group
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.utils import PAD_SLOT_ID from vllm.v1.attention.backends.utils import PAD_SLOT_ID
...@@ -18,19 +19,36 @@ class BlockTables: ...@@ -18,19 +19,36 @@ class BlockTables:
max_num_batched_tokens: int, max_num_batched_tokens: int,
max_model_len: int, max_model_len: int,
device: torch.device, device: torch.device,
cp_kv_cache_interleave_size: int = 1,
): ):
self.block_sizes = block_sizes self.block_sizes = block_sizes
self.max_num_reqs = max_num_reqs self.max_num_reqs = max_num_reqs
self.max_num_batched_tokens = max_num_batched_tokens self.max_num_batched_tokens = max_num_batched_tokens
self.max_model_len = max_model_len self.max_model_len = max_model_len
self.device = device self.device = device
assert cp_kv_cache_interleave_size >= 1
self.cp_kv_cache_interleave_size = cp_kv_cache_interleave_size
try:
dcp = get_dcp_group()
self.dcp_world_size, self.dcp_rank = dcp.world_size, dcp.rank_in_group
except AssertionError:
self.dcp_world_size, self.dcp_rank = 1, 0
# TODO(wentao): PCP supprot
self.total_cp_world_size = self.dcp_world_size
self.total_cp_rank = self.dcp_rank
self.num_kv_cache_groups = len(self.block_sizes) self.num_kv_cache_groups = len(self.block_sizes)
# num_kv_cache_groups x [max_num_reqs, max_num_blocks] # num_kv_cache_groups x [max_num_reqs, max_num_blocks]
self.block_tables: list[StagedWriteTensor] = [] self.block_tables: list[StagedWriteTensor] = []
for i in range(self.num_kv_cache_groups): for i in range(self.num_kv_cache_groups):
block_size = self.block_sizes[i] block_size = self.block_sizes[i]
max_num_blocks = cdiv(self.max_model_len, block_size) # with DCP, a request's KV is sharded across
# ranks, so one physical block on this rank
# corresponds to `block_size * total_cp_world_size`
# tokens in the global (unsharded) sequence.
virtual_block_size = block_size * self.total_cp_world_size
max_num_blocks = cdiv(self.max_model_len, virtual_block_size)
block_table = StagedWriteTensor( block_table = StagedWriteTensor(
(self.max_num_reqs, max_num_blocks), (self.max_num_reqs, max_num_blocks),
dtype=torch.int32, dtype=torch.int32,
...@@ -131,6 +149,9 @@ class BlockTables: ...@@ -131,6 +149,9 @@ class BlockTables:
self.block_sizes_tensor, self.block_sizes_tensor,
self.slot_mappings, self.slot_mappings,
self.slot_mappings.stride(0), self.slot_mappings.stride(0),
TOTAL_CP_WORLD_SIZE=self.total_cp_world_size,
TOTAL_CP_RANK=self.total_cp_rank,
CP_KV_CACHE_INTERLEAVE_SIZE=self.cp_kv_cache_interleave_size,
PAD_ID=PAD_SLOT_ID, PAD_ID=PAD_SLOT_ID,
TRITON_BLOCK_SIZE=1024, # type: ignore TRITON_BLOCK_SIZE=1024, # type: ignore
) )
...@@ -183,6 +204,9 @@ def _compute_slot_mappings_kernel( ...@@ -183,6 +204,9 @@ def _compute_slot_mappings_kernel(
block_sizes, # [num_kv_cache_groups] block_sizes, # [num_kv_cache_groups]
slot_mappings_ptr, # [num_kv_cache_groups, max_num_tokens] slot_mappings_ptr, # [num_kv_cache_groups, max_num_tokens]
slot_mappings_stride, slot_mappings_stride,
TOTAL_CP_WORLD_SIZE: tl.constexpr,
TOTAL_CP_RANK: tl.constexpr,
CP_KV_CACHE_INTERLEAVE_SIZE: tl.constexpr,
PAD_ID: tl.constexpr, PAD_ID: tl.constexpr,
TRITON_BLOCK_SIZE: tl.constexpr, TRITON_BLOCK_SIZE: tl.constexpr,
): ):
...@@ -201,6 +225,7 @@ def _compute_slot_mappings_kernel( ...@@ -201,6 +225,7 @@ def _compute_slot_mappings_kernel(
block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32) block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32)
block_table_stride = tl.load(block_table_strides + group_id) block_table_stride = tl.load(block_table_strides + group_id)
block_size = tl.load(block_sizes + group_id) block_size = tl.load(block_sizes + group_id)
virtual_block_size = block_size * TOTAL_CP_WORLD_SIZE
req_state_idx = tl.load(idx_mapping + batch_idx) req_state_idx = tl.load(idx_mapping + batch_idx)
start_idx = tl.load(query_start_loc + batch_idx) start_idx = tl.load(query_start_loc + batch_idx)
...@@ -208,11 +233,26 @@ def _compute_slot_mappings_kernel( ...@@ -208,11 +233,26 @@ def _compute_slot_mappings_kernel(
for i in range(start_idx, end_idx, TRITON_BLOCK_SIZE): for i in range(start_idx, end_idx, TRITON_BLOCK_SIZE):
offset = i + tl.arange(0, TRITON_BLOCK_SIZE) offset = i + tl.arange(0, TRITON_BLOCK_SIZE)
positions = tl.load(pos + offset, mask=offset < end_idx, other=0) positions = tl.load(pos + offset, mask=offset < end_idx, other=0)
block_indices = positions // block_size block_indices = positions // virtual_block_size
block_numbers = tl.load( block_numbers = tl.load(
block_table_ptr + req_state_idx * block_table_stride + block_indices block_table_ptr + req_state_idx * block_table_stride + block_indices
) )
slot_ids = block_numbers * block_size + positions % block_size virtual_block_offsets = positions - block_indices * virtual_block_size
# determine whether the token is stored on this CP rank.
is_local = (
virtual_block_offsets // CP_KV_CACHE_INTERLEAVE_SIZE
) % TOTAL_CP_WORLD_SIZE == TOTAL_CP_RANK
# mapping virture block offsets to local block offsets.
local_block_offsets = (
virtual_block_offsets // (TOTAL_CP_WORLD_SIZE * CP_KV_CACHE_INTERLEAVE_SIZE)
) * CP_KV_CACHE_INTERLEAVE_SIZE + (
virtual_block_offsets % CP_KV_CACHE_INTERLEAVE_SIZE
)
# physical slot index
slot_ids = block_numbers * block_size + local_block_offsets
slot_ids = tl.where(is_local, slot_ids, PAD_ID)
tl.store(slot_mapping_ptr + offset, slot_ids, mask=offset < end_idx) tl.store(slot_mapping_ptr + offset, slot_ids, mask=offset < end_idx)
......
...@@ -10,6 +10,7 @@ from tqdm import tqdm ...@@ -10,6 +10,7 @@ from tqdm import tqdm
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode from vllm.config.compilation import CUDAGraphMode
from vllm.distributed import get_dcp_group
from vllm.distributed.parallel_state import graph_capture, is_global_first_rank from vllm.distributed.parallel_state import graph_capture, is_global_first_rank
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.v1.attention.backend import AttentionMetadataBuilder from vllm.v1.attention.backend import AttentionMetadataBuilder
...@@ -17,6 +18,7 @@ from vllm.v1.kv_cache_interface import KVCacheConfig ...@@ -17,6 +18,7 @@ from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import ( from vllm.v1.worker.gpu.attn_utils import (
build_attn_metadata, build_attn_metadata,
build_slot_mappings_by_layer, build_slot_mappings_by_layer,
prepare_dcp_local_seq_lens,
) )
from vllm.v1.worker.gpu.block_table import BlockTables from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
...@@ -257,6 +259,23 @@ def prepare_inputs_to_capture( ...@@ -257,6 +259,23 @@ def prepare_inputs_to_capture(
input_buffers.seq_lens[:num_reqs] = num_tokens input_buffers.seq_lens[:num_reqs] = num_tokens
input_buffers.seq_lens[num_reqs:] = 0 input_buffers.seq_lens[num_reqs:] = 0
try:
dcp_group = get_dcp_group()
dcp_world_size = dcp_group.world_size
dcp_rank = dcp_group.rank_in_group
except AssertionError:
dcp_world_size = 1
dcp_rank = 0
if dcp_world_size > 1:
prepare_dcp_local_seq_lens(
input_buffers.dcp_local_seq_lens,
input_buffers.seq_lens,
num_reqs,
dcp_size=dcp_world_size,
dcp_rank=dcp_rank,
cp_kv_cache_interleave_size=block_tables.cp_kv_cache_interleave_size,
)
input_block_tables = [x[:num_reqs] for x in block_tables.input_block_tables] input_block_tables = [x[:num_reqs] for x in block_tables.input_block_tables]
slot_mappings = block_tables.slot_mappings[:, :num_tokens] slot_mappings = block_tables.slot_mappings[:, :num_tokens]
slot_mappings_by_layer = build_slot_mappings_by_layer( slot_mappings_by_layer = build_slot_mappings_by_layer(
...@@ -275,5 +294,6 @@ def prepare_inputs_to_capture( ...@@ -275,5 +294,6 @@ def prepare_inputs_to_capture(
block_tables=input_block_tables, block_tables=input_block_tables,
slot_mappings=slot_mappings, slot_mappings=slot_mappings,
kv_cache_config=kv_cache_config, kv_cache_config=kv_cache_config,
dcp_local_seq_lens=input_buffers.dcp_local_seq_lens,
) )
return attn_metadata, slot_mappings_by_layer return attn_metadata, slot_mappings_by_layer
...@@ -27,6 +27,10 @@ class InputBuffers: ...@@ -27,6 +27,10 @@ class InputBuffers:
max_num_reqs + 1, dtype=torch.int32, device=device max_num_reqs + 1, dtype=torch.int32, device=device
) )
self.seq_lens = torch.zeros(max_num_reqs, dtype=torch.int32, device=device) self.seq_lens = torch.zeros(max_num_reqs, dtype=torch.int32, device=device)
# DCP: per-request local seq_lens buffer
self.dcp_local_seq_lens = torch.zeros(
max_num_reqs, dtype=torch.int32, device=device
)
@dataclass @dataclass
......
...@@ -11,6 +11,7 @@ import torch.nn as nn ...@@ -11,6 +11,7 @@ import torch.nn as nn
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode from vllm.config.compilation import CUDAGraphMode
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
get_dcp_group,
get_pp_group, get_pp_group,
prepare_communication_buffer_for_model, prepare_communication_buffer_for_model,
) )
...@@ -24,6 +25,7 @@ from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE ...@@ -24,6 +25,7 @@ from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
from vllm.v1.worker.cp_utils import check_attention_cp_compatibility
from vllm.v1.worker.gpu.async_utils import AsyncOutput from vllm.v1.worker.gpu.async_utils import AsyncOutput
from vllm.v1.worker.gpu.attn_utils import ( from vllm.v1.worker.gpu.attn_utils import (
build_attn_metadata, build_attn_metadata,
...@@ -31,6 +33,7 @@ from vllm.v1.worker.gpu.attn_utils import ( ...@@ -31,6 +33,7 @@ from vllm.v1.worker.gpu.attn_utils import (
get_kv_cache_spec, get_kv_cache_spec,
init_attn_backend, init_attn_backend,
init_kv_cache, init_kv_cache,
prepare_dcp_local_seq_lens,
) )
from vllm.v1.worker.gpu.block_table import BlockTables from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.buffer_utils import async_copy_to_gpu from vllm.v1.worker.gpu.buffer_utils import async_copy_to_gpu
...@@ -248,11 +251,15 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -248,11 +251,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
max_num_batched_tokens=self.max_num_tokens, max_num_batched_tokens=self.max_num_tokens,
max_model_len=self.max_model_len, max_model_len=self.max_model_len,
device=self.device, device=self.device,
cp_kv_cache_interleave_size=(
self.parallel_config.cp_kv_cache_interleave_size
),
) )
self.attn_backends, self.attn_metadata_builders = init_attn_backend( self.attn_backends, self.attn_metadata_builders = init_attn_backend(
self.kv_cache_config, self.vllm_config, self.device self.kv_cache_config, self.vllm_config, self.device
) )
check_attention_cp_compatibility(self.vllm_config)
if self.do_spec_decode: if self.do_spec_decode:
# HACK(woosuk) # HACK(woosuk)
self.speculator.set_attn( self.speculator.set_attn(
...@@ -294,6 +301,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -294,6 +301,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
block_tables=block_tables, block_tables=block_tables,
slot_mappings=slot_mappings, slot_mappings=slot_mappings,
kv_cache_config=self.kv_cache_config, kv_cache_config=self.kv_cache_config,
dcp_local_seq_lens=self.input_buffers.dcp_local_seq_lens,
) )
input_batch.attn_metadata = attn_metadata input_batch.attn_metadata = attn_metadata
input_batch.slot_mappings = slot_mappings_by_layer input_batch.slot_mappings = slot_mappings_by_layer
...@@ -627,6 +635,19 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -627,6 +635,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
) )
seq_lens = self.input_buffers.seq_lens[:num_reqs] seq_lens = self.input_buffers.seq_lens[:num_reqs]
dcp_size = self.parallel_config.decode_context_parallel_size
if dcp_size > 1:
prepare_dcp_local_seq_lens(
self.input_buffers.dcp_local_seq_lens,
seq_lens,
num_reqs,
dcp_size=dcp_size,
dcp_rank=get_dcp_group().rank_in_group,
cp_kv_cache_interleave_size=(
self.parallel_config.cp_kv_cache_interleave_size
),
)
# Prepare M-RoPE positions. # Prepare M-RoPE positions.
if self.uses_mrope: if self.uses_mrope:
self.mrope_states.prepare_mrope_positions( self.mrope_states.prepare_mrope_positions(
...@@ -674,6 +695,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -674,6 +695,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
block_tables=block_tables, block_tables=block_tables,
slot_mappings=slot_mappings, slot_mappings=slot_mappings,
kv_cache_config=self.kv_cache_config, kv_cache_config=self.kv_cache_config,
dcp_local_seq_lens=self.input_buffers.dcp_local_seq_lens,
) )
input_ids = self.input_buffers.input_ids[:num_tokens_after_padding] input_ids = self.input_buffers.input_ids[:num_tokens_after_padding]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment