Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ca47e176
Unverified
Commit
ca47e176
authored
Jan 09, 2025
by
Cyrus Leung
Committed by
GitHub
Jan 08, 2025
Browse files
[Misc] Move some model utils into vision file (#11848)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
78f4590b
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
94 additions
and
92 deletions
+94
-92
vllm/model_executor/models/clip.py
vllm/model_executor/models/clip.py
+2
-3
vllm/model_executor/models/pixtral.py
vllm/model_executor/models/pixtral.py
+2
-3
vllm/model_executor/models/qwen2_vl.py
vllm/model_executor/models/qwen2_vl.py
+2
-1
vllm/model_executor/models/siglip.py
vllm/model_executor/models/siglip.py
+2
-3
vllm/model_executor/models/utils.py
vllm/model_executor/models/utils.py
+1
-36
vllm/model_executor/models/vision.py
vllm/model_executor/models/vision.py
+82
-1
vllm/multimodal/inputs.py
vllm/multimodal/inputs.py
+3
-1
vllm/multimodal/utils.py
vllm/multimodal/utils.py
+0
-44
No files found.
vllm/model_executor/models/clip.py
View file @
ca47e176
...
...
@@ -20,11 +20,10 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.multimodal.utils
import
(
cached_get_tokenizer
,
consecutive_placeholder_ranges
,
repeat_and_pad_placeholder_tokens
,
resolve_visual_encoder_outputs
)
repeat_and_pad_placeholder_tokens
)
from
vllm.sequence
import
SequenceData
from
.vision
import
VisionEncoderInfo
from
.vision
import
VisionEncoderInfo
,
resolve_visual_encoder_outputs
def
get_clip_patch_grid_length
(
*
,
image_size
:
int
,
patch_size
:
int
)
->
int
:
...
...
vllm/model_executor/models/pixtral.py
View file @
ca47e176
...
...
@@ -31,14 +31,13 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalKwargs
from
vllm.multimodal.inputs
import
NestedTensors
,
PlaceholderRange
from
vllm.multimodal.utils
import
(
cached_get_tokenizer
,
consecutive_placeholder_ranges
,
resolve_visual_encoder_outputs
)
consecutive_placeholder_ranges
)
from
vllm.sequence
import
IntermediateTensors
,
SequenceData
from
.interfaces
import
SupportsMultiModal
,
SupportsPP
from
.utils
import
(
init_vllm_registered_model
,
maybe_prefix
,
merge_multimodal_embeddings
)
from
.vision
import
VisionEncoderInfo
from
.vision
import
VisionEncoderInfo
,
resolve_visual_encoder_outputs
try
:
from
xformers
import
ops
as
xops
...
...
vllm/model_executor/models/qwen2_vl.py
View file @
ca47e176
...
...
@@ -66,8 +66,9 @@ from vllm.sequence import IntermediateTensors
from
vllm.transformers_utils.config
import
uses_mrope
from
.interfaces
import
SupportsLoRA
,
SupportsMultiModal
,
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
get_vit_attn_backend
,
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
init_vllm_registered_model
,
maybe_prefix
)
from
.vision
import
get_vit_attn_backend
logger
=
init_logger
(
__name__
)
...
...
vllm/model_executor/models/siglip.py
View file @
ca47e176
...
...
@@ -24,11 +24,10 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.multimodal.utils
import
(
cached_get_tokenizer
,
consecutive_placeholder_ranges
,
repeat_and_pad_placeholder_tokens
,
resolve_visual_encoder_outputs
)
repeat_and_pad_placeholder_tokens
)
from
vllm.sequence
import
SequenceData
from
.vision
import
VisionEncoderInfo
from
.vision
import
VisionEncoderInfo
,
resolve_visual_encoder_outputs
def
get_siglip_patch_grid_length
(
*
,
image_size
:
int
,
patch_size
:
int
)
->
int
:
...
...
vllm/model_executor/models/utils.py
View file @
ca47e176
...
...
@@ -8,16 +8,12 @@ import torch.nn as nn
from
torch.func
import
functional_call
from
transformers
import
PretrainedConfig
import
vllm.envs
as
envs
from
vllm.attention.selector
import
(
backend_name_to_enum
,
get_global_forced_attn_backend
)
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.multimodal
import
MultiModalPlaceholderMap
,
NestedTensors
from
vllm.platforms
import
_Backend
,
current_platform
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
is_pin_memory_available
,
print_warning_once
from
vllm.utils
import
is_pin_memory_available
logger
=
init_logger
(
__name__
)
...
...
@@ -612,37 +608,6 @@ def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int):
return
make_empty_intermediate_tensors
def
get_vit_attn_backend
(
support_fa
:
bool
=
False
)
->
_Backend
:
"""
Get the available attention backend for Vision Transformer.
"""
# TODO(Isotr0py): Remove `support_fa` after support FA for all ViTs attn.
selected_backend
:
Optional
[
_Backend
]
=
get_global_forced_attn_backend
()
if
selected_backend
is
None
:
backend_by_env_var
:
Optional
[
str
]
=
envs
.
VLLM_ATTENTION_BACKEND
if
backend_by_env_var
is
not
None
:
selected_backend
=
backend_name_to_enum
(
backend_by_env_var
)
if
selected_backend
is
None
:
# For Volta and Turing GPUs, use xformers instead.
device_available
=
current_platform
.
has_device_capability
(
80
)
if
device_available
and
support_fa
:
from
transformers.utils
import
is_flash_attn_2_available
if
is_flash_attn_2_available
():
selected_backend
=
_Backend
.
FLASH_ATTN
else
:
print_warning_once
(
"Current `vllm-flash-attn` has a bug inside vision module, "
"so we use xformers backend instead. You can run "
"`pip install flash-attn` to use flash-attention backend."
)
selected_backend
=
_Backend
.
XFORMERS
elif
current_platform
.
is_cpu
()
or
current_platform
.
is_rocm
():
# ROCM doesn't support xformers
selected_backend
=
_Backend
.
TORCH_SDPA
else
:
selected_backend
=
_Backend
.
XFORMERS
return
selected_backend
def
maybe_prefix
(
prefix
:
str
,
name
:
str
)
->
str
:
"""Add a prefix to a name if the prefix is non-empty.
...
...
vllm/model_executor/models/vision.py
View file @
ca47e176
from
abc
import
ABC
,
abstractmethod
from
typing
import
Final
,
Generic
,
Protocol
,
TypeVar
from
typing
import
Final
,
Generic
,
Optional
,
Protocol
,
TypeVar
,
Union
import
torch
from
transformers
import
PretrainedConfig
import
vllm.envs
as
envs
from
vllm.attention.selector
import
(
backend_name_to_enum
,
get_global_forced_attn_backend
)
from
vllm.platforms
import
_Backend
,
current_platform
from
vllm.utils
import
print_warning_once
_C
=
TypeVar
(
"_C"
,
bound
=
PretrainedConfig
)
...
...
@@ -60,3 +67,77 @@ def get_vision_encoder_info(
msg
=
f
"Unsupported vision config:
{
type
(
vision_config
)
}
"
raise
NotImplementedError
(
msg
)
def
get_vit_attn_backend
(
support_fa
:
bool
=
False
)
->
_Backend
:
"""
Get the available attention backend for Vision Transformer.
"""
# TODO(Isotr0py): Remove `support_fa` after support FA for all ViTs attn.
selected_backend
:
Optional
[
_Backend
]
=
get_global_forced_attn_backend
()
if
selected_backend
is
None
:
backend_by_env_var
:
Optional
[
str
]
=
envs
.
VLLM_ATTENTION_BACKEND
if
backend_by_env_var
is
not
None
:
selected_backend
=
backend_name_to_enum
(
backend_by_env_var
)
if
selected_backend
is
None
:
# For Volta and Turing GPUs, use xformers instead.
device_available
=
current_platform
.
has_device_capability
(
80
)
if
device_available
and
support_fa
:
from
transformers.utils
import
is_flash_attn_2_available
if
is_flash_attn_2_available
():
selected_backend
=
_Backend
.
FLASH_ATTN
else
:
print_warning_once
(
"Current `vllm-flash-attn` has a bug inside vision module, "
"so we use xformers backend instead. You can run "
"`pip install flash-attn` to use flash-attention backend."
)
selected_backend
=
_Backend
.
XFORMERS
elif
current_platform
.
is_cpu
()
or
current_platform
.
is_rocm
():
# ROCM doesn't support xformers
selected_backend
=
_Backend
.
TORCH_SDPA
else
:
selected_backend
=
_Backend
.
XFORMERS
return
selected_backend
def
resolve_visual_encoder_outputs
(
encoder_outputs
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
feature_sample_layers
:
Optional
[
list
[
int
]],
post_layer_norm
:
Optional
[
torch
.
nn
.
LayerNorm
],
max_possible_layers
:
int
,
)
->
torch
.
Tensor
:
"""Given the outputs a visual encoder module that may correspond to the
output of the last layer, or a list of hidden states to be stacked,
handle post normalization and resolve it into a single output tensor.
Args:
encoder_outputs: Output of encoder's last layer or all hidden states.
feature_sample_layers: Optional layer indices to grab from the encoder
outputs; if provided, encoder outputs must be a list.
post_layer_norm: Post norm to apply to the output of the encoder.
max_possible_layers: Total layers in the fully loaded visual encoder.
"""
if
feature_sample_layers
is
None
:
if
post_layer_norm
is
not
None
:
return
post_layer_norm
(
encoder_outputs
)
return
encoder_outputs
# Get the hidden states corresponding to the layer indices.
# Negative values are relative to the full visual encoder,
# so offset them depending on how many layers were loaded.
# NOTE: this assumes that encoder_outputs contains a list
# of hidden states in the same order as the encoder layers
# that produced them.
offset
=
max_possible_layers
-
len
(
encoder_outputs
)
hs_pool
=
[
encoder_outputs
[
layer_idx
]
if
layer_idx
>=
0
else
encoder_outputs
[
layer_idx
+
offset
]
for
layer_idx
in
feature_sample_layers
]
# Apply post-norm on the final hidden state if we are using it
uses_last_layer
=
feature_sample_layers
[
-
1
]
in
(
len
(
hs_pool
)
-
1
,
-
1
)
if
post_layer_norm
is
not
None
and
uses_last_layer
:
hs_pool
[
-
1
]
=
post_layer_norm
(
encoder_outputs
)
return
torch
.
cat
(
hs_pool
,
dim
=-
1
)
vllm/multimodal/inputs.py
View file @
ca47e176
...
...
@@ -99,6 +99,8 @@ class MultiModalDataBuiltins(TypedDict, total=False):
MultiModalDataDict
:
TypeAlias
=
Mapping
[
str
,
ModalityData
[
Any
]]
"""
A dictionary containing an entry for each modality type to input.
The built-in modalities are defined by :class:`MultiModalDataBuiltins`.
"""
...
...
@@ -485,7 +487,7 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
MultiModalPlaceholderDict
=
Mapping
[
str
,
Sequence
[
PlaceholderRange
]]
"""
A dictionary containing placeholder ranges.
A dictionary containing placeholder ranges
for each modality
.
"""
...
...
vllm/multimodal/utils.py
View file @
ca47e176
...
...
@@ -5,7 +5,6 @@ from urllib.parse import ParseResult, urlparse
import
numpy
as
np
import
numpy.typing
as
npt
import
torch
from
PIL
import
Image
import
vllm.envs
as
envs
...
...
@@ -285,49 +284,6 @@ def encode_video_base64(frames: npt.NDArray) -> str:
return
video_io
.
encode_base64
(
frames
)
def
resolve_visual_encoder_outputs
(
encoder_outputs
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
feature_sample_layers
:
Optional
[
list
[
int
]],
post_layer_norm
:
Optional
[
torch
.
nn
.
LayerNorm
],
max_possible_layers
:
int
,
)
->
torch
.
Tensor
:
"""Given the outputs a visual encoder module that may correspond to the
output of the last layer, or a list of hidden states to be stacked,
handle post normalization and resolve it into a single output tensor.
Args:
encoder_outputs: Output of encoder's last layer or all hidden states.
feature_sample_layers: Optional layer indices to grab from the encoder
outputs; if provided, encoder outputs must be a list.
post_layer_norm: Post norm to apply to the output of the encoder.
max_possible_layers: Total layers in the fully loaded visual encoder.
"""
if
feature_sample_layers
is
None
:
if
post_layer_norm
is
not
None
:
return
post_layer_norm
(
encoder_outputs
)
return
encoder_outputs
# Get the hidden states corresponding to the layer indices.
# Negative values are relative to the full visual encoder,
# so offset them depending on how many layers were loaded.
# NOTE: this assumes that encoder_outputs contains a list
# of hidden states in the same order as the encoder layers
# that produced them.
offset
=
max_possible_layers
-
len
(
encoder_outputs
)
hs_pool
=
[
encoder_outputs
[
layer_idx
]
if
layer_idx
>=
0
else
encoder_outputs
[
layer_idx
+
offset
]
for
layer_idx
in
feature_sample_layers
]
# Apply post-norm on the final hidden state if we are using it
uses_last_layer
=
feature_sample_layers
[
-
1
]
in
(
len
(
hs_pool
)
-
1
,
-
1
)
if
post_layer_norm
is
not
None
and
uses_last_layer
:
hs_pool
[
-
1
]
=
post_layer_norm
(
encoder_outputs
)
return
torch
.
cat
(
hs_pool
,
dim
=-
1
)
# Utilities for input processors
_T
=
TypeVar
(
"_T"
,
str
,
int
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment