Unverified Commit 99c1b9d2 authored by Mick's avatar Mick Committed by GitHub
Browse files

fix: apply cache size limit of attention mask for VisionAttention (#3657)

parent 634a3561
from __future__ import annotations
from functools import lru_cache
from typing import Optional
import torch
......@@ -223,9 +224,6 @@ class VisionSdpaAttention(nn.Module):
"""
# TODO: Should it be released after used?
_mask_cache = {}
def __init__(
self,
head_size: int,
......@@ -239,75 +237,61 @@ class VisionSdpaAttention(nn.Module):
self.use_full_precision_softmax = use_full_precision_softmax
self.dropout = dropout
def generate_patch_attention_mask(
self,
s: int,
bsz: int,
device,
cu_seqlens: Optional[torch.Tensor],
flatten_batch: bool = False,
dtype=torch.bfloat16,
) -> torch.Tensor:
r"""
Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1, 1, s, s)`.
When `flatten_batch` is True:
- All sequences in the batch are flattened into a single dimension
- `s` represents the total number of tokens across all sequences in the batch
- Returns a unified mask of shape `(1, 1, s, s)`
When `flatten_batch` is False:
- Each sequence has its own attention mask
- `s` represents the maximum sequence length in the batch
- Returns separate masks of shape `(b, 1, s, s)`
@staticmethod
@lru_cache(maxsize=128)
def _generate_mask_cache(
s: int, flatten_batch: bool, cu_seqlens: tuple
) -> torch.BoolTensor:
"""
Generate a boolean attention mask with caching mechanism.
Args:
flatten_batch: (bool):
If True, treats all sequences in the batch as a single flattened sequence
If False, generates separate masks for each sequence
s: sequence length
flatten_batch: whether to flatten batch dimension
cu_seqlens: tuple of cumulative sequence lengths
Returns:
Tensor of shape `(b, 1, s, s)` or `(1, 1, s, s)`.
attention mask tensor
"""
cache_key = (s, bsz, flatten_batch, tuple(cu_seqlens.cpu().tolist()))
if cache_key in VisionSdpaAttention._mask_cache:
cached_mask = VisionSdpaAttention._mask_cache[cache_key]
# print(f"cache hit for key: {cache_key}")
return cached_mask.to(device=device, dtype=dtype)
if cu_seqlens is None:
raise ValueError("Internal Error: cu_seqlens cannot be None")
if flatten_batch:
mask = torch.zeros([1, s, s], device=device, dtype=torch.bool)
mask = torch.zeros([1, s, s], dtype=torch.bool)
for i in range(1, len(cu_seqlens)):
start = cu_seqlens[i - 1]
end = cu_seqlens[i]
mask[
...,
start:end,
start:end,
] = True
mask[..., start:end, start:end] = True
else:
# [1, 1, 1, s]
row_indices = torch.arange(s, device=device).view(1, 1, 1, s)
row_indices = torch.arange(s).view(1, 1, 1, s)
# [1, 1, s, 1]
col_indices = torch.arange(s, device=device).view(1, 1, s, 1)
col_indices = torch.arange(s).view(1, 1, s, 1)
# [b, 1, 1, 1]
seq_lens = (
(cu_seqlens[1:] - cu_seqlens[:-1]).to(device=device).view(-1, 1, 1, 1)
)
seq_lens = torch.tensor(
[end - start for start, end in zip(cu_seqlens[:-1], cu_seqlens[1:])],
).view(-1, 1, 1, 1)
mask = (row_indices < seq_lens) & (col_indices < seq_lens)
# Convert to attention mask format (False -> 0, True -> -inf)
mask = (~mask).to(dtype) * torch.finfo(dtype).min
return mask
def generate_patch_attention_mask(
self,
s: int,
cu_seqlens: Optional[torch.Tensor],
flatten_batch: bool = False,
) -> Optional[torch.Tensor]:
r"""
Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1, 1, s, s)`.
Args:
s: sequence length
cu_seqlens: cumulative sequence lengths tensor. If not, returns an empty mask
flatten_batch: whether to flatten batch dimension
Returns:
attention mask tensor or None
"""
if cu_seqlens is None:
return None
VisionSdpaAttention._mask_cache[cache_key] = mask
cu_seqlens_tuple = tuple(cu_seqlens.cpu().tolist())
return mask
return self._generate_mask_cache(s, flatten_batch, cu_seqlens_tuple)
def forward(
self,
......@@ -330,15 +314,23 @@ class VisionSdpaAttention(nn.Module):
# [b, 1, s, s]
if attention_mask is None:
attention_mask = self.generate_patch_attention_mask(
s, bsz, q.device, cu_seqlens, self.flatten_batch, q.dtype
s, cu_seqlens, flatten_batch=self.flatten_batch
)
if attention_mask is None:
if self.use_full_precision_softmax:
raise RuntimeError("Empty attention mask")
else:
attention_mask = attention_mask.to(device=q.device)
q, k, v = [rearrange(x, "(b s) h d -> b h s d", b=bsz) for x in [q, k, v]]
# [b, 1, s]
if self.use_full_precision_softmax:
scale = self.head_size**-0.5
k_transposed = rearrange(k, "b h s d -> b h d s")
attn_weights = torch.matmul(q, k_transposed) * scale
del k, k_transposed
attention_mask = (~attention_mask) * torch.finfo(q.dtype).min
attn_weights = attn_weights + attention_mask
del attention_mask
# full-precision
......@@ -354,7 +346,12 @@ class VisionSdpaAttention(nn.Module):
# SDPA
# [b, h, s, head_size]
output = F.scaled_dot_product_attention(
q, k, v, attention_mask, dropout_p=self.dropout
q,
k,
v,
attn_mask=attention_mask,
dropout_p=self.dropout,
is_causal=False,
)
# [b, h, s, head_size] --> [b * s, h, head_size]
......@@ -380,7 +377,6 @@ class VisionTritonAttention(nn.Module):
v: torch.Tensor,
_bsz: int,
cu_seqlens: Optional[torch.Tensor],
**kwargs,
) -> torch.Tensor:
r"""
Args:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment