Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
421c4629
Unverified
Commit
421c4629
authored
Apr 03, 2025
by
Kyle Sayers
Committed by
GitHub
Apr 03, 2025
Browse files
[SupportsQuant] Bert, Blip, Blip2, Bloom (#15573)
Signed-off-by:
Kyle Sayers
<
kylesayrs@gmail.com
>
parent
84884cd9
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
16 additions
and
9 deletions
+16
-9
vllm/model_executor/models/bert.py
vllm/model_executor/models/bert.py
+6
-4
vllm/model_executor/models/blip.py
vllm/model_executor/models/blip.py
+4
-1
vllm/model_executor/models/blip2.py
vllm/model_executor/models/blip2.py
+4
-2
vllm/model_executor/models/bloom.py
vllm/model_executor/models/bloom.py
+2
-2
No files found.
vllm/model_executor/models/bert.py
View file @
421c4629
...
@@ -26,7 +26,7 @@ from vllm.sequence import IntermediateTensors, PoolerOutput
...
@@ -26,7 +26,7 @@ from vllm.sequence import IntermediateTensors, PoolerOutput
from
vllm.transformers_utils.config
import
(
from
vllm.transformers_utils.config
import
(
get_cross_encoder_activation_function
)
get_cross_encoder_activation_function
)
from
.interfaces
import
SupportsCrossEncoding
,
SupportsV0Only
from
.interfaces
import
SupportsCrossEncoding
,
SupportsQuant
,
SupportsV0Only
from
.utils
import
WeightsMapper
,
maybe_prefix
from
.utils
import
WeightsMapper
,
maybe_prefix
...
@@ -313,7 +313,8 @@ class BertOutput(nn.Module):
...
@@ -313,7 +313,8 @@ class BertOutput(nn.Module):
return
hidden_states
return
hidden_states
class
BertModel
(
nn
.
Module
):
class
BertModel
(
nn
.
Module
,
SupportsQuant
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"query"
,
"key"
,
"value"
]}
def
__init__
(
self
,
def
__init__
(
self
,
*
,
*
,
...
@@ -385,7 +386,7 @@ class BertModel(nn.Module):
...
@@ -385,7 +386,7 @@ class BertModel(nn.Module):
return
loaded_params
return
loaded_params
class
BertEmbeddingModel
(
nn
.
Module
,
SupportsV0Only
):
class
BertEmbeddingModel
(
nn
.
Module
,
SupportsV0Only
,
SupportsQuant
):
"""A model that uses Bert to provide embedding functionalities.
"""A model that uses Bert to provide embedding functionalities.
This class encapsulates the BertModel and provides an interface for
This class encapsulates the BertModel and provides an interface for
...
@@ -443,7 +444,8 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only):
...
@@ -443,7 +444,8 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only):
softmax
=
False
)
softmax
=
False
)
class
BertForSequenceClassification
(
nn
.
Module
,
SupportsCrossEncoding
):
class
BertForSequenceClassification
(
nn
.
Module
,
SupportsCrossEncoding
,
SupportsQuant
):
"""A model that uses Bert to provide embedding functionalities.
"""A model that uses Bert to provide embedding functionalities.
This class encapsulates the BertModel and provides an interface for
This class encapsulates the BertModel and provides an interface for
...
...
vllm/model_executor/models/blip.py
View file @
421c4629
...
@@ -16,6 +16,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...
@@ -16,6 +16,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
.interfaces
import
SupportsQuant
def
get_blip_patch_grid_length
(
*
,
image_size
:
int
,
patch_size
:
int
)
->
int
:
def
get_blip_patch_grid_length
(
*
,
image_size
:
int
,
patch_size
:
int
)
->
int
:
assert
image_size
%
patch_size
==
0
assert
image_size
%
patch_size
==
0
...
@@ -243,9 +245,10 @@ class BlipEncoder(nn.Module):
...
@@ -243,9 +245,10 @@ class BlipEncoder(nn.Module):
return
hidden_states
return
hidden_states
class
BlipVisionModel
(
nn
.
Module
):
class
BlipVisionModel
(
nn
.
Module
,
SupportsQuant
):
config_class
=
BlipVisionConfig
config_class
=
BlipVisionConfig
main_input_name
=
"pixel_values"
main_input_name
=
"pixel_values"
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
]}
def
__init__
(
def
__init__
(
self
,
self
,
...
...
vllm/model_executor/models/blip2.py
View file @
421c4629
...
@@ -24,7 +24,8 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
...
@@ -24,7 +24,8 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
.blip
import
BlipVisionModel
from
.blip
import
BlipVisionModel
from
.interfaces
import
MultiModalEmbeddings
,
SupportsMultiModal
,
SupportsPP
from
.interfaces
import
(
MultiModalEmbeddings
,
SupportsMultiModal
,
SupportsPP
,
SupportsQuant
)
from
.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
init_vllm_registered_model
,
from
.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
init_vllm_registered_model
,
maybe_prefix
,
merge_multimodal_embeddings
)
maybe_prefix
,
merge_multimodal_embeddings
)
...
@@ -498,7 +499,8 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]):
...
@@ -498,7 +499,8 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]):
@
MULTIMODAL_REGISTRY
.
register_processor
(
Blip2MultiModalProcessor
,
@
MULTIMODAL_REGISTRY
.
register_processor
(
Blip2MultiModalProcessor
,
info
=
Blip2ProcessingInfo
,
info
=
Blip2ProcessingInfo
,
dummy_inputs
=
Blip2DummyInputsBuilder
)
dummy_inputs
=
Blip2DummyInputsBuilder
)
class
Blip2ForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
):
class
Blip2ForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
,
SupportsQuant
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
...
...
vllm/model_executor/models/bloom.py
View file @
421c4629
...
@@ -42,7 +42,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...
@@ -42,7 +42,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsPP
,
SupportsV0Only
from
.interfaces
import
SupportsPP
,
SupportsQuant
,
SupportsV0Only
from
.utils
import
(
is_pp_missing_parameter
,
from
.utils
import
(
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
maybe_prefix
)
...
@@ -279,7 +279,7 @@ class BloomModel(nn.Module):
...
@@ -279,7 +279,7 @@ class BloomModel(nn.Module):
return
hidden_states
return
hidden_states
class
BloomForCausalLM
(
nn
.
Module
,
SupportsPP
,
SupportsV0Only
):
class
BloomForCausalLM
(
nn
.
Module
,
SupportsPP
,
SupportsV0Only
,
SupportsQuant
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
super
().
__init__
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment