Unverified Commit c66c7f86 authored by Roger Wang's avatar Roger Wang Committed by GitHub
Browse files

[Bugfix] Fix PaliGemma MMP (#6930)

parent 6e063ea3
......@@ -9,7 +9,6 @@ from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
......@@ -133,12 +132,10 @@ class PaliGemmaMultiModalProjector(nn.Module):
def __init__(self, vision_hidden_size: int, projection_dim: int):
super().__init__()
self.linear = ColumnParallelLinear(vision_hidden_size,
projection_dim,
bias=True)
self.linear = nn.Linear(vision_hidden_size, projection_dim, bias=True)
def forward(self, image_features: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.linear(image_features)
hidden_states = self.linear(image_features)
return hidden_states
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment