Unverified Commit 0ff70821 authored by Roger Wang's avatar Roger Wang Committed by GitHub
Browse files

[Core] Deprecate `xformers` (#29262)


Signed-off-by: default avatarRoger Wang <hey@rogerw.io>
parent 5253f427
...@@ -309,7 +309,6 @@ class Glm4vVisionAttention(nn.Module): ...@@ -309,7 +309,6 @@ class Glm4vVisionAttention(nn.Module):
if self.attn_backend not in { if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA, AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
}: }:
raise RuntimeError( raise RuntimeError(
...@@ -345,7 +344,6 @@ class Glm4vVisionAttention(nn.Module): ...@@ -345,7 +344,6 @@ class Glm4vVisionAttention(nn.Module):
rotary_pos_emb_cos: torch.Tensor, rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor, rotary_pos_emb_sin: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention max_seqlen: int | None = None, # Only used for Flash Attention
seqlens: list[int] | None = None, # Only used for xFormers
) -> torch.Tensor: ) -> torch.Tensor:
# [s, b, c] --> [s, b, head * 3 * head_dim] # [s, b, c] --> [s, b, head * 3 * head_dim]
x, _ = self.qkv(x) x, _ = self.qkv(x)
...@@ -400,20 +398,6 @@ class Glm4vVisionAttention(nn.Module): ...@@ -400,20 +398,6 @@ class Glm4vVisionAttention(nn.Module):
context_layer = rearrange( context_layer = rearrange(
context_layer, "b s h d -> s b (h d)" context_layer, "b s h d -> s b (h d)"
).contiguous() ).contiguous()
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
attn_bias = BlockDiagonalMask.from_seqlens(
q_seqlen=seqlens, kv_seqlen=None, device=q.device
)
context_layer = xops.memory_efficient_attention_forward(
q, k, v, attn_bias=attn_bias, p=0, scale=None
)
context_layer = rearrange(
context_layer, "b s h d -> s b (h d)"
).contiguous()
output, _ = self.proj(context_layer) output, _ = self.proj(context_layer)
return output return output
...@@ -461,7 +445,6 @@ class Glm4vVisionBlock(nn.Module): ...@@ -461,7 +445,6 @@ class Glm4vVisionBlock(nn.Module):
rotary_pos_emb_cos: torch.Tensor, rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor, rotary_pos_emb_sin: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention max_seqlen: int | None = None, # Only used for Flash Attention
seqlens: list[int] | None = None, # Only used for xFormers
) -> torch.Tensor: ) -> torch.Tensor:
x_attn = self.attn( x_attn = self.attn(
self.norm1(x), self.norm1(x),
...@@ -469,7 +452,6 @@ class Glm4vVisionBlock(nn.Module): ...@@ -469,7 +452,6 @@ class Glm4vVisionBlock(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos, rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin, rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
seqlens=seqlens,
) )
x_fused_norm, residual = self.norm2(x, residual=x_attn) x_fused_norm, residual = self.norm2(x, residual=x_attn)
x = residual + self.mlp(x_fused_norm) x = residual + self.mlp(x_fused_norm)
...@@ -803,15 +785,14 @@ class Glm4vVisionTransformer(nn.Module): ...@@ -803,15 +785,14 @@ class Glm4vVisionTransformer(nn.Module):
def compute_attn_mask_seqlen( def compute_attn_mask_seqlen(
self, self,
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
) -> tuple[int | None, list[int] | None]: ) -> int | None:
max_seqlen, seqlens = None, None max_seqlen = None
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
if ( if (
self.attn_backend == AttentionBackendEnum.FLASH_ATTN self.attn_backend == AttentionBackendEnum.FLASH_ATTN
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
): ):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
return max_seqlen, seqlens return max_seqlen
def forward( def forward(
self, self,
...@@ -836,8 +817,9 @@ class Glm4vVisionTransformer(nn.Module): ...@@ -836,8 +817,9 @@ class Glm4vVisionTransformer(nn.Module):
).cumsum(dim=0, dtype=torch.int32) ).cumsum(dim=0, dtype=torch.int32)
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0) cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
# pre-compute seqlens for attn mask to reduce cuMemcpy operations # pre-compute max_seqlen for attn mask to reduce cuMemcpy operations
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens) max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
x = self.embeddings( x = self.embeddings(
x, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1] x, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1]
) )
...@@ -851,7 +833,6 @@ class Glm4vVisionTransformer(nn.Module): ...@@ -851,7 +833,6 @@ class Glm4vVisionTransformer(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos, rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin, rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
seqlens=seqlens,
) )
# adapter # adapter
......
...@@ -9,6 +9,7 @@ from typing import Annotated, Any, Literal, TypeAlias, TypeVar ...@@ -9,6 +9,7 @@ from typing import Annotated, Any, Literal, TypeAlias, TypeVar
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from transformers import PretrainedConfig from transformers import PretrainedConfig
from transformers.activations import GELUActivation from transformers.activations import GELUActivation
...@@ -424,7 +425,7 @@ class KeyeSiglipAttention(nn.Module): ...@@ -424,7 +425,7 @@ class KeyeSiglipAttention(nn.Module):
if self.attn_backend not in { if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.XFORMERS, AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
}: }:
raise RuntimeError( raise RuntimeError(
...@@ -451,7 +452,6 @@ class KeyeSiglipAttention(nn.Module): ...@@ -451,7 +452,6 @@ class KeyeSiglipAttention(nn.Module):
) )
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
batch_size = q.shape[0] batch_size = q.shape[0]
if rope_emb is None: if rope_emb is None:
...@@ -498,17 +498,21 @@ class KeyeSiglipAttention(nn.Module): ...@@ -498,17 +498,21 @@ class KeyeSiglipAttention(nn.Module):
softmax_scale=self.scale, softmax_scale=self.scale,
) )
context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size) context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
elif self.attn_backend == AttentionBackendEnum.XFORMERS: elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
from xformers import ops as xops outputs = []
from xformers.ops.fmha.attn_bias import BlockDiagonalMask for i in range(1, len(cu_seqlens)):
start_idx = cu_seqlens[i - 1]
attn_bias = BlockDiagonalMask.from_seqlens( end_idx = cu_seqlens[i]
q_seqlen=seqlens, kv_seqlen=None, device=q.device q_i = q[:, start_idx:end_idx]
) k_i = k[:, start_idx:end_idx]
v_i = v[:, start_idx:end_idx]
context_layer = xops.memory_efficient_attention_forward( q_i, k_i, v_i = (
q, k, v, attn_bias=attn_bias, p=0, scale=None rearrange(x, "b s h d -> b h s d") for x in (q_i, k_i, v_i)
) )
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
output_i = rearrange(output_i, "b h s d -> b s h d ")
outputs.append(output_i)
context_layer = torch.cat(outputs, dim=1) if outputs else q[:, :0]
context_layer = rearrange(context_layer, "b s h d -> b s (h d)").contiguous() context_layer = rearrange(context_layer, "b s h d -> b s (h d)").contiguous()
......
...@@ -38,7 +38,6 @@ from vllm.attention.layer import ( ...@@ -38,7 +38,6 @@ from vllm.attention.layer import (
) )
from vllm.attention.ops.vit_attn_wrappers import ( from vllm.attention.ops.vit_attn_wrappers import (
vit_flash_attn_wrapper, vit_flash_attn_wrapper,
vit_xformers_attn_wrapper,
) )
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
...@@ -657,7 +656,6 @@ class SiglipAttention(nn.Module): ...@@ -657,7 +656,6 @@ class SiglipAttention(nn.Module):
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor | None, rotary_pos_emb: torch.Tensor | None,
max_seqlen: torch.Tensor | None, max_seqlen: torch.Tensor | None,
seqlens: torch.Tensor | None,
) -> torch.Tensor: ) -> torch.Tensor:
batch_size, _, _ = hidden_states.shape batch_size, _, _ = hidden_states.shape
...@@ -703,10 +701,6 @@ class SiglipAttention(nn.Module): ...@@ -703,10 +701,6 @@ class SiglipAttention(nn.Module):
context_layer = rearrange( context_layer = rearrange(
context_layer, "b s h d -> s b (h d)" context_layer, "b s h d -> s b (h d)"
).contiguous() ).contiguous()
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
if seqlens is None:
raise ValueError("xFormers attention backend requires seqlens tensor.")
context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens)
else: else:
raise RuntimeError( raise RuntimeError(
f"PaddleOCR-VL does not support {self.attn_backend} backend now." f"PaddleOCR-VL does not support {self.attn_backend} backend now."
...@@ -818,7 +812,6 @@ class SiglipEncoderLayer(nn.Module): ...@@ -818,7 +812,6 @@ class SiglipEncoderLayer(nn.Module):
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor | None, rotary_pos_emb: torch.Tensor | None,
max_seqlen: torch.Tensor | None, max_seqlen: torch.Tensor | None,
seqlens: torch.Tensor | None,
) -> torch.Tensor: ) -> torch.Tensor:
residual = hidden_states residual = hidden_states
...@@ -828,7 +821,6 @@ class SiglipEncoderLayer(nn.Module): ...@@ -828,7 +821,6 @@ class SiglipEncoderLayer(nn.Module):
cu_seqlens=cu_seqlens, cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb, rotary_pos_emb=rotary_pos_emb,
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
seqlens=seqlens,
) )
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
...@@ -870,7 +862,6 @@ class SiglipEncoder(nn.Module): ...@@ -870,7 +862,6 @@ class SiglipEncoder(nn.Module):
if self.attn_backend not in { if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA, AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
}: }:
raise RuntimeError( raise RuntimeError(
...@@ -943,14 +934,11 @@ class SiglipEncoder(nn.Module): ...@@ -943,14 +934,11 @@ class SiglipEncoder(nn.Module):
cu_seqlens = cu_seqlens.to(device=device) cu_seqlens = cu_seqlens.to(device=device)
max_seqlen = None max_seqlen = None
seqlens = None
if self.attn_backend in { if self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
}: }:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
hidden_states = inputs_embeds hidden_states = inputs_embeds
for encoder_layer in self.layers: for encoder_layer in self.layers:
...@@ -959,7 +947,6 @@ class SiglipEncoder(nn.Module): ...@@ -959,7 +947,6 @@ class SiglipEncoder(nn.Module):
cu_seqlens=cu_seqlens, cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb, rotary_pos_emb=rotary_pos_emb,
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
seqlens=seqlens,
) )
return hidden_states return hidden_states
......
...@@ -74,6 +74,7 @@ from .vision import ( ...@@ -74,6 +74,7 @@ from .vision import (
) )
try: try:
# Note: vLLM does not install xformers by default.
from xformers import ops as xops from xformers import ops as xops
if current_platform.is_cuda() and current_platform.has_device_capability(100): if current_platform.is_cuda() and current_platform.has_device_capability(100):
......
...@@ -46,7 +46,6 @@ from vllm.attention.layer import maybe_get_vit_flash_attn_backend ...@@ -46,7 +46,6 @@ from vllm.attention.layer import maybe_get_vit_flash_attn_backend
from vllm.attention.ops.vit_attn_wrappers import ( from vllm.attention.ops.vit_attn_wrappers import (
vit_flash_attn_wrapper, vit_flash_attn_wrapper,
vit_torch_sdpa_wrapper, vit_torch_sdpa_wrapper,
vit_xformers_attn_wrapper,
) )
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig from vllm.config import VllmConfig
...@@ -375,7 +374,6 @@ class Qwen2_5_VisionAttention(nn.Module): ...@@ -375,7 +374,6 @@ class Qwen2_5_VisionAttention(nn.Module):
rotary_pos_emb_cos: torch.Tensor, rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor, rotary_pos_emb_sin: torch.Tensor,
max_seqlen: torch.Tensor, # Only used for Flash Attention max_seqlen: torch.Tensor, # Only used for Flash Attention
seqlens: torch.Tensor, # Only used for xFormers
) -> torch.Tensor: ) -> torch.Tensor:
# [s, b, c] --> [s, b, head * 3 * head_dim] # [s, b, c] --> [s, b, head * 3 * head_dim]
x, _ = self.qkv(x) x, _ = self.qkv(x)
...@@ -435,8 +433,6 @@ class Qwen2_5_VisionAttention(nn.Module): ...@@ -435,8 +433,6 @@ class Qwen2_5_VisionAttention(nn.Module):
v, v,
cu_seqlens, cu_seqlens,
) )
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens)
output, _ = self.proj(context_layer) output, _ = self.proj(context_layer)
return output return output
...@@ -448,9 +444,7 @@ class Qwen2_5_VisionAttention(nn.Module): ...@@ -448,9 +444,7 @@ class Qwen2_5_VisionAttention(nn.Module):
"cu_seqlens": 0, "cu_seqlens": 0,
"rotary_pos_emb_cos": 0, "rotary_pos_emb_cos": 0,
"rotary_pos_emb_sin": 0, "rotary_pos_emb_sin": 0,
"seqlens": 0,
}, },
mark_unbacked_dims={"seqlens": 0},
enable_if=should_torch_compile_mm_vit, enable_if=should_torch_compile_mm_vit,
) )
class Qwen2_5_VisionBlock(nn.Module): class Qwen2_5_VisionBlock(nn.Module):
...@@ -501,7 +495,6 @@ class Qwen2_5_VisionBlock(nn.Module): ...@@ -501,7 +495,6 @@ class Qwen2_5_VisionBlock(nn.Module):
rotary_pos_emb_cos: torch.Tensor, rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor, rotary_pos_emb_sin: torch.Tensor,
max_seqlen: torch.Tensor, # Only used for Flash Attention max_seqlen: torch.Tensor, # Only used for Flash Attention
seqlens: torch.Tensor, # Only used for xFormers
) -> torch.Tensor: ) -> torch.Tensor:
x_attn = self.attn( x_attn = self.attn(
self.norm1(x), self.norm1(x),
...@@ -509,7 +502,6 @@ class Qwen2_5_VisionBlock(nn.Module): ...@@ -509,7 +502,6 @@ class Qwen2_5_VisionBlock(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos, rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin, rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
seqlens=seqlens,
) )
x_fused_norm, residual = self.norm2(x, residual=x_attn) x_fused_norm, residual = self.norm2(x, residual=x_attn)
x = residual + self.mlp(x_fused_norm) x = residual + self.mlp(x_fused_norm)
...@@ -670,7 +662,6 @@ class Qwen2_5_VisionTransformer(nn.Module): ...@@ -670,7 +662,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
if self.attn_backend not in { if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA, AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
}: }:
raise RuntimeError( raise RuntimeError(
...@@ -822,17 +813,14 @@ class Qwen2_5_VisionTransformer(nn.Module): ...@@ -822,17 +813,14 @@ class Qwen2_5_VisionTransformer(nn.Module):
def compute_attn_mask_seqlen( def compute_attn_mask_seqlen(
self, self,
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor:
max_seqlen = torch.zeros([], device=cu_seqlens.device) max_seqlen = torch.zeros([], device=cu_seqlens.device)
seqlens = torch.zeros(1, device=cu_seqlens.device)
if self.attn_backend in { if self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
}: }:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
elif self.attn_backend == AttentionBackendEnum.XFORMERS: return max_seqlen
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
return max_seqlen, seqlens
@staticmethod @staticmethod
def invert_permutation(perm: torch.Tensor) -> torch.Tensor: def invert_permutation(perm: torch.Tensor) -> torch.Tensor:
...@@ -897,10 +885,8 @@ class Qwen2_5_VisionTransformer(nn.Module): ...@@ -897,10 +885,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
# transformers # transformers
# pre-compute seqlens for window/full attn to reduce cuMemcpy operations # pre-compute seqlens for window/full attn to reduce cuMemcpy operations
max_seqlen_full, seqlens_full = self.compute_attn_mask_seqlen(cu_seqlens) max_seqlen_full = self.compute_attn_mask_seqlen(cu_seqlens)
max_seqlen_window, seqlens_window = self.compute_attn_mask_seqlen( max_seqlen_window = self.compute_attn_mask_seqlen(cu_window_seqlens)
cu_window_seqlens
)
cu_seqlens = cu_seqlens.to(device=self.device, non_blocking=True) cu_seqlens = cu_seqlens.to(device=self.device, non_blocking=True)
cu_window_seqlens = cu_window_seqlens.to(device=self.device, non_blocking=True) cu_window_seqlens = cu_window_seqlens.to(device=self.device, non_blocking=True)
...@@ -927,11 +913,9 @@ class Qwen2_5_VisionTransformer(nn.Module): ...@@ -927,11 +913,9 @@ class Qwen2_5_VisionTransformer(nn.Module):
if layer_num in self.fullatt_block_indexes: if layer_num in self.fullatt_block_indexes:
cu_seqlens_now = cu_seqlens cu_seqlens_now = cu_seqlens
max_seqlen_now = max_seqlen_full max_seqlen_now = max_seqlen_full
seqlens_now = seqlens_full
else: else:
cu_seqlens_now = cu_window_seqlens cu_seqlens_now = cu_window_seqlens
max_seqlen_now = max_seqlen_window max_seqlen_now = max_seqlen_window
seqlens_now = seqlens_window
hidden_states = blk( hidden_states = blk(
hidden_states, hidden_states,
...@@ -939,7 +923,6 @@ class Qwen2_5_VisionTransformer(nn.Module): ...@@ -939,7 +923,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos, rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin, rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen_now, max_seqlen=max_seqlen_now,
seqlens=seqlens_now,
) )
# For Qwen2.5-VL-3B, float16 will overflow at last block # For Qwen2.5-VL-3B, float16 will overflow at last block
......
...@@ -348,7 +348,6 @@ class Qwen2VisionAttention(nn.Module): ...@@ -348,7 +348,6 @@ class Qwen2VisionAttention(nn.Module):
if self.attn_backend not in { if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA, AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
}: }:
raise RuntimeError( raise RuntimeError(
...@@ -384,7 +383,6 @@ class Qwen2VisionAttention(nn.Module): ...@@ -384,7 +383,6 @@ class Qwen2VisionAttention(nn.Module):
rotary_pos_emb_cos: torch.Tensor, rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor, rotary_pos_emb_sin: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention max_seqlen: int | None = None, # Only used for Flash Attention
seqlens: list[int] | None = None, # Only used for xFormers
) -> torch.Tensor: ) -> torch.Tensor:
# [s, b, c] --> [s, b, 3 * head * head_dim] # [s, b, c] --> [s, b, 3 * head * head_dim]
x, _ = self.qkv(x) x, _ = self.qkv(x)
...@@ -445,20 +443,6 @@ class Qwen2VisionAttention(nn.Module): ...@@ -445,20 +443,6 @@ class Qwen2VisionAttention(nn.Module):
context_layer = rearrange( context_layer = rearrange(
context_layer, "b s h d -> s b (h d)" context_layer, "b s h d -> s b (h d)"
).contiguous() ).contiguous()
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
attn_bias = BlockDiagonalMask.from_seqlens(
q_seqlen=seqlens, kv_seqlen=None, device=q.device
)
context_layer = xops.memory_efficient_attention_forward(
q, k, v, attn_bias=attn_bias, p=0, scale=None
)
context_layer = rearrange(
context_layer, "b s h d -> s b (h d)"
).contiguous()
output, _ = self.proj(context_layer) output, _ = self.proj(context_layer)
return output return output
...@@ -509,7 +493,6 @@ class Qwen2VisionBlock(nn.Module): ...@@ -509,7 +493,6 @@ class Qwen2VisionBlock(nn.Module):
rotary_pos_emb_cos: torch.Tensor, rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor, rotary_pos_emb_sin: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention max_seqlen: int | None = None, # Only used for Flash Attention
seqlens: list[int] | None = None, # Only used for xFormers
) -> torch.Tensor: ) -> torch.Tensor:
x = x + self.attn( x = x + self.attn(
self.norm1(x), self.norm1(x),
...@@ -517,7 +500,6 @@ class Qwen2VisionBlock(nn.Module): ...@@ -517,7 +500,6 @@ class Qwen2VisionBlock(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos, rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin, rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
seqlens=seqlens,
) )
x = x + self.mlp(self.norm2(x)) x = x + self.mlp(self.norm2(x))
...@@ -728,18 +710,14 @@ class Qwen2VisionTransformer(nn.Module): ...@@ -728,18 +710,14 @@ class Qwen2VisionTransformer(nn.Module):
sin_combined = sin[pos_ids].flatten(1) sin_combined = sin[pos_ids].flatten(1)
return cos_combined, sin_combined return cos_combined, sin_combined
def compute_attn_mask_seqlen( def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> int | None:
self, cu_seqlens: torch.Tensor max_seqlen = None
) -> tuple[int | None, list[int] | None]:
max_seqlen, seqlens = None, None
if self.attn_backend in { if self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
}: }:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
elif self.attn_backend == AttentionBackendEnum.XFORMERS: return max_seqlen
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
return max_seqlen, seqlens
def forward( def forward(
self, self,
...@@ -771,7 +749,7 @@ class Qwen2VisionTransformer(nn.Module): ...@@ -771,7 +749,7 @@ class Qwen2VisionTransformer(nn.Module):
x = x.unsqueeze(1) x = x.unsqueeze(1)
# pre-compute seqlens for attn mask to reduce cuMemcpy operations # pre-compute seqlens for attn mask to reduce cuMemcpy operations
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens) max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
cu_seqlens = cu_seqlens.to(self.device, non_blocking=True) cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
for blk in self.blocks: for blk in self.blocks:
x = blk( x = blk(
...@@ -780,7 +758,6 @@ class Qwen2VisionTransformer(nn.Module): ...@@ -780,7 +758,6 @@ class Qwen2VisionTransformer(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos, rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin, rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
seqlens=seqlens,
) )
# adapter # adapter
......
...@@ -224,7 +224,6 @@ class Qwen3_VisionBlock(nn.Module): ...@@ -224,7 +224,6 @@ class Qwen3_VisionBlock(nn.Module):
rotary_pos_emb_cos: torch.Tensor, rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor, rotary_pos_emb_sin: torch.Tensor,
max_seqlen: torch.Tensor, # Only used for Flash Attention max_seqlen: torch.Tensor, # Only used for Flash Attention
seqlens: torch.Tensor, # Only used for xFormers
) -> torch.Tensor: ) -> torch.Tensor:
x = x + self.attn( x = x + self.attn(
self.norm1(x), self.norm1(x),
...@@ -232,7 +231,6 @@ class Qwen3_VisionBlock(nn.Module): ...@@ -232,7 +231,6 @@ class Qwen3_VisionBlock(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos, rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin, rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
seqlens=seqlens,
) )
x = x + self.mlp(self.norm2(x)) x = x + self.mlp(self.norm2(x))
...@@ -500,14 +498,11 @@ class Qwen3Omni_VisionTransformer(nn.Module): ...@@ -500,14 +498,11 @@ class Qwen3Omni_VisionTransformer(nn.Module):
def compute_attn_mask_seqlen( def compute_attn_mask_seqlen(
self, self,
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor:
max_seqlen = torch.zeros([], device=cu_seqlens.device) max_seqlen = torch.zeros([], device=cu_seqlens.device)
seqlens = torch.zeros(1, device=cu_seqlens.device)
if self.attn_backend == AttentionBackendEnum.FLASH_ATTN: if self.attn_backend == AttentionBackendEnum.FLASH_ATTN:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
elif self.attn_backend == AttentionBackendEnum.XFORMERS: return max_seqlen
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
return max_seqlen, seqlens
def forward( def forward(
self, self,
...@@ -533,7 +528,7 @@ class Qwen3Omni_VisionTransformer(nn.Module): ...@@ -533,7 +528,7 @@ class Qwen3Omni_VisionTransformer(nn.Module):
hidden_states = hidden_states.unsqueeze(1) hidden_states = hidden_states.unsqueeze(1)
rotary_pos_emb_cos = rotary_pos_emb_cos.to(hidden_states.device) rotary_pos_emb_cos = rotary_pos_emb_cos.to(hidden_states.device)
rotary_pos_emb_sin = rotary_pos_emb_sin.to(hidden_states.device) rotary_pos_emb_sin = rotary_pos_emb_sin.to(hidden_states.device)
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens) max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
hidden_states_list = [] hidden_states_list = []
deepstack_visual_indexes = self.deepstack_visual_indexes deepstack_visual_indexes = self.deepstack_visual_indexes
...@@ -545,7 +540,6 @@ class Qwen3Omni_VisionTransformer(nn.Module): ...@@ -545,7 +540,6 @@ class Qwen3Omni_VisionTransformer(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos, rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin, rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
seqlens=seqlens,
) )
if ( if (
deepstack_visual_indexes is not None deepstack_visual_indexes is not None
......
...@@ -235,7 +235,6 @@ class Qwen3_VisionBlock(nn.Module): ...@@ -235,7 +235,6 @@ class Qwen3_VisionBlock(nn.Module):
rotary_pos_emb_cos: torch.Tensor, rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor, rotary_pos_emb_sin: torch.Tensor,
max_seqlen: torch.Tensor, # Only used for Flash Attention max_seqlen: torch.Tensor, # Only used for Flash Attention
seqlens: torch.Tensor, # Only used for xFormers
) -> torch.Tensor: ) -> torch.Tensor:
x = x + self.attn( x = x + self.attn(
self.norm1(x), self.norm1(x),
...@@ -243,7 +242,6 @@ class Qwen3_VisionBlock(nn.Module): ...@@ -243,7 +242,6 @@ class Qwen3_VisionBlock(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos, rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin, rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
seqlens=seqlens,
) )
x = x + self.mlp(self.norm2(x)) x = x + self.mlp(self.norm2(x))
...@@ -391,7 +389,6 @@ class Qwen3_VisionTransformer(nn.Module): ...@@ -391,7 +389,6 @@ class Qwen3_VisionTransformer(nn.Module):
if self.attn_backend not in { if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA, AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
}: }:
raise RuntimeError( raise RuntimeError(
...@@ -531,17 +528,14 @@ class Qwen3_VisionTransformer(nn.Module): ...@@ -531,17 +528,14 @@ class Qwen3_VisionTransformer(nn.Module):
def compute_attn_mask_seqlen( def compute_attn_mask_seqlen(
self, self,
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor:
max_seqlen = torch.zeros([], device=cu_seqlens.device) max_seqlen = torch.zeros([], device=cu_seqlens.device)
seqlens = torch.zeros(1, device=cu_seqlens.device)
if ( if (
self.attn_backend == AttentionBackendEnum.FLASH_ATTN self.attn_backend == AttentionBackendEnum.FLASH_ATTN
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
): ):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
elif self.attn_backend == AttentionBackendEnum.XFORMERS: return max_seqlen
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
return max_seqlen, seqlens
def forward( def forward(
self, self,
...@@ -569,7 +563,7 @@ class Qwen3_VisionTransformer(nn.Module): ...@@ -569,7 +563,7 @@ class Qwen3_VisionTransformer(nn.Module):
cu_seqlens = torch.from_numpy(cu_seqlens) cu_seqlens = torch.from_numpy(cu_seqlens)
hidden_states = hidden_states.unsqueeze(1) hidden_states = hidden_states.unsqueeze(1)
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens) max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
cu_seqlens = cu_seqlens.to(self.device, non_blocking=True) cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
deepstack_feature_lists = [] deepstack_feature_lists = []
...@@ -580,7 +574,6 @@ class Qwen3_VisionTransformer(nn.Module): ...@@ -580,7 +574,6 @@ class Qwen3_VisionTransformer(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos, rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin, rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
seqlens=seqlens,
) )
if layer_num in self.deepstack_visual_indexes: if layer_num in self.deepstack_visual_indexes:
deepstack_merger_idx = self.deepstack_visual_indexes.index(layer_num) deepstack_merger_idx = self.deepstack_visual_indexes.index(layer_num)
......
...@@ -277,12 +277,7 @@ class CudaPlatformBase(Platform): ...@@ -277,12 +277,7 @@ class CudaPlatformBase(Platform):
except ImportError: except ImportError:
pass pass
if cls.has_device_capability(100):
# xFormers doesn't support Blackwell, fall back to SDPA
# See https://github.com/facebookresearch/xformers/issues/1317#issuecomment-3199392579 # noqa: E501
return AttentionBackendEnum.TORCH_SDPA return AttentionBackendEnum.TORCH_SDPA
else:
return AttentionBackendEnum.XFORMERS
@classmethod @classmethod
def get_valid_backends( def get_valid_backends(
......
...@@ -49,7 +49,6 @@ STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND" ...@@ -49,7 +49,6 @@ STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND"
# Possible string values of STR_BACKEND_ENV_VAR # Possible string values of STR_BACKEND_ENV_VAR
# register, corresponding to possible backends # register, corresponding to possible backends
STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER" STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER"
STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN" STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
STR_INVALID_VAL: str = "INVALID" STR_INVALID_VAL: str = "INVALID"
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with XFormersAttention."""
from dataclasses import dataclass
from typing import ClassVar, Optional
import torch
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionImpl,
AttentionType,
MultipleOf,
)
from vllm.attention.ops.triton_unified_attention import unified_attention
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder,
CommonAttentionMetadata,
split_decodes_and_prefills,
)
from vllm.v1.kv_cache_interface import AttentionSpec
try:
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import (
AttentionBias,
PagedBlockDiagonalCausalWithOffsetPaddedKeysMask,
)
XFORMERS_AVAILABLE = True
except ImportError:
XFORMERS_AVAILABLE = False
from vllm import _custom_ops as ops
logger = init_logger(__name__)
class XFormersAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [MultipleOf(16)]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [
32,
40,
48,
56,
64,
72,
80,
88,
96,
104,
112,
120,
128,
136,
144,
152,
160,
168,
176,
184,
192,
200,
208,
216,
224,
232,
240,
248,
256,
]
@staticmethod
def get_name() -> str:
return "XFORMERS"
@staticmethod
def get_impl_cls() -> type["XFormersAttentionImpl"]:
return XFormersAttentionImpl
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
def get_builder_cls() -> type["XFormersAttentionMetadataBuilder"]:
return XFormersAttentionMetadataBuilder
@staticmethod
def use_cascade_attention(*args, **kwargs) -> bool:
return False
@dataclass
class XFormersAttentionMetadata:
num_actual_tokens: int # Number of tokens excluding padding.
max_query_len: int
query_start_loc: torch.Tensor
max_seq_len: int
seq_lens: torch.Tensor
block_table: torch.Tensor
slot_mapping: torch.Tensor
num_prefill_tokens: int = 0
num_decode_tokens: int = 0
num_prefills: int = 0
num_decodes: int = 0
# Biases for different attention types.
attn_bias: Optional["AttentionBias"] = None
# Self-attention prefill/decode metadata cache
_cached_prefill_metadata: Optional["XFormersAttentionMetadata"] = None
_cached_decode_metadata: Optional["XFormersAttentionMetadata"] = None
@property
def prefill_metadata(self) -> Optional["XFormersAttentionMetadata"]:
if self.num_prefills == 0:
return None
if self._cached_prefill_metadata is not None:
# Recover cached prefill-phase attention
# metadata structure
return self._cached_prefill_metadata
q_start_loc = self.query_start_loc[self.num_decodes :]
q_seqlens = torch.diff(q_start_loc)
kv_seqlens = self.seq_lens[self.num_decodes :]
# Construct & cache prefill-phase attention metadata structure
self._cached_prefill_metadata = XFormersAttentionMetadata(
num_actual_tokens=self.num_prefill_tokens,
max_query_len=int(q_seqlens.max().item()),
query_start_loc=q_start_loc - q_start_loc[0],
max_seq_len=int(kv_seqlens.max().item()),
seq_lens=kv_seqlens,
block_table=self.block_table[self.num_decodes :],
slot_mapping=self.slot_mapping[self.num_decode_tokens :],
)
return self._cached_prefill_metadata
@property
def decode_metadata(self) -> Optional["XFormersAttentionMetadata"]:
if self.num_decode_tokens == 0:
return None
if self._cached_decode_metadata is not None:
# Recover cached decode-phase attention
# metadata structure
return self._cached_decode_metadata
q_start_loc = self.query_start_loc
q_seqlens = torch.diff(q_start_loc)
decode_kv_seqlens = self.seq_lens[: self.num_decodes]
# Construct & cache decode-phase attention metadata structure
self._cached_decode_metadata = XFormersAttentionMetadata(
num_actual_tokens=self.num_decode_tokens,
max_query_len=int(q_seqlens[: self.num_decodes].max().item()),
query_start_loc=q_start_loc[: self.num_decodes + 1],
max_seq_len=int(decode_kv_seqlens.max().item()),
seq_lens=decode_kv_seqlens,
block_table=self.block_table[: self.num_decodes],
slot_mapping=self.slot_mapping[: self.num_decode_tokens],
attn_bias=self.attn_bias,
)
return self._cached_decode_metadata
class XFormersAttentionMetadataBuilder(
AttentionMetadataBuilder[XFormersAttentionMetadata]
):
reorder_batch_threshold: int = 1
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
assert XFORMERS_AVAILABLE
self.block_size = kv_cache_spec.block_size
self._num_decodes = 0
self._num_decode_tokens = 0
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> XFormersAttentionMetadata:
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
)
)
num_actual_tokens = common_attn_metadata.num_actual_tokens
q_start_loc = common_attn_metadata.query_start_loc
q_seqlens = torch.diff(q_start_loc)
max_query_len = common_attn_metadata.max_query_len
kv_seqlens = common_attn_metadata.seq_lens
max_seq_len = common_attn_metadata.max_seq_len
block_table = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
bias = None
if num_decodes > 0:
# Construct the decoder bias.
decode_q_seqlens = q_seqlens[:num_decodes]
decode_kv_seqlens = kv_seqlens[:num_decodes]
bias = PagedBlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens(
q_seqlen=decode_q_seqlens.tolist(),
kv_seqlen=decode_kv_seqlens.tolist(),
page_size=self.block_size,
block_tables=block_table[:num_decodes],
device=block_table.device,
)
return XFormersAttentionMetadata(
num_actual_tokens=num_actual_tokens,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
num_prefills=num_prefills,
num_decodes=num_decodes,
max_query_len=max_query_len,
query_start_loc=q_start_loc,
max_seq_len=max_seq_len,
seq_lens=kv_seqlens,
block_table=block_table,
slot_mapping=slot_mapping,
attn_bias=bias,
)
class XFormersAttentionImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: list[float] | None,
sliding_window: int | None,
kv_cache_dtype: str,
logits_soft_cap: float | None = None,
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: str | None = None,
) -> None:
if kv_sharing_target_layer_name is not None:
raise NotImplementedError("KV sharing is not supported in V0.")
if alibi_slopes is not None:
raise NotImplementedError("XFormers does not support alibi slopes yet.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.kv_cache_dtype = kv_cache_dtype
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
if sliding_window is None:
self.sliding_window = (-1, -1)
else:
self.sliding_window = (sliding_window - 1, 0)
if logits_soft_cap is None:
# Setting logits_soft_cap to 0 means no soft cap.
logits_soft_cap = 0
self.logits_soft_cap = logits_soft_cap
if attn_type != AttentionType.DECODER:
raise NotImplementedError(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"XFormersAttentionImpl."
)
def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: XFormersAttentionMetadata,
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
"""Forward pass with XFormers.
Args:
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache: shape =
[2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for XFormersAttentionImpl"
)
if attn_metadata is None:
# Profiling run.
return output.fill_(0)
# Cache the input KVs.
key_cache, value_cache = kv_cache.unbind(0)
if self.kv_sharing_target_layer_name is None:
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens]
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
num_actual_tokens = attn_metadata.num_actual_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
if prefill_meta := attn_metadata.prefill_metadata:
descale_shape = (prefill_meta.query_start_loc.shape[0] - 1, key.shape[1])
unified_attention(
q=query[num_decode_tokens:num_actual_tokens],
k=key_cache,
v=value_cache,
out=output[num_decode_tokens:num_actual_tokens],
cu_seqlens_q=prefill_meta.query_start_loc,
max_seqlen_q=prefill_meta.max_query_len,
seqused_k=prefill_meta.seq_lens,
max_seqlen_k=prefill_meta.max_seq_len,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=prefill_meta.block_table,
softcap=self.logits_soft_cap,
q_descale=None, # Not supported
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
)
if decode_meta := attn_metadata.decode_metadata:
# Query for decode. KV is not needed because it is already cached.
decode_query = query[:num_decode_tokens]
# Reshape query to [1, B_T, G, H, D].
q = decode_query.view(
1, -1, self.num_kv_heads, self.num_queries_per_kv, self.head_size
)
# Reshape the k and v caches to [1, Bkv_T, G, H, D]
cache_k = key_cache.view(
1, -1, self.num_kv_heads, 1, self.head_size
).expand(
1,
-1,
self.num_kv_heads,
self.num_queries_per_kv,
self.head_size,
)
cache_v = value_cache.view(
1, -1, self.num_kv_heads, 1, self.head_size
).expand(
1,
-1,
self.num_kv_heads,
self.num_queries_per_kv,
self.head_size,
)
attn_bias = decode_meta.attn_bias
output[:num_decode_tokens] = xops.memory_efficient_attention_forward(
q,
cache_k,
cache_v,
attn_bias=attn_bias,
p=0.0,
scale=self.scale,
).view(decode_query.shape)
# Reshape the output tensor.
return output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment