llama_pro.py 4.73 KB
Newer Older
chenych's avatar
chenych committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Copyright 2024 Tencent Inc. and the LlamaFactory team.
#
# This code is inspired by the Tencent's LLaMA-Pro library.
# https://github.com/TencentARC/LLaMA-Pro/blob/main/scripts/block_expansion.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
17
18
19
20

import json
import os
from collections import OrderedDict
chenych's avatar
chenych committed
21
from typing import TYPE_CHECKING
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41

import fire
import torch
from safetensors.torch import save_file
from tqdm import tqdm
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers.modeling_utils import (
    SAFE_WEIGHTS_INDEX_NAME,
    SAFE_WEIGHTS_NAME,
    WEIGHTS_INDEX_NAME,
    WEIGHTS_NAME,
    shard_checkpoint,
)


if TYPE_CHECKING:
    from transformers import PretrainedConfig, PreTrainedModel


def change_name(name: str, old_index: int, new_index: int) -> str:
luopl's avatar
luopl committed
42
    return name.replace(f".{old_index:d}.", f".{new_index:d}.")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
43
44
45
46
47
48


def block_expansion(
    model_name_or_path: str,
    output_dir: str,
    num_expand: int,
chenych's avatar
chenych committed
49
50
    shard_size: str = "2GB",
    save_safetensors: bool = True,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
51
):
chenych's avatar
chenych committed
52
53
54
55
    r"""
    Performs block expansion for LLaMA, Mistral, Qwen1.5 or Yi models.
    Usage: python llama_pro.py --model_name_or_path meta-llama/Llama-2-7b-hf --output_dir llama2_pro --num_expand 8
    """
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
    config: "PretrainedConfig" = AutoConfig.from_pretrained(model_name_or_path)
    num_layers = getattr(config, "num_hidden_layers")
    setattr(config, "num_hidden_layers", num_layers + num_expand)
    config.save_pretrained(output_dir)

    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
    tokenizer.save_pretrained(output_dir)

    config: "PretrainedConfig" = AutoConfig.from_pretrained(model_name_or_path)  # load the original one
    if save_safetensors:
        setattr(config, "tie_word_embeddings", False)  # safetensors does not allow shared weights

    model: "PreTrainedModel" = AutoModelForCausalLM.from_pretrained(
        model_name_or_path,
        config=config,
        torch_dtype="auto",
        trust_remote_code=True,
        low_cpu_mem_usage=True,
    )
    state_dict = model.state_dict()

    if num_layers % num_expand != 0:
luopl's avatar
luopl committed
78
        raise ValueError(f"`num_layers` {num_layers} should be divisible by `num_expand` {num_expand}.")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
79
80
81
82
83
84

    split = num_layers // num_expand
    layer_cnt = 0
    output_state_dict = OrderedDict()
    for i in range(num_layers):
        for key, value in state_dict.items():
luopl's avatar
luopl committed
85
            if f".{i:d}." in key:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
86
87
                output_state_dict[change_name(key, i, layer_cnt)] = value

luopl's avatar
luopl committed
88
        print(f"Add layer {layer_cnt} copied from layer {i}")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
89
90
91
        layer_cnt += 1
        if (i + 1) % split == 0:
            for key, value in state_dict.items():
luopl's avatar
luopl committed
92
                if f".{i:d}." in key:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
93
94
95
96
97
                    if "down_proj" in key or "o_proj" in key:
                        output_state_dict[change_name(key, i, layer_cnt)] = torch.zeros_like(value)
                    else:
                        output_state_dict[change_name(key, i, layer_cnt)] = torch.clone(value)

luopl's avatar
luopl committed
98
            print(f"Add layer {layer_cnt} expanded from layer {i}")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
            layer_cnt += 1

    for key, value in state_dict.items():
        if key not in output_state_dict:
            output_state_dict[key] = value

    weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME
    shards, index = shard_checkpoint(output_state_dict, max_shard_size=shard_size, weights_name=weights_name)

    for shard_file, shard in tqdm(shards.items(), desc="Save weights"):
        if save_safetensors:
            save_file(shard, os.path.join(output_dir, shard_file), metadata={"format": "pt"})
        else:
            torch.save(shard, os.path.join(output_dir, shard_file))

    if index is None:
luopl's avatar
luopl committed
115
        print(f"Model weights saved in {os.path.join(output_dir, weights_name)}")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
116
117
118
119
    else:
        index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
        with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
            json.dump(index, f, indent=2, sort_keys=True)
luopl's avatar
luopl committed
120
        print(f"Model weights saved in {output_dir}")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
121

chenych's avatar
chenych committed
122
    print("- Fine-tune this model with:")
luopl's avatar
luopl committed
123
    print(f"model_name_or_path: {output_dir}")
chenych's avatar
chenych committed
124
    print("finetuning_type: freeze")
luopl's avatar
luopl committed
125
    print(f"freeze_trainable_layers: {num_expand}")
chenych's avatar
chenych committed
126
    print("use_llama_pro: true")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
127
128
129
130


if __name__ == "__main__":
    fire.Fire(block_expansion)