"vscode:/vscode.git/clone" did not exist on "d1adb9b4032dd430bb28b8e91feb8164c3a1ca9c"
Unverified Commit 12913d17 authored by Kyle Sayers's avatar Kyle Sayers Committed by GitHub
Browse files

[Quant] Add `SupportsQuant` to phi3 and clip (#13104)

parent 80f63a39
...@@ -30,6 +30,7 @@ class QuarkConfig(QuantizationConfig): ...@@ -30,6 +30,7 @@ class QuarkConfig(QuantizationConfig):
kv_cache_group: Optional[List[str]] = None, kv_cache_group: Optional[List[str]] = None,
kv_cache_config: Optional[Dict[str, Any]] = None, kv_cache_config: Optional[Dict[str, Any]] = None,
pack_method: str = "reorder"): pack_method: str = "reorder"):
super().__init__()
if kv_cache_group is None: if kv_cache_group is None:
kv_cache_group = [] kv_cache_group = []
self.quant_config = quant_config self.quant_config = quant_config
......
...@@ -21,6 +21,7 @@ class Int8TpuConfig(QuantizationConfig): ...@@ -21,6 +21,7 @@ class Int8TpuConfig(QuantizationConfig):
self, self,
activation_scheme: str = "none", activation_scheme: str = "none",
) -> None: ) -> None:
super().__init__()
if activation_scheme not in ACTIVATION_SCHEMES: if activation_scheme not in ACTIVATION_SCHEMES:
raise ValueError( raise ValueError(
f"Unsupported activation scheme {activation_scheme}") f"Unsupported activation scheme {activation_scheme}")
......
...@@ -15,6 +15,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -15,6 +15,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsQuant
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs
...@@ -335,10 +336,10 @@ class CLIPVisionTransformer(nn.Module): ...@@ -335,10 +336,10 @@ class CLIPVisionTransformer(nn.Module):
return encoder_outputs return encoder_outputs
class CLIPVisionModel(nn.Module): class CLIPVisionModel(nn.Module, SupportsQuant):
config_class = CLIPVisionConfig config_class = CLIPVisionConfig
main_input_name = "pixel_values" main_input_name = "pixel_values"
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
def __init__( def __init__(
self, self,
......
...@@ -7,6 +7,8 @@ import torch ...@@ -7,6 +7,8 @@ import torch
from typing_extensions import TypeIs, TypeVar from typing_extensions import TypeIs, TypeVar
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.utils import supports_kw from vllm.utils import supports_kw
from .interfaces_base import is_pooling_model from .interfaces_base import is_pooling_model
...@@ -443,6 +445,36 @@ def supports_cross_encoding( ...@@ -443,6 +445,36 @@ def supports_cross_encoding(
return is_pooling_model(model) and _supports_cross_encoding(model) return is_pooling_model(model) and _supports_cross_encoding(model)
class SupportsQuant:
"""The interface required for all models that support quantization."""
packed_modules_mapping: ClassVar[Dict[str, List[str]]] = {}
quant_config: Optional[QuantizationConfig] = None
def __new__(cls, *args, **kwargs) -> "SupportsQuant":
instance = super().__new__(cls)
quant_config = cls._find_quant_config(*args, **kwargs)
if quant_config is not None:
instance.quant_config = quant_config
instance.quant_config.packed_modules_mapping.update(
cls.packed_modules_mapping)
return instance
@staticmethod
def _find_quant_config(*args, **kwargs) -> Optional[QuantizationConfig]:
from vllm.config import VllmConfig # avoid circular import
args_values = list(args) + list(kwargs.values())
for arg in args_values:
if isinstance(arg, VllmConfig):
return arg.quant_config
if isinstance(arg, QuantizationConfig):
return arg
return None
@runtime_checkable @runtime_checkable
class SupportsTranscription(Protocol): class SupportsTranscription(Protocol):
"""The interface required for all models that support transcription.""" """The interface required for all models that support transcription."""
......
...@@ -50,7 +50,7 @@ from vllm.sequence import IntermediateTensors ...@@ -50,7 +50,7 @@ from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of from vllm.utils import is_list_of
from .clip import CLIPVisionModel from .clip import CLIPVisionModel
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP, SupportsQuant
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix, init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings) merge_multimodal_embeddings)
...@@ -498,7 +498,8 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]): ...@@ -498,7 +498,8 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
@MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor, @MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor,
info=Phi3VProcessingInfo, info=Phi3VProcessingInfo,
dummy_inputs=Phi3VDummyInputsBuilder) dummy_inputs=Phi3VDummyInputsBuilder)
class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
SupportsQuant):
hf_to_vllm_mapper = WeightsMapper( hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={ orig_to_new_prefix={
"model.vision_embed_tokens.wte": "embed_tokens", "model.vision_embed_tokens.wte": "embed_tokens",
...@@ -510,7 +511,6 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -510,7 +511,6 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config multimodal_config = vllm_config.model_config.multimodal_config
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
...@@ -520,14 +520,14 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -520,14 +520,14 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
org_num_embeddings=config.vocab_size, org_num_embeddings=config.vocab_size,
quant_config=quant_config, quant_config=self.quant_config,
prefix=maybe_prefix(prefix, "model.embed_tokens"), prefix=maybe_prefix(prefix, "model.embed_tokens"),
) )
# TODO: Optionally initializes this for supporting input embeddings. # TODO: Optionally initializes this for supporting input embeddings.
self.vision_embed_tokens = Phi3HDImageEmbedding( self.vision_embed_tokens = Phi3HDImageEmbedding(
config, config,
quant_config, self.quant_config,
prefix=maybe_prefix(prefix, "model.vision_embed_tokens")) prefix=maybe_prefix(prefix, "model.vision_embed_tokens"))
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment