Unverified Commit 543aa485 authored by afeldman-nm's avatar afeldman-nm Committed by GitHub
Browse files

[Kernel] Correctly invoke prefill & decode kernels for cross-attention...


[Kernel] Correctly invoke prefill & decode kernels for cross-attention (towards eventual encoder/decoder model support) (#4888)
Co-authored-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
parent f7a8fa39
...@@ -47,32 +47,32 @@ def test_flash_attn(monkeypatch): ...@@ -47,32 +47,32 @@ def test_flash_attn(monkeypatch):
# Unsupported CUDA arch # Unsupported CUDA arch
with patch("torch.cuda.get_device_capability", return_value=[7, 5]): with patch("torch.cuda.get_device_capability", return_value=[7, 5]):
backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16)
assert backend.name != "FLASH_ATTN" assert backend.name != STR_FLASH_ATTN_VAL
# Unsupported data type # Unsupported data type
backend = which_attn_to_use(8, 16, 8, None, torch.float8_e4m3fn, None, 16) backend = which_attn_to_use(8, 16, 8, None, torch.float8_e4m3fn, None, 16)
assert backend.name != "FLASH_ATTN" assert backend.name != STR_FLASH_ATTN_VAL
# Unsupported kv cache data type # Unsupported kv cache data type
backend = which_attn_to_use(8, 16, 8, None, torch.float16, "fp8", 16) backend = which_attn_to_use(8, 16, 8, None, torch.float16, "fp8", 16)
assert backend.name != "FLASH_ATTN" assert backend.name != STR_FLASH_ATTN_VAL
# Unsupported block size # Unsupported block size
backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 8) backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 8)
assert backend.name != "FLASH_ATTN" assert backend.name != STR_FLASH_ATTN_VAL
# Unsupported sliding window # Unsupported sliding window
backend = which_attn_to_use(8, 16, 8, 1, torch.float16, None, 16) backend = which_attn_to_use(8, 16, 8, 1, torch.float16, None, 16)
assert backend.name != "FLASH_ATTN" assert backend.name != STR_FLASH_ATTN_VAL
# flash-attn is not installed # flash-attn is not installed
with patch.dict('sys.modules', {'vllm_flash_attn': None}): with patch.dict('sys.modules', {'vllm_flash_attn': None}):
backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16)
assert backend.name != "FLASH_ATTN" assert backend.name != STR_FLASH_ATTN_VAL
# Unsupported head size # Unsupported head size
backend = which_attn_to_use(8, 17, 8, None, torch.float16, None, 16) backend = which_attn_to_use(8, 17, 8, None, torch.float16, None, 16)
assert backend.name != "FLASH_ATTN" assert backend.name != STR_FLASH_ATTN_VAL
def test_invalid_env(monkeypatch): def test_invalid_env(monkeypatch):
......
This diff is collapsed.
This diff is collapsed.
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
from enum import Enum, auto
from typing import (Any, Dict, Generic, List, Optional, Set, Tuple, Type, from typing import (Any, Dict, Generic, List, Optional, Set, Tuple, Type,
TypeVar) TypeVar)
import torch import torch
class AttentionType(Enum):
DECODER = auto() # Decoder attention between previous layer Q/K/V
ENCODER = auto() # Encoder attention between previous layer Q/K/V
ENCODER_DECODER = auto() # Attention between dec. Q and enc. K/V
class AttentionBackend(ABC): class AttentionBackend(ABC):
"""Abstract class for attention backends.""" """Abstract class for attention backends."""
...@@ -128,5 +135,6 @@ class AttentionImpl(ABC, Generic[T]): ...@@ -128,5 +135,6 @@ class AttentionImpl(ABC, Generic[T]):
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: T, attn_metadata: T,
kv_scale: float = 1.0, kv_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
...@@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type ...@@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type
import torch import torch
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata) AttentionMetadata, AttentionType)
from vllm.attention.ops.blocksparse_attention.interface import ( from vllm.attention.ops.blocksparse_attention.interface import (
LocalStridedBlockSparseAttn, get_head_sliding_step) LocalStridedBlockSparseAttn, get_head_sliding_step)
from vllm.attention.ops.paged_attn import PagedAttention from vllm.attention.ops.paged_attn import PagedAttention
...@@ -328,6 +328,7 @@ class BlocksparseFlashAttentionImpl(AttentionImpl): ...@@ -328,6 +328,7 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: BlocksparseFlashAttentionMetadata, attn_metadata: BlocksparseFlashAttentionMetadata,
kv_scale: float = 1.0, kv_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention. """Forward pass with FlashAttention and PagedAttention.
...@@ -340,6 +341,12 @@ class BlocksparseFlashAttentionImpl(AttentionImpl): ...@@ -340,6 +341,12 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"BlocksparseFlashAttentionImpl")
num_tokens, hidden_size = query.shape num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors. # Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
......
...@@ -7,7 +7,7 @@ from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache ...@@ -7,7 +7,7 @@ from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata) AttentionMetadata, AttentionType)
class FlashAttentionBackend(AttentionBackend): class FlashAttentionBackend(AttentionBackend):
...@@ -257,6 +257,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -257,6 +257,7 @@ class FlashAttentionImpl(AttentionImpl):
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata, attn_metadata: FlashAttentionMetadata,
kv_scale: float = 1.0, kv_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with FlashAttention. """Forward pass with FlashAttention.
...@@ -269,6 +270,12 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -269,6 +270,12 @@ class FlashAttentionImpl(AttentionImpl):
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashAttentionImpl")
# NOTE(woosuk): FlashAttention does not support FP8 KV cache. # NOTE(woosuk): FlashAttention does not support FP8 KV cache.
assert kv_scale == 1.0, "kv_scale is not supported in FlashAttention." assert kv_scale == 1.0, "kv_scale is not supported in FlashAttention."
......
...@@ -14,7 +14,7 @@ import torch ...@@ -14,7 +14,7 @@ import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata) AttentionMetadata, AttentionType)
class FlashInferBackend(AttentionBackend): class FlashInferBackend(AttentionBackend):
...@@ -224,8 +224,14 @@ class FlashInferImpl(AttentionImpl): ...@@ -224,8 +224,14 @@ class FlashInferImpl(AttentionImpl):
kv_cache: Optional[torch.Tensor], kv_cache: Optional[torch.Tensor],
attn_metadata: FlashInferMetadata, attn_metadata: FlashInferMetadata,
kv_scale: float = 1.0, kv_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
assert kv_scale == 1.0 assert kv_scale == 1.0
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashInferImpl")
num_tokens, hidden_size = query.shape num_tokens, hidden_size = query.shape
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size)
......
...@@ -7,7 +7,7 @@ import torch ...@@ -7,7 +7,7 @@ import torch
from vllm._ipex_ops import ipex_ops from vllm._ipex_ops import ipex_ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata) AttentionMetadata, AttentionType)
from vllm.attention.ops.paged_attn import (PagedAttention, from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata) PagedAttentionMetadata)
...@@ -157,6 +157,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): ...@@ -157,6 +157,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
kv_cache: Optional[torch.Tensor], kv_cache: Optional[torch.Tensor],
attn_metadata: IpexAttnMetadata, # type: ignore attn_metadata: IpexAttnMetadata, # type: ignore
kv_scale: float = 1.0, kv_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with IPEX varlen_attention and PagedAttention. """Forward pass with IPEX varlen_attention and PagedAttention.
...@@ -170,6 +171,11 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): ...@@ -170,6 +171,11 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
assert kv_scale == 1.0 assert kv_scale == 1.0
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"IpexAttnBackendImpl")
num_tokens, hidden_size = query.shape num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors. # Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
......
...@@ -6,7 +6,7 @@ import torch_xla.experimental.custom_kernel # Required to register custom ops. ...@@ -6,7 +6,7 @@ import torch_xla.experimental.custom_kernel # Required to register custom ops.
import torch_xla.experimental.dynamo_set_buffer_donor import torch_xla.experimental.dynamo_set_buffer_donor
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata) AttentionMetadata, AttentionType)
class PallasAttentionBackend(AttentionBackend): class PallasAttentionBackend(AttentionBackend):
...@@ -132,6 +132,7 @@ class PallasAttentionBackendImpl(AttentionImpl): ...@@ -132,6 +132,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
kv_cache: Tuple[Optional[torch.Tensor], Optional[torch.Tensor]], kv_cache: Tuple[Optional[torch.Tensor], Optional[torch.Tensor]],
attn_metadata: PallasMetadata, attn_metadata: PallasMetadata,
kv_scale: float = 1.0, kv_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with Pallas attention. """Forward pass with Pallas attention.
...@@ -146,6 +147,11 @@ class PallasAttentionBackendImpl(AttentionImpl): ...@@ -146,6 +147,11 @@ class PallasAttentionBackendImpl(AttentionImpl):
shape = [batch_size, seq_len, num_heads * head_size] shape = [batch_size, seq_len, num_heads * head_size]
""" """
assert kv_scale == 1.0 assert kv_scale == 1.0
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"PallasAttentionBackendImpl")
batch_size, seq_len, hidden_size = query.shape batch_size, seq_len, hidden_size = query.shape
query = query.view(batch_size, seq_len, self.num_heads, self.head_size) query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size) key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)
......
...@@ -6,7 +6,7 @@ import torch ...@@ -6,7 +6,7 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata) AttentionMetadata, AttentionType)
from vllm.attention.ops.paged_attn import (PagedAttention, from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata) PagedAttentionMetadata)
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -297,6 +297,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -297,6 +297,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: ROCmFlashAttentionMetadata, attn_metadata: ROCmFlashAttentionMetadata,
kv_scale: float = 1.0, kv_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention. """Forward pass with FlashAttention and PagedAttention.
...@@ -309,6 +310,12 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -309,6 +310,12 @@ class ROCmFlashAttentionImpl(AttentionImpl):
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"ROCmFlashAttentionImpl")
num_tokens, hidden_size = query.shape num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors. # Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
......
...@@ -7,7 +7,7 @@ import torch ...@@ -7,7 +7,7 @@ import torch
from torch.nn.functional import scaled_dot_product_attention from torch.nn.functional import scaled_dot_product_attention
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata) AttentionMetadata, AttentionType)
from vllm.attention.ops.paged_attn import PagedAttentionMetadata from vllm.attention.ops.paged_attn import PagedAttentionMetadata
from vllm.utils import is_cpu from vllm.utils import is_cpu
...@@ -145,6 +145,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): ...@@ -145,6 +145,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
kv_cache: Optional[torch.Tensor], kv_cache: Optional[torch.Tensor],
attn_metadata: TorchSDPAMetadata, # type: ignore attn_metadata: TorchSDPAMetadata, # type: ignore
kv_scale: float = 1.0, kv_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention. """Forward pass with torch SDPA and PagedAttention.
...@@ -158,6 +159,11 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): ...@@ -158,6 +159,11 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
assert kv_scale == 1.0 assert kv_scale == 1.0
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"TorchSDPABackendImpl")
num_tokens, hidden_size = query.shape num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors. # Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
......
"""Attention backend utils"""
# Error string(s) for encoder/decoder
# unsupported attention scenarios
STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported "
"with encoder/decoder models.")
This diff is collapsed.
...@@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional ...@@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata, AttentionType
from vllm.attention.selector import get_attn_backend from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
...@@ -90,9 +90,16 @@ class Attention(nn.Module): ...@@ -90,9 +90,16 @@ class Attention(nn.Module):
value: torch.Tensor, value: torch.Tensor,
kv_cache: Optional[torch.Tensor], kv_cache: Optional[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
return self.impl.forward(query, key, value, kv_cache, attn_metadata,
self._kv_scale) return self.impl.forward(query,
key,
value,
kv_cache,
attn_metadata,
self._kv_scale,
attn_type=attn_type)
def extra_repr(self) -> str: def extra_repr(self) -> str:
s = f"head_size={self.impl.head_size}" # type: ignore s = f"head_size={self.impl.head_size}" # type: ignore
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment