Unverified Commit 87d319c5 authored by Ryan Rock's avatar Ryan Rock Committed by GitHub
Browse files

[AMD][CI] Support Triton attention with ExampleConnector (#34931)


Signed-off-by: default avatarRyan Rock <ryan.rock@amd.com>
parent a9ec392c
...@@ -8,7 +8,7 @@ from PIL import Image ...@@ -8,7 +8,7 @@ from PIL import Image
from vllm import LLM, EngineArgs, SamplingParams from vllm import LLM, EngineArgs, SamplingParams
from vllm.assets.image import ImageAsset from vllm.assets.image import ImageAsset
from vllm.config import KVTransferConfig from vllm.config import AttentionConfig, KVTransferConfig
from vllm.multimodal.utils import encode_image_url from vllm.multimodal.utils import encode_image_url
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -110,14 +110,17 @@ def process_prompt(processor, llm: LLM, question: str, image_urls: list[Image]): ...@@ -110,14 +110,17 @@ def process_prompt(processor, llm: LLM, question: str, image_urls: list[Image]):
print("-" * 50) print("-" * 50)
@pytest.mark.skipif( @pytest.mark.parametrize(
current_platform.is_rocm(), "attn_backend",
reason=( (
"hipErrorLaunchFailure when running this test, see issue:" ["FLASH_ATTN", "TRITON_ATTN"]
"https://github.com/ROCm/pytorch/issues/2822" if current_platform.is_cuda()
else ["TRITON_ATTN"]
if current_platform.is_rocm()
else []
), ),
) )
def test_shared_storage_connector_hashes(tmp_path): def test_shared_storage_connector_hashes(tmp_path, attn_backend):
""" """
Tests that ExampleConnector saves KV to the storage locations Tests that ExampleConnector saves KV to the storage locations
with proper hashes; that are unique for inputs with identical text but with proper hashes; that are unique for inputs with identical text but
...@@ -138,6 +141,7 @@ def test_shared_storage_connector_hashes(tmp_path): ...@@ -138,6 +141,7 @@ def test_shared_storage_connector_hashes(tmp_path):
max_model_len=8192, max_model_len=8192,
max_num_seqs=1, max_num_seqs=1,
gpu_memory_utilization=0.4, gpu_memory_utilization=0.4,
attention_config=AttentionConfig(backend=attn_backend),
enforce_eager=True, enforce_eager=True,
kv_transfer_config=kv_transfer_config, kv_transfer_config=kv_transfer_config,
limit_mm_per_prompt={"image": 2}, limit_mm_per_prompt={"image": 2},
......
...@@ -20,7 +20,6 @@ from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import ( ...@@ -20,7 +20,6 @@ from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import (
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
NixlKVConnectorStats, NixlKVConnectorStats,
) )
from vllm.platforms import current_platform
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
...@@ -97,13 +96,6 @@ def _compare_directories(dir1: Path, dir2: Path) -> bool: ...@@ -97,13 +96,6 @@ def _compare_directories(dir1: Path, dir2: Path) -> bool:
return True return True
@pytest.mark.skipif(
current_platform.is_rocm(),
reason=(
"hipErrorLaunchFailure when running this test, see issue:"
"https://github.com/ROCm/pytorch/issues/2822"
),
)
def test_multi_example_connector_consistency(): def test_multi_example_connector_consistency():
""" """
Tests that MultiConnector with two ExampleConnectors saves Tests that MultiConnector with two ExampleConnectors saves
......
...@@ -17,6 +17,7 @@ from vllm.logger import init_logger ...@@ -17,6 +17,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.attention.mla_attention import MLACommonMetadata from vllm.model_executor.layers.attention.mla_attention import MLACommonMetadata
from vllm.utils.hashing import safe_hash from vllm.utils.hashing import safe_hash
from vllm.v1.attention.backend import AttentionMetadata from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -118,12 +119,12 @@ class ExampleConnector(KVConnectorBase_V1): ...@@ -118,12 +119,12 @@ class ExampleConnector(KVConnectorBase_V1):
The number of elements in kv_caches and layer_names should be The number of elements in kv_caches and layer_names should be
the same. the same.
""" """
attn_metadata = forward_context.attn_metadata
def inject_kv_into_layer( def inject_kv_into_layer(
dst_kv_cache_layer: torch.Tensor, dst_kv_cache_layer: torch.Tensor,
src_kv_cache: torch.Tensor, src_kv_cache: torch.Tensor,
slot_mapping: torch.Tensor, slot_mapping: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> None: ) -> None:
"""Inject the KV cache into the layer. """Inject the KV cache into the layer.
...@@ -145,6 +146,10 @@ class ExampleConnector(KVConnectorBase_V1): ...@@ -145,6 +146,10 @@ class ExampleConnector(KVConnectorBase_V1):
num_pages * page_size, -1 num_pages * page_size, -1
) )
dst_kv_cache_layer[slot_mapping, ...] = src_kv_cache dst_kv_cache_layer[slot_mapping, ...] = src_kv_cache
elif isinstance(attn_metadata, TritonAttentionMetadata):
block_idxs = slot_mapping // self._block_size
offsets = slot_mapping % self._block_size
dst_kv_cache_layer[block_idxs, :, offsets] = src_kv_cache
else: else:
num_pages = dst_kv_cache_layer_shape[1] num_pages = dst_kv_cache_layer_shape[1]
page_size = dst_kv_cache_layer_shape[2] page_size = dst_kv_cache_layer_shape[2]
...@@ -186,7 +191,13 @@ class ExampleConnector(KVConnectorBase_V1): ...@@ -186,7 +191,13 @@ class ExampleConnector(KVConnectorBase_V1):
layer_name, request.token_ids, request.mm_hashes layer_name, request.token_ids, request.mm_hashes
) )
kv_cache = safetensors.torch.load_file(filename)["kv_cache"].cuda() kv_cache = safetensors.torch.load_file(filename)["kv_cache"].cuda()
inject_kv_into_layer(kv_cache_layer, kv_cache, request.slot_mapping) if isinstance(attn_metadata, dict):
inject_kv_into_layer(
kv_cache_layer,
kv_cache,
request.slot_mapping,
attn_metadata[layer_name],
)
def wait_for_layer_load(self, layer_name: str) -> None: def wait_for_layer_load(self, layer_name: str) -> None:
"""Blocking until the KV for a specific layer is loaded into vLLM's """Blocking until the KV for a specific layer is loaded into vLLM's
...@@ -229,6 +240,10 @@ class ExampleConnector(KVConnectorBase_V1): ...@@ -229,6 +240,10 @@ class ExampleConnector(KVConnectorBase_V1):
if isinstance(attn_metadata, MLACommonMetadata): if isinstance(attn_metadata, MLACommonMetadata):
num_pages, page_size = layer.shape[0], layer.shape[1] num_pages, page_size = layer.shape[0], layer.shape[1]
return layer.reshape(num_pages * page_size, -1)[slot_mapping, ...] return layer.reshape(num_pages * page_size, -1)[slot_mapping, ...]
elif isinstance(attn_metadata, TritonAttentionMetadata):
block_idxs = slot_mapping // self._block_size
offsets = slot_mapping % self._block_size
return layer[block_idxs, :, offsets]
num_pages, page_size = layer.shape[1], layer.shape[2] num_pages, page_size = layer.shape[1], layer.shape[2]
return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping, ...] return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping, ...]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment