Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
10398b47
Unverified
Commit
10398b47
authored
Dec 05, 2024
by
Isotr0py
Committed by
GitHub
Dec 04, 2024
Browse files
[Model] Consolidate ViTs attention implementation without mask (#10893)
Signed-off-by:
Isotr0py
<
2037008807@qq.com
>
parent
01d079fd
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
109 additions
and
226 deletions
+109
-226
vllm/attention/layer.py
vllm/attention/layer.py
+63
-0
vllm/model_executor/models/blip.py
vllm/model_executor/models/blip.py
+4
-41
vllm/model_executor/models/clip.py
vllm/model_executor/models/clip.py
+4
-42
vllm/model_executor/models/glm4_vision_encoder.py
vllm/model_executor/models/glm4_vision_encoder.py
+6
-16
vllm/model_executor/models/idefics2_vision_model.py
vllm/model_executor/models/idefics2_vision_model.py
+4
-21
vllm/model_executor/models/intern_vit.py
vllm/model_executor/models/intern_vit.py
+4
-24
vllm/model_executor/models/internvl.py
vllm/model_executor/models/internvl.py
+11
-12
vllm/model_executor/models/molmo.py
vllm/model_executor/models/molmo.py
+9
-29
vllm/model_executor/models/siglip.py
vllm/model_executor/models/siglip.py
+4
-41
No files found.
vllm/attention/layer.py
View file @
10398b47
...
@@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional
...
@@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
vllm.attention
import
AttentionMetadata
,
AttentionType
from
vllm.attention
import
AttentionMetadata
,
AttentionType
from
vllm.attention.selector
import
backend_name_to_enum
,
get_attn_backend
from
vllm.attention.selector
import
backend_name_to_enum
,
get_attn_backend
...
@@ -168,6 +169,68 @@ class Attention(nn.Module):
...
@@ -168,6 +169,68 @@ class Attention(nn.Module):
return
s
return
s
class
MultiHeadAttention
(
nn
.
Module
):
"""Multi-headed attention without any cache, used for ViT."""
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
Optional
[
int
]
=
None
,
):
super
().
__init__
()
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
scale
self
.
num_kv_heads
=
num_heads
if
num_kv_heads
is
None
else
num_kv_heads
dtype
=
torch
.
get_default_dtype
()
attn_backend
=
get_attn_backend
(
head_size
,
dtype
,
kv_cache_dtype
=
None
,
block_size
=
16
,
is_attention_free
=
False
)
if
attn_backend
in
{
_Backend
.
FLASH_ATTN
,
_Backend
.
FLASH_ATTN_VLLM_V1
}:
attn_backend
=
_Backend
.
XFORMERS
self
.
attn_backend
=
attn_backend
if
attn_backend
in
{
_Backend
.
TORCH_SDPA
,
_Backend
.
XFORMERS
}
else
_Backend
.
TORCH_SDPA
def
forward
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""Input shape: batch_size x seq_len x hidden_size"""
# TODO(Isotr0py): Use existing backend implementations and support FA2
bsz
,
q_len
,
_
=
query
.
size
()
kv_len
=
key
.
size
(
1
)
query
=
query
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
bsz
,
kv_len
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
bsz
,
kv_len
,
self
.
num_kv_heads
,
self
.
head_size
)
if
self
.
attn_backend
==
_Backend
.
XFORMERS
:
from
xformers
import
ops
as
xops
out
=
xops
.
memory_efficient_attention_forward
(
query
,
key
,
value
,
scale
=
self
.
scale
)
elif
self
.
attn_backend
==
_Backend
.
TORCH_SDPA
:
query
,
key
,
value
=
(
x
.
transpose
(
1
,
2
)
for
x
in
(
query
,
key
,
value
))
out
=
F
.
scaled_dot_product_attention
(
query
,
key
,
value
,
scale
=
self
.
scale
)
out
=
out
.
transpose
(
1
,
2
)
return
out
.
view
(
bsz
,
q_len
,
-
1
)
def
unified_attention
(
def
unified_attention
(
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
...
...
vllm/model_executor/models/blip.py
View file @
10398b47
...
@@ -4,11 +4,10 @@ from typing import Iterable, Optional, Set, Tuple, Union
...
@@ -4,11 +4,10 @@ from typing import Iterable, Optional, Set, Tuple, Union
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
PIL
import
Image
from
PIL
import
Image
from
transformers
import
Blip2VisionConfig
,
BlipVisionConfig
from
transformers
import
Blip2VisionConfig
,
BlipVisionConfig
from
vllm.attention.
selecto
r
import
_Backend
from
vllm.attention.
laye
r
import
MultiHeadAttention
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.distributed
import
divide
,
get_tensor_model_parallel_world_size
from
vllm.distributed
import
divide
,
get_tensor_model_parallel_world_size
from
vllm.inputs
import
DecoderOnlyInputs
,
token_inputs
from
vllm.inputs
import
DecoderOnlyInputs
,
token_inputs
...
@@ -22,8 +21,6 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
...
@@ -22,8 +21,6 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens
)
repeat_and_pad_placeholder_tokens
)
from
vllm.sequence
import
SequenceData
from
vllm.sequence
import
SequenceData
from
.utils
import
get_vit_attn_backend
def
get_blip_patch_grid_length
(
*
,
image_size
:
int
,
patch_size
:
int
)
->
int
:
def
get_blip_patch_grid_length
(
*
,
image_size
:
int
,
patch_size
:
int
)
->
int
:
assert
image_size
%
patch_size
==
0
assert
image_size
%
patch_size
==
0
...
@@ -205,11 +202,8 @@ class BlipAttention(nn.Module):
...
@@ -205,11 +202,8 @@ class BlipAttention(nn.Module):
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
num_heads_per_partition
=
divide
(
self
.
num_heads
,
self
.
tp_size
)
self
.
num_heads_per_partition
=
divide
(
self
.
num_heads
,
self
.
tp_size
)
# Detect attention implementation.
self
.
attn
=
MultiHeadAttention
(
self
.
num_heads_per_partition
,
self
.
attn_backend
=
get_vit_attn_backend
(
support_fa
=
False
)
self
.
head_dim
,
self
.
scale
)
if
self
.
attn_backend
not
in
{
_Backend
.
TORCH_SDPA
,
_Backend
.
XFORMERS
}:
raise
RuntimeError
(
f
"BLIP does not support
{
self
.
attn_backend
}
backend now."
)
def
_shape
(
self
,
tensor
:
torch
.
Tensor
,
seq_len
:
int
,
bsz
:
int
):
def
_shape
(
self
,
tensor
:
torch
.
Tensor
,
seq_len
:
int
,
bsz
:
int
):
return
tensor
.
view
(
bsz
,
seq_len
,
self
.
num_heads
,
return
tensor
.
view
(
bsz
,
seq_len
,
self
.
num_heads
,
...
@@ -220,41 +214,10 @@ class BlipAttention(nn.Module):
...
@@ -220,41 +214,10 @@ class BlipAttention(nn.Module):
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
):
):
"""Input shape: Batch x Time x Channel"""
"""Input shape: Batch x Time x Channel"""
bsz
,
tgt_len
,
_
=
hidden_states
.
size
()
qkv_states
,
_
=
self
.
qkv
(
hidden_states
)
qkv_states
,
_
=
self
.
qkv
(
hidden_states
)
query_states
,
key_states
,
value_states
=
qkv_states
.
chunk
(
3
,
dim
=-
1
)
query_states
,
key_states
,
value_states
=
qkv_states
.
chunk
(
3
,
dim
=-
1
)
query_states
=
query_states
.
view
(
bsz
,
tgt_len
,
out
=
self
.
attn
(
query_states
,
key_states
,
value_states
)
self
.
num_heads_per_partition
,
self
.
head_dim
)
key_states
=
key_states
.
view
(
bsz
,
tgt_len
,
self
.
num_heads_per_partition
,
self
.
head_dim
)
value_states
=
value_states
.
view
(
bsz
,
tgt_len
,
self
.
num_heads_per_partition
,
self
.
head_dim
)
if
self
.
attn_backend
==
_Backend
.
XFORMERS
:
from
xformers
import
ops
as
xops
out
=
xops
.
memory_efficient_attention_forward
(
query_states
,
key_states
,
value_states
,
p
=
self
.
dropout
,
scale
=
self
.
scale
)
elif
self
.
attn_backend
==
_Backend
.
TORCH_SDPA
:
query_states
,
key_states
,
value_states
=
(
x
.
transpose
(
1
,
2
)
for
x
in
(
query_states
,
key_states
,
value_states
))
out
=
F
.
scaled_dot_product_attention
(
query_states
,
key_states
,
value_states
,
dropout_p
=
self
.
dropout
,
scale
=
self
.
scale
)
out
=
out
.
transpose
(
1
,
2
)
out
=
out
.
view
(
bsz
,
tgt_len
,
-
1
)
attn_output
,
_
=
self
.
projection
(
out
)
attn_output
,
_
=
self
.
projection
(
out
)
return
attn_output
,
None
return
attn_output
,
None
...
...
vllm/model_executor/models/clip.py
View file @
10398b47
...
@@ -5,11 +5,10 @@ from typing import Iterable, List, Optional, Set, Tuple, Union
...
@@ -5,11 +5,10 @@ from typing import Iterable, List, Optional, Set, Tuple, Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
PIL
import
Image
from
PIL
import
Image
from
transformers
import
CLIPVisionConfig
from
transformers
import
CLIPVisionConfig
from
vllm.attention.
selecto
r
import
_Backend
from
vllm.attention.
laye
r
import
MultiHeadAttention
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.distributed
import
divide
,
get_tensor_model_parallel_world_size
from
vllm.distributed
import
divide
,
get_tensor_model_parallel_world_size
from
vllm.inputs
import
DecoderOnlyInputs
,
token_inputs
from
vllm.inputs
import
DecoderOnlyInputs
,
token_inputs
...
@@ -25,8 +24,6 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
...
@@ -25,8 +24,6 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
resolve_visual_encoder_outputs
)
resolve_visual_encoder_outputs
)
from
vllm.sequence
import
SequenceData
from
vllm.sequence
import
SequenceData
from
.utils
import
get_vit_attn_backend
def
get_clip_patch_grid_length
(
*
,
image_size
:
int
,
patch_size
:
int
)
->
int
:
def
get_clip_patch_grid_length
(
*
,
image_size
:
int
,
patch_size
:
int
)
->
int
:
assert
image_size
%
patch_size
==
0
assert
image_size
%
patch_size
==
0
...
@@ -235,11 +232,8 @@ class CLIPAttention(nn.Module):
...
@@ -235,11 +232,8 @@ class CLIPAttention(nn.Module):
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
num_heads_per_partition
=
divide
(
self
.
num_heads
,
self
.
tp_size
)
self
.
num_heads_per_partition
=
divide
(
self
.
num_heads
,
self
.
tp_size
)
# Detect attention implementation.
self
.
attn
=
MultiHeadAttention
(
self
.
num_heads_per_partition
,
self
.
attn_backend
=
get_vit_attn_backend
(
support_fa
=
False
)
self
.
head_dim
,
self
.
scale
)
if
self
.
attn_backend
not
in
{
_Backend
.
TORCH_SDPA
,
_Backend
.
XFORMERS
}:
raise
RuntimeError
(
f
"CLIP does not support
{
self
.
attn_backend
}
backend now."
)
def
_shape
(
self
,
tensor
:
torch
.
Tensor
,
seq_len
:
int
,
bsz
:
int
):
def
_shape
(
self
,
tensor
:
torch
.
Tensor
,
seq_len
:
int
,
bsz
:
int
):
return
tensor
.
view
(
bsz
,
seq_len
,
self
.
num_heads
,
return
tensor
.
view
(
bsz
,
seq_len
,
self
.
num_heads
,
...
@@ -250,42 +244,10 @@ class CLIPAttention(nn.Module):
...
@@ -250,42 +244,10 @@ class CLIPAttention(nn.Module):
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
):
):
"""Input shape: Batch x Time x Channel"""
"""Input shape: Batch x Time x Channel"""
bsz
,
tgt_len
,
_
=
hidden_states
.
size
()
qkv_states
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv_states
,
_
=
self
.
qkv_proj
(
hidden_states
)
query_states
,
key_states
,
value_states
=
qkv_states
.
chunk
(
3
,
dim
=-
1
)
query_states
,
key_states
,
value_states
=
qkv_states
.
chunk
(
3
,
dim
=-
1
)
out
=
self
.
attn
(
query_states
,
key_states
,
value_states
)
query_states
=
query_states
.
view
(
bsz
,
tgt_len
,
self
.
num_heads_per_partition
,
self
.
head_dim
)
key_states
=
key_states
.
view
(
bsz
,
tgt_len
,
self
.
num_heads_per_partition
,
self
.
head_dim
)
value_states
=
value_states
.
view
(
bsz
,
tgt_len
,
self
.
num_heads_per_partition
,
self
.
head_dim
)
if
self
.
attn_backend
==
_Backend
.
XFORMERS
:
from
xformers
import
ops
as
xops
out
=
xops
.
memory_efficient_attention_forward
(
query_states
,
key_states
,
value_states
,
p
=
self
.
dropout
,
scale
=
self
.
scale
)
elif
self
.
attn_backend
==
_Backend
.
TORCH_SDPA
:
query_states
,
key_states
,
value_states
=
(
x
.
transpose
(
1
,
2
)
for
x
in
(
query_states
,
key_states
,
value_states
))
out
=
F
.
scaled_dot_product_attention
(
query_states
,
key_states
,
value_states
,
dropout_p
=
self
.
dropout
,
scale
=
self
.
scale
)
out
=
out
.
transpose
(
1
,
2
)
out
=
out
.
view
(
bsz
,
tgt_len
,
-
1
)
attn_output
,
_
=
self
.
out_proj
(
out
)
attn_output
,
_
=
self
.
out_proj
(
out
)
return
attn_output
,
None
return
attn_output
,
None
...
...
vllm/model_executor/models/glm4_vision_encoder.py
View file @
10398b47
...
@@ -8,6 +8,7 @@ import torch
...
@@ -8,6 +8,7 @@ import torch
from
torch
import
nn
from
torch
import
nn
from
torch.nn
import
LayerNorm
from
torch.nn
import
LayerNorm
from
vllm.attention.layer
import
MultiHeadAttention
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
SiluAndMul
,
get_act_fn
from
vllm.model_executor.layers.activation
import
SiluAndMul
,
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
...
@@ -77,27 +78,16 @@ class Attention(nn.Module):
...
@@ -77,27 +78,16 @@ class Attention(nn.Module):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
)
)
self
.
attn
=
MultiHeadAttention
(
self
.
num_heads_per_rank
,
self
.
head_dim
,
self
.
scale
)
self
.
output_dropout
=
torch
.
nn
.
Dropout
(
config
.
dropout_prob
)
self
.
output_dropout
=
torch
.
nn
.
Dropout
(
config
.
dropout_prob
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
B
,
L
,
_
=
x
.
shape
qkv
,
_
=
self
.
query_key_value
(
x
)
# B, L, 3 * H * D
qkv
,
_
=
self
.
query_key_value
(
x
)
# B, L, 3 * H * D
q
,
k
,
v
=
qkv
.
chunk
(
3
,
dim
=-
1
)
q
,
k
,
v
=
qkv
.
chunk
(
3
,
dim
=-
1
)
q
=
q
.
reshape
(
B
,
L
,
self
.
num_heads_per_rank
,
self
.
head_dim
).
permute
(
0
,
2
,
1
,
3
)
# B, H, L, D
out
=
self
.
attn
(
q
,
k
,
v
)
k
=
k
.
reshape
(
B
,
L
,
self
.
num_heads_per_rank
,
output
,
_
=
self
.
dense
(
out
)
self
.
head_dim
).
permute
(
0
,
2
,
1
,
3
)
# B, H, L, D
v
=
v
.
reshape
(
B
,
L
,
self
.
num_heads_per_rank
,
self
.
head_dim
).
permute
(
0
,
2
,
1
,
3
)
# B, H, L, D
out
=
torch
.
nn
.
functional
.
scaled_dot_product_attention
(
q
,
k
,
v
,
attn_mask
=
None
,
dropout_p
=
0.
,
is_causal
=
False
)
output
,
_
=
self
.
dense
(
out
.
transpose
(
1
,
2
).
view
(
B
,
L
,
-
1
))
output
=
self
.
output_dropout
(
output
)
output
=
self
.
output_dropout
(
output
)
return
output
return
output
...
...
vllm/model_executor/models/idefics2_vision_model.py
View file @
10398b47
...
@@ -21,8 +21,8 @@ import torch
...
@@ -21,8 +21,8 @@ import torch
from
torch
import
nn
from
torch
import
nn
from
transformers.models.idefics2.configuration_idefics2
import
(
from
transformers.models.idefics2.configuration_idefics2
import
(
Idefics2Config
,
Idefics2VisionConfig
)
Idefics2Config
,
Idefics2VisionConfig
)
from
xformers
import
ops
as
xops
from
vllm.attention.layer
import
MultiHeadAttention
from
vllm.distributed
import
divide
,
get_tensor_model_parallel_world_size
from
vllm.distributed
import
divide
,
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
...
@@ -141,35 +141,18 @@ class Idefics2VisionAttention(nn.Module):
...
@@ -141,35 +141,18 @@ class Idefics2VisionAttention(nn.Module):
)
)
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
num_heads_per_partition
=
divide
(
self
.
num_heads
,
self
.
tp_size
)
self
.
num_heads_per_partition
=
divide
(
self
.
num_heads
,
self
.
tp_size
)
self
.
is_causal
=
False
self
.
attn
=
MultiHeadAttention
(
self
.
num_heads_per_partition
,
self
.
head_dim
,
self
.
scale
)
def
forward
(
def
forward
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
batch_size
,
q_len
,
_
=
hidden_states
.
size
()
qkv
,
_
=
self
.
qkv_proj
(
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
hidden_states
)
# batch_size, q_len, 3 * num_heads_per_partition * head_dim
)
# batch_size, q_len, 3 * num_heads_per_partition * head_dim
query_states
,
key_states
,
value_states
=
qkv
.
chunk
(
3
,
dim
=-
1
)
query_states
,
key_states
,
value_states
=
qkv
.
chunk
(
3
,
dim
=-
1
)
query_states
=
query_states
.
view
(
batch_size
,
q_len
,
out
=
self
.
attn
(
query_states
,
key_states
,
value_states
)
self
.
num_heads_per_partition
,
self
.
head_dim
)
key_states
=
key_states
.
view
(
batch_size
,
q_len
,
self
.
num_heads_per_partition
,
self
.
head_dim
)
value_states
=
value_states
.
view
(
batch_size
,
q_len
,
self
.
num_heads_per_partition
,
self
.
head_dim
)
# see: https://facebookresearch.github.io/xformers/components/ops.html
out
=
xops
.
memory_efficient_attention_forward
(
query_states
,
key_states
,
value_states
,
p
=
self
.
dropout
,
scale
=
self
.
scale
,
)
out
=
out
.
view
(
batch_size
,
q_len
,
-
1
)
attn_output
,
_
=
self
.
out_proj
(
out
)
attn_output
,
_
=
self
.
out_proj
(
out
)
return
attn_output
return
attn_output
...
...
vllm/model_executor/models/intern_vit.py
View file @
10398b47
...
@@ -12,7 +12,7 @@ import torch.nn as nn
...
@@ -12,7 +12,7 @@ import torch.nn as nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
vllm.attention.
selecto
r
import
_Backend
from
vllm.attention.
laye
r
import
MultiHeadAttention
from
vllm.distributed
import
(
divide
,
get_tensor_model_parallel_rank
,
from
vllm.distributed
import
(
divide
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
split_tensor_along_last_dim
,
split_tensor_along_last_dim
,
...
@@ -25,8 +25,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...
@@ -25,8 +25,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
.utils
import
get_vit_attn_backend
NORM2FN
=
{
NORM2FN
=
{
'rms_norm'
:
RMSNorm
,
'rms_norm'
:
RMSNorm
,
'layer_norm'
:
nn
.
LayerNorm
,
'layer_norm'
:
nn
.
LayerNorm
,
...
@@ -183,10 +181,8 @@ class InternParallelAttention(nn.Module):
...
@@ -183,10 +181,8 @@ class InternParallelAttention(nn.Module):
prefix
=
f
"
{
prefix
}
.proj"
,
prefix
=
f
"
{
prefix
}
.proj"
,
)
)
self
.
attn_backend
=
get_vit_attn_backend
(
support_fa
=
False
)
self
.
attn
=
MultiHeadAttention
(
self
.
num_heads_per_partition
,
if
self
.
attn_backend
not
in
{
_Backend
.
TORCH_SDPA
,
_Backend
.
XFORMERS
}:
self
.
head_dim
,
self
.
scale
)
raise
RuntimeError
(
f
"InternViT does not support
{
self
.
attn_backend
}
backend now."
)
def
_apply_qk_norm
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
):
def
_apply_qk_norm
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
):
if
self
.
tp_size
>
1
:
if
self
.
tp_size
>
1
:
...
@@ -209,23 +205,7 @@ class InternParallelAttention(nn.Module):
...
@@ -209,23 +205,7 @@ class InternParallelAttention(nn.Module):
if
self
.
qk_normalization
:
if
self
.
qk_normalization
:
q
,
k
=
self
.
_apply_qk_norm
(
q
,
k
)
q
,
k
=
self
.
_apply_qk_norm
(
q
,
k
)
q
=
q
.
view
(
B
,
N
,
self
.
num_heads_per_partition
,
self
.
head_dim
)
out
=
self
.
attn
(
q
,
k
,
v
)
k
=
k
.
view
(
B
,
N
,
self
.
num_heads_per_partition
,
self
.
head_dim
)
v
=
v
.
view
(
B
,
N
,
self
.
num_heads_per_partition
,
self
.
head_dim
)
if
self
.
attn_backend
==
_Backend
.
XFORMERS
:
from
xformers
import
ops
as
xops
out
=
xops
.
memory_efficient_attention_forward
(
q
,
k
,
v
,
scale
=
self
.
scale
)
elif
self
.
attn_backend
==
_Backend
.
TORCH_SDPA
:
q
,
k
,
v
=
(
x
.
transpose
(
1
,
2
)
for
x
in
(
q
,
k
,
v
))
out
=
F
.
scaled_dot_product_attention
(
q
,
k
,
v
,
scale
=
self
.
scale
)
out
=
out
.
transpose
(
1
,
2
)
out
=
out
.
view
(
B
,
N
,
-
1
)
out
,
_
=
self
.
proj
(
out
)
out
,
_
=
self
.
proj
(
out
)
return
out
return
out
...
...
vllm/model_executor/models/internvl.py
View file @
10398b47
...
@@ -482,6 +482,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -482,6 +482,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
self
.
mlp1
=
self
.
_init_mlp1
(
config
)
self
.
mlp1
=
self
.
_init_mlp1
(
config
)
self
.
img_context_token_id
=
None
self
.
img_context_token_id
=
None
self
.
visual_token_mask
=
None
self
.
make_empty_intermediate_tensors
=
(
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
)
self
.
language_model
.
make_empty_intermediate_tensors
)
...
@@ -635,13 +636,12 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -635,13 +636,12 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
return
image_embeds
return
image_embeds
def
_
g
et_visual_token_mask
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
_
s
et_visual_token_mask
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
self
.
is_mono
:
if
self
.
is_mono
:
visual_token_mask
=
(
self
.
visual_token_mask
=
(
input_ids
==
self
.
img_context_token_id
).
reshape
(
-
1
,
1
)
input_ids
==
self
.
img_context_token_id
).
reshape
(
-
1
,
1
)
else
:
else
:
visual_token_mask
=
None
self
.
visual_token_mask
=
None
return
visual_token_mask
def
get_multimodal_embeddings
(
self
,
**
kwargs
)
->
Optional
[
NestedTensors
]:
def
get_multimodal_embeddings
(
self
,
**
kwargs
)
->
Optional
[
NestedTensors
]:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
...
@@ -658,6 +658,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -658,6 +658,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
inputs_embeds
=
self
.
language_model
.
get_input_embeddings
(
input_ids
)
inputs_embeds
=
self
.
language_model
.
get_input_embeddings
(
input_ids
)
if
multimodal_embeddings
is
not
None
:
if
multimodal_embeddings
is
not
None
:
assert
self
.
img_context_token_id
is
not
None
assert
self
.
img_context_token_id
is
not
None
self
.
_set_visual_token_mask
(
input_ids
)
inputs_embeds
=
merge_multimodal_embeddings
(
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
inputs_embeds
,
multimodal_embeddings
,
input_ids
,
inputs_embeds
,
multimodal_embeddings
,
self
.
img_context_token_id
)
self
.
img_context_token_id
)
...
@@ -674,7 +675,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -674,7 +675,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
**
kwargs
:
object
,
**
kwargs
:
object
,
)
->
Union
[
SamplerOutput
,
IntermediateTensors
]:
)
->
Union
[
SamplerOutput
,
IntermediateTensors
]:
visual_token_mask
=
None
if
intermediate_tensors
is
not
None
:
if
intermediate_tensors
is
not
None
:
input_ids
=
None
input_ids
=
None
inputs_embeds
=
None
inputs_embeds
=
None
...
@@ -695,16 +695,15 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -695,16 +695,15 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
"intermediate_tensors"
:
intermediate_tensors
,
"intermediate_tensors"
:
intermediate_tensors
,
"inputs_embeds"
:
inputs_embeds
,
"inputs_embeds"
:
inputs_embeds
,
}
}
if
self
.
img_context_token_id
is
not
None
:
visual_token_mask
=
self
.
_get_visual_token_mask
(
input_ids
)
# We always overwrite it back to None after computing visual token
if
self
.
visual_token_mask
is
not
None
:
# mask so that this doesn't need to depend on encoder output
# overwrite visual_token_mask and img_context_token_id back to None,
# so that this doesn't need to depend on encoder output
forward_kwargs
.
update
(
{
"visual_token_mask"
:
self
.
visual_token_mask
})
self
.
visual_token_mask
=
None
self
.
img_context_token_id
=
None
self
.
img_context_token_id
=
None
if
self
.
is_mono
:
forward_kwargs
.
update
({
"visual_token_mask"
:
visual_token_mask
})
hidden_states
=
self
.
language_model
.
model
(
**
forward_kwargs
)
hidden_states
=
self
.
language_model
.
model
(
**
forward_kwargs
)
return
hidden_states
return
hidden_states
...
...
vllm/model_executor/models/molmo.py
View file @
10398b47
...
@@ -13,6 +13,7 @@ from torch.nn import functional as F
...
@@ -13,6 +13,7 @@ from torch.nn import functional as F
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention.layer
import
MultiHeadAttention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
...
@@ -38,14 +39,12 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...
@@ -38,14 +39,12 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalKwargs
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalKwargs
from
vllm.multimodal.inputs
import
NestedTensors
from
vllm.multimodal.inputs
import
NestedTensors
from
vllm.multimodal.utils
import
cached_get_tokenizer
from
vllm.multimodal.utils
import
cached_get_tokenizer
from
vllm.platforms
import
_Backend
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
IntermediateTensors
,
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
IntermediateTensors
,
SequenceData
)
SequenceData
)
from
vllm.transformers_utils.processor
import
get_processor
from
vllm.transformers_utils.processor
import
get_processor
from
.interfaces
import
SupportsMultiModal
,
SupportsPP
from
.interfaces
import
SupportsMultiModal
,
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
get_vit_attn_backend
,
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
is_pp_missing_parameter
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
maybe_prefix
)
...
@@ -188,13 +187,11 @@ class MultiHeadDotProductAttention(nn.Module):
...
@@ -188,13 +187,11 @@ class MultiHeadDotProductAttention(nn.Module):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
)
)
# Detect attention implementation.
self
.
scale
=
self
.
head_dim
**-
0.5
self
.
attn_backend
:
_Backend
=
get_vit_attn_backend
(
support_fa
=
True
)
self
.
attn
=
MultiHeadAttention
(
self
.
num_heads
,
if
self
.
attn_backend
not
in
{
self
.
head_dim
,
_Backend
.
FLASH_ATTN
,
_Backend
.
TORCH_SDPA
,
_Backend
.
XFORMERS
self
.
scale
,
}:
num_kv_heads
=
self
.
num_kv_heads
)
raise
RuntimeError
(
f
"Molmo does not support
{
self
.
attn_backend
}
backend now."
)
def
forward
(
self
,
def
forward
(
self
,
inputs_q
:
torch
.
Tensor
,
inputs_q
:
torch
.
Tensor
,
...
@@ -210,25 +207,8 @@ class MultiHeadDotProductAttention(nn.Module):
...
@@ -210,25 +207,8 @@ class MultiHeadDotProductAttention(nn.Module):
xq
,
_
=
self
.
wq
(
inputs_q
)
xq
,
_
=
self
.
wq
(
inputs_q
)
xk
,
_
=
self
.
wk
(
inputs_k
)
xk
,
_
=
self
.
wk
(
inputs_k
)
xv
,
_
=
self
.
wv
(
inputs_v
)
xv
,
_
=
self
.
wv
(
inputs_v
)
q_shape
=
xq
.
size
()[:
-
1
]
+
(
self
.
num_heads
,
self
.
head_dim
)
kv_shape
=
xk
.
size
()[:
-
1
]
+
(
self
.
num_kv_heads
,
self
.
head_dim
)
output
=
self
.
attn
(
xq
,
xk
,
xv
)
xq
=
xq
.
view
(
*
q_shape
)
xk
=
xk
.
view
(
*
kv_shape
)
xv
=
xv
.
view
(
*
kv_shape
)
if
self
.
attn_backend
==
_Backend
.
FLASH_ATTN
:
from
flash_attn
import
flash_attn_func
output
=
flash_attn_func
(
xq
,
xk
,
xv
,
dropout_p
=
0.0
,
causal
=
False
)
elif
self
.
attn_backend
==
_Backend
.
TORCH_SDPA
:
xq
,
xk
,
xv
=
(
rearrange
(
x
,
"b s h d -> b h s d"
)
for
x
in
(
xq
,
xk
,
xv
))
output
=
F
.
scaled_dot_product_attention
(
xq
,
xk
,
xv
)
output
=
rearrange
(
output
,
"b h s d -> b s h d "
)
elif
self
.
attn_backend
==
_Backend
.
XFORMERS
:
from
xformers
import
ops
as
xops
output
=
xops
.
memory_efficient_attention_forward
(
xq
,
xk
,
xv
,
p
=
0
)
output
=
rearrange
(
output
,
"b s h d -> b s (h d)"
).
contiguous
()
output
,
_
=
self
.
wo
(
output
)
output
,
_
=
self
.
wo
(
output
)
return
output
return
output
...
...
vllm/model_executor/models/siglip.py
View file @
10398b47
...
@@ -6,12 +6,11 @@ from typing import Iterable, List, Optional, Set, Tuple, Union
...
@@ -6,12 +6,11 @@ from typing import Iterable, List, Optional, Set, Tuple, Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.nn.functional
as
F
from
PIL
import
Image
from
PIL
import
Image
from
torch
import
nn
from
torch
import
nn
from
transformers
import
SiglipVisionConfig
from
transformers
import
SiglipVisionConfig
from
vllm.attention.
selecto
r
import
_Backend
from
vllm.attention.
laye
r
import
MultiHeadAttention
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.distributed
import
divide
,
get_tensor_model_parallel_world_size
from
vllm.distributed
import
divide
,
get_tensor_model_parallel_world_size
from
vllm.inputs
import
DecoderOnlyInputs
,
token_inputs
from
vllm.inputs
import
DecoderOnlyInputs
,
token_inputs
...
@@ -29,8 +28,6 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
...
@@ -29,8 +28,6 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
resolve_visual_encoder_outputs
)
resolve_visual_encoder_outputs
)
from
vllm.sequence
import
SequenceData
from
vllm.sequence
import
SequenceData
from
.utils
import
get_vit_attn_backend
def
get_siglip_patch_grid_length
(
*
,
image_size
:
int
,
patch_size
:
int
)
->
int
:
def
get_siglip_patch_grid_length
(
*
,
image_size
:
int
,
patch_size
:
int
)
->
int
:
# Since interpolation is applied, the image size need not be divisible
# Since interpolation is applied, the image size need not be divisible
...
@@ -291,52 +288,18 @@ class SiglipAttention(nn.Module):
...
@@ -291,52 +288,18 @@ class SiglipAttention(nn.Module):
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
num_heads_per_partition
=
divide
(
self
.
num_heads
,
self
.
tp_size
)
self
.
num_heads_per_partition
=
divide
(
self
.
num_heads
,
self
.
tp_size
)
self
.
attn_backend
=
get_vit_attn_backend
(
support_fa
=
False
)
self
.
attn
=
MultiHeadAttention
(
self
.
num_heads_per_partition
,
if
self
.
attn_backend
not
in
{
_Backend
.
TORCH_SDPA
,
_Backend
.
XFORMERS
}:
self
.
head_dim
,
self
.
scale
)
raise
RuntimeError
(
f
"SIGLIP does not support
{
self
.
attn_backend
}
backend now."
)
def
forward
(
def
forward
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Input shape: Batch x Time x Channel"""
"""Input shape: Batch x Time x Channel"""
batch_size
,
q_len
,
_
=
hidden_states
.
size
()
qkv_states
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv_states
,
_
=
self
.
qkv_proj
(
hidden_states
)
query_states
,
key_states
,
value_states
=
qkv_states
.
chunk
(
3
,
dim
=-
1
)
query_states
,
key_states
,
value_states
=
qkv_states
.
chunk
(
3
,
dim
=-
1
)
query_states
=
query_states
.
view
(
batch_size
,
q_len
,
out
=
self
.
attn
(
query_states
,
key_states
,
value_states
)
self
.
num_heads_per_partition
,
self
.
head_dim
)
key_states
=
key_states
.
view
(
batch_size
,
q_len
,
self
.
num_heads_per_partition
,
self
.
head_dim
)
value_states
=
value_states
.
view
(
batch_size
,
q_len
,
self
.
num_heads_per_partition
,
self
.
head_dim
)
if
self
.
attn_backend
==
_Backend
.
XFORMERS
:
from
xformers
import
ops
as
xops
out
=
xops
.
memory_efficient_attention_forward
(
query_states
,
key_states
,
value_states
,
p
=
self
.
dropout
,
scale
=
self
.
scale
)
elif
self
.
attn_backend
==
_Backend
.
TORCH_SDPA
:
query_states
,
key_states
,
value_states
=
(
x
.
transpose
(
1
,
2
)
for
x
in
(
query_states
,
key_states
,
value_states
))
out
=
F
.
scaled_dot_product_attention
(
query_states
,
key_states
,
value_states
,
dropout_p
=
self
.
dropout
,
scale
=
self
.
scale
)
out
=
out
.
transpose
(
1
,
2
)
out
=
out
.
view
(
batch_size
,
q_len
,
-
1
)
attn_output
,
_
=
self
.
out_proj
(
out
)
attn_output
,
_
=
self
.
out_proj
(
out
)
return
attn_output
,
None
return
attn_output
,
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment