Unverified Commit f47a2c67 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

[Auto Sync] Update load_config.py, model_config.py, configu... (20250923) (#10825)


Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
parent ee704e62
...@@ -24,6 +24,8 @@ class LoadFormat(str, enum.Enum): ...@@ -24,6 +24,8 @@ class LoadFormat(str, enum.Enum):
JAX = "jax" JAX = "jax"
REMOTE = "remote" REMOTE = "remote"
REMOTE_INSTANCE = "remote_instance" REMOTE_INSTANCE = "remote_instance"
RDMA = "rdma"
LOCAL_CACHED = "local_cached"
@dataclass @dataclass
...@@ -47,6 +49,7 @@ class LoadConfig: ...@@ -47,6 +49,7 @@ class LoadConfig:
checkpoints. checkpoints.
decryption_key_file: If set, decrypts the output files with a password read decryption_key_file: If set, decrypts the output files with a password read
from this file (after PBKDF2). from this file (after PBKDF2).
decrypt_max_concurrency: The maximum number of concurrent processes to decrypt the safetensor files. -1 means no limit.
""" """
load_format: Union[str, LoadFormat] = LoadFormat.AUTO load_format: Union[str, LoadFormat] = LoadFormat.AUTO
...@@ -54,6 +57,7 @@ class LoadConfig: ...@@ -54,6 +57,7 @@ class LoadConfig:
model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict) model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict)
ignore_patterns: Optional[Union[List[str], str]] = None ignore_patterns: Optional[Union[List[str], str]] = None
decryption_key_file: Optional[str] = None decryption_key_file: Optional[str] = None
decrypt_max_concurrency: int = -1
def __post_init__(self): def __post_init__(self):
model_loader_extra_config = self.model_loader_extra_config or {} model_loader_extra_config = self.model_loader_extra_config or {}
......
...@@ -75,7 +75,10 @@ class ModelConfig: ...@@ -75,7 +75,10 @@ class ModelConfig:
self.model_path = model_path self.model_path = model_path
self.revision = revision self.revision = revision
self.quantization = quantization self.quantization = quantization
self.is_draft_model = is_draft_model
self.model_impl = model_impl self.model_impl = model_impl
# TODO: remove these fields
self.tp_rank = tp_rank self.tp_rank = tp_rank
self.remote_instance_weight_loader_seed_instance_ip = ( self.remote_instance_weight_loader_seed_instance_ip = (
remote_instance_weight_loader_seed_instance_ip remote_instance_weight_loader_seed_instance_ip
...@@ -87,12 +90,12 @@ class ModelConfig: ...@@ -87,12 +90,12 @@ class ModelConfig:
remote_instance_weight_loader_send_weights_group_ports remote_instance_weight_loader_send_weights_group_ports
) )
self.maybe_pull_model_tokenizer_from_remote() # Get hf config
self._maybe_pull_model_tokenizer_from_remote()
self.model_override_args = json.loads(model_override_args) self.model_override_args = json.loads(model_override_args)
kwargs = {} kwargs = {}
if override_config_file and override_config_file.strip(): if override_config_file and override_config_file.strip():
kwargs["_configuration_file"] = override_config_file.strip() kwargs["_configuration_file"] = override_config_file.strip()
self.hf_config = get_config( self.hf_config = get_config(
self.model_path, self.model_path,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
...@@ -100,7 +103,7 @@ class ModelConfig: ...@@ -100,7 +103,7 @@ class ModelConfig:
model_override_args=self.model_override_args, model_override_args=self.model_override_args,
**kwargs, **kwargs,
) )
self.hf_text_config = get_hf_text_config(self.hf_config)
self.hf_generation_config = get_generation_config( self.hf_generation_config = get_generation_config(
self.model_path, self.model_path,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
...@@ -108,7 +111,25 @@ class ModelConfig: ...@@ -108,7 +111,25 @@ class ModelConfig:
**kwargs, **kwargs,
) )
self.hf_text_config = get_hf_text_config(self.hf_config) # Set enable_multimodal
if enable_multimodal is None:
mm_disabled_models = [
"Gemma3ForConditionalGeneration",
"Llama4ForConditionalGeneration",
"Step3VLForConditionalGeneration",
]
if self.hf_config.architectures[0] in mm_disabled_models:
enable_multimodal = False
logger.info(
f"Multimodal is disabled for {self.hf_config.model_type}. To enable it, set --enable-multimodal."
)
else:
enable_multimodal = True
# Config draft model
self._config_draft_model()
# Check model type
self.attention_chunk_size = getattr( self.attention_chunk_size = getattr(
self.hf_text_config, "attention_chunk_size", None self.hf_text_config, "attention_chunk_size", None
) )
...@@ -124,20 +145,73 @@ class ModelConfig: ...@@ -124,20 +145,73 @@ class ModelConfig:
self.hf_config.architectures, self.hf_text_config.num_hidden_layers self.hf_config.architectures, self.hf_text_config.num_hidden_layers
) )
) )
self.is_generation = is_generation_model(
self.hf_config.architectures, is_embedding
)
self.is_multimodal = enable_multimodal and is_multimodal_model(
self.hf_config.architectures
)
self.is_multimodal_gen = enable_multimodal and is_multimodal_gen_model(
self.hf_config.architectures
)
self.is_image_gen = enable_multimodal and is_image_gen_model(
self.hf_config.architectures
)
self.is_audio_model = enable_multimodal and is_audio_model(
self.hf_config.architectures
)
self.is_multimodal_chunked_prefill_supported = (
enable_multimodal
and is_multimodal_chunked_prefill_supported(self.hf_config.architectures)
)
self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
if enable_multimodal is None: # Derive context length and model shapes
mm_disabled_models = [ self._derive_context_length(context_length)
"Gemma3ForConditionalGeneration", self._derive_model_shapes()
"Llama4ForConditionalGeneration",
"Step3VLForConditionalGeneration", # Verify quantization
] self._verify_quantization()
if self.hf_config.architectures[0] in mm_disabled_models:
enable_multimodal = False # Verify dual-chunk attention config
logger.info( self._verify_dual_chunk_attention_config()
f"Multimodal is disabled for {self.hf_config.model_type}. To enable it, set --enable-multimodal."
# Cache attributes
self.hf_eos_token_id = self._get_hf_eos_token_id()
# multimodal
self.image_token_id = getattr(
self.hf_config, "image_token_id", None
) or getattr(self.hf_config, "image_token_index", None)
@staticmethod
def from_server_args(
server_args: ServerArgs,
model_path: str = None,
model_revision: str = None,
**kwargs,
):
return ModelConfig(
model_path=model_path or server_args.model_path,
trust_remote_code=server_args.trust_remote_code,
revision=model_revision or server_args.revision,
context_length=server_args.context_length,
model_override_args=server_args.json_model_override_args,
is_embedding=server_args.is_embedding,
enable_multimodal=server_args.enable_multimodal,
dtype=server_args.dtype,
quantization=server_args.quantization,
hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
model_impl=server_args.model_impl,
remote_instance_weight_loader_seed_instance_ip=server_args.remote_instance_weight_loader_seed_instance_ip,
remote_instance_weight_loader_seed_instance_service_port=server_args.remote_instance_weight_loader_seed_instance_service_port,
remote_instance_weight_loader_send_weights_group_ports=server_args.remote_instance_weight_loader_send_weights_group_ports,
**kwargs,
) )
else:
enable_multimodal = True def _config_draft_model(self):
is_draft_model = self.is_draft_model
if ( if (
is_draft_model is_draft_model
...@@ -172,31 +246,10 @@ class ModelConfig: ...@@ -172,31 +246,10 @@ class ModelConfig:
self.hf_config.architectures[0] = "Qwen3NextForCausalLMMTP" self.hf_config.architectures[0] = "Qwen3NextForCausalLMMTP"
self.hf_config.num_nextn_predict_layers = 1 self.hf_config.num_nextn_predict_layers = 1
# Check model type def _derive_context_length(self, context_length: int):
self.is_generation = is_generation_model( is_draft_model = self.is_draft_model
self.hf_config.architectures, is_embedding
)
self.is_multimodal = enable_multimodal and is_multimodal_model(
self.hf_config.architectures
)
self.is_multimodal_gen = enable_multimodal and is_multimodal_gen_model(
self.hf_config.architectures
)
self.is_image_gen = enable_multimodal and is_image_gen_model(
self.hf_config.architectures
)
self.is_audio_model = enable_multimodal and is_audio_model(
self.hf_config.architectures
)
self.is_multimodal_chunked_prefill_supported = (
enable_multimodal
and is_multimodal_chunked_prefill_supported(self.hf_config.architectures)
)
self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
# Derive context length
derived_context_len = get_context_length(self.hf_text_config) derived_context_len = get_context_length(self.hf_text_config)
if context_length is not None: if context_length is not None:
if context_length > derived_context_len: if context_length > derived_context_len:
reason = "Target model's" if is_draft_model else "User-specified" reason = "Target model's" if is_draft_model else "User-specified"
...@@ -224,6 +277,10 @@ class ModelConfig: ...@@ -224,6 +277,10 @@ class ModelConfig:
else: else:
self.context_len = derived_context_len self.context_len = derived_context_len
# Transfer context_len to HuggingFace config so models can access it
self.hf_config.context_len = self.context_len
def _derive_model_shapes(self):
# Unify the config keys for hf_text_config # Unify the config keys for hf_text_config
self.head_dim = getattr( self.head_dim = getattr(
self.hf_text_config, self.hf_text_config,
...@@ -318,45 +375,6 @@ class ModelConfig: ...@@ -318,45 +375,6 @@ class ModelConfig:
) )
self.vocab_size = self.hf_text_config.vocab_size self.vocab_size = self.hf_text_config.vocab_size
# Verify quantization
self._verify_quantization()
# Verify dual-chunk attention config
self._verify_dual_chunk_attention_config()
# Cache attributes
self.hf_eos_token_id = self.get_hf_eos_token_id()
# multimodal
self.image_token_id = getattr(
self.hf_config, "image_token_id", None
) or getattr(self.hf_config, "image_token_index", None)
@staticmethod
def from_server_args(
server_args: ServerArgs,
model_path: str = None,
model_revision: str = None,
**kwargs,
):
return ModelConfig(
model_path=model_path or server_args.model_path,
trust_remote_code=server_args.trust_remote_code,
revision=model_revision or server_args.revision,
context_length=server_args.context_length,
model_override_args=server_args.json_model_override_args,
is_embedding=server_args.is_embedding,
enable_multimodal=server_args.enable_multimodal,
dtype=server_args.dtype,
quantization=server_args.quantization,
hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
model_impl=server_args.model_impl,
remote_instance_weight_loader_seed_instance_ip=server_args.remote_instance_weight_loader_seed_instance_ip,
remote_instance_weight_loader_seed_instance_service_port=server_args.remote_instance_weight_loader_seed_instance_service_port,
remote_instance_weight_loader_send_weights_group_ports=server_args.remote_instance_weight_loader_send_weights_group_ports,
**kwargs,
)
def get_total_num_attention_heads(self) -> int: def get_total_num_attention_heads(self) -> int:
return self.num_attention_heads return self.num_attention_heads
...@@ -591,7 +609,7 @@ class ModelConfig: ...@@ -591,7 +609,7 @@ class ModelConfig:
"sparse_attention_enabled" "sparse_attention_enabled"
] = True ] = True
def get_hf_eos_token_id(self) -> Optional[Set[int]]: def _get_hf_eos_token_id(self) -> Optional[Set[int]]:
eos_ids = getattr(self.hf_config, "eos_token_id", None) eos_ids = getattr(self.hf_config, "eos_token_id", None)
if eos_ids is not None: if eos_ids is not None:
# it can be either int or list of int # it can be either int or list of int
...@@ -611,7 +629,7 @@ class ModelConfig: ...@@ -611,7 +629,7 @@ class ModelConfig:
eos_ids = eos_ids | generation_eos_ids eos_ids = eos_ids | generation_eos_ids
return eos_ids return eos_ids
def maybe_pull_model_tokenizer_from_remote(self) -> None: def _maybe_pull_model_tokenizer_from_remote(self) -> None:
""" """
Pull the model config files to a temporary Pull the model config files to a temporary
directory in case of remote. directory in case of remote.
......
import logging import logging
import torch
from sglang.srt.utils import get_bool_env_var, get_device_sm, is_blackwell from sglang.srt.utils import get_bool_env_var, get_device_sm, is_blackwell
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment