Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
10398b47
Unverified
Commit
10398b47
authored
Dec 05, 2024
by
Isotr0py
Committed by
GitHub
Dec 04, 2024
Browse files
[Model] Consolidate ViTs attention implementation without mask (#10893)
Signed-off-by:
Isotr0py
<
2037008807@qq.com
>
parent
01d079fd
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
109 additions
and
226 deletions
+109
-226
vllm/attention/layer.py
vllm/attention/layer.py
+63
-0
vllm/model_executor/models/blip.py
vllm/model_executor/models/blip.py
+4
-41
vllm/model_executor/models/clip.py
vllm/model_executor/models/clip.py
+4
-42
vllm/model_executor/models/glm4_vision_encoder.py
vllm/model_executor/models/glm4_vision_encoder.py
+6
-16
vllm/model_executor/models/idefics2_vision_model.py
vllm/model_executor/models/idefics2_vision_model.py
+4
-21
vllm/model_executor/models/intern_vit.py
vllm/model_executor/models/intern_vit.py
+4
-24
vllm/model_executor/models/internvl.py
vllm/model_executor/models/internvl.py
+11
-12
vllm/model_executor/models/molmo.py
vllm/model_executor/models/molmo.py
+9
-29
vllm/model_executor/models/siglip.py
vllm/model_executor/models/siglip.py
+4
-41
No files found.
vllm/attention/layer.py
View file @
10398b47
...
...
@@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
vllm.attention
import
AttentionMetadata
,
AttentionType
from
vllm.attention.selector
import
backend_name_to_enum
,
get_attn_backend
...
...
@@ -168,6 +169,68 @@ class Attention(nn.Module):
return
s
class
MultiHeadAttention
(
nn
.
Module
):
"""Multi-headed attention without any cache, used for ViT."""
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
Optional
[
int
]
=
None
,
):
super
().
__init__
()
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
scale
self
.
num_kv_heads
=
num_heads
if
num_kv_heads
is
None
else
num_kv_heads
dtype
=
torch
.
get_default_dtype
()
attn_backend
=
get_attn_backend
(
head_size
,
dtype
,
kv_cache_dtype
=
None
,
block_size
=
16
,
is_attention_free
=
False
)
if
attn_backend
in
{
_Backend
.
FLASH_ATTN
,
_Backend
.
FLASH_ATTN_VLLM_V1
}:
attn_backend
=
_Backend
.
XFORMERS
self
.
attn_backend
=
attn_backend
if
attn_backend
in
{
_Backend
.
TORCH_SDPA
,
_Backend
.
XFORMERS
}
else
_Backend
.
TORCH_SDPA
def
forward
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""Input shape: batch_size x seq_len x hidden_size"""
# TODO(Isotr0py): Use existing backend implementations and support FA2
bsz
,
q_len
,
_
=
query
.
size
()
kv_len
=
key
.
size
(
1
)
query
=
query
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
bsz
,
kv_len
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
bsz
,
kv_len
,
self
.
num_kv_heads
,
self
.
head_size
)
if
self
.
attn_backend
==
_Backend
.
XFORMERS
:
from
xformers
import
ops
as
xops
out
=
xops
.
memory_efficient_attention_forward
(
query
,
key
,
value
,
scale
=
self
.
scale
)
elif
self
.
attn_backend
==
_Backend
.
TORCH_SDPA
:
query
,
key
,
value
=
(
x
.
transpose
(
1
,
2
)
for
x
in
(
query
,
key
,
value
))
out
=
F
.
scaled_dot_product_attention
(
query
,
key
,
value
,
scale
=
self
.
scale
)
out
=
out
.
transpose
(
1
,
2
)
return
out
.
view
(
bsz
,
q_len
,
-
1
)
def
unified_attention
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
...
...
vllm/model_executor/models/blip.py
View file @
10398b47
...
...
@@ -4,11 +4,10 @@ from typing import Iterable, Optional, Set, Tuple, Union
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
PIL
import
Image
from
transformers
import
Blip2VisionConfig
,
BlipVisionConfig
from
vllm.attention.
selecto
r
import
_Backend
from
vllm.attention.
laye
r
import
MultiHeadAttention
from
vllm.config
import
ModelConfig
from
vllm.distributed
import
divide
,
get_tensor_model_parallel_world_size
from
vllm.inputs
import
DecoderOnlyInputs
,
token_inputs
...
...
@@ -22,8 +21,6 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens
)
from
vllm.sequence
import
SequenceData
from
.utils
import
get_vit_attn_backend
def
get_blip_patch_grid_length
(
*
,
image_size
:
int
,
patch_size
:
int
)
->
int
:
assert
image_size
%
patch_size
==
0
...
...
@@ -205,11 +202,8 @@ class BlipAttention(nn.Module):
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
num_heads_per_partition
=
divide
(
self
.
num_heads
,
self
.
tp_size
)
# Detect attention implementation.
self
.
attn_backend
=
get_vit_attn_backend
(
support_fa
=
False
)
if
self
.
attn_backend
not
in
{
_Backend
.
TORCH_SDPA
,
_Backend
.
XFORMERS
}:
raise
RuntimeError
(
f
"BLIP does not support
{
self
.
attn_backend
}
backend now."
)
self
.
attn
=
MultiHeadAttention
(
self
.
num_heads_per_partition
,
self
.
head_dim
,
self
.
scale
)
def
_shape
(
self
,
tensor
:
torch
.
Tensor
,
seq_len
:
int
,
bsz
:
int
):
return
tensor
.
view
(
bsz
,
seq_len
,
self
.
num_heads
,
...
...
@@ -220,41 +214,10 @@ class BlipAttention(nn.Module):
hidden_states
:
torch
.
Tensor
,
):
"""Input shape: Batch x Time x Channel"""
bsz
,
tgt_len
,
_
=
hidden_states
.
size
()
qkv_states
,
_
=
self
.
qkv
(
hidden_states
)
query_states
,
key_states
,
value_states
=
qkv_states
.
chunk
(
3
,
dim
=-
1
)
query_states
=
query_states
.
view
(
bsz
,
tgt_len
,
self
.
num_heads_per_partition
,
self
.
head_dim
)
key_states
=
key_states
.
view
(
bsz
,
tgt_len
,
self
.
num_heads_per_partition
,
self
.
head_dim
)
value_states
=
value_states
.
view
(
bsz
,
tgt_len
,
self
.
num_heads_per_partition
,
self
.
head_dim
)
if
self
.
attn_backend
==
_Backend
.
XFORMERS
:
from
xformers
import
ops
as
xops
out
=
xops
.
memory_efficient_attention_forward
(
query_states
,
key_states
,
value_states
,
p
=
self
.
dropout
,
scale
=
self
.
scale
)
elif
self
.
attn_backend
==
_Backend
.
TORCH_SDPA
:
query_states
,
key_states
,
value_states
=
(
x
.
transpose
(
1
,
2
)
for
x
in
(
query_states
,
key_states
,
value_states
))
out
=
F
.
scaled_dot_product_attention
(
query_states
,
key_states
,
value_states
,
dropout_p
=
self
.
dropout
,
scale
=
self
.
scale
)
out
=
out
.
transpose
(
1
,
2
)
out
=
out
.
view
(
bsz
,
tgt_len
,
-
1
)
out
=
self
.
attn
(
query_states
,
key_states
,
value_states
)
attn_output
,
_
=
self
.
projection
(
out
)
return
attn_output
,
None
...
...
vllm/model_executor/models/clip.py
View file @
10398b47
...
...
@@ -5,11 +5,10 @@ from typing import Iterable, List, Optional, Set, Tuple, Union
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
PIL
import
Image
from
transformers
import
CLIPVisionConfig
from
vllm.attention.
selecto
r
import
_Backend
from
vllm.attention.
laye
r
import
MultiHeadAttention
from
vllm.config
import
ModelConfig
from
vllm.distributed
import
divide
,
get_tensor_model_parallel_world_size
from
vllm.inputs
import
DecoderOnlyInputs
,
token_inputs
...
...
@@ -25,8 +24,6 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
resolve_visual_encoder_outputs
)
from
vllm.sequence
import
SequenceData
from
.utils
import
get_vit_attn_backend
def
get_clip_patch_grid_length
(
*
,
image_size
:
int
,
patch_size
:
int
)
->
int
:
assert
image_size
%
patch_size
==
0
...
...
@@ -235,11 +232,8 @@ class CLIPAttention(nn.Module):
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
num_heads_per_partition
=
divide
(
self
.
num_heads
,
self
.
tp_size
)
# Detect attention implementation.
self
.
attn_backend
=
get_vit_attn_backend
(
support_fa
=
False
)
if
self
.
attn_backend
not
in
{
_Backend
.
TORCH_SDPA
,
_Backend
.
XFORMERS
}:
raise
RuntimeError
(
f
"CLIP does not support
{
self
.
attn_backend
}
backend now."
)
self
.
attn
=
MultiHeadAttention
(
self
.
num_heads_per_partition
,
self
.
head_dim
,
self
.
scale
)
def
_shape
(
self
,
tensor
:
torch
.
Tensor
,
seq_len
:
int
,
bsz
:
int
):
return
tensor
.
view
(
bsz
,
seq_len
,
self
.
num_heads
,
...
...
@@ -250,42 +244,10 @@ class CLIPAttention(nn.Module):
hidden_states
:
torch
.
Tensor
,
):
"""Input shape: Batch x Time x Channel"""
bsz
,
tgt_len
,
_
=
hidden_states
.
size
()
qkv_states
,
_
=
self
.
qkv_proj
(
hidden_states
)
query_states
,
key_states
,
value_states
=
qkv_states
.
chunk
(
3
,
dim
=-
1
)
query_states
=
query_states
.
view
(
bsz
,
tgt_len
,
self
.
num_heads_per_partition
,
self
.
head_dim
)
key_states
=
key_states
.
view
(
bsz
,
tgt_len
,
self
.
num_heads_per_partition
,
self
.
head_dim
)
value_states
=
value_states
.
view
(
bsz
,
tgt_len
,
self
.
num_heads_per_partition
,
self
.
head_dim
)
if
self
.
attn_backend
==
_Backend
.
XFORMERS
:
from
xformers
import
ops
as
xops
out
=
xops
.
memory_efficient_attention_forward
(
query_states
,
key_states
,
value_states
,
p
=
self
.
dropout
,
scale
=
self
.
scale
)
elif
self
.
attn_backend
==
_Backend
.
TORCH_SDPA
:
query_states
,
key_states
,
value_states
=
(
x
.
transpose
(
1
,
2
)
for
x
in
(
query_states
,
key_states
,
value_states
))
out
=
F
.
scaled_dot_product_attention
(
query_states
,
key_states
,
value_states
,
dropout_p
=
self
.
dropout
,
scale
=
self
.
scale
)
out
=
out
.
transpose
(
1
,
2
)
out
=
out
.
view
(
bsz
,
tgt_len
,
-
1
)
out
=
self
.
attn
(
query_states
,
key_states
,
value_states
)
attn_output
,
_
=
self
.
out_proj
(
out
)
return
attn_output
,
None
...
...
vllm/model_executor/models/glm4_vision_encoder.py
View file @
10398b47
...
...
@@ -8,6 +8,7 @@ import torch
from
torch
import
nn
from
torch.nn
import
LayerNorm
from
vllm.attention.layer
import
MultiHeadAttention
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
SiluAndMul
,
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
...
...
@@ -77,27 +78,16 @@ class Attention(nn.Module):
quant_config
=
quant_config
,
)
self
.
attn
=
MultiHeadAttention
(
self
.
num_heads_per_rank
,
self
.
head_dim
,
self
.
scale
)
self
.
output_dropout
=
torch
.
nn
.
Dropout
(
config
.
dropout_prob
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
B
,
L
,
_
=
x
.
shape
qkv
,
_
=
self
.
query_key_value
(
x
)
# B, L, 3 * H * D
q
,
k
,
v
=
qkv
.
chunk
(
3
,
dim
=-
1
)
q
=
q
.
reshape
(
B
,
L
,
self
.
num_heads_per_rank
,
self
.
head_dim
).
permute
(
0
,
2
,
1
,
3
)
# B, H, L, D
k
=
k
.
reshape
(
B
,
L
,
self
.
num_heads_per_rank
,
self
.
head_dim
).
permute
(
0
,
2
,
1
,
3
)
# B, H, L, D
v
=
v
.
reshape
(
B
,
L
,
self
.
num_heads_per_rank
,
self
.
head_dim
).
permute
(
0
,
2
,
1
,
3
)
# B, H, L, D
out
=
torch
.
nn
.
functional
.
scaled_dot_product_attention
(
q
,
k
,
v
,
attn_mask
=
None
,
dropout_p
=
0.
,
is_causal
=
False
)
output
,
_
=
self
.
dense
(
out
.
transpose
(
1
,
2
).
view
(
B
,
L
,
-
1
))
out
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
dense
(
out
)
output
=
self
.
output_dropout
(
output
)
return
output
...
...
vllm/model_executor/models/idefics2_vision_model.py
View file @
10398b47
...
...
@@ -21,8 +21,8 @@ import torch
from
torch
import
nn
from
transformers.models.idefics2.configuration_idefics2
import
(
Idefics2Config
,
Idefics2VisionConfig
)
from
xformers
import
ops
as
xops
from
vllm.attention.layer
import
MultiHeadAttention
from
vllm.distributed
import
divide
,
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
...
...
@@ -141,35 +141,18 @@ class Idefics2VisionAttention(nn.Module):
)
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
num_heads_per_partition
=
divide
(
self
.
num_heads
,
self
.
tp_size
)
self
.
is_causal
=
False
self
.
attn
=
MultiHeadAttention
(
self
.
num_heads_per_partition
,
self
.
head_dim
,
self
.
scale
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
batch_size
,
q_len
,
_
=
hidden_states
.
size
()
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
# batch_size, q_len, 3 * num_heads_per_partition * head_dim
query_states
,
key_states
,
value_states
=
qkv
.
chunk
(
3
,
dim
=-
1
)
query_states
=
query_states
.
view
(
batch_size
,
q_len
,
self
.
num_heads_per_partition
,
self
.
head_dim
)
key_states
=
key_states
.
view
(
batch_size
,
q_len
,
self
.
num_heads_per_partition
,
self
.
head_dim
)
value_states
=
value_states
.
view
(
batch_size
,
q_len
,
self
.
num_heads_per_partition
,
self
.
head_dim
)
# see: https://facebookresearch.github.io/xformers/components/ops.html
out
=
xops
.
memory_efficient_attention_forward
(
query_states
,
key_states
,
value_states
,
p
=
self
.
dropout
,
scale
=
self
.
scale
,
)
out
=
out
.
view
(
batch_size
,
q_len
,
-
1
)
out
=
self
.
attn
(
query_states
,
key_states
,
value_states
)
attn_output
,
_
=
self
.
out_proj
(
out
)
return
attn_output
...
...
vllm/model_executor/models/intern_vit.py
View file @
10398b47
...
...
@@ -12,7 +12,7 @@ import torch.nn as nn
import
torch.nn.functional
as
F
from
transformers
import
PretrainedConfig
from
vllm.attention.
selecto
r
import
_Backend
from
vllm.attention.
laye
r
import
MultiHeadAttention
from
vllm.distributed
import
(
divide
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
split_tensor_along_last_dim
,
...
...
@@ -25,8 +25,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
.utils
import
get_vit_attn_backend
NORM2FN
=
{
'rms_norm'
:
RMSNorm
,
'layer_norm'
:
nn
.
LayerNorm
,
...
...
@@ -183,10 +181,8 @@ class InternParallelAttention(nn.Module):
prefix
=
f
"
{
prefix
}
.proj"
,
)
self
.
attn_backend
=
get_vit_attn_backend
(
support_fa
=
False
)
if
self
.
attn_backend
not
in
{
_Backend
.
TORCH_SDPA
,
_Backend
.
XFORMERS
}:
raise
RuntimeError
(
f
"InternViT does not support
{
self
.
attn_backend
}
backend now."
)
self
.
attn
=
MultiHeadAttention
(
self
.
num_heads_per_partition
,
self
.
head_dim
,
self
.
scale
)
def
_apply_qk_norm
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
):
if
self
.
tp_size
>
1
:
...
...
@@ -209,23 +205,7 @@ class InternParallelAttention(nn.Module):
if
self
.
qk_normalization
:
q
,
k
=
self
.
_apply_qk_norm
(
q
,
k
)
q
=
q
.
view
(
B
,
N
,
self
.
num_heads_per_partition
,
self
.
head_dim
)
k
=
k
.
view
(
B
,
N
,
self
.
num_heads_per_partition
,
self
.
head_dim
)
v
=
v
.
view
(
B
,
N
,
self
.
num_heads_per_partition
,
self
.
head_dim
)
if
self
.
attn_backend
==
_Backend
.
XFORMERS
:
from
xformers
import
ops
as
xops
out
=
xops
.
memory_efficient_attention_forward
(
q
,
k
,
v
,
scale
=
self
.
scale
)
elif
self
.
attn_backend
==
_Backend
.
TORCH_SDPA
:
q
,
k
,
v
=
(
x
.
transpose
(
1
,
2
)
for
x
in
(
q
,
k
,
v
))
out
=
F
.
scaled_dot_product_attention
(
q
,
k
,
v
,
scale
=
self
.
scale
)
out
=
out
.
transpose
(
1
,
2
)
out
=
out
.
view
(
B
,
N
,
-
1
)
out
=
self
.
attn
(
q
,
k
,
v
)
out
,
_
=
self
.
proj
(
out
)
return
out
...
...
vllm/model_executor/models/internvl.py
View file @
10398b47
...
...
@@ -482,6 +482,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
self
.
mlp1
=
self
.
_init_mlp1
(
config
)
self
.
img_context_token_id
=
None
self
.
visual_token_mask
=
None
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
)
...
...
@@ -635,13 +636,12 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
return
image_embeds
def
_
g
et_visual_token_mask
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
_
s
et_visual_token_mask
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
self
.
is_mono
:
visual_token_mask
=
(
self
.
visual_token_mask
=
(
input_ids
==
self
.
img_context_token_id
).
reshape
(
-
1
,
1
)
else
:
visual_token_mask
=
None
return
visual_token_mask
self
.
visual_token_mask
=
None
def
get_multimodal_embeddings
(
self
,
**
kwargs
)
->
Optional
[
NestedTensors
]:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
...
...
@@ -658,6 +658,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
inputs_embeds
=
self
.
language_model
.
get_input_embeddings
(
input_ids
)
if
multimodal_embeddings
is
not
None
:
assert
self
.
img_context_token_id
is
not
None
self
.
_set_visual_token_mask
(
input_ids
)
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
inputs_embeds
,
multimodal_embeddings
,
self
.
img_context_token_id
)
...
...
@@ -674,7 +675,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
**
kwargs
:
object
,
)
->
Union
[
SamplerOutput
,
IntermediateTensors
]:
visual_token_mask
=
None
if
intermediate_tensors
is
not
None
:
input_ids
=
None
inputs_embeds
=
None
...
...
@@ -695,16 +695,15 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
"intermediate_tensors"
:
intermediate_tensors
,
"inputs_embeds"
:
inputs_embeds
,
}
if
self
.
img_context_token_id
is
not
None
:
visual_token_mask
=
self
.
_get_visual_token_mask
(
input_ids
)
# We always overwrite it back to None after computing visual token
# mask so that this doesn't need to depend on encoder output
if
self
.
visual_token_mask
is
not
None
:
# overwrite visual_token_mask and img_context_token_id back to None,
# so that this doesn't need to depend on encoder output
forward_kwargs
.
update
(
{
"visual_token_mask"
:
self
.
visual_token_mask
})
self
.
visual_token_mask
=
None
self
.
img_context_token_id
=
None
if
self
.
is_mono
:
forward_kwargs
.
update
({
"visual_token_mask"
:
visual_token_mask
})
hidden_states
=
self
.
language_model
.
model
(
**
forward_kwargs
)
return
hidden_states
...
...
vllm/model_executor/models/molmo.py
View file @
10398b47
...
...
@@ -13,6 +13,7 @@ from torch.nn import functional as F
from
transformers
import
PretrainedConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention.layer
import
MultiHeadAttention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
...
...
@@ -38,14 +39,12 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalKwargs
from
vllm.multimodal.inputs
import
NestedTensors
from
vllm.multimodal.utils
import
cached_get_tokenizer
from
vllm.platforms
import
_Backend
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
IntermediateTensors
,
SequenceData
)
from
vllm.transformers_utils.processor
import
get_processor
from
.interfaces
import
SupportsMultiModal
,
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
get_vit_attn_backend
,
is_pp_missing_parameter
,
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
...
...
@@ -188,13 +187,11 @@ class MultiHeadDotProductAttention(nn.Module):
quant_config
=
quant_config
,
)
# Detect attention implementation.
self
.
attn_backend
:
_Backend
=
get_vit_attn_backend
(
support_fa
=
True
)
if
self
.
attn_backend
not
in
{
_Backend
.
FLASH_ATTN
,
_Backend
.
TORCH_SDPA
,
_Backend
.
XFORMERS
}:
raise
RuntimeError
(
f
"Molmo does not support
{
self
.
attn_backend
}
backend now."
)
self
.
scale
=
self
.
head_dim
**-
0.5
self
.
attn
=
MultiHeadAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scale
,
num_kv_heads
=
self
.
num_kv_heads
)
def
forward
(
self
,
inputs_q
:
torch
.
Tensor
,
...
...
@@ -210,25 +207,8 @@ class MultiHeadDotProductAttention(nn.Module):
xq
,
_
=
self
.
wq
(
inputs_q
)
xk
,
_
=
self
.
wk
(
inputs_k
)
xv
,
_
=
self
.
wv
(
inputs_v
)
q_shape
=
xq
.
size
()[:
-
1
]
+
(
self
.
num_heads
,
self
.
head_dim
)
kv_shape
=
xk
.
size
()[:
-
1
]
+
(
self
.
num_kv_heads
,
self
.
head_dim
)
xq
=
xq
.
view
(
*
q_shape
)
xk
=
xk
.
view
(
*
kv_shape
)
xv
=
xv
.
view
(
*
kv_shape
)
if
self
.
attn_backend
==
_Backend
.
FLASH_ATTN
:
from
flash_attn
import
flash_attn_func
output
=
flash_attn_func
(
xq
,
xk
,
xv
,
dropout_p
=
0.0
,
causal
=
False
)
elif
self
.
attn_backend
==
_Backend
.
TORCH_SDPA
:
xq
,
xk
,
xv
=
(
rearrange
(
x
,
"b s h d -> b h s d"
)
for
x
in
(
xq
,
xk
,
xv
))
output
=
F
.
scaled_dot_product_attention
(
xq
,
xk
,
xv
)
output
=
rearrange
(
output
,
"b h s d -> b s h d "
)
elif
self
.
attn_backend
==
_Backend
.
XFORMERS
:
from
xformers
import
ops
as
xops
output
=
xops
.
memory_efficient_attention_forward
(
xq
,
xk
,
xv
,
p
=
0
)
output
=
rearrange
(
output
,
"b s h d -> b s (h d)"
).
contiguous
()
output
=
self
.
attn
(
xq
,
xk
,
xv
)
output
,
_
=
self
.
wo
(
output
)
return
output
...
...
vllm/model_executor/models/siglip.py
View file @
10398b47
...
...
@@ -6,12 +6,11 @@ from typing import Iterable, List, Optional, Set, Tuple, Union
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
from
PIL
import
Image
from
torch
import
nn
from
transformers
import
SiglipVisionConfig
from
vllm.attention.
selecto
r
import
_Backend
from
vllm.attention.
laye
r
import
MultiHeadAttention
from
vllm.config
import
ModelConfig
from
vllm.distributed
import
divide
,
get_tensor_model_parallel_world_size
from
vllm.inputs
import
DecoderOnlyInputs
,
token_inputs
...
...
@@ -29,8 +28,6 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
resolve_visual_encoder_outputs
)
from
vllm.sequence
import
SequenceData
from
.utils
import
get_vit_attn_backend
def
get_siglip_patch_grid_length
(
*
,
image_size
:
int
,
patch_size
:
int
)
->
int
:
# Since interpolation is applied, the image size need not be divisible
...
...
@@ -291,52 +288,18 @@ class SiglipAttention(nn.Module):
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
num_heads_per_partition
=
divide
(
self
.
num_heads
,
self
.
tp_size
)
self
.
attn_backend
=
get_vit_attn_backend
(
support_fa
=
False
)
if
self
.
attn_backend
not
in
{
_Backend
.
TORCH_SDPA
,
_Backend
.
XFORMERS
}:
raise
RuntimeError
(
f
"SIGLIP does not support
{
self
.
attn_backend
}
backend now."
)
self
.
attn
=
MultiHeadAttention
(
self
.
num_heads_per_partition
,
self
.
head_dim
,
self
.
scale
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""Input shape: Batch x Time x Channel"""
batch_size
,
q_len
,
_
=
hidden_states
.
size
()
qkv_states
,
_
=
self
.
qkv_proj
(
hidden_states
)
query_states
,
key_states
,
value_states
=
qkv_states
.
chunk
(
3
,
dim
=-
1
)
query_states
=
query_states
.
view
(
batch_size
,
q_len
,
self
.
num_heads_per_partition
,
self
.
head_dim
)
key_states
=
key_states
.
view
(
batch_size
,
q_len
,
self
.
num_heads_per_partition
,
self
.
head_dim
)
value_states
=
value_states
.
view
(
batch_size
,
q_len
,
self
.
num_heads_per_partition
,
self
.
head_dim
)
if
self
.
attn_backend
==
_Backend
.
XFORMERS
:
from
xformers
import
ops
as
xops
out
=
xops
.
memory_efficient_attention_forward
(
query_states
,
key_states
,
value_states
,
p
=
self
.
dropout
,
scale
=
self
.
scale
)
elif
self
.
attn_backend
==
_Backend
.
TORCH_SDPA
:
query_states
,
key_states
,
value_states
=
(
x
.
transpose
(
1
,
2
)
for
x
in
(
query_states
,
key_states
,
value_states
))
out
=
F
.
scaled_dot_product_attention
(
query_states
,
key_states
,
value_states
,
dropout_p
=
self
.
dropout
,
scale
=
self
.
scale
)
out
=
out
.
transpose
(
1
,
2
)
out
=
out
.
view
(
batch_size
,
q_len
,
-
1
)
out
=
self
.
attn
(
query_states
,
key_states
,
value_states
)
attn_output
,
_
=
self
.
out_proj
(
out
)
return
attn_output
,
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment