llamafy_qwen.py 7.33 KB
Newer Older
chenych's avatar
chenych committed
1
# Copyright 2025 the LlamaFactory team.
luopl's avatar
luopl committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
from collections import OrderedDict
chenych's avatar
chenych committed
18
from typing import Any
luopl's avatar
luopl committed
19
20
21

import fire
import torch
chenych's avatar
chenych committed
22
from huggingface_hub import split_torch_state_dict_into_shards
luopl's avatar
luopl committed
23
24
25
from safetensors import safe_open
from safetensors.torch import save_file
from tqdm import tqdm
chenych's avatar
chenych committed
26
from transformers.modeling_utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME
luopl's avatar
luopl committed
27
28
29
30
31
32
33
34
35
36
37
38
39
from transformers.utils import check_min_version


try:
    check_min_version("4.34.0")
except Exception:
    raise ValueError("Please upgrade `transformers` to 4.34.0")


CONFIG_NAME = "config.json"


def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetensors: bool) -> str:
chenych's avatar
chenych committed
40
    qwen_state_dict: dict[str, torch.Tensor] = OrderedDict()
luopl's avatar
luopl committed
41
42
43
44
45
46
    for filepath in tqdm(os.listdir(input_dir), desc="Load weights"):
        if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".safetensors"):
            with safe_open(os.path.join(input_dir, filepath), framework="pt", device="cpu") as f:
                for key in f.keys():
                    qwen_state_dict[key] = f.get_tensor(key)

chenych's avatar
chenych committed
47
    llama_state_dict: dict[str, torch.Tensor] = OrderedDict()
luopl's avatar
luopl committed
48
49
50
51
52
    torch_dtype = None
    for key, value in tqdm(qwen_state_dict.items(), desc="Convert format"):
        if torch_dtype is None:
            torch_dtype = value.dtype
        if "wte" in key:
chenych's avatar
chenych committed
53
            llama_state_dict["model.embed_tokens.weight"] = value
luopl's avatar
luopl committed
54
        elif "ln_f" in key:
chenych's avatar
chenych committed
55
            llama_state_dict["model.norm.weight"] = value
luopl's avatar
luopl committed
56
57
58
59
        else:
            key = key.replace("transformer.h", "model.layers")
            if "attn.c_attn" in key:
                proj_size = value.size(0) // 3
chenych's avatar
chenych committed
60
61
                llama_state_dict[key.replace("attn.c_attn", "self_attn.q_proj")] = value[:proj_size, ...]
                llama_state_dict[key.replace("attn.c_attn", "self_attn.k_proj")] = value[
luopl's avatar
luopl committed
62
63
                    proj_size : 2 * proj_size, ...
                ]
chenych's avatar
chenych committed
64
                llama_state_dict[key.replace("attn.c_attn", "self_attn.v_proj")] = value[2 * proj_size :, ...]
luopl's avatar
luopl committed
65
            elif "attn.c_proj" in key:
chenych's avatar
chenych committed
66
67
                llama_state_dict[key.replace("attn.c_proj", "self_attn.o_proj")] = value
                llama_state_dict[key.replace("attn.c_proj.weight", "self_attn.o_proj.bias")] = torch.zeros_like(
luopl's avatar
luopl committed
68
69
70
                    value[:, 0]
                ).squeeze()
            elif "ln_1" in key:
chenych's avatar
chenych committed
71
                llama_state_dict[key.replace("ln_1", "input_layernorm")] = value
luopl's avatar
luopl committed
72
            elif "ln_2" in key:
chenych's avatar
chenych committed
73
                llama_state_dict[key.replace("ln_2", "post_attention_layernorm")] = value
luopl's avatar
luopl committed
74
            elif "mlp.w1" in key:
chenych's avatar
chenych committed
75
                llama_state_dict[key.replace("mlp.w1", "mlp.up_proj")] = value
luopl's avatar
luopl committed
76
            elif "mlp.w2" in key:
chenych's avatar
chenych committed
77
                llama_state_dict[key.replace("mlp.w2", "mlp.gate_proj")] = value
luopl's avatar
luopl committed
78
            elif "mlp.c_proj" in key:
chenych's avatar
chenych committed
79
                llama_state_dict[key.replace("mlp.c_proj", "mlp.down_proj")] = value
luopl's avatar
luopl committed
80
            elif "lm_head" in key:
chenych's avatar
chenych committed
81
                llama_state_dict[key] = value
luopl's avatar
luopl committed
82
83
84
85
            else:
                raise KeyError(f"Unable to process key {key}")

    weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME
chenych's avatar
chenych committed
86
87
88
89
90
91
    filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
    state_dict_split = split_torch_state_dict_into_shards(
        llama_state_dict, filename_pattern=filename_pattern, max_shard_size=shard_size
    )
    for shard_file, tensors in tqdm(state_dict_split.filename_to_tensors.items(), desc="Save weights"):
        shard = {tensor: llama_state_dict[tensor].contiguous() for tensor in tensors}
luopl's avatar
luopl committed
92
93
94
95
96
        if save_safetensors:
            save_file(shard, os.path.join(output_dir, shard_file), metadata={"format": "pt"})
        else:
            torch.save(shard, os.path.join(output_dir, shard_file))

chenych's avatar
chenych committed
97
98
    if not state_dict_split.is_sharded:
        print(f"Model weights saved in {os.path.join(output_dir, weights_name)}.")
luopl's avatar
luopl committed
99
    else:
chenych's avatar
chenych committed
100
101
102
103
        index = {
            "metadata": state_dict_split.metadata,
            "weight_map": state_dict_split.tensor_to_filename,
        }
luopl's avatar
luopl committed
104
105
106
        index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
        with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
            json.dump(index, f, indent=2, sort_keys=True)
chenych's avatar
chenych committed
107
108

        print(f"Model weights saved in {output_dir}.")
luopl's avatar
luopl committed
109
110
111
112
113
114

    return str(torch_dtype).replace("torch.", "")


def save_config(input_dir: str, output_dir: str, torch_dtype: str):
    with open(os.path.join(input_dir, CONFIG_NAME), encoding="utf-8") as f:
chenych's avatar
chenych committed
115
        qwen_config_dict: dict[str, Any] = json.load(f)
luopl's avatar
luopl committed
116

chenych's avatar
chenych committed
117
    llama2_config_dict: dict[str, Any] = OrderedDict()
luopl's avatar
luopl committed
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
    llama2_config_dict["architectures"] = ["LlamaForCausalLM"]
    llama2_config_dict["hidden_act"] = "silu"
    llama2_config_dict["hidden_size"] = qwen_config_dict["hidden_size"]
    llama2_config_dict["initializer_range"] = qwen_config_dict["initializer_range"]
    llama2_config_dict["intermediate_size"] = qwen_config_dict["intermediate_size"] // 2
    llama2_config_dict["max_position_embeddings"] = qwen_config_dict["max_position_embeddings"]
    llama2_config_dict["model_type"] = "llama"
    llama2_config_dict["num_attention_heads"] = qwen_config_dict["num_attention_heads"]
    llama2_config_dict["num_hidden_layers"] = qwen_config_dict["num_hidden_layers"]
    llama2_config_dict["num_key_value_heads"] = qwen_config_dict["hidden_size"] // qwen_config_dict["kv_channels"]
    llama2_config_dict["pretraining_tp"] = 1
    llama2_config_dict["rms_norm_eps"] = qwen_config_dict["layer_norm_epsilon"]
    llama2_config_dict["rope_scaling"] = None
    llama2_config_dict["tie_word_embeddings"] = qwen_config_dict["tie_word_embeddings"]
    llama2_config_dict["torch_dtype"] = torch_dtype
    llama2_config_dict["transformers_version"] = "4.34.0"
    llama2_config_dict["use_cache"] = True
    llama2_config_dict["vocab_size"] = qwen_config_dict["vocab_size"]
    llama2_config_dict["attention_bias"] = True

    with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f:
        json.dump(llama2_config_dict, f, indent=2)
chenych's avatar
chenych committed
140

luopl's avatar
luopl committed
141
142
143
144
145
146
147
148
149
    print(f"Model config saved in {os.path.join(output_dir, CONFIG_NAME)}")


def llamafy_qwen(
    input_dir: str,
    output_dir: str,
    shard_size: str = "2GB",
    save_safetensors: bool = False,
):
chenych's avatar
chenych committed
150
151
    r"""Convert the Qwen models in the same format as LLaMA2.

luopl's avatar
luopl committed
152
153
154
155
156
157
158
159
160
161
162
163
164
165
    Usage: python llamafy_qwen.py --input_dir input --output_dir output
    Converted model: https://huggingface.co/hiyouga/Qwen-14B-Chat-LLaMAfied
    """
    try:
        os.makedirs(output_dir, exist_ok=False)
    except Exception as e:
        raise print("Output dir already exists", e)

    torch_dtype = save_weight(input_dir, output_dir, shard_size, save_safetensors)
    save_config(input_dir, output_dir, torch_dtype)


if __name__ == "__main__":
    fire.Fire(llamafy_qwen)