Unverified Commit 99c1b9d2 authored by Mick's avatar Mick Committed by GitHub
Browse files

fix: apply cache size limit of attention mask for VisionAttention (#3657)

parent 634a3561
from __future__ import annotations from __future__ import annotations
from functools import lru_cache
from typing import Optional from typing import Optional
import torch import torch
...@@ -223,9 +224,6 @@ class VisionSdpaAttention(nn.Module): ...@@ -223,9 +224,6 @@ class VisionSdpaAttention(nn.Module):
""" """
# TODO: Should it be released after used?
_mask_cache = {}
def __init__( def __init__(
self, self,
head_size: int, head_size: int,
...@@ -239,75 +237,61 @@ class VisionSdpaAttention(nn.Module): ...@@ -239,75 +237,61 @@ class VisionSdpaAttention(nn.Module):
self.use_full_precision_softmax = use_full_precision_softmax self.use_full_precision_softmax = use_full_precision_softmax
self.dropout = dropout self.dropout = dropout
def generate_patch_attention_mask( @staticmethod
self, @lru_cache(maxsize=128)
s: int, def _generate_mask_cache(
bsz: int, s: int, flatten_batch: bool, cu_seqlens: tuple
device, ) -> torch.BoolTensor:
cu_seqlens: Optional[torch.Tensor], """
flatten_batch: bool = False, Generate a boolean attention mask with caching mechanism.
dtype=torch.bfloat16,
) -> torch.Tensor:
r"""
Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1, 1, s, s)`.
When `flatten_batch` is True:
- All sequences in the batch are flattened into a single dimension
- `s` represents the total number of tokens across all sequences in the batch
- Returns a unified mask of shape `(1, 1, s, s)`
When `flatten_batch` is False:
- Each sequence has its own attention mask
- `s` represents the maximum sequence length in the batch
- Returns separate masks of shape `(b, 1, s, s)`
Args: Args:
flatten_batch: (bool): s: sequence length
If True, treats all sequences in the batch as a single flattened sequence flatten_batch: whether to flatten batch dimension
If False, generates separate masks for each sequence cu_seqlens: tuple of cumulative sequence lengths
Returns: Returns:
Tensor of shape `(b, 1, s, s)` or `(1, 1, s, s)`. attention mask tensor
""" """
cache_key = (s, bsz, flatten_batch, tuple(cu_seqlens.cpu().tolist()))
if cache_key in VisionSdpaAttention._mask_cache:
cached_mask = VisionSdpaAttention._mask_cache[cache_key]
# print(f"cache hit for key: {cache_key}")
return cached_mask.to(device=device, dtype=dtype)
if cu_seqlens is None:
raise ValueError("Internal Error: cu_seqlens cannot be None")
if flatten_batch: if flatten_batch:
mask = torch.zeros([1, s, s], device=device, dtype=torch.bool) mask = torch.zeros([1, s, s], dtype=torch.bool)
for i in range(1, len(cu_seqlens)): for i in range(1, len(cu_seqlens)):
start = cu_seqlens[i - 1] start = cu_seqlens[i - 1]
end = cu_seqlens[i] end = cu_seqlens[i]
mask[ mask[..., start:end, start:end] = True
...,
start:end,
start:end,
] = True
else: else:
# [1, 1, 1, s] # [1, 1, 1, s]
row_indices = torch.arange(s, device=device).view(1, 1, 1, s) row_indices = torch.arange(s).view(1, 1, 1, s)
# [1, 1, s, 1] # [1, 1, s, 1]
col_indices = torch.arange(s, device=device).view(1, 1, s, 1) col_indices = torch.arange(s).view(1, 1, s, 1)
# [b, 1, 1, 1] # [b, 1, 1, 1]
seq_lens = ( seq_lens = torch.tensor(
(cu_seqlens[1:] - cu_seqlens[:-1]).to(device=device).view(-1, 1, 1, 1) [end - start for start, end in zip(cu_seqlens[:-1], cu_seqlens[1:])],
) ).view(-1, 1, 1, 1)
mask = (row_indices < seq_lens) & (col_indices < seq_lens) mask = (row_indices < seq_lens) & (col_indices < seq_lens)
# Convert to attention mask format (False -> 0, True -> -inf) return mask
mask = (~mask).to(dtype) * torch.finfo(dtype).min
def generate_patch_attention_mask(
self,
s: int,
cu_seqlens: Optional[torch.Tensor],
flatten_batch: bool = False,
) -> Optional[torch.Tensor]:
r"""
Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1, 1, s, s)`.
Args:
s: sequence length
cu_seqlens: cumulative sequence lengths tensor. If not, returns an empty mask
flatten_batch: whether to flatten batch dimension
Returns:
attention mask tensor or None
"""
if cu_seqlens is None:
return None
VisionSdpaAttention._mask_cache[cache_key] = mask cu_seqlens_tuple = tuple(cu_seqlens.cpu().tolist())
return mask return self._generate_mask_cache(s, flatten_batch, cu_seqlens_tuple)
def forward( def forward(
self, self,
...@@ -330,15 +314,23 @@ class VisionSdpaAttention(nn.Module): ...@@ -330,15 +314,23 @@ class VisionSdpaAttention(nn.Module):
# [b, 1, s, s] # [b, 1, s, s]
if attention_mask is None: if attention_mask is None:
attention_mask = self.generate_patch_attention_mask( attention_mask = self.generate_patch_attention_mask(
s, bsz, q.device, cu_seqlens, self.flatten_batch, q.dtype s, cu_seqlens, flatten_batch=self.flatten_batch
) )
if attention_mask is None:
if self.use_full_precision_softmax:
raise RuntimeError("Empty attention mask")
else:
attention_mask = attention_mask.to(device=q.device)
q, k, v = [rearrange(x, "(b s) h d -> b h s d", b=bsz) for x in [q, k, v]] q, k, v = [rearrange(x, "(b s) h d -> b h s d", b=bsz) for x in [q, k, v]]
# [b, 1, s]
if self.use_full_precision_softmax: if self.use_full_precision_softmax:
scale = self.head_size**-0.5 scale = self.head_size**-0.5
k_transposed = rearrange(k, "b h s d -> b h d s") k_transposed = rearrange(k, "b h s d -> b h d s")
attn_weights = torch.matmul(q, k_transposed) * scale attn_weights = torch.matmul(q, k_transposed) * scale
del k, k_transposed del k, k_transposed
attention_mask = (~attention_mask) * torch.finfo(q.dtype).min
attn_weights = attn_weights + attention_mask attn_weights = attn_weights + attention_mask
del attention_mask del attention_mask
# full-precision # full-precision
...@@ -354,7 +346,12 @@ class VisionSdpaAttention(nn.Module): ...@@ -354,7 +346,12 @@ class VisionSdpaAttention(nn.Module):
# SDPA # SDPA
# [b, h, s, head_size] # [b, h, s, head_size]
output = F.scaled_dot_product_attention( output = F.scaled_dot_product_attention(
q, k, v, attention_mask, dropout_p=self.dropout q,
k,
v,
attn_mask=attention_mask,
dropout_p=self.dropout,
is_causal=False,
) )
# [b, h, s, head_size] --> [b * s, h, head_size] # [b, h, s, head_size] --> [b * s, h, head_size]
...@@ -380,7 +377,6 @@ class VisionTritonAttention(nn.Module): ...@@ -380,7 +377,6 @@ class VisionTritonAttention(nn.Module):
v: torch.Tensor, v: torch.Tensor,
_bsz: int, _bsz: int,
cu_seqlens: Optional[torch.Tensor], cu_seqlens: Optional[torch.Tensor],
**kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
r""" r"""
Args: Args:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment