Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b2fac671
Unverified
Commit
b2fac671
authored
Jun 05, 2025
by
Nicolò Lucchesi
Committed by
GitHub
Jun 04, 2025
Browse files
[P/D] Heterogeneous TP (#18833)
Signed-off-by:
nicklucche
<
nlucches@redhat.com
>
parent
23027e2d
Changes
6
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
287 additions
and
100 deletions
+287
-100
tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh
tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh
+8
-3
tests/v1/kv_connector/nixl_integration/test_accuracy.py
tests/v1/kv_connector/nixl_integration/test_accuracy.py
+1
-0
vllm/distributed/kv_transfer/kv_connector/utils.py
vllm/distributed/kv_transfer/kv_connector/utils.py
+17
-1
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
...distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+243
-95
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+16
-0
vllm/worker/worker_base.py
vllm/worker/worker_base.py
+2
-1
No files found.
tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh
View file @
b2fac671
...
...
@@ -8,7 +8,9 @@ MODELS=(
# Number of prefill and decode instances to create
NUM_PREFILL_INSTANCES
=
${
NUM_PREFILL_INSTANCES
:-
1
}
# Default to 1
NUM_DECODE_INSTANCES
=
${
NUM_DECODE_INSTANCES
:-
2
}
# Default to 2
NUM_DECODE_INSTANCES
=
${
NUM_DECODE_INSTANCES
:-
1
}
# Default to 1
PREFILLER_TP_SIZE
=
${
PREFILLER_TP_SIZE
:-
1
}
DECODER_TP_SIZE
=
${
DECODER_TP_SIZE
:-
1
}
# Find the git repository root directory
GIT_ROOT
=
$(
git rev-parse
--show-toplevel
)
...
...
@@ -74,9 +76,10 @@ run_tests_for_model() {
for
i
in
$(
seq
0
$((
NUM_PREFILL_INSTANCES-1
))
)
;
do
# Calculate GPU ID - we'll distribute across available GPUs
GPU_ID
=
$((
i
%
$(
get_num_gpus
)
))
# Calculate port number (base port + instance number)
PORT
=
$((
8100
+
i
))
# Calculate side channel port
# Calculate side channel port
. Avoid clash with with TP workers.
SIDE_CHANNEL_PORT
=
$((
5559
+
i
))
echo
"Starting prefill instance
$i
on GPU
$GPU_ID
, port
$PORT
"
...
...
@@ -87,6 +90,7 @@ run_tests_for_model() {
--enforce-eager
\
--disable-log-requests
\
--gpu-memory-utilization 0.2
\
--tensor-parallel-size
$PREFILLER_TP_SIZE
\
--kv-transfer-config '{
\"
kv_connector
\"
:
\"
NixlConnector
\"
,
\"
kv_role
\"
:
\"
kv_both
\"
}'"
if
[
-n
"
$model_args
"
]
;
then
...
...
@@ -109,7 +113,7 @@ run_tests_for_model() {
# Calculate port number (base port + instance number)
PORT
=
$((
8200
+
i
))
# Calculate side channel port
SIDE_CHANNEL_PORT
=
$((
5659
+
i
))
SIDE_CHANNEL_PORT
=
$((
5659
+
i
*
$DECODER_TP_SIZE
))
echo
"Starting decode instance
$i
on GPU
$GPU_ID
, port
$PORT
"
...
...
@@ -119,6 +123,7 @@ run_tests_for_model() {
--enforce-eager
\
--disable-log-requests
\
--gpu-memory-utilization 0.2
\
--tensor-parallel-size
$DECODER_TP_SIZE
\
--kv-transfer-config '{
\"
kv_connector
\"
:
\"
NixlConnector
\"
,
\"
kv_role
\"
:
\"
kv_both
\"
}'"
if
[
-n
"
$model_args
"
]
;
then
...
...
tests/v1/kv_connector/nixl_integration/test_accuracy.py
View file @
b2fac671
...
...
@@ -14,6 +14,7 @@ RTOL = 0.03
# Model-specific expected values
EXPECTED_VALUES
=
{
"Qwen/Qwen3-0.6B"
:
0.41
,
"deepseek-ai/deepseek-vl2-small"
:
0.59
}
SIMPLE_PROMPT
=
"The best part about working on vLLM is that I got to meet so many people across various different organizations like UCB, Google, and Meta which means"
,
# noqa: E501
...
...
vllm/distributed/kv_transfer/kv_connector/utils.py
View file @
b2fac671
...
...
@@ -3,11 +3,12 @@
"""
KV cache helper for store.
"""
import
torch
import
vllm.envs
as
envs
from
vllm
import
_custom_ops
as
ops
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
,
get_current_vllm_config
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
...
...
@@ -90,3 +91,18 @@ class model_aware_kv_ops_helper:
layer
.
self_attn
.
attn
.
_k_scale
,
layer
.
self_attn
.
attn
.
_v_scale
,
)
def
get_kv_connector_cache_layout
():
vllm_config
=
get_current_vllm_config
()
kv_config
=
vllm_config
.
kv_transfer_config
if
vllm_config
.
model_config
is
None
:
logger
.
warning
(
"Unable to detect current VLLM config. "
\
"Defaulting to NHD kv cache layout."
)
else
:
use_mla
=
vllm_config
.
model_config
.
use_mla
if
not
use_mla
and
kv_config
.
kv_connector
==
"NixlConnector"
:
logger
.
info
(
"NixlConnector detected. Setting KV cache "
\
"layout to HND for better xfer performance."
)
return
"HND"
return
"NHD"
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
View file @
b2fac671
This diff is collapsed.
Click to expand it.
vllm/v1/attention/backends/flash_attn.py
View file @
b2fac671
...
...
@@ -16,6 +16,8 @@ from vllm.attention.ops.merge_attn_states import merge_attn_states
from
vllm.attention.utils.fa_utils
import
(
flash_attn_supports_fp8
,
get_flash_attn_version
)
from
vllm.config
import
VllmConfig
,
get_layers_from_vllm_config
from
vllm.distributed.kv_transfer.kv_connector.utils
import
(
get_kv_connector_cache_layout
)
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.utils
import
cdiv
...
...
@@ -70,6 +72,20 @@ class FlashAttentionBackend(AttentionBackend):
raise
ValueError
(
"Block size must be a multiple of 16."
)
return
(
2
,
num_blocks
,
block_size
,
num_kv_heads
,
head_size
)
@
staticmethod
def
get_kv_cache_stride_order
()
->
tuple
[
int
,
...]:
# NOTE When running disaggregated PD with NIXL, HND layout is used for
# faster transfer. `stride_order` indicates the permutation that gets
# us from `get_kv_cache_shape` to the actual memory layout we want.
cache_layout
=
get_kv_connector_cache_layout
()
if
cache_layout
==
"NHD"
:
stride_order
=
(
0
,
1
,
2
,
3
,
4
)
elif
cache_layout
==
"HND"
:
stride_order
=
(
0
,
1
,
3
,
2
,
4
)
else
:
raise
ValueError
(
"Unknown cache layout format %s."
,
cache_layout
)
return
stride_order
@
dataclass
class
FlashAttentionMetadata
:
...
...
vllm/worker/worker_base.py
View file @
b2fac671
...
...
@@ -597,6 +597,7 @@ class WorkerWrapperBase:
def
initialize_from_config
(
self
,
kv_cache_configs
:
List
[
Any
])
->
None
:
kv_cache_config
=
kv_cache_configs
[
self
.
rpc_rank
]
with
set_current_vllm_config
(
self
.
vllm_config
):
self
.
worker
.
initialize_from_config
(
kv_cache_config
)
# type: ignore
def
init_device
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment