Unverified Commit 95be2a7f authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Model Runner V2] Minor simplification for DCP (#34786)


Signed-off-by: default avatarWoosuk Kwon <woosuk@inferact.ai>
parent 0e60c925
...@@ -12,7 +12,6 @@ from vllm.v1.attention.backend import ( ...@@ -12,7 +12,6 @@ from vllm.v1.attention.backend import (
AttentionMetadataBuilder, AttentionMetadataBuilder,
CommonAttentionMetadata, CommonAttentionMetadata,
) )
from vllm.v1.attention.backends.utils import get_dcp_local_seq_lens
from vllm.v1.kv_cache_interface import ( from vllm.v1.kv_cache_interface import (
AttentionSpec, AttentionSpec,
KVCacheConfig, KVCacheConfig,
...@@ -144,28 +143,6 @@ def build_slot_mappings_by_layer( ...@@ -144,28 +143,6 @@ def build_slot_mappings_by_layer(
return slot_mappings_by_layer return slot_mappings_by_layer
def prepare_dcp_local_seq_lens(
dcp_local_seq_lens: torch.Tensor,
seq_lens: torch.Tensor,
num_reqs: int,
dcp_size: int,
dcp_rank: int,
cp_kv_cache_interleave_size: int,
) -> None:
"""Populate the persistent DCP local seq_lens buffer (CUDA graph safe)."""
if dcp_size <= 1:
return
local_seq_lens = get_dcp_local_seq_lens(
seq_lens[:num_reqs],
dcp_size=dcp_size,
dcp_rank=dcp_rank,
cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
)
dcp_local_seq_lens[:num_reqs].copy_(local_seq_lens, non_blocking=True)
dcp_local_seq_lens[num_reqs:].zero_()
def build_attn_metadata( def build_attn_metadata(
attn_metadata_builders: list[AttentionMetadataBuilder], attn_metadata_builders: list[AttentionMetadataBuilder],
num_reqs: int, num_reqs: int,
...@@ -181,7 +158,6 @@ def build_attn_metadata( ...@@ -181,7 +158,6 @@ def build_attn_metadata(
dcp_local_seq_lens: torch.Tensor | None = None, dcp_local_seq_lens: torch.Tensor | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
seq_lens = seq_lens[:num_reqs] seq_lens = seq_lens[:num_reqs]
if dcp_local_seq_lens is not None: if dcp_local_seq_lens is not None:
dcp_local_seq_lens = dcp_local_seq_lens[:num_reqs] dcp_local_seq_lens = dcp_local_seq_lens[:num_reqs]
......
...@@ -4,7 +4,6 @@ from collections.abc import Iterable ...@@ -4,7 +4,6 @@ from collections.abc import Iterable
import torch import torch
from vllm.distributed import get_dcp_group
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.utils import PAD_SLOT_ID from vllm.v1.attention.backends.utils import PAD_SLOT_ID
...@@ -19,36 +18,29 @@ class BlockTables: ...@@ -19,36 +18,29 @@ class BlockTables:
max_num_batched_tokens: int, max_num_batched_tokens: int,
max_model_len: int, max_model_len: int,
device: torch.device, device: torch.device,
cp_kv_cache_interleave_size: int = 1, cp_size: int = 1,
cp_rank: int = 0,
cp_interleave: int = 1,
): ):
self.block_sizes = block_sizes self.block_sizes = block_sizes
self.max_num_reqs = max_num_reqs self.max_num_reqs = max_num_reqs
self.max_num_batched_tokens = max_num_batched_tokens self.max_num_batched_tokens = max_num_batched_tokens
self.max_model_len = max_model_len self.max_model_len = max_model_len
self.device = device self.device = device
assert cp_kv_cache_interleave_size >= 1
self.cp_kv_cache_interleave_size = cp_kv_cache_interleave_size
try: self.cp_size = cp_size
dcp = get_dcp_group() self.cp_rank = cp_rank
self.dcp_world_size, self.dcp_rank = dcp.world_size, dcp.rank_in_group self.cp_interleave = cp_interleave
except AssertionError:
self.dcp_world_size, self.dcp_rank = 1, 0
# TODO(wentao): PCP supprot
self.total_cp_world_size = self.dcp_world_size
self.total_cp_rank = self.dcp_rank
self.num_kv_cache_groups = len(self.block_sizes) self.num_kv_cache_groups = len(self.block_sizes)
# num_kv_cache_groups x [max_num_reqs, max_num_blocks] # num_kv_cache_groups x [max_num_reqs, max_num_blocks]
self.block_tables: list[StagedWriteTensor] = [] self.block_tables: list[StagedWriteTensor] = []
for i in range(self.num_kv_cache_groups): for i in range(self.num_kv_cache_groups):
block_size = self.block_sizes[i] block_size = self.block_sizes[i]
# with DCP, a request's KV is sharded across # When using DCP, each request's KV cache is sharded among different ranks.
# ranks, so one physical block on this rank # As a result, one block on the current rank covers `block_size * cp_size`
# corresponds to `block_size * total_cp_world_size` # tokens in the full, global (unsharded) sequence.
# tokens in the global (unsharded) sequence. max_num_blocks = cdiv(self.max_model_len, block_size * self.cp_size)
virtual_block_size = block_size * self.total_cp_world_size
max_num_blocks = cdiv(self.max_model_len, virtual_block_size)
block_table = StagedWriteTensor( block_table = StagedWriteTensor(
(self.max_num_reqs, max_num_blocks), (self.max_num_reqs, max_num_blocks),
dtype=torch.int32, dtype=torch.int32,
...@@ -149,9 +141,9 @@ class BlockTables: ...@@ -149,9 +141,9 @@ class BlockTables:
self.block_sizes_tensor, self.block_sizes_tensor,
self.slot_mappings, self.slot_mappings,
self.slot_mappings.stride(0), self.slot_mappings.stride(0),
TOTAL_CP_WORLD_SIZE=self.total_cp_world_size, self.cp_rank,
TOTAL_CP_RANK=self.total_cp_rank, CP_SIZE=self.cp_size,
CP_KV_CACHE_INTERLEAVE_SIZE=self.cp_kv_cache_interleave_size, CP_INTERLEAVE=self.cp_interleave,
PAD_ID=PAD_SLOT_ID, PAD_ID=PAD_SLOT_ID,
TRITON_BLOCK_SIZE=1024, # type: ignore TRITON_BLOCK_SIZE=1024, # type: ignore
) )
...@@ -204,9 +196,9 @@ def _compute_slot_mappings_kernel( ...@@ -204,9 +196,9 @@ def _compute_slot_mappings_kernel(
block_sizes, # [num_kv_cache_groups] block_sizes, # [num_kv_cache_groups]
slot_mappings_ptr, # [num_kv_cache_groups, max_num_tokens] slot_mappings_ptr, # [num_kv_cache_groups, max_num_tokens]
slot_mappings_stride, slot_mappings_stride,
TOTAL_CP_WORLD_SIZE: tl.constexpr, cp_rank,
TOTAL_CP_RANK: tl.constexpr, CP_SIZE: tl.constexpr,
CP_KV_CACHE_INTERLEAVE_SIZE: tl.constexpr, CP_INTERLEAVE: tl.constexpr,
PAD_ID: tl.constexpr, PAD_ID: tl.constexpr,
TRITON_BLOCK_SIZE: tl.constexpr, TRITON_BLOCK_SIZE: tl.constexpr,
): ):
...@@ -225,7 +217,6 @@ def _compute_slot_mappings_kernel( ...@@ -225,7 +217,6 @@ def _compute_slot_mappings_kernel(
block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32) block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32)
block_table_stride = tl.load(block_table_strides + group_id) block_table_stride = tl.load(block_table_strides + group_id)
block_size = tl.load(block_sizes + group_id) block_size = tl.load(block_sizes + group_id)
virtual_block_size = block_size * TOTAL_CP_WORLD_SIZE
req_state_idx = tl.load(idx_mapping + batch_idx) req_state_idx = tl.load(idx_mapping + batch_idx)
start_idx = tl.load(query_start_loc + batch_idx) start_idx = tl.load(query_start_loc + batch_idx)
...@@ -233,26 +224,25 @@ def _compute_slot_mappings_kernel( ...@@ -233,26 +224,25 @@ def _compute_slot_mappings_kernel(
for i in range(start_idx, end_idx, TRITON_BLOCK_SIZE): for i in range(start_idx, end_idx, TRITON_BLOCK_SIZE):
offset = i + tl.arange(0, TRITON_BLOCK_SIZE) offset = i + tl.arange(0, TRITON_BLOCK_SIZE)
positions = tl.load(pos + offset, mask=offset < end_idx, other=0) positions = tl.load(pos + offset, mask=offset < end_idx, other=0)
block_indices = positions // virtual_block_size
block_indices = positions // (block_size * CP_SIZE)
block_offsets = positions % (block_size * CP_SIZE)
block_numbers = tl.load( block_numbers = tl.load(
block_table_ptr + req_state_idx * block_table_stride + block_indices block_table_ptr + req_state_idx * block_table_stride + block_indices
) )
virtual_block_offsets = positions - block_indices * virtual_block_size
# determine whether the token is stored on this CP rank. if CP_SIZE == 1:
is_local = ( # Common case: Context parallelism is not used.
virtual_block_offsets // CP_KV_CACHE_INTERLEAVE_SIZE slot_ids = block_numbers * block_size + block_offsets
) % TOTAL_CP_WORLD_SIZE == TOTAL_CP_RANK else:
# mapping virture block offsets to local block offsets. # Context parallelism is used.
local_block_offsets = ( is_local = block_offsets // CP_INTERLEAVE % CP_SIZE == cp_rank
virtual_block_offsets // (TOTAL_CP_WORLD_SIZE * CP_KV_CACHE_INTERLEAVE_SIZE) rounds = block_offsets // (CP_INTERLEAVE * CP_SIZE)
) * CP_KV_CACHE_INTERLEAVE_SIZE + ( remainder = block_offsets % CP_INTERLEAVE
virtual_block_offsets % CP_KV_CACHE_INTERLEAVE_SIZE local_offsets = rounds * CP_INTERLEAVE + remainder
) slot_ids = block_numbers * block_size + local_offsets
slot_ids = tl.where(is_local, slot_ids, PAD_ID)
# physical slot index
slot_ids = block_numbers * block_size + local_block_offsets
slot_ids = tl.where(is_local, slot_ids, PAD_ID)
tl.store(slot_mapping_ptr + offset, slot_ids, mask=offset < end_idx) tl.store(slot_mapping_ptr + offset, slot_ids, mask=offset < end_idx)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.triton_utils import tl, triton
def prepare_dcp_local_seq_lens(
dcp_local_seq_lens: torch.Tensor,
seq_lens: torch.Tensor,
num_reqs: int,
dcp_size: int,
dcp_rank: int,
cp_interleave: int,
) -> None:
"""Populate the persistent DCP local seq_lens buffer (CUDA graph safe)."""
if dcp_size == 1:
return
max_num_reqs = dcp_local_seq_lens.shape[0]
BLOCK_SIZE = 128
num_blocks = triton.cdiv(max_num_reqs, BLOCK_SIZE)
_dcp_local_seq_lens_kernel[(num_blocks,)](
dcp_local_seq_lens,
seq_lens,
dcp_size,
dcp_rank,
cp_interleave,
num_reqs,
max_num_reqs,
BLOCK_SIZE,
)
@triton.jit
def _dcp_local_seq_lens_kernel(
out_ptr,
seq_lens_ptr,
dcp_size,
dcp_rank,
cp_interleave,
num_reqs,
max_num_reqs,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
block = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
seq_lens = tl.load(seq_lens_ptr + block, mask=block < num_reqs)
# Distribute KV cache among different ranks, in a round-robin manner.
rounds = seq_lens // (dcp_size * cp_interleave)
remainder = seq_lens % (dcp_size * cp_interleave)
remainder = tl.maximum(remainder - dcp_rank * cp_interleave, 0)
remainder = tl.minimum(remainder, cp_interleave)
local_seq_lens = rounds * cp_interleave + remainder
# For [num_reqs, max_num_reqs), pad with 0
local_seq_lens = tl.where(block < num_reqs, local_seq_lens, 0)
tl.store(out_ptr + block, local_seq_lens, mask=block < max_num_reqs)
...@@ -10,7 +10,6 @@ from tqdm import tqdm ...@@ -10,7 +10,6 @@ from tqdm import tqdm
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode from vllm.config.compilation import CUDAGraphMode
from vllm.distributed import get_dcp_group
from vllm.distributed.parallel_state import graph_capture, is_global_first_rank from vllm.distributed.parallel_state import graph_capture, is_global_first_rank
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.v1.attention.backend import AttentionMetadataBuilder from vllm.v1.attention.backend import AttentionMetadataBuilder
...@@ -18,7 +17,6 @@ from vllm.v1.kv_cache_interface import KVCacheConfig ...@@ -18,7 +17,6 @@ from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import ( from vllm.v1.worker.gpu.attn_utils import (
build_attn_metadata, build_attn_metadata,
build_slot_mappings_by_layer, build_slot_mappings_by_layer,
prepare_dcp_local_seq_lens,
) )
from vllm.v1.worker.gpu.block_table import BlockTables from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
...@@ -259,22 +257,8 @@ def prepare_inputs_to_capture( ...@@ -259,22 +257,8 @@ def prepare_inputs_to_capture(
input_buffers.seq_lens[:num_reqs] = num_tokens input_buffers.seq_lens[:num_reqs] = num_tokens
input_buffers.seq_lens[num_reqs:] = 0 input_buffers.seq_lens[num_reqs:] = 0
try: input_buffers.dcp_local_seq_lens[:num_reqs] = num_tokens
dcp_group = get_dcp_group() input_buffers.dcp_local_seq_lens[num_reqs:] = 0
dcp_world_size = dcp_group.world_size
dcp_rank = dcp_group.rank_in_group
except AssertionError:
dcp_world_size = 1
dcp_rank = 0
if dcp_world_size > 1:
prepare_dcp_local_seq_lens(
input_buffers.dcp_local_seq_lens,
input_buffers.seq_lens,
num_reqs,
dcp_size=dcp_world_size,
dcp_rank=dcp_rank,
cp_kv_cache_interleave_size=block_tables.cp_kv_cache_interleave_size,
)
input_block_tables = [x[:num_reqs] for x in block_tables.input_block_tables] input_block_tables = [x[:num_reqs] for x in block_tables.input_block_tables]
slot_mappings = block_tables.slot_mappings[:, :num_tokens] slot_mappings = block_tables.slot_mappings[:, :num_tokens]
......
...@@ -33,10 +33,10 @@ from vllm.v1.worker.gpu.attn_utils import ( ...@@ -33,10 +33,10 @@ from vllm.v1.worker.gpu.attn_utils import (
get_kv_cache_spec, get_kv_cache_spec,
init_attn_backend, init_attn_backend,
init_kv_cache, init_kv_cache,
prepare_dcp_local_seq_lens,
) )
from vllm.v1.worker.gpu.block_table import BlockTables from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.buffer_utils import async_copy_to_gpu from vllm.v1.worker.gpu.buffer_utils import async_copy_to_gpu
from vllm.v1.worker.gpu.cp_utils import prepare_dcp_local_seq_lens
from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager
from vllm.v1.worker.gpu.dp_utils import ( from vllm.v1.worker.gpu.dp_utils import (
get_cudagraph_and_dp_padding, get_cudagraph_and_dp_padding,
...@@ -192,6 +192,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -192,6 +192,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.is_first_pp_rank = True self.is_first_pp_rank = True
self.is_last_pp_rank = True self.is_last_pp_rank = True
# Decode context parallelism.
self.dcp_size = self.parallel_config.decode_context_parallel_size
self.use_dcp = self.dcp_size > 1
self.dcp_rank = get_dcp_group().rank_in_group if self.use_dcp else 0
self.cp_interleave = self.parallel_config.cp_kv_cache_interleave_size
def update_max_model_len(self, max_model_len: int) -> None: def update_max_model_len(self, max_model_len: int) -> None:
self.max_model_len = max_model_len self.max_model_len = max_model_len
self.req_states.max_model_len = max_model_len self.req_states.max_model_len = max_model_len
...@@ -251,9 +257,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -251,9 +257,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
max_num_batched_tokens=self.max_num_tokens, max_num_batched_tokens=self.max_num_tokens,
max_model_len=self.max_model_len, max_model_len=self.max_model_len,
device=self.device, device=self.device,
cp_kv_cache_interleave_size=( cp_size=self.dcp_size,
self.parallel_config.cp_kv_cache_interleave_size cp_rank=self.dcp_rank,
), cp_interleave=self.cp_interleave,
) )
self.attn_backends, self.attn_metadata_builders = init_attn_backend( self.attn_backends, self.attn_metadata_builders = init_attn_backend(
...@@ -636,18 +642,17 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -636,18 +642,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
) )
seq_lens = self.input_buffers.seq_lens[:num_reqs] seq_lens = self.input_buffers.seq_lens[:num_reqs]
dcp_size = self.parallel_config.decode_context_parallel_size if self.use_dcp:
if dcp_size > 1: # Prepare dcp local seq_lens.
prepare_dcp_local_seq_lens( prepare_dcp_local_seq_lens(
self.input_buffers.dcp_local_seq_lens, self.input_buffers.dcp_local_seq_lens,
seq_lens, self.input_buffers.seq_lens,
num_reqs, num_reqs,
dcp_size=dcp_size, self.dcp_size,
dcp_rank=get_dcp_group().rank_in_group, self.dcp_rank,
cp_kv_cache_interleave_size=( self.cp_interleave,
self.parallel_config.cp_kv_cache_interleave_size
),
) )
dcp_local_seq_lens = self.input_buffers.dcp_local_seq_lens[:num_reqs]
# Prepare M-RoPE positions. # Prepare M-RoPE positions.
if self.uses_mrope: if self.uses_mrope:
...@@ -696,7 +701,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -696,7 +701,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
block_tables=block_tables, block_tables=block_tables,
slot_mappings=slot_mappings, slot_mappings=slot_mappings,
kv_cache_config=self.kv_cache_config, kv_cache_config=self.kv_cache_config,
dcp_local_seq_lens=self.input_buffers.dcp_local_seq_lens, dcp_local_seq_lens=dcp_local_seq_lens,
) )
input_ids = self.input_buffers.input_ids[:num_tokens_after_padding] input_ids = self.input_buffers.input_ids[:num_tokens_after_padding]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment