Unverified Commit 16ee07f2 authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Model] Refactor Molmo weights loading to use AutoWeightsLoader (#10771)


Signed-off-by: default avatarIsotr0py <2037008807@qq.com>
parent 40bc2425
...@@ -3,7 +3,7 @@ import re ...@@ -3,7 +3,7 @@ import re
from array import array from array import array
from dataclasses import dataclass from dataclasses import dataclass
from functools import lru_cache, partial from functools import lru_cache, partial
from typing import Iterable, List, Mapping, Optional, Tuple, TypedDict from typing import Iterable, List, Mapping, Optional, Set, Tuple, TypedDict
import torch import torch
from einops import rearrange from einops import rearrange
...@@ -44,7 +44,8 @@ from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, ...@@ -44,7 +44,8 @@ from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
from vllm.transformers_utils.processor import get_processor from vllm.transformers_utils.processor import get_processor
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (get_vit_attn_backend, from .utils import (AutoWeightsLoader, WeightsMapper, get_vit_attn_backend,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix) maybe_prefix)
...@@ -720,6 +721,42 @@ class MolmoVisionBackbone(nn.Module): ...@@ -720,6 +721,42 @@ class MolmoVisionBackbone(nn.Module):
# image_features: (batch_size, num_image, num_patch, d_model) # image_features: (batch_size, num_image, num_patch, d_model)
return image_features return image_features
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
@support_torch_compile @support_torch_compile
class MolmoModel(nn.Module): class MolmoModel(nn.Module):
...@@ -804,6 +841,28 @@ class MolmoModel(nn.Module): ...@@ -804,6 +841,28 @@ class MolmoModel(nn.Module):
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "gate_up_proj" in name:
up_proj, gate_proj = loaded_weight.chunk(2, dim=0)
loaded_weight = torch.cat([gate_proj, up_proj], dim=0)
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
cached_get_processor = lru_cache(get_processor) cached_get_processor = lru_cache(get_processor)
...@@ -1200,103 +1259,53 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -1200,103 +1259,53 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
hf_to_vllm_mapper = WeightsMapper(
params_mapping = [ orig_to_new_substr={
("model.transformer.ln_f.weight", "model.norm.weight"), # vision backbone mapping
("attn_out", "self_attn.o_proj"), "image_projector.w1.": "image_projector.gate_proj.",
("att_proj", "self_attn.qkv_proj"), "image_projector.w3.": "image_projector.up_proj.",
("q_norm", "self_attn.q_norm"), "image_projector.w2.": "image_projector.down_proj.",
("k_norm", "self_attn.k_norm"), # language backbone mapping
("attn_norm", "input_layernorm"), "att_proj": "self_attn.qkv_proj",
("ff_norm", "post_attention_layernorm"), "attn_out": "self_attn.o_proj",
] "q_norm": "self_attn.q_norm",
"k_norm": "self_attn.k_norm",
params_dict = dict(self.named_parameters(remove_duplicate=False)) "ff_proj": "mlp.gate_up_proj",
"ff_out": "mlp.down_proj",
embedding_weight = dict() "attn_norm": "input_layernorm",
projector_weight = dict() "ff_norm": "post_attention_layernorm",
for name, loaded_weight in weights: },
if "rotary_emb.inv_freq" in name: orig_to_new_prefix={
continue # vision backbone mapping
if self.config.tie_word_embeddings and "lm_head.weight" in name: "model.vision_backbone.": "vision_backbone.",
continue # language backbone mapping
"model.transformer.blocks.": "model.layers.",
if "wte.embedding" in name: "model.transformer.ln_f.": "model.norm.",
embedding_weight["embedding"] = loaded_weight # lm_head is renamed to model.transformer.mlp.down_proj firstly,
continue # we need to run a second renaming for it
"model.transformer.mlp.down_proj.": "lm_head.",
if "wte.new_embedding" in name: },
embedding_weight["new_embedding"] = loaded_weight )
continue loader = AutoWeightsLoader(self)
weights = _get_weights_with_merged_embedding(weights)
if "vision_backbone" in name: return loader.load_weights(weights, mapper=hf_to_vllm_mapper)
if name.startswith("model"):
name = name[len("model."):]
if 'image_projector' in name: def _get_weights_with_merged_embedding(
if 'w1' in name: weights: Iterable[Tuple[str, torch.Tensor]]
projector_weight['gate_proj'] = loaded_weight ) -> Iterable[Tuple[str, torch.Tensor]]:
elif 'w3' in name: embedding_weights = {}
projector_weight['up_proj'] = loaded_weight for name, weight in weights:
elif 'w2' in name: if "wte.embedding" in name:
projector_weight['down_proj'] = loaded_weight embedding_weights["embedding"] = weight
else: elif "wte.new_embedding" in name:
raise ValueError( embedding_weights["new_embedding"] = weight
f"Unexpected projector weight: {name}") else:
continue yield (name, weight)
else: # this is compatible with most of quantization,
if "transformer.blocks" in name: # because they won't quantize embed_tokens
name = name.replace("transformer.blocks", "layers") embedding_weights = torch.cat(
[embedding_weights["embedding"], embedding_weights["new_embedding"]],
if "ff_proj" in name: dim=0,
name = name.replace("ff_proj", "mlp.gate_up_proj") )
assert 'weight' in name yield ("model.embed_tokens.weight", embedding_weights)
up_weight, gate_weight = loaded_weight.chunk(2, dim=0)
loaded_weight = torch.cat([gate_weight, up_weight], dim=0)
elif "ff_out" in name:
if "layers" in name:
name = name.replace("ff_out", "mlp.down_proj")
else:
# lm head
name = name.replace("model.transformer.ff_out",
"lm_head")
else:
for (param_name, weight_name) in params_mapping:
if param_name in name:
name = name.replace(param_name, weight_name)
break
try:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
except KeyError:
raise ValueError(f"Unexpected weight: {name}") from None
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
gate_up_proj_weight = torch.cat(
[projector_weight["gate_proj"], projector_weight["up_proj"]],
dim=0)
name = "vision_backbone.image_projector.gate_up_proj.weight"
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, gate_up_proj_weight)
down_proj_weight = projector_weight["down_proj"]
name = "vision_backbone.image_projector.down_proj.weight"
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, down_proj_weight)
embedding_weight = torch.cat(
[embedding_weight["embedding"], embedding_weight["new_embedding"]],
dim=0)
name = "model.embed_tokens.weight"
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, embedding_weight)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment