Unverified Commit 4a9952ec authored by LoganJane's avatar LoganJane Committed by GitHub
Browse files

[Bugfix] Add quant_config in ViT of Kimi-K2.5 (#34501)


Signed-off-by: default avatarLoganJane <LoganJane73@hotmail.com>
Signed-off-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
parent 1dae7b78
...@@ -23,6 +23,10 @@ from transformers.processing_utils import ProcessorMixin ...@@ -23,6 +23,10 @@ from transformers.processing_utils import ProcessorMixin
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (
CompressedTensorsConfig,
)
from vllm.model_executor.models.interfaces import ( from vllm.model_executor.models.interfaces import (
SupportsMultiModal, SupportsMultiModal,
SupportsPP, SupportsPP,
...@@ -361,6 +365,7 @@ class KimiK25ForConditionalGeneration( ...@@ -361,6 +365,7 @@ class KimiK25ForConditionalGeneration(
with self._mark_tower_model(vllm_config, "vision_chunk"): with self._mark_tower_model(vllm_config, "vision_chunk"):
self.vision_tower = MoonViT3dPretrainedModel( self.vision_tower = MoonViT3dPretrainedModel(
config.vision_config, config.vision_config,
quant_config=self._maybe_ignore_quant_config(quant_config),
prefix=maybe_prefix(prefix, "vision_tower"), prefix=maybe_prefix(prefix, "vision_tower"),
) )
self.vision_tower = self.vision_tower.to( self.vision_tower = self.vision_tower.to(
...@@ -370,6 +375,7 @@ class KimiK25ForConditionalGeneration( ...@@ -370,6 +375,7 @@ class KimiK25ForConditionalGeneration(
self.mm_projector = KimiK25MultiModalProjector( self.mm_projector = KimiK25MultiModalProjector(
config=config.vision_config, config=config.vision_config,
use_data_parallel=self.use_data_parallel, use_data_parallel=self.use_data_parallel,
quant_config=self._maybe_ignore_quant_config(quant_config),
prefix=maybe_prefix(prefix, "mm_projector"), prefix=maybe_prefix(prefix, "mm_projector"),
) )
self.mm_projector = self.mm_projector.to( self.mm_projector = self.mm_projector.to(
...@@ -389,6 +395,11 @@ class KimiK25ForConditionalGeneration( ...@@ -389,6 +395,11 @@ class KimiK25ForConditionalGeneration(
) )
self.media_placeholder: int = self.config.media_placeholder_token_id self.media_placeholder: int = self.config.media_placeholder_token_id
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
if isinstance(quant_config, CompressedTensorsConfig):
return None
return quant_config
def _parse_and_validate_media_input( def _parse_and_validate_media_input(
self, **kwargs: object self, **kwargs: object
) -> KimiK25MediaPixelInputs | None: ) -> KimiK25MediaPixelInputs | None:
......
...@@ -28,6 +28,7 @@ from vllm.model_executor.layers.linear import ( ...@@ -28,6 +28,7 @@ from vllm.model_executor.layers.linear import (
ReplicatedLinear, ReplicatedLinear,
RowParallelLinear, RowParallelLinear,
) )
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.models.utils import maybe_prefix from vllm.model_executor.models.utils import maybe_prefix
from vllm.model_executor.models.vision import ( from vllm.model_executor.models.vision import (
is_vit_use_data_parallel, is_vit_use_data_parallel,
...@@ -304,6 +305,7 @@ class MLP2(nn.Module): ...@@ -304,6 +305,7 @@ class MLP2(nn.Module):
dims: list[int], dims: list[int],
activation, activation,
bias: bool = True, bias: bool = True,
quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False, use_data_parallel: bool = False,
): ):
...@@ -314,6 +316,7 @@ class MLP2(nn.Module): ...@@ -314,6 +316,7 @@ class MLP2(nn.Module):
dims[0], dims[0],
dims[1], dims[1],
bias=bias, bias=bias,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "fc0"), prefix=maybe_prefix(prefix, "fc0"),
disable_tp=self.use_data_parallel, disable_tp=self.use_data_parallel,
) )
...@@ -321,6 +324,7 @@ class MLP2(nn.Module): ...@@ -321,6 +324,7 @@ class MLP2(nn.Module):
dims[1], dims[1],
dims[2], dims[2],
bias=bias, bias=bias,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "fc1"), prefix=maybe_prefix(prefix, "fc1"),
disable_tp=self.use_data_parallel, disable_tp=self.use_data_parallel,
) )
...@@ -341,6 +345,7 @@ class MoonViTEncoderLayer(nn.Module): ...@@ -341,6 +345,7 @@ class MoonViTEncoderLayer(nn.Module):
num_heads: int, num_heads: int,
hidden_dim: int, hidden_dim: int,
mlp_dim: int, mlp_dim: int,
quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
*, *,
activation=F.gelu, activation=F.gelu,
...@@ -362,6 +367,7 @@ class MoonViTEncoderLayer(nn.Module): ...@@ -362,6 +367,7 @@ class MoonViTEncoderLayer(nn.Module):
self.mlp = MLP2( self.mlp = MLP2(
[hidden_dim, mlp_dim, hidden_dim], [hidden_dim, mlp_dim, hidden_dim],
activation, activation,
quant_config=quant_config,
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
use_data_parallel=self.use_data_parallel, use_data_parallel=self.use_data_parallel,
) )
...@@ -371,6 +377,7 @@ class MoonViTEncoderLayer(nn.Module): ...@@ -371,6 +377,7 @@ class MoonViTEncoderLayer(nn.Module):
total_num_heads=num_heads, total_num_heads=num_heads,
total_num_kv_heads=num_heads, total_num_kv_heads=num_heads,
bias=attn_bias, bias=attn_bias,
quant_config=quant_config,
prefix=f"{prefix}.wqkv", prefix=f"{prefix}.wqkv",
disable_tp=self.use_data_parallel, disable_tp=self.use_data_parallel,
) )
...@@ -378,6 +385,7 @@ class MoonViTEncoderLayer(nn.Module): ...@@ -378,6 +385,7 @@ class MoonViTEncoderLayer(nn.Module):
hidden_dim, hidden_dim,
hidden_dim, hidden_dim,
bias=attn_bias, bias=attn_bias,
quant_config=quant_config,
prefix=f"{prefix}.wo", prefix=f"{prefix}.wo",
disable_tp=self.use_data_parallel, disable_tp=self.use_data_parallel,
) )
...@@ -461,6 +469,7 @@ class MoonViT3dEncoder(nn.Module): ...@@ -461,6 +469,7 @@ class MoonViT3dEncoder(nn.Module):
num_layers: int, num_layers: int,
block_cfg: dict, block_cfg: dict,
video_attn_type: str = "spatial_temporal", video_attn_type: str = "spatial_temporal",
quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -476,6 +485,7 @@ class MoonViT3dEncoder(nn.Module): ...@@ -476,6 +485,7 @@ class MoonViT3dEncoder(nn.Module):
[ [
MoonViTEncoderLayer( MoonViTEncoderLayer(
**block_cfg, **block_cfg,
quant_config=quant_config,
prefix=f"{prefix}.blocks.{layer_idx}", prefix=f"{prefix}.blocks.{layer_idx}",
) )
for layer_idx in range(num_layers) for layer_idx in range(num_layers)
...@@ -544,6 +554,7 @@ class MoonViT3dPretrainedModel(nn.Module): ...@@ -544,6 +554,7 @@ class MoonViT3dPretrainedModel(nn.Module):
def __init__( def __init__(
self, self,
config: KimiK25VisionConfig, config: KimiK25VisionConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
...@@ -573,6 +584,7 @@ class MoonViT3dPretrainedModel(nn.Module): ...@@ -573,6 +584,7 @@ class MoonViT3dPretrainedModel(nn.Module):
"attn_bias": True, "attn_bias": True,
}, },
video_attn_type=config.video_attn_type, video_attn_type=config.video_attn_type,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "encoder"), prefix=maybe_prefix(prefix, "encoder"),
) )
...@@ -646,6 +658,7 @@ class KimiK25MultiModalProjector(nn.Module): ...@@ -646,6 +658,7 @@ class KimiK25MultiModalProjector(nn.Module):
self, self,
config: KimiK25VisionConfig, config: KimiK25VisionConfig,
use_data_parallel: bool = False, use_data_parallel: bool = False,
quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
...@@ -660,12 +673,14 @@ class KimiK25MultiModalProjector(nn.Module): ...@@ -660,12 +673,14 @@ class KimiK25MultiModalProjector(nn.Module):
self.hidden_size, self.hidden_size,
self.hidden_size, self.hidden_size,
bias=True, bias=True,
quant_config=quant_config,
prefix=f"{prefix}.linear_1", prefix=f"{prefix}.linear_1",
) )
self.linear_2 = ReplicatedLinear( self.linear_2 = ReplicatedLinear(
self.hidden_size, self.hidden_size,
config.mm_hidden_size, config.mm_hidden_size,
bias=True, bias=True,
quant_config=quant_config,
prefix=f"{prefix}.linear_2", prefix=f"{prefix}.linear_2",
) )
self.act = GELUActivation() self.act = GELUActivation()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment