Commit 99324e25 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.9.2' into v0.9.2-ori

parents cc7f22a8 a5dd03c1
...@@ -654,7 +654,6 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -654,7 +654,6 @@ class FlashAttentionImpl(AttentionImpl):
logits_soft_cap = 0 logits_soft_cap = 0
self.logits_soft_cap = logits_soft_cap self.logits_soft_cap = logits_soft_cap
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.num_queries_per_kv = self.num_heads // self.num_kv_heads
support_head_sizes = FlashAttentionBackend.get_supported_head_sizes() support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
...@@ -673,6 +672,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -673,6 +672,7 @@ class FlashAttentionImpl(AttentionImpl):
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata, attn_metadata: FlashAttentionMetadata,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with FlashAttention. """Forward pass with FlashAttention.
...@@ -692,6 +692,11 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -692,6 +692,11 @@ class FlashAttentionImpl(AttentionImpl):
""" """
assert output is not None, "Output tensor must be provided." assert output is not None, "Output tensor must be provided."
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for FlashAttentionImpl")
# NOTE(woosuk): FlashAttention2 does not support FP8 KV cache. # NOTE(woosuk): FlashAttention2 does not support FP8 KV cache.
if not flash_attn_supports_fp8() or output.dtype != torch.bfloat16: if not flash_attn_supports_fp8() or output.dtype != torch.bfloat16:
assert ( assert (
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses import dataclasses
import os
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
...@@ -50,8 +49,7 @@ if TYPE_CHECKING: ...@@ -50,8 +49,7 @@ if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder, from vllm.worker.model_runner import (ModelInputForGPUBuilder,
ModelInputForGPUWithSamplingMetadata) ModelInputForGPUWithSamplingMetadata)
FLASHINFER_KV_CACHE_LAYOUT: str = os.getenv("FLASHINFER_KV_CACHE_LAYOUT", FLASHINFER_KV_CACHE_LAYOUT: str = envs.VLLM_KV_CACHE_LAYOUT or "NHD"
"NHD").upper()
class FlashInferBackend(AttentionBackend): class FlashInferBackend(AttentionBackend):
...@@ -957,7 +955,6 @@ class FlashInferImpl(AttentionImpl): ...@@ -957,7 +955,6 @@ class FlashInferImpl(AttentionImpl):
self.kv_cache_dtype = kv_cache_dtype self.kv_cache_dtype = kv_cache_dtype
self.logits_soft_cap = logits_soft_cap self.logits_soft_cap = logits_soft_cap
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.num_queries_per_kv = self.num_heads // self.num_kv_heads
if attn_type != AttentionType.DECODER: if attn_type != AttentionType.DECODER:
...@@ -975,8 +972,14 @@ class FlashInferImpl(AttentionImpl): ...@@ -975,8 +972,14 @@ class FlashInferImpl(AttentionImpl):
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: FlashInferMetadata, attn_metadata: FlashInferMetadata,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for FlashInferImpl")
# TODO: directly write to output tensor # TODO: directly write to output tensor
num_heads: int = self.num_heads num_heads: int = self.num_heads
head_size: int = self.head_size head_size: int = self.head_size
......
...@@ -148,7 +148,6 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module): ...@@ -148,7 +148,6 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
alibi_slopes_tensor = torch.tensor(alibi_slopes, alibi_slopes_tensor = torch.tensor(alibi_slopes,
dtype=torch.bfloat16) dtype=torch.bfloat16)
self.alibi_slopes = alibi_slopes_tensor self.alibi_slopes = alibi_slopes_tensor
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.num_queries_per_kv = self.num_heads // self.num_kv_heads
if self.prefill_impl == 'fsdpa': if self.prefill_impl == 'fsdpa':
...@@ -181,6 +180,7 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module): ...@@ -181,6 +180,7 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: HPUAttentionMetadata, attn_metadata: HPUAttentionMetadata,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention. """Forward pass with xFormers and PagedAttention.
...@@ -193,6 +193,11 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module): ...@@ -193,6 +193,11 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for HPUAttentionImpl")
batch_size, seq_len, hidden_size = query.shape batch_size, seq_len, hidden_size = query.shape
_, seq_len_kv, _ = key.shape _, seq_len_kv, _ = key.shape
......
...@@ -145,7 +145,6 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): ...@@ -145,7 +145,6 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
self.sliding_window = sliding_window self.sliding_window = sliding_window
self.kv_cache_dtype = kv_cache_dtype self.kv_cache_dtype = kv_cache_dtype
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.need_mask = (self.sliding_window is not None) self.need_mask = (self.sliding_window is not None)
if logits_soft_cap is None: if logits_soft_cap is None:
...@@ -192,6 +191,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): ...@@ -192,6 +191,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: IpexAttnMetadata, # type: ignore attn_metadata: IpexAttnMetadata, # type: ignore
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with IPEX varlen_attention and PagedAttention. """Forward pass with IPEX varlen_attention and PagedAttention.
...@@ -206,6 +206,11 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): ...@@ -206,6 +206,11 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for IpexAttentionImpl")
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
num_tokens, hidden_size = query.shape num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors. # Reshape the query, key, and value tensors.
......
...@@ -1334,11 +1334,17 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -1334,11 +1334,17 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: T, attn_metadata: T,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if output is not None: if output is not None:
raise NotImplementedError( raise NotImplementedError(
"output is not yet supported for MLAImplBase") "output is not yet supported for MLAImplBase")
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for MLAImplBase")
if attn_metadata.is_profile_run and \ if attn_metadata.is_profile_run and \
attn_metadata.context_chunk_workspace is not None: attn_metadata.context_chunk_workspace is not None:
# During the profile run try to simulate to worse case output size # During the profile run try to simulate to worse case output size
......
...@@ -121,9 +121,8 @@ class PallasAttentionBackendImpl(AttentionImpl): ...@@ -121,9 +121,8 @@ class PallasAttentionBackendImpl(AttentionImpl):
self.num_heads = num_heads self.num_heads = num_heads
self.head_size = head_size self.head_size = head_size
self.scale = float(scale) self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads self.num_kv_heads = num_kv_heads
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.logits_soft_cap = logits_soft_cap self.logits_soft_cap = logits_soft_cap
if head_size % 128 != 0: if head_size % 128 != 0:
...@@ -172,6 +171,7 @@ class PallasAttentionBackendImpl(AttentionImpl): ...@@ -172,6 +171,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
kv_cache: Tuple[torch.Tensor, torch.Tensor], kv_cache: Tuple[torch.Tensor, torch.Tensor],
attn_metadata: PallasMetadata, attn_metadata: PallasMetadata,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with Pallas attention. """Forward pass with Pallas attention.
...@@ -187,6 +187,11 @@ class PallasAttentionBackendImpl(AttentionImpl): ...@@ -187,6 +187,11 @@ class PallasAttentionBackendImpl(AttentionImpl):
Returns: Returns:
shape = [batch_size, seq_len, num_heads * head_size] shape = [batch_size, seq_len, num_heads * head_size]
""" """
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for PallasAttentionImpl")
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
batch_size, seq_len, hidden_size = query.shape batch_size, seq_len, hidden_size = query.shape
query = query.view(batch_size, seq_len, self.num_heads, self.head_size) query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
......
...@@ -17,6 +17,7 @@ from vllm.attention.backends.utils import (CommonAttentionState, ...@@ -17,6 +17,7 @@ from vllm.attention.backends.utils import (CommonAttentionState,
CommonMetadataBuilder) CommonMetadataBuilder)
from vllm.attention.ops.paged_attn import (PagedAttention, from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata) PagedAttentionMetadata)
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.rocm import use_rocm_custom_paged_attention from vllm.platforms.rocm import use_rocm_custom_paged_attention
...@@ -37,11 +38,11 @@ def is_rocm_aiter_paged_attn_enabled() -> bool: ...@@ -37,11 +38,11 @@ def is_rocm_aiter_paged_attn_enabled() -> bool:
@cache @cache
def _get_paged_attn_module() -> PagedAttention: def _get_paged_attn_module() -> PagedAttention:
""" """
Initializes the appropriate PagedAttention module from `attention/ops`, Initializes the appropriate PagedAttention module from `attention/ops`,
which is used as helper function which is used as helper function
by `ROCmFlashAttentionImpl` and `ROCmFlashAttentionBackend`. by `ROCmFlashAttentionImpl` and `ROCmFlashAttentionBackend`.
The choice of attention module depends on whether The choice of attention module depends on whether
AITER paged attention is enabled: AITER paged attention is enabled:
- If enabled, `ROCmFlashAttentionImpl` uses `AITERPagedAttention`. - If enabled, `ROCmFlashAttentionImpl` uses `AITERPagedAttention`.
- Otherwise, it defaults to using the original `PagedAttention`. - Otherwise, it defaults to using the original `PagedAttention`.
...@@ -527,7 +528,6 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -527,7 +528,6 @@ class ROCmFlashAttentionImpl(AttentionImpl):
if sliding_window is not None else (-1, -1)) if sliding_window is not None else (-1, -1))
self.kv_cache_dtype = kv_cache_dtype self.kv_cache_dtype = kv_cache_dtype
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.paged_attn_module = _get_paged_attn_module() self.paged_attn_module = _get_paged_attn_module()
...@@ -584,6 +584,10 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -584,6 +584,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
logger.debug("Using naive (SDPA) attention in ROCmBackend") logger.debug("Using naive (SDPA) attention in ROCmBackend")
self.aiter_kv_scales_initialized = False self.aiter_kv_scales_initialized = False
self.force_fp8_attention = (
get_current_vllm_config() is not None
and get_current_vllm_config().model_config.override_attention_dtype
== "fp8")
def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)""" """torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
...@@ -593,6 +597,15 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -593,6 +597,15 @@ class ROCmFlashAttentionImpl(AttentionImpl):
head_dim).reshape(tokens, n_kv_heads * n_rep, head_dim).reshape(tokens, n_kv_heads * n_rep,
head_dim)) head_dim))
def fused_output_quant_supported(self, dtype: torch.dtype, static: bool,
group_shape: tuple[int, int]):
if self.use_triton_flash_attn:
return dtype == current_platform.fp8_dtype(
) and static and group_shape == (-1, -1) # per-tensor
# Only supported in the Triton backend
return False
def forward( def forward(
self, self,
layer: AttentionLayer, layer: AttentionLayer,
...@@ -602,6 +615,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -602,6 +615,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: ROCmFlashAttentionMetadata, attn_metadata: ROCmFlashAttentionMetadata,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention. """Forward pass with FlashAttention and PagedAttention.
...@@ -655,6 +669,11 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -655,6 +669,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
""" """
assert output is not None, "Output tensor must be provided." assert output is not None, "Output tensor must be provided."
if output_scale is not None and not self.use_triton_flash_attn:
raise NotImplementedError(
"fused output quantization only supported for Triton"
" implementation in ROCMFlashAttentionImpl for now")
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
if key is not None: if key is not None:
assert value is not None assert value is not None
...@@ -770,9 +789,12 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -770,9 +789,12 @@ class ROCmFlashAttentionImpl(AttentionImpl):
query.dtype, query.dtype,
seq_lens, seq_lens,
make_attn_mask=causal_mask) # type: ignore make_attn_mask=causal_mask) # type: ignore
use_fp8_scales = (layer._q_scale and layer._k_scale use_fp8_scales = (layer._q_scale and layer._k_scale
and layer._v_scale and layer._prob_scale and layer._v_scale and layer._prob_scale
and self.kv_cache_dtype == "fp8") and (self.kv_cache_dtype == "fp8"
or self.force_fp8_attention))
full_scales = ( full_scales = (
layer._q_scale.item(), layer._k_scale.item(), layer._q_scale.item(), layer._k_scale.item(),
layer._v_scale.item(), layer._v_scale.item(),
...@@ -791,6 +813,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -791,6 +813,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
attn_masks[0][None] attn_masks[0][None]
if attn_masks is not None else None, if attn_masks is not None else None,
full_scales, full_scales,
output_scale,
) )
elif self.use_naive_attn: elif self.use_naive_attn:
if self.num_kv_heads != self.num_heads: if self.num_kv_heads != self.num_heads:
...@@ -880,7 +903,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -880,7 +903,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
assert _PARTITION_SIZE_ROCM % block_size == 0 assert _PARTITION_SIZE_ROCM % block_size == 0
tmp_output = torch.empty( tmp_output = torch.empty(
size=(num_seqs, num_heads, max_num_partitions, head_size), size=(num_seqs, num_heads, max_num_partitions, head_size),
dtype=output.dtype, dtype=query.dtype,
device=output.device, device=output.device,
) )
exp_sums = torch.empty( exp_sums = torch.empty(
...@@ -914,9 +937,17 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -914,9 +937,17 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self.kv_cache_dtype, self.kv_cache_dtype,
layer._k_scale, layer._k_scale,
layer._v_scale, layer._v_scale,
output_scale,
) )
else: else:
output[num_prefill_tokens:] = paged_attn.forward_decode( # PagedAttention does not support fused quant, manually quantize
if output_scale is None:
out_pa = output[num_prefill_tokens:]
else:
out_pa = torch.empty_like(output[num_prefill_tokens:],
dtype=query.dtype)
out_pa[:] = paged_attn.forward_decode(
decode_query, decode_query,
key_cache, key_cache,
value_cache, value_cache,
...@@ -937,6 +968,14 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -937,6 +968,14 @@ class ROCmFlashAttentionImpl(AttentionImpl):
layer._v_scale, layer._v_scale,
) )
# Manually perform quantization
if output_scale is not None:
out_uq = out_pa.view(-1, self.num_heads * self.head_size)
out_q = output.view(-1, self.num_heads * self.head_size)
ops.scaled_fp8_quant(out_uq,
output_scale,
output=out_q[num_prefill_tokens:])
# Reshape the output tensor. # Reshape the output tensor.
return output.view(-1, self.num_heads * self.head_size) return output.view(-1, self.num_heads * self.head_size)
......
...@@ -65,7 +65,7 @@ class TorchSDPABackend(AttentionBackend): ...@@ -65,7 +65,7 @@ class TorchSDPABackend(AttentionBackend):
dst_kv_cache: torch.Tensor, dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor, src_to_dst: torch.Tensor,
) -> None: ) -> None:
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) raise NotImplementedError("Swap is not supported in TorchSDPABackend.")
@staticmethod @staticmethod
def copy_blocks( def copy_blocks(
...@@ -433,7 +433,6 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): ...@@ -433,7 +433,6 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
self.sliding_window = sliding_window self.sliding_window = sliding_window
self.kv_cache_dtype = kv_cache_dtype self.kv_cache_dtype = kv_cache_dtype
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.need_mask = (self.alibi_slopes is not None self.need_mask = (self.alibi_slopes is not None
or self.sliding_window is not None) or self.sliding_window is not None)
...@@ -459,6 +458,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): ...@@ -459,6 +458,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: TorchSDPAMetadata, # type: ignore attn_metadata: TorchSDPAMetadata, # type: ignore
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention. """Forward pass with torch SDPA and PagedAttention.
...@@ -473,6 +473,10 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): ...@@ -473,6 +473,10 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for TorchSDPABackendImpl")
# For warming-up # For warming-up
if attn_metadata is None: if attn_metadata is None:
......
...@@ -373,7 +373,7 @@ class CommonAttentionState(AttentionState): ...@@ -373,7 +373,7 @@ class CommonAttentionState(AttentionState):
f"Expected attn_backend name to be either 'XFORMERS'," \ f"Expected attn_backend name to be either 'XFORMERS'," \
f"'ROCM_FLASH', or 'FLASH_ATTN', but " \ f"'ROCM_FLASH', or 'FLASH_ATTN', but " \
f"got '{self.runner.attn_backend.get_name()}'" f"got '{self.runner.attn_backend.get_name()}'"
self._add_additonal_input_buffers_for_enc_dec_model( self._add_additional_input_buffers_for_enc_dec_model(
attn_metadata=attn_metadata, input_buffers=input_buffers) attn_metadata=attn_metadata, input_buffers=input_buffers)
return input_buffers return input_buffers
...@@ -427,7 +427,7 @@ class CommonAttentionState(AttentionState): ...@@ -427,7 +427,7 @@ class CommonAttentionState(AttentionState):
attn_metadata.max_encoder_seq_len = self.runner.max_seq_len_to_capture attn_metadata.max_encoder_seq_len = self.runner.max_seq_len_to_capture
attn_metadata.num_encoder_tokens = 0 attn_metadata.num_encoder_tokens = 0
def _add_additonal_input_buffers_for_enc_dec_model( def _add_additional_input_buffers_for_enc_dec_model(
self, attn_metadata, input_buffers: Dict[str, Any]): self, attn_metadata, input_buffers: Dict[str, Any]):
""" """
Saves additional input buffers specific to the encoder-decoder model Saves additional input buffers specific to the encoder-decoder model
......
...@@ -415,7 +415,6 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): ...@@ -415,7 +415,6 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
self.sliding_window = sliding_window self.sliding_window = sliding_window
self.kv_cache_dtype = kv_cache_dtype self.kv_cache_dtype = kv_cache_dtype
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.num_queries_per_kv = self.num_heads // self.num_kv_heads
supported_head_sizes = PagedAttention.get_supported_head_sizes() supported_head_sizes = PagedAttention.get_supported_head_sizes()
...@@ -435,6 +434,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): ...@@ -435,6 +434,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: "XFormersMetadata", attn_metadata: "XFormersMetadata",
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention. """Forward pass with xFormers and PagedAttention.
...@@ -487,6 +487,11 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): ...@@ -487,6 +487,11 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for XFormersImpl")
attn_type = self.attn_type attn_type = self.attn_type
# Check that appropriate attention metadata attributes are # Check that appropriate attention metadata attributes are
# selected for the desired attention type # selected for the desired attention type
......
...@@ -80,6 +80,9 @@ class Attention(nn.Module): ...@@ -80,6 +80,9 @@ class Attention(nn.Module):
calculate_kv_scales = False calculate_kv_scales = False
if num_kv_heads is None: if num_kv_heads is None:
num_kv_heads = num_heads num_kv_heads = num_heads
assert num_heads % num_kv_heads == 0, \
f"num_heads ({num_heads}) is not " \
f"divisible by num_kv_heads ({num_kv_heads})"
# The default k/v_scale is set to 1.0. This is ignored # The default k/v_scale is set to 1.0. This is ignored
# when kv-cache is not fp8, and should be used with # when kv-cache is not fp8, and should be used with
...@@ -206,7 +209,7 @@ class Attention(nn.Module): ...@@ -206,7 +209,7 @@ class Attention(nn.Module):
if self.use_output: if self.use_output:
output_shape = (output_shape output_shape = (output_shape
if output_shape is not None else query.shape) if output_shape is not None else query.shape)
output = torch.empty(output_shape, output = torch.zeros(output_shape,
dtype=query.dtype, dtype=query.dtype,
device=query.device) device=query.device)
hidden_size = output_shape[-1] hidden_size = output_shape[-1]
...@@ -291,7 +294,9 @@ class MultiHeadAttention(nn.Module): ...@@ -291,7 +294,9 @@ class MultiHeadAttention(nn.Module):
self.scale = scale self.scale = scale
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
assert self.num_heads % self.num_kv_heads == 0 assert self.num_heads % self.num_kv_heads == 0, \
f"num_heads ({self.num_heads}) is not " \
f"divisible by num_kv_heads ({self.num_kv_heads})"
self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.num_queries_per_kv = self.num_heads // self.num_kv_heads
dtype = torch.get_default_dtype() dtype = torch.get_default_dtype()
...@@ -301,12 +306,17 @@ class MultiHeadAttention(nn.Module): ...@@ -301,12 +306,17 @@ class MultiHeadAttention(nn.Module):
block_size=16, block_size=16,
is_attention_free=False) is_attention_free=False)
backend = backend_name_to_enum(attn_backend.get_name()) backend = backend_name_to_enum(attn_backend.get_name())
if backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}: if current_platform.is_rocm():
backend = _Backend.XFORMERS # currently, only torch_sdpa is supported on rocm
self.attn_backend = _Backend.TORCH_SDPA
else:
if backend in (_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1,
_Backend.FLEX_ATTENTION):
backend = _Backend.XFORMERS
self.attn_backend = backend if backend in { self.attn_backend = backend if backend in {
_Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.PALLAS_VLLM_V1 _Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.PALLAS_VLLM_V1
} else _Backend.TORCH_SDPA } else _Backend.TORCH_SDPA
def forward( def forward(
self, self,
...@@ -430,6 +440,7 @@ def unified_attention_with_output( ...@@ -430,6 +440,7 @@ def unified_attention_with_output(
value: torch.Tensor, value: torch.Tensor,
output: torch.Tensor, output: torch.Tensor,
layer_name: str, layer_name: str,
output_scale: Optional[torch.Tensor] = None,
) -> None: ) -> None:
wait_for_kv_layer_from_connector(layer_name) wait_for_kv_layer_from_connector(layer_name)
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
...@@ -444,7 +455,8 @@ def unified_attention_with_output( ...@@ -444,7 +455,8 @@ def unified_attention_with_output(
value, value,
kv_cache, kv_cache,
attn_metadata, attn_metadata,
output=output) output=output,
output_scale=output_scale)
maybe_save_kv_layer_to_connector(layer_name, kv_cache) maybe_save_kv_layer_to_connector(layer_name, kv_cache)
...@@ -455,6 +467,7 @@ def unified_attention_with_output_fake( ...@@ -455,6 +467,7 @@ def unified_attention_with_output_fake(
value: torch.Tensor, value: torch.Tensor,
output: torch.Tensor, output: torch.Tensor,
layer_name: str, layer_name: str,
output_scale: Optional[torch.Tensor] = None,
) -> None: ) -> None:
return return
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Dict, List, Optional, Tuple from typing import List, Optional, Tuple
try: try:
import intel_extension_for_pytorch.llm.modules as ipex_modules import intel_extension_for_pytorch.llm.modules as ipex_modules
...@@ -29,7 +29,7 @@ class _PagedAttention: ...@@ -29,7 +29,7 @@ class _PagedAttention:
head_size: int, head_size: int,
*args, *args,
) -> Tuple[int, ...]: ) -> Tuple[int, ...]:
return (2, num_blocks, block_size * num_kv_heads * head_size) return 2, num_blocks, block_size * num_kv_heads * head_size
@staticmethod @staticmethod
def split_kv_cache( def split_kv_cache(
...@@ -120,7 +120,7 @@ class _PagedAttention: ...@@ -120,7 +120,7 @@ class _PagedAttention:
@staticmethod @staticmethod
def copy_blocks( def copy_blocks(
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]], src_to_dists: torch.Tensor,
*args, *args,
) -> None: ) -> None:
key_caches = [kv_cache[0] for kv_cache in kv_caches] key_caches = [kv_cache[0] for kv_cache in kv_caches]
......
...@@ -8,9 +8,7 @@ import torch ...@@ -8,9 +8,7 @@ import torch
from neuronxcc import nki from neuronxcc import nki
from neuronxcc.nki.language import par_dim from neuronxcc.nki.language import par_dim
from vllm.utils import cdiv
def ceil_div(a, b):
return (a + b - 1) // b
def is_power_of_2(x): def is_power_of_2(x):
...@@ -35,11 +33,10 @@ def load_block_tables(block_tables_hbm, num_tiles, num_blocks_per_tile): ...@@ -35,11 +33,10 @@ def load_block_tables(block_tables_hbm, num_tiles, num_blocks_per_tile):
(num_tiles, num_blocks_per_tile)) (num_tiles, num_blocks_per_tile))
block_tables_sbuf = nl.zeros( block_tables_sbuf = nl.zeros(
(ceil_div(num_tiles, (cdiv(num_tiles, B_P_SIZE), par_dim(B_P_SIZE), num_blocks_per_tile),
B_P_SIZE), par_dim(B_P_SIZE), num_blocks_per_tile),
dtype=nl.int32, dtype=nl.int32,
) )
for i in nl.affine_range(ceil_div(num_tiles, B_P_SIZE)): for i in nl.affine_range(cdiv(num_tiles, B_P_SIZE)):
i_p = nl.arange(B_P_SIZE)[:, None] i_p = nl.arange(B_P_SIZE)[:, None]
i_f = nl.arange(num_blocks_per_tile)[None, :] i_f = nl.arange(num_blocks_per_tile)[None, :]
block_tables_sbuf[i, i_p, i_f] = nl.load( block_tables_sbuf[i, i_p, i_f] = nl.load(
...@@ -83,7 +80,7 @@ def transform_block_tables_for_indirect_load( ...@@ -83,7 +80,7 @@ def transform_block_tables_for_indirect_load(
assert is_power_of_2( assert is_power_of_2(
num_blocks_per_tile), f"{num_blocks_per_tile=} is not power of 2" num_blocks_per_tile), f"{num_blocks_per_tile=} is not power of 2"
num_loads = ceil_div(num_blocks_per_tile, B_P_SIZE) num_loads = cdiv(num_blocks_per_tile, B_P_SIZE)
block_tables_transposed = nl.ndarray( block_tables_transposed = nl.ndarray(
( (
num_loads, num_loads,
...@@ -165,7 +162,7 @@ def load_kv_tile_from_cache( ...@@ -165,7 +162,7 @@ def load_kv_tile_from_cache(
equivalent to (par_dim(B_P_SIZE), seqlen_kv // B_P_SIZE * B_D_SIZE) equivalent to (par_dim(B_P_SIZE), seqlen_kv // B_P_SIZE * B_D_SIZE)
""" """
# load key cache # load key cache
num_loads = ceil_div(num_blocks_per_large_tile, B_P_SIZE) num_loads = cdiv(num_blocks_per_large_tile, B_P_SIZE)
for load_idx in nl.affine_range(num_loads): for load_idx in nl.affine_range(num_loads):
i_p = nl.arange(B_P_SIZE)[:, None] i_p = nl.arange(B_P_SIZE)[:, None]
i_f = nl.arange(tiled_block_size * B_D_SIZE)[None, :] i_f = nl.arange(tiled_block_size * B_D_SIZE)[None, :]
...@@ -605,7 +602,7 @@ def flash_paged_attention( ...@@ -605,7 +602,7 @@ def flash_paged_attention(
) )
for large_k_tile_idx in nl.sequential_range(0, num_large_k_tile): for large_k_tile_idx in nl.sequential_range(0, num_large_k_tile):
num_loads = ceil_div(num_blocks_per_large_tile, B_P_SIZE) num_loads = cdiv(num_blocks_per_large_tile, B_P_SIZE)
cur_k_tile = nl.ndarray( cur_k_tile = nl.ndarray(
(par_dim(B_D_SIZE), LARGE_TILE_SZ), (par_dim(B_D_SIZE), LARGE_TILE_SZ),
dtype=kernel_dtype, dtype=kernel_dtype,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import jax
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
from vllm.utils import cdiv
def _kv_cache_update_kernel(
# Prefetch
slices_ref, # [3, padded_num_slices], list of (kv_cache_start,
# new_kv_start, slice_len)
# Input
new_kv_hbm_ref, # [num_tokens, num_combined_kv_heads, head_dim]
kv_cache_hbm_ref, # [total_num_pages * page_size, num_combined_kv_heads,
# head_dim]
# Output
_, # [total_num_pages * page_size, num_combined_kv_heads, head_dim]
# Scratch
scratch, # [num_slices_per_block, page_size, num_combined_kv_heads,
# head_dim]
sem,
):
async_copies = []
block_idx = pl.program_id(0)
num_slices_per_block = scratch.shape[0]
# Copy from new_kv_hbm_ref to scratch
for i in range(num_slices_per_block):
offset_i = i + block_idx * num_slices_per_block
new_kv_start = slices_ref[1, offset_i]
length = slices_ref[2, offset_i]
async_copy = pltpu.make_async_copy(
new_kv_hbm_ref.at[pl.ds(new_kv_start, length), ...],
scratch.at[i, pl.ds(0, length), ...],
sem,
)
async_copy.start()
async_copies.append(async_copy)
for async_copy in async_copies:
async_copy.wait()
# Copy from scratch to kv_cache_hbm_ref
async_copies.clear()
for i in range(num_slices_per_block):
offset_i = i + block_idx * num_slices_per_block
kv_cache_start = slices_ref[0, offset_i]
length = slices_ref[2, offset_i]
async_copy = pltpu.make_async_copy(
scratch.at[i, pl.ds(0, length), ...],
kv_cache_hbm_ref.at[pl.ds(kv_cache_start, length), ...],
sem,
)
async_copy.start()
async_copies.append(async_copy)
for async_copy in async_copies:
async_copy.wait()
@functools.partial(
jax.jit,
static_argnames=["page_size", "num_slices_per_block"],
)
def kv_cache_update(
new_kv: jax.Array, # [total_num_token, num_combined_kv_heads, head_dim]
slices: jax.
Array, # [3, slices], list of (kv_cache_start, new_kv_start, slice_len)
kv_cache: jax.
Array, # [total_num_pages * page_size, num_combined_kv_heads, head_dim]
num_kv_update_slices: jax.Array, # [1]
*,
page_size: int = 32,
num_slices_per_block: int = 8,
):
assert slices.shape[1] % num_slices_per_block == 0
_, num_combined_kv_heads, head_dim = new_kv.shape
assert kv_cache.shape[1] == num_combined_kv_heads
assert kv_cache.shape[2] == head_dim
assert head_dim % 128 == 0
# TODO: Add dynamic check to make sure that the all the slice lengths are
# smaller or equal to page_size
in_specs = [
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
]
out_specs = [pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)]
out_shape = [jax.ShapeDtypeStruct(kv_cache.shape, dtype=kv_cache.dtype)]
scalar_prefetches = [slices]
scratch = pltpu.VMEM(
(num_slices_per_block, page_size, num_combined_kv_heads, head_dim),
new_kv.dtype,
)
scratch_shapes = [
scratch,
pltpu.SemaphoreType.DMA,
]
kernel = pl.pallas_call(
_kv_cache_update_kernel,
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=len(scalar_prefetches),
in_specs=in_specs,
out_specs=out_specs,
grid=(cdiv(num_kv_update_slices[0], num_slices_per_block), ),
scratch_shapes=scratch_shapes,
),
out_shape=out_shape,
input_output_aliases={len(scalar_prefetches) + 1: 0},
)
return kernel(*scalar_prefetches, new_kv, kv_cache)[0]
...@@ -25,9 +25,14 @@ Not currently supported: ...@@ -25,9 +25,14 @@ Not currently supported:
import torch import torch
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.rocm import on_gfx1x
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
# Avoid misleading ROCm warning.
if current_platform.is_rocm():
from vllm.platforms.rocm import on_gfx1x
else:
on_gfx1x = lambda *args, **kwargs: False
torch_dtype: tl.constexpr = torch.float16 torch_dtype: tl.constexpr = torch.float16
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment