Unverified Commit e2d66f60 authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Skip llama4 vision module loading when multimodal disabled (#8272)


Co-authored-by: default avatarMick <mickjagger19@icloud.com>
parent 01c00004
...@@ -106,6 +106,7 @@ GLOBAL_SERVER_ARGS_KEYS = [ ...@@ -106,6 +106,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
"num_reserved_decode_tokens", "num_reserved_decode_tokens",
"weight_loader_disable_mmap", "weight_loader_disable_mmap",
"enable_triton_kernel_moe", "enable_triton_kernel_moe",
"enable_multimodal",
] ]
# Put some global args for easy access # Put some global args for easy access
......
...@@ -23,6 +23,7 @@ from sglang.srt.managers.schedule_batch import ( ...@@ -23,6 +23,7 @@ from sglang.srt.managers.schedule_batch import (
Modality, Modality,
MultimodalDataItem, MultimodalDataItem,
MultimodalInputs, MultimodalInputs,
global_server_args_dict,
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
...@@ -55,13 +56,17 @@ class Llama4ForConditionalGeneration(nn.Module): ...@@ -55,13 +56,17 @@ class Llama4ForConditionalGeneration(nn.Module):
self.quant_config = quant_config self.quant_config = quant_config
# Check if this is a text-only model (modelopt fp8 llama4 has no vision components) # Check if this is a text-only model (modelopt fp8 llama4 has no vision components)
self.has_vision = self._has_vision_weights(config) self.has_vision_weights = self._has_vision_weights(config)
if not self.has_vision: if not self.has_vision_weights:
logger.warning( logger.warning(
"No vision weights found in checkpoint. Model will run in text-only mode. " "No vision weights found in checkpoint. Model will run in text-only mode. "
"Multimodal capabilities (image processing) will be unavailable." "Multimodal capabilities (image processing) will be unavailable."
) )
self.has_vision = (
self.has_vision_weights and global_server_args_dict["enable_multimodal"]
)
if self.has_vision: if self.has_vision:
self.vision_model = Llama4VisionModel(config.vision_config) self.vision_model = Llama4VisionModel(config.vision_config)
self.multi_modal_projector = Llama4MultiModalProjector(config) self.multi_modal_projector = Llama4MultiModalProjector(config)
...@@ -269,7 +274,9 @@ class Llama4ForConditionalGeneration(nn.Module): ...@@ -269,7 +274,9 @@ class Llama4ForConditionalGeneration(nn.Module):
def _should_skip_weight(self, name: str) -> bool: def _should_skip_weight(self, name: str) -> bool:
"""Check if we should skip loading this weight.""" """Check if we should skip loading this weight."""
return "vision" in name and not self.has_vision return not self.has_vision and (
"vision" in name or "multi_modal_projector" in name
)
def _transform_weight_name(self, name: str) -> str: def _transform_weight_name(self, name: str) -> str:
"""Transform weight name by adding language_model prefix if needed.""" """Transform weight name by adding language_model prefix if needed."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment