Unverified Commit 038bc5d5 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

Support `--enable-llama4-multimodal` (#5254)

parent aee62d74
......@@ -136,6 +136,7 @@ def load_model(server_args, port_args, tp_rank):
context_length=server_args.context_length,
model_override_args=server_args.json_model_override_args,
is_embedding=server_args.is_embedding,
enable_multimodal=server_args.enable_multimodal,
dtype=server_args.dtype,
quantization=server_args.quantization,
)
......
......@@ -43,10 +43,12 @@ class ModelConfig:
context_length: Optional[int] = None,
model_override_args: Optional[str] = None,
is_embedding: Optional[bool] = None,
enable_multimodal: Optional[bool] = None,
dtype: str = "auto",
quantization: Optional[str] = None,
override_config_file: Optional[str] = None,
) -> None:
self.model_path = model_path
self.revision = revision
self.quantization = quantization
......@@ -70,14 +72,28 @@ class ModelConfig:
self.hf_text_config, "attention_chunk_size", None
)
if enable_multimodal is None:
if self.hf_config.architectures == "Llama4ForConditionalGeneration":
enable_multimodal = False
else:
enable_multimodal = True
# Check model type
self.is_generation = is_generation_model(
self.hf_config.architectures, is_embedding
)
self.is_multimodal = is_multimodal_model(self.hf_config.architectures)
self.is_multimodal_gen = is_multimodal_gen_model(self.hf_config.architectures)
self.is_image_gen = is_image_gen_model(self.hf_config.architectures)
self.is_audio_model = is_audio_model(self.hf_config.architectures)
self.is_multimodal = enable_multimodal and is_multimodal_model(
self.hf_config.architectures
)
self.is_multimodal_gen = enable_multimodal and is_multimodal_gen_model(
self.hf_config.architectures
)
self.is_image_gen = enable_multimodal and is_image_gen_model(
self.hf_config.architectures
)
self.is_audio_model = enable_multimodal and is_audio_model(
self.hf_config.architectures
)
self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
......
......@@ -437,6 +437,7 @@ class Scheduler(
context_length=server_args.context_length,
model_override_args=server_args.json_model_override_args,
is_embedding=server_args.is_embedding,
enable_multimodal=server_args.enable_multimodal,
dtype=server_args.dtype,
quantization=server_args.quantization,
)
......
......@@ -163,6 +163,7 @@ class TokenizerManager:
context_length=server_args.context_length,
model_override_args=server_args.json_model_override_args,
is_embedding=server_args.is_embedding,
enable_multimodal=server_args.enable_multimodal,
dtype=server_args.dtype,
quantization=server_args.quantization,
)
......
......@@ -68,6 +68,7 @@ class TpModelWorker:
context_length=server_args.context_length,
model_override_args=server_args.json_model_override_args,
is_embedding=server_args.is_embedding,
enable_multimodal=server_args.enable_multimodal,
dtype=server_args.dtype,
quantization=server_args.quantization,
)
......
......@@ -281,7 +281,6 @@ class ModelRunner:
f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
f"because this is a multimodal model."
)
logger.info(
"Automatically turn off --chunked-prefill-size for multimodal model."
)
......
......@@ -156,6 +156,7 @@ class ServerArgs:
disable_outlines_disk_cache: bool = False
disable_custom_all_reduce: bool = False
disable_mla: bool = False
enable_llama4_multimodal: Optional[bool] = None
disable_overlap_schedule: bool = False
enable_mixed_chunk: bool = False
enable_dp_attention: bool = False
......@@ -294,6 +295,8 @@ class ServerArgs:
f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
)
self.enable_multimodal: Optional[bool] = self.enable_llama4_multimodal
# Data parallelism attention
if self.enable_dp_attention:
self.schedule_conservativeness = self.schedule_conservativeness * 0.3
......@@ -974,6 +977,11 @@ class ServerArgs:
action="store_true",
help="Disable Multi-head Latent Attention (MLA) for DeepSeek V2/V3/R1 series models.",
)
parser.add_argument(
"--enable-llama4-multimodal",
action="store_true",
help="Enable the multimodal functionality for Llama-4.",
)
parser.add_argument(
"--disable-overlap-schedule",
action="store_true",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment