Commit afd0da21 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.7.1' into v0.7.1-dev

parents 1a11f127 4f4d427a
......@@ -31,7 +31,7 @@ if __name__ == "__main__":
type=str,
required=True,
help="json trace file output by "
"examples/offline_profile.py")
"examples/offline_inference/profiling.py")
parser.add_argument("--phase",
type=str,
required=True,
......
......@@ -534,11 +534,11 @@ def main(
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--json-trace",
type=str,
required=True,
help="json trace file output by examples/offline_profile.py")
parser.add_argument("--json-trace",
type=str,
required=True,
help="json trace file output by \
examples/offline_inference/profiling.py")
parser.add_argument("--output-directory",
type=str,
required=False,
......
......@@ -274,8 +274,9 @@ def SummarizeEntries(entries, extra_step_types):
print(' {:.1f} s weighted time ({:.1f} s elapsed time sum, {:1.1f}x '
'parallelism)'.format(length, total_cpu_time,
total_cpu_time * 1.0 / length))
print(' %d build steps completed, average of %1.2f/s' %
(len(entries), len(entries) / (length)))
print(' {} build steps completed, average of {:1.2f}/s'.format(
len(entries),
len(entries) / (length)))
def main():
......
......@@ -19,4 +19,4 @@ if ! [ -x "$(command -v shellcheck)" ]; then
fi
# TODO - fix warnings in .buildkite/run-amd-test.sh
find . -name "*.sh" -not -path "./.buildkite/run-amd-test.sh" -print0 | xargs -0 -I {} sh -c 'git check-ignore -q "{}" || shellcheck "{}"'
find . -name "*.sh" ".git" -prune -not -path "./.buildkite/run-amd-test.sh" -print0 | xargs -0 -I {} sh -c 'git check-ignore -q "{}" || shellcheck -s bash "{}"'
#!/bin/bash
sphinx-lint --disable trailing-whitespace,missing-final-newline docs
"""vLLM: a high-throughput and memory-efficient inference engine for LLMs"""
import os
import torch
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
......@@ -17,6 +20,19 @@ from vllm.sampling_params import SamplingParams
from vllm.version import __version__, __version_tuple__, __hcu_version__
# set some common config/environment variables that should be set
# for all processes created by vllm and all processes
# that interact with vllm workers.
# they are executed whenever `import vllm` is called.
# see https://github.com/NVIDIA/nccl/issues/1234
os.environ['NCCL_CUMEM_ENABLE'] = '0'
# see https://github.com/vllm-project/vllm/issues/10480
os.environ['TORCHINDUCTOR_COMPILE_THREADS'] = '1'
# see https://github.com/vllm-project/vllm/issues/10619
torch._inductor.config.compile_threads = 1
__all__ = [
"__version__",
"__version_tuple__",
......
......@@ -30,8 +30,7 @@ with contextlib.suppress(ImportError):
import vllm._moe_C # noqa: F401
supports_moe_ops = True
# neuron has torch version that doesn't even have impl_abstract
if TYPE_CHECKING or current_platform.is_neuron():
if TYPE_CHECKING:
def register_fake(fn):
return lambda name: fn
......@@ -42,49 +41,6 @@ else:
from torch.library import impl_abstract as register_fake
# activation ops
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.silu_and_mul(out, x)
def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_and_mul(out, x)
def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_tanh_and_mul(out, x)
def silu_and_mul_opt(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.silu_and_mul_opt(out, x)
def gelu_and_mul_opt(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_and_mul_opt(out, x)
def gelu_tanh_and_mul_opt(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_tanh_and_mul_opt(out, x)
def fatrelu_and_mul(out: torch.Tensor,
x: torch.Tensor,
threshold: float = 0.0) -> None:
torch.ops._C.fatrelu_and_mul(out, x, threshold)
def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_fast(out, x)
def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_new(out, x)
def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_quick(out, x)
# page attention ops
def paged_attention_v1(
out: torch.Tensor,
......@@ -99,8 +55,8 @@ def paged_attention_v1(
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
......@@ -131,8 +87,8 @@ def paged_attention_v2(
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
......@@ -486,8 +442,8 @@ def paged_attention_rocm(
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
) -> None:
torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query,
key_cache, value_cache, num_kv_heads,
......@@ -866,16 +822,43 @@ def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability)
def cutlass_scaled_mm_supports_block_fp8(cuda_device_capability: int) -> bool:
return torch.ops._C.cutlass_scaled_mm_supports_block_fp8(
cuda_device_capability)
def cutlass_scaled_mm(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: torch.dtype,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
# assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
# assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
# assert bias is None or bias.shape[0] == b.shape[
# 1] and bias.dtype == out_dtype
"""
`cutlass_scaled_mm` implements a fused version of
`output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)`
where scale_a * a and scale_b * b are implemented using numpy-style
broadcasting.
In order to support blockwise scaling like found in DeepSeek V3 we also
support extended "group" broadcast rules. We extend the numpy-style
broadcasting rules with the following rule:
"if the extent of a dimension in the source shape is between 1 and
corresponding extent in the target shape we repeat each element along
that dimension src_shape[dim] // target_shape[dim] times consecutively"
example if we have:
a = [[1, 2], and target_shape = (2, 4)
[3, 4]]
then we would expand a to:
a = [[1, 1, 2, 2],
[3, 3, 4, 4]]
currently we only support the case:
scale_a.shape * [1, 128] == a.shape
scale_b.shape * [128, 128] == b.shape
"""
assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
assert bias is None or bias.shape[0] == b.shape[
1] and bias.dtype == out_dtype
# m = a.shape[0]
# n = b.shape[1]
......@@ -1283,8 +1266,8 @@ def scaled_int8_quant(
if scale is not None:
# static-per-tensor quantization.
assert symmetric == (
azp is
None), "azp must only be provided for asymmetric quantization."
azp
is None), "azp must only be provided for asymmetric quantization."
torch.ops._C.static_scaled_int8_quant(output, input, scale, azp)
return output, scale, azp
......@@ -1419,8 +1402,8 @@ def reshape_and_cache(
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
) -> None:
torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache,
value_cache, slot_mapping,
......@@ -1434,8 +1417,8 @@ def reshape_and_cache_flash(
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
) -> None:
torch.ops._C_cache_ops.reshape_and_cache_flash(key, value, key_cache,
value_cache, slot_mapping,
......@@ -1443,6 +1426,19 @@ def reshape_and_cache_flash(
v_scale)
def concat_and_cache_mla(
kv_c: torch.Tensor,
k_pe: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
scale: torch.Tensor,
) -> None:
torch.ops._C_cache_ops.concat_and_cache_mla(kv_c, k_pe, kv_cache,
slot_mapping, kv_cache_dtype,
scale)
def copy_blocks(key_caches: List[torch.Tensor],
value_caches: List[torch.Tensor],
block_mapping: torch.Tensor) -> None:
......
......@@ -21,12 +21,10 @@ class AudioAsset:
name: Literal["winning_call", "mary_had_lamb"]
@property
def audio_and_sample_rate(self) -> tuple[npt.NDArray, int]:
def audio_and_sample_rate(self) -> tuple[npt.NDArray, float]:
audio_path = get_vllm_public_assets(filename=f"{self.name}.ogg",
s3_prefix=ASSET_DIR)
y, sr = librosa.load(audio_path, sr=None)
assert isinstance(sr, int)
return y, sr
return librosa.load(audio_path, sr=None)
@property
def url(self) -> str:
......
......@@ -26,4 +26,4 @@ class ImageAsset:
"""
image_path = get_vllm_public_assets(filename=f"{self.name}.pt",
s3_prefix=VLM_IMAGES_DIR)
return torch.load(image_path, map_location="cpu")
return torch.load(image_path, map_location="cpu", weights_only=True)
from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass, fields
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
Tuple, Type, TypeVar)
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional,
Protocol, Set, Tuple, Type, TypeVar)
import torch
......@@ -31,6 +31,10 @@ class AttentionType:
class AttentionBackend(ABC):
"""Abstract class for attention backends."""
# For some attention backends, we allocate an output tensor before
# calling the custom op. When piecewise cudagraph is enabled, this
# makes sure the output tensor is allocated inside the cudagraph.
accept_output_buffer: bool = False
@staticmethod
@abstractmethod
......@@ -61,11 +65,6 @@ class AttentionBackend(ABC):
def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
raise NotImplementedError
@classmethod
def make_metadata_builder(cls, *args,
**kwargs) -> "AttentionMetadataBuilder":
return cls.get_builder_cls()(*args, **kwargs)
@staticmethod
@abstractmethod
def get_kv_cache_shape(
......@@ -124,6 +123,10 @@ class AttentionMetadata:
multi_modal_placeholder_index_maps: Optional[Dict[
str, MultiModalPlaceholderMap.IndexMap]]
# Enable/disable KV scales calculation. This is so that we can disable the
# calculation until after prefill and cuda graph capture.
enable_kv_scales_calculation: bool
@property
@abstractmethod
def prefill_metadata(self) -> Optional["AttentionMetadata"]:
......@@ -210,6 +213,12 @@ class AttentionMetadataBuilder(ABC, Generic[T]):
@abstractmethod
def __init__(self, input_builder: "ModelRunnerInputBuilderBase") -> None:
"""Create the builder, remember some configuration and parameters."""
raise NotImplementedError
@abstractmethod
def prepare(self) -> None:
"""Prepare for one batch."""
raise NotImplementedError
@abstractmethod
......@@ -219,6 +228,24 @@ class AttentionMetadataBuilder(ABC, Generic[T]):
raise NotImplementedError
class AttentionLayer(Protocol):
_k_scale: torch.Tensor
_v_scale: torch.Tensor
_k_scale_float: float
_v_scale_float: float
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
...
class AttentionImpl(ABC, Generic[T]):
@abstractmethod
......@@ -233,20 +260,35 @@ class AttentionImpl(ABC, Generic[T]):
kv_cache_dtype: str = "auto",
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
) -> None:
raise NotImplementedError
@abstractmethod
def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: T,
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError
class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
@abstractmethod
def forward(
self,
layer: AttentionLayer,
hidden_states_or_cq: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: T,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError
......@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type
import torch
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import (CommonAttentionState,
CommonMetadataBuilder)
......@@ -90,8 +91,7 @@ class BlocksparseFlashAttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
# For attention layer compatibility
return "FLASH_ATTN"
return "BLOCK_SPARSE_FLASH_ATTN"
@staticmethod
def get_impl_cls() -> Type["BlocksparseFlashAttentionImpl"]:
......@@ -225,6 +225,7 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps,
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
seq_lens=self.seq_lens[:self.num_prefills],
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
max_query_len=self.max_query_len,
......@@ -255,6 +256,7 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
num_decode_tokens=self.num_decode_tokens,
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
seq_lens=None,
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
max_query_len=None,
......@@ -305,6 +307,7 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
) -> None:
assert blocksparse_params is not None
assert alibi_slopes is None, ValueError(
......@@ -355,16 +358,20 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
active_head_range=self.blocksparse_params.active_head_range,
)
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"BlocksparseFlashAttentionImpl")
def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: BlocksparseFlashAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
......@@ -380,12 +387,6 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"BlocksparseFlashAttentionImpl")
num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
......@@ -407,8 +408,8 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
)
if prefill_meta := attn_metadata.prefill_metadata:
......@@ -445,8 +446,8 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
self.num_kv_heads,
self.scale,
self.alibi_slopes,
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
tp_rank=self.tp_rank,
blocksparse_local_blocks=self.local_blocks,
blocksparse_vert_stride=self.vert_stride,
......
......@@ -8,6 +8,7 @@ import torch
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata,
AttentionMetadataBuilder,
AttentionType)
......@@ -16,19 +17,27 @@ from vllm.attention.backends.utils import (
compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens,
get_seq_len_block_table_args, is_all_cross_attn_metadata_set,
is_all_encoder_attn_metadata_set, is_block_tables_empty)
from vllm.envs import VLLM_FLASH_ATTN_VERSION
from vllm.logger import init_logger
from vllm.multimodal import MultiModalPlaceholderMap
from vllm.platforms import current_platform
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
from vllm.vllm_flash_attn import (fa_version_unsupported_reason,
flash_attn_varlen_func,
flash_attn_with_kvcache,
is_fa_version_supported)
if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
ModelInputForGPUWithSamplingMetadata)
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
flash_attn_with_kvcache)
logger = init_logger(__name__)
class FlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_supported_head_sizes() -> List[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
......@@ -224,6 +233,7 @@ class FlashAttentionMetadata(AttentionMetadata):
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps,
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=self.max_query_len,
......@@ -268,6 +278,7 @@ class FlashAttentionMetadata(AttentionMetadata):
num_decode_tokens=self.num_decode_tokens,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
seq_lens=None,
seq_lens_tensor=seq_lens_tensor,
max_decode_query_len=self.max_decode_query_len,
......@@ -372,6 +383,12 @@ class FlashAttentionMetadataBuilder(
AttentionMetadataBuilder[FlashAttentionMetadata]):
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.input_builder = input_builder
self.runner = input_builder.runner
self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size
def prepare(self):
self.slot_mapping: List[int] = []
self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = []
......@@ -385,11 +402,6 @@ class FlashAttentionMetadataBuilder(
self.num_decode_tokens = 0
self.has_prefix_cache_hit = False
self.input_builder = input_builder
self.runner = input_builder.runner
self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size
def _add_seq_group(
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
chunked_prefill_enabled: bool, prefix_cache_hit: bool):
......@@ -552,6 +564,7 @@ class FlashAttentionMetadataBuilder(
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=True,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
max_decode_query_len=max_decode_query_len,
......@@ -602,6 +615,7 @@ class FlashAttentionImpl(AttentionImpl):
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
) -> None:
if blocksparse_params is not None:
raise ValueError(
......@@ -629,17 +643,35 @@ class FlashAttentionImpl(AttentionImpl):
raise ValueError(
f"Head size {head_size} is not supported by FlashAttention. "
f"Supported head sizes are: {support_head_sizes}.")
self.attn_type = attn_type
# if hopper default to FA3, otherwise stick to FA2 for now
# TODO(lucas): profile FA3 on ampere to see if it makes sense to
# use FA3 as default for both
if current_platform.get_device_capability()[0] >= 9:
self.fa_version = 3 if is_fa_version_supported(3) else 2
else:
self.fa_version = 2
if VLLM_FLASH_ATTN_VERSION is not None:
assert VLLM_FLASH_ATTN_VERSION in [2, 3]
self.fa_version = VLLM_FLASH_ATTN_VERSION
if not is_fa_version_supported(self.fa_version):
logger.error("Cannot use FA version %d is not supported due to %s",
self.fa_version,
fa_version_unsupported_reason(self.fa_version))
assert is_fa_version_supported(self.fa_version)
def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
......@@ -656,11 +688,12 @@ class FlashAttentionImpl(AttentionImpl):
NOTE: It in-place updates the output tensor.
"""
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
assert k_scale == 1.0 and v_scale == 1.0, (
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0, (
"key/v_scale is not supported in FlashAttention.")
assert output is not None, "Output tensor must be provided."
attn_type = self.attn_type
if (attn_type == AttentionType.ENCODER
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
raise AttributeError("Encoder attention requires setting "
......@@ -707,8 +740,8 @@ class FlashAttentionImpl(AttentionImpl):
kv_cache[1],
updated_slot_mapping.flatten(), # type: ignore[union-attr]
kv_cache_dtype,
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
)
(num_prefill_query_tokens, num_prefill_kv_tokens,
......@@ -749,6 +782,7 @@ class FlashAttentionImpl(AttentionImpl):
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
out=prefill_output,
fa_version=self.fa_version,
)
else:
# prefix-enabled attention
......@@ -762,7 +796,7 @@ class FlashAttentionImpl(AttentionImpl):
v=value_cache,
cu_seqlens_q=prefill_meta.query_start_loc,
max_seqlen_q=prefill_meta.max_query_len,
cu_seqlens_k=prefill_meta.seq_start_loc,
seqused_k=prefill_meta.seq_lens_tensor,
max_seqlen_k=max_seq_len,
softmax_scale=softmax_scale,
causal=True,
......@@ -771,6 +805,7 @@ class FlashAttentionImpl(AttentionImpl):
block_table=prefill_meta.block_tables,
softcap=logits_soft_cap,
out=prefill_output,
fa_version=self.fa_version,
)
if decode_meta := attn_metadata.decode_metadata:
......@@ -790,7 +825,7 @@ class FlashAttentionImpl(AttentionImpl):
v=value_cache,
cu_seqlens_q=decode_meta.query_start_loc,
max_seqlen_q=decode_meta.max_decode_query_len,
cu_seqlens_k=decode_meta.seq_start_loc,
seqused_k=decode_meta.seq_lens_tensor,
max_seqlen_k=decode_meta.max_decode_seq_len,
softmax_scale=softmax_scale,
causal=True,
......@@ -799,6 +834,7 @@ class FlashAttentionImpl(AttentionImpl):
softcap=logits_soft_cap,
block_table=decode_meta.block_tables,
out=decode_output,
fa_version=self.fa_version,
)
else:
# Use flash_attn_with_kvcache for normal decoding.
......@@ -819,6 +855,7 @@ class FlashAttentionImpl(AttentionImpl):
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
out=decode_output.unsqueeze(1),
fa_version=self.fa_version,
)
return output
......
import dataclasses
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
......@@ -13,9 +14,11 @@ try:
from vllm.vllm_flash_attn import flash_attn_varlen_func
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
except ImportError:
BatchDecodeWithPagedKVCacheWrapper = None
CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
BatchPrefillWithPagedKVCacheWrapper = None
# Avoid turning these types into variables during type checking
if not TYPE_CHECKING:
BatchDecodeWithPagedKVCacheWrapper = None
CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
BatchPrefillWithPagedKVCacheWrapper = None
FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
import torch
......@@ -23,13 +26,16 @@ import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata,
AttentionMetadataBuilder,
AttentionState, AttentionType)
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
compute_slot_mapping_start_idx,
is_block_tables_empty)
from vllm.attention.layer import Attention
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
make_tensor_with_pad)
......@@ -98,6 +104,72 @@ class FlashInferBackend(AttentionBackend):
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")
@dataclass
class PerLayerParameters:
"""
Currently, FlashInfer backend only support models in which all layers share
the same values for the following hyperparameters.
"""
window_left: int
logits_soft_cap: Optional[float]
sm_scale: float
def get_per_layer_parameters(
vllm_config: VllmConfig) -> Dict[str, PerLayerParameters]:
"""
Scan all attention layers and determine some hyperparameters
to use during `plan`.
"""
layers = vllm_config.compilation_config.static_forward_context
per_layer_params: Dict[str, PerLayerParameters] = {}
for key, layer in layers.items():
assert isinstance(layer, Attention)
impl = layer.impl
assert isinstance(impl, FlashInferImpl)
# Infer hyperparameters from the attention layer
window_size = impl.sliding_window
window_left = window_size[0] if window_size is not None else -1
logits_soft_cap = impl.logits_soft_cap
sm_scale = impl.scale
per_layer_params[key] = PerLayerParameters(window_left,
logits_soft_cap, sm_scale)
return per_layer_params
def infer_global_hyperparameters(
per_layer_params: Dict[str, PerLayerParameters]) -> PerLayerParameters:
"""
Currently, FlashInfer backend only support models in which all layers share
the same values for the following hyperparameters:
- `window_left`
- `logits_soft_cap`
- `sm_scale`
So this function asserts that all layers share the same values for these
hyperparameters and returns the global values.
"""
assert len(per_layer_params) > 0, "No attention layers found in the model."
param_sets = list(per_layer_params.values())
global_params = param_sets[0]
for params in param_sets:
assert params == global_params, (
"FlashInfer backend currently only supports models in which all "
"layers share the same values for the following hyperparameters: "
"`window_left`, `logits_soft_cap`, `sm_scale`.")
return global_params
class FlashInferState(AttentionState):
def __init__(self, runner):
......@@ -107,6 +179,11 @@ class FlashInferState(AttentionState):
self._decode_wrapper = None
self._prefill_wrapper = None
# Global hyperparameters shared by all attention layers
self.global_hyperparameters: Optional[PerLayerParameters] = None
self.vllm_config = get_current_vllm_config()
def _get_workspace_buffer(self):
if self._workspace_buffer is None:
self._workspace_buffer = torch.empty(
......@@ -214,10 +291,14 @@ class FlashInferState(AttentionState):
batch_size + 1,
dtype=torch.int32)
global_params = infer_global_hyperparameters(
get_per_layer_parameters(self.vllm_config))
attn_metadata = self.runner.attn_backend.make_metadata(
num_prefills=0,
slot_mapping=self._graph_slot_mapping[:batch_size],
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
num_prefill_tokens=0,
num_decode_tokens=batch_size,
max_prefill_seq_len=0,
......@@ -236,7 +317,9 @@ class FlashInferState(AttentionState):
q_data_type=self.runner.model_config.dtype,
use_cuda_graph=True,
decode_wrapper=self._graph_decode_wrapper,
prefill_wrapper=None)
prefill_wrapper=None,
**dataclasses.asdict(global_params),
)
attn_metadata.begin_forward()
return attn_metadata
......@@ -256,7 +339,12 @@ class FlashInferState(AttentionState):
def begin_forward(self, model_input):
assert not self._is_graph_capturing
state = self
if model_input.attn_metadata.use_cuda_graph:
use_cuda_graph = model_input.attn_metadata.use_cuda_graph
is_decode = model_input.attn_metadata.num_prefills == 0
# In case of multistep chunked-prefill, there might be prefill requests
# scheduled while CUDA graph mode is enabled. We don't run graph in that
# case.
if use_cuda_graph and is_decode:
batch_size = model_input.input_tokens.shape[0]
state = (self.runner.graph_runners[model_input.virtual_engine]
[batch_size].attn_state)
......@@ -318,9 +406,28 @@ class FlashInferMetadata(AttentionMetadata):
data_type: torch.dtype = None
# The data type of the query
q_data_type: torch.dtype = None
device: torch.device = torch.device("cuda")
# FlashInfer 0.2 encourages passing host tensors
device: torch.device = torch.device("cpu")
is_profile_run: bool = False
# The FlashInfer backend currently supports only models in which all layers
# share the same following hyperparameters:
# The left (inclusive) window size for the attention window, when
# set to `-1`, the window size will be set to the full length of
# the sequence. Defaults to `-1`.
window_left: int = -1
# The attention logits soft capping value (used in Gemini, Grok and
# Gemma-2, etc.), if not provided, will be set to `0`. If greater
# than 0, the logits will be capped according to formula:
# $$\texttt{logits\_soft\_cap} \times
# \mathrm{tanh}(x / \texttt{logits\_soft\_cap})$$,
# where $x$ is the input logits.
logits_soft_cap: Optional[float] = None
# The scale used in softmax, if not provided, will be set to
# `1.0 / sqrt(head_dim)`.
sm_scale: Optional[float] = None
def __post_init__(self):
# Refer to
# https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
......@@ -356,14 +463,21 @@ class FlashInferMetadata(AttentionMetadata):
self.block_table_bound = self.block_table_bound.to(self.device)
self.seq_lens_tensor = self.seq_lens_tensor.to(self.device)
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
self.prefill_wrapper.end_forward()
self.prefill_wrapper.begin_forward(
self.prefill_wrapper.plan(
self.query_start_loc,
self.paged_kv_indptr[:self.num_prefills + 1],
self.paged_kv_indices,
self.paged_kv_last_page_len[:self.num_prefills],
self.num_qo_heads, self.num_kv_heads, self.head_dim,
self.page_size)
self.num_qo_heads,
self.num_kv_heads,
self.head_dim,
self.page_size,
causal=True,
sm_scale=self.sm_scale,
window_left=self.window_left,
logits_soft_cap=self.logits_soft_cap,
q_data_type=self.q_data_type,
kv_data_type=self.data_type)
if self.num_decode_tokens > 0:
assert self.paged_kv_indices is not None
assert self.paged_kv_indptr is not None
......@@ -379,8 +493,7 @@ class FlashInferMetadata(AttentionMetadata):
self.seq_lens_tensor = self.seq_lens_tensor.to(self.device)
assert self.decode_wrapper is not None
self.decode_wrapper.end_forward()
self.decode_wrapper.begin_forward(
self.decode_wrapper.plan(
self.paged_kv_indptr[self.num_prefills:],
self.paged_kv_indices,
self.paged_kv_last_page_len[self.num_prefills:],
......@@ -390,8 +503,11 @@ class FlashInferMetadata(AttentionMetadata):
self.page_size,
# Disable flashinfer's pos encoding and use vllm's rope.
pos_encoding_mode="NONE",
window_left=self.window_left,
logits_soft_cap=self.logits_soft_cap,
sm_scale=self.sm_scale,
# kv-cache data type.
data_type=self.data_type,
kv_data_type=self.data_type,
# query data type.
q_data_type=self.q_data_type)
......@@ -429,10 +545,24 @@ class FlashInferMetadata(AttentionMetadata):
Update metadata in-place to advance one decode step.
"""
assert not turn_prefills_into_decodes, \
("Chunked prefill is not supported with flashinfer yet."
"turn_prefills_into_decodes is a Multi-Step + Chunked-Prefill "
"specific parameter.")
if turn_prefills_into_decodes:
# When Multi-Step is enabled with Chunked-Prefill, prefills and
# decodes are scheduled together. In the first step, all the
# prefills turn into decodes. This update reflects that
# conversion.
assert self.num_decode_tokens + self.num_prefills == num_seqs
# Flashinfer doesn't support speculative decoding + chunked-prefill
# + multi-step scheduling yet.
assert self.decode_query_len == 1
self.num_decode_tokens += self.num_prefills
self.num_prefills = 0
self.num_prefill_tokens = 0
self.max_prefill_seq_len = 0
self.max_query_len = 1
self.slot_mapping = self.slot_mapping[:num_seqs]
else:
assert self.seq_lens_tensor is not None
assert num_seqs > 0
assert num_queries > 0
......@@ -468,6 +598,19 @@ class FlashInferMetadata(AttentionMetadata):
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.input_builder = input_builder
self.runner = input_builder.runner
self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size
# Global hyperparameters shared by all attention layers
self.global_hyperparameters: Optional[PerLayerParameters] = None
self.vllm_config = get_current_vllm_config()
def prepare(self):
self.slot_mapping: List[int] = []
self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = []
......@@ -480,12 +623,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.num_prefill_tokens = 0
self.num_decode_tokens = 0
self.input_builder = input_builder
self.runner = input_builder.runner
self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size
# Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
# for the precise definition of the following fields.
# An example:
......@@ -505,6 +642,20 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.total_blocks = 0
self.is_profile_run: bool = False
if self.global_hyperparameters is None:
# Infer global hyperparameters, since currently we only support
# models in which all layers share the same values for the
# following hyperparameters:
# - `window_left`
# - `logits_soft_cap`
# - `sm_scale`
inferred_params = infer_global_hyperparameters(
get_per_layer_parameters(self.vllm_config))
self.global_hyperparameters = inferred_params
self.window_left = inferred_params.window_left
self.logits_soft_cap = inferred_params.logits_soft_cap
self.sm_scale = inferred_params.sm_scale
def _add_seq_group(
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
chunked_prefill_enabled: bool):
......@@ -713,6 +864,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor,
multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=False,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
max_prefill_seq_len=max_prefill_seq_len,
......@@ -734,7 +886,11 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
data_type=kv_cache_dtype,
q_data_type=self.runner.model_config.dtype,
use_cuda_graph=use_captured_graph,
is_profile_run=self.is_profile_run)
is_profile_run=self.is_profile_run,
window_left=self.window_left,
logits_soft_cap=self.logits_soft_cap,
sm_scale=self.sm_scale,
)
class FlashInferImpl(AttentionImpl):
......@@ -750,6 +906,7 @@ class FlashInferImpl(AttentionImpl):
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
......@@ -766,27 +923,24 @@ class FlashInferImpl(AttentionImpl):
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashInferImpl")
def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: FlashInferMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# TODO: directly write to output tensor
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashInferImpl")
num_heads: int = self.num_heads
head_size: int = self.head_size
num_kv_heads: int = self.num_kv_heads
......@@ -810,8 +964,8 @@ class FlashInferImpl(AttentionImpl):
kv_cache[:, 1],
attn_metadata.slot_mapping.flatten(),
kv_cache_dtype,
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
)
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
# to process the cache when the kv_cache_dtype is fp8
......@@ -865,25 +1019,34 @@ class FlashInferImpl(AttentionImpl):
else:
assert prefill_meta is not None
assert prefill_meta.prefill_wrapper is not None
prefill_output = prefill_meta.prefill_wrapper.forward(
assert prefill_meta.prefill_wrapper._causal
assert prefill_meta.prefill_wrapper._window_left == window_left
assert prefill_meta.prefill_wrapper._logits_soft_cap == (
logits_soft_cap or 0.0)
assert prefill_meta.prefill_wrapper._sm_scale == softmax_scale
prefill_output = prefill_meta.prefill_wrapper.run(
query,
kv_cache,
logits_soft_cap=logits_soft_cap,
causal=True,
k_scale=k_scale,
v_scale=v_scale,
window_left=window_left)
k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float,
)
if decode_meta := attn_metadata.decode_metadata:
assert decode_meta is not None
assert decode_meta.decode_wrapper is not None
decode_output = decode_meta.decode_wrapper.forward(
assert decode_meta.decode_wrapper._window_left == window_left
assert decode_meta.decode_wrapper._logits_soft_cap == (
logits_soft_cap or 0.0)
assert decode_meta.decode_wrapper._sm_scale == softmax_scale
decode_output = decode_meta.decode_wrapper.run(
decode_query,
kv_cache,
sm_scale=softmax_scale,
logits_soft_cap=logits_soft_cap,
k_scale=k_scale,
v_scale=v_scale,
window_left=window_left)
k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float,
)
if prefill_output is None and decode_output is not None:
# Decode only batch.
......
......@@ -11,6 +11,7 @@ import vllm_hpu_extension.ops as ops
from vllm_hpu_extension.utils import Matmul, Softmax, VLLMKVCache
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.attention.ops.hpu_paged_attn import (HPUPagedAttention,
......@@ -102,6 +103,7 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
max_seq_len: int = 4096,
attn_type: str = AttentionType.DECODER,
) -> None:
super(AttentionImpl, self).__init__()
self.kv_cache_dtype = kv_cache_dtype
......@@ -143,16 +145,20 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {suppored_head_sizes}.")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"HPUAttentionImpl")
def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: HPUAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention.
......@@ -166,11 +172,6 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"HPUAttentionImpl")
batch_size, seq_len, hidden_size = query.shape
_, seq_len_kv, _ = key.shape
......
......@@ -7,6 +7,7 @@ import torch
from vllm._ipex_ops import ipex_ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.attention.ops.paged_attn import (PagedAttention,
......@@ -115,6 +116,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
) -> None:
if blocksparse_params is not None:
raise ValueError(
......@@ -146,6 +148,11 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
raise NotImplementedError(
"IPEX backend does not support FP8 KV cache. "
"Please use xFormers backend instead.")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"IpexAttnBackendImpl")
def split_kv_cache(
self,
......@@ -165,14 +172,12 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: IpexAttnMetadata, # type: ignore
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with IPEX varlen_attention and PagedAttention.
......@@ -188,12 +193,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert k_scale == 1.0 and v_scale == 1.0
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"IpexAttnBackendImpl")
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
......@@ -210,8 +210,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
value_cache,
attn_metadata.slot_mapping.flatten(),
self.kv_cache_dtype,
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
)
if attn_metadata.is_prompt:
......@@ -296,8 +296,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
max_seq_len,
self.alibi_slopes,
self.kv_cache_dtype,
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
)
else:
# Run PagedAttention V2.
......@@ -329,8 +329,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
max_seq_len,
self.alibi_slopes,
self.kv_cache_dtype,
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
)
# Reshape the output tensor.
......
from abc import abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, Generic, List, Optional, Tuple
import torch
from compressed_tensors.quantization import QuantizationStrategy
from vllm import _custom_ops as ops
from vllm import envs
from vllm.attention.backends.abstract import (AttentionLayer,
AttentionMetadata,
MLAAttentionImpl, T)
from vllm.distributed import (get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase, RowParallelLinear,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsLinearMethod)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsW8A8Fp8)
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
apply_fp8_linear_generic, current_platform_fp8_dtype, is_fp8)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
scaled_dequantize, scaled_quantize)
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.vllm_flash_attn import flash_attn_varlen_func
@dataclass
class MLACommonMetadata(AttentionMetadata):
# Input positions for rotrary embeddings since for MLA the rotary
# position embeddings are applied inside the attention backend
input_positions: torch.Tensor
class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
"""
Common class for implementing repeated parts
Main reference: DeepseekV2 paper, and FlashInfer Implementation
(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
Deepseek's MLA attention works the following way:
* Use a single latent vector to represent the entire KV cache.
* The attention "simulates" a multi-head attention, while the compute is
similar to multi-query attention.
* The dataflow is as follows,
* B: batch/sequence length
* H: hidden size
* N: number of attention heads
* Lq: latent dimension for Q
* Lkv: latent dimension for K/V
* P: nope dimension, P+R is the actual head_dim in common attention.
* R: rope dimension, this slide of the head_dim goes through rope.
* V: V head dim.
* kv_c: latent/compressed KV
* q_c: latent/compressed Q
#
# Outside the MLA attention backend
#
1. The hidden states (B, H) are projected down into cq (B, Lq) and
kv_c_k_pe (B, Lkv+R).
2. The kv_c_k_pe is split into kv_c (B, Lkv) and k_pe (B, R). cq
and kv_c are normalized.
#
# Inside the MLA attention backend
#
* if prefill:
3. The q_c is then projected up into the multi-head version.
* q_c goes from (B, Lq) to (B, N, (P+R)), which is split into q_nope
(B, N, P) and q_pe (B, N, R).
4. q_pe, k_pe are then passed through rotary embeddings.
5. kv_c and k_pe are concatenated and inserted into the cache
6. The kv_c is then projected up into the multi-head version.
* kv_c goes from (B, Lkv) to (B, N, (P+V)) which has the nope
dimensions for K and V, which is split into k_nope (B, N, P)
and v (B, N, V).
7. q (B, N, (P+R)) and k (B, N, (P+R)) matrices are assembled from
q_nope, q_pe, k_nope, k_pe.
8. Attention is computued with q, k, v.
9. The attention computation returns (B, N, V), which is projected back
to (B, H) using out projection.
* if decode:
3. Here's the change, we do not perform up the full up projection for
q_c, and there is no up projection at all for kv_c. This is
achieved by the technique of "weight absorption". The paper says
"Fortunately, due to the associative law of matrix multiplication,
we can absorb WUK into WUQ, and WUV into WO"
* The q up projection turns (B, Lq) into (B, N, (P+R)), we split it
into W_UQ (Lq, N, P) and W_QR (Lq, N, R).
* The kv_c up projection turns (B, Lkv) into (B, N, (P+V)), we split
it into W_UK (Lkv, N, P) and W_UV (Lkv, N, V).
* The out projection shape W_O (N*V, H) turns (B, N, V) into (B, H).
* We can precompute the product of W_UQ and W_UK into
W_UQ_UK (Lq, N, Lkv), which is possible due to QK^T operation in
attention.
* We can precompute the product of W_UV and W_O into
W_UV_O (N, Lkv, H), which is possible due to V@O as the
"epilogue" of attention
4. We still need to compute q_pe (B, N, R) by applying W_QR to q_latent.
5. q_pe, k_pe are then passed through rotary embeddings.
6. kv_c and k_pe are concatenated and inserted into the cache
7. By applying W_UQ_UK to q_latent, we have the new q_nope of shape
(B, N, Lkv).
8. q (B, N, (Lkv+R)), k (B, (Lkv+R)) are assembled from q_nope, q_pe,
kv_a, k_pe. v (B, Lkv) is exactly the same vector as kv_a.
9. The attention is computed with q, k, v. Note that we just performed
a MQA attention with (LKv+R) as our head dim.
10. The KV cache is updated using the new entries k (B, N, (Lkv+R)),
which included the v and rope values.
11. The attention computation returns (B, N, Lkv), which is projected
back to (B, H) using W_UV_O.
From @tsu-bin's calculation, we only want to use the absorption technique
for decode. The prefill algorithm should still use the up-projected MHA
for less flops and memory usage.
"""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]],
logits_soft_cap: Optional[float],
attn_type: str,
# MLA Specific Arguments
q_lora_rank: Optional[int],
kv_lora_rank: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
qk_head_dim: int,
v_head_dim: int,
rotary_emb: RotaryEmbedding,
# q_proj should be q_b_proj if q_lora_rank is not None, but from an
# attention backend perspective we rely on the layer to pass in the
# correct matrix
q_proj: ColumnParallelLinear,
kv_b_proj: ColumnParallelLinear,
o_proj: RowParallelLinear,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
self.kv_cache_dtype = kv_cache_dtype
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.qk_head_dim = qk_head_dim
self.v_head_dim = v_head_dim
self.rotary_emb = rotary_emb
self.q_proj = q_proj
self.kv_b_proj = kv_b_proj
self.o_proj = o_proj
def _v_up_proj_and_o_proj(self, x):
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
if is_fp8(self.W_UV_O):
output_parallel = apply_fp8_linear_generic(
x.flatten(start_dim=1), self.W_UV_O, self.W_UV_O_scales,
self.reqaunt_input_group_shape,
self.reqaunt_weight_group_shape)
else:
output_parallel = torch.matmul(x.flatten(start_dim=1),
self.W_UV_O)
if self.tp_size > 1:
output = tensor_model_parallel_all_reduce(output_parallel)
else:
output = output_parallel
return output
else:
x = torch.einsum("bnl,lnv->bnv", x, self.W_UV)
return self.o_proj(x.reshape(-1,
self.num_heads * self.v_head_dim))[0]
def _q_proj_and_k_up_proj(self, x):
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
if is_fp8(self.W_Q_UK):
return apply_fp8_linear_generic(
x, self.W_Q_UK, self.W_Q_UK_scales,
self.reqaunt_input_group_shape,
self.reqaunt_weight_group_shape).view(
-1, self.num_heads, self.kv_lora_rank)
return torch.matmul(x, self.W_Q_UK)\
.view(-1, self.num_heads, self.kv_lora_rank)
else:
x = torch.matmul(x, self.W_Q)\
.view(-1, self.num_heads, self.qk_nope_head_dim)
return torch.einsum("bnp,lnp->bnl", x, self.W_UK)\
.view(-1, self.num_heads, self.kv_lora_rank)
def process_weights_after_loading(self, act_dtype: torch.dtype):
def is_layer_fp8(layer: LinearBase) -> bool:
return isinstance(layer.quant_method, Fp8LinearMethod) or\
(isinstance(layer.quant_method, CompressedTensorsLinearMethod)\
and isinstance(layer.scheme, CompressedTensorsW8A8Fp8))
def quantization_scheme_supported(layer: LinearBase) -> bool:
return isinstance(layer.quant_method, UnquantizedLinearMethod) or \
is_layer_fp8(layer)
# TODO(lucas) This is very gross, we need a more wide scale refactor of
# all the FP8 code with a more standard way of
# defining schemes/group-shapes, we should also potentially force
# quant_methods to support a decompress function
#
# returns input_group_shape, weight_group_shape
def get_scale_group_shapes_for_fp8(layer: LinearBase) -> \
Tuple[Tuple[int, int], Tuple[int, int]]:
if isinstance(layer.quant_method, Fp8LinearMethod):
if layer.quant_method.block_quant is not None:
weight_block_size = \
layer.quant_method.quant_config.weight_block_size
# per-token-group (1, X), block-quantized (X, Y)
return (1, weight_block_size[-1]), weight_block_size
else:
return (-1, -1), (-1, -1) # per-tensor, per-tensor
elif isinstance(layer.quant_method, CompressedTensorsLinearMethod)\
and isinstance(layer.scheme, CompressedTensorsW8A8Fp8):
# this is hacky but we always assume the for
# CompressedTensorsW8A8Fp8 the input is dynamic per-token
# we ignore if it is static-per-tensor since we are going to
# requantize after later anyways
strategy = layer.scheme.strategy
if strategy == QuantizationStrategy.TENSOR:
return (1, -1), (-1, -1) # per-token, per-tensor
elif strategy == QuantizationStrategy.CHANNEL:
return (1, -1), (-1, 1) # per-token, per-channel
else:
raise NotImplementedError(
f"QuantizationStrategy.{strategy} is not supported for "
"fp8 MLA, please run with VLLM_MLA_DISABLE=1")
else:
raise NotImplementedError(
"Can't determine scale group shapes for "
f"{layer.quant_method}, please run with VLLM_MLA_DISABLE=1"
)
def get_scales(layer: LinearBase) -> torch.Tensor:
if hasattr(layer, "weight_scale_inv"):
return layer.weight_scale_inv
return layer.weight_scale
def get_and_maybe_dequant_weights(layer: LinearBase):
if is_layer_fp8(layer):
if isinstance(layer.quant_method, \
CompressedTensorsLinearMethod) and \
isinstance(layer.scheme, CompressedTensorsW8A8Fp8):
# NOTE(lucas): note sure why but `CompressedTensorsW8A8Fp8`
# seems to store weights as (input, output) instead of
# (output, input) so we need to transpose
weight = layer.weight.T # standardize to (output, input)
else:
weight = layer.weight
_, weight_scale_group_shape = \
get_scale_group_shapes_for_fp8(layer)
scales = get_scales(layer)
return scaled_dequantize(weight, scales,
weight_scale_group_shape)
else:
return layer.weight
if not (quantization_scheme_supported(self.kv_b_proj) and\
quantization_scheme_supported(self.q_proj) and\
quantization_scheme_supported(self.o_proj)):
raise NotImplementedError(
"Only FP8 and UnquantizedLinearMethod are supported for MLA"
", please run with VLLM_MLA_DISABLE=1")
weight_dtype = self.kv_b_proj.weight.dtype
assert self.o_proj.weight.dtype == weight_dtype
assert self.q_proj.weight.dtype == weight_dtype
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
assert kv_b_proj_weight.shape == (
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
f"{kv_b_proj_weight.shape=}, "
f"{self.kv_lora_rank=}, "
f"{self.num_heads=}, "
f"{self.qk_nope_head_dim=}, "
f"{self.v_head_dim=}")
kv_b_proj_weight = kv_b_proj_weight.view(
self.kv_lora_rank,
self.num_heads,
self.qk_nope_head_dim + self.v_head_dim,
)
W_UK, W_UV = kv_b_proj_weight.split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
q_proj_weight = get_and_maybe_dequant_weights(self.q_proj).T\
.view(-1, self.num_heads, self.qk_head_dim)
# can be W_Q or W_UQ depending q_lora_rank, the former if
# q_lora_rank is None, the latter otherwise. From the Attention backend
# perspective though we call these both W_Q and rely on the layer
# to pass in the correct matrix
W_Q = q_proj_weight[..., :self.qk_nope_head_dim]
self.W_QR = q_proj_weight[..., self.qk_nope_head_dim:]\
.flatten(start_dim=1).contiguous()
# W_QR is small so for simplicity we dont bother requantizing it
self.W_QR = self.W_QR.to(act_dtype)
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
requantization_enabled = not envs.VLLM_MLA_DISABLE_REQUANTIZATION
if is_fp8(weight_dtype) and requantization_enabled:
# This assumes it wise to requantize using the same group shapes
# (i.e. strategy, per-tensor, per-channel, block etc.) that the
# weights were originally quantized
requant_input_group_shape, requant_weight_group_shape = \
get_scale_group_shapes_for_fp8(self.q_proj)
assert (requant_input_group_shape, requant_weight_group_shape)\
== get_scale_group_shapes_for_fp8(self.kv_b_proj)
assert (requant_input_group_shape, requant_weight_group_shape)\
== get_scale_group_shapes_for_fp8(self.o_proj)
self.reqaunt_input_group_shape = requant_input_group_shape
self.reqaunt_weight_group_shape = requant_weight_group_shape
#
# Perform matrix-absorption following
# https://github.com/flashinfer-ai/flashinfer/pull/551
# for decode, as a result we end up with absorbed weights for decode
# and another copy of raw weights for prefill.
#
self.W_UK, self.W_UV = kv_b_proj_weight.split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
# We absorb `W_UK` into `W_Q` resulting in either W_Q_UK or W_UQ_UK
# depending q_lora_rank, the former if q_lora_rank is None, the
# latter otherwise
# basically if q_lora_rank is none we are absorbing into q_proj
# instead of UQ
W_Q_UK = torch.einsum("qnd,lnd -> qnl", W_Q, W_UK)\
.flatten(start_dim=1).contiguous()
if is_fp8(weight_dtype) and requantization_enabled:
W_Q_UK, W_Q_UK_scales = scaled_quantize(
W_Q_UK,
self.reqaunt_weight_group_shape,
quant_dtype=current_platform_fp8_dtype)
# For FP8 save the transpose so we can use
# `apply_w8a8_block_fp8_linear` directly
self.W_Q_UK = W_Q_UK.T.contiguous()
self.W_Q_UK_scales = W_Q_UK_scales.T.contiguous()
else:
self.W_Q_UK = W_Q_UK.to(act_dtype)
W_O = get_and_maybe_dequant_weights(self.o_proj)\
.view(-1, self.num_heads, self.v_head_dim)
W_UV_O = torch.einsum("lnd,hnd -> nlh", W_UV, W_O)\
.flatten(start_dim=0, end_dim=1).contiguous()
if is_fp8(weight_dtype) and requantization_enabled:
W_UV_O, W_UV_O_scales = scaled_quantize(
W_UV_O,
self.reqaunt_weight_group_shape,
quant_dtype=current_platform_fp8_dtype)
# For FP8 save the transpose so we can use
# `apply_w8a8_block_fp8_linear` directly
self.W_UV_O = W_UV_O.T.contiguous()
self.W_UV_O_scales = W_UV_O_scales.T.contiguous()
else:
self.W_UV_O = W_UV_O.to(act_dtype)
self.tp_size = get_tensor_model_parallel_world_size()
else:
if is_fp8(weight_dtype):
raise NotImplementedError(
"Currently fp8 requires matrix absorption")
self.W_UV = W_UV
self.W_UK = W_UK
self.W_Q = W_Q.flatten(start_dim=1)
@abstractmethod
def _forward_prefill(
self,
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
attn_metadata: T,
) -> torch.Tensor:
raise NotImplementedError
@abstractmethod
def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: T,
) -> torch.Tensor:
raise NotImplementedError
def forward(
self,
layer: AttentionLayer,
hidden_states_or_q_c: torch.Tensor, # query in unified attn
k_c_normed: torch.Tensor, # key in unified attn
k_pe: torch.Tensor, # value in unified attn
kv_cache: torch.Tensor,
attn_metadata: T,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if output is not None:
raise NotImplementedError(
"output is not yet supported for MLAImplBase")
is_decode = attn_metadata.decode_metadata is not None
is_prefill = attn_metadata.prefill_metadata is not None
if (is_decode and is_prefill):
raise NotImplementedError(
"chunked prefill is not supported for MLAImplBase")
# Restore head dim (for rotary embedding)
k_pe = k_pe.unsqueeze(1)
assert hasattr(attn_metadata, "input_positions")
if is_decode:
q_nope = self._q_proj_and_k_up_proj(hidden_states_or_q_c)
q_pe = torch.matmul(hidden_states_or_q_c, self.W_QR)\
.view(-1, self.num_heads, self.qk_rope_head_dim)
q_pe, k_pe = \
self.rotary_emb(attn_metadata.input_positions, q_pe, k_pe)
else:
assert is_prefill
q = self.q_proj(hidden_states_or_q_c)[0]\
.view(-1, self.num_heads, self.qk_head_dim)
# TODO(lucas): there must be a nicer way to write this line
q[..., self.qk_nope_head_dim:], k_pe = \
self.rotary_emb(
attn_metadata.input_positions,
q[..., self.qk_nope_head_dim:], k_pe)
# write the latent and rope to kv cache
if kv_cache.numel() > 0:
ops.concat_and_cache_mla(
k_c_normed,
k_pe.squeeze(1),
kv_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache_dtype=self.kv_cache_dtype,
scale=layer._k_scale,
)
if attn_metadata.prefill_metadata is not None:
return self._forward_prefill(q, k_c_normed, k_pe, attn_metadata)
if attn_metadata.decode_metadata is not None:
return self._forward_decode(q_nope, q_pe, kv_cache, attn_metadata)
# Optional common flash-attn based prefill
def _forward_prefill_flash(
self,
q: torch.Tensor,
k_c_normed: torch.Tensor,
k_pe: torch.Tensor,
seq_start_loc: torch.Tensor,
max_prefill_seq_len: int,
) -> torch.Tensor:
kv_nope = self.kv_b_proj(k_c_normed)[0]\
.view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
# For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim
v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
value=0)
attn_output = flash_attn_varlen_func(
q=q,
k=k,
v=v_padded,
cu_seqlens_q=seq_start_loc,
cu_seqlens_k=seq_start_loc,
max_seqlen_q=max_prefill_seq_len,
max_seqlen_k=max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
)
attn_output = attn_output\
.view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\
.reshape(-1, self.num_heads * v.shape[-1])
return self.o_proj(attn_output)[0]
......@@ -5,6 +5,7 @@ import torch
import torch_xla.experimental.custom_kernel # Required to register custom ops.
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
......@@ -100,6 +101,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
......@@ -108,6 +110,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.logits_soft_cap = logits_soft_cap
if head_size % 128 != 0:
raise NotImplementedError("Head size must be a multiple of 128.")
if alibi_slopes is not None:
......@@ -118,9 +121,6 @@ class PallasAttentionBackendImpl(AttentionImpl):
raise NotImplementedError("FP8 KV cache dtype is not supported.")
if blocksparse_params is not None:
raise NotImplementedError("Blocksparse is not supported.")
if logits_soft_cap is not None:
raise NotImplementedError(
"Attention logits soft-capping is not supported.")
if torch_xla.tpu.version() < 4:
raise NotImplementedError("TPU version must be 4 or higher.")
......@@ -141,16 +141,20 @@ class PallasAttentionBackendImpl(AttentionImpl):
# megacore mode will be None.
self.megacore_mode = "batch"
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"PallasAttentionBackendImpl")
def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Tuple[torch.Tensor, torch.Tensor],
attn_metadata: PallasMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with Pallas attention.
......@@ -167,12 +171,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
Returns:
shape = [batch_size, seq_len, num_heads * head_size]
"""
assert k_scale == 1.0 and v_scale == 1.0
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"PallasAttentionBackendImpl")
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
batch_size, seq_len, hidden_size = query.shape
query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)
......@@ -229,6 +228,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
num_kv_pages_per_compute_block,
num_queries_per_compute_block,
use_kernel=True,
attn_logits_soft_cap=self.logits_soft_cap,
)
else:
# Decoding run.
......@@ -256,6 +256,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
attn_metadata.block_tables,
pages_per_compute_block,
self.megacore_mode,
attn_logits_soft_cap=self.logits_soft_cap,
)
else:
chunk_size = max_num_seq
......@@ -279,6 +280,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
attn_metadata.block_tables[chunk_start:chunk_end],
pages_per_compute_block,
self.megacore_mode,
attn_logits_soft_cap=self.logits_soft_cap,
)
output[chunk_start:chunk_end] = chunk_output
......@@ -312,6 +314,8 @@ def paged_attention(
block_tables: torch.Tensor,
pages_per_compute_block: int,
megacore_mode: Optional[str],
*,
attn_logits_soft_cap: Optional[float],
) -> torch.Tensor:
batch_size = query.shape[0]
if megacore_mode == "batch" and batch_size % 2 != 0:
......@@ -319,26 +323,13 @@ def paged_attention(
else:
megacore_mode = megacore_mode
# NOTE(woosuk): A temporary workaround to avoid the error:
# "xla::paged_attention() Expected a value of type 'str' for
# argument 'megacore_mode' but instead found type 'NoneType'."
if megacore_mode is not None:
output = torch.ops.xla.paged_attention(
query,
key_cache,
value_cache,
context_lens,
block_tables,
pages_per_compute_block,
megacore_mode=megacore_mode,
)
else:
output = torch.ops.xla.paged_attention(
query,
key_cache,
value_cache,
context_lens,
block_tables,
pages_per_compute_block,
)
return output
return torch.ops.xla.paged_attention(
query,
key_cache,
value_cache,
context_lens,
block_tables,
pages_per_compute_block,
megacore_mode=megacore_mode,
attn_logits_soft_cap=attn_logits_soft_cap,
)
......@@ -140,6 +140,7 @@ class PlaceholderAttentionMetadata(AttentionMetadata):
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps,
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
seq_lens=self.seq_lens[:self.num_prefills],
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
max_decode_query_len=0,
......@@ -173,6 +174,7 @@ class PlaceholderAttentionMetadata(AttentionMetadata):
num_decode_tokens=self.num_decode_tokens,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
seq_lens=None,
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
max_decode_query_len=self.max_decode_query_len,
......@@ -253,6 +255,11 @@ class PlaceholderAttentionMetadataBuilder(
AttentionMetadataBuilder[PlaceholderAttentionMetadata]):
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.input_builder = input_builder
self.runner = input_builder.runner
def prepare(self):
self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = []
self.curr_seq_lens: List[int] = []
......@@ -263,9 +270,6 @@ class PlaceholderAttentionMetadataBuilder(
self.num_prefill_tokens = 0
self.num_decode_tokens = 0
self.input_builder = input_builder
self.runner = input_builder.runner
def _add_seq_group(
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
chunked_prefill_enabled: bool):
......@@ -378,6 +382,7 @@ class PlaceholderAttentionMetadataBuilder(
num_prefills=self.num_prefills,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=True,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
......
......@@ -7,6 +7,7 @@ import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import (CommonAttentionState,
CommonMetadataBuilder)
......@@ -90,6 +91,17 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
seq_lens: Optional[List[int]]
# seq_lens stored as a tensor.
seq_lens_tensor: Optional[torch.Tensor]
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
max_prefill_seq_len: int
# Maximum sequence length among decode batch. 0 if there are prefill
# requests only.
max_decode_seq_len: int
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
......@@ -100,30 +112,18 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
# |-- query_len ---|
# Maximum query length in the batch. None for decoding.
max_query_len: Optional[int]
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
max_prefill_seq_len: int
# Maximum sequence length among decode batch. 0 if there are prefill
# requests only.
max_decode_seq_len: int
max_query_len: Optional[int] = None
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
query_start_loc: Optional[torch.Tensor]
query_start_loc: Optional[torch.Tensor] = None
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor]
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool
seq_start_loc: Optional[torch.Tensor] = None
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor: Optional[torch.Tensor]
context_lens_tensor: Optional[torch.Tensor] = None
# Max number of query tokens among request in the batch.
max_decode_query_len: Optional[int] = None
......@@ -133,6 +133,23 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
tree_attention_masks_tensor: Optional[torch.Tensor] = None
block_tables_list: Optional[List[int]] = None
# Begin encoder attn & enc/dec cross-attn fields...
# Encoder sequence lengths representation
encoder_seq_lens: Optional[List[int]] = None
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
# Maximum sequence length among encoder sequences
max_encoder_seq_len: Optional[int] = None
# Number of tokens input to encoder
num_encoder_tokens: Optional[int] = None
# Cross-attention memory-mapping data structures: slot mapping
# and block tables
cross_slot_mapping: Optional[torch.Tensor] = None
cross_block_tables: Optional[torch.Tensor] = None
@property
def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
......@@ -144,10 +161,7 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
assert self.seq_lens is not None
assert self.seq_lens_tensor is not None
assert self.query_start_loc is not None
assert self.context_lens_tensor is not None
assert self.block_tables is not None
assert self.seq_start_loc is not None
self._cached_prefill_metadata = ROCmFlashAttentionMetadata(
num_prefills=self.num_prefills,
......@@ -156,19 +170,28 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps,
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
seq_lens=self.seq_lens[:self.num_prefills],
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
max_query_len=self.max_query_len,
max_prefill_seq_len=self.max_prefill_seq_len,
max_decode_seq_len=0,
query_start_loc=self.query_start_loc[:self.num_prefills + 1],
seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
query_start_loc=None if self.query_start_loc is None else
self.query_start_loc[:self.num_prefills + 1],
seq_start_loc=None if self.seq_start_loc is None else
self.seq_start_loc[:self.num_prefills + 1],
context_lens_tensor=None if self.context_lens_tensor is None else
self.context_lens_tensor[:self.num_prefills],
block_tables=self.block_tables[:self.num_prefills],
use_cuda_graph=False,
# Begin encoder & cross attn fields below...
encoder_seq_lens=self.encoder_seq_lens,
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
max_encoder_seq_len=self.max_encoder_seq_len,
cross_slot_mapping=self.cross_slot_mapping,
cross_block_tables=self.cross_block_tables
tree_attention_masks_tensor=self.tree_attention_masks_tensor,
block_tables_list=self.block_tables_list
)
block_tables_list=self.block_tables_list)
return self._cached_prefill_metadata
@property
......@@ -187,6 +210,7 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
num_decode_tokens=self.num_decode_tokens,
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
seq_lens=None,
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
max_query_len=None,
......@@ -197,9 +221,14 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
context_lens_tensor=None,
block_tables=self.block_tables[self.num_prefills:],
use_cuda_graph=self.use_cuda_graph,
# Begin encoder & cross attn fields below...
encoder_seq_lens=self.encoder_seq_lens,
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
max_encoder_seq_len=self.max_encoder_seq_len,
cross_slot_mapping=self.cross_slot_mapping,
cross_block_tables=self.cross_block_tables,
tree_attention_masks_tensor=self.tree_attention_masks_tensor,
block_tables_list=self.block_tables_list
)
block_tables_list=self.block_tables_list)
# Batch may be composed of prefill|decodes, adjust query start indices
# to refer to the start of decodes when the two are split apart.
# E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
......@@ -309,6 +338,97 @@ def _make_alibi_bias(alibi_slopes: torch.Tensor,
return attn_biases
def _get_seq_len_block_table_args(
attn_metadata: ROCmFlashAttentionMetadata,
attn_type: str,
) -> tuple:
'''
The particular choice of sequence-length
attributes which should be extracted from attn_metadata is dependent
on the type of attention operation.
Decoder attn -> select entirely decoder self-attention-related fields
Encoder/decoder cross-attn -> select encoder sequence lengths
Encoder attn -> select encoder sequence lengths fields
Arguments:
* attn_metadata: Attention metadata structure associated with attention op
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate sequence-lengths tensors for query and key
* Appropriate max sequence-length scalar
'''
partial_prefix_sum = 0
if attn_type == AttentionType.ENCODER:
assert attn_metadata.encoder_seq_lens is not None
assert attn_metadata.encoder_seq_lens_tensor is not None
query_seq_start_loc = torch.tensor(
[0] + [
partial_prefix_sum := partial_prefix_sum + i
for i in attn_metadata.encoder_seq_lens
],
device=attn_metadata.encoder_seq_lens_tensor.device,
dtype=attn_metadata.encoder_seq_lens_tensor.dtype)
causal_mask = False
# No block tables associated with encoder attention
return (query_seq_start_loc, attn_metadata.max_encoder_seq_len,
query_seq_start_loc, attn_metadata.max_encoder_seq_len,
attn_metadata.encoder_seq_lens, causal_mask)
elif attn_type == AttentionType.DECODER:
# Decoder self-attention
# Choose max_seq_len based on whether we are in prompt_run
assert attn_metadata.seq_lens is not None
assert attn_metadata.seq_lens_tensor is not None
query_seq_start_loc = torch.tensor(
[0] + [
partial_prefix_sum := partial_prefix_sum + i
for i in attn_metadata.seq_lens
],
device=attn_metadata.seq_lens_tensor.device,
dtype=attn_metadata.seq_lens_tensor.dtype)
max_seq_len = attn_metadata.max_prefill_seq_len
causal_mask = True
return (query_seq_start_loc, max_seq_len, query_seq_start_loc,
max_seq_len, attn_metadata.seq_lens, causal_mask)
elif attn_type == AttentionType.ENCODER_DECODER:
assert attn_metadata.seq_lens is not None
assert attn_metadata.encoder_seq_lens_tensor is not None
query_start_loc = torch.tensor(
[0] + [
partial_prefix_sum := partial_prefix_sum + i
for i in attn_metadata.seq_lens
],
device=attn_metadata.encoder_seq_lens_tensor.device,
dtype=attn_metadata.encoder_seq_lens_tensor.dtype)
partial_prefix_sum = 0
assert attn_metadata.encoder_seq_lens is not None
assert attn_metadata.seq_lens_tensor is not None
key_seq_start_loc = torch.tensor(
[0] + [
partial_prefix_sum := partial_prefix_sum + i
for i in attn_metadata.encoder_seq_lens
],
device=attn_metadata.seq_lens_tensor.device,
dtype=attn_metadata.seq_lens_tensor.dtype)
causal_mask = False
# Enc/dec cross-attention KVs match encoder sequence length;
# cross-attention utilizes special "cross" block tables
return (query_start_loc, attn_metadata.max_prefill_seq_len,
key_seq_start_loc, attn_metadata.max_encoder_seq_len,
attn_metadata.seq_lens, causal_mask)
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
class ROCmFlashAttentionImpl(AttentionImpl):
"""
If the input tensors contain prompt tokens, the layout is as follows:
......@@ -346,21 +466,18 @@ class ROCmFlashAttentionImpl(AttentionImpl):
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
) -> None:
if blocksparse_params is not None:
raise ValueError(
"ROCmFlashAttention does not support blocksparse attention.")
'''
if logits_soft_cap is not None:
raise ValueError(
"ROCmFlashAttention does not support attention logits soft "
"capping.")
'''
if logits_soft_cap is None:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap = 0
self.logits_soft_cap = logits_soft_cap
self.logits_soft_cap = 0.0
else:
self.logits_soft_cap = logits_soft_cap
self.attn_type = attn_type
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
......@@ -385,11 +502,22 @@ class ROCmFlashAttentionImpl(AttentionImpl):
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN
if self.use_triton_flash_attn:
# from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
# triton_attention)
from vllm.attention.ops.flash_attn_triton_mqa_gqa import (
flash_attn_varlen_func)
self.attn_func = flash_attn_varlen_func # triton_attention
if logits_soft_cap is not None:
raise ValueError(
"ROCm Triton FlashAttention does not support attention"
"logits soft capping."
" please try using the ROCm CK "
"FA backend instead by setting the env var "
"`VLLM_USE_TRITON_FLASH_ATTN=0`")
from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
triton_attention)
self.attn_func = triton_attention
# from vllm.attention.ops.flash_attn_triton_mqa_gqa import (
# flash_attn_varlen_func)
# self.attn_func = flash_attn_varlen_func
logger.debug("Using Triton FA in ROCmBackend")
if self.sliding_window != (-1, -1):
logger.warning("ROCm Triton FA does not currently support "
......@@ -411,8 +539,13 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self.use_naive_attn = True
if self.use_naive_attn:
if logits_soft_cap is not None:
raise ValueError(
"ROCm Naive FlashAttention does not support"
"attention logits soft capping.")
self.attn_func = _sdpa_attention
logger.debug("Using naive attention in ROCmBackend")
logger.debug("Using naive (SDPA) attention in ROCmBackend")
def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
......@@ -424,18 +557,47 @@ class ROCmFlashAttentionImpl(AttentionImpl):
def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: ROCmFlashAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
For decoder-only models: query, key and value must be non-None.
For encoder/decoder models:
* ROCmFlashAttentionImpl.forward() may be invoked for both self- and
cross-attention layers.
* For self-attention: query, key and value must be non-None.
* For cross-attention:
* Query must be non-None
* During prefill, key and value must be non-None; key and value
get cached for use during decode.
* During decode, key and value may be None, since:
(1) key and value tensors were cached during prefill, and
(2) cross-attention key and value tensors do not grow during
decode
A note on how the attn_type (attention type enum) argument impacts
attention forward() behavior:
* DECODER: normal decoder-only behavior;
use decoder self-attention block table
* ENCODER: no KV caching; pass encoder sequence
attributes (encoder_seq_lens/encoder_seq_lens_tensor/
max_encoder_seq_len) to kernel, in lieu of decoder
sequence attributes (seq_lens/seq_lens_tensor/max_seq_len)
* ENCODER_DECODER: cross-attention behavior;
use cross-attention block table for caching KVs derived
from encoder hidden states; since KV sequence lengths
will match encoder sequence lengths, pass encoder sequence
attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/
max_encoder_seq_len)
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
......@@ -444,60 +606,80 @@ class ROCmFlashAttentionImpl(AttentionImpl):
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
attn_type: Select attention type, between encoder attention,
decoder self-attention, or encoder/decoder cross-
attention. Defaults to decoder self-attention,
which is the vLLM default generally
Returns:
shape = [num_tokens, num_heads * head_size]
"""
# Reminder: Please update docs/source/usage/compatibility_matrix.md
# If the feature combo become valid
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"ROCmFlashAttentionImpl")
num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
if key is not None:
assert value is not None
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
else:
assert value is None
if kv_cache.numel() > 0:
if self.attn_type != AttentionType.ENCODER and kv_cache.numel() > 0:
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
PagedAttention.write_to_paged_cache(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
k_scale,
v_scale,
)
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
if key is not None and value is not None:
# Reshape the input keys and values and store them in the
# cache. If kv_cache is not provided, the new key and value
# tensors are not cached. This happens during the initial
# memory profiling run.
PagedAttention.write_to_paged_cache(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping
if self.attn_type != AttentionType.ENCODER_DECODER else
attn_metadata.cross_slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
if self.attn_type != AttentionType.ENCODER:
num_prefill_tokens = attn_metadata.num_prefill_tokens
else:
assert attn_metadata.num_encoder_tokens is not None
num_prefill_tokens = attn_metadata.num_encoder_tokens
output = torch.empty_like(query)
# Query for decode. KV is not needed because it is already cached.
decode_query = query[num_prefill_tokens:]
# QKV for prefill.
query = query[:num_prefill_tokens]
key = key[:num_prefill_tokens]
value = value[:num_prefill_tokens]
assert query.shape[0] == num_prefill_tokens
assert decode_query.shape[0] == num_decode_tokens
if key is not None and value is not None \
and self.attn_type != AttentionType.ENCODER_DECODER:
key = key[:num_prefill_tokens]
value = value[:num_prefill_tokens]
if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
assert prefill_meta.seq_lens is not None
# normal attention and DECODER
if self.attn_type == AttentionType.DECODER and (
kv_cache.numel() == 0 or prefill_meta.block_tables is None
or prefill_meta.block_tables.numel() == 0):
(query_seq_start_loc, query_max_seq_len, key_seq_start_loc,
key_max_seq_len, seq_lens,
causal_mask) = (prefill_meta.seq_start_loc,
prefill_meta.max_prefill_seq_len,
prefill_meta.seq_start_loc,
prefill_meta.max_prefill_seq_len,
attn_metadata.seq_lens, True)
# prefix-enabled attention and ENCODER/ENCODER_DECODER
else:
(query_seq_start_loc, query_max_seq_len, key_seq_start_loc,
key_max_seq_len, seq_lens,
causal_mask) = _get_seq_len_block_table_args(
prefill_meta, self.attn_type)
# Prompt run.
if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0:
# triton attention
# When block_tables are not filled, it means q and k are the
......@@ -508,30 +690,21 @@ class ROCmFlashAttentionImpl(AttentionImpl):
attn_masks = _make_alibi_bias(
self.alibi_slopes,
query.dtype,
attn_metadata.seq_lens,
seq_lens,
make_attn_mask=False) # type: ignore
# out = self.attn_func(
# query,
# key,
# value,
# prefill_meta.seq_lens,
# num_tokens,
# self.num_heads,
# self.head_size,
# self.scale,
# attn_masks,
# )
out = self.attn_func(
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlens_q=prefill_meta.max_prefill_seq_len,
max_seqlens_k=prefill_meta.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
out, _ = self.attn_func(
query,
key,
value,
None,
query_seq_start_loc,
key_seq_start_loc,
query_max_seq_len,
key_max_seq_len,
causal_mask,
self.scale,
attn_masks[0][None]
if attn_masks is not None else None,
)
elif self.use_naive_attn:
......@@ -553,11 +726,12 @@ class ROCmFlashAttentionImpl(AttentionImpl):
query,
key,
value,
prefill_meta.seq_lens,
num_tokens,
query_seq_start_loc,
num_prefill_tokens,
self.num_heads,
self.head_size,
self.scale,
causal_mask,
attn_masks,
)
else:
......@@ -565,10 +739,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
cu_seqlens_q=query_seq_start_loc,
cu_seqlens_k=key_seq_start_loc,
max_seqlen_q=prefill_meta.max_prefill_seq_len,
max_seqlen_k=prefill_meta.max_prefill_seq_len,
max_seqlen_k=key_max_seq_len,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
......@@ -578,7 +752,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
# common code for prefill
assert output[:num_prefill_tokens].shape == out.shape
output[:num_prefill_tokens] = out
if output.shape[0] > num_prefill_tokens:
output[:num_prefill_tokens] = out
else:
output = out
else:
# prefix-enabled attention
output[:num_prefill_tokens] = PagedAttention.forward_prefix(
......@@ -595,8 +772,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
prefill_meta.max_query_len,
self.alibi_slopes,
self.sliding_window[0],
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
)
if decode_meta := attn_metadata.decode_metadata:
......@@ -610,7 +787,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
# decode_meta.max_decode_seq_len)
use_custom = False
if use_custom:
max_seq_len = decode_meta.max_decode_seq_len
max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type
!= AttentionType.ENCODER_DECODER else
decode_meta.max_encoder_seq_len)
assert max_seq_len is not None
max_num_partitions = (
(max_seq_len + _PARTITION_SIZE_ROCM - 1) //
_PARTITION_SIZE_ROCM)
......@@ -626,8 +806,12 @@ class ROCmFlashAttentionImpl(AttentionImpl):
device=output.device,
)
max_logits = torch.empty_like(exp_sums)
if num_prefill_tokens > 0:
out = output[num_prefill_tokens:]
else:
out = output
ops.paged_attention_rocm(
output[num_prefill_tokens:],
out,
exp_sums,
max_logits,
tmp_output,
......@@ -636,14 +820,18 @@ class ROCmFlashAttentionImpl(AttentionImpl):
value_cache,
self.num_kv_heads,
self.scale,
decode_meta.block_tables,
decode_meta.seq_lens_tensor,
decode_meta.block_tables
if self.attn_type != AttentionType.ENCODER_DECODER else
decode_meta.cross_block_tables,
decode_meta.seq_lens_tensor
if self.attn_type != AttentionType.ENCODER_DECODER else
decode_meta.encoder_seq_lens_tensor,
block_size,
max_seq_len,
self.alibi_slopes,
self.kv_cache_dtype,
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
)
else:
tree_attention_masks_tensor = decode_meta.tree_attention_masks_tensor
......@@ -651,21 +839,27 @@ class ROCmFlashAttentionImpl(AttentionImpl):
decode_query,
key_cache,
value_cache,
decode_meta.block_tables,
decode_meta.seq_lens_tensor,
decode_meta.max_decode_seq_len,
decode_meta.block_tables
if self.attn_type != AttentionType.ENCODER_DECODER else
decode_meta.cross_block_tables,
decode_meta.seq_lens_tensor
if self.attn_type != AttentionType.ENCODER_DECODER else
decode_meta.encoder_seq_lens_tensor,
decode_meta.max_decode_seq_len
if self.attn_type != AttentionType.ENCODER_DECODER else
decode_meta.max_encoder_seq_len,
self.kv_cache_dtype,
self.num_kv_heads,
self.scale,
self.alibi_slopes,
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
attn_masks=tree_attention_masks_tensor,
attn_masks_stride=tree_attention_masks_tensor.stride(0) if tree_attention_masks_tensor is not None else 0
)
# Reshape the output tensor.
return output.view(num_tokens, hidden_size)
return output.view(-1, self.num_heads * self.head_size)
def _sdpa_attention(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment