# Copyright (c) Alibaba, Inc. and its affiliates. from copy import deepcopy from dataclasses import dataclass, field from typing import Literal, Optional import torch from torch import nn from swift.utils.logger import get_logger from .module_mapping import MODEL_KEYS_MAPPING, ModelKeys from .utils import SwiftAdapter, SwiftConfig, SwiftOutput logger = get_logger() @dataclass class LLaMAProConfig(SwiftConfig): """ The configuration class for the LLaMAPro module. See https://arxiv.org/abs/2401.02415 Args: model_type(`str`): LLaMAPro only support parts of the LLM models because of the variables need to be manually modified. num_new_blocks(`int`): How many new blocks need to be added num_groups(`int`): The groups of new blocks are split to. Default equals to `num_new_blocks` which means each single layer will be inserted into every `num_hidden_layers/num_new_blocks` original layers. """ model_type: str = field( default=None, metadata={ 'choices': list(MODEL_KEYS_MAPPING.keys()), }) num_new_blocks: int = None num_groups: Optional[int] = None def __post_init__(self): from .mapping import SwiftTuners self.swift_type = SwiftTuners.LLAMAPRO class LLaMAPro(SwiftAdapter): @staticmethod def prepare_model(model: nn.Module, config: LLaMAProConfig, adapter_name: str) -> SwiftOutput: """Prepare a model with `LLaMAProConfig`""" num_hidden_layers = None if hasattr(model.config, 'num_hidden_layers'): num_hidden_layers = model.config.num_hidden_layers elif hasattr(model.config, 'num_layers'): num_hidden_layers = model.config.num_layers assert num_hidden_layers is not None, 'Cannot find num of layers config' assert num_hidden_layers % config.num_new_blocks == 0, f'Model layers {num_hidden_layers} ' \ f'should be divided by {config.num_new_blocks}' if config.num_groups is None: config.num_groups = config.num_new_blocks num_stride = num_hidden_layers // config.num_groups # We only support decoder only model for now. module_list = LLaMAPro._find_module_list(config, model) new_module_list = nn.ModuleList() new_module_idx = [] for idx, module in enumerate(module_list): new_module_list.append(module) if (idx + 1) % num_stride == 0: new_module = deepcopy(module) new_module_list.append(new_module) new_module_idx.append(idx + 1 + len(new_module_idx)) LLaMAPro._update_module_weight(config, new_module_list, new_module_idx) LLaMAPro._update_module_attr(config, new_module_list) model.config.num_hidden_layers = len(new_module_list) LLaMAPro._set_module_list(config, model, new_module_list) def state_dict_callback(state_dict, adapter_name): model_key_mapping = LLaMAPro._get_model_key_mapping(config.model_type, config) new_module_list = [model_key_mapping.module_list + f'.{i}' for i in new_module_idx] return { key: value for key, value in state_dict.items() if any([m_part in key for m_part in new_module_list]) } def mark_trainable_callback(model): model_key_mapping = LLaMAPro._get_model_key_mapping(config.model_type, config) new_module_list = [model_key_mapping.module_list + f'.{i}' for i in new_module_idx] for name, parameter in model.named_parameters(): parameter: nn.Parameter if any([m_part in name for m_part in new_module_list]): parameter.requires_grad = True return SwiftOutput(config, state_dict_callback, mark_trainable_callback) @staticmethod def _get_model_key_mapping(model_type, config) -> ModelKeys: if model_type in MODEL_KEYS_MAPPING.keys(): model_key_mapping = MODEL_KEYS_MAPPING[model_type] else: model_key_mapping = config.model_key_mapping if model_key_mapping is None: raise ValueError(f'{model_type} is not defined in MODEL_KEYS_MAPPING, ' f'please consider pass the information through the config.model_key_mapping') if isinstance(model_key_mapping, dict): model_key_mapping: ModelKeys = ModelKeys(**model_key_mapping) assert model_key_mapping.o_proj is not None and model_key_mapping.down_proj is not None, \ 'LLaMAPro only support models with o_proj and down_proj components.' return model_key_mapping @staticmethod def _update_module_attr(config: LLaMAProConfig, module_list): model_type = config.model_type model_key_mapping = LLaMAPro._get_model_key_mapping(model_type, config) attention = model_key_mapping.attention attention = attention.split('{}.')[1] if model_type in ('llama', 'mistral', 'qwen2', 'yi', 'gemma', 'deepseek', 'openbuddy', 'xverse', 'orion', 'bluelm', 'ziya', 'skywork'): for idx, module in enumerate(module_list): getattr(module, attention).layer_idx = idx elif model_type in ('chatglm', ): for idx, module in enumerate(module_list): getattr(module, attention).layer_number = idx elif model_type in ('phi2', ): for idx, module in enumerate(module_list): getattr(module, attention).block_idx = idx @staticmethod def _update_module_weight(config: LLaMAProConfig, module_list, new_module_idx): model_key_mapping = LLaMAPro._get_model_key_mapping(config.model_type, config) o_proj = model_key_mapping.o_proj.split('{}.')[1] down_proj = model_key_mapping.o_proj.split('{}.')[1] for idx, module in enumerate(module_list): if idx not in new_module_idx: continue _o_proj: nn.Linear = module.get_submodule(o_proj) _down_proj: nn.Linear = module.get_submodule(down_proj) _o_proj.weight.data = torch.zeros_like(_o_proj.weight.data) _down_proj.weight.data = torch.zeros_like(_down_proj.weight.data) if hasattr(_o_proj, 'bias') and _o_proj.bias: _o_proj.bias = torch.zeros_like(_o_proj.bias) if hasattr(_down_proj, 'bias') and _down_proj.bias: _down_proj.bias = torch.zeros_like(_down_proj.bias) @staticmethod def _set_module_list(config, module: nn.Module, module_list: nn.ModuleList): model_key_mapping = LLaMAPro._get_model_key_mapping(config.model_type, config) idx = model_key_mapping.module_list.rfind('.') parent = module.get_submodule(model_key_mapping.module_list[:idx]) setattr(parent, model_key_mapping.module_list[idx + 1:], module_list) @staticmethod def _find_module_list(config, module: nn.Module) -> nn.ModuleList: model_key_mapping = LLaMAPro._get_model_key_mapping(config.model_type, config) return module.get_submodule(model_key_mapping.module_list) @staticmethod def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None): for sub_module in module.modules(): if isinstance(sub_module, torch.nn.Embedding): sub_module.nef_activated = activate @staticmethod def has_additional_modules(): return True