# coding=utf-8 # Adapted from # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py # Copyright 2023 The vLLM team. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # and OPT implementations in this library. It has been modified from its # original forms to accommodate minor architectural differences compared # to GPT-NeoX and OPT used by the Meta AI team that trained the model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Inference-only MiniCPM-V model compatible with HuggingFace weights.""" import math import re from functools import partial from typing import Dict, Iterable, List, Optional, Tuple, Union import numpy as np import torch import torch.nn.functional as F import torch.types from PIL import Image from torch import nn from torch.nn.init import trunc_normal_ from transformers.configuration_utils import PretrainedConfig from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import SupportsVision from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.models.minicpm import MiniCPMModel from vllm.model_executor.models.qwen2 import Qwen2Model from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.image import (cached_get_image_processor, cached_get_tokenizer) from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData _KEYS_TO_MODIFY_MAPPING = { "llm.lm_head": "lm_head", "llm.model": "llm", } def get_abs_pos(abs_pos: torch.Tensor, tgt_size: torch.Tensor): # abs_pos: L, C # tgt_size: (H, W) # return: M, C src_size = int(math.sqrt(abs_pos.size(0))) # tgt_size = int(math.sqrt(tgt_size)) dtype = abs_pos.dtype return F.interpolate( abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2), size=(tgt_size[0], tgt_size[1]), mode="bicubic", align_corners=False, ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype) # https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 def get_2d_sincos_pos_embed(embed_dim: int, grid_size: Union[int, Tuple[int, int]], cls_token: bool = False, version: Tuple[int, int] = (2, 0)): """ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) """ if isinstance(grid_size, int): grid_h_size, grid_w_size = grid_size, grid_size else: grid_h_size, grid_w_size = grid_size[0], grid_size[1] grid_h = np.arange(grid_h_size, dtype=np.float32) grid_w = np.arange(grid_w_size, dtype=np.float32) grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.stack(grid, axis=0) if version == (2, 0): grid = grid.reshape([2, 1, grid_h_size, grid_w_size]) pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version) if cls_token: pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) else: pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version) return pos_embed def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: Union[int, Tuple[int, int]], version: Tuple[int, int] = (2, 0)): assert embed_dim % 2 == 0 # use half of dimensions to encode grid_h emb_h = get_1d_sincos_pos_embed_from_grid( embed_dim // 2, grid[0], version) # (H*W, D/2) or (H, W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid( embed_dim // 2, grid[1], version) # (H*W, D/2) or (H, W, D/2) if version == (2, 0): emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) else: emb = np.concatenate([emb_h, emb_w], axis=-1) # (H, W, D) return emb def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: int, version: Tuple[int, int] = (2, 0)): """ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) / (H, W) out: (M, D) / (H, W, D) """ assert embed_dim % 2 == 0 omega = np.arange(embed_dim // 2, dtype=np.float32) omega /= embed_dim / 2. omega = 1. / 10000**omega # (D/2,) if version == (2, 0): pos = pos.reshape(-1) # (M,) out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product emb_sin = np.sin(out) # (M, D/2) emb_cos = np.cos(out) # (M, D/2) emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) else: out = np.einsum('hw,d->hwd', pos, omega) # (H, W, D/2), outer product emb_sin = np.sin(out) # (H, W, D/2) emb_cos = np.cos(out) # (H, W, D/2) emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (H, W, D) return emb class Resampler(nn.Module): """ A 2D perceiver-resampler network with one cross attention layers by (grid_size**2) learnable queries and 2d sincos pos_emb Outputs: A tensor with the shape of (grid_size**2, embed_dim) """ default_norm_layer = partial(nn.LayerNorm, eps=1e-6) def __init__(self, num_queries: int, grid_size: int, embed_dim: int, num_heads: int, kv_dim: Optional[int] = None, norm_layer: nn.Module = default_norm_layer, adaptive: bool = False, max_size: Tuple[int, int] = (70, 70), version: Tuple[int, int] = (2, 0)): super().__init__() self.version = version if self.version == (2, 0): self.num_queries = grid_size**2 else: self.num_queries = num_queries self.max_size = max_size self.embed_dim = embed_dim self.num_heads = num_heads self.adaptive = adaptive self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim)) trunc_normal_(self.query, std=.02) if kv_dim is not None and kv_dim != embed_dim: self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False) else: self.kv_proj = nn.Identity() self.attn = nn.MultiheadAttention(embed_dim, num_heads) self.ln_q = norm_layer(embed_dim) self.ln_kv = norm_layer(embed_dim) self.ln_post = norm_layer(embed_dim) self.proj = nn.Parameter( (embed_dim**-0.5) * torch.randn(embed_dim, embed_dim)) if self.version == (2, 0): self.pos_embed = nn.Parameter( torch.from_numpy( get_2d_sincos_pos_embed( embed_dim, grid_size, version=self.version)).float()).requires_grad_(False) else: self._set_2d_pos_cache(self.max_size) self.apply(self._init_weights) def _set_2d_pos_cache(self, max_size: Tuple[int, int], device: torch.types.Device = 'cpu'): pos_embed = torch.from_numpy( get_2d_sincos_pos_embed(self.embed_dim, max_size, version=self.version)).float().to(device) self.register_buffer("pos_embed", pos_embed, persistent=False) def _adjust_pos_cache(self, tgt_sizes: torch.Tensor, device: torch.types.Device): max_h = torch.max(tgt_sizes[:, 0]) max_w = torch.max(tgt_sizes[:, 1]) if max_h > self.max_size[0] or max_w > self.max_size[1]: self.max_size = [ max(max_h, self.max_size[0]), max(max_w, self.max_size[1]) ] self._set_2d_pos_cache(self.max_size, device) def _init_weights(self, m: nn.Module): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def forward_2_5(self, x: torch.Tensor, tgt_sizes: Optional[torch.Tensor] = None): assert x.shape[0] == tgt_sizes.shape[0] bs = x.shape[0] device = x.device dtype = x.dtype patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1] self._adjust_pos_cache(tgt_sizes, device=device) max_patch_len = torch.max(patch_len) key_padding_mask = torch.zeros((bs, max_patch_len), dtype=torch.bool, device=device) pos_embed = [] for i in range(bs): tgt_h, tgt_w = tgt_sizes[i] pos_embed.append(self.pos_embed[:tgt_h, :tgt_w, :].reshape( (tgt_h * tgt_w, -1)).to(dtype)) # patches * D key_padding_mask[i, patch_len[i]:] = True pos_embed = torch.nn.utils.rnn.pad_sequence(pos_embed, batch_first=True, padding_value=0.0).permute( 1, 0, 2) # BLD => L * B * D x = self.kv_proj(x) # B * L * D x = self.ln_kv(x).permute(1, 0, 2) # L * B * D q = self.ln_q(self.query) # Q * D out = self.attn( self._repeat(q, bs), # Q * B * D x + pos_embed, # L * B * D + L * B * D x, key_padding_mask=key_padding_mask)[0] # out: Q * B * D x = out.permute(1, 0, 2) # B * Q * D x = self.ln_post(x) x = x @ self.proj return x def forward_2(self, x: torch.Tensor, tgt_sizes: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None): if self.adaptive: pos_embed = torch.Tensor( get_2d_sincos_pos_embed(self.embed_dim, tgt_sizes)).float().to(device=x.device, dtype=x.dtype) else: pos_embed = get_abs_pos(self.pos_embed, tgt_sizes) x = self.kv_proj(x) x = self.ln_kv(x).permute(1, 0, 2) N = x.shape[1] q = self.ln_q(self.query) out = self.attn(self._repeat(q, N) + self.pos_embed.unsqueeze(1), x + pos_embed.unsqueeze(1), x, attn_mask=attn_mask)[0] x = out.permute(1, 0, 2) x = self.ln_post(x) x = x @ self.proj return x def forward(self, x: torch.Tensor, tgt_sizes: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None): if self.version == (2, 0): return self.forward_2(x, tgt_sizes=tgt_sizes, attn_mask=attn_mask) else: return self.forward_2_5(x, tgt_sizes=tgt_sizes) def _repeat(self, query, N: int): return query.unsqueeze(1).repeat(1, N, 1) def get_max_minicpmv_image_tokens(ctx: InputContext): hf_config = ctx.get_hf_config(PretrainedConfig) return getattr(hf_config, "query_num", 64) def dummy_seq_data_for_minicpmv(seq_len: int): token_ids = [0] * seq_len return SequenceData(token_ids) def dummy_image_for_minicpmv(hf_config: PretrainedConfig): width = height = hf_config.image_size image = Image.new("RGB", (width, height), color=0) return {"image": image} def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int): hf_config = ctx.get_hf_config(PretrainedConfig) # image_feature_size = get_max_minicpmv_image_tokens(ctx) seq_data = dummy_seq_data_for_minicpmv(seq_len) mm_data = dummy_image_for_minicpmv(hf_config) return seq_data, mm_data def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs): multi_modal_data = llm_inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: return llm_inputs model_config = ctx.model_config tokenizer = cached_get_tokenizer(model_config.tokenizer, trust_remote_code=True) prompt = llm_inputs.get("prompt") if prompt is None: token_ids = llm_inputs.get("prompt_token_ids") prompt = tokenizer.decode(token_ids) image_processor = cached_get_image_processor(model_config.tokenizer) pattern = "(./)" image = multi_modal_data["image"] image_tags = re.findall(pattern, prompt) assert len(image_tags) <= 1 text_chunks = prompt.split(pattern) new_prompt = text_chunks[0] \ + image_processor.get_slice_image_placeholder(image.size) \ + text_chunks[1] new_token_ids = tokenizer.encode(new_prompt) llm_inputs = LLMInputs(prompt_token_ids=new_token_ids, prompt=new_prompt, multi_modal_data=multi_modal_data) return llm_inputs @MULTIMODAL_REGISTRY.register_image_input_mapper() @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_minicpmv_image_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_minicpmv) @INPUT_REGISTRY.register_input_processor(input_processor_for_minicpmv) class MiniCPMV(nn.Module, SupportsVision): def __init__( self, config: PretrainedConfig, multimodal_config: MultiModalConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config self.multimodal_config = multimodal_config if not hasattr(self.config, "version"): if self.config.hidden_size == 2304 and self.config.query_num == 64: self.version = (2, 0) else: self.version = (2, 5) else: self.version = str(self.config.version).split(".") self.version = tuple([int(x) for x in self.version]) self.llm = self.init_llm(config, cache_config, quant_config) self.vpm = self.init_vision_module() param_dtype = torch.get_default_dtype() self.vpm.to(dtype=param_dtype) self.vision_dim = self.vpm.embed_dim if self.version == (2, 0) \ else self.vpm.embeddings.embed_dim self.embed_dim = self.config.hidden_size self.resampler = self.init_resampler(self.embed_dim, self.vision_dim) self.resampler.to(device="cuda", dtype=param_dtype) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() def init_llm(self, config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): if self.version == (2, 0): return MiniCPMModel(config, cache_config=cache_config, quant_config=quant_config) elif self.version == (2, 5): return LlamaModel(config, cache_config=cache_config, quant_config=quant_config) else: return Qwen2Model(config, cache_config=cache_config, quant_config=quant_config) def init_vision_module(self): if self.version == (2, 0): try: import timm except ImportError: raise ImportError( 'Please install timm==0.9.10') from ImportError default_dtype = torch.get_default_dtype() torch.set_default_dtype(torch.float16) model = timm.create_model('vit_so400m_patch14_siglip_384.webli', pretrained=False, num_classes=0, dynamic_img_size=True, dynamic_img_pad=True) torch.set_default_dtype(default_dtype) if isinstance(model, timm.models.VisionTransformer ) and model.attn_pool is not None: model.attn_pool = torch.nn.Identity() if self.config.drop_vision_last_layer: model.blocks = model.blocks[:-1] elif self.version == (2, 5): from transformers.models.idefics2.modeling_idefics2 import ( Idefics2VisionTransformer) model = Idefics2VisionTransformer(self.config.vision_config) if self.config.drop_vision_last_layer: model.encoder.layers = model.encoder.layers[:-1] else: from vllm.model_executor.models.na_vit import ( SiglipVisionTransformer) if self.config._attn_implementation == 'flash_attention_2': self.config.vision_config._attn_implementation \ = 'flash_attention_2' else: # not support sdpa self.config.vision_config._attn_implementation = 'eager' model = SiglipVisionTransformer(self.config.vision_config) if self.config.drop_vision_last_layer: model.encoder.layers = model.encoder.layers[:-1] return model def init_resampler(self, embed_dim: int, vision_dim: int): default_dtype = torch.get_default_dtype() torch.set_default_dtype(torch.float16) if self.version == (2, 0): resampler = Resampler(grid_size=int( math.sqrt(self.config.query_num)), num_queries=None, embed_dim=embed_dim, num_heads=embed_dim // 128, kv_dim=vision_dim, adaptive=True, version=self.version) else: resampler = Resampler(num_queries=self.config.query_num, grid_size=None, embed_dim=embed_dim, num_heads=embed_dim // 128, kv_dim=vision_dim, adaptive=True, version=self.version) torch.set_default_dtype(default_dtype) return resampler def get_vision_embedding(self, pixel_values: List[List[torch.Tensor]], patch_attn_mask: Optional[torch.Tensor] = None, tgt_sizes: Optional[torch.Tensor] = None, version: Tuple[int, int] = (2, 0)): if version == (2, 0): res = [] dtype = self.vpm.pos_embed.data.dtype for pixel_value in pixel_values: # V2.0 start H, W = pixel_value[0].shape[-2:] tgt_size = (math.ceil(H / self.vpm.patch_embed.patch_size[0]), math.ceil(W / self.vpm.patch_embed.patch_size[0])) # V2.0 end vision_embedding = self.vpm.forward_features( pixel_value.unsqueeze(0).type(dtype)) if hasattr(self.vpm, 'num_prefix_tokens' ) and self.vpm.num_prefix_tokens > 0: vision_embedding = vision_embedding[:, self.vpm. num_prefix_tokens:] res.append(self.resampler(vision_embedding, tgt_size)) return torch.vstack(res) elif version == (2, 5): vision_embedding = self.vpm( pixel_values.type(dtype), patch_attention_mask=patch_attn_mask).last_hidden_state vision_embedding = self.resampler(vision_embedding, tgt_sizes) else: vision_embedding = self.vpm(pixel_values.type(dtype), patch_attention_mask=patch_attn_mask, tgt_sizes=tgt_sizes).last_hidden_state def get_image_bounds(self, input_ids: torch.Tensor): tokenizer = cached_get_tokenizer(self.config._name_or_path, trust_remote_code=True) if not hasattr(tokenizer, "slice_start_id"): start_cond = input_ids == tokenizer.im_start_id end_cond = input_ids == tokenizer.im_end_id else: start_cond = (input_ids == tokenizer.im_start_id) | ( input_ids == tokenizer.slice_start_id) end_cond = (input_ids == tokenizer.im_end_id) | ( input_ids == tokenizer.slice_end_id) image_start_tokens = torch.where(start_cond)[0] image_start_tokens += 1 image_end_tokens = torch.where(end_cond)[0] valid_image_nums = max(len(image_start_tokens), len(image_end_tokens)) if valid_image_nums == 0: return [] image_bound = torch.hstack([ image_start_tokens[:valid_image_nums].unsqueeze(-1), image_end_tokens[:valid_image_nums].unsqueeze(-1), ]) return image_bound def get_vision_hidden_states(self, data: Dict[str, Union[List[torch.Tensor], torch.Tensor]]): if "vision_hidden_states" not in data: pixel_values = data["pixel_values"] tgt_sizes = data["tgt_sizes"] vision_hidden_states = [] if self.version == (2, 0): if pixel_values is not None and len(pixel_values) > 0: vision_hidden_states = self.get_vision_embedding( pixel_values) else: vision_hidden_states = torch.tensor([]).to( data["input_ids"].device) else: device = self.vpm.embeddings.position_embedding.weight.device dtype = self.vpm.embeddings.position_embedding.weight.dtype all_pixel_values = [ i.flatten(end_dim=1).permute(1, 0) for i in pixel_values ] if all_pixel_values: tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32) max_patches = torch.max(tgt_sizes[:, 0] * tgt_sizes[:, 1]) all_pixel_values = torch.nn.utils.rnn.pad_sequence( all_pixel_values, batch_first=True, padding_value=0.0) B, L, _ = all_pixel_values.shape all_pixel_values = all_pixel_values.permute( 0, 2, 1).reshape(B, 3, -1, L) patch_attn_mask = torch.zeros((B, 1, max_patches), dtype=torch.bool, device=device) if self.version == (2, 5): for i in range(B): patch_attn_mask[i, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True vision_embedding = self.vpm( all_pixel_values.type(dtype), patch_attention_mask=patch_attn_mask ).last_hidden_state else: for i in range(B): patch_attn_mask[i, 0, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True vision_embedding = self.vpm( all_pixel_values.type(dtype), patch_attention_mask=patch_attn_mask, tgt_sizes=tgt_sizes).last_hidden_state vision_hidden_states = self.resampler( vision_embedding, tgt_sizes) else: # no image dummy_feature = [] vision_hidden_states = dummy_feature else: vision_hidden_states = data["vision_hidden_states"] return vision_hidden_states def get_embedding(self, data: Dict[str, Union[List[torch.Tensor], torch.Tensor]]): input_ids = data["input_ids"] vision_hidden_states = self.get_vision_hidden_states(data) if vision_hidden_states is not None and len(vision_hidden_states) > 0: image_bounds = self.get_image_bounds(input_ids) else: image_bounds = [] if hasattr(self.config, 'scale_emb'): vlm_embedding = self.llm.embed_tokens( input_ids) * self.config.scale_emb else: vlm_embedding = self.llm.embed_tokens(input_ids) vision_hidden_states = [ i.type(vlm_embedding.dtype) if isinstance(i, torch.Tensor) else i for i in vision_hidden_states ] if len(vision_hidden_states) > 0 and len(image_bounds) > 0: vision_hidden_states = torch.cat(vision_hidden_states, dim=0) image_indices = torch.stack([ torch.arange(r[0], r[1], dtype=torch.long) for r in image_bounds ]).to(vlm_embedding.device) vlm_embedding.scatter_( 0, image_indices.view(-1, 1).repeat(1, vlm_embedding.shape[-1]), vision_hidden_states.view(-1, vision_hidden_states.shape[-1])) return vlm_embedding, vision_hidden_states def process_multimodal_inputs(self, inputs: Dict[str, Union[List[torch.Tensor], torch.Tensor]]): pixel_values = [] tgt_sizes = [] for b in range(len(inputs["pixel_values"])): pixel_values += inputs["pixel_values"][b] tgt_sizes += inputs["tgt_sizes"][b] return { "pixel_values": pixel_values, "input_ids": inputs["input_ids"], "tgt_sizes": tgt_sizes } def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs: object, ): inputs = { "pixel_values": kwargs.pop("pixel_values", []), "input_ids": input_ids, "tgt_sizes": kwargs.pop("tgt_sizes", None), } inputs = self.process_multimodal_inputs(inputs) vlm_embeddings, vision_hidden_states = self.get_embedding(inputs) output = self.llm(input_ids=None, positions=positions, kv_caches=kv_caches, attn_metadata=attn_metadata, intermediate_tensors=intermediate_tensors, inputs_embeds=vlm_embeddings) return output def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits def sample( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): if key_to_modify in name: name = name.replace(key_to_modify, new_key) if "rotary_emb.inv_freq" in name: continue if ("rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name): # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue use_default_weight_loading = False if "vpm" in name or 'resampler' in name: # We only do sharding for language model and # not vision model for now. use_default_weight_loading = True else: for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue param = params_dict[name.replace(weight_name, param_name)] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: use_default_weight_loading = True if use_default_weight_loading: param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight)