Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e7ebb662
Unverified
Commit
e7ebb662
authored
Nov 18, 2024
by
Isotr0py
Committed by
GitHub
Nov 18, 2024
Browse files
[Model] Remove transformers attention porting in VITs (#10414)
Signed-off-by:
Isotr0py
<
2037008807@qq.com
>
parent
5be4e52b
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
139 additions
and
102 deletions
+139
-102
vllm/model_executor/models/blip.py
vllm/model_executor/models/blip.py
+36
-30
vllm/model_executor/models/clip.py
vllm/model_executor/models/clip.py
+36
-29
vllm/model_executor/models/intern_vit.py
vllm/model_executor/models/intern_vit.py
+22
-10
vllm/model_executor/models/molmo.py
vllm/model_executor/models/molmo.py
+1
-1
vllm/model_executor/models/qwen2_vl.py
vllm/model_executor/models/qwen2_vl.py
+1
-1
vllm/model_executor/models/siglip.py
vllm/model_executor/models/siglip.py
+35
-28
vllm/model_executor/models/utils.py
vllm/model_executor/models/utils.py
+8
-3
No files found.
vllm/model_executor/models/blip.py
View file @
e7ebb662
...
@@ -4,10 +4,11 @@ from typing import Iterable, Optional, Set, Tuple, Union
...
@@ -4,10 +4,11 @@ from typing import Iterable, Optional, Set, Tuple, Union
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
PIL
import
Image
from
PIL
import
Image
from
transformers
import
Blip2VisionConfig
,
BlipVisionConfig
from
transformers
import
Blip2VisionConfig
,
BlipVisionConfig
from
transformers.models.blip.modeling_blip
import
BlipAttention
from
vllm.attention.selector
import
_Backend
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.distributed
import
divide
,
get_tensor_model_parallel_world_size
from
vllm.distributed
import
divide
,
get_tensor_model_parallel_world_size
from
vllm.inputs
import
DecoderOnlyInputs
,
token_inputs
from
vllm.inputs
import
DecoderOnlyInputs
,
token_inputs
...
@@ -21,11 +22,7 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
...
@@ -21,11 +22,7 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens
)
repeat_and_pad_placeholder_tokens
)
from
vllm.sequence
import
SequenceData
from
vllm.sequence
import
SequenceData
try
:
from
.utils
import
get_vit_attn_backend
from
xformers
import
ops
as
xops
USE_XFORMERS_OPS
=
True
except
ImportError
:
USE_XFORMERS_OPS
=
False
def
get_blip_patch_grid_length
(
*
,
image_size
:
int
,
patch_size
:
int
)
->
int
:
def
get_blip_patch_grid_length
(
*
,
image_size
:
int
,
patch_size
:
int
)
->
int
:
...
@@ -168,7 +165,7 @@ class BlipVisionEmbeddings(nn.Module):
...
@@ -168,7 +165,7 @@ class BlipVisionEmbeddings(nn.Module):
return
embeddings
return
embeddings
class
Blip
Parallel
Attention
(
nn
.
Module
):
class
BlipAttention
(
nn
.
Module
):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def
__init__
(
def
__init__
(
...
@@ -208,6 +205,12 @@ class BlipParallelAttention(nn.Module):
...
@@ -208,6 +205,12 @@ class BlipParallelAttention(nn.Module):
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
num_heads_per_partition
=
divide
(
self
.
num_heads
,
self
.
tp_size
)
self
.
num_heads_per_partition
=
divide
(
self
.
num_heads
,
self
.
tp_size
)
# Detect attention implementation.
self
.
attn_backend
=
get_vit_attn_backend
(
support_fa
=
False
)
if
self
.
attn_backend
not
in
{
_Backend
.
TORCH_SDPA
,
_Backend
.
XFORMERS
}:
raise
RuntimeError
(
f
"BLIP does not support
{
self
.
attn_backend
}
backend now."
)
def
_shape
(
self
,
tensor
:
torch
.
Tensor
,
seq_len
:
int
,
bsz
:
int
):
def
_shape
(
self
,
tensor
:
torch
.
Tensor
,
seq_len
:
int
,
bsz
:
int
):
return
tensor
.
view
(
bsz
,
seq_len
,
self
.
num_heads
,
return
tensor
.
view
(
bsz
,
seq_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
).
contiguous
()
self
.
head_dim
).
transpose
(
1
,
2
).
contiguous
()
...
@@ -231,11 +234,26 @@ class BlipParallelAttention(nn.Module):
...
@@ -231,11 +234,26 @@ class BlipParallelAttention(nn.Module):
self
.
num_heads_per_partition
,
self
.
num_heads_per_partition
,
self
.
head_dim
)
self
.
head_dim
)
out
=
xops
.
memory_efficient_attention_forward
(
query_states
,
if
self
.
attn_backend
==
_Backend
.
XFORMERS
:
key_states
,
from
xformers
import
ops
as
xops
value_states
,
p
=
self
.
dropout
,
out
=
xops
.
memory_efficient_attention_forward
(
query_states
,
scale
=
self
.
scale
)
key_states
,
value_states
,
p
=
self
.
dropout
,
scale
=
self
.
scale
)
elif
self
.
attn_backend
==
_Backend
.
TORCH_SDPA
:
query_states
,
key_states
,
value_states
=
(
x
.
transpose
(
1
,
2
)
for
x
in
(
query_states
,
key_states
,
value_states
))
out
=
F
.
scaled_dot_product_attention
(
query_states
,
key_states
,
value_states
,
dropout_p
=
self
.
dropout
,
scale
=
self
.
scale
)
out
=
out
.
transpose
(
1
,
2
)
out
=
out
.
view
(
bsz
,
tgt_len
,
-
1
)
out
=
out
.
view
(
bsz
,
tgt_len
,
-
1
)
attn_output
,
_
=
self
.
projection
(
out
)
attn_output
,
_
=
self
.
projection
(
out
)
...
@@ -285,18 +303,11 @@ class BlipEncoderLayer(nn.Module):
...
@@ -285,18 +303,11 @@ class BlipEncoderLayer(nn.Module):
super
().
__init__
()
super
().
__init__
()
# fallback to sdpa attention if tp unavailable
# fallback to sdpa attention if tp unavailable
num_heads
=
config
.
num_attention_heads
self
.
self_attn
=
BlipAttention
(
tp_size
=
get_tensor_model_parallel_world_size
()
config
,
if
USE_XFORMERS_OPS
and
num_heads
%
tp_size
==
0
:
quant_config
=
quant_config
,
self
.
self_attn
=
BlipParallelAttention
(
prefix
=
f
"
{
prefix
}
.self_attn"
,
config
,
)
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
)
else
:
# Blip doesn't have SDPA attention implemented in transformers
# use eager attention instead for cpu backend
self
.
self_attn
=
BlipAttention
(
config
)
self
.
layer_norm1
=
nn
.
LayerNorm
(
config
.
hidden_size
,
self
.
layer_norm1
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
eps
=
config
.
layer_norm_eps
)
self
.
mlp
=
BlipMLP
(
config
,
self
.
mlp
=
BlipMLP
(
config
,
...
@@ -374,11 +385,6 @@ class BlipVisionModel(nn.Module):
...
@@ -374,11 +385,6 @@ class BlipVisionModel(nn.Module):
prefix
:
str
=
""
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
tp_size
=
get_tensor_model_parallel_world_size
()
num_heads
=
config
.
num_attention_heads
self
.
shard_weight
=
USE_XFORMERS_OPS
and
num_heads
%
tp_size
==
0
self
.
config
=
config
self
.
config
=
config
self
.
embeddings
=
BlipVisionEmbeddings
(
config
)
self
.
embeddings
=
BlipVisionEmbeddings
(
config
)
...
@@ -422,7 +428,7 @@ class BlipVisionModel(nn.Module):
...
@@ -422,7 +428,7 @@ class BlipVisionModel(nn.Module):
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
if
self
.
shard_weight
else
[]
]
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
loaded_params
:
Set
[
str
]
=
set
()
layer_count
=
len
(
self
.
encoder
.
layers
)
layer_count
=
len
(
self
.
encoder
.
layers
)
...
...
vllm/model_executor/models/clip.py
View file @
e7ebb662
...
@@ -5,10 +5,11 @@ from typing import Iterable, List, Optional, Set, Tuple, Union
...
@@ -5,10 +5,11 @@ from typing import Iterable, List, Optional, Set, Tuple, Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
PIL
import
Image
from
PIL
import
Image
from
transformers
import
CLIPVisionConfig
from
transformers
import
CLIPVisionConfig
from
transformers.models.clip.modeling_clip
import
CLIPSdpaAttention
from
vllm.attention.selector
import
_Backend
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.distributed
import
divide
,
get_tensor_model_parallel_world_size
from
vllm.distributed
import
divide
,
get_tensor_model_parallel_world_size
from
vllm.inputs
import
DecoderOnlyInputs
,
token_inputs
from
vllm.inputs
import
DecoderOnlyInputs
,
token_inputs
...
@@ -23,11 +24,7 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
...
@@ -23,11 +24,7 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens
)
repeat_and_pad_placeholder_tokens
)
from
vllm.sequence
import
SequenceData
from
vllm.sequence
import
SequenceData
try
:
from
.utils
import
get_vit_attn_backend
from
xformers
import
ops
as
xops
USE_XFORMERS_OPS
=
True
except
ImportError
:
USE_XFORMERS_OPS
=
False
def
get_clip_patch_grid_length
(
*
,
image_size
:
int
,
patch_size
:
int
)
->
int
:
def
get_clip_patch_grid_length
(
*
,
image_size
:
int
,
patch_size
:
int
)
->
int
:
...
@@ -197,7 +194,7 @@ class CLIPVisionEmbeddings(nn.Module):
...
@@ -197,7 +194,7 @@ class CLIPVisionEmbeddings(nn.Module):
return
embeddings
return
embeddings
class
CLIP
Parallel
Attention
(
nn
.
Module
):
class
CLIPAttention
(
nn
.
Module
):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def
__init__
(
def
__init__
(
...
@@ -237,6 +234,12 @@ class CLIPParallelAttention(nn.Module):
...
@@ -237,6 +234,12 @@ class CLIPParallelAttention(nn.Module):
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
num_heads_per_partition
=
divide
(
self
.
num_heads
,
self
.
tp_size
)
self
.
num_heads_per_partition
=
divide
(
self
.
num_heads
,
self
.
tp_size
)
# Detect attention implementation.
self
.
attn_backend
=
get_vit_attn_backend
(
support_fa
=
False
)
if
self
.
attn_backend
not
in
{
_Backend
.
TORCH_SDPA
,
_Backend
.
XFORMERS
}:
raise
RuntimeError
(
f
"CLIP does not support
{
self
.
attn_backend
}
backend now."
)
def
_shape
(
self
,
tensor
:
torch
.
Tensor
,
seq_len
:
int
,
bsz
:
int
):
def
_shape
(
self
,
tensor
:
torch
.
Tensor
,
seq_len
:
int
,
bsz
:
int
):
return
tensor
.
view
(
bsz
,
seq_len
,
self
.
num_heads
,
return
tensor
.
view
(
bsz
,
seq_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
).
contiguous
()
self
.
head_dim
).
transpose
(
1
,
2
).
contiguous
()
...
@@ -261,11 +264,26 @@ class CLIPParallelAttention(nn.Module):
...
@@ -261,11 +264,26 @@ class CLIPParallelAttention(nn.Module):
self
.
num_heads_per_partition
,
self
.
num_heads_per_partition
,
self
.
head_dim
)
self
.
head_dim
)
out
=
xops
.
memory_efficient_attention_forward
(
query_states
,
if
self
.
attn_backend
==
_Backend
.
XFORMERS
:
key_states
,
from
xformers
import
ops
as
xops
value_states
,
p
=
self
.
dropout
,
out
=
xops
.
memory_efficient_attention_forward
(
query_states
,
scale
=
self
.
scale
)
key_states
,
value_states
,
p
=
self
.
dropout
,
scale
=
self
.
scale
)
elif
self
.
attn_backend
==
_Backend
.
TORCH_SDPA
:
query_states
,
key_states
,
value_states
=
(
x
.
transpose
(
1
,
2
)
for
x
in
(
query_states
,
key_states
,
value_states
))
out
=
F
.
scaled_dot_product_attention
(
query_states
,
key_states
,
value_states
,
dropout_p
=
self
.
dropout
,
scale
=
self
.
scale
)
out
=
out
.
transpose
(
1
,
2
)
out
=
out
.
view
(
bsz
,
tgt_len
,
-
1
)
out
=
out
.
view
(
bsz
,
tgt_len
,
-
1
)
attn_output
,
_
=
self
.
out_proj
(
out
)
attn_output
,
_
=
self
.
out_proj
(
out
)
...
@@ -311,17 +329,11 @@ class CLIPEncoderLayer(nn.Module):
...
@@ -311,17 +329,11 @@ class CLIPEncoderLayer(nn.Module):
prefix
:
str
=
""
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
self_attn
=
CLIPAttention
(
num_heads
=
config
.
num_attention_heads
config
,
tp_size
=
get_tensor_model_parallel_world_size
()
quant_config
=
quant_config
,
if
USE_XFORMERS_OPS
and
num_heads
%
tp_size
==
0
:
prefix
=
f
"
{
prefix
}
.self_attn"
,
self
.
self_attn
=
CLIPParallelAttention
(
)
config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
)
else
:
self
.
self_attn
=
CLIPSdpaAttention
(
config
)
self
.
layer_norm1
=
nn
.
LayerNorm
(
config
.
hidden_size
,
self
.
layer_norm1
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
eps
=
config
.
layer_norm_eps
)
self
.
mlp
=
CLIPMLP
(
config
,
self
.
mlp
=
CLIPMLP
(
config
,
...
@@ -461,11 +473,6 @@ class CLIPVisionModel(nn.Module):
...
@@ -461,11 +473,6 @@ class CLIPVisionModel(nn.Module):
prefix
:
str
=
""
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
tp_size
=
get_tensor_model_parallel_world_size
()
num_heads
=
config
.
num_attention_heads
self
.
shard_weight
=
USE_XFORMERS_OPS
and
num_heads
%
tp_size
==
0
self
.
vision_model
=
CLIPVisionTransformer
(
self
.
vision_model
=
CLIPVisionTransformer
(
config
=
config
,
config
=
config
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
...
@@ -490,7 +497,7 @@ class CLIPVisionModel(nn.Module):
...
@@ -490,7 +497,7 @@ class CLIPVisionModel(nn.Module):
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
if
self
.
shard_weight
else
[]
]
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
loaded_params
:
Set
[
str
]
=
set
()
layer_count
=
len
(
self
.
vision_model
.
encoder
.
layers
)
layer_count
=
len
(
self
.
vision_model
.
encoder
.
layers
)
...
...
vllm/model_executor/models/intern_vit.py
View file @
e7ebb662
...
@@ -12,6 +12,7 @@ import torch.nn as nn
...
@@ -12,6 +12,7 @@ import torch.nn as nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
vllm.attention.selector
import
_Backend
from
vllm.distributed
import
(
divide
,
get_tensor_model_parallel_rank
,
from
vllm.distributed
import
(
divide
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
split_tensor_along_last_dim
,
split_tensor_along_last_dim
,
...
@@ -24,11 +25,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...
@@ -24,11 +25,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
try
:
from
.utils
import
get_vit_attn_backend
from
xformers
import
ops
as
xops
USE_XFORMERS_OPS
=
True
except
ImportError
:
USE_XFORMERS_OPS
=
False
NORM2FN
=
{
NORM2FN
=
{
'rms_norm'
:
RMSNorm
,
'rms_norm'
:
RMSNorm
,
...
@@ -186,6 +183,11 @@ class InternParallelAttention(nn.Module):
...
@@ -186,6 +183,11 @@ class InternParallelAttention(nn.Module):
prefix
=
f
"
{
prefix
}
.proj"
,
prefix
=
f
"
{
prefix
}
.proj"
,
)
)
self
.
attn_backend
=
get_vit_attn_backend
(
support_fa
=
False
)
if
self
.
attn_backend
not
in
{
_Backend
.
TORCH_SDPA
,
_Backend
.
XFORMERS
}:
raise
RuntimeError
(
f
"InternViT does not support
{
self
.
attn_backend
}
backend now."
)
def
_apply_qk_norm
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
):
def
_apply_qk_norm
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
):
if
self
.
tp_size
>
1
:
if
self
.
tp_size
>
1
:
q
=
tensor_model_parallel_all_gather
(
q
.
contiguous
())
q
=
tensor_model_parallel_all_gather
(
q
.
contiguous
())
...
@@ -211,11 +213,21 @@ class InternParallelAttention(nn.Module):
...
@@ -211,11 +213,21 @@ class InternParallelAttention(nn.Module):
k
=
k
.
view
(
B
,
N
,
self
.
num_heads_per_partition
,
self
.
head_dim
)
k
=
k
.
view
(
B
,
N
,
self
.
num_heads_per_partition
,
self
.
head_dim
)
v
=
v
.
view
(
B
,
N
,
self
.
num_heads_per_partition
,
self
.
head_dim
)
v
=
v
.
view
(
B
,
N
,
self
.
num_heads_per_partition
,
self
.
head_dim
)
x
=
xops
.
memory_efficient_attention_forward
(
q
,
k
,
v
,
scale
=
self
.
scale
)
if
self
.
attn_backend
==
_Backend
.
XFORMERS
:
x
=
x
.
view
(
B
,
N
,
-
1
)
from
xformers
import
ops
as
xops
x
,
_
=
self
.
proj
(
x
)
out
=
xops
.
memory_efficient_attention_forward
(
q
,
return
x
k
,
v
,
scale
=
self
.
scale
)
elif
self
.
attn_backend
==
_Backend
.
TORCH_SDPA
:
q
,
k
,
v
=
(
x
.
transpose
(
1
,
2
)
for
x
in
(
q
,
k
,
v
))
out
=
F
.
scaled_dot_product_attention
(
q
,
k
,
v
,
scale
=
self
.
scale
)
out
=
out
.
transpose
(
1
,
2
)
out
=
out
.
view
(
B
,
N
,
-
1
)
out
,
_
=
self
.
proj
(
out
)
return
out
class
InternSdpaAttention
(
nn
.
Module
):
class
InternSdpaAttention
(
nn
.
Module
):
...
@@ -362,7 +374,7 @@ class InternVisionEncoderLayer(nn.Module):
...
@@ -362,7 +374,7 @@ class InternVisionEncoderLayer(nn.Module):
tp_size
=
get_tensor_model_parallel_world_size
()
tp_size
=
get_tensor_model_parallel_world_size
()
num_heads
=
config
.
num_attention_heads
num_heads
=
config
.
num_attention_heads
if
USE_XFORMERS_OPS
and
(
num_heads
+
num_dummy_heads
)
%
tp_size
==
0
:
if
(
num_heads
+
num_dummy_heads
)
%
tp_size
==
0
:
return
InternParallelAttention
(
config
,
return
InternParallelAttention
(
config
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
num_dummy_heads
=
num_dummy_heads
,
num_dummy_heads
=
num_dummy_heads
,
...
...
vllm/model_executor/models/molmo.py
View file @
e7ebb662
...
@@ -187,7 +187,7 @@ class MultiHeadDotProductAttention(nn.Module):
...
@@ -187,7 +187,7 @@ class MultiHeadDotProductAttention(nn.Module):
)
)
# Detect attention implementation.
# Detect attention implementation.
self
.
attn_backend
:
_Backend
=
get_vit_attn_backend
()
self
.
attn_backend
:
_Backend
=
get_vit_attn_backend
(
support_fa
=
True
)
if
self
.
attn_backend
not
in
{
if
self
.
attn_backend
not
in
{
_Backend
.
FLASH_ATTN
,
_Backend
.
TORCH_SDPA
,
_Backend
.
XFORMERS
_Backend
.
FLASH_ATTN
,
_Backend
.
TORCH_SDPA
,
_Backend
.
XFORMERS
}:
}:
...
...
vllm/model_executor/models/qwen2_vl.py
View file @
e7ebb662
...
@@ -260,7 +260,7 @@ class Qwen2VisionAttention(nn.Module):
...
@@ -260,7 +260,7 @@ class Qwen2VisionAttention(nn.Module):
prefix
=
f
"
{
prefix
}
.proj"
)
prefix
=
f
"
{
prefix
}
.proj"
)
# Detect attention implementation.
# Detect attention implementation.
self
.
attn_backend
:
_Backend
=
get_vit_attn_backend
()
self
.
attn_backend
:
_Backend
=
get_vit_attn_backend
(
support_fa
=
True
)
if
self
.
attn_backend
not
in
{
if
self
.
attn_backend
not
in
{
_Backend
.
FLASH_ATTN
,
_Backend
.
TORCH_SDPA
,
_Backend
.
XFORMERS
_Backend
.
FLASH_ATTN
,
_Backend
.
TORCH_SDPA
,
_Backend
.
XFORMERS
}:
}:
...
...
vllm/model_executor/models/siglip.py
View file @
e7ebb662
...
@@ -6,11 +6,12 @@ from typing import Iterable, List, Optional, Set, Tuple, Union
...
@@ -6,11 +6,12 @@ from typing import Iterable, List, Optional, Set, Tuple, Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.nn.functional
as
F
from
PIL
import
Image
from
PIL
import
Image
from
torch
import
nn
from
torch
import
nn
from
transformers
import
SiglipVisionConfig
from
transformers
import
SiglipVisionConfig
from
transformers.models.siglip.modeling_siglip
import
SiglipSdpaAttention
from
vllm.attention.selector
import
_Backend
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.distributed
import
divide
,
get_tensor_model_parallel_world_size
from
vllm.distributed
import
divide
,
get_tensor_model_parallel_world_size
from
vllm.inputs
import
DecoderOnlyInputs
,
token_inputs
from
vllm.inputs
import
DecoderOnlyInputs
,
token_inputs
...
@@ -27,11 +28,7 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
...
@@ -27,11 +28,7 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens
)
repeat_and_pad_placeholder_tokens
)
from
vllm.sequence
import
SequenceData
from
vllm.sequence
import
SequenceData
try
:
from
.utils
import
get_vit_attn_backend
from
xformers
import
ops
as
xops
USE_XFORMERS_OPS
=
True
except
ImportError
:
USE_XFORMERS_OPS
=
False
def
get_siglip_patch_grid_length
(
*
,
image_size
:
int
,
patch_size
:
int
)
->
int
:
def
get_siglip_patch_grid_length
(
*
,
image_size
:
int
,
patch_size
:
int
)
->
int
:
...
@@ -254,7 +251,7 @@ class SiglipVisionEmbeddings(nn.Module):
...
@@ -254,7 +251,7 @@ class SiglipVisionEmbeddings(nn.Module):
return
embeddings
return
embeddings
class
Siglip
Parallel
Attention
(
nn
.
Module
):
class
SiglipAttention
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -293,6 +290,11 @@ class SiglipParallelAttention(nn.Module):
...
@@ -293,6 +290,11 @@ class SiglipParallelAttention(nn.Module):
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
num_heads_per_partition
=
divide
(
self
.
num_heads
,
self
.
tp_size
)
self
.
num_heads_per_partition
=
divide
(
self
.
num_heads
,
self
.
tp_size
)
self
.
attn_backend
=
get_vit_attn_backend
(
support_fa
=
False
)
if
self
.
attn_backend
not
in
{
_Backend
.
TORCH_SDPA
,
_Backend
.
XFORMERS
}:
raise
RuntimeError
(
f
"SIGLIP does not support
{
self
.
attn_backend
}
backend now."
)
def
forward
(
def
forward
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
...
@@ -313,11 +315,26 @@ class SiglipParallelAttention(nn.Module):
...
@@ -313,11 +315,26 @@ class SiglipParallelAttention(nn.Module):
self
.
num_heads_per_partition
,
self
.
num_heads_per_partition
,
self
.
head_dim
)
self
.
head_dim
)
out
=
xops
.
memory_efficient_attention_forward
(
query_states
,
if
self
.
attn_backend
==
_Backend
.
XFORMERS
:
key_states
,
from
xformers
import
ops
as
xops
value_states
,
p
=
self
.
dropout
,
out
=
xops
.
memory_efficient_attention_forward
(
query_states
,
scale
=
self
.
scale
)
key_states
,
value_states
,
p
=
self
.
dropout
,
scale
=
self
.
scale
)
elif
self
.
attn_backend
==
_Backend
.
TORCH_SDPA
:
query_states
,
key_states
,
value_states
=
(
x
.
transpose
(
1
,
2
)
for
x
in
(
query_states
,
key_states
,
value_states
))
out
=
F
.
scaled_dot_product_attention
(
query_states
,
key_states
,
value_states
,
dropout_p
=
self
.
dropout
,
scale
=
self
.
scale
)
out
=
out
.
transpose
(
1
,
2
)
out
=
out
.
view
(
batch_size
,
q_len
,
-
1
)
out
=
out
.
view
(
batch_size
,
q_len
,
-
1
)
attn_output
,
_
=
self
.
out_proj
(
out
)
attn_output
,
_
=
self
.
out_proj
(
out
)
...
@@ -372,17 +389,11 @@ class SiglipEncoderLayer(nn.Module):
...
@@ -372,17 +389,11 @@ class SiglipEncoderLayer(nn.Module):
self
.
embed_dim
=
config
.
hidden_size
self
.
embed_dim
=
config
.
hidden_size
num_heads
=
config
.
num_attention_heads
self
.
self_attn
=
SiglipAttention
(
tp_size
=
get_tensor_model_parallel_world_size
()
config
,
if
USE_XFORMERS_OPS
and
num_heads
%
tp_size
==
0
:
quant_config
=
quant_config
,
self
.
self_attn
=
SiglipParallelAttention
(
prefix
=
f
"
{
prefix
}
.self_attn"
,
config
,
)
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
)
else
:
self
.
self_attn
=
SiglipSdpaAttention
(
config
)
self
.
layer_norm1
=
nn
.
LayerNorm
(
self
.
embed_dim
,
self
.
layer_norm1
=
nn
.
LayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_eps
)
eps
=
config
.
layer_norm_eps
)
self
.
mlp
=
SiglipMLP
(
self
.
mlp
=
SiglipMLP
(
...
@@ -569,10 +580,6 @@ class SiglipVisionModel(nn.Module):
...
@@ -569,10 +580,6 @@ class SiglipVisionModel(nn.Module):
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
num_heads
=
config
.
num_attention_heads
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
shard_weight
=
USE_XFORMERS_OPS
and
num_heads
%
tp_size
==
0
self
.
vision_model
=
SiglipVisionTransformer
(
self
.
vision_model
=
SiglipVisionTransformer
(
config
,
config
,
quant_config
,
quant_config
,
...
@@ -601,7 +608,7 @@ class SiglipVisionModel(nn.Module):
...
@@ -601,7 +608,7 @@ class SiglipVisionModel(nn.Module):
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
if
self
.
shard_weight
else
[]
]
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
loaded_params
:
Set
[
str
]
=
set
()
layer_count
=
len
(
self
.
vision_model
.
encoder
.
layers
)
layer_count
=
len
(
self
.
vision_model
.
encoder
.
layers
)
...
...
vllm/model_executor/models/utils.py
View file @
e7ebb662
...
@@ -587,7 +587,11 @@ class LLMWrapper(nn.Module):
...
@@ -587,7 +587,11 @@ class LLMWrapper(nn.Module):
return
llm
(
*
args
,
**
kwargs
)
return
llm
(
*
args
,
**
kwargs
)
def
get_vit_attn_backend
()
->
_Backend
:
def
get_vit_attn_backend
(
support_fa
:
bool
=
False
)
->
_Backend
:
"""
Get the available attention backend for Vision Transformer.
"""
# TODO(Isotr0py): Remove `support_fa` after support FA for all ViTs attn.
selected_backend
:
Optional
[
_Backend
]
=
get_global_forced_attn_backend
()
selected_backend
:
Optional
[
_Backend
]
=
get_global_forced_attn_backend
()
if
selected_backend
is
None
:
if
selected_backend
is
None
:
backend_by_env_var
:
Optional
[
str
]
=
envs
.
VLLM_ATTENTION_BACKEND
backend_by_env_var
:
Optional
[
str
]
=
envs
.
VLLM_ATTENTION_BACKEND
...
@@ -596,7 +600,7 @@ def get_vit_attn_backend() -> _Backend:
...
@@ -596,7 +600,7 @@ def get_vit_attn_backend() -> _Backend:
if
selected_backend
is
None
:
if
selected_backend
is
None
:
# For Volta and Turing GPUs, use xformers instead.
# For Volta and Turing GPUs, use xformers instead.
device_available
=
current_platform
.
has_device_capability
(
80
)
device_available
=
current_platform
.
has_device_capability
(
80
)
if
device_available
:
if
device_available
and
support_fa
:
from
transformers.utils
import
is_flash_attn_2_available
from
transformers.utils
import
is_flash_attn_2_available
if
is_flash_attn_2_available
():
if
is_flash_attn_2_available
():
selected_backend
=
_Backend
.
FLASH_ATTN
selected_backend
=
_Backend
.
FLASH_ATTN
...
@@ -606,7 +610,8 @@ def get_vit_attn_backend() -> _Backend:
...
@@ -606,7 +610,8 @@ def get_vit_attn_backend() -> _Backend:
"so we use xformers backend instead. You can run "
"so we use xformers backend instead. You can run "
"`pip install flash-attn` to use flash-attention backend."
)
"`pip install flash-attn` to use flash-attention backend."
)
selected_backend
=
_Backend
.
XFORMERS
selected_backend
=
_Backend
.
XFORMERS
elif
current_platform
.
is_cpu
():
elif
current_platform
.
is_cpu
()
or
current_platform
.
is_rocm
():
# ROCM doesn't support xformers
selected_backend
=
_Backend
.
TORCH_SDPA
selected_backend
=
_Backend
.
TORCH_SDPA
else
:
else
:
selected_backend
=
_Backend
.
XFORMERS
selected_backend
=
_Backend
.
XFORMERS
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment