Unverified Commit ac201a0e authored by yzds's avatar yzds Committed by GitHub
Browse files

[Feature] Support Decode Context Parallel (DCP) for MLA (#23734)


Signed-off-by: default avatarhongchao <hongchao@msh.team>
Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
Co-authored-by: default avatarhongchao <hongchao@msh.team>
Co-authored-by: default avataryoukaichao <youkaichao@gmail.com>
parent 3c529fc9
......@@ -837,7 +837,7 @@ steps:
- pytest -v -s models/test_oot_registration.py # it needs a clean process
- pytest -v -s plugins/lora_resolvers # unit tests for in-tree lora resolver plugins
- label: Pipeline Parallelism Test # 45min
- label: Pipeline + Context Parallelism Test # 45min
timeout_in_minutes: 60
mirror_hardwares: [amdexperimental]
working_dir: "/vllm-workspace/tests"
......@@ -851,6 +851,7 @@ steps:
commands:
- pytest -v -s distributed/test_pp_cudagraph.py
- pytest -v -s distributed/test_pipeline_parallel.py
# - pytest -v -s distributed/test_context_parallel.py # TODO: enable it on Hopper runners or add triton MLA support
- label: LoRA TP Test (Distributed) # 17 min
timeout_in_minutes: 30
......
......@@ -36,13 +36,6 @@ void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe,
const std::string& kv_cache_dtype,
torch::Tensor& scale);
void cp_fused_concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe,
torch::Tensor& cp_local_token_select_indices,
torch::Tensor& kv_cache,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype,
torch::Tensor& scale);
// Just for unittest
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
const double scale, const std::string& kv_cache_dtype);
......
......@@ -396,51 +396,6 @@ __global__ void concat_and_cache_mla_kernel(
copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank);
}
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void cp_fused_concat_and_cache_mla_kernel(
const scalar_t* __restrict__ kv_c, // [num_full_tokens, kv_lora_rank]
const scalar_t* __restrict__ k_pe, // [num_full_tokens, pe_dim]
const int64_t* __restrict__ cp_local_token_select_indices, // [num_tokens]
cache_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank
// + pe_dim)]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int block_stride, //
const int entry_stride, //
const int kv_c_stride, //
const int k_pe_stride, //
const int kv_lora_rank, //
const int pe_dim, //
const int block_size, //
const float* scale //
) {
const int64_t token_idx = cp_local_token_select_indices[blockIdx.x];
const int64_t slot_idx = slot_mapping[blockIdx.x];
// NOTE: slot_idx can be -1 if the token is padded
if (slot_idx < 0) {
return;
}
const int64_t block_idx = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size;
auto copy = [&](const scalar_t* __restrict__ src, cache_t* __restrict__ dst,
int src_stride, int dst_stride, int size, int offset) {
for (int i = threadIdx.x; i < size; i += blockDim.x) {
const int64_t src_idx = token_idx * src_stride + i;
const int64_t dst_idx =
block_idx * block_stride + block_offset * entry_stride + i + offset;
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
dst[dst_idx] = src[src_idx];
} else {
dst[dst_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(src[src_idx], *scale);
}
}
};
copy(kv_c, kv_cache, kv_c_stride, block_stride, kv_lora_rank, 0);
copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank);
}
} // namespace vllm
// KV_T is the data type of key and value tensors.
......@@ -554,20 +509,6 @@ void reshape_and_cache_flash(
kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \
reinterpret_cast<const float*>(scale.data_ptr()));
// KV_T is the data type of key and value tensors.
// CACHE_T is the stored data type of kv-cache.
// KV_DTYPE is the real data type of kv-cache.
#define CALL_CP_FUSED_CONCAT_AND_CACHE_MLA(KV_T, CACHE_T, KV_DTYPE) \
vllm::cp_fused_concat_and_cache_mla_kernel<KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(kv_c.data_ptr()), \
reinterpret_cast<KV_T*>(k_pe.data_ptr()), \
cp_local_token_select_indices.data_ptr<int64_t>(), \
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), block_stride, entry_stride, \
kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \
reinterpret_cast<const float*>(scale.data_ptr()));
void concat_and_cache_mla(
torch::Tensor& kv_c, // [num_tokens, kv_lora_rank]
torch::Tensor& k_pe, // [num_tokens, pe_dim]
......@@ -606,50 +547,6 @@ void concat_and_cache_mla(
CALL_CONCAT_AND_CACHE_MLA);
}
// Note(hc): cp_fused_concat_and_cache_mla fuses the following three kernel
// calls into one:
// k_c_normed.index_select(0, cp_local_token_select_indices) + \
// k_pe.squeeze(1).index_select(0, cp_local_token_select_indices) + \
// concat_and_cache_mla.
void cp_fused_concat_and_cache_mla(
torch::Tensor& kv_c, // [num_total_tokens, kv_lora_rank]
torch::Tensor& k_pe, // [num_total_tokens, pe_dim]
torch::Tensor& cp_local_token_select_indices, // [num_tokens]
torch::Tensor& kv_cache, // [num_blocks, block_size, (kv_lora_rank +
// pe_dim)]
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
const std::string& kv_cache_dtype, torch::Tensor& scale) {
// NOTE(woosuk): In vLLM V1, key.size(0) can be different from
// slot_mapping.size(0) because of padding for CUDA graphs.
// In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because
// both include padding.
// In vLLM V1, however, key.size(0) can be larger than slot_mapping.size(0)
// since key includes padding for CUDA graphs, while slot_mapping does not.
// In this case, slot_mapping.size(0) represents the actual number of tokens
// before padding.
// For compatibility with both cases, we use slot_mapping.size(0) as the
// number of tokens.
int num_tokens = slot_mapping.size(0);
int kv_lora_rank = kv_c.size(1);
int pe_dim = k_pe.size(1);
int block_size = kv_cache.size(1);
TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim);
int kv_c_stride = kv_c.stride(0);
int k_pe_stride = k_pe.stride(0);
int block_stride = kv_cache.stride(0);
int entry_stride = kv_cache.stride(1);
dim3 grid(num_tokens);
dim3 block(std::min(kv_lora_rank, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(kv_c));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
CALL_CP_FUSED_CONCAT_AND_CACHE_MLA);
}
namespace vllm {
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
......
......@@ -693,16 +693,6 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
" Tensor scale) -> ()");
cache_ops.impl("concat_and_cache_mla", torch::kCUDA, &concat_and_cache_mla);
cache_ops.def(
"cp_fused_concat_and_cache_mla(Tensor kv_c, Tensor k_pe,"
" Tensor cp_local_token_select_indices,"
" Tensor! kv_cache,"
" Tensor slot_mapping,"
" str kv_cache_dtype,"
" Tensor scale) -> ()");
cache_ops.impl("cp_fused_concat_and_cache_mla", torch::kCUDA,
&cp_fused_concat_and_cache_mla);
// Convert the key and value cache to fp8 data type.
cache_ops.def(
"convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, "
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
WARNING: This test runs in both single-node (4 GPUs) and multi-node
(2 node with 2 GPUs each) modes. If the test only uses 2 GPUs, it is
important to set the distributed backend to "mp" to avoid Ray scheduling
all workers in a node other than the head node, which can cause the test
to fail.
"""
import json
import os
from dataclasses import dataclass
from typing import Literal, NamedTuple, Optional
import pytest
from vllm.config import RunnerOption
from vllm.logger import init_logger
from ..models.registry import HF_EXAMPLE_MODELS
from ..utils import compare_two_settings, create_new_process_for_each_test
logger = init_logger("test_context_parallel")
VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
class ParallelSetup(NamedTuple):
tp_size: int
pp_size: int
dcp_size: int
eager_mode: bool
chunked_prefill: bool
class CPTestOptions(NamedTuple):
multi_node_only: bool
load_format: Optional[str] = None
@dataclass
class CPTestSettings:
parallel_setups: list[ParallelSetup]
# NOTE: the length of distributed_backends and
# vllm_major_versions should be the same, and they
# are first zipped together to iterate over all
# test settings.
distributed_backends: list[str]
# vllm major version: "0" for V0, "1" for V1
vllm_major_versions: list[str]
runner: RunnerOption
test_options: CPTestOptions
def __post_init__(self):
if len(self.distributed_backends) != len(self.vllm_major_versions):
raise ValueError(
f"Length mismatch: distributed_backends "
f"({len(self.distributed_backends)}) != "
f"vllm_major_versions ({len(self.vllm_major_versions)})")
@staticmethod
def detailed(
*,
tp_base: int = 4,
pp_base: int = 1,
dcp_base: int = 1,
multi_node_only: bool = False,
runner: RunnerOption = "auto",
load_format: Optional[str] = None,
):
parallel_setups = []
for eager_mode_val in [False]:
for pp_multiplier in [1]:
for dcp_multiplier in [2, 4]:
for chunked_prefill_val in [True]:
parallel_setups.append(
ParallelSetup(tp_size=tp_base,
pp_size=pp_multiplier * pp_base,
dcp_size=dcp_multiplier * dcp_base,
eager_mode=eager_mode_val,
chunked_prefill=chunked_prefill_val))
return CPTestSettings(
parallel_setups=parallel_setups,
distributed_backends=["mp"],
vllm_major_versions=["1"],
runner=runner,
test_options=CPTestOptions(multi_node_only=multi_node_only,
load_format=load_format),
)
def iter_params(self, model_id: str):
opts = self.test_options
for parallel_setup in self.parallel_setups:
for backend, vllm_major_version in zip(self.distributed_backends,
self.vllm_major_versions):
yield (model_id, parallel_setup, backend, vllm_major_version,
self.runner, opts)
def _compare_cp_with_tp(
model_id: str,
parallel_setup: ParallelSetup,
distributed_backend: str,
vllm_major_version: str,
runner: RunnerOption,
test_options: CPTestOptions,
num_gpus_available: int,
*,
method: Literal["generate"],
is_multimodal: bool,
):
(
tp_size,
pp_size,
dcp_size,
eager_mode,
chunked_prefill,
) = parallel_setup
multi_node_only, load_format = test_options
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
model_info.check_transformers_version(on_fail="skip")
trust_remote_code = model_info.trust_remote_code
tokenizer_mode = model_info.tokenizer_mode
hf_overrides = model_info.hf_overrides
if load_format == "dummy":
# Avoid OOM
text_overrides = {
"num_hidden_layers": 4,
"hidden_size": 512,
"intermediate_size": 800,
"num_attention_heads": 4,
"num_key_value_heads": 1,
}
if is_multimodal:
hf_overrides.update({"text_config": text_overrides})
else:
hf_overrides.update(text_overrides)
else:
model_info.check_available_online(on_fail="skip")
if num_gpus_available < tp_size * pp_size:
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
if VLLM_MULTI_NODE and distributed_backend == "mp":
pytest.skip("Skipping multi-node pipeline parallel test for "
"multiprocessing distributed backend")
if multi_node_only and not VLLM_MULTI_NODE:
pytest.skip("Not in multi-node setting")
common_args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"--max-model-len",
"2048",
"--max-num-seqs",
"8",
]
if chunked_prefill:
common_args.append("--enable-chunked-prefill")
if eager_mode:
common_args.append("--enforce-eager")
if runner != "auto":
common_args.extend(["--runner", runner])
if trust_remote_code:
common_args.append("--trust-remote-code")
if tokenizer_mode:
common_args.extend(["--tokenizer-mode", tokenizer_mode])
if load_format:
common_args.extend(["--load-format", load_format])
if hf_overrides:
common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
cp_env = tp_env = {
"VLLM_USE_V1":
vllm_major_version, # Note(hc): DCP only support V1 engine only
}
cp_args = [
*common_args,
"--tensor-parallel-size",
str(tp_size),
"--pipeline-parallel-size",
str(pp_size),
"--decode-context-parallel-size",
str(dcp_size),
"--distributed-executor-backend",
distributed_backend,
]
tp_args = [
*common_args,
"--tensor-parallel-size",
str(tp_size),
"--pipeline-parallel-size",
str(pp_size),
"--distributed-executor-backend",
distributed_backend,
]
try:
compare_two_settings(model_id,
cp_args,
tp_args,
cp_env,
tp_env,
method=method,
max_wait_seconds=720)
except Exception:
testing_ray_compiled_graph = cp_env is not None
if testing_ray_compiled_graph and vllm_major_version == "0":
# Ray Compiled Graph tests are flaky for V0,
# so we don't want to fail the test
logger.exception("Ray Compiled Graph tests failed")
else:
raise
CP_TEXT_GENERATION_MODELS = {
# [MLA attention only]
"deepseek-ai/DeepSeek-V2-Lite-Chat": CPTestSettings.detailed(),
}
CP_TEST_MODELS = [
# TODO support other models
# [LANGUAGE GENERATION]
"deepseek-ai/DeepSeek-V2-Lite-Chat",
]
@pytest.mark.parametrize(
("model_id", "parallel_setup", "distributed_backend", "vllm_major_version",
"runner", "test_options"),
[
params for model_id, settings in CP_TEXT_GENERATION_MODELS.items()
for params in settings.iter_params(model_id)
if model_id in CP_TEST_MODELS
],
)
@create_new_process_for_each_test()
def test_cp_generation(
model_id: str,
parallel_setup: ParallelSetup,
distributed_backend: str,
vllm_major_version: str,
runner: RunnerOption,
test_options: CPTestOptions,
num_gpus_available,
):
_compare_cp_with_tp(model_id,
parallel_setup,
distributed_backend,
vllm_major_version,
runner,
test_options,
num_gpus_available,
method="generate",
is_multimodal=False)
......@@ -1625,20 +1625,6 @@ def concat_and_cache_mla(
scale)
def cp_fused_concat_and_cache_mla(
kv_c: torch.Tensor,
k_pe: torch.Tensor,
cp_local_token_select_indices: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
scale: torch.Tensor,
) -> None:
torch.ops._C_cache_ops.cp_fused_concat_and_cache_mla(
kv_c, k_pe, cp_local_token_select_indices, kv_cache, slot_mapping,
kv_cache_dtype, scale)
def copy_blocks(key_caches: list[torch.Tensor],
value_caches: list[torch.Tensor],
block_mapping: torch.Tensor) -> None:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.distributed.parallel_state import GroupCoordinator
from vllm.triton_utils import tl, triton
@triton.jit
def _correct_attn_cp_out_kernel(outputs_ptr, new_output_ptr, lses_ptr,
vlse_ptr, outputs_stride_B, outputs_stride_H,
outputs_stride_D, lses_stride_N, lses_stride_B,
lses_stride_H, lse_idx, HEAD_DIM: tl.constexpr,
N_ROUNDED: tl.constexpr):
"""
Apply the all-gathered lses to correct each local rank's attention
output. we still need perform a cross-rank reduction to obtain the
final attention output.
Args:
output: [ B, H, D ]
lses : [ N, B, H ]
cp, batch, q_heads, v_head_dim
Return:
output: [ B, H, D ]
lse : [ B, H ]
"""
batch_idx = tl.program_id(axis=0).to(tl.int64)
head_idx = tl.program_id(axis=1).to(tl.int64)
d_offsets = tl.arange(0, HEAD_DIM)
num_n_offsets = tl.arange(0, N_ROUNDED)
# shape = [N]
lse_offsets = num_n_offsets * lses_stride_N + batch_idx * \
lses_stride_B + head_idx * lses_stride_H
# calc final lse
lse = tl.load(lses_ptr + lse_offsets)
lse = tl.where((lse != lse) | (lse == float('inf')), -float('inf'), lse)
lse_max = tl.max(lse, axis=0)
lse -= lse_max
lse_exp = tl.exp(lse)
lse_acc = tl.sum(lse_exp, axis=0)
lse = tl.log(lse_acc)
lse += lse_max
lse_offsets = batch_idx * lses_stride_B + head_idx * lses_stride_H
tl.store(vlse_ptr + lse_offsets, lse)
# shape = [D]
output_offsets = batch_idx * outputs_stride_B + \
head_idx * outputs_stride_H + \
d_offsets * outputs_stride_D
# correct output
lse_offset = lse_idx * lses_stride_N + batch_idx * \
lses_stride_B + head_idx * lses_stride_H
lse_tmp = tl.load(lses_ptr + lse_offset)
lse_finally = lse_tmp - lse
lse_finally = tl.where(
(lse_finally != lse_finally) | (lse_finally == float('inf')),
-float('inf'), lse_finally)
factor = tl.exp(lse_finally)
output = tl.load(outputs_ptr + output_offsets)
output = output * factor
tl.store(new_output_ptr + output_offsets, output)
class CPTritonContext:
""" The CPTritonContext is used to avoid recompilation of the Triton JIT.
"""
def __init__(self):
self.inner_kernel = None
def call_kernel(self, kernel, grid, *regular_args, **const_args):
if self.inner_kernel is None:
self.inner_kernel = kernel[grid](*regular_args, **const_args)
else:
self.inner_kernel[grid](*regular_args)
def correct_attn_out(out: torch.Tensor, lses: torch.Tensor, cp_rank: int,
ctx: CPTritonContext):
"""
Apply the all-gathered lses to correct each local rank's attention
output. we still need perform a cross-rank reduction to obtain the
final attention output.
Args:
output: [ B, H, D ]
lses : [ N, B, H ]
Return:
output: [ B, H, D ]
lse : [ B, H ]
"""
if ctx is None:
ctx = CPTritonContext()
lse = torch.empty_like(lses[0])
grid = (out.shape[0], out.shape[1], 1)
regular_args = (out, out, lses, lse, *out.stride(), *lses.stride(),
cp_rank)
const_args = {
"HEAD_DIM": out.shape[-1],
"N_ROUNDED": lses.shape[0],
}
ctx.call_kernel(_correct_attn_cp_out_kernel, grid, *regular_args,
**const_args)
return out, lse
def cp_lse_ag_out_rs(cp_attn_out: torch.Tensor,
cp_attn_lse: torch.Tensor,
cp_group: GroupCoordinator,
ctx: CPTritonContext = None):
"""
cp_attn_out: [ B, H, D ]
cp_attn_lse: [ B, H ]
"""
if cp_group.world_size == 1:
return cp_attn_out
if ctx is None:
ctx = CPTritonContext()
lses = torch.empty((cp_group.world_size, ) + cp_attn_lse.shape,
dtype=cp_attn_lse.dtype,
device=cp_attn_lse.device)
cp_attn_lse = cp_attn_lse.contiguous()
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
out, _ = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
assert out.is_contiguous()
out = cp_group.reduce_scatter(out, dim=1)
return out
......@@ -105,7 +105,9 @@ def flash_mla_with_kvcache(
descale_q,
descale_k,
)
return out, softmax_lse
# Note(hc): need revisit when we support DCP with decode query_len > 1.
return out.squeeze(1), softmax_lse.squeeze(-1)
#
......
......@@ -170,6 +170,11 @@ class ParallelConfig:
Set to be private as it's not intended to be configured by users.
"""
decode_context_parallel_size: int = 1
"""Number of decode context parallel groups, because the world size does
not change by dcp, it simply reuse the GPUs of TP group, and tp_size
needs to be divisible by dcp_size."""
@property
def world_size_across_dp(self) -> int:
"""world_size_across_dp is TPxPPxDP, it is the size of the world
......
......@@ -904,6 +904,18 @@ def get_tensor_model_parallel_group():
return get_tp_group()
_DCP: Optional[GroupCoordinator] = None
def get_dcp_group() -> GroupCoordinator:
assert _DCP is not None, (
"decode context model parallel group is not initialized")
return _DCP
# kept for backward compatibility
get_context_model_parallel_group = get_dcp_group
_PP: Optional[GroupCoordinator] = None
_DP: Optional[GroupCoordinator] = None
......@@ -1034,6 +1046,7 @@ def init_distributed_environment(
def initialize_model_parallel(
tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1,
decode_context_model_parallel_size: Optional[int] = 1,
backend: Optional[str] = None,
) -> None:
"""
......@@ -1098,6 +1111,23 @@ def initialize_model_parallel(
use_message_queue_broadcaster=True,
group_name="tp")
# Build the DCP model-parallel groups.
global _DCP
assert _DCP is None, (
"decode context model parallel group is already initialized")
# Note(hc): In the current implementation of decode context parallel,
# dcp_size must not exceed tp_size, because the world size does not
# change by DCP, it simply reuse the GPUs of TP group, and split one
# TP group into tp_size//dcp_size DCP groups.
group_ranks = all_ranks.reshape(
-1, decode_context_model_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
_DCP = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
use_message_queue_broadcaster=True,
group_name="dcp")
# Build the pipeline model-parallel groups.
global _PP
assert _PP is None, (
......@@ -1141,6 +1171,7 @@ def initialize_model_parallel(
def ensure_model_parallel_initialized(
tensor_model_parallel_size: int,
pipeline_model_parallel_size: int,
decode_context_model_parallel_size: Optional[int] = 1,
backend: Optional[str] = None,
) -> None:
"""Helper to initialize model parallel groups if they are not initialized,
......@@ -1151,7 +1182,8 @@ def ensure_model_parallel_initialized(
get_world_group().device_group)
if not model_parallel_is_initialized():
initialize_model_parallel(tensor_model_parallel_size,
pipeline_model_parallel_size, backend)
pipeline_model_parallel_size,
decode_context_model_parallel_size, backend)
return
assert (
......@@ -1226,6 +1258,16 @@ def get_tensor_model_parallel_rank():
return get_tp_group().rank_in_group
def get_decode_context_model_parallel_world_size():
"""Return world size for the decode context model parallel group."""
return get_dcp_group().world_size
def get_decode_context_model_parallel_rank():
"""Return my rank for the decode context model parallel group."""
return get_dcp_group().rank_in_group
def get_node_count() -> int:
"""Return the total number of nodes in the distributed environment. """
assert _NODE_COUNT is not None, (
......@@ -1246,6 +1288,11 @@ def destroy_model_parallel():
_PP.destroy()
_PP = None
global _DCP
if _DCP:
_DCP.destroy()
_DCP = None
global _DP
if _DP:
_DP.destroy()
......
......@@ -306,6 +306,8 @@ class EngineArgs:
# number of P/D disaggregation (or other disaggregation) workers
pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
decode_context_parallel_size: int = \
ParallelConfig.decode_context_parallel_size
data_parallel_size: int = ParallelConfig.data_parallel_size
data_parallel_rank: Optional[int] = None
data_parallel_start_rank: Optional[int] = None
......@@ -636,6 +638,9 @@ class EngineArgs:
**parallel_kwargs["pipeline_parallel_size"])
parallel_group.add_argument("--tensor-parallel-size", "-tp",
**parallel_kwargs["tensor_parallel_size"])
parallel_group.add_argument(
"--decode-context-parallel-size", "-dcp",
**parallel_kwargs["decode_context_parallel_size"])
parallel_group.add_argument("--data-parallel-size", "-dp",
**parallel_kwargs["data_parallel_size"])
parallel_group.add_argument(
......@@ -1156,6 +1161,17 @@ class EngineArgs:
# global layers in interleaved sliding window models.
sliding_window = model_config.get_sliding_window()
# Note(hc): In the current implementation of decode context
# parallel(DCP), tp_size needs to be divisible by dcp_size,
# because the world size does not change by dcp, it simply
# reuse the GPUs of TP group, and split one TP group into
# tp_size//dcp_size DCP groups.
assert self.tensor_parallel_size % self.decode_context_parallel_size \
== 0, (
f"tp_size={self.tensor_parallel_size} must be divisible by"
f"dcp_size={self.decode_context_parallel_size}."
)
cache_config = CacheConfig(
block_size=self.block_size,
gpu_memory_utilization=self.gpu_memory_utilization,
......@@ -1306,6 +1322,7 @@ class EngineArgs:
distributed_executor_backend=self.distributed_executor_backend,
worker_cls=self.worker_cls,
worker_extension_cls=self.worker_extension_cls,
decode_context_parallel_size=self.decode_context_parallel_size,
)
speculative_config = self.create_speculative_config(
......
......@@ -201,10 +201,11 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
AttentionMetadata,
MLAAttentionImpl)
from vllm.attention.backends.utils import get_mla_dims
from vllm.attention.ops.common import cp_lse_ag_out_rs
from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.attention.utils.fa_utils import get_flash_attn_version
from vllm.config import VllmConfig
from vllm.distributed.parallel_state import is_global_first_rank
from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase,
......@@ -323,6 +324,13 @@ class MLACommonPrefillMetadata:
seq_lens: torch.Tensor
workspace: torch.Tensor
# for mla DCP
cp_chunk_seq_lens: Optional[list[list[int]]] = None
origin_context_lens: Optional[list[int]] = None
cp_cu_seq_lens: Optional[torch.Tensor] = None
chunk_size: Optional[int] = None
cu_seq_lens_lst: Optional[list[list[int]]] = None
block_table: torch.Tensor
query_start_loc: torch.Tensor
max_query_len: int
......@@ -444,6 +452,13 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
parallel_config)
self.mla_dims = get_mla_dims(self.model_config)
self.aot_schedule = current_platform.is_cuda()
try:
self.dcp_world_size = get_dcp_group().world_size
self.dcp_rank = get_dcp_group().rank_in_group
except AssertionError:
# DCP might not be initialized in testing
self.dcp_world_size = 1
self.dcp_rank = 0
# Dont try to access the runner on AMD
if self.aot_schedule:
......@@ -465,6 +480,21 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
128 * 1024)
assert self.chunked_prefill_workspace_size >= \
scheduler_config.max_num_seqs * cache_config.block_size
if self.dcp_world_size > 1:
# Note(hc): The local kvcache is incomplete when DCP is triggered,
# an additional kvcache allgather across the DCP group is therefore
# required, so the workspace has to be enlarged by 1/DCP relative
# to the original TP allocation.
assert self.chunked_prefill_workspace_size % \
self.dcp_world_size == 0
self.chunked_prefill_workspace = torch.empty(
(self.chunked_prefill_workspace_size +
self.chunked_prefill_workspace_size // self.dcp_world_size,
self.model_config.get_head_size()),
dtype=self.model_config.dtype,
device=device,
)
else:
self.chunked_prefill_workspace = torch.empty(
(self.chunked_prefill_workspace_size,
self.model_config.get_head_size()),
......@@ -631,6 +661,12 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
split_decodes_and_prefills(common_attn_metadata,
decode_threshold=self.reorder_batch_threshold)
# Note(hc): update seq_lens of decode reqs under DCP.
if self.dcp_world_size > 1:
seq_lens[:num_decodes] = seq_lens[:num_decodes] \
// self.dcp_world_size + (self.dcp_rank <= \
(seq_lens[:num_decodes] - 1) % self.dcp_world_size)
assert num_decodes + num_prefills == num_reqs
assert num_decode_tokens + num_prefill_tokens == num_tokens
......@@ -639,6 +675,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
reqs_start = num_decodes # prefill_start
context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs]
# Note(hc): The context lengths in the perspective of dcp rank0.
cp_context_lens_cpu = torch.ceil(context_lens_cpu.float() /
self.dcp_world_size).int()
origin_context_lens = context_lens_cpu.tolist()
max_context_len_cpu = context_lens_cpu.max().item()
num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
prefill_query_start_loc = query_start_loc[
......@@ -691,14 +731,60 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
out=cu_seq_lens_cpu[:, 1:],
dtype=torch.int32)
if self.dcp_world_size > 1:
# Note(hc): The above max_context_chunk already enforces
# block_size alignment, DCP just need the block_size can
# be divisible by dcp_world_size, because DCP use
# cp_gather_cache which not require `cp_chunk_starts`
# aligned to page_size.
assert max_context_chunk % self.dcp_world_size == 0
cp_max_context_chunk = max_context_chunk // \
self.dcp_world_size
cp_chunk_starts = \
torch.arange(num_chunks, dtype=torch.int32) \
.unsqueeze(1).expand(-1, num_prefills) \
* cp_max_context_chunk
cp_chunk_ends = torch.min(
cp_context_lens_cpu.unsqueeze(0),
cp_chunk_starts + cp_max_context_chunk)
cp_chunk_seq_lens = (cp_chunk_ends -
cp_chunk_starts).clamp(min=0)
cp_cu_seq_lens_cpu = torch.zeros(num_chunks,
num_prefills + 1,
dtype=torch.int32,
pin_memory=True)
torch.cumsum(cp_chunk_seq_lens,
dim=1,
out=cp_cu_seq_lens_cpu[:, 1:],
dtype=torch.int32)
chunked_context_metadata_cls = \
CudnnPrefillMetadata.ChunkedContextMetadata \
if self._use_cudnn_prefill else \
MLACommonPrefillMetadata.ChunkedContextMetadata
if self.dcp_world_size > 1:
chunked_context_metadata = \
chunked_context_metadata_cls(
cu_seq_lens=cu_seq_lens_cpu \
.to(device, non_blocking=True),
starts=cp_chunk_starts.to(device, non_blocking=True),
seq_tot=cp_chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
seq_lens=chunk_seq_lens,
workspace=self.chunked_prefill_workspace,
cp_chunk_seq_lens=cp_chunk_seq_lens.tolist(),
origin_context_lens=origin_context_lens,
cp_cu_seq_lens=cp_cu_seq_lens_cpu \
.to(device, non_blocking=True),
chunk_size=max_context_chunk,
cu_seq_lens_lst=cu_seq_lens_cpu.tolist(),
)
else:
chunked_context_metadata = \
chunked_context_metadata_cls(
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
cu_seq_lens=cu_seq_lens_cpu \
.to(device, non_blocking=True),
starts=chunk_starts.to(device, non_blocking=True),
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
......@@ -757,6 +843,71 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
return attn_metadata
def reorg_kvcache(
allgatered_kv_c_normed: torch.Tensor,
allgatered_k_pe: torch.Tensor,
cp_chunk_seq_lens_lst: list[int],
origin_context_lens: list[int],
cp_world_size: int,
sum_seq_len: int,
max_seq_len: int,
chunk_size: int,
chunk_idx: int,
toks: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
reorg kvcache after cp local gather to tp layout for attn kernel.
Args:
cp_chunk_seq_lens_lst: chunk context lengths under CP.
origin_context_lens: origin full context lengths under CP.
cp_world_size: CP size.
sum_seq_len: the sum of cp_chunk_seq_lens_lst.
max_seq_len: the max value of cp_chunk_seq_lens_lst.
chunk_size: equals to max_context_chunk from
chunked_context_metadata building.
chunk_idx: chunk idx of chunked_prefill.
toks: the number of tokens for local gather cache.
"""
kv_c_segments = []
k_pe_segments = []
src_token_idx = 0
max_seq_len_check = 0
for cp_chunk_seq_len, origin_context_len in zip(cp_chunk_seq_lens_lst,
origin_context_lens):
chunk_context_len = chunk_size
if cp_chunk_seq_len != 0:
chunk_context_len = min(
chunk_context_len, origin_context_len - chunk_size * chunk_idx)
cp_target_rank = (chunk_context_len - 1) % cp_world_size
cur_seq_len = 0
for rank in range(cp_world_size):
if rank > cp_target_rank and cp_chunk_seq_len:
real_cp_chunk_seq_len = cp_chunk_seq_len - 1
else:
real_cp_chunk_seq_len = cp_chunk_seq_len
if real_cp_chunk_seq_len:
kv_c_segment = allgatered_kv_c_normed[rank * toks +
src_token_idx:rank *
toks + src_token_idx +
real_cp_chunk_seq_len]
k_pe_segment = allgatered_k_pe[rank * toks +
src_token_idx:rank * toks +
src_token_idx +
real_cp_chunk_seq_len]
kv_c_segments.append(kv_c_segment)
k_pe_segments.append(k_pe_segment)
cur_seq_len += real_cp_chunk_seq_len
max_seq_len_check = max(max_seq_len_check, cur_seq_len)
src_token_idx += cp_chunk_seq_len
reorganized_kv_c_normed = torch.cat(kv_c_segments, dim=0)
reorganized_k_pe = torch.cat(k_pe_segments, dim=0)
assert reorganized_kv_c_normed.shape[0] == sum_seq_len
assert reorganized_k_pe.shape[0] == sum_seq_len
assert max_seq_len_check == max_seq_len
return reorganized_kv_c_normed, reorganized_k_pe
class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
"""
NOTE: Please read the comment at the top of the file before trying to
......@@ -836,6 +987,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
self.vllm_flash_attn_version == 3
and current_platform.get_device_capability()[0] == 9)
self.dcp_world_size: Optional[int] = None
def _flash_attn_varlen_diff_headdims(self,
q,
k,
......@@ -1152,6 +1305,108 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
return output, output_lse
def _context_parallel_compute_prefill_context(
self,
q: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
k_scale: torch.Tensor,
dcp_world_size: int,
):
assert k_scale is None, "DCP not support sacled kvcache now."
assert attn_metadata.prefill is not None
prefill_metadata = attn_metadata.prefill
assert prefill_metadata.chunked_context is not None
assert prefill_metadata.chunked_context.cp_chunk_seq_lens is not None
assert prefill_metadata.chunked_context.origin_context_lens is not None
assert prefill_metadata.chunked_context.cp_cu_seq_lens is not None
assert prefill_metadata.chunked_context.chunk_size is not None
assert prefill_metadata.chunked_context.cu_seq_lens_lst is not None
output = None
iters = len(prefill_metadata.chunked_context.seq_tot)
workspace = prefill_metadata.chunked_context.workspace
for i in range(iters):
toks = prefill_metadata.chunked_context.seq_tot[i]
ops.cp_gather_cache(
src_cache=kv_c_and_k_pe_cache,
dst=workspace,
block_table=prefill_metadata.block_table,
cu_seq_lens=prefill_metadata.chunked_context.cp_cu_seq_lens[i],
batch_size=attn_metadata.num_prefills,
seq_starts=prefill_metadata.chunked_context.starts[i],
)
# workspace
# |------- N tokens --------|--------- N*dcp_size tokens ----------|
# |<- use for loca_gather ->|<--------- use for allgather -------->|
allgather_offset = workspace.shape[0] // (dcp_world_size + 1)
assert allgather_offset * (dcp_world_size +
1) == workspace.shape[0]
assert toks <= allgather_offset
local_gathered_kvcache = workspace[:toks]
cur_allgather_workspace = workspace[
allgather_offset:allgather_offset * (1 + dcp_world_size)]
assert toks * dcp_world_size <= cur_allgather_workspace.shape[0]
cur_allgather_kvcache = cur_allgather_workspace[:toks *
dcp_world_size]
cur_allgather_kvcache.copy_(get_dcp_group().all_gather(
local_gathered_kvcache, dim=0))
assert cur_allgather_kvcache.shape[
-1] == self.kv_lora_rank + self.qk_rope_head_dim
allgatered_kv_c_normed, allgatered_k_pe = \
cur_allgather_kvcache.unsqueeze(
1).split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed, k_pe = reorg_kvcache(
allgatered_kv_c_normed,
allgatered_k_pe,
cp_chunk_seq_lens_lst=prefill_metadata.chunked_context.
cp_chunk_seq_lens[i],
origin_context_lens=prefill_metadata.chunked_context.
origin_context_lens,
cp_world_size=dcp_world_size,
sum_seq_len=prefill_metadata.chunked_context.cu_seq_lens_lst[i]
[-1],
max_seq_len=prefill_metadata.chunked_context.max_seq_lens[i],
chunk_size=prefill_metadata.chunked_context.chunk_size,
chunk_idx=i,
toks=toks)
kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
dim=-1)
attn_output, attn_softmax_lse = self._run_prefill_context_chunk(
prefill=prefill_metadata,
chunk_idx=i,
q=q,
k=k,
v=v,
)
if output is None:
output = attn_output
output_lse = attn_softmax_lse
else:
output_tmp = torch.empty_like(output)
output_lse_tmp = torch.empty_like(output_lse)
merge_attn_states(
output=output_tmp,
output_lse=output_lse_tmp,
prefix_output=output,
prefix_lse=output_lse,
suffix_output=attn_output,
suffix_lse=attn_softmax_lse,
)
output = output_tmp
output_lse = output_lse_tmp
return output, output_lse
def _forward_prefill(
self,
q: torch.Tensor,
......@@ -1162,6 +1417,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
k_scale: torch.Tensor,
) -> torch.Tensor:
assert attn_metadata.prefill is not None
assert self.dcp_world_size is not None
has_context = attn_metadata.prefill.chunked_context is not None
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\
......@@ -1181,7 +1437,14 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
if has_context:
suffix_output, suffix_lse = output
context_output, context_lse = self._compute_prefill_context( \
if self.dcp_world_size > 1:
context_output, context_lse = \
self._context_parallel_compute_prefill_context(
q, kv_c_and_k_pe_cache, attn_metadata,
k_scale=None, dcp_world_size=self.dcp_world_size)
else:
context_output, context_lse = \
self._compute_prefill_context(
q, kv_c_and_k_pe_cache, attn_metadata, k_scale)
output = torch.empty_like(suffix_output)
......@@ -1202,12 +1465,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
@abstractmethod
def _forward_decode(
self,
ql_nope: torch.Tensor,
q_pe: torch.Tensor,
q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: M,
layer: AttentionLayer,
) -> torch.Tensor:
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
raise NotImplementedError
def forward(
......@@ -1235,6 +1497,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
# same expert outputs.
return output.fill_(0)
if self.dcp_world_size is None:
self.dcp_world_size = get_dcp_group().world_size
fp8_attention = self.kv_cache_dtype.startswith("fp8")
num_actual_toks = attn_metadata.num_actual_tokens
......@@ -1313,7 +1578,26 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
layer._q_scale)
decode_q_pe = decode_q_pe.reshape(q_pe_shape)
output[:num_decode_tokens] = self._forward_decode(
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata, layer)
decode_q = (decode_ql_nope, decode_q_pe)
if self.dcp_world_size > 1:
assert not fp8_attention, "DCP not support fp8 kvcache now."
# concatenate decode_ql_nope and decode_q_pe -> (B, N, L + P)
decode_q = torch.cat(decode_q, dim=-1)
# decode_q do allgather in head dim.
decode_q = get_dcp_group().all_gather(decode_q, dim=1)
# call decode attn
attn_out, lse = self._forward_decode(decode_q, kv_cache,
attn_metadata, layer)
# recorect dcp attn_out with lse.
if self.dcp_world_size > 1:
assert lse is not None, (
"For a mla backend want to enable"
"DCP, it is mandatory that the corresponding decode attn"
"kernel return the softmax lse.")
attn_out = cp_lse_ag_out_rs(attn_out, lse, get_dcp_group())
# v_up projection
output[:num_decode_tokens] = self._v_up_proj(attn_out)
return output_padded
......@@ -232,7 +232,7 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
self._workspace.get_buf(),
self.scale, self._num_kv_splits)
return self._v_up_proj(o)
return o
# TODO: Currently we leave it here only for backup in case something is
# wrong with the new SM100 CUTLASS MLA kernel
......@@ -265,21 +265,25 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
attn_metadata.decode.seq_lens,
attn_metadata.decode.block_table, self.scale)
return self._v_up_proj(o)
return o
def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
q: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
layer: AttentionLayer,
) -> torch.Tensor:
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
if type(q) is tuple:
q_nope, q_pe = q
else:
q_nope, q_pe = torch.split(
q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
if self._use_old_cutlass_mla:
# TODO: Remove the old cutlass MLA kernel after more extensive
# testing
return self._old_forward_decode(q_nope, q_pe, kv_c_and_k_pe_cache,
attn_metadata)
attn_metadata), None
return self._sm100_forward_decode(q_nope, q_pe, kv_c_and_k_pe_cache,
attn_metadata)
attn_metadata), None
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import ClassVar, Optional
from typing import ClassVar, Optional, Union
import torch
......@@ -154,15 +154,20 @@ class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]):
def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: FlashAttnMLAMetadata,
layer: AttentionLayer,
) -> torch.Tensor:
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
if type(q) is tuple:
q_nope, q_pe = q
else:
q_nope, q_pe = torch.split(
q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError(
"FP8 FlashAttention MLA not yet supported")
......
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import ClassVar, Optional
from typing import ClassVar, Optional, Union
import torch
......@@ -169,20 +169,20 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: FlashMLAMetadata,
layer: AttentionLayer,
) -> torch.Tensor:
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
q = torch.cat([q_nope, q_pe], dim=-1)\
.unsqueeze(1) # Add seqlen dim of 1 (decode)
if type(q) is tuple:
q = torch.cat(q, dim=-1)
o, _ = flash_mla_with_kvcache(
q=q,
assert isinstance(q, torch.Tensor)
o, lse = flash_mla_with_kvcache(
q=q.unsqueeze(1), # Add seqlen dim of 1 (decode)
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens,
......@@ -196,4 +196,4 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
descale_k=layer._k_scale.reshape(1),
)
return self._v_up_proj(o)
return o, lse
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import ClassVar, Optional
from typing import ClassVar, Optional, Union
import torch
......@@ -220,18 +220,19 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: AiterMLAMetadata,
layer: AttentionLayer,
) -> torch.Tensor:
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
B = q_nope.shape[0]
if type(q) is tuple:
q = torch.cat(q, dim=-1)
q = torch.cat([q_nope, q_pe], dim=-1)
assert isinstance(q, torch.Tensor)
B = q.shape[0]
o = torch.zeros(B,
self.num_heads,
self.kv_lora_rank,
......@@ -249,4 +250,4 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
attn_metadata.decode.paged_kv_indices,
attn_metadata.decode.paged_kv_last_page_len)
return self._v_up_proj(o)
return o, None
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
from typing import Optional, Union
import torch
......@@ -123,21 +123,22 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
layer: AttentionLayer,
) -> torch.Tensor:
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError("FP8 Triton MLA not yet supported")
B = q_nope.shape[0]
if type(q) is tuple:
q = torch.cat(q, dim=-1)
q = torch.cat([q_nope, q_pe], dim=-1)
assert isinstance(q, torch.Tensor)
B = q.shape[0]
o = torch.zeros(B,
self.num_heads,
self.kv_lora_rank,
......@@ -171,4 +172,4 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
attn_metadata.decode.seq_lens, attn_logits,
num_kv_splits, self.scale, PAGE_SIZE)
return self._v_up_proj(o)
return o, None
......@@ -24,6 +24,7 @@ class KVCacheCoordinator(ABC):
use_eagle: bool,
enable_caching: bool,
enable_kv_cache_events: bool,
dcp_world_size: int,
):
self.kv_cache_config = kv_cache_config
self.max_model_len = max_model_len
......@@ -39,6 +40,7 @@ class KVCacheCoordinator(ABC):
kv_cache_spec=kv_cache_group.kv_cache_spec,
block_pool=self.block_pool,
kv_cache_group_id=i,
dcp_world_size=dcp_world_size,
) for i, kv_cache_group in enumerate(
self.kv_cache_config.kv_cache_groups))
......@@ -197,9 +199,14 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
"""
def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int,
use_eagle: bool, enable_kv_cache_events: bool):
super().__init__(kv_cache_config, max_model_len, use_eagle, False,
enable_kv_cache_events)
use_eagle: bool, enable_kv_cache_events: bool,
dcp_world_size: int):
super().__init__(kv_cache_config,
max_model_len,
use_eagle,
False,
enable_kv_cache_events,
dcp_world_size=dcp_world_size)
self.num_single_type_manager = len(self.single_type_managers)
def get_num_common_prefix_blocks(self, request_id: str,
......@@ -225,12 +232,19 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int,
use_eagle: bool, enable_caching: bool,
enable_kv_cache_events: bool):
super().__init__(kv_cache_config, max_model_len, use_eagle,
enable_caching, enable_kv_cache_events)
enable_kv_cache_events: bool, dcp_world_size: int):
super().__init__(kv_cache_config,
max_model_len,
use_eagle,
enable_caching,
enable_kv_cache_events,
dcp_world_size=dcp_world_size)
self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[
0].kv_cache_spec
self.block_size = self.kv_cache_spec.block_size
self.dcp_world_size = dcp_world_size
if dcp_world_size > 1:
self.block_size *= dcp_world_size
assert len(self.kv_cache_config.kv_cache_groups) == 1, (
"UnitaryKVCacheCoordinator assumes only one kv cache group")
......@@ -246,6 +260,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
block_pool=self.block_pool,
kv_cache_spec=self.kv_cache_spec,
use_eagle=self.use_eagle,
dcp_world_size=self.dcp_world_size,
)
return hit_blocks, len(hit_blocks[0]) * self.block_size
......@@ -261,9 +276,14 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int,
use_eagle: bool, enable_caching: bool,
enable_kv_cache_events: bool):
super().__init__(kv_cache_config, max_model_len, use_eagle,
enable_caching, enable_kv_cache_events)
enable_kv_cache_events: bool, dcp_world_size: int):
super().__init__(kv_cache_config,
max_model_len,
use_eagle,
enable_caching,
enable_kv_cache_events,
dcp_world_size=dcp_world_size)
assert dcp_world_size == 1, "DCP not support hybrid attn now."
self.verify_and_split_kv_cache_groups()
def verify_and_split_kv_cache_groups(self) -> None:
......@@ -394,17 +414,27 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
return hit_blocks, hit_length
def get_kv_cache_coordinator(
kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool,
def get_kv_cache_coordinator(kv_cache_config: KVCacheConfig,
max_model_len: int, use_eagle: bool,
enable_caching: bool,
enable_kv_cache_events: bool) -> KVCacheCoordinator:
enable_kv_cache_events: bool,
dcp_world_size: int) -> KVCacheCoordinator:
if not enable_caching:
return KVCacheCoordinatorNoPrefixCache(kv_cache_config, max_model_len,
return KVCacheCoordinatorNoPrefixCache(kv_cache_config,
max_model_len,
use_eagle,
enable_kv_cache_events)
enable_kv_cache_events,
dcp_world_size=dcp_world_size)
if len(kv_cache_config.kv_cache_groups) == 1:
return UnitaryKVCacheCoordinator(kv_cache_config, max_model_len,
use_eagle, enable_caching,
enable_kv_cache_events)
return HybridKVCacheCoordinator(kv_cache_config, max_model_len, use_eagle,
enable_caching, enable_kv_cache_events)
return UnitaryKVCacheCoordinator(kv_cache_config,
max_model_len,
use_eagle,
enable_caching,
enable_kv_cache_events,
dcp_world_size=dcp_world_size)
return HybridKVCacheCoordinator(kv_cache_config,
max_model_len,
use_eagle,
enable_caching,
enable_kv_cache_events,
dcp_world_size=dcp_world_size)
......@@ -91,6 +91,7 @@ class KVCacheManager:
use_eagle: bool = False,
log_stats: bool = False,
enable_kv_cache_events: bool = False,
dcp_world_size: int = 1,
) -> None:
self.max_model_len = max_model_len
......@@ -109,12 +110,20 @@ class KVCacheManager:
self.block_size = kv_cache_config.kv_cache_groups[
0].kv_cache_spec.block_size
if dcp_world_size > 1:
assert len(kv_cache_config.kv_cache_groups) == 1
# Note(hc): need revisit. When both DCP and any future
# PCP are enabled, the block_size may need to be scaled
# by a factor of dcp_size × pcp_size?
self.block_size *= dcp_world_size
self.coordinator = get_kv_cache_coordinator(
kv_cache_config=kv_cache_config,
max_model_len=self.max_model_len,
use_eagle=self.use_eagle,
enable_caching=self.enable_caching,
enable_kv_cache_events=enable_kv_cache_events,
dcp_world_size=dcp_world_size,
)
self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups)
self.block_pool = self.coordinator.block_pool
......
......@@ -846,6 +846,12 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
)
num_tokens = num_blocks * vllm_config.cache_config.block_size
if vllm_config.parallel_config.decode_context_parallel_size > 1:
num_tokens *= vllm_config.parallel_config.decode_context_parallel_size
logger.info(
"Multiplying the GPU KV cache size by the dcp_world_size %d.",
vllm_config.parallel_config.decode_context_parallel_size)
num_tokens_str = f"{num_tokens:,}"
logger.info("GPU KV cache size: %s tokens", num_tokens_str)
max_model_len_str = f"{vllm_config.model_config.max_model_len:,}"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment