Unverified Commit 289fc48a authored by Netanel Haber's avatar Netanel Haber Committed by GitHub
Browse files

Use MMEncoderAttention (=use FlashAttention) instead of torch.sdpa in radio.py (#35653)

parent 2f2212e6
...@@ -10,7 +10,8 @@ ...@@ -10,7 +10,8 @@
import math import math
from collections.abc import Iterable from collections.abc import Iterable
from itertools import repeat from dataclasses import dataclass
from itertools import accumulate, repeat
from typing import TypeAlias from typing import TypeAlias
import torch import torch
...@@ -477,28 +478,27 @@ class ViTPatchLinear(nn.Linear): ...@@ -477,28 +478,27 @@ class ViTPatchLinear(nn.Linear):
self.patch_size = patch_size self.patch_size = patch_size
@dataclass(frozen=True, kw_only=True)
class MaskMetadata:
cu_seqlens: torch.Tensor
max_seqlen: torch.Tensor
class RadioParallelAttention(InternParallelAttention): class RadioParallelAttention(InternParallelAttention):
def forward( def forward(
self, x: torch.Tensor, attn_mask: torch.Tensor | None = None self, x: torch.Tensor, mask_meta: MaskMetadata | None = None
) -> torch.Tensor: ) -> torch.Tensor:
if attn_mask is None:
return super().forward(x)
B, N, _ = x.shape
qkv, _ = self.qkv(x) qkv, _ = self.qkv(x)
q, k, v = qkv.chunk(3, dim=-1) q, k, v = qkv.chunk(3, dim=-1)
if self.qk_normalization: if self.qk_normalization:
q, k = self._apply_qk_norm(q, k) q, k = self._apply_qk_norm(q, k)
q = q.view(B, N, self.num_heads_per_partition, self.head_dim) cu_seqlens, max_seqlen = None, None
k = k.view(B, N, self.num_heads_per_partition, self.head_dim) if mask_meta is not None:
v = v.view(B, N, self.num_heads_per_partition, self.head_dim) cu_seqlens = mask_meta.cu_seqlens
q, k, v = (t.transpose(1, 2) for t in (q, k, v)) max_seqlen = mask_meta.max_seqlen
out = F.scaled_dot_product_attention( out = self.attn(q, k, v, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
q, k, v, attn_mask=attn_mask, scale=self.scale
)
out = out.transpose(1, 2).reshape(B, N, -1)
out, _ = self.proj(out) out, _ = self.proj(out)
return out return out
...@@ -510,11 +510,11 @@ class RadioVisionEncoderLayer(InternVisionEncoderLayer): ...@@ -510,11 +510,11 @@ class RadioVisionEncoderLayer(InternVisionEncoderLayer):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attn_mask: torch.Tensor | None = None, mask_meta: MaskMetadata | None = None,
): ):
hidden_states = ( hidden_states = (
hidden_states hidden_states
+ self.attn(self.norm1(hidden_states), attn_mask=attn_mask) * self.ls1 + self.attn(self.norm1(hidden_states), mask_meta=mask_meta) * self.ls1
) )
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) * self.ls2 hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) * self.ls2
...@@ -529,11 +529,11 @@ class RadioVisionEncoder(InternVisionEncoder): ...@@ -529,11 +529,11 @@ class RadioVisionEncoder(InternVisionEncoder):
def forward( def forward(
self, self,
inputs_embeds: torch.Tensor, inputs_embeds: torch.Tensor,
attn_mask: torch.Tensor | None = None, mask_meta: MaskMetadata | None = None,
): ):
hidden_states = inputs_embeds hidden_states = inputs_embeds
for encoder_layer in self.layers: for encoder_layer in self.layers:
hidden_states = encoder_layer(hidden_states, attn_mask=attn_mask) hidden_states = encoder_layer(hidden_states, mask_meta=mask_meta)
return hidden_states return hidden_states
...@@ -590,44 +590,36 @@ class RadioInternVisionModel(nn.Module): ...@@ -590,44 +590,36 @@ class RadioInternVisionModel(nn.Module):
def get_input_embeddings(self): def get_input_embeddings(self):
return self.embeddings return self.embeddings
def create_inter_image_attention_mask( def inter_image_mask_metadata(
self, imgs_sizes: list[tuple[int, int]], device: torch.device self, imgs_sizes: list[tuple[int, int]], device: torch.device
) -> torch.Tensor: ) -> MaskMetadata:
patch_size = self.patch_generator.patch_size patch_size = self.patch_generator.patch_size
num_skip = self.patch_generator.num_skip num_skip = self.patch_generator.num_skip
seq_lens = calc_seq_lens(imgs_sizes, patch_size) seq_lens = calc_seq_lens(imgs_sizes, patch_size)
patch_counts = [seq_len + num_skip for seq_len in seq_lens] adjusted = [s + num_skip for s in seq_lens]
total_patches = sum(patch_counts) cu_seqlens = torch.tensor(
list(accumulate(adjusted, initial=0)), dtype=torch.int32, device=device
# Create attention mask - default to False (mask out)
mask = torch.zeros(
total_patches, total_patches, dtype=torch.bool, device=device
) )
# Keep max_seqlen on CPU to avoid .item() sync
# Each image's patches can only attend to patches from the same image # See: https://github.com/vllm-project/vllm/blob/20b6b01/vllm/v1/attention/ops/vit_attn_wrappers.py#L48
start_idx = 0 max_seqlen = torch.tensor(max(adjusted), dtype=torch.int32)
for patch_count in patch_counts: return MaskMetadata(cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
end_idx = start_idx + patch_count
# Allow attention within this image's patches
mask[start_idx:end_idx, start_idx:end_idx] = True
start_idx = end_idx
return mask
def forward( def forward(
self, self,
x: torch.Tensor, x: torch.Tensor,
imgs_sizes: torch.Tensor | None = None, imgs_sizes: list[tuple[int, int]] | None = None,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
hidden_states = self.patch_generator(x, imgs_sizes=imgs_sizes) hidden_states = self.patch_generator(x, imgs_sizes=imgs_sizes)
attn_mask = None mask_meta = None
if imgs_sizes is not None and len(imgs_sizes) > 1: if imgs_sizes is not None:
# Dynamic Resolution assert len(imgs_sizes) > 0
attn_mask = self.create_inter_image_attention_mask( # Dynamic resolution: process each image as an independent sequence.
imgs_sizes, device=x.device mask_meta = self.inter_image_mask_metadata(
imgs_sizes, device=hidden_states.device
) )
encoder_outputs = self.encoder(inputs_embeds=hidden_states, attn_mask=attn_mask) encoder_outputs = self.encoder(inputs_embeds=hidden_states, mask_meta=mask_meta)
return encoder_outputs return encoder_outputs
...@@ -670,7 +662,7 @@ class RadioModel(nn.Module): ...@@ -670,7 +662,7 @@ class RadioModel(nn.Module):
pixel_values: torch.Tensor | None = None, pixel_values: torch.Tensor | None = None,
pixel_embeds: torch.Tensor | None = None, pixel_embeds: torch.Tensor | None = None,
*, *,
imgs_sizes: torch.Tensor | None = None, imgs_sizes: list[tuple[int, int]] | None = None,
) -> tuple[torch.FloatTensor, torch.FloatTensor]: ) -> tuple[torch.FloatTensor, torch.FloatTensor]:
y = self.model(pixel_values, imgs_sizes=imgs_sizes) y = self.model(pixel_values, imgs_sizes=imgs_sizes)
return self._extract_final(y, imgs_sizes=imgs_sizes) return self._extract_final(y, imgs_sizes=imgs_sizes)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment