Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
46cbbca0
Unverified
Commit
46cbbca0
authored
Dec 05, 2025
by
Qiu
Committed by
GitHub
Dec 04, 2025
Browse files
[CI][DCP][Perf] reduce DCP CI execution time (#29858)
Signed-off-by:
QiuChunshuo
<
qiuchunshuo@huawei.com
>
parent
b286a311
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
100 additions
and
94 deletions
+100
-94
tests/distributed/test_context_parallel.py
tests/distributed/test_context_parallel.py
+95
-93
tests/models/registry.py
tests/models/registry.py
+5
-1
No files found.
tests/distributed/test_context_parallel.py
View file @
46cbbca0
...
@@ -16,16 +16,35 @@ from typing import Literal, NamedTuple
...
@@ -16,16 +16,35 @@ from typing import Literal, NamedTuple
import
pytest
import
pytest
import
torch
import
torch
from
tests.evals.gsm8k.gsm8k_eval
import
evaluate_gsm8k
from
tests.utils
import
RemoteOpenAIServer
,
create_new_process_for_each_test
from
vllm.config.model
import
RunnerOption
from
vllm.config.model
import
RunnerOption
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
..models.registry
import
HF_EXAMPLE_MODELS
from
..models.registry
import
HF_EXAMPLE_MODELS
from
..utils
import
compare_two_settings
,
create_new_process_for_each_test
logger
=
init_logger
(
"test_context_parallel"
)
logger
=
init_logger
(
"test_context_parallel"
)
VLLM_MULTI_NODE
=
os
.
getenv
(
"VLLM_MULTI_NODE"
,
"0"
)
==
"1"
VLLM_MULTI_NODE
=
os
.
getenv
(
"VLLM_MULTI_NODE"
,
"0"
)
==
"1"
CP_TEST_MODELS
=
[
# TODO support other models
# [LANGUAGE GENERATION]
"deepseek-ai/DeepSeek-V2-Lite-Chat"
,
"Qwen/Qwen2.5-1.5B-Instruct"
,
]
# GSM8K eval configuration
NUM_QUESTIONS
=
256
# Fast eval for CI
NUM_SHOTS
=
5
# Few-shot examples
# tp accuracy with 2% buffer
MIN_ACCURACY
=
{
# .buildkite/lm-eval-harness/configs/DeepSeek-V2-Lite-Chat.yaml
"deepseek-ai/DeepSeek-V2-Lite-Chat"
:
0.64
,
# .buildkite/lm-eval-harness/configs/Qwen2.5-1.5B-Instruct.yaml
"Qwen/Qwen2.5-1.5B-Instruct"
:
0.52
,
}
class
ParallelSetup
(
NamedTuple
):
class
ParallelSetup
(
NamedTuple
):
tp_size
:
int
tp_size
:
int
...
@@ -38,7 +57,6 @@ class ParallelSetup(NamedTuple):
...
@@ -38,7 +57,6 @@ class ParallelSetup(NamedTuple):
class
CPTestOptions
(
NamedTuple
):
class
CPTestOptions
(
NamedTuple
):
multi_node_only
:
bool
multi_node_only
:
bool
load_format
:
str
|
None
=
None
attn_backend
:
str
|
None
=
None
attn_backend
:
str
|
None
=
None
...
@@ -54,17 +72,20 @@ class CPTestSettings:
...
@@ -54,17 +72,20 @@ class CPTestSettings:
*
,
*
,
tp_base
:
int
=
4
,
tp_base
:
int
=
4
,
pp_base
:
int
=
1
,
pp_base
:
int
=
1
,
dcp_
base
:
int
=
1
,
dcp_
multipliers
:
list
[
float
]
|
None
=
None
,
cp_kv_cache_interleave_size
:
int
=
1
,
cp_kv_cache_interleave_size
:
int
=
1
,
multi_node_only
:
bool
=
False
,
multi_node_only
:
bool
=
False
,
runner
:
RunnerOption
=
"auto"
,
runner
:
RunnerOption
=
"auto"
,
load_format
:
str
|
None
=
None
,
attn_backend
:
str
|
None
=
None
,
attn_backend
:
str
|
None
=
None
,
):
):
parallel_setups
=
[]
parallel_setups
=
[]
if
dcp_multipliers
is
None
:
dcp_multipliers
=
[
0.5
,
]
for
eager_mode_val
in
[
False
]:
for
eager_mode_val
in
[
False
]:
for
pp_multiplier
in
[
1
]:
for
pp_multiplier
in
[
1
]:
for
dcp_multiplier
in
[
0.5
,
1
]
:
for
dcp_multiplier
in
dcp_multipliers
:
for
chunked_prefill_val
in
[
True
]:
for
chunked_prefill_val
in
[
True
]:
parallel_setups
.
append
(
parallel_setups
.
append
(
ParallelSetup
(
ParallelSetup
(
...
@@ -82,7 +103,6 @@ class CPTestSettings:
...
@@ -82,7 +103,6 @@ class CPTestSettings:
runner
=
runner
,
runner
=
runner
,
test_options
=
CPTestOptions
(
test_options
=
CPTestOptions
(
multi_node_only
=
multi_node_only
,
multi_node_only
=
multi_node_only
,
load_format
=
load_format
,
attn_backend
=
attn_backend
,
attn_backend
=
attn_backend
,
),
),
)
)
...
@@ -101,7 +121,24 @@ class CPTestSettings:
...
@@ -101,7 +121,24 @@ class CPTestSettings:
)
)
def
_compare_cp_with_tp
(
CP_TEXT_GENERATION_MODELS
=
{
"deepseek-ai/DeepSeek-V2-Lite-Chat"
:
[
CPTestSettings
.
detailed
(
dcp_multipliers
=
[
0.5
,
1
],
cp_kv_cache_interleave_size
=
64
),
],
"Qwen/Qwen2.5-1.5B-Instruct"
:
[
CPTestSettings
.
detailed
(
cp_kv_cache_interleave_size
=
16
,
attn_backend
=
"FLASH_ATTN"
),
CPTestSettings
.
detailed
(
cp_kv_cache_interleave_size
=
16
,
attn_backend
=
"FLASHINFER"
),
],
}
def
_test_cp_gsm8k
(
model_id
:
str
,
model_id
:
str
,
parallel_setup
:
ParallelSetup
,
parallel_setup
:
ParallelSetup
,
distributed_backend
:
str
,
distributed_backend
:
str
,
...
@@ -121,7 +158,7 @@ def _compare_cp_with_tp(
...
@@ -121,7 +158,7 @@ def _compare_cp_with_tp(
chunked_prefill
,
chunked_prefill
,
)
=
parallel_setup
)
=
parallel_setup
multi_node_only
,
load_format
,
attn_backend
=
test_options
multi_node_only
,
attn_backend
=
test_options
model_info
=
HF_EXAMPLE_MODELS
.
find_hf_info
(
model_id
)
model_info
=
HF_EXAMPLE_MODELS
.
find_hf_info
(
model_id
)
model_info
.
check_transformers_version
(
on_fail
=
"skip"
)
model_info
.
check_transformers_version
(
on_fail
=
"skip"
)
...
@@ -130,22 +167,7 @@ def _compare_cp_with_tp(
...
@@ -130,22 +167,7 @@ def _compare_cp_with_tp(
tokenizer_mode
=
model_info
.
tokenizer_mode
tokenizer_mode
=
model_info
.
tokenizer_mode
hf_overrides
=
model_info
.
hf_overrides
hf_overrides
=
model_info
.
hf_overrides
if
load_format
==
"dummy"
:
model_info
.
check_available_online
(
on_fail
=
"skip"
)
# Avoid OOM
text_overrides
=
{
"num_hidden_layers"
:
4
,
"hidden_size"
:
512
,
"intermediate_size"
:
800
,
"num_attention_heads"
:
4
,
"num_key_value_heads"
:
1
,
}
if
is_multimodal
:
hf_overrides
.
update
({
"text_config"
:
text_overrides
})
else
:
hf_overrides
.
update
(
text_overrides
)
else
:
model_info
.
check_available_online
(
on_fail
=
"skip"
)
if
num_gpus_available
<
tp_size
*
pp_size
:
if
num_gpus_available
<
tp_size
*
pp_size
:
pytest
.
skip
(
f
"Need at least
{
tp_size
}
x
{
pp_size
}
GPUs"
)
pytest
.
skip
(
f
"Need at least
{
tp_size
}
x
{
pp_size
}
GPUs"
)
...
@@ -157,90 +179,70 @@ def _compare_cp_with_tp(
...
@@ -157,90 +179,70 @@ def _compare_cp_with_tp(
if
multi_node_only
and
not
VLLM_MULTI_NODE
:
if
multi_node_only
and
not
VLLM_MULTI_NODE
:
pytest
.
skip
(
"Not in multi-node setting"
)
pytest
.
skip
(
"Not in multi-node setting"
)
common
_args
=
[
server
_args
=
[
# use half precision for speed and memory savings in CI environment
# use half precision for speed and memory savings in CI environment
"--dtype"
,
"--dtype"
,
"bfloat16"
,
"bfloat16"
,
"--max-model-len"
,
"--max-model-len"
,
"
2048
"
,
"
4096
"
,
"--max-num-seqs"
,
"--max-num-seqs"
,
"
8
"
,
"
64
"
,
]
]
if
chunked_prefill
:
if
chunked_prefill
:
common
_args
.
append
(
"--enable-chunked-prefill"
)
server
_args
.
append
(
"--enable-chunked-prefill"
)
if
eager_mode
:
if
eager_mode
:
common
_args
.
append
(
"--enforce-eager"
)
server
_args
.
append
(
"--enforce-eager"
)
if
runner
!=
"auto"
:
if
runner
!=
"auto"
:
common
_args
.
extend
([
"--runner"
,
runner
])
server
_args
.
extend
([
"--runner"
,
runner
])
if
trust_remote_code
:
if
trust_remote_code
:
common
_args
.
append
(
"--trust-remote-code"
)
server
_args
.
append
(
"--trust-remote-code"
)
if
tokenizer_mode
:
if
tokenizer_mode
:
common_args
.
extend
([
"--tokenizer-mode"
,
tokenizer_mode
])
server_args
.
extend
([
"--tokenizer-mode"
,
tokenizer_mode
])
if
load_format
:
common_args
.
extend
([
"--load-format"
,
load_format
])
if
hf_overrides
:
if
hf_overrides
:
common_args
.
extend
([
"--hf-overrides"
,
json
.
dumps
(
hf_overrides
)])
server_args
.
extend
([
"--hf-overrides"
,
json
.
dumps
(
hf_overrides
)])
if
not
attn_backend
:
server_args
.
extend
(
cp_env
=
tp_env
=
{}
[
else
:
"--tensor-parallel-size"
,
cp_env
=
tp_env
=
{
str
(
tp_size
),
"VLLM_ATTENTION_BACKEND"
:
attn_backend
,
"--pipeline-parallel-size"
,
}
str
(
pp_size
),
"--decode-context-parallel-size"
,
cp_args
=
[
str
(
dcp_size
),
*
common_args
,
"--dcp-kv-cache-interleave-size"
,
"--tensor-parallel-size"
,
str
(
cp_kv_cache_interleave_size
),
str
(
tp_size
),
"--distributed-executor-backend"
,
"--pipeline-parallel-size"
,
distributed_backend
,
str
(
pp_size
),
]
"--decode-context-parallel-size"
,
)
str
(
dcp_size
),
"--dcp-kv-cache-interleave-size"
,
str
(
cp_kv_cache_interleave_size
),
"--distributed-executor-backend"
,
distributed_backend
,
]
tp_args
=
[
server_env
=
{}
*
common_args
,
if
attn_backend
:
"--tensor-parallel-size"
,
server_env
[
"VLLM_ATTENTION_BACKEND"
]
=
attn_backend
str
(
tp_size
),
"--pipeline-parallel-size"
,
str
(
pp_size
),
"--distributed-executor-backend"
,
distributed_backend
,
]
compare_two_settings
(
with
RemoteOpenAIServer
(
model_id
,
model_id
,
cp_args
,
server_args
,
tp_args
,
env_dict
=
server_env
,
cp_env
,
tp_env
,
method
=
method
,
max_wait_seconds
=
720
,
max_wait_seconds
=
720
,
)
)
as
remote_server
:
host
=
f
"http://
{
remote_server
.
host
}
"
port
=
remote_server
.
port
CP_TEXT_GENERATION_MODELS
=
{
"deepseek-ai/DeepSeek-V2-Lite-Chat"
:
[
# Run GSM8K evaluation
CPTestSettings
.
detailed
(),
results
=
evaluate_gsm8k
(
CPTestSettings
.
detailed
(
tp_base
=
2
),
num_questions
=
NUM_QUESTIONS
,
CPTestSettings
.
detailed
(
tp_base
=
2
,
cp_kv_cache_interleave_size
=
64
),
num_shots
=
NUM_SHOTS
,
],
host
=
host
,
"bigcode/gpt_bigcode-santacoder"
:
[
port
=
port
,
CPTestSettings
.
detailed
(),
)
CPTestSettings
.
detailed
(
tp_base
=
2
),
],
}
CP_TEST_MODELS
=
[
# Validate accuracy is reasonable
# TODO support other models
accuracy
=
results
[
"accuracy"
]
# [LANGUAGE GENERATION
]
min_accuracy
=
MIN_ACCURACY
[
model_id
]
"deepseek-ai/DeepSeek-V2-Lite-Chat"
,
assert
accuracy
>=
min_accuracy
,
(
"bigcode/gpt_bigcode-santacoder"
,
f
"TP+DCP accuracy too low:
{
accuracy
:.
3
f
}
<
{
min_accuracy
:.
3
f
}
"
]
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -274,12 +276,12 @@ def test_cp_generation(
...
@@ -274,12 +276,12 @@ def test_cp_generation(
):
):
pytest
.
skip
(
reason
=
"MLA+DCP requires compute capability of 9.0 or higher"
)
pytest
.
skip
(
reason
=
"MLA+DCP requires compute capability of 9.0 or higher"
)
if
(
if
(
model_id
==
"
bigcode/gpt_bigcode-santacoder
"
model_id
==
"
Qwen/Qwen2.5-1.5B-Instruct
"
and
torch
.
cuda
.
get_device_capability
()
!=
(
9
,
0
)
and
torch
.
cuda
.
get_device_capability
()
!=
(
9
,
0
)
):
):
pytest
.
skip
(
reason
=
"GQA+DCP currently requires compute capability of 9.0"
)
pytest
.
skip
(
reason
=
"GQA+DCP currently requires compute capability of 9.0"
)
_
compare_cp_with_tp
(
_
test_cp_gsm8k
(
model_id
,
model_id
,
parallel_setup
,
parallel_setup
,
distributed_backend
,
distributed_backend
,
...
...
tests/models/registry.py
View file @
46cbbca0
...
@@ -416,7 +416,11 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
...
@@ -416,7 +416,11 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
trust_remote_code
=
True
,
trust_remote_code
=
True
,
),
),
"Qwen2ForCausalLM"
:
_HfExamplesInfo
(
"Qwen2ForCausalLM"
:
_HfExamplesInfo
(
"Qwen/Qwen2-0.5B-Instruct"
,
extras
=
{
"2.5"
:
"Qwen/Qwen2.5-0.5B-Instruct"
}
"Qwen/Qwen2-0.5B-Instruct"
,
extras
=
{
"2.5"
:
"Qwen/Qwen2.5-0.5B-Instruct"
,
"2.5-1.5B"
:
"Qwen/Qwen2.5-1.5B-Instruct"
,
},
),
),
"Qwen2MoeForCausalLM"
:
_HfExamplesInfo
(
"Qwen/Qwen1.5-MoE-A2.7B-Chat"
),
"Qwen2MoeForCausalLM"
:
_HfExamplesInfo
(
"Qwen/Qwen1.5-MoE-A2.7B-Chat"
),
"Qwen3ForCausalLM"
:
_HfExamplesInfo
(
"Qwen/Qwen3-8B"
),
"Qwen3ForCausalLM"
:
_HfExamplesInfo
(
"Qwen/Qwen3-8B"
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment