"vscode:/vscode.git/clone" did not exist on "4aaaf8c8ce517dd97a1cb2610e57fc161755a3a3"
Unverified Commit 423e9f1c authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Use Transformers helper `get_text_config()` instead of checking for `text_config` (#17105)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 0bd7f8fc
...@@ -553,9 +553,8 @@ def main(args: argparse.Namespace): ...@@ -553,9 +553,8 @@ def main(args: argparse.Namespace):
intermediate_size = config.moe_intermediate_size intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size shard_intermediate_size = 2 * intermediate_size // args.tp_size
else: else:
if not hasattr(config, "hidden_size"):
# Support for llama4 # Support for llama4
config = config.text_config config = config.get_text_config()
# Default: Mixtral. # Default: Mixtral.
E = config.num_local_experts E = config.num_local_experts
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
......
...@@ -24,10 +24,7 @@ def test_can_initialize(model_arch): ...@@ -24,10 +24,7 @@ def test_can_initialize(model_arch):
def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig: def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig:
hf_config.update(model_info.hf_overrides) hf_config.update(model_info.hf_overrides)
if hasattr(hf_config, "text_config"): text_config = hf_config.get_text_config()
text_config: PretrainedConfig = hf_config.text_config
else:
text_config = hf_config
text_config.update({ text_config.update({
"num_layers": 1, "num_layers": 1,
......
...@@ -2841,12 +2841,10 @@ def _get_and_verify_dtype( ...@@ -2841,12 +2841,10 @@ def _get_and_verify_dtype(
) -> torch.dtype: ) -> torch.dtype:
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
# because config.torch_dtype can be None. # because config.torch_dtype can be None.
config_dtype = getattr(config, "torch_dtype", None) config_dtype = getattr(config.get_text_config(), "torch_dtype", None)
# Fallbacks for multi-modal models if the root config # Fallback for multi-modal models if the root config
# does not define torch_dtype # does not define torch_dtype
if config_dtype is None and hasattr(config, "text_config"):
config_dtype = getattr(config.text_config, "torch_dtype", None)
if config_dtype is None and hasattr(config, "vision_config"): if config_dtype is None and hasattr(config, "vision_config"):
config_dtype = getattr(config.vision_config, "torch_dtype", None) config_dtype = getattr(config.vision_config, "torch_dtype", None)
......
...@@ -760,19 +760,22 @@ def get_hf_text_config(config: PretrainedConfig): ...@@ -760,19 +760,22 @@ def get_hf_text_config(config: PretrainedConfig):
"""Get the "sub" config relevant to llm for multi modal models. """Get the "sub" config relevant to llm for multi modal models.
No op for pure text models. No op for pure text models.
""" """
if hasattr(config, "text_config"): # This block should be unnecessary after https://github.com/huggingface/transformers/pull/37517
# The code operates under the assumption that text_config should have if hasattr(config, "thinker_config"):
# `num_attention_heads` (among others). Assert here to fail early
# if transformers config doesn't align with this assumption.
assert hasattr(config.text_config, "num_attention_heads")
return config.text_config
elif hasattr(config, "thinker_config"):
# TODO(suyang.fy): Refactor code. # TODO(suyang.fy): Refactor code.
# For Qwen2.5-Omni, change hf_text_config to # For Qwen2.5-Omni, change hf_text_config to
# thinker_config.text_config. # thinker_config.text_config.
return config.thinker_config.text_config return config.thinker_config.text_config
else:
return config text_config = config.get_text_config()
if text_config is not config:
# The code operates under the assumption that text_config should have
# `num_attention_heads` (among others). Assert here to fail early
# if transformers config doesn't align with this assumption.
assert hasattr(text_config, "num_attention_heads")
return text_config
def try_get_generation_config( def try_get_generation_config(
......
...@@ -508,13 +508,8 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]): ...@@ -508,13 +508,8 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
logger.warning("Regarding multimodal models, vLLM currently " logger.warning("Regarding multimodal models, vLLM currently "
"only supports adding LoRA to language model.") "only supports adding LoRA to language model.")
# It's necessary to distinguish between the max_position_embeddings # Use get_text_config() in case of multimodal models
# of VLMs and LLMs. text_config = self.model_config.hf_config.get_text_config()
if hasattr(self.model.config, "max_position_embeddings"):
max_pos_embeddings = self.model.config.max_position_embeddings
else:
max_pos_embeddings = (
self.model.config.text_config.max_position_embeddings)
self.lora_manager = LRUCacheWorkerLoRAManager( self.lora_manager = LRUCacheWorkerLoRAManager(
self.scheduler_config.max_num_seqs, self.scheduler_config.max_num_seqs,
...@@ -524,7 +519,7 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]): ...@@ -524,7 +519,7 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
self.device, self.device,
self.model.embedding_modules, self.model.embedding_modules,
self.model.embedding_padding_modules, self.model.embedding_padding_modules,
max_position_embeddings=max_pos_embeddings, max_position_embeddings=text_config.max_position_embeddings,
) )
self.model = self.lora_manager.create_lora_manager(self.model) self.model = self.lora_manager.create_lora_manager(self.model)
......
...@@ -724,14 +724,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): ...@@ -724,14 +724,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
"Bias support in LoRA is not enabled in HPU yet." "Bias support in LoRA is not enabled in HPU yet."
assert not self.lora_config.fully_sharded_loras, \ assert not self.lora_config.fully_sharded_loras, \
"Fully sharded LoRAs is not enabled in HPU yet." "Fully sharded LoRAs is not enabled in HPU yet."
# It's necessary to distinguish between the
# max_position_embeddings of VLMs and LLMs. # Use get_text_config() in case of multimodal models
if hasattr(self.model.config, "max_position_embeddings"): text_config = self.model_config.hf_config.get_text_config()
max_pos_embeddings = (
self.model.config.max_position_embeddings)
else:
max_pos_embeddings = (
self.model.config.text_config.max_position_embeddings)
self.lora_manager = LRUCacheWorkerLoRAManager( self.lora_manager = LRUCacheWorkerLoRAManager(
self.scheduler_config.max_num_seqs, self.scheduler_config.max_num_seqs,
...@@ -741,7 +736,8 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): ...@@ -741,7 +736,8 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
self.device, self.device,
self.model.embedding_modules, self.model.embedding_modules,
self.model.embedding_padding_modules, self.model.embedding_padding_modules,
max_position_embeddings=max_pos_embeddings, max_position_embeddings=text_config.
max_position_embeddings,
) )
self.model = self.lora_manager.create_lora_manager(self.model) self.model = self.lora_manager.create_lora_manager(self.model)
......
...@@ -1130,14 +1130,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1130,14 +1130,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
logger.warning( logger.warning(
"Regarding multimodal models, vLLM currently " "Regarding multimodal models, vLLM currently "
"only supports adding LoRA to language model.") "only supports adding LoRA to language model.")
# It's necessary to distinguish between the
# max_position_embeddings of VLMs and LLMs. # Use get_text_config() in case of multimodal models
if hasattr(self.model.config, "max_position_embeddings"): text_config = self.model_config.hf_config.get_text_config()
max_pos_embeddings = (
self.model.config.max_position_embeddings)
else:
max_pos_embeddings = (
self.model.config.text_config.max_position_embeddings)
self.lora_manager = LRUCacheWorkerLoRAManager( self.lora_manager = LRUCacheWorkerLoRAManager(
self.scheduler_config.max_num_seqs, self.scheduler_config.max_num_seqs,
...@@ -1147,7 +1142,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1147,7 +1142,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.device, self.device,
self.model.embedding_modules, self.model.embedding_modules,
self.model.embedding_padding_modules, self.model.embedding_padding_modules,
max_position_embeddings=max_pos_embeddings, max_position_embeddings=text_config.
max_position_embeddings,
) )
self.model = self.lora_manager.create_lora_manager(self.model) self.model = self.lora_manager.create_lora_manager(self.model)
time_after_load = time.perf_counter() time_after_load = time.perf_counter()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment