Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
87d319c5
Unverified
Commit
87d319c5
authored
Mar 01, 2026
by
Ryan Rock
Committed by
GitHub
Mar 01, 2026
Browse files
[AMD][CI] Support Triton attention with ExampleConnector (#34931)
Signed-off-by:
Ryan Rock
<
ryan.rock@amd.com
>
parent
a9ec392c
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
28 additions
and
17 deletions
+28
-17
tests/v1/kv_connector/unit/test_example_connector.py
tests/v1/kv_connector/unit/test_example_connector.py
+11
-7
tests/v1/kv_connector/unit/test_multi_connector.py
tests/v1/kv_connector/unit/test_multi_connector.py
+0
-8
vllm/distributed/kv_transfer/kv_connector/v1/example_connector.py
...tributed/kv_transfer/kv_connector/v1/example_connector.py
+17
-2
No files found.
tests/v1/kv_connector/unit/test_example_connector.py
View file @
87d319c5
...
...
@@ -8,7 +8,7 @@ from PIL import Image
from
vllm
import
LLM
,
EngineArgs
,
SamplingParams
from
vllm.assets.image
import
ImageAsset
from
vllm.config
import
KVTransferConfig
from
vllm.config
import
AttentionConfig
,
KVTransferConfig
from
vllm.multimodal.utils
import
encode_image_url
from
vllm.platforms
import
current_platform
...
...
@@ -110,14 +110,17 @@ def process_prompt(processor, llm: LLM, question: str, image_urls: list[Image]):
print
(
"-"
*
50
)
@
pytest
.
mark
.
skipif
(
current_platform
.
is_rocm
(),
reason
=
(
"hipErrorLaunchFailure when running this test, see issue:"
"https://github.com/ROCm/pytorch/issues/2822"
@
pytest
.
mark
.
parametrize
(
"attn_backend"
,
(
[
"FLASH_ATTN"
,
"TRITON_ATTN"
]
if
current_platform
.
is_cuda
()
else
[
"TRITON_ATTN"
]
if
current_platform
.
is_rocm
()
else
[]
),
)
def
test_shared_storage_connector_hashes
(
tmp_path
):
def
test_shared_storage_connector_hashes
(
tmp_path
,
attn_backend
):
"""
Tests that ExampleConnector saves KV to the storage locations
with proper hashes; that are unique for inputs with identical text but
...
...
@@ -138,6 +141,7 @@ def test_shared_storage_connector_hashes(tmp_path):
max_model_len
=
8192
,
max_num_seqs
=
1
,
gpu_memory_utilization
=
0.4
,
attention_config
=
AttentionConfig
(
backend
=
attn_backend
),
enforce_eager
=
True
,
kv_transfer_config
=
kv_transfer_config
,
limit_mm_per_prompt
=
{
"image"
:
2
},
...
...
tests/v1/kv_connector/unit/test_multi_connector.py
View file @
87d319c5
...
...
@@ -20,7 +20,6 @@ from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import (
from
vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector
import
(
NixlKVConnectorStats
,
)
from
vllm.platforms
import
current_platform
MODEL_NAME
=
"meta-llama/Llama-3.2-1B-Instruct"
...
...
@@ -97,13 +96,6 @@ def _compare_directories(dir1: Path, dir2: Path) -> bool:
return
True
@
pytest
.
mark
.
skipif
(
current_platform
.
is_rocm
(),
reason
=
(
"hipErrorLaunchFailure when running this test, see issue:"
"https://github.com/ROCm/pytorch/issues/2822"
),
)
def
test_multi_example_connector_consistency
():
"""
Tests that MultiConnector with two ExampleConnectors saves
...
...
vllm/distributed/kv_transfer/kv_connector/v1/example_connector.py
View file @
87d319c5
...
...
@@ -17,6 +17,7 @@ from vllm.logger import init_logger
from
vllm.model_executor.layers.attention.mla_attention
import
MLACommonMetadata
from
vllm.utils.hashing
import
safe_hash
from
vllm.v1.attention.backend
import
AttentionMetadata
from
vllm.v1.attention.backends.triton_attn
import
TritonAttentionMetadata
from
vllm.v1.core.sched.output
import
SchedulerOutput
if
TYPE_CHECKING
:
...
...
@@ -118,12 +119,12 @@ class ExampleConnector(KVConnectorBase_V1):
The number of elements in kv_caches and layer_names should be
the same.
"""
attn_metadata
=
forward_context
.
attn_metadata
def
inject_kv_into_layer
(
dst_kv_cache_layer
:
torch
.
Tensor
,
src_kv_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
None
:
"""Inject the KV cache into the layer.
...
...
@@ -145,6 +146,10 @@ class ExampleConnector(KVConnectorBase_V1):
num_pages
*
page_size
,
-
1
)
dst_kv_cache_layer
[
slot_mapping
,
...]
=
src_kv_cache
elif
isinstance
(
attn_metadata
,
TritonAttentionMetadata
):
block_idxs
=
slot_mapping
//
self
.
_block_size
offsets
=
slot_mapping
%
self
.
_block_size
dst_kv_cache_layer
[
block_idxs
,
:,
offsets
]
=
src_kv_cache
else
:
num_pages
=
dst_kv_cache_layer_shape
[
1
]
page_size
=
dst_kv_cache_layer_shape
[
2
]
...
...
@@ -186,7 +191,13 @@ class ExampleConnector(KVConnectorBase_V1):
layer_name
,
request
.
token_ids
,
request
.
mm_hashes
)
kv_cache
=
safetensors
.
torch
.
load_file
(
filename
)[
"kv_cache"
].
cuda
()
inject_kv_into_layer
(
kv_cache_layer
,
kv_cache
,
request
.
slot_mapping
)
if
isinstance
(
attn_metadata
,
dict
):
inject_kv_into_layer
(
kv_cache_layer
,
kv_cache
,
request
.
slot_mapping
,
attn_metadata
[
layer_name
],
)
def
wait_for_layer_load
(
self
,
layer_name
:
str
)
->
None
:
"""Blocking until the KV for a specific layer is loaded into vLLM's
...
...
@@ -229,6 +240,10 @@ class ExampleConnector(KVConnectorBase_V1):
if
isinstance
(
attn_metadata
,
MLACommonMetadata
):
num_pages
,
page_size
=
layer
.
shape
[
0
],
layer
.
shape
[
1
]
return
layer
.
reshape
(
num_pages
*
page_size
,
-
1
)[
slot_mapping
,
...]
elif
isinstance
(
attn_metadata
,
TritonAttentionMetadata
):
block_idxs
=
slot_mapping
//
self
.
_block_size
offsets
=
slot_mapping
%
self
.
_block_size
return
layer
[
block_idxs
,
:,
offsets
]
num_pages
,
page_size
=
layer
.
shape
[
1
],
layer
.
shape
[
2
]
return
layer
.
reshape
(
2
,
num_pages
*
page_size
,
-
1
)[:,
slot_mapping
,
...]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment