"vllm/vscode:/vscode.git/clone" did not exist on "9912b8ccb861593d76216afa583ac593faf5a309"
Unverified Commit 4516d44b authored by Jingchun Gao's avatar Jingchun Gao Committed by GitHub
Browse files

[DCP] Support Decode Context Parallel (DCP) for GQA with Flashinfer (#25438)


Signed-off-by: default avatargaojc <1055866782@qq.com>
Signed-off-by: default avatarJingchun Gao <gaojingchun1@huawei.com>
Signed-off-by: default avatarJingchun Gao <63247409+gjc0824@users.noreply.github.com>
Signed-off-by: default avatarQiuChunshuo <qiuchunshuo@huawei.com>
Co-authored-by: default avatargaojingchun (A) <g00955623@china.huawei.com>
Co-authored-by: default avatarJingchun Gao <gaojingchun1@huawei.com>
Co-authored-by: default avatarQiuChunshuo <qiuchunshuo@huawei.com>
parent 41b92f7d
......@@ -39,6 +39,7 @@ class ParallelSetup(NamedTuple):
class CPTestOptions(NamedTuple):
multi_node_only: bool
load_format: str | None = None
attn_backend: str | None = None
@dataclass
......@@ -58,6 +59,7 @@ class CPTestSettings:
multi_node_only: bool = False,
runner: RunnerOption = "auto",
load_format: str | None = None,
attn_backend: str | None = None,
):
parallel_setups = []
for eager_mode_val in [False]:
......@@ -79,7 +81,9 @@ class CPTestSettings:
distributed_backends=["mp"],
runner=runner,
test_options=CPTestOptions(
multi_node_only=multi_node_only, load_format=load_format
multi_node_only=multi_node_only,
load_format=load_format,
attn_backend=attn_backend,
),
)
......@@ -117,7 +121,7 @@ def _compare_cp_with_tp(
chunked_prefill,
) = parallel_setup
multi_node_only, load_format = test_options
multi_node_only, load_format, attn_backend = test_options
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
model_info.check_transformers_version(on_fail="skip")
......@@ -177,6 +181,13 @@ def _compare_cp_with_tp(
if hf_overrides:
common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
if not attn_backend:
cp_env = tp_env = {}
else:
cp_env = tp_env = {
"VLLM_ATTENTION_BACKEND": attn_backend,
}
cp_args = [
*common_args,
"--tensor-parallel-size",
......@@ -205,6 +216,8 @@ def _compare_cp_with_tp(
model_id,
cp_args,
tp_args,
cp_env,
tp_env,
method=method,
max_wait_seconds=720,
)
......
......@@ -1183,6 +1183,14 @@ class ModelConfig:
f"but got {decode_context_parallel_size}"
)
num_q_per_kv = total_num_attention_heads // total_num_kv_heads
assert num_q_per_kv % decode_context_parallel_size == 0, (
f"Total number of q per kv attn heads ({num_q_per_kv})"
" must be divisible by dcp world size when enable "
"decode context parallel for GQA "
f"({parallel_config.decode_context_parallel_size})."
)
def get_sliding_window(self) -> int | None:
"""Get the sliding window size from the HF text config if present."""
return getattr(self.hf_text_config, "sliding_window", None)
......
......@@ -259,6 +259,7 @@ def use_trtllm_attention(
num_kv_heads: int,
num_tokens: int,
max_seq_len: int,
dcp_world_size: int,
kv_cache_dtype: str,
q_dtype: torch.dtype,
is_prefill: bool,
......@@ -272,6 +273,14 @@ def use_trtllm_attention(
if force_use_trtllm is not None and not force_use_trtllm:
return False
# Decode context parallel is not supported
if dcp_world_size > 1:
logger.warning_once(
"Trtllm does not support returning LSE and as a result "
"does not support DCP, reverting to FlashInfer"
)
return False
# The platform is not supported
if not supports_trtllm_attention():
if force_use_trtllm:
......
This diff is collapsed.
......@@ -31,6 +31,7 @@ from vllm.distributed import destroy_distributed_environment, destroy_model_para
from vllm.distributed.device_communicators.shm_broadcast import Handle, MessageQueue
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
from vllm.distributed.parallel_state import (
get_dcp_group,
get_dp_group,
get_ep_group,
get_pp_group,
......@@ -726,6 +727,8 @@ class WorkerProc:
pp_rank = get_pp_group().rank_in_group
tp_size = get_tp_group().world_size
tp_rank = get_tp_group().rank_in_group
dcp_size = get_dcp_group().world_size
dcp_rank = get_dcp_group().rank_in_group
process_name = "Worker"
if dp_size > 1:
process_name += f"_DP{dp_rank}"
......@@ -733,6 +736,8 @@ class WorkerProc:
process_name += f"_PP{pp_rank}"
if tp_size > 1:
process_name += f"_TP{tp_rank}"
if dcp_size > 1:
process_name += f"_DCP{dcp_rank}"
if enable_ep:
ep_rank = get_ep_group().rank_in_group
process_name += f"_EP{ep_rank}"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment