Unverified Commit 038bc5d5 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

Support `--enable-llama4-multimodal` (#5254)

parent aee62d74
...@@ -136,6 +136,7 @@ def load_model(server_args, port_args, tp_rank): ...@@ -136,6 +136,7 @@ def load_model(server_args, port_args, tp_rank):
context_length=server_args.context_length, context_length=server_args.context_length,
model_override_args=server_args.json_model_override_args, model_override_args=server_args.json_model_override_args,
is_embedding=server_args.is_embedding, is_embedding=server_args.is_embedding,
enable_multimodal=server_args.enable_multimodal,
dtype=server_args.dtype, dtype=server_args.dtype,
quantization=server_args.quantization, quantization=server_args.quantization,
) )
......
...@@ -43,10 +43,12 @@ class ModelConfig: ...@@ -43,10 +43,12 @@ class ModelConfig:
context_length: Optional[int] = None, context_length: Optional[int] = None,
model_override_args: Optional[str] = None, model_override_args: Optional[str] = None,
is_embedding: Optional[bool] = None, is_embedding: Optional[bool] = None,
enable_multimodal: Optional[bool] = None,
dtype: str = "auto", dtype: str = "auto",
quantization: Optional[str] = None, quantization: Optional[str] = None,
override_config_file: Optional[str] = None, override_config_file: Optional[str] = None,
) -> None: ) -> None:
self.model_path = model_path self.model_path = model_path
self.revision = revision self.revision = revision
self.quantization = quantization self.quantization = quantization
...@@ -70,14 +72,28 @@ class ModelConfig: ...@@ -70,14 +72,28 @@ class ModelConfig:
self.hf_text_config, "attention_chunk_size", None self.hf_text_config, "attention_chunk_size", None
) )
if enable_multimodal is None:
if self.hf_config.architectures == "Llama4ForConditionalGeneration":
enable_multimodal = False
else:
enable_multimodal = True
# Check model type # Check model type
self.is_generation = is_generation_model( self.is_generation = is_generation_model(
self.hf_config.architectures, is_embedding self.hf_config.architectures, is_embedding
) )
self.is_multimodal = is_multimodal_model(self.hf_config.architectures) self.is_multimodal = enable_multimodal and is_multimodal_model(
self.is_multimodal_gen = is_multimodal_gen_model(self.hf_config.architectures) self.hf_config.architectures
self.is_image_gen = is_image_gen_model(self.hf_config.architectures) )
self.is_audio_model = is_audio_model(self.hf_config.architectures) self.is_multimodal_gen = enable_multimodal and is_multimodal_gen_model(
self.hf_config.architectures
)
self.is_image_gen = enable_multimodal and is_image_gen_model(
self.hf_config.architectures
)
self.is_audio_model = enable_multimodal and is_audio_model(
self.hf_config.architectures
)
self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures) self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
......
...@@ -437,6 +437,7 @@ class Scheduler( ...@@ -437,6 +437,7 @@ class Scheduler(
context_length=server_args.context_length, context_length=server_args.context_length,
model_override_args=server_args.json_model_override_args, model_override_args=server_args.json_model_override_args,
is_embedding=server_args.is_embedding, is_embedding=server_args.is_embedding,
enable_multimodal=server_args.enable_multimodal,
dtype=server_args.dtype, dtype=server_args.dtype,
quantization=server_args.quantization, quantization=server_args.quantization,
) )
......
...@@ -163,6 +163,7 @@ class TokenizerManager: ...@@ -163,6 +163,7 @@ class TokenizerManager:
context_length=server_args.context_length, context_length=server_args.context_length,
model_override_args=server_args.json_model_override_args, model_override_args=server_args.json_model_override_args,
is_embedding=server_args.is_embedding, is_embedding=server_args.is_embedding,
enable_multimodal=server_args.enable_multimodal,
dtype=server_args.dtype, dtype=server_args.dtype,
quantization=server_args.quantization, quantization=server_args.quantization,
) )
......
...@@ -68,6 +68,7 @@ class TpModelWorker: ...@@ -68,6 +68,7 @@ class TpModelWorker:
context_length=server_args.context_length, context_length=server_args.context_length,
model_override_args=server_args.json_model_override_args, model_override_args=server_args.json_model_override_args,
is_embedding=server_args.is_embedding, is_embedding=server_args.is_embedding,
enable_multimodal=server_args.enable_multimodal,
dtype=server_args.dtype, dtype=server_args.dtype,
quantization=server_args.quantization, quantization=server_args.quantization,
) )
......
...@@ -281,7 +281,6 @@ class ModelRunner: ...@@ -281,7 +281,6 @@ class ModelRunner:
f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} " f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
f"because this is a multimodal model." f"because this is a multimodal model."
) )
logger.info( logger.info(
"Automatically turn off --chunked-prefill-size for multimodal model." "Automatically turn off --chunked-prefill-size for multimodal model."
) )
......
...@@ -156,6 +156,7 @@ class ServerArgs: ...@@ -156,6 +156,7 @@ class ServerArgs:
disable_outlines_disk_cache: bool = False disable_outlines_disk_cache: bool = False
disable_custom_all_reduce: bool = False disable_custom_all_reduce: bool = False
disable_mla: bool = False disable_mla: bool = False
enable_llama4_multimodal: Optional[bool] = None
disable_overlap_schedule: bool = False disable_overlap_schedule: bool = False
enable_mixed_chunk: bool = False enable_mixed_chunk: bool = False
enable_dp_attention: bool = False enable_dp_attention: bool = False
...@@ -294,6 +295,8 @@ class ServerArgs: ...@@ -294,6 +295,8 @@ class ServerArgs:
f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]." f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
) )
self.enable_multimodal: Optional[bool] = self.enable_llama4_multimodal
# Data parallelism attention # Data parallelism attention
if self.enable_dp_attention: if self.enable_dp_attention:
self.schedule_conservativeness = self.schedule_conservativeness * 0.3 self.schedule_conservativeness = self.schedule_conservativeness * 0.3
...@@ -974,6 +977,11 @@ class ServerArgs: ...@@ -974,6 +977,11 @@ class ServerArgs:
action="store_true", action="store_true",
help="Disable Multi-head Latent Attention (MLA) for DeepSeek V2/V3/R1 series models.", help="Disable Multi-head Latent Attention (MLA) for DeepSeek V2/V3/R1 series models.",
) )
parser.add_argument(
"--enable-llama4-multimodal",
action="store_true",
help="Enable the multimodal functionality for Llama-4.",
)
parser.add_argument( parser.add_argument(
"--disable-overlap-schedule", "--disable-overlap-schedule",
action="store_true", action="store_true",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment