Unverified Commit 9de1320b authored by Mick's avatar Mick Committed by GitHub
Browse files

fix: fp8 mllama4 without vision modules being quantized (#10611)

parent dda34c2f
......@@ -291,7 +291,7 @@ class Llama4UnfoldConvolution(nn.Module):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.unfold(hidden_states)
hidden_states = hidden_states.permute(0, 2, 1)
hidden_states = hidden_states.permute(0, 2, 1).contiguous()
hidden_states, _ = self.linear(hidden_states)
return hidden_states
......@@ -446,9 +446,20 @@ class Llama4ForConditionalGeneration(nn.Module):
)
if self.has_vision:
# TODO: make this more general
ignore_quant_layers = getattr(config, "quantization_config", {}).get(
"ignore", {}
)
if (
"model.layers.vision_model*" in ignore_quant_layers
and "model.layers.multi_modal_projector*" in ignore_quant_layers
):
vision_quant_config = None
else:
vision_quant_config = quant_config
self.vision_model = Llama4VisionModel(
config.vision_config,
quant_config=quant_config,
quant_config=vision_quant_config,
prefix=add_prefix("vision_model", prefix),
)
......@@ -560,7 +571,7 @@ class Llama4ForConditionalGeneration(nn.Module):
forward_batch=forward_batch,
language_model=self.language_model,
data_embedding_funcs={
Modality.IMAGE: self.get_image_feature,
Modality.IMAGE: image_embedding_func,
},
positions=positions,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment