Unverified Commit ac201a0e authored by yzds's avatar yzds Committed by GitHub
Browse files

[Feature] Support Decode Context Parallel (DCP) for MLA (#23734)


Signed-off-by: default avatarhongchao <hongchao@msh.team>
Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
Co-authored-by: default avatarhongchao <hongchao@msh.team>
Co-authored-by: default avataryoukaichao <youkaichao@gmail.com>
parent 3c529fc9
......@@ -837,7 +837,7 @@ steps:
- pytest -v -s models/test_oot_registration.py # it needs a clean process
- pytest -v -s plugins/lora_resolvers # unit tests for in-tree lora resolver plugins
- label: Pipeline Parallelism Test # 45min
- label: Pipeline + Context Parallelism Test # 45min
timeout_in_minutes: 60
mirror_hardwares: [amdexperimental]
working_dir: "/vllm-workspace/tests"
......@@ -851,6 +851,7 @@ steps:
commands:
- pytest -v -s distributed/test_pp_cudagraph.py
- pytest -v -s distributed/test_pipeline_parallel.py
# - pytest -v -s distributed/test_context_parallel.py # TODO: enable it on Hopper runners or add triton MLA support
- label: LoRA TP Test (Distributed) # 17 min
timeout_in_minutes: 30
......
......@@ -36,13 +36,6 @@ void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe,
const std::string& kv_cache_dtype,
torch::Tensor& scale);
void cp_fused_concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe,
torch::Tensor& cp_local_token_select_indices,
torch::Tensor& kv_cache,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype,
torch::Tensor& scale);
// Just for unittest
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
const double scale, const std::string& kv_cache_dtype);
......
......@@ -396,51 +396,6 @@ __global__ void concat_and_cache_mla_kernel(
copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank);
}
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void cp_fused_concat_and_cache_mla_kernel(
const scalar_t* __restrict__ kv_c, // [num_full_tokens, kv_lora_rank]
const scalar_t* __restrict__ k_pe, // [num_full_tokens, pe_dim]
const int64_t* __restrict__ cp_local_token_select_indices, // [num_tokens]
cache_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank
// + pe_dim)]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int block_stride, //
const int entry_stride, //
const int kv_c_stride, //
const int k_pe_stride, //
const int kv_lora_rank, //
const int pe_dim, //
const int block_size, //
const float* scale //
) {
const int64_t token_idx = cp_local_token_select_indices[blockIdx.x];
const int64_t slot_idx = slot_mapping[blockIdx.x];
// NOTE: slot_idx can be -1 if the token is padded
if (slot_idx < 0) {
return;
}
const int64_t block_idx = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size;
auto copy = [&](const scalar_t* __restrict__ src, cache_t* __restrict__ dst,
int src_stride, int dst_stride, int size, int offset) {
for (int i = threadIdx.x; i < size; i += blockDim.x) {
const int64_t src_idx = token_idx * src_stride + i;
const int64_t dst_idx =
block_idx * block_stride + block_offset * entry_stride + i + offset;
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
dst[dst_idx] = src[src_idx];
} else {
dst[dst_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(src[src_idx], *scale);
}
}
};
copy(kv_c, kv_cache, kv_c_stride, block_stride, kv_lora_rank, 0);
copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank);
}
} // namespace vllm
// KV_T is the data type of key and value tensors.
......@@ -554,20 +509,6 @@ void reshape_and_cache_flash(
kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \
reinterpret_cast<const float*>(scale.data_ptr()));
// KV_T is the data type of key and value tensors.
// CACHE_T is the stored data type of kv-cache.
// KV_DTYPE is the real data type of kv-cache.
#define CALL_CP_FUSED_CONCAT_AND_CACHE_MLA(KV_T, CACHE_T, KV_DTYPE) \
vllm::cp_fused_concat_and_cache_mla_kernel<KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(kv_c.data_ptr()), \
reinterpret_cast<KV_T*>(k_pe.data_ptr()), \
cp_local_token_select_indices.data_ptr<int64_t>(), \
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), block_stride, entry_stride, \
kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \
reinterpret_cast<const float*>(scale.data_ptr()));
void concat_and_cache_mla(
torch::Tensor& kv_c, // [num_tokens, kv_lora_rank]
torch::Tensor& k_pe, // [num_tokens, pe_dim]
......@@ -606,50 +547,6 @@ void concat_and_cache_mla(
CALL_CONCAT_AND_CACHE_MLA);
}
// Note(hc): cp_fused_concat_and_cache_mla fuses the following three kernel
// calls into one:
// k_c_normed.index_select(0, cp_local_token_select_indices) + \
// k_pe.squeeze(1).index_select(0, cp_local_token_select_indices) + \
// concat_and_cache_mla.
void cp_fused_concat_and_cache_mla(
torch::Tensor& kv_c, // [num_total_tokens, kv_lora_rank]
torch::Tensor& k_pe, // [num_total_tokens, pe_dim]
torch::Tensor& cp_local_token_select_indices, // [num_tokens]
torch::Tensor& kv_cache, // [num_blocks, block_size, (kv_lora_rank +
// pe_dim)]
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
const std::string& kv_cache_dtype, torch::Tensor& scale) {
// NOTE(woosuk): In vLLM V1, key.size(0) can be different from
// slot_mapping.size(0) because of padding for CUDA graphs.
// In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because
// both include padding.
// In vLLM V1, however, key.size(0) can be larger than slot_mapping.size(0)
// since key includes padding for CUDA graphs, while slot_mapping does not.
// In this case, slot_mapping.size(0) represents the actual number of tokens
// before padding.
// For compatibility with both cases, we use slot_mapping.size(0) as the
// number of tokens.
int num_tokens = slot_mapping.size(0);
int kv_lora_rank = kv_c.size(1);
int pe_dim = k_pe.size(1);
int block_size = kv_cache.size(1);
TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim);
int kv_c_stride = kv_c.stride(0);
int k_pe_stride = k_pe.stride(0);
int block_stride = kv_cache.stride(0);
int entry_stride = kv_cache.stride(1);
dim3 grid(num_tokens);
dim3 block(std::min(kv_lora_rank, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(kv_c));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
CALL_CP_FUSED_CONCAT_AND_CACHE_MLA);
}
namespace vllm {
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
......
......@@ -693,16 +693,6 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
" Tensor scale) -> ()");
cache_ops.impl("concat_and_cache_mla", torch::kCUDA, &concat_and_cache_mla);
cache_ops.def(
"cp_fused_concat_and_cache_mla(Tensor kv_c, Tensor k_pe,"
" Tensor cp_local_token_select_indices,"
" Tensor! kv_cache,"
" Tensor slot_mapping,"
" str kv_cache_dtype,"
" Tensor scale) -> ()");
cache_ops.impl("cp_fused_concat_and_cache_mla", torch::kCUDA,
&cp_fused_concat_and_cache_mla);
// Convert the key and value cache to fp8 data type.
cache_ops.def(
"convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, "
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
WARNING: This test runs in both single-node (4 GPUs) and multi-node
(2 node with 2 GPUs each) modes. If the test only uses 2 GPUs, it is
important to set the distributed backend to "mp" to avoid Ray scheduling
all workers in a node other than the head node, which can cause the test
to fail.
"""
import json
import os
from dataclasses import dataclass
from typing import Literal, NamedTuple, Optional
import pytest
from vllm.config import RunnerOption
from vllm.logger import init_logger
from ..models.registry import HF_EXAMPLE_MODELS
from ..utils import compare_two_settings, create_new_process_for_each_test
logger = init_logger("test_context_parallel")
VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
class ParallelSetup(NamedTuple):
tp_size: int
pp_size: int
dcp_size: int
eager_mode: bool
chunked_prefill: bool
class CPTestOptions(NamedTuple):
multi_node_only: bool
load_format: Optional[str] = None
@dataclass
class CPTestSettings:
parallel_setups: list[ParallelSetup]
# NOTE: the length of distributed_backends and
# vllm_major_versions should be the same, and they
# are first zipped together to iterate over all
# test settings.
distributed_backends: list[str]
# vllm major version: "0" for V0, "1" for V1
vllm_major_versions: list[str]
runner: RunnerOption
test_options: CPTestOptions
def __post_init__(self):
if len(self.distributed_backends) != len(self.vllm_major_versions):
raise ValueError(
f"Length mismatch: distributed_backends "
f"({len(self.distributed_backends)}) != "
f"vllm_major_versions ({len(self.vllm_major_versions)})")
@staticmethod
def detailed(
*,
tp_base: int = 4,
pp_base: int = 1,
dcp_base: int = 1,
multi_node_only: bool = False,
runner: RunnerOption = "auto",
load_format: Optional[str] = None,
):
parallel_setups = []
for eager_mode_val in [False]:
for pp_multiplier in [1]:
for dcp_multiplier in [2, 4]:
for chunked_prefill_val in [True]:
parallel_setups.append(
ParallelSetup(tp_size=tp_base,
pp_size=pp_multiplier * pp_base,
dcp_size=dcp_multiplier * dcp_base,
eager_mode=eager_mode_val,
chunked_prefill=chunked_prefill_val))
return CPTestSettings(
parallel_setups=parallel_setups,
distributed_backends=["mp"],
vllm_major_versions=["1"],
runner=runner,
test_options=CPTestOptions(multi_node_only=multi_node_only,
load_format=load_format),
)
def iter_params(self, model_id: str):
opts = self.test_options
for parallel_setup in self.parallel_setups:
for backend, vllm_major_version in zip(self.distributed_backends,
self.vllm_major_versions):
yield (model_id, parallel_setup, backend, vllm_major_version,
self.runner, opts)
def _compare_cp_with_tp(
model_id: str,
parallel_setup: ParallelSetup,
distributed_backend: str,
vllm_major_version: str,
runner: RunnerOption,
test_options: CPTestOptions,
num_gpus_available: int,
*,
method: Literal["generate"],
is_multimodal: bool,
):
(
tp_size,
pp_size,
dcp_size,
eager_mode,
chunked_prefill,
) = parallel_setup
multi_node_only, load_format = test_options
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
model_info.check_transformers_version(on_fail="skip")
trust_remote_code = model_info.trust_remote_code
tokenizer_mode = model_info.tokenizer_mode
hf_overrides = model_info.hf_overrides
if load_format == "dummy":
# Avoid OOM
text_overrides = {
"num_hidden_layers": 4,
"hidden_size": 512,
"intermediate_size": 800,
"num_attention_heads": 4,
"num_key_value_heads": 1,
}
if is_multimodal:
hf_overrides.update({"text_config": text_overrides})
else:
hf_overrides.update(text_overrides)
else:
model_info.check_available_online(on_fail="skip")
if num_gpus_available < tp_size * pp_size:
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
if VLLM_MULTI_NODE and distributed_backend == "mp":
pytest.skip("Skipping multi-node pipeline parallel test for "
"multiprocessing distributed backend")
if multi_node_only and not VLLM_MULTI_NODE:
pytest.skip("Not in multi-node setting")
common_args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"--max-model-len",
"2048",
"--max-num-seqs",
"8",
]
if chunked_prefill:
common_args.append("--enable-chunked-prefill")
if eager_mode:
common_args.append("--enforce-eager")
if runner != "auto":
common_args.extend(["--runner", runner])
if trust_remote_code:
common_args.append("--trust-remote-code")
if tokenizer_mode:
common_args.extend(["--tokenizer-mode", tokenizer_mode])
if load_format:
common_args.extend(["--load-format", load_format])
if hf_overrides:
common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
cp_env = tp_env = {
"VLLM_USE_V1":
vllm_major_version, # Note(hc): DCP only support V1 engine only
}
cp_args = [
*common_args,
"--tensor-parallel-size",
str(tp_size),
"--pipeline-parallel-size",
str(pp_size),
"--decode-context-parallel-size",
str(dcp_size),
"--distributed-executor-backend",
distributed_backend,
]
tp_args = [
*common_args,
"--tensor-parallel-size",
str(tp_size),
"--pipeline-parallel-size",
str(pp_size),
"--distributed-executor-backend",
distributed_backend,
]
try:
compare_two_settings(model_id,
cp_args,
tp_args,
cp_env,
tp_env,
method=method,
max_wait_seconds=720)
except Exception:
testing_ray_compiled_graph = cp_env is not None
if testing_ray_compiled_graph and vllm_major_version == "0":
# Ray Compiled Graph tests are flaky for V0,
# so we don't want to fail the test
logger.exception("Ray Compiled Graph tests failed")
else:
raise
CP_TEXT_GENERATION_MODELS = {
# [MLA attention only]
"deepseek-ai/DeepSeek-V2-Lite-Chat": CPTestSettings.detailed(),
}
CP_TEST_MODELS = [
# TODO support other models
# [LANGUAGE GENERATION]
"deepseek-ai/DeepSeek-V2-Lite-Chat",
]
@pytest.mark.parametrize(
("model_id", "parallel_setup", "distributed_backend", "vllm_major_version",
"runner", "test_options"),
[
params for model_id, settings in CP_TEXT_GENERATION_MODELS.items()
for params in settings.iter_params(model_id)
if model_id in CP_TEST_MODELS
],
)
@create_new_process_for_each_test()
def test_cp_generation(
model_id: str,
parallel_setup: ParallelSetup,
distributed_backend: str,
vllm_major_version: str,
runner: RunnerOption,
test_options: CPTestOptions,
num_gpus_available,
):
_compare_cp_with_tp(model_id,
parallel_setup,
distributed_backend,
vllm_major_version,
runner,
test_options,
num_gpus_available,
method="generate",
is_multimodal=False)
......@@ -1625,20 +1625,6 @@ def concat_and_cache_mla(
scale)
def cp_fused_concat_and_cache_mla(
kv_c: torch.Tensor,
k_pe: torch.Tensor,
cp_local_token_select_indices: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
scale: torch.Tensor,
) -> None:
torch.ops._C_cache_ops.cp_fused_concat_and_cache_mla(
kv_c, k_pe, cp_local_token_select_indices, kv_cache, slot_mapping,
kv_cache_dtype, scale)
def copy_blocks(key_caches: list[torch.Tensor],
value_caches: list[torch.Tensor],
block_mapping: torch.Tensor) -> None:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.distributed.parallel_state import GroupCoordinator
from vllm.triton_utils import tl, triton
@triton.jit
def _correct_attn_cp_out_kernel(outputs_ptr, new_output_ptr, lses_ptr,
vlse_ptr, outputs_stride_B, outputs_stride_H,
outputs_stride_D, lses_stride_N, lses_stride_B,
lses_stride_H, lse_idx, HEAD_DIM: tl.constexpr,
N_ROUNDED: tl.constexpr):
"""
Apply the all-gathered lses to correct each local rank's attention
output. we still need perform a cross-rank reduction to obtain the
final attention output.
Args:
output: [ B, H, D ]
lses : [ N, B, H ]
cp, batch, q_heads, v_head_dim
Return:
output: [ B, H, D ]
lse : [ B, H ]
"""
batch_idx = tl.program_id(axis=0).to(tl.int64)
head_idx = tl.program_id(axis=1).to(tl.int64)
d_offsets = tl.arange(0, HEAD_DIM)
num_n_offsets = tl.arange(0, N_ROUNDED)
# shape = [N]
lse_offsets = num_n_offsets * lses_stride_N + batch_idx * \
lses_stride_B + head_idx * lses_stride_H
# calc final lse
lse = tl.load(lses_ptr + lse_offsets)
lse = tl.where((lse != lse) | (lse == float('inf')), -float('inf'), lse)
lse_max = tl.max(lse, axis=0)
lse -= lse_max
lse_exp = tl.exp(lse)
lse_acc = tl.sum(lse_exp, axis=0)
lse = tl.log(lse_acc)
lse += lse_max
lse_offsets = batch_idx * lses_stride_B + head_idx * lses_stride_H
tl.store(vlse_ptr + lse_offsets, lse)
# shape = [D]
output_offsets = batch_idx * outputs_stride_B + \
head_idx * outputs_stride_H + \
d_offsets * outputs_stride_D
# correct output
lse_offset = lse_idx * lses_stride_N + batch_idx * \
lses_stride_B + head_idx * lses_stride_H
lse_tmp = tl.load(lses_ptr + lse_offset)
lse_finally = lse_tmp - lse
lse_finally = tl.where(
(lse_finally != lse_finally) | (lse_finally == float('inf')),
-float('inf'), lse_finally)
factor = tl.exp(lse_finally)
output = tl.load(outputs_ptr + output_offsets)
output = output * factor
tl.store(new_output_ptr + output_offsets, output)
class CPTritonContext:
""" The CPTritonContext is used to avoid recompilation of the Triton JIT.
"""
def __init__(self):
self.inner_kernel = None
def call_kernel(self, kernel, grid, *regular_args, **const_args):
if self.inner_kernel is None:
self.inner_kernel = kernel[grid](*regular_args, **const_args)
else:
self.inner_kernel[grid](*regular_args)
def correct_attn_out(out: torch.Tensor, lses: torch.Tensor, cp_rank: int,
ctx: CPTritonContext):
"""
Apply the all-gathered lses to correct each local rank's attention
output. we still need perform a cross-rank reduction to obtain the
final attention output.
Args:
output: [ B, H, D ]
lses : [ N, B, H ]
Return:
output: [ B, H, D ]
lse : [ B, H ]
"""
if ctx is None:
ctx = CPTritonContext()
lse = torch.empty_like(lses[0])
grid = (out.shape[0], out.shape[1], 1)
regular_args = (out, out, lses, lse, *out.stride(), *lses.stride(),
cp_rank)
const_args = {
"HEAD_DIM": out.shape[-1],
"N_ROUNDED": lses.shape[0],
}
ctx.call_kernel(_correct_attn_cp_out_kernel, grid, *regular_args,
**const_args)
return out, lse
def cp_lse_ag_out_rs(cp_attn_out: torch.Tensor,
cp_attn_lse: torch.Tensor,
cp_group: GroupCoordinator,
ctx: CPTritonContext = None):
"""
cp_attn_out: [ B, H, D ]
cp_attn_lse: [ B, H ]
"""
if cp_group.world_size == 1:
return cp_attn_out
if ctx is None:
ctx = CPTritonContext()
lses = torch.empty((cp_group.world_size, ) + cp_attn_lse.shape,
dtype=cp_attn_lse.dtype,
device=cp_attn_lse.device)
cp_attn_lse = cp_attn_lse.contiguous()
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
out, _ = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
assert out.is_contiguous()
out = cp_group.reduce_scatter(out, dim=1)
return out
......@@ -105,7 +105,9 @@ def flash_mla_with_kvcache(
descale_q,
descale_k,
)
return out, softmax_lse
# Note(hc): need revisit when we support DCP with decode query_len > 1.
return out.squeeze(1), softmax_lse.squeeze(-1)
#
......
......@@ -170,6 +170,11 @@ class ParallelConfig:
Set to be private as it's not intended to be configured by users.
"""
decode_context_parallel_size: int = 1
"""Number of decode context parallel groups, because the world size does
not change by dcp, it simply reuse the GPUs of TP group, and tp_size
needs to be divisible by dcp_size."""
@property
def world_size_across_dp(self) -> int:
"""world_size_across_dp is TPxPPxDP, it is the size of the world
......
......@@ -904,6 +904,18 @@ def get_tensor_model_parallel_group():
return get_tp_group()
_DCP: Optional[GroupCoordinator] = None
def get_dcp_group() -> GroupCoordinator:
assert _DCP is not None, (
"decode context model parallel group is not initialized")
return _DCP
# kept for backward compatibility
get_context_model_parallel_group = get_dcp_group
_PP: Optional[GroupCoordinator] = None
_DP: Optional[GroupCoordinator] = None
......@@ -1034,6 +1046,7 @@ def init_distributed_environment(
def initialize_model_parallel(
tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1,
decode_context_model_parallel_size: Optional[int] = 1,
backend: Optional[str] = None,
) -> None:
"""
......@@ -1098,6 +1111,23 @@ def initialize_model_parallel(
use_message_queue_broadcaster=True,
group_name="tp")
# Build the DCP model-parallel groups.
global _DCP
assert _DCP is None, (
"decode context model parallel group is already initialized")
# Note(hc): In the current implementation of decode context parallel,
# dcp_size must not exceed tp_size, because the world size does not
# change by DCP, it simply reuse the GPUs of TP group, and split one
# TP group into tp_size//dcp_size DCP groups.
group_ranks = all_ranks.reshape(
-1, decode_context_model_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
_DCP = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
use_message_queue_broadcaster=True,
group_name="dcp")
# Build the pipeline model-parallel groups.
global _PP
assert _PP is None, (
......@@ -1141,6 +1171,7 @@ def initialize_model_parallel(
def ensure_model_parallel_initialized(
tensor_model_parallel_size: int,
pipeline_model_parallel_size: int,
decode_context_model_parallel_size: Optional[int] = 1,
backend: Optional[str] = None,
) -> None:
"""Helper to initialize model parallel groups if they are not initialized,
......@@ -1151,7 +1182,8 @@ def ensure_model_parallel_initialized(
get_world_group().device_group)
if not model_parallel_is_initialized():
initialize_model_parallel(tensor_model_parallel_size,
pipeline_model_parallel_size, backend)
pipeline_model_parallel_size,
decode_context_model_parallel_size, backend)
return
assert (
......@@ -1226,6 +1258,16 @@ def get_tensor_model_parallel_rank():
return get_tp_group().rank_in_group
def get_decode_context_model_parallel_world_size():
"""Return world size for the decode context model parallel group."""
return get_dcp_group().world_size
def get_decode_context_model_parallel_rank():
"""Return my rank for the decode context model parallel group."""
return get_dcp_group().rank_in_group
def get_node_count() -> int:
"""Return the total number of nodes in the distributed environment. """
assert _NODE_COUNT is not None, (
......@@ -1246,6 +1288,11 @@ def destroy_model_parallel():
_PP.destroy()
_PP = None
global _DCP
if _DCP:
_DCP.destroy()
_DCP = None
global _DP
if _DP:
_DP.destroy()
......
......@@ -306,6 +306,8 @@ class EngineArgs:
# number of P/D disaggregation (or other disaggregation) workers
pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
decode_context_parallel_size: int = \
ParallelConfig.decode_context_parallel_size
data_parallel_size: int = ParallelConfig.data_parallel_size
data_parallel_rank: Optional[int] = None
data_parallel_start_rank: Optional[int] = None
......@@ -636,6 +638,9 @@ class EngineArgs:
**parallel_kwargs["pipeline_parallel_size"])
parallel_group.add_argument("--tensor-parallel-size", "-tp",
**parallel_kwargs["tensor_parallel_size"])
parallel_group.add_argument(
"--decode-context-parallel-size", "-dcp",
**parallel_kwargs["decode_context_parallel_size"])
parallel_group.add_argument("--data-parallel-size", "-dp",
**parallel_kwargs["data_parallel_size"])
parallel_group.add_argument(
......@@ -1156,6 +1161,17 @@ class EngineArgs:
# global layers in interleaved sliding window models.
sliding_window = model_config.get_sliding_window()
# Note(hc): In the current implementation of decode context
# parallel(DCP), tp_size needs to be divisible by dcp_size,
# because the world size does not change by dcp, it simply
# reuse the GPUs of TP group, and split one TP group into
# tp_size//dcp_size DCP groups.
assert self.tensor_parallel_size % self.decode_context_parallel_size \
== 0, (
f"tp_size={self.tensor_parallel_size} must be divisible by"
f"dcp_size={self.decode_context_parallel_size}."
)
cache_config = CacheConfig(
block_size=self.block_size,
gpu_memory_utilization=self.gpu_memory_utilization,
......@@ -1306,6 +1322,7 @@ class EngineArgs:
distributed_executor_backend=self.distributed_executor_backend,
worker_cls=self.worker_cls,
worker_extension_cls=self.worker_extension_cls,
decode_context_parallel_size=self.decode_context_parallel_size,
)
speculative_config = self.create_speculative_config(
......
This diff is collapsed.
......@@ -232,7 +232,7 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
self._workspace.get_buf(),
self.scale, self._num_kv_splits)
return self._v_up_proj(o)
return o
# TODO: Currently we leave it here only for backup in case something is
# wrong with the new SM100 CUTLASS MLA kernel
......@@ -265,21 +265,25 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
attn_metadata.decode.seq_lens,
attn_metadata.decode.block_table, self.scale)
return self._v_up_proj(o)
return o
def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
q: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
layer: AttentionLayer,
) -> torch.Tensor:
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
if type(q) is tuple:
q_nope, q_pe = q
else:
q_nope, q_pe = torch.split(
q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
if self._use_old_cutlass_mla:
# TODO: Remove the old cutlass MLA kernel after more extensive
# testing
return self._old_forward_decode(q_nope, q_pe, kv_c_and_k_pe_cache,
attn_metadata)
attn_metadata), None
return self._sm100_forward_decode(q_nope, q_pe, kv_c_and_k_pe_cache,
attn_metadata)
attn_metadata), None
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import ClassVar, Optional
from typing import ClassVar, Optional, Union
import torch
......@@ -154,15 +154,20 @@ class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]):
def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: FlashAttnMLAMetadata,
layer: AttentionLayer,
) -> torch.Tensor:
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
if type(q) is tuple:
q_nope, q_pe = q
else:
q_nope, q_pe = torch.split(
q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError(
"FP8 FlashAttention MLA not yet supported")
......
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import ClassVar, Optional
from typing import ClassVar, Optional, Union
import torch
......@@ -169,20 +169,20 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: FlashMLAMetadata,
layer: AttentionLayer,
) -> torch.Tensor:
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
q = torch.cat([q_nope, q_pe], dim=-1)\
.unsqueeze(1) # Add seqlen dim of 1 (decode)
if type(q) is tuple:
q = torch.cat(q, dim=-1)
o, _ = flash_mla_with_kvcache(
q=q,
assert isinstance(q, torch.Tensor)
o, lse = flash_mla_with_kvcache(
q=q.unsqueeze(1), # Add seqlen dim of 1 (decode)
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens,
......@@ -196,4 +196,4 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
descale_k=layer._k_scale.reshape(1),
)
return self._v_up_proj(o)
return o, lse
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import ClassVar, Optional
from typing import ClassVar, Optional, Union
import torch
......@@ -220,18 +220,19 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: AiterMLAMetadata,
layer: AttentionLayer,
) -> torch.Tensor:
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
B = q_nope.shape[0]
if type(q) is tuple:
q = torch.cat(q, dim=-1)
q = torch.cat([q_nope, q_pe], dim=-1)
assert isinstance(q, torch.Tensor)
B = q.shape[0]
o = torch.zeros(B,
self.num_heads,
self.kv_lora_rank,
......@@ -249,4 +250,4 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
attn_metadata.decode.paged_kv_indices,
attn_metadata.decode.paged_kv_last_page_len)
return self._v_up_proj(o)
return o, None
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
from typing import Optional, Union
import torch
......@@ -123,21 +123,22 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
layer: AttentionLayer,
) -> torch.Tensor:
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError("FP8 Triton MLA not yet supported")
B = q_nope.shape[0]
if type(q) is tuple:
q = torch.cat(q, dim=-1)
q = torch.cat([q_nope, q_pe], dim=-1)
assert isinstance(q, torch.Tensor)
B = q.shape[0]
o = torch.zeros(B,
self.num_heads,
self.kv_lora_rank,
......@@ -171,4 +172,4 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
attn_metadata.decode.seq_lens, attn_logits,
num_kv_splits, self.scale, PAGE_SIZE)
return self._v_up_proj(o)
return o, None
......@@ -24,6 +24,7 @@ class KVCacheCoordinator(ABC):
use_eagle: bool,
enable_caching: bool,
enable_kv_cache_events: bool,
dcp_world_size: int,
):
self.kv_cache_config = kv_cache_config
self.max_model_len = max_model_len
......@@ -39,6 +40,7 @@ class KVCacheCoordinator(ABC):
kv_cache_spec=kv_cache_group.kv_cache_spec,
block_pool=self.block_pool,
kv_cache_group_id=i,
dcp_world_size=dcp_world_size,
) for i, kv_cache_group in enumerate(
self.kv_cache_config.kv_cache_groups))
......@@ -197,9 +199,14 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
"""
def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int,
use_eagle: bool, enable_kv_cache_events: bool):
super().__init__(kv_cache_config, max_model_len, use_eagle, False,
enable_kv_cache_events)
use_eagle: bool, enable_kv_cache_events: bool,
dcp_world_size: int):
super().__init__(kv_cache_config,
max_model_len,
use_eagle,
False,
enable_kv_cache_events,
dcp_world_size=dcp_world_size)
self.num_single_type_manager = len(self.single_type_managers)
def get_num_common_prefix_blocks(self, request_id: str,
......@@ -225,12 +232,19 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int,
use_eagle: bool, enable_caching: bool,
enable_kv_cache_events: bool):
super().__init__(kv_cache_config, max_model_len, use_eagle,
enable_caching, enable_kv_cache_events)
enable_kv_cache_events: bool, dcp_world_size: int):
super().__init__(kv_cache_config,
max_model_len,
use_eagle,
enable_caching,
enable_kv_cache_events,
dcp_world_size=dcp_world_size)
self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[
0].kv_cache_spec
self.block_size = self.kv_cache_spec.block_size
self.dcp_world_size = dcp_world_size
if dcp_world_size > 1:
self.block_size *= dcp_world_size
assert len(self.kv_cache_config.kv_cache_groups) == 1, (
"UnitaryKVCacheCoordinator assumes only one kv cache group")
......@@ -246,6 +260,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
block_pool=self.block_pool,
kv_cache_spec=self.kv_cache_spec,
use_eagle=self.use_eagle,
dcp_world_size=self.dcp_world_size,
)
return hit_blocks, len(hit_blocks[0]) * self.block_size
......@@ -261,9 +276,14 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int,
use_eagle: bool, enable_caching: bool,
enable_kv_cache_events: bool):
super().__init__(kv_cache_config, max_model_len, use_eagle,
enable_caching, enable_kv_cache_events)
enable_kv_cache_events: bool, dcp_world_size: int):
super().__init__(kv_cache_config,
max_model_len,
use_eagle,
enable_caching,
enable_kv_cache_events,
dcp_world_size=dcp_world_size)
assert dcp_world_size == 1, "DCP not support hybrid attn now."
self.verify_and_split_kv_cache_groups()
def verify_and_split_kv_cache_groups(self) -> None:
......@@ -394,17 +414,27 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
return hit_blocks, hit_length
def get_kv_cache_coordinator(
kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool,
enable_caching: bool,
enable_kv_cache_events: bool) -> KVCacheCoordinator:
def get_kv_cache_coordinator(kv_cache_config: KVCacheConfig,
max_model_len: int, use_eagle: bool,
enable_caching: bool,
enable_kv_cache_events: bool,
dcp_world_size: int) -> KVCacheCoordinator:
if not enable_caching:
return KVCacheCoordinatorNoPrefixCache(kv_cache_config, max_model_len,
return KVCacheCoordinatorNoPrefixCache(kv_cache_config,
max_model_len,
use_eagle,
enable_kv_cache_events)
enable_kv_cache_events,
dcp_world_size=dcp_world_size)
if len(kv_cache_config.kv_cache_groups) == 1:
return UnitaryKVCacheCoordinator(kv_cache_config, max_model_len,
use_eagle, enable_caching,
enable_kv_cache_events)
return HybridKVCacheCoordinator(kv_cache_config, max_model_len, use_eagle,
enable_caching, enable_kv_cache_events)
return UnitaryKVCacheCoordinator(kv_cache_config,
max_model_len,
use_eagle,
enable_caching,
enable_kv_cache_events,
dcp_world_size=dcp_world_size)
return HybridKVCacheCoordinator(kv_cache_config,
max_model_len,
use_eagle,
enable_caching,
enable_kv_cache_events,
dcp_world_size=dcp_world_size)
......@@ -91,6 +91,7 @@ class KVCacheManager:
use_eagle: bool = False,
log_stats: bool = False,
enable_kv_cache_events: bool = False,
dcp_world_size: int = 1,
) -> None:
self.max_model_len = max_model_len
......@@ -109,12 +110,20 @@ class KVCacheManager:
self.block_size = kv_cache_config.kv_cache_groups[
0].kv_cache_spec.block_size
if dcp_world_size > 1:
assert len(kv_cache_config.kv_cache_groups) == 1
# Note(hc): need revisit. When both DCP and any future
# PCP are enabled, the block_size may need to be scaled
# by a factor of dcp_size × pcp_size?
self.block_size *= dcp_world_size
self.coordinator = get_kv_cache_coordinator(
kv_cache_config=kv_cache_config,
max_model_len=self.max_model_len,
use_eagle=self.use_eagle,
enable_caching=self.enable_caching,
enable_kv_cache_events=enable_kv_cache_events,
dcp_world_size=dcp_world_size,
)
self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups)
self.block_pool = self.coordinator.block_pool
......
......@@ -846,6 +846,12 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
)
num_tokens = num_blocks * vllm_config.cache_config.block_size
if vllm_config.parallel_config.decode_context_parallel_size > 1:
num_tokens *= vllm_config.parallel_config.decode_context_parallel_size
logger.info(
"Multiplying the GPU KV cache size by the dcp_world_size %d.",
vllm_config.parallel_config.decode_context_parallel_size)
num_tokens_str = f"{num_tokens:,}"
logger.info("GPU KV cache size: %s tokens", num_tokens_str)
max_model_len_str = f"{vllm_config.model_config.max_model_len:,}"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment