Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e7ebb662
Unverified
Commit
e7ebb662
authored
Nov 18, 2024
by
Isotr0py
Committed by
GitHub
Nov 18, 2024
Browse files
[Model] Remove transformers attention porting in VITs (#10414)
Signed-off-by:
Isotr0py
<
2037008807@qq.com
>
parent
5be4e52b
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
139 additions
and
102 deletions
+139
-102
vllm/model_executor/models/blip.py
vllm/model_executor/models/blip.py
+36
-30
vllm/model_executor/models/clip.py
vllm/model_executor/models/clip.py
+36
-29
vllm/model_executor/models/intern_vit.py
vllm/model_executor/models/intern_vit.py
+22
-10
vllm/model_executor/models/molmo.py
vllm/model_executor/models/molmo.py
+1
-1
vllm/model_executor/models/qwen2_vl.py
vllm/model_executor/models/qwen2_vl.py
+1
-1
vllm/model_executor/models/siglip.py
vllm/model_executor/models/siglip.py
+35
-28
vllm/model_executor/models/utils.py
vllm/model_executor/models/utils.py
+8
-3
No files found.
vllm/model_executor/models/blip.py
View file @
e7ebb662
...
...
@@ -4,10 +4,11 @@ from typing import Iterable, Optional, Set, Tuple, Union
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
PIL
import
Image
from
transformers
import
Blip2VisionConfig
,
BlipVisionConfig
from
transformers.models.blip.modeling_blip
import
BlipAttention
from
vllm.attention.selector
import
_Backend
from
vllm.config
import
ModelConfig
from
vllm.distributed
import
divide
,
get_tensor_model_parallel_world_size
from
vllm.inputs
import
DecoderOnlyInputs
,
token_inputs
...
...
@@ -21,11 +22,7 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens
)
from
vllm.sequence
import
SequenceData
try
:
from
xformers
import
ops
as
xops
USE_XFORMERS_OPS
=
True
except
ImportError
:
USE_XFORMERS_OPS
=
False
from
.utils
import
get_vit_attn_backend
def
get_blip_patch_grid_length
(
*
,
image_size
:
int
,
patch_size
:
int
)
->
int
:
...
...
@@ -168,7 +165,7 @@ class BlipVisionEmbeddings(nn.Module):
return
embeddings
class
Blip
Parallel
Attention
(
nn
.
Module
):
class
BlipAttention
(
nn
.
Module
):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def
__init__
(
...
...
@@ -208,6 +205,12 @@ class BlipParallelAttention(nn.Module):
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
num_heads_per_partition
=
divide
(
self
.
num_heads
,
self
.
tp_size
)
# Detect attention implementation.
self
.
attn_backend
=
get_vit_attn_backend
(
support_fa
=
False
)
if
self
.
attn_backend
not
in
{
_Backend
.
TORCH_SDPA
,
_Backend
.
XFORMERS
}:
raise
RuntimeError
(
f
"BLIP does not support
{
self
.
attn_backend
}
backend now."
)
def
_shape
(
self
,
tensor
:
torch
.
Tensor
,
seq_len
:
int
,
bsz
:
int
):
return
tensor
.
view
(
bsz
,
seq_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
).
contiguous
()
...
...
@@ -231,11 +234,26 @@ class BlipParallelAttention(nn.Module):
self
.
num_heads_per_partition
,
self
.
head_dim
)
if
self
.
attn_backend
==
_Backend
.
XFORMERS
:
from
xformers
import
ops
as
xops
out
=
xops
.
memory_efficient_attention_forward
(
query_states
,
key_states
,
value_states
,
p
=
self
.
dropout
,
scale
=
self
.
scale
)
elif
self
.
attn_backend
==
_Backend
.
TORCH_SDPA
:
query_states
,
key_states
,
value_states
=
(
x
.
transpose
(
1
,
2
)
for
x
in
(
query_states
,
key_states
,
value_states
))
out
=
F
.
scaled_dot_product_attention
(
query_states
,
key_states
,
value_states
,
dropout_p
=
self
.
dropout
,
scale
=
self
.
scale
)
out
=
out
.
transpose
(
1
,
2
)
out
=
out
.
view
(
bsz
,
tgt_len
,
-
1
)
attn_output
,
_
=
self
.
projection
(
out
)
...
...
@@ -285,18 +303,11 @@ class BlipEncoderLayer(nn.Module):
super
().
__init__
()
# fallback to sdpa attention if tp unavailable
num_heads
=
config
.
num_attention_heads
tp_size
=
get_tensor_model_parallel_world_size
()
if
USE_XFORMERS_OPS
and
num_heads
%
tp_size
==
0
:
self
.
self_attn
=
BlipParallelAttention
(
self
.
self_attn
=
BlipAttention
(
config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
)
else
:
# Blip doesn't have SDPA attention implemented in transformers
# use eager attention instead for cpu backend
self
.
self_attn
=
BlipAttention
(
config
)
self
.
layer_norm1
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
mlp
=
BlipMLP
(
config
,
...
...
@@ -374,11 +385,6 @@ class BlipVisionModel(nn.Module):
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
tp_size
=
get_tensor_model_parallel_world_size
()
num_heads
=
config
.
num_attention_heads
self
.
shard_weight
=
USE_XFORMERS_OPS
and
num_heads
%
tp_size
==
0
self
.
config
=
config
self
.
embeddings
=
BlipVisionEmbeddings
(
config
)
...
...
@@ -422,7 +428,7 @@ class BlipVisionModel(nn.Module):
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
if
self
.
shard_weight
else
[]
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
layer_count
=
len
(
self
.
encoder
.
layers
)
...
...
vllm/model_executor/models/clip.py
View file @
e7ebb662
...
...
@@ -5,10 +5,11 @@ from typing import Iterable, List, Optional, Set, Tuple, Union
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
PIL
import
Image
from
transformers
import
CLIPVisionConfig
from
transformers.models.clip.modeling_clip
import
CLIPSdpaAttention
from
vllm.attention.selector
import
_Backend
from
vllm.config
import
ModelConfig
from
vllm.distributed
import
divide
,
get_tensor_model_parallel_world_size
from
vllm.inputs
import
DecoderOnlyInputs
,
token_inputs
...
...
@@ -23,11 +24,7 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens
)
from
vllm.sequence
import
SequenceData
try
:
from
xformers
import
ops
as
xops
USE_XFORMERS_OPS
=
True
except
ImportError
:
USE_XFORMERS_OPS
=
False
from
.utils
import
get_vit_attn_backend
def
get_clip_patch_grid_length
(
*
,
image_size
:
int
,
patch_size
:
int
)
->
int
:
...
...
@@ -197,7 +194,7 @@ class CLIPVisionEmbeddings(nn.Module):
return
embeddings
class
CLIP
Parallel
Attention
(
nn
.
Module
):
class
CLIPAttention
(
nn
.
Module
):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def
__init__
(
...
...
@@ -237,6 +234,12 @@ class CLIPParallelAttention(nn.Module):
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
num_heads_per_partition
=
divide
(
self
.
num_heads
,
self
.
tp_size
)
# Detect attention implementation.
self
.
attn_backend
=
get_vit_attn_backend
(
support_fa
=
False
)
if
self
.
attn_backend
not
in
{
_Backend
.
TORCH_SDPA
,
_Backend
.
XFORMERS
}:
raise
RuntimeError
(
f
"CLIP does not support
{
self
.
attn_backend
}
backend now."
)
def
_shape
(
self
,
tensor
:
torch
.
Tensor
,
seq_len
:
int
,
bsz
:
int
):
return
tensor
.
view
(
bsz
,
seq_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
).
contiguous
()
...
...
@@ -261,11 +264,26 @@ class CLIPParallelAttention(nn.Module):
self
.
num_heads_per_partition
,
self
.
head_dim
)
if
self
.
attn_backend
==
_Backend
.
XFORMERS
:
from
xformers
import
ops
as
xops
out
=
xops
.
memory_efficient_attention_forward
(
query_states
,
key_states
,
value_states
,
p
=
self
.
dropout
,
scale
=
self
.
scale
)
elif
self
.
attn_backend
==
_Backend
.
TORCH_SDPA
:
query_states
,
key_states
,
value_states
=
(
x
.
transpose
(
1
,
2
)
for
x
in
(
query_states
,
key_states
,
value_states
))
out
=
F
.
scaled_dot_product_attention
(
query_states
,
key_states
,
value_states
,
dropout_p
=
self
.
dropout
,
scale
=
self
.
scale
)
out
=
out
.
transpose
(
1
,
2
)
out
=
out
.
view
(
bsz
,
tgt_len
,
-
1
)
attn_output
,
_
=
self
.
out_proj
(
out
)
...
...
@@ -311,17 +329,11 @@ class CLIPEncoderLayer(nn.Module):
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
num_heads
=
config
.
num_attention_heads
tp_size
=
get_tensor_model_parallel_world_size
()
if
USE_XFORMERS_OPS
and
num_heads
%
tp_size
==
0
:
self
.
self_attn
=
CLIPParallelAttention
(
self
.
self_attn
=
CLIPAttention
(
config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
)
else
:
self
.
self_attn
=
CLIPSdpaAttention
(
config
)
self
.
layer_norm1
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
mlp
=
CLIPMLP
(
config
,
...
...
@@ -461,11 +473,6 @@ class CLIPVisionModel(nn.Module):
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
tp_size
=
get_tensor_model_parallel_world_size
()
num_heads
=
config
.
num_attention_heads
self
.
shard_weight
=
USE_XFORMERS_OPS
and
num_heads
%
tp_size
==
0
self
.
vision_model
=
CLIPVisionTransformer
(
config
=
config
,
quant_config
=
quant_config
,
...
...
@@ -490,7 +497,7 @@ class CLIPVisionModel(nn.Module):
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
if
self
.
shard_weight
else
[]
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
layer_count
=
len
(
self
.
vision_model
.
encoder
.
layers
)
...
...
vllm/model_executor/models/intern_vit.py
View file @
e7ebb662
...
...
@@ -12,6 +12,7 @@ import torch.nn as nn
import
torch.nn.functional
as
F
from
transformers
import
PretrainedConfig
from
vllm.attention.selector
import
_Backend
from
vllm.distributed
import
(
divide
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
split_tensor_along_last_dim
,
...
...
@@ -24,11 +25,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
try
:
from
xformers
import
ops
as
xops
USE_XFORMERS_OPS
=
True
except
ImportError
:
USE_XFORMERS_OPS
=
False
from
.utils
import
get_vit_attn_backend
NORM2FN
=
{
'rms_norm'
:
RMSNorm
,
...
...
@@ -186,6 +183,11 @@ class InternParallelAttention(nn.Module):
prefix
=
f
"
{
prefix
}
.proj"
,
)
self
.
attn_backend
=
get_vit_attn_backend
(
support_fa
=
False
)
if
self
.
attn_backend
not
in
{
_Backend
.
TORCH_SDPA
,
_Backend
.
XFORMERS
}:
raise
RuntimeError
(
f
"InternViT does not support
{
self
.
attn_backend
}
backend now."
)
def
_apply_qk_norm
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
):
if
self
.
tp_size
>
1
:
q
=
tensor_model_parallel_all_gather
(
q
.
contiguous
())
...
...
@@ -211,11 +213,21 @@ class InternParallelAttention(nn.Module):
k
=
k
.
view
(
B
,
N
,
self
.
num_heads_per_partition
,
self
.
head_dim
)
v
=
v
.
view
(
B
,
N
,
self
.
num_heads_per_partition
,
self
.
head_dim
)
x
=
xops
.
memory_efficient_attention_forward
(
q
,
k
,
v
,
scale
=
self
.
scale
)
x
=
x
.
view
(
B
,
N
,
-
1
)
if
self
.
attn_backend
==
_Backend
.
XFORMERS
:
from
xformers
import
ops
as
xops
x
,
_
=
self
.
proj
(
x
)
return
x
out
=
xops
.
memory_efficient_attention_forward
(
q
,
k
,
v
,
scale
=
self
.
scale
)
elif
self
.
attn_backend
==
_Backend
.
TORCH_SDPA
:
q
,
k
,
v
=
(
x
.
transpose
(
1
,
2
)
for
x
in
(
q
,
k
,
v
))
out
=
F
.
scaled_dot_product_attention
(
q
,
k
,
v
,
scale
=
self
.
scale
)
out
=
out
.
transpose
(
1
,
2
)
out
=
out
.
view
(
B
,
N
,
-
1
)
out
,
_
=
self
.
proj
(
out
)
return
out
class
InternSdpaAttention
(
nn
.
Module
):
...
...
@@ -362,7 +374,7 @@ class InternVisionEncoderLayer(nn.Module):
tp_size
=
get_tensor_model_parallel_world_size
()
num_heads
=
config
.
num_attention_heads
if
USE_XFORMERS_OPS
and
(
num_heads
+
num_dummy_heads
)
%
tp_size
==
0
:
if
(
num_heads
+
num_dummy_heads
)
%
tp_size
==
0
:
return
InternParallelAttention
(
config
,
quant_config
=
quant_config
,
num_dummy_heads
=
num_dummy_heads
,
...
...
vllm/model_executor/models/molmo.py
View file @
e7ebb662
...
...
@@ -187,7 +187,7 @@ class MultiHeadDotProductAttention(nn.Module):
)
# Detect attention implementation.
self
.
attn_backend
:
_Backend
=
get_vit_attn_backend
()
self
.
attn_backend
:
_Backend
=
get_vit_attn_backend
(
support_fa
=
True
)
if
self
.
attn_backend
not
in
{
_Backend
.
FLASH_ATTN
,
_Backend
.
TORCH_SDPA
,
_Backend
.
XFORMERS
}:
...
...
vllm/model_executor/models/qwen2_vl.py
View file @
e7ebb662
...
...
@@ -260,7 +260,7 @@ class Qwen2VisionAttention(nn.Module):
prefix
=
f
"
{
prefix
}
.proj"
)
# Detect attention implementation.
self
.
attn_backend
:
_Backend
=
get_vit_attn_backend
()
self
.
attn_backend
:
_Backend
=
get_vit_attn_backend
(
support_fa
=
True
)
if
self
.
attn_backend
not
in
{
_Backend
.
FLASH_ATTN
,
_Backend
.
TORCH_SDPA
,
_Backend
.
XFORMERS
}:
...
...
vllm/model_executor/models/siglip.py
View file @
e7ebb662
...
...
@@ -6,11 +6,12 @@ from typing import Iterable, List, Optional, Set, Tuple, Union
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
from
PIL
import
Image
from
torch
import
nn
from
transformers
import
SiglipVisionConfig
from
transformers.models.siglip.modeling_siglip
import
SiglipSdpaAttention
from
vllm.attention.selector
import
_Backend
from
vllm.config
import
ModelConfig
from
vllm.distributed
import
divide
,
get_tensor_model_parallel_world_size
from
vllm.inputs
import
DecoderOnlyInputs
,
token_inputs
...
...
@@ -27,11 +28,7 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens
)
from
vllm.sequence
import
SequenceData
try
:
from
xformers
import
ops
as
xops
USE_XFORMERS_OPS
=
True
except
ImportError
:
USE_XFORMERS_OPS
=
False
from
.utils
import
get_vit_attn_backend
def
get_siglip_patch_grid_length
(
*
,
image_size
:
int
,
patch_size
:
int
)
->
int
:
...
...
@@ -254,7 +251,7 @@ class SiglipVisionEmbeddings(nn.Module):
return
embeddings
class
Siglip
Parallel
Attention
(
nn
.
Module
):
class
SiglipAttention
(
nn
.
Module
):
def
__init__
(
self
,
...
...
@@ -293,6 +290,11 @@ class SiglipParallelAttention(nn.Module):
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
num_heads_per_partition
=
divide
(
self
.
num_heads
,
self
.
tp_size
)
self
.
attn_backend
=
get_vit_attn_backend
(
support_fa
=
False
)
if
self
.
attn_backend
not
in
{
_Backend
.
TORCH_SDPA
,
_Backend
.
XFORMERS
}:
raise
RuntimeError
(
f
"SIGLIP does not support
{
self
.
attn_backend
}
backend now."
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
...
...
@@ -313,11 +315,26 @@ class SiglipParallelAttention(nn.Module):
self
.
num_heads_per_partition
,
self
.
head_dim
)
if
self
.
attn_backend
==
_Backend
.
XFORMERS
:
from
xformers
import
ops
as
xops
out
=
xops
.
memory_efficient_attention_forward
(
query_states
,
key_states
,
value_states
,
p
=
self
.
dropout
,
scale
=
self
.
scale
)
elif
self
.
attn_backend
==
_Backend
.
TORCH_SDPA
:
query_states
,
key_states
,
value_states
=
(
x
.
transpose
(
1
,
2
)
for
x
in
(
query_states
,
key_states
,
value_states
))
out
=
F
.
scaled_dot_product_attention
(
query_states
,
key_states
,
value_states
,
dropout_p
=
self
.
dropout
,
scale
=
self
.
scale
)
out
=
out
.
transpose
(
1
,
2
)
out
=
out
.
view
(
batch_size
,
q_len
,
-
1
)
attn_output
,
_
=
self
.
out_proj
(
out
)
...
...
@@ -372,17 +389,11 @@ class SiglipEncoderLayer(nn.Module):
self
.
embed_dim
=
config
.
hidden_size
num_heads
=
config
.
num_attention_heads
tp_size
=
get_tensor_model_parallel_world_size
()
if
USE_XFORMERS_OPS
and
num_heads
%
tp_size
==
0
:
self
.
self_attn
=
SiglipParallelAttention
(
self
.
self_attn
=
SiglipAttention
(
config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
)
else
:
self
.
self_attn
=
SiglipSdpaAttention
(
config
)
self
.
layer_norm1
=
nn
.
LayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_eps
)
self
.
mlp
=
SiglipMLP
(
...
...
@@ -569,10 +580,6 @@ class SiglipVisionModel(nn.Module):
)
->
None
:
super
().
__init__
()
num_heads
=
config
.
num_attention_heads
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
shard_weight
=
USE_XFORMERS_OPS
and
num_heads
%
tp_size
==
0
self
.
vision_model
=
SiglipVisionTransformer
(
config
,
quant_config
,
...
...
@@ -601,7 +608,7 @@ class SiglipVisionModel(nn.Module):
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
if
self
.
shard_weight
else
[]
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
layer_count
=
len
(
self
.
vision_model
.
encoder
.
layers
)
...
...
vllm/model_executor/models/utils.py
View file @
e7ebb662
...
...
@@ -587,7 +587,11 @@ class LLMWrapper(nn.Module):
return
llm
(
*
args
,
**
kwargs
)
def
get_vit_attn_backend
()
->
_Backend
:
def
get_vit_attn_backend
(
support_fa
:
bool
=
False
)
->
_Backend
:
"""
Get the available attention backend for Vision Transformer.
"""
# TODO(Isotr0py): Remove `support_fa` after support FA for all ViTs attn.
selected_backend
:
Optional
[
_Backend
]
=
get_global_forced_attn_backend
()
if
selected_backend
is
None
:
backend_by_env_var
:
Optional
[
str
]
=
envs
.
VLLM_ATTENTION_BACKEND
...
...
@@ -596,7 +600,7 @@ def get_vit_attn_backend() -> _Backend:
if
selected_backend
is
None
:
# For Volta and Turing GPUs, use xformers instead.
device_available
=
current_platform
.
has_device_capability
(
80
)
if
device_available
:
if
device_available
and
support_fa
:
from
transformers.utils
import
is_flash_attn_2_available
if
is_flash_attn_2_available
():
selected_backend
=
_Backend
.
FLASH_ATTN
...
...
@@ -606,7 +610,8 @@ def get_vit_attn_backend() -> _Backend:
"so we use xformers backend instead. You can run "
"`pip install flash-attn` to use flash-attention backend."
)
selected_backend
=
_Backend
.
XFORMERS
elif
current_platform
.
is_cpu
():
elif
current_platform
.
is_cpu
()
or
current_platform
.
is_rocm
():
# ROCM doesn't support xformers
selected_backend
=
_Backend
.
TORCH_SDPA
else
:
selected_backend
=
_Backend
.
XFORMERS
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment