"vllm/vscode:/vscode.git/clone" did not exist on "0d06b533a0fcca7a62603c868df68235659d6935"
Unverified Commit 421c4629 authored by Kyle Sayers's avatar Kyle Sayers Committed by GitHub
Browse files

[SupportsQuant] Bert, Blip, Blip2, Bloom (#15573)


Signed-off-by: default avatarKyle Sayers <kylesayrs@gmail.com>
parent 84884cd9
...@@ -26,7 +26,7 @@ from vllm.sequence import IntermediateTensors, PoolerOutput ...@@ -26,7 +26,7 @@ from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.transformers_utils.config import ( from vllm.transformers_utils.config import (
get_cross_encoder_activation_function) get_cross_encoder_activation_function)
from .interfaces import SupportsCrossEncoding, SupportsV0Only from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only
from .utils import WeightsMapper, maybe_prefix from .utils import WeightsMapper, maybe_prefix
...@@ -313,7 +313,8 @@ class BertOutput(nn.Module): ...@@ -313,7 +313,8 @@ class BertOutput(nn.Module):
return hidden_states return hidden_states
class BertModel(nn.Module): class BertModel(nn.Module, SupportsQuant):
packed_modules_mapping = {"qkv_proj": ["query", "key", "value"]}
def __init__(self, def __init__(self,
*, *,
...@@ -385,7 +386,7 @@ class BertModel(nn.Module): ...@@ -385,7 +386,7 @@ class BertModel(nn.Module):
return loaded_params return loaded_params
class BertEmbeddingModel(nn.Module, SupportsV0Only): class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
"""A model that uses Bert to provide embedding functionalities. """A model that uses Bert to provide embedding functionalities.
This class encapsulates the BertModel and provides an interface for This class encapsulates the BertModel and provides an interface for
...@@ -443,7 +444,8 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only): ...@@ -443,7 +444,8 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only):
softmax=False) softmax=False)
class BertForSequenceClassification(nn.Module, SupportsCrossEncoding): class BertForSequenceClassification(nn.Module, SupportsCrossEncoding,
SupportsQuant):
"""A model that uses Bert to provide embedding functionalities. """A model that uses Bert to provide embedding functionalities.
This class encapsulates the BertModel and provides an interface for This class encapsulates the BertModel and provides an interface for
......
...@@ -16,6 +16,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -16,6 +16,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from .interfaces import SupportsQuant
def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int: def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
assert image_size % patch_size == 0 assert image_size % patch_size == 0
...@@ -243,9 +245,10 @@ class BlipEncoder(nn.Module): ...@@ -243,9 +245,10 @@ class BlipEncoder(nn.Module):
return hidden_states return hidden_states
class BlipVisionModel(nn.Module): class BlipVisionModel(nn.Module, SupportsQuant):
config_class = BlipVisionConfig config_class = BlipVisionConfig
main_input_name = "pixel_values" main_input_name = "pixel_values"
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
def __init__( def __init__(
self, self,
......
...@@ -24,7 +24,8 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs ...@@ -24,7 +24,8 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .blip import BlipVisionModel from .blip import BlipVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP,
SupportsQuant)
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings) maybe_prefix, merge_multimodal_embeddings)
...@@ -498,7 +499,8 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]): ...@@ -498,7 +499,8 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]):
@MULTIMODAL_REGISTRY.register_processor(Blip2MultiModalProcessor, @MULTIMODAL_REGISTRY.register_processor(Blip2MultiModalProcessor,
info=Blip2ProcessingInfo, info=Blip2ProcessingInfo,
dummy_inputs=Blip2DummyInputsBuilder) dummy_inputs=Blip2DummyInputsBuilder)
class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
SupportsQuant):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
......
...@@ -42,7 +42,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -42,7 +42,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP, SupportsV0Only from .interfaces import SupportsPP, SupportsQuant, SupportsV0Only
from .utils import (is_pp_missing_parameter, from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix) maybe_prefix)
...@@ -279,7 +279,7 @@ class BloomModel(nn.Module): ...@@ -279,7 +279,7 @@ class BloomModel(nn.Module):
return hidden_states return hidden_states
class BloomForCausalLM(nn.Module, SupportsPP, SupportsV0Only): class BloomForCausalLM(nn.Module, SupportsPP, SupportsV0Only, SupportsQuant):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment