Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
289fc48a
Unverified
Commit
289fc48a
authored
Mar 04, 2026
by
Netanel Haber
Committed by
GitHub
Mar 04, 2026
Browse files
Use MMEncoderAttention (=use FlashAttention) instead of torch.sdpa in radio.py (#35653)
parent
2f2212e6
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
36 additions
and
44 deletions
+36
-44
vllm/model_executor/models/radio.py
vllm/model_executor/models/radio.py
+36
-44
No files found.
vllm/model_executor/models/radio.py
View file @
289fc48a
...
...
@@ -10,7 +10,8 @@
import
math
from
collections.abc
import
Iterable
from
itertools
import
repeat
from
dataclasses
import
dataclass
from
itertools
import
accumulate
,
repeat
from
typing
import
TypeAlias
import
torch
...
...
@@ -477,28 +478,27 @@ class ViTPatchLinear(nn.Linear):
self
.
patch_size
=
patch_size
@
dataclass
(
frozen
=
True
,
kw_only
=
True
)
class
MaskMetadata
:
cu_seqlens
:
torch
.
Tensor
max_seqlen
:
torch
.
Tensor
class
RadioParallelAttention
(
InternParallelAttention
):
def
forward
(
self
,
x
:
torch
.
Tensor
,
attn_mask
:
torch
.
Tensor
|
None
=
None
self
,
x
:
torch
.
Tensor
,
mask_meta
:
MaskMetadata
|
None
=
None
)
->
torch
.
Tensor
:
if
attn_mask
is
None
:
return
super
().
forward
(
x
)
B
,
N
,
_
=
x
.
shape
qkv
,
_
=
self
.
qkv
(
x
)
q
,
k
,
v
=
qkv
.
chunk
(
3
,
dim
=-
1
)
if
self
.
qk_normalization
:
q
,
k
=
self
.
_apply_qk_norm
(
q
,
k
)
q
=
q
.
view
(
B
,
N
,
self
.
num_heads_per_partition
,
self
.
head_dim
)
k
=
k
.
view
(
B
,
N
,
self
.
num_heads_per_partition
,
self
.
head_dim
)
v
=
v
.
view
(
B
,
N
,
self
.
num_heads_per_partition
,
self
.
head_dim
)
q
,
k
,
v
=
(
t
.
transpose
(
1
,
2
)
for
t
in
(
q
,
k
,
v
))
out
=
F
.
scaled_dot_product_attention
(
q
,
k
,
v
,
attn_mask
=
attn_mask
,
scale
=
self
.
scale
)
out
=
out
.
transpose
(
1
,
2
).
reshape
(
B
,
N
,
-
1
)
cu_seqlens
,
max_seqlen
=
None
,
None
if
mask_meta
is
not
None
:
cu_seqlens
=
mask_meta
.
cu_seqlens
max_seqlen
=
mask_meta
.
max_seqlen
out
=
self
.
attn
(
q
,
k
,
v
,
cu_seqlens
=
cu_seqlens
,
max_seqlen
=
max_seqlen
)
out
,
_
=
self
.
proj
(
out
)
return
out
...
...
@@ -510,11 +510,11 @@ class RadioVisionEncoderLayer(InternVisionEncoderLayer):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attn_mask
:
torch
.
Tensor
|
None
=
None
,
mask_meta
:
MaskMetadata
|
None
=
None
,
):
hidden_states
=
(
hidden_states
+
self
.
attn
(
self
.
norm1
(
hidden_states
),
attn_mask
=
attn_mask
)
*
self
.
ls1
+
self
.
attn
(
self
.
norm1
(
hidden_states
),
mask_meta
=
mask_meta
)
*
self
.
ls1
)
hidden_states
=
hidden_states
+
self
.
mlp
(
self
.
norm2
(
hidden_states
))
*
self
.
ls2
...
...
@@ -529,11 +529,11 @@ class RadioVisionEncoder(InternVisionEncoder):
def
forward
(
self
,
inputs_embeds
:
torch
.
Tensor
,
attn_mask
:
torch
.
Tensor
|
None
=
None
,
mask_meta
:
MaskMetadata
|
None
=
None
,
):
hidden_states
=
inputs_embeds
for
encoder_layer
in
self
.
layers
:
hidden_states
=
encoder_layer
(
hidden_states
,
attn_mask
=
attn_mask
)
hidden_states
=
encoder_layer
(
hidden_states
,
mask_meta
=
mask_meta
)
return
hidden_states
...
...
@@ -590,44 +590,36 @@ class RadioInternVisionModel(nn.Module):
def
get_input_embeddings
(
self
):
return
self
.
embeddings
def
create_
inter_image_
attention_mask
(
def
inter_image_
mask_metadata
(
self
,
imgs_sizes
:
list
[
tuple
[
int
,
int
]],
device
:
torch
.
device
)
->
torch
.
Tensor
:
)
->
MaskMetadata
:
patch_size
=
self
.
patch_generator
.
patch_size
num_skip
=
self
.
patch_generator
.
num_skip
seq_lens
=
calc_seq_lens
(
imgs_sizes
,
patch_size
)
patch_counts
=
[
seq_len
+
num_skip
for
seq_len
in
seq_lens
]
total_patches
=
sum
(
patch_counts
)
# Create attention mask - default to False (mask out)
mask
=
torch
.
zeros
(
total_patches
,
total_patches
,
dtype
=
torch
.
bool
,
device
=
device
adjusted
=
[
s
+
num_skip
for
s
in
seq_lens
]
cu_seqlens
=
torch
.
tensor
(
list
(
accumulate
(
adjusted
,
initial
=
0
)),
dtype
=
torch
.
int32
,
device
=
device
)
# Each image's patches can only attend to patches from the same image
start_idx
=
0
for
patch_count
in
patch_counts
:
end_idx
=
start_idx
+
patch_count
# Allow attention within this image's patches
mask
[
start_idx
:
end_idx
,
start_idx
:
end_idx
]
=
True
start_idx
=
end_idx
return
mask
# Keep max_seqlen on CPU to avoid .item() sync
# See: https://github.com/vllm-project/vllm/blob/20b6b01/vllm/v1/attention/ops/vit_attn_wrappers.py#L48
max_seqlen
=
torch
.
tensor
(
max
(
adjusted
),
dtype
=
torch
.
int32
)
return
MaskMetadata
(
cu_seqlens
=
cu_seqlens
,
max_seqlen
=
max_seqlen
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
imgs_sizes
:
torch
.
Tensor
|
None
=
None
,
imgs_sizes
:
list
[
tuple
[
int
,
int
]]
|
None
=
None
,
)
->
torch
.
FloatTensor
:
hidden_states
=
self
.
patch_generator
(
x
,
imgs_sizes
=
imgs_sizes
)
attn_mask
=
None
if
imgs_sizes
is
not
None
and
len
(
imgs_sizes
)
>
1
:
# Dynamic Resolution
attn_mask
=
self
.
create_inter_image_attention_mask
(
imgs_sizes
,
device
=
x
.
device
mask_meta
=
None
if
imgs_sizes
is
not
None
:
assert
len
(
imgs_sizes
)
>
0
# Dynamic resolution: process each image as an independent sequence.
mask_meta
=
self
.
inter_image_mask_metadata
(
imgs_sizes
,
device
=
hidden_states
.
device
)
encoder_outputs
=
self
.
encoder
(
inputs_embeds
=
hidden_states
,
attn_mask
=
attn_mask
)
encoder_outputs
=
self
.
encoder
(
inputs_embeds
=
hidden_states
,
mask_meta
=
mask_meta
)
return
encoder_outputs
...
...
@@ -670,7 +662,7 @@ class RadioModel(nn.Module):
pixel_values
:
torch
.
Tensor
|
None
=
None
,
pixel_embeds
:
torch
.
Tensor
|
None
=
None
,
*
,
imgs_sizes
:
torch
.
Tensor
|
None
=
None
,
imgs_sizes
:
list
[
tuple
[
int
,
int
]]
|
None
=
None
,
)
->
tuple
[
torch
.
FloatTensor
,
torch
.
FloatTensor
]:
y
=
self
.
model
(
pixel_values
,
imgs_sizes
=
imgs_sizes
)
return
self
.
_extract_final
(
y
,
imgs_sizes
=
imgs_sizes
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment