Unverified Commit f154bb9f authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Simplify weight loading in Transformers backend (#21382)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 3ec7170f
...@@ -177,7 +177,7 @@ TEXT_GENERATION_MODELS = { ...@@ -177,7 +177,7 @@ TEXT_GENERATION_MODELS = {
"ai21labs/Jamba-tiny-dev": PPTestSettings.fast(), "ai21labs/Jamba-tiny-dev": PPTestSettings.fast(),
"meta-llama/Llama-3.2-1B-Instruct": PPTestSettings.detailed(), "meta-llama/Llama-3.2-1B-Instruct": PPTestSettings.detailed(),
# Tests TransformersForCausalLM # Tests TransformersForCausalLM
"ArthurZ/Ilama-3.2-1B": PPTestSettings.fast(), "hmellor/Ilama-3.2-1B": PPTestSettings.fast(),
"openbmb/MiniCPM-2B-sft-bf16": PPTestSettings.fast(), "openbmb/MiniCPM-2B-sft-bf16": PPTestSettings.fast(),
"openbmb/MiniCPM3-4B": PPTestSettings.fast(), "openbmb/MiniCPM3-4B": PPTestSettings.fast(),
# Uses Llama # Uses Llama
...@@ -249,7 +249,7 @@ TEST_MODELS = [ ...@@ -249,7 +249,7 @@ TEST_MODELS = [
# [LANGUAGE GENERATION] # [LANGUAGE GENERATION]
"microsoft/Phi-3.5-MoE-instruct", "microsoft/Phi-3.5-MoE-instruct",
"meta-llama/Llama-3.2-1B-Instruct", "meta-llama/Llama-3.2-1B-Instruct",
"ArthurZ/Ilama-3.2-1B", "hmellor/Ilama-3.2-1B",
"ibm/PowerLM-3b", "ibm/PowerLM-3b",
"deepseek-ai/DeepSeek-V2-Lite-Chat", "deepseek-ai/DeepSeek-V2-Lite-Chat",
# [LANGUAGE EMBEDDING] # [LANGUAGE EMBEDDING]
......
...@@ -9,7 +9,7 @@ from vllm.platforms import current_platform ...@@ -9,7 +9,7 @@ from vllm.platforms import current_platform
from ..utils import create_new_process_for_each_test, multi_gpu_test from ..utils import create_new_process_for_each_test, multi_gpu_test
MODEL_PATH = "ArthurZ/ilama-3.2-1B" MODEL_PATH = "hmellor/Ilama-3.2-1B"
PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501 PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501
......
...@@ -500,7 +500,7 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = { ...@@ -500,7 +500,7 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
} }
_TRANSFORMERS_MODELS = { _TRANSFORMERS_MODELS = {
"TransformersForCausalLM": _HfExamplesInfo("ArthurZ/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501 "TransformersForCausalLM": _HfExamplesInfo("hmellor/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501
"TransformersForMultimodalLM": _HfExamplesInfo("OpenGVLab/InternVL3-1B-hf"), "TransformersForMultimodalLM": _HfExamplesInfo("OpenGVLab/InternVL3-1B-hf"),
} }
......
...@@ -56,7 +56,7 @@ def check_implementation( ...@@ -56,7 +56,7 @@ def check_implementation(
"model,model_impl", "model,model_impl",
[ [
("meta-llama/Llama-3.2-1B-Instruct", "transformers"), ("meta-llama/Llama-3.2-1B-Instruct", "transformers"),
("ArthurZ/Ilama-3.2-1B", "auto"), # CUSTOM CODE ("hmellor/Ilama-3.2-1B", "auto"), # CUSTOM CODE
]) # trust_remote_code=True by default ]) # trust_remote_code=True by default
def test_models( def test_models(
hf_runner: type[HfRunner], hf_runner: type[HfRunner],
......
...@@ -624,13 +624,9 @@ class SupportsQuant: ...@@ -624,13 +624,9 @@ class SupportsQuant:
instance.quant_config = quant_config instance.quant_config = quant_config
# apply model mappings to config for proper config-model matching # apply model mappings to config for proper config-model matching
# NOTE: `TransformersForCausalLM` is not supported due to how this if (hf_to_vllm_mapper := instance.hf_to_vllm_mapper) is not None:
# class defines `hf_to_vllm_mapper` as a post-init `@property`. instance.quant_config.apply_vllm_mapper(hf_to_vllm_mapper)
# After this is fixed, get `instance.hf_to_vllm_mapper` directly if instance.packed_modules_mapping is not None:
if getattr(instance, "hf_to_vllm_mapper", None) is not None:
instance.quant_config.apply_vllm_mapper(
instance.hf_to_vllm_mapper)
if getattr(instance, "packed_modules_mapping", None) is not None:
instance.quant_config.packed_modules_mapping.update( instance.quant_config.packed_modules_mapping.update(
instance.packed_modules_mapping) instance.packed_modules_mapping)
......
...@@ -414,7 +414,7 @@ class ConfigOverride: ...@@ -414,7 +414,7 @@ class ConfigOverride:
setattr(self.config, key, value) setattr(self.config, key, value)
class TransformersModel(nn.Module): class TransformersModel:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
...@@ -454,9 +454,6 @@ class TransformersModel(nn.Module): ...@@ -454,9 +454,6 @@ class TransformersModel(nn.Module):
# method after v4.54.0 is released # method after v4.54.0 is released
self.text_config._attn_implementation = "vllm" self.text_config._attn_implementation = "vllm"
with init_on_device_without_buffers("meta"), config_override: with init_on_device_without_buffers("meta"), config_override:
# FIXME(Isotr0py): We need to refactor this part in the future to
# avoid registering an extra model layer, otherwise we will need a
# weights mapper to rename weights.
self.model: PreTrainedModel = AutoModel.from_config( self.model: PreTrainedModel = AutoModel.from_config(
config, config,
torch_dtype=model_config.dtype, torch_dtype=model_config.dtype,
...@@ -620,9 +617,6 @@ class TransformersModel(nn.Module): ...@@ -620,9 +617,6 @@ class TransformersModel(nn.Module):
for child in module.children(): for child in module.children():
self.init_parameters(child) self.init_parameters(child)
def get_input_embeddings(self) -> nn.Module:
return self.model.get_input_embeddings()
def forward( def forward(
self, self,
input_ids: Optional[torch.Tensor], input_ids: Optional[torch.Tensor],
...@@ -694,7 +688,9 @@ class TransformersForCausalLM(nn.Module, SupportsQuant, SupportsLoRA, ...@@ -694,7 +688,9 @@ class TransformersForCausalLM(nn.Module, SupportsQuant, SupportsLoRA,
self.config = config self.config = config
self.model = TransformersModel(vllm_config=vllm_config, prefix=prefix) self.transformers_model = TransformersModel(vllm_config=vllm_config,
prefix=prefix)
self.model = self.transformers_model.model
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size
...@@ -716,22 +712,7 @@ class TransformersForCausalLM(nn.Module, SupportsQuant, SupportsLoRA, ...@@ -716,22 +712,7 @@ class TransformersForCausalLM(nn.Module, SupportsQuant, SupportsLoRA,
self.lm_head = PPMissingLayer() self.lm_head = PPMissingLayer()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.transformers_model.make_empty_intermediate_tensors)
# FIXME(Isotr0py): Don't use any weights mapper for Transformers backend,
# this makes thing complicated. We need to remove this mapper after refactor
# `TransformersModel` in the future.
# NOTE: `SupportsQuant` can be updated after property decorator is removed
@property
def hf_to_vllm_mapper(self):
prefix_mapper = {
name: "model." + name
for name, _ in self.model.model.named_children()
}
return WeightsMapper(
orig_to_new_substr={"model.": "model.model."},
orig_to_new_prefix=prefix_mapper,
)
def forward( def forward(
self, self,
...@@ -740,8 +721,9 @@ class TransformersForCausalLM(nn.Module, SupportsQuant, SupportsLoRA, ...@@ -740,8 +721,9 @@ class TransformersForCausalLM(nn.Module, SupportsQuant, SupportsLoRA,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.model(input_ids, positions, intermediate_tensors, model_output = self.transformers_model.forward(input_ids, positions,
inputs_embeds) intermediate_tensors,
inputs_embeds)
return model_output return model_output
def compute_logits( def compute_logits(
...@@ -755,12 +737,10 @@ class TransformersForCausalLM(nn.Module, SupportsQuant, SupportsLoRA, ...@@ -755,12 +737,10 @@ class TransformersForCausalLM(nn.Module, SupportsQuant, SupportsLoRA,
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader( skip_prefixes = ["lm_head."
self, ] if self.config.tie_word_embeddings else None
skip_prefixes=(["lm_head."] loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
if self.config.tie_word_embeddings else None), return loader.load_weights(weights)
)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
@MULTIMODAL_REGISTRY.register_processor( @MULTIMODAL_REGISTRY.register_processor(
...@@ -772,6 +752,29 @@ class TransformersForMultimodalLM(nn.Module, SupportsQuant, SupportsLoRA, ...@@ -772,6 +752,29 @@ class TransformersForMultimodalLM(nn.Module, SupportsQuant, SupportsLoRA,
embedding_padding_modules = ["lm_head"] embedding_padding_modules = ["lm_head"]
embedding_modules = ["embed_tokens"] embedding_modules = ["embed_tokens"]
# Backwards compatibility for prev released models. State dicts back then
# had different formats and cannot be loaded with `AutoModel` mapping as is
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"language_model.model": "model.language_model",
"text_model.model": "model.text_model",
"vision_tower": "model.vision_tower",
"vqmodel": "model.vqmodel",
"visual": "model.visual",
"vision_model": "model.vision_model",
"vision_embed_tokens": "model.vision_embed_tokens",
"image_newline": "model.image_newline",
"multi_modal_projector": "model.multi_modal_projector",
"text_model.lm_head": "lm_head",
"language_model.lm_head": "lm_head",
# Qwen models used "model" as the name for the language model.
# Therefore, we must map each of submodule explicitly to avoid
# conflicts with newer models that use "model.language_model".
"model.embed_tokens": "model.language_model.embed_tokens",
"model.layers": "model.language_model.layers",
"model.norm": "model.language_model.norm",
})
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
config: PretrainedConfig = vllm_config.model_config.hf_config config: PretrainedConfig = vllm_config.model_config.hf_config
...@@ -780,7 +783,9 @@ class TransformersForMultimodalLM(nn.Module, SupportsQuant, SupportsLoRA, ...@@ -780,7 +783,9 @@ class TransformersForMultimodalLM(nn.Module, SupportsQuant, SupportsLoRA,
self.config = config self.config = config
self.dtype = vllm_config.model_config.dtype self.dtype = vllm_config.model_config.dtype
self.model = TransformersModel(vllm_config=vllm_config, prefix=prefix) self.transformers_model = TransformersModel(vllm_config=vllm_config,
prefix=prefix)
self.model = self.transformers_model.model
text_config = config.get_text_config() text_config = config.get_text_config()
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
...@@ -803,32 +808,7 @@ class TransformersForMultimodalLM(nn.Module, SupportsQuant, SupportsLoRA, ...@@ -803,32 +808,7 @@ class TransformersForMultimodalLM(nn.Module, SupportsQuant, SupportsLoRA,
self.lm_head = PPMissingLayer() self.lm_head = PPMissingLayer()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.transformers_model.make_empty_intermediate_tensors)
@property
def hf_to_vllm_mapper(self):
# Backwards compatibility for prev released models
# State dicts back then had different formats
# and cannot be loaded with `AutoModel` mapping
# as is
prefix_mapper = {
"language_model.model": "model.language_model",
"text_model.model": "model.text_model",
"vision_tower": "model.vision_tower",
"vqmodel": "model.vqmodel",
"vision_model": "model.vision_model",
"vision_embed_tokens": "model.vision_embed_tokens",
"image_newline": "model.image_newline",
"multi_modal_projector": "model.multi_modal_projector",
"text_model.lm_head": "lm_head",
"language_model.lm_head": "lm_head",
}
# Don't change the order for QwenVL
if 'Qwen2' in self.config.__class__.__name__:
prefix_mapper["model"] = "model.language_model"
prefix_mapper["visual"] = "model.visual"
return WeightsMapper(orig_to_new_prefix=prefix_mapper, )
def forward( def forward(
self, self,
...@@ -848,8 +828,9 @@ class TransformersForMultimodalLM(nn.Module, SupportsQuant, SupportsLoRA, ...@@ -848,8 +828,9 @@ class TransformersForMultimodalLM(nn.Module, SupportsQuant, SupportsLoRA,
input_ids, multimodal_embeds) input_ids, multimodal_embeds)
input_ids = None input_ids = None
model_output = self.model(input_ids, positions, intermediate_tensors, model_output = self.transformers_model.forward(input_ids, positions,
inputs_embeds) intermediate_tensors,
inputs_embeds)
return model_output return model_output
def compute_logits( def compute_logits(
...@@ -898,7 +879,7 @@ class TransformersForMultimodalLM(nn.Module, SupportsQuant, SupportsLoRA, ...@@ -898,7 +879,7 @@ class TransformersForMultimodalLM(nn.Module, SupportsQuant, SupportsLoRA,
if isinstance(num_image_patches, list): if isinstance(num_image_patches, list):
num_image_patches = torch.cat(num_image_patches) num_image_patches = torch.cat(num_image_patches)
vision_embeddings = self.model.model.get_image_features( vision_embeddings = self.model.get_image_features(
pixel_values, pixel_values,
**{ **{
k: v.flatten(0, 1) k: v.flatten(0, 1)
...@@ -928,7 +909,7 @@ class TransformersForMultimodalLM(nn.Module, SupportsQuant, SupportsLoRA, ...@@ -928,7 +909,7 @@ class TransformersForMultimodalLM(nn.Module, SupportsQuant, SupportsLoRA,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings=None, multimodal_embeddings=None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.model.model.get_input_embeddings()(input_ids) inputs_embeds = self.model.get_input_embeddings()(input_ids)
if (multimodal_embeddings is not None if (multimodal_embeddings is not None
and len(multimodal_embeddings) != 0): and len(multimodal_embeddings) != 0):
mask = (input_ids == self.config.image_token_id) mask = (input_ids == self.config.image_token_id)
......
...@@ -10,7 +10,7 @@ MODELS_ON_S3 = [ ...@@ -10,7 +10,7 @@ MODELS_ON_S3 = [
"allenai/OLMoE-1B-7B-0924-Instruct", "allenai/OLMoE-1B-7B-0924-Instruct",
"amd/Llama-3.1-8B-Instruct-FP8-KV-Quark-test", "amd/Llama-3.1-8B-Instruct-FP8-KV-Quark-test",
"AMead10/Llama-3.2-1B-Instruct-AWQ", "AMead10/Llama-3.2-1B-Instruct-AWQ",
"ArthurZ/Ilama-3.2-1B", "hmellor/Ilama-3.2-1B",
"BAAI/bge-base-en-v1.5", "BAAI/bge-base-en-v1.5",
"BAAI/bge-multilingual-gemma2", "BAAI/bge-multilingual-gemma2",
"BAAI/bge-reranker-v2-m3", "BAAI/bge-reranker-v2-m3",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment