llama_pro.py 5.27 KB
Newer Older
chenych's avatar
chenych committed
1
# Copyright 2025 Tencent Inc. and the LlamaFactory team.
chenych's avatar
chenych committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
#
# This code is inspired by the Tencent's LLaMA-Pro library.
# https://github.com/TencentARC/LLaMA-Pro/blob/main/scripts/block_expansion.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
17
18
19
20

import json
import os
from collections import OrderedDict
chenych's avatar
chenych committed
21
from typing import TYPE_CHECKING
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
22
23
24

import fire
import torch
chenych's avatar
chenych committed
25
from huggingface_hub import split_torch_state_dict_into_shards
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
26
27
from safetensors.torch import save_file
from tqdm import tqdm
luopl's avatar
luopl committed
28
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, PreTrainedModel
chenych's avatar
chenych committed
29
from transformers.modeling_utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
30
31
32


if TYPE_CHECKING:
luopl's avatar
luopl committed
33
    from transformers import PretrainedConfig
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
34
35
36


def change_name(name: str, old_index: int, new_index: int) -> str:
luopl's avatar
luopl committed
37
    return name.replace(f".{old_index:d}.", f".{new_index:d}.")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
38
39
40
41
42
43


def block_expansion(
    model_name_or_path: str,
    output_dir: str,
    num_expand: int,
chenych's avatar
chenych committed
44
    shard_size: str = "5GB",
chenych's avatar
chenych committed
45
    save_safetensors: bool = True,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
46
):
chenych's avatar
chenych committed
47
48
    r"""Perform block expansion for LLaMA, Mistral, Qwen2 or Yi models.

chenych's avatar
chenych committed
49
50
    Usage: python llama_pro.py --model_name_or_path meta-llama/Llama-2-7b-hf --output_dir llama2_pro --num_expand 8
    """
chenych's avatar
chenych committed
51
    config: PretrainedConfig = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
52
    num_layers = getattr(config, "num_hidden_layers")
chenych's avatar
chenych committed
53
54
55
    if num_layers % num_expand != 0:
        raise ValueError(f"`num_layers` {num_layers} should be divisible by `num_expand` {num_expand}.")

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
56
57
58
    setattr(config, "num_hidden_layers", num_layers + num_expand)
    config.save_pretrained(output_dir)

chenych's avatar
chenych committed
59
    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
60
61
    tokenizer.save_pretrained(output_dir)

chenych's avatar
chenych committed
62
    print(f"Expanding model of {num_layers} layers to {num_layers + num_expand} layers.")
luopl's avatar
luopl committed
63
    model = AutoModelForCausalLM.from_pretrained(
chenych's avatar
chenych committed
64
        model_name_or_path, torch_dtype="auto", device_map="cpu", trust_remote_code=True, low_cpu_mem_usage=True
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
65
    )
luopl's avatar
luopl committed
66
    assert isinstance(model, PreTrainedModel)  # type hint
chenych's avatar
chenych committed
67
68
    if save_safetensors and getattr(model.config, "tie_word_embeddings", False):
        del model.lm_head  # safetensors does not allow shared weights
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
69
70
71

    split = num_layers // num_expand
    layer_cnt = 0
chenych's avatar
chenych committed
72
    state_dict = model.state_dict()
chenych's avatar
chenych committed
73
    output_state_dict: dict[str, torch.Tensor] = OrderedDict()
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
74
75
    for i in range(num_layers):
        for key, value in state_dict.items():
luopl's avatar
luopl committed
76
            if f".{i:d}." in key:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
77
78
                output_state_dict[change_name(key, i, layer_cnt)] = value

luopl's avatar
luopl committed
79
        print(f"Add layer {layer_cnt} copied from layer {i}.")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
80
81
82
        layer_cnt += 1
        if (i + 1) % split == 0:
            for key, value in state_dict.items():
luopl's avatar
luopl committed
83
                if f".{i:d}." in key:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
84
85
86
87
88
                    if "down_proj" in key or "o_proj" in key:
                        output_state_dict[change_name(key, i, layer_cnt)] = torch.zeros_like(value)
                    else:
                        output_state_dict[change_name(key, i, layer_cnt)] = torch.clone(value)

luopl's avatar
luopl committed
89
            print(f"Add layer {layer_cnt} expanded from layer {i}.")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
90
91
92
93
94
95
96
            layer_cnt += 1

    for key, value in state_dict.items():
        if key not in output_state_dict:
            output_state_dict[key] = value

    weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME
chenych's avatar
chenych committed
97
98
99
100
101
102
    filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
    state_dict_split = split_torch_state_dict_into_shards(
        output_state_dict, filename_pattern=filename_pattern, max_shard_size=shard_size
    )
    for shard_file, tensors in tqdm(state_dict_split.filename_to_tensors.items(), desc="Save weights"):
        shard = {tensor: output_state_dict[tensor].contiguous() for tensor in tensors}
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
103
104
105
106
107
        if save_safetensors:
            save_file(shard, os.path.join(output_dir, shard_file), metadata={"format": "pt"})
        else:
            torch.save(shard, os.path.join(output_dir, shard_file))

chenych's avatar
chenych committed
108
    if not state_dict_split.is_sharded:
luopl's avatar
luopl committed
109
        print(f"Model weights saved in {os.path.join(output_dir, weights_name)}.")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
110
    else:
chenych's avatar
chenych committed
111
112
113
114
        index = {
            "metadata": state_dict_split.metadata,
            "weight_map": state_dict_split.tensor_to_filename,
        }
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
115
116
117
        index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
        with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
            json.dump(index, f, indent=2, sort_keys=True)
luopl's avatar
luopl committed
118
119

        print(f"Model weights saved in {output_dir}.")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
120

chenych's avatar
chenych committed
121
    print("- Fine-tune this model with:")
luopl's avatar
luopl committed
122
    print(f"model_name_or_path: {output_dir}")
chenych's avatar
chenych committed
123
    print("finetuning_type: freeze")
luopl's avatar
luopl committed
124
    print(f"freeze_trainable_layers: {num_expand}")
chenych's avatar
chenych committed
125
    print("use_llama_pro: true")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
126
127
128
129


if __name__ == "__main__":
    fire.Fire(block_expansion)