Unverified Commit 4516d44b authored by Jingchun Gao's avatar Jingchun Gao Committed by GitHub
Browse files

[DCP] Support Decode Context Parallel (DCP) for GQA with Flashinfer (#25438)


Signed-off-by: default avatargaojc <1055866782@qq.com>
Signed-off-by: default avatarJingchun Gao <gaojingchun1@huawei.com>
Signed-off-by: default avatarJingchun Gao <63247409+gjc0824@users.noreply.github.com>
Signed-off-by: default avatarQiuChunshuo <qiuchunshuo@huawei.com>
Co-authored-by: default avatargaojingchun (A) <g00955623@china.huawei.com>
Co-authored-by: default avatarJingchun Gao <gaojingchun1@huawei.com>
Co-authored-by: default avatarQiuChunshuo <qiuchunshuo@huawei.com>
parent 41b92f7d
...@@ -39,6 +39,7 @@ class ParallelSetup(NamedTuple): ...@@ -39,6 +39,7 @@ class ParallelSetup(NamedTuple):
class CPTestOptions(NamedTuple): class CPTestOptions(NamedTuple):
multi_node_only: bool multi_node_only: bool
load_format: str | None = None load_format: str | None = None
attn_backend: str | None = None
@dataclass @dataclass
...@@ -58,6 +59,7 @@ class CPTestSettings: ...@@ -58,6 +59,7 @@ class CPTestSettings:
multi_node_only: bool = False, multi_node_only: bool = False,
runner: RunnerOption = "auto", runner: RunnerOption = "auto",
load_format: str | None = None, load_format: str | None = None,
attn_backend: str | None = None,
): ):
parallel_setups = [] parallel_setups = []
for eager_mode_val in [False]: for eager_mode_val in [False]:
...@@ -79,7 +81,9 @@ class CPTestSettings: ...@@ -79,7 +81,9 @@ class CPTestSettings:
distributed_backends=["mp"], distributed_backends=["mp"],
runner=runner, runner=runner,
test_options=CPTestOptions( test_options=CPTestOptions(
multi_node_only=multi_node_only, load_format=load_format multi_node_only=multi_node_only,
load_format=load_format,
attn_backend=attn_backend,
), ),
) )
...@@ -117,7 +121,7 @@ def _compare_cp_with_tp( ...@@ -117,7 +121,7 @@ def _compare_cp_with_tp(
chunked_prefill, chunked_prefill,
) = parallel_setup ) = parallel_setup
multi_node_only, load_format = test_options multi_node_only, load_format, attn_backend = test_options
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id) model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
model_info.check_transformers_version(on_fail="skip") model_info.check_transformers_version(on_fail="skip")
...@@ -177,6 +181,13 @@ def _compare_cp_with_tp( ...@@ -177,6 +181,13 @@ def _compare_cp_with_tp(
if hf_overrides: if hf_overrides:
common_args.extend(["--hf-overrides", json.dumps(hf_overrides)]) common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
if not attn_backend:
cp_env = tp_env = {}
else:
cp_env = tp_env = {
"VLLM_ATTENTION_BACKEND": attn_backend,
}
cp_args = [ cp_args = [
*common_args, *common_args,
"--tensor-parallel-size", "--tensor-parallel-size",
...@@ -205,6 +216,8 @@ def _compare_cp_with_tp( ...@@ -205,6 +216,8 @@ def _compare_cp_with_tp(
model_id, model_id,
cp_args, cp_args,
tp_args, tp_args,
cp_env,
tp_env,
method=method, method=method,
max_wait_seconds=720, max_wait_seconds=720,
) )
......
...@@ -1183,6 +1183,14 @@ class ModelConfig: ...@@ -1183,6 +1183,14 @@ class ModelConfig:
f"but got {decode_context_parallel_size}" f"but got {decode_context_parallel_size}"
) )
num_q_per_kv = total_num_attention_heads // total_num_kv_heads
assert num_q_per_kv % decode_context_parallel_size == 0, (
f"Total number of q per kv attn heads ({num_q_per_kv})"
" must be divisible by dcp world size when enable "
"decode context parallel for GQA "
f"({parallel_config.decode_context_parallel_size})."
)
def get_sliding_window(self) -> int | None: def get_sliding_window(self) -> int | None:
"""Get the sliding window size from the HF text config if present.""" """Get the sliding window size from the HF text config if present."""
return getattr(self.hf_text_config, "sliding_window", None) return getattr(self.hf_text_config, "sliding_window", None)
......
...@@ -259,6 +259,7 @@ def use_trtllm_attention( ...@@ -259,6 +259,7 @@ def use_trtllm_attention(
num_kv_heads: int, num_kv_heads: int,
num_tokens: int, num_tokens: int,
max_seq_len: int, max_seq_len: int,
dcp_world_size: int,
kv_cache_dtype: str, kv_cache_dtype: str,
q_dtype: torch.dtype, q_dtype: torch.dtype,
is_prefill: bool, is_prefill: bool,
...@@ -272,6 +273,14 @@ def use_trtllm_attention( ...@@ -272,6 +273,14 @@ def use_trtllm_attention(
if force_use_trtllm is not None and not force_use_trtllm: if force_use_trtllm is not None and not force_use_trtllm:
return False return False
# Decode context parallel is not supported
if dcp_world_size > 1:
logger.warning_once(
"Trtllm does not support returning LSE and as a result "
"does not support DCP, reverting to FlashInfer"
)
return False
# The platform is not supported # The platform is not supported
if not supports_trtllm_attention(): if not supports_trtllm_attention():
if force_use_trtllm: if force_use_trtllm:
......
This diff is collapsed.
...@@ -31,6 +31,7 @@ from vllm.distributed import destroy_distributed_environment, destroy_model_para ...@@ -31,6 +31,7 @@ from vllm.distributed import destroy_distributed_environment, destroy_model_para
from vllm.distributed.device_communicators.shm_broadcast import Handle, MessageQueue from vllm.distributed.device_communicators.shm_broadcast import Handle, MessageQueue
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
get_dcp_group,
get_dp_group, get_dp_group,
get_ep_group, get_ep_group,
get_pp_group, get_pp_group,
...@@ -726,6 +727,8 @@ class WorkerProc: ...@@ -726,6 +727,8 @@ class WorkerProc:
pp_rank = get_pp_group().rank_in_group pp_rank = get_pp_group().rank_in_group
tp_size = get_tp_group().world_size tp_size = get_tp_group().world_size
tp_rank = get_tp_group().rank_in_group tp_rank = get_tp_group().rank_in_group
dcp_size = get_dcp_group().world_size
dcp_rank = get_dcp_group().rank_in_group
process_name = "Worker" process_name = "Worker"
if dp_size > 1: if dp_size > 1:
process_name += f"_DP{dp_rank}" process_name += f"_DP{dp_rank}"
...@@ -733,6 +736,8 @@ class WorkerProc: ...@@ -733,6 +736,8 @@ class WorkerProc:
process_name += f"_PP{pp_rank}" process_name += f"_PP{pp_rank}"
if tp_size > 1: if tp_size > 1:
process_name += f"_TP{tp_rank}" process_name += f"_TP{tp_rank}"
if dcp_size > 1:
process_name += f"_DCP{dcp_rank}"
if enable_ep: if enable_ep:
ep_rank = get_ep_group().rank_in_group ep_rank = get_ep_group().rank_in_group
process_name += f"_EP{ep_rank}" process_name += f"_EP{ep_rank}"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment