Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cf1d62a6
Unverified
Commit
cf1d62a6
authored
Oct 16, 2024
by
Isotr0py
Committed by
GitHub
Oct 16, 2024
Browse files
[Model] Support SDPA attention for Molmo vision backbone (#9410)
parent
59230ef3
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
61 additions
and
78 deletions
+61
-78
vllm/model_executor/models/molmo.py
vllm/model_executor/models/molmo.py
+15
-37
vllm/model_executor/models/qwen2_vl.py
vllm/model_executor/models/qwen2_vl.py
+12
-40
vllm/model_executor/models/utils.py
vllm/model_executor/models/utils.py
+34
-1
No files found.
vllm/model_executor/models/molmo.py
View file @
cf1d62a6
import
logging
import
math
import
math
import
re
import
re
from
array
import
array
from
array
import
array
...
@@ -14,10 +13,8 @@ from torch import nn
...
@@ -14,10 +13,8 @@ from torch import nn
from
torch.nn
import
functional
as
F
from
torch.nn
import
functional
as
F
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
import
vllm.envs
as
envs
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention.selector
import
(
_Backend
,
backend_name_to_enum
,
from
vllm.attention.selector
import
_Backend
get_global_forced_attn_backend
)
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
...
@@ -43,12 +40,11 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...
@@ -43,12 +40,11 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
vllm.model_executor.models.interfaces
import
SupportsMultiModal
from
vllm.model_executor.models.interfaces
import
SupportsMultiModal
from
vllm.model_executor.models.utils
import
make_layers
from
vllm.model_executor.models.utils
import
make_layers
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalInputs
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalInputs
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
IntermediateTensors
,
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
IntermediateTensors
,
SequenceData
)
SequenceData
)
from
vllm.transformers_utils.processor
import
get_processor
from
vllm.transformers_utils.processor
import
get_processor
log
=
logging
.
getLogger
(
__name__
)
from
.utils
import
get_vit_attn_backend
# TODO: hard-coded for now. Consider making it configurable.
# TODO: hard-coded for now. Consider making it configurable.
VIT_LAYERS
=
[
-
2
,
-
9
]
VIT_LAYERS
=
[
-
2
,
-
9
]
...
@@ -190,35 +186,12 @@ class MultiHeadDotProductAttention(nn.Module):
...
@@ -190,35 +186,12 @@ class MultiHeadDotProductAttention(nn.Module):
)
)
# Detect attention implementation.
# Detect attention implementation.
selected_backend
:
Optional
[
_Backend
]
=
get_global_forced_attn_backend
()
self
.
attn_backend
:
_Backend
=
get_vit_attn_backend
()
if
selected_backend
is
None
:
if
self
.
attn_backend
not
in
{
backend_by_env_var
:
Optional
[
str
]
=
envs
.
VLLM_ATTENTION_BACKEND
_Backend
.
FLASH_ATTN
,
_Backend
.
TORCH_SDPA
,
_Backend
.
XFORMERS
if
backend_by_env_var
is
not
None
:
}:
selected_backend
=
backend_name_to_enum
(
backend_by_env_var
)
raise
RuntimeError
(
if
selected_backend
is
None
:
f
"Molmo does not support
{
self
.
attn_backend
}
backend now."
)
# For Volta and Turing GPUs, use xformers instead.
device_available
=
current_platform
.
get_device_capability
()[
0
]
>=
8
if
device_available
:
from
transformers.utils
import
is_flash_attn_2_available
if
is_flash_attn_2_available
():
self
.
_use_flash_attn
=
True
else
:
log
.
warning
(
"Current Molmo implementation has a bug with "
"`vllm-flash-attn` inside vision module, so we use "
"xformers backend instead. You can run `pip install "
"flash-attn to use flash-attention backend."
)
self
.
_use_flash_attn
=
False
else
:
self
.
_use_flash_attn
=
False
else
:
if
selected_backend
==
_Backend
.
FLASH_ATTN
:
self
.
_use_flash_attn
=
True
elif
selected_backend
==
_Backend
.
XFORMERS
:
self
.
_use_flash_attn
=
False
else
:
raise
RuntimeError
(
f
"Molmo does not support
{
selected_backend
}
backend now."
)
def
forward
(
self
,
def
forward
(
self
,
inputs_q
:
torch
.
Tensor
,
inputs_q
:
torch
.
Tensor
,
...
@@ -240,10 +213,15 @@ class MultiHeadDotProductAttention(nn.Module):
...
@@ -240,10 +213,15 @@ class MultiHeadDotProductAttention(nn.Module):
xk
=
xk
.
view
(
*
kv_shape
)
xk
=
xk
.
view
(
*
kv_shape
)
xv
=
xv
.
view
(
*
kv_shape
)
xv
=
xv
.
view
(
*
kv_shape
)
if
self
.
_use_flash_attn
:
if
self
.
attn_backend
==
_Backend
.
FLASH_ATTN
:
from
flash_attn
import
flash_attn_func
from
flash_attn
import
flash_attn_func
output
=
flash_attn_func
(
xq
,
xk
,
xv
,
dropout_p
=
0.0
,
causal
=
False
)
output
=
flash_attn_func
(
xq
,
xk
,
xv
,
dropout_p
=
0.0
,
causal
=
False
)
else
:
elif
self
.
attn_backend
==
_Backend
.
TORCH_SDPA
:
xq
,
xk
,
xv
=
(
rearrange
(
x
,
"b s h d -> b h s d"
)
for
x
in
(
xq
,
xk
,
xv
))
output
=
F
.
scaled_dot_product_attention
(
xq
,
xk
,
xv
)
output
=
rearrange
(
output
,
"b h s d -> b s h d "
)
elif
self
.
attn_backend
==
_Backend
.
XFORMERS
:
from
xformers
import
ops
as
xops
from
xformers
import
ops
as
xops
output
=
xops
.
memory_efficient_attention_forward
(
xq
,
xk
,
xv
,
p
=
0
)
output
=
xops
.
memory_efficient_attention_forward
(
xq
,
xk
,
xv
,
p
=
0
)
...
...
vllm/model_executor/models/qwen2_vl.py
View file @
cf1d62a6
...
@@ -39,10 +39,8 @@ from transformers.models.qwen2_vl.configuration_qwen2_vl import (
...
@@ -39,10 +39,8 @@ from transformers.models.qwen2_vl.configuration_qwen2_vl import (
from
transformers.models.qwen2_vl.image_processing_qwen2_vl
import
(
from
transformers.models.qwen2_vl.image_processing_qwen2_vl
import
(
make_batched_images
,
make_batched_videos
,
smart_resize
)
make_batched_images
,
make_batched_videos
,
smart_resize
)
import
vllm.envs
as
envs
from
vllm.attention
import
AttentionMetadata
from
vllm.attention
import
AttentionMetadata
from
vllm.attention.selector
import
(
_Backend
,
backend_name_to_enum
,
from
vllm.attention.selector
import
_Backend
get_global_forced_attn_backend
)
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.distributed
import
get_pp_group
,
parallel_state
from
vllm.distributed
import
get_pp_group
,
parallel_state
from
vllm.distributed
import
utils
as
dist_utils
from
vllm.distributed
import
utils
as
dist_utils
...
@@ -63,14 +61,13 @@ from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
...
@@ -63,14 +61,13 @@ from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
MultiModalInputs
)
MultiModalInputs
)
from
vllm.multimodal.base
import
MultiModalData
from
vllm.multimodal.base
import
MultiModalData
from
vllm.multimodal.image
import
cached_get_image_processor
from
vllm.multimodal.image
import
cached_get_image_processor
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
IntermediateTensors
,
SequenceData
from
vllm.sequence
import
IntermediateTensors
,
SequenceData
from
vllm.transformers_utils.config
import
uses_mrope
from
vllm.transformers_utils.config
import
uses_mrope
from
vllm.transformers_utils.processor
import
get_processor
from
vllm.transformers_utils.processor
import
get_processor
from
vllm.utils
import
is_cpu
from
.interfaces
import
SupportsMultiModal
,
SupportsPP
from
.interfaces
import
SupportsMultiModal
,
SupportsPP
from
.utils
import
(
PPMissingLayer
,
is_pp_missing_parameter
,
from
.utils
import
(
PPMissingLayer
,
get_vit_attn_backend
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
)
make_empty_intermediate_tensors_factory
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -215,37 +212,12 @@ class Qwen2VisionAttention(nn.Module):
...
@@ -215,37 +212,12 @@ class Qwen2VisionAttention(nn.Module):
quant_config
=
quant_config
)
quant_config
=
quant_config
)
# Detect attention implementation.
# Detect attention implementation.
selected_backend
:
Optional
[
_Backend
]
=
get_global_forced_attn_backend
()
self
.
attn_backend
:
_Backend
=
get_vit_attn_backend
()
if
selected_backend
is
None
:
if
self
.
attn_backend
not
in
{
backend_by_env_var
:
Optional
[
str
]
=
envs
.
VLLM_ATTENTION_BACKEND
_Backend
.
FLASH_ATTN
,
_Backend
.
TORCH_SDPA
,
_Backend
.
XFORMERS
if
backend_by_env_var
is
not
None
:
}:
selected_backend
=
backend_name_to_enum
(
backend_by_env_var
)
raise
RuntimeError
(
if
selected_backend
is
None
:
f
"Qwen2-VL does not support
{
self
.
attn_backend
}
backend now."
)
# For Volta and Turing GPUs, use xformers instead.
device_available
=
current_platform
.
has_device_capability
(
80
)
if
device_available
:
from
transformers.utils
import
is_flash_attn_2_available
if
is_flash_attn_2_available
():
self
.
_use_flash_attn
=
True
else
:
logger
.
warning
(
"Current Qwen2-VL implementation has a bug with "
"`vllm-flash-attn` inside vision module, so we use "
"xformers backend instead. You can run `pip install "
"flash-attn to use flash-attention backend."
)
self
.
_use_flash_attn
=
False
else
:
self
.
_use_flash_attn
=
False
else
:
if
selected_backend
==
_Backend
.
FLASH_ATTN
:
self
.
_use_flash_attn
=
True
elif
selected_backend
==
_Backend
.
XFORMERS
:
self
.
_use_flash_attn
=
False
else
:
raise
RuntimeError
(
f
"Qwen2-VL does not support
{
selected_backend
}
backend now."
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -274,7 +246,7 @@ class Qwen2VisionAttention(nn.Module):
...
@@ -274,7 +246,7 @@ class Qwen2VisionAttention(nn.Module):
q
=
apply_rotary_pos_emb_vision
(
q
,
rotary_pos_emb
)
q
=
apply_rotary_pos_emb_vision
(
q
,
rotary_pos_emb
)
k
=
apply_rotary_pos_emb_vision
(
k
,
rotary_pos_emb
)
k
=
apply_rotary_pos_emb_vision
(
k
,
rotary_pos_emb
)
if
self
.
_use_flash_attn
:
if
self
.
attn_backend
==
_Backend
.
FLASH_ATTN
:
# from vllm_flash_attn.flash_attn_interface import (
# from vllm_flash_attn.flash_attn_interface import (
# flash_attn_varlen_func)
# flash_attn_varlen_func)
from
flash_attn
import
flash_attn_varlen_func
from
flash_attn
import
flash_attn_varlen_func
...
@@ -295,7 +267,7 @@ class Qwen2VisionAttention(nn.Module):
...
@@ -295,7 +267,7 @@ class Qwen2VisionAttention(nn.Module):
context_layer
=
rearrange
(
output
,
context_layer
=
rearrange
(
output
,
"(b s) ... -> b s ..."
,
"(b s) ... -> b s ..."
,
b
=
batch_size
)
b
=
batch_size
)
elif
is_cpu
()
:
elif
self
.
attn_backend
==
_Backend
.
TORCH_SDPA
:
seq_length
=
q
.
size
(
1
)
seq_length
=
q
.
size
(
1
)
q
,
k
,
v
=
[
rearrange
(
x
,
"b s h d -> b h s d"
)
for
x
in
[
q
,
k
,
v
]]
q
,
k
,
v
=
[
rearrange
(
x
,
"b s h d -> b h s d"
)
for
x
in
[
q
,
k
,
v
]]
attention_mask
=
torch
.
zeros
([
1
,
seq_length
,
seq_length
],
attention_mask
=
torch
.
zeros
([
1
,
seq_length
,
seq_length
],
...
@@ -310,7 +282,7 @@ class Qwen2VisionAttention(nn.Module):
...
@@ -310,7 +282,7 @@ class Qwen2VisionAttention(nn.Module):
attention_mask
,
attention_mask
,
dropout_p
=
0.0
)
dropout_p
=
0.0
)
context_layer
=
rearrange
(
output
,
"b h s d -> b s h d "
)
context_layer
=
rearrange
(
output
,
"b h s d -> b s h d "
)
el
se
:
el
if
self
.
attn_backend
==
_Backend
.
XFORMERS
:
from
xformers
import
ops
as
xops
from
xformers
import
ops
as
xops
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalMask
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalMask
...
...
vllm/model_executor/models/utils.py
View file @
cf1d62a6
...
@@ -8,15 +8,22 @@ import torch.nn as nn
...
@@ -8,15 +8,22 @@ import torch.nn as nn
from
torch.func
import
functional_call
from
torch.func
import
functional_call
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
import
vllm.envs
as
envs
from
vllm.attention.selector
import
(
_Backend
,
backend_name_to_enum
,
get_global_forced_attn_backend
)
from
vllm.config
import
(
CacheConfig
,
LoRAConfig
,
MultiModalConfig
,
from
vllm.config
import
(
CacheConfig
,
LoRAConfig
,
MultiModalConfig
,
SchedulerConfig
)
SchedulerConfig
)
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.model_loader.loader
import
build_model
from
vllm.model_executor.model_loader.loader
import
build_model
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.multimodal.base
import
NestedTensors
from
vllm.multimodal.base
import
NestedTensors
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
is_pin_memory_available
from
vllm.utils
import
is_cpu
,
is_pin_memory_available
logger
=
init_logger
(
__name__
)
WeightsMapping
=
Mapping
[
str
,
Optional
[
str
]]
WeightsMapping
=
Mapping
[
str
,
Optional
[
str
]]
"""If a key maps to a value of `None`, the corresponding weight is ignored."""
"""If a key maps to a value of `None`, the corresponding weight is ignored."""
...
@@ -487,3 +494,29 @@ class LLMWrapper(nn.Module):
...
@@ -487,3 +494,29 @@ class LLMWrapper(nn.Module):
def
__call__
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Any
:
def
__call__
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Any
:
llm
=
super
().
__getattr__
(
self
.
model_name
)
llm
=
super
().
__getattr__
(
self
.
model_name
)
return
llm
(
*
args
,
**
kwargs
)
return
llm
(
*
args
,
**
kwargs
)
def
get_vit_attn_backend
()
->
_Backend
:
selected_backend
:
Optional
[
_Backend
]
=
get_global_forced_attn_backend
()
if
selected_backend
is
None
:
backend_by_env_var
:
Optional
[
str
]
=
envs
.
VLLM_ATTENTION_BACKEND
if
backend_by_env_var
is
not
None
:
selected_backend
=
backend_name_to_enum
(
backend_by_env_var
)
if
selected_backend
is
None
:
# For Volta and Turing GPUs, use xformers instead.
device_available
=
current_platform
.
has_device_capability
(
80
)
if
device_available
:
from
transformers.utils
import
is_flash_attn_2_available
if
is_flash_attn_2_available
():
selected_backend
=
_Backend
.
FLASH_ATTN
else
:
logger
.
warning
(
"Current `vllm-flash-attn` has a bug inside vision module, "
"so we use xformers backend instead. You can run "
"`pip install flash-attn` to use flash-attention backend."
)
selected_backend
=
_Backend
.
XFORMERS
elif
is_cpu
():
selected_backend
=
_Backend
.
TORCH_SDPA
else
:
selected_backend
=
_Backend
.
XFORMERS
return
selected_backend
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment