Unverified Commit 87d319c5 authored by Ryan Rock's avatar Ryan Rock Committed by GitHub
Browse files

[AMD][CI] Support Triton attention with ExampleConnector (#34931)


Signed-off-by: default avatarRyan Rock <ryan.rock@amd.com>
parent a9ec392c
......@@ -8,7 +8,7 @@ from PIL import Image
from vllm import LLM, EngineArgs, SamplingParams
from vllm.assets.image import ImageAsset
from vllm.config import KVTransferConfig
from vllm.config import AttentionConfig, KVTransferConfig
from vllm.multimodal.utils import encode_image_url
from vllm.platforms import current_platform
......@@ -110,14 +110,17 @@ def process_prompt(processor, llm: LLM, question: str, image_urls: list[Image]):
print("-" * 50)
@pytest.mark.skipif(
current_platform.is_rocm(),
reason=(
"hipErrorLaunchFailure when running this test, see issue:"
"https://github.com/ROCm/pytorch/issues/2822"
@pytest.mark.parametrize(
"attn_backend",
(
["FLASH_ATTN", "TRITON_ATTN"]
if current_platform.is_cuda()
else ["TRITON_ATTN"]
if current_platform.is_rocm()
else []
),
)
def test_shared_storage_connector_hashes(tmp_path):
def test_shared_storage_connector_hashes(tmp_path, attn_backend):
"""
Tests that ExampleConnector saves KV to the storage locations
with proper hashes; that are unique for inputs with identical text but
......@@ -138,6 +141,7 @@ def test_shared_storage_connector_hashes(tmp_path):
max_model_len=8192,
max_num_seqs=1,
gpu_memory_utilization=0.4,
attention_config=AttentionConfig(backend=attn_backend),
enforce_eager=True,
kv_transfer_config=kv_transfer_config,
limit_mm_per_prompt={"image": 2},
......
......@@ -20,7 +20,6 @@ from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import (
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
NixlKVConnectorStats,
)
from vllm.platforms import current_platform
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
......@@ -97,13 +96,6 @@ def _compare_directories(dir1: Path, dir2: Path) -> bool:
return True
@pytest.mark.skipif(
current_platform.is_rocm(),
reason=(
"hipErrorLaunchFailure when running this test, see issue:"
"https://github.com/ROCm/pytorch/issues/2822"
),
)
def test_multi_example_connector_consistency():
"""
Tests that MultiConnector with two ExampleConnectors saves
......
......@@ -17,6 +17,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.attention.mla_attention import MLACommonMetadata
from vllm.utils.hashing import safe_hash
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING:
......@@ -118,12 +119,12 @@ class ExampleConnector(KVConnectorBase_V1):
The number of elements in kv_caches and layer_names should be
the same.
"""
attn_metadata = forward_context.attn_metadata
def inject_kv_into_layer(
dst_kv_cache_layer: torch.Tensor,
src_kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> None:
"""Inject the KV cache into the layer.
......@@ -145,6 +146,10 @@ class ExampleConnector(KVConnectorBase_V1):
num_pages * page_size, -1
)
dst_kv_cache_layer[slot_mapping, ...] = src_kv_cache
elif isinstance(attn_metadata, TritonAttentionMetadata):
block_idxs = slot_mapping // self._block_size
offsets = slot_mapping % self._block_size
dst_kv_cache_layer[block_idxs, :, offsets] = src_kv_cache
else:
num_pages = dst_kv_cache_layer_shape[1]
page_size = dst_kv_cache_layer_shape[2]
......@@ -186,7 +191,13 @@ class ExampleConnector(KVConnectorBase_V1):
layer_name, request.token_ids, request.mm_hashes
)
kv_cache = safetensors.torch.load_file(filename)["kv_cache"].cuda()
inject_kv_into_layer(kv_cache_layer, kv_cache, request.slot_mapping)
if isinstance(attn_metadata, dict):
inject_kv_into_layer(
kv_cache_layer,
kv_cache,
request.slot_mapping,
attn_metadata[layer_name],
)
def wait_for_layer_load(self, layer_name: str) -> None:
"""Blocking until the KV for a specific layer is loaded into vLLM's
......@@ -229,6 +240,10 @@ class ExampleConnector(KVConnectorBase_V1):
if isinstance(attn_metadata, MLACommonMetadata):
num_pages, page_size = layer.shape[0], layer.shape[1]
return layer.reshape(num_pages * page_size, -1)[slot_mapping, ...]
elif isinstance(attn_metadata, TritonAttentionMetadata):
block_idxs = slot_mapping // self._block_size
offsets = slot_mapping % self._block_size
return layer[block_idxs, :, offsets]
num_pages, page_size = layer.shape[1], layer.shape[2]
return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping, ...]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment