Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
911355e2
Unverified
Commit
911355e2
authored
Mar 16, 2026
by
Andreas Karatzas
Committed by
GitHub
Mar 16, 2026
Browse files
[ROCm] Fix KV copy methods and auto-select attention backend for ROCm (#36845)
Signed-off-by:
Andreas Karatzas
<
akaratza@amd.com
>
parent
8d3f8f48
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
75 additions
and
17 deletions
+75
-17
tests/v1/kv_connector/nixl_integration/spec_decode_acceptance_test.sh
...connector/nixl_integration/spec_decode_acceptance_test.sh
+51
-17
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+24
-0
No files found.
tests/v1/kv_connector/nixl_integration/spec_decode_acceptance_test.sh
View file @
911355e2
...
...
@@ -21,6 +21,11 @@
# MODEL_NAME - target model (default: meta-llama/Llama-3.1-8B-Instruct)
# NUM_SPEC_TOKENS - number of speculative tokens (default: 3)
# GPU_MEMORY_UTILIZATION - (default: 0.7)
# ATTENTION_BACKEND - attention backend to use
# Default: TRITON_ATTN on ROCm, FLASH_ATTN on NVIDIA
# ROCm options: TRITON_ATTN, ROCM_ATTN, ROCM_AITER_FA,
# ROCM_AITER_UNIFIED_ATTN
# NVIDIA options: FLASH_ATTN, FLASHINFER
set
-x
# ── Model & spec decode config ──────────────────────────────────────────
...
...
@@ -51,6 +56,28 @@ GIT_ROOT=$(git rev-parse --show-toplevel)
SMI_BIN
=
$(
which nvidia-smi
||
which rocm-smi
||
echo
""
)
# ── Detect platform (NVIDIA vs ROCm) ────────────────────────────────────
if
[[
"
$SMI_BIN
"
==
*
"rocm"
*
]]
;
then
GPU_PLATFORM
=
"rocm"
GPU_DEVICE_VAR
=
"HIP_VISIBLE_DEVICES"
else
GPU_PLATFORM
=
"nvidia"
GPU_DEVICE_VAR
=
"CUDA_VISIBLE_DEVICES"
fi
echo
"Detected GPU platform:
${
GPU_PLATFORM
}
(using
${
GPU_DEVICE_VAR
}
)"
# ── Attention backend config ─────────────────────────────────────────────
if
[[
-z
"
${
ATTENTION_BACKEND
:-}
"
]]
;
then
if
[[
"
$GPU_PLATFORM
"
==
"rocm"
]]
;
then
ATTENTION_BACKEND
=
"TRITON_ATTN"
else
ATTENTION_BACKEND
=
"FLASH_ATTN"
fi
fi
echo
"Using attention backend:
${
ATTENTION_BACKEND
}
"
cleanup_instances
()
{
echo
""
echo
"Cleaning up..."
...
...
@@ -84,13 +111,16 @@ wait_for_server() {
# ── Resolve GPU list ─────────────────────────────────────────────────────
if
[[
-n
"
${
CUDA_VISIBLE_DEVICES
:-}
"
]]
;
then
IFS
=
','
read
-ra
ALL_GPUS
<<<
"
$CUDA_VISIBLE_DEVICES
"
# Accept either CUDA_VISIBLE_DEVICES or HIP_VISIBLE_DEVICES
VISIBLE_DEVICES
=
"
${
CUDA_VISIBLE_DEVICES
:-${
HIP_VISIBLE_DEVICES
:-}}
"
if
[[
-n
"
${
VISIBLE_DEVICES
}
"
]]
;
then
IFS
=
','
read
-ra
ALL_GPUS
<<<
"
$VISIBLE_DEVICES
"
else
ALL_GPUS
=()
if
[[
"
$
SMI_BIN
"
==
*
"nvidia"
*
]]
;
then
if
[[
"
$
GPU_PLATFORM
"
==
"nvidia"
]]
;
then
num
=
$(
$SMI_BIN
--query-gpu
=
name
--format
=
csv,noheader |
wc
-l
)
elif
[[
"
$
SMI_BIN
"
==
*
"rocm"
*
]]
;
then
elif
[[
"
$
GPU_PLATFORM
"
==
"rocm"
]]
;
then
num
=
$(
$SMI_BIN
-l
|
grep
-c
GPU
)
else
num
=
1
...
...
@@ -100,7 +130,7 @@ fi
TOTAL_GPUS_NEEDED
=
$((
(
NUM_PREFILL_INSTANCES
*
PREFILLER_TP_SIZE
)
+
(
NUM_DECODE_INSTANCES
*
DECODER_TP_SIZE
)
))
if
[[
${#
ALL_GPUS
[@]
}
-lt
$TOTAL_GPUS_NEEDED
]]
;
then
echo
"FAIL: Need
$TOTAL_GPUS_NEEDED
GPUs but only have
${#
ALL_GPUS
[@]
}
(
CUDA_VISIBLE_DEVICES=
${
CUDA_
VISIBLE_DEVICES
:-
not
set
}
)"
echo
"FAIL: Need
$TOTAL_GPUS_NEEDED
GPUs but only have
${#
ALL_GPUS
[@]
}
(
visible devices=
${
VISIBLE_DEVICES
:-
not
set
}
)"
exit
1
fi
...
...
@@ -124,6 +154,8 @@ run_test_for_device() {
echo
"SD model:
${
SD_MODEL
}
"
echo
"Spec tokens:
${
NUM_SPEC_TOKENS
}
"
echo
"KV buffer device:
${
kv_device
}
"
echo
"Attention backend:
${
ATTENTION_BACKEND
}
"
echo
"GPU platform:
${
GPU_PLATFORM
}
"
echo
"GPUs available:
${
ALL_GPUS
[*]
}
"
echo
"================================================================"
...
...
@@ -146,7 +178,8 @@ run_test_for_device() {
local
SIDE_CHANNEL_PORT
=
$((
5559
+
i
))
echo
"Starting prefill instance
$i
on GPU
$GPU_ID
, port
$PORT
"
CUDA_VISIBLE_DEVICES
=
$GPU_ID
\
env
\
${
GPU_DEVICE_VAR
}
=
$GPU_ID
\
VLLM_KV_CACHE_LAYOUT
=
'HND'
\
UCX_NET_DEVICES
=
all
\
VLLM_NIXL_SIDE_CHANNEL_PORT
=
$SIDE_CHANNEL_PORT
\
...
...
@@ -159,7 +192,7 @@ run_test_for_device() {
--tensor-parallel-size
$PREFILLER_TP_SIZE
\
--kv-transfer-config
"
$kv_config
"
\
--speculative-config
"
$PREFILL_SPEC_CONFIG
"
\
--attention-backend
FLASH_ATTN
&
--attention-backend
$ATTENTION_BACKEND
&
PREFILL_HOSTS+
=(
"localhost"
)
PREFILL_PORTS+
=(
"
$PORT
"
)
...
...
@@ -178,7 +211,8 @@ run_test_for_device() {
local
SIDE_CHANNEL_PORT
=
$((
5659
+
i
*
$DECODER_TP_SIZE
))
echo
"Starting decode instance
$i
on GPU
$GPU_ID
, port
$PORT
"
CUDA_VISIBLE_DEVICES
=
$GPU_ID
\
env
\
${
GPU_DEVICE_VAR
}
=
$GPU_ID
\
VLLM_KV_CACHE_LAYOUT
=
'HND'
\
UCX_NET_DEVICES
=
all
\
VLLM_NIXL_SIDE_CHANNEL_PORT
=
$SIDE_CHANNEL_PORT
\
...
...
@@ -191,7 +225,7 @@ run_test_for_device() {
--tensor-parallel-size
$DECODER_TP_SIZE
\
--kv-transfer-config
"
$kv_config
"
\
--speculative-config
"
$DECODE_SPEC_CONFIG
"
\
--attention-backend
FLASH_ATTN
&
--attention-backend
$ATTENTION_BACKEND
&
DECODE_HOSTS+
=(
"localhost"
)
DECODE_PORTS+
=(
"
$PORT
"
)
...
...
@@ -218,7 +252,7 @@ run_test_for_device() {
sleep
5
# Run test
echo
"Running spec decode acceptance test (kv_buffer_device=
${
kv_device
}
)..."
echo
"Running spec decode acceptance test (kv_buffer_device=
${
kv_device
}
, backend=
${
ATTENTION_BACKEND
}
)..."
DECODE_PORT
=
${
DECODE_PORTS
[0]
}
\
TEST_MODEL
=
$MODEL_NAME
\
python3
-m
pytest
-s
-x
"
${
GIT_ROOT
}
/tests/v1/kv_connector/nixl_integration/test_spec_decode_acceptance.py"
...
...
@@ -234,4 +268,4 @@ for device in $KV_BUFFER_DEVICES; do
run_test_for_device
"
$device
"
done
echo
"=== All spec decode acceptance tests passed ==="
echo
"=== All spec decode acceptance tests passed
(backend=
${
ATTENTION_BACKEND
}
)
==="
vllm/platforms/rocm.py
View file @
911355e2
...
...
@@ -851,6 +851,30 @@ class RocmPlatform(Platform):
"`dtype` flag in CLI, for example: --dtype=half."
)
@
classmethod
def
insert_blocks_to_device
(
cls
,
src_cache
:
torch
.
Tensor
,
dst_cache
:
torch
.
Tensor
,
src_block_indices
:
torch
.
Tensor
,
dst_block_indices
:
torch
.
Tensor
,
)
->
None
:
"""Copy blocks from src_cache to dst_cache on GPU."""
_src_cache
=
src_cache
[:,
src_block_indices
]
dst_cache
[:,
dst_block_indices
]
=
_src_cache
.
to
(
dst_cache
.
device
)
@
classmethod
def
swap_out_blocks_to_host
(
cls
,
src_cache
:
torch
.
Tensor
,
dst_cache
:
torch
.
Tensor
,
src_block_indices
:
torch
.
Tensor
,
dst_block_indices
:
torch
.
Tensor
,
)
->
None
:
"""Copy blocks from GPU to host (CPU)."""
_src_cache
=
src_cache
[:,
src_block_indices
]
dst_cache
[:,
dst_block_indices
]
=
_src_cache
.
cpu
()
@
classmethod
def
support_hybrid_kv_cache
(
cls
)
->
bool
:
return
True
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment