Unverified Commit 5d98d560 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

Support Pixtral-Large HF by using llava multimodal_projector_bias config (#12710)


Signed-off-by: default avatarmgoin <michael@neuralmagic.com>
parent 73b35cca
...@@ -75,19 +75,20 @@ class LlavaMultiModalProjector(nn.Module): ...@@ -75,19 +75,20 @@ class LlavaMultiModalProjector(nn.Module):
vision_hidden_size: int, vision_hidden_size: int,
text_hidden_size: int, text_hidden_size: int,
projector_hidden_act: str, projector_hidden_act: str,
multimodal_projector_bias: bool,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""): prefix: str = ""):
super().__init__() super().__init__()
self.linear_1 = ColumnParallelLinear(vision_hidden_size, self.linear_1 = ColumnParallelLinear(vision_hidden_size,
text_hidden_size, text_hidden_size,
bias=True, bias=multimodal_projector_bias,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.linear_1") prefix=f"{prefix}.linear_1")
self.act = get_act_fn(projector_hidden_act) self.act = get_act_fn(projector_hidden_act)
self.linear_2 = RowParallelLinear(text_hidden_size, self.linear_2 = RowParallelLinear(text_hidden_size,
text_hidden_size, text_hidden_size,
bias=True, bias=multimodal_projector_bias,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.linear_2") prefix=f"{prefix}.linear_2")
...@@ -503,6 +504,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -503,6 +504,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
vision_hidden_size=config.vision_config.hidden_size, vision_hidden_size=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size, text_hidden_size=config.text_config.hidden_size,
projector_hidden_act=config.projector_hidden_act, projector_hidden_act=config.projector_hidden_act,
multimodal_projector_bias=config.multimodal_projector_bias,
quant_config=quant_config, quant_config=quant_config,
prefix=maybe_prefix(prefix, "multi_modal_projector")) prefix=maybe_prefix(prefix, "multi_modal_projector"))
......
...@@ -231,7 +231,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -231,7 +231,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
self.multi_modal_projector = LlavaMultiModalProjector( self.multi_modal_projector = LlavaMultiModalProjector(
vision_hidden_size=vision_hidden_size, vision_hidden_size=vision_hidden_size,
text_hidden_size=config.text_config.hidden_size, text_hidden_size=config.text_config.hidden_size,
projector_hidden_act=config.projector_hidden_act) projector_hidden_act=config.projector_hidden_act,
multimodal_projector_bias=config.multimodal_projector_bias)
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
vllm_config=vllm_config, vllm_config=vllm_config,
......
...@@ -253,16 +253,16 @@ class LlavaNextVideoPooler(nn.Module): ...@@ -253,16 +253,16 @@ class LlavaNextVideoPooler(nn.Module):
class LlavaNextMultiModalProjector(nn.Module): class LlavaNextMultiModalProjector(nn.Module):
def __init__(self, vision_hidden_size: int, text_hidden_size: int, def __init__(self, vision_hidden_size: int, text_hidden_size: int,
projector_hidden_act: str): projector_hidden_act: str, multimodal_projector_bias: bool):
super().__init__() super().__init__()
self.linear_1 = nn.Linear(vision_hidden_size, self.linear_1 = nn.Linear(vision_hidden_size,
text_hidden_size, text_hidden_size,
bias=True) bias=multimodal_projector_bias)
self.act = get_act_fn(projector_hidden_act) self.act = get_act_fn(projector_hidden_act)
self.linear_2 = nn.Linear(text_hidden_size, self.linear_2 = nn.Linear(text_hidden_size,
text_hidden_size, text_hidden_size,
bias=True) bias=multimodal_projector_bias)
def forward(self, image_features: torch.Tensor) -> torch.Tensor: def forward(self, image_features: torch.Tensor) -> torch.Tensor:
hidden_states = self.linear_1(image_features) hidden_states = self.linear_1(image_features)
...@@ -298,7 +298,8 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -298,7 +298,8 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
self.multi_modal_projector = LlavaNextMultiModalProjector( self.multi_modal_projector = LlavaNextMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size, vision_hidden_size=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size, text_hidden_size=config.text_config.hidden_size,
projector_hidden_act=config.projector_hidden_act) projector_hidden_act=config.projector_hidden_act,
multimodal_projector_bias=config.multimodal_projector_bias)
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
vllm_config=vllm_config, vllm_config=vllm_config,
hf_config=config.text_config, hf_config=config.text_config,
......
...@@ -372,11 +372,11 @@ class LlavaOnevisionMultiModalProjector(nn.Module): ...@@ -372,11 +372,11 @@ class LlavaOnevisionMultiModalProjector(nn.Module):
self.linear_1 = nn.Linear(config.vision_config.hidden_size, self.linear_1 = nn.Linear(config.vision_config.hidden_size,
config.text_config.hidden_size, config.text_config.hidden_size,
bias=True) bias=config.multimodal_projector_bias)
self.act = get_act_fn(config.projector_hidden_act) self.act = get_act_fn(config.projector_hidden_act)
self.linear_2 = nn.Linear(config.text_config.hidden_size, self.linear_2 = nn.Linear(config.text_config.hidden_size,
config.text_config.hidden_size, config.text_config.hidden_size,
bias=True) bias=config.multimodal_projector_bias)
def forward(self, image_features: torch.Tensor) -> torch.Tensor: def forward(self, image_features: torch.Tensor) -> torch.Tensor:
hidden_states = self.linear_1(image_features) hidden_states = self.linear_1(image_features)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment