llama.py 10.3 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
# Copyright (c) 2023, Tri Dao.

import json
4
5
import math
import os
Tri Dao's avatar
Tri Dao committed
6
7
import re
from collections import OrderedDict
8
9
from pathlib import Path
from typing import Union
Tri Dao's avatar
Tri Dao committed
10
11
12
13
14
15
16
17

import torch
import torch.nn.functional as F
from transformers import GPT2Config, LlamaConfig


def remap_state_dict_meta_llama(state_dict, config):
    def key_mapping_layers(key):
Tri Dao's avatar
Tri Dao committed
18
19
        return f"transformer.{key}" if not key.startswith("output.") else key

Tri Dao's avatar
Tri Dao committed
20
21
22
    state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
    # Word embedding
    def key_mapping_emb(key):
Tri Dao's avatar
Tri Dao committed
23
24
25
26
        return re.sub(
            r"^transformer.tok_embeddings.", "transformer.embeddings.word_embeddings.", key
        )

Tri Dao's avatar
Tri Dao committed
27
    state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
Tri Dao's avatar
Tri Dao committed
28
    word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
Tri Dao's avatar
Tri Dao committed
29
    # It's possible that vocab_size is padded to be a multiple of 8, for example.
Tri Dao's avatar
Tri Dao committed
30
31
32
33
34
    pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
    vocab_size = (
        math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) * pad_vocab_size_multiple
    )
    state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
Tri Dao's avatar
Tri Dao committed
35
36
        word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
    )
Tri Dao's avatar
Tri Dao committed
37
38
    if getattr(config, "tie_word_embeddings"):
        state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
Tri Dao's avatar
Tri Dao committed
39
    else:
Tri Dao's avatar
Tri Dao committed
40
        output_embeddings = state_dict.pop("output.weight")
Tri Dao's avatar
Tri Dao committed
41
42
        # Need to recompute vocab_size since LLaMa shards the word embeddings and output embeddings
        # differently.
Tri Dao's avatar
Tri Dao committed
43
44
45
46
        vocab_size = (
            math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple)
            * pad_vocab_size_multiple
        )
Tri Dao's avatar
Tri Dao committed
47
        # It's possible that vocab_size is padded to be a multiple of 8, for example.
Tri Dao's avatar
Tri Dao committed
48
        state_dict["lm_head.weight"] = F.pad(
Tri Dao's avatar
Tri Dao committed
49
50
51
52
53
            output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
        )

    # LayerNorm
    def key_mapping_ln(key):
Tri Dao's avatar
Tri Dao committed
54
55
56
57
58
        key = re.sub(r"^transformer.norm.", r"transformer.ln_f.", key)
        key = re.sub(
            r"^transformer.layers.(\d+).attention_norm.", r"transformer.layers.\1.norm1.", key
        )
        key = re.sub(r"^transformer.layers.(\d+).ffn_norm.", r"transformer.layers.\1.norm2.", key)
Tri Dao's avatar
Tri Dao committed
59
        return key
Tri Dao's avatar
Tri Dao committed
60

Tri Dao's avatar
Tri Dao committed
61
62
63
64
    state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())

    # MLP
    for l in range(config.n_layer):
Tri Dao's avatar
Tri Dao committed
65
66
        w1 = state_dict.pop(f"transformer.layers.{l}.feed_forward.w1.weight")
        w3 = state_dict.pop(f"transformer.layers.{l}.feed_forward.w3.weight")
Tri Dao's avatar
Tri Dao committed
67
        # Our ordering is different
Tri Dao's avatar
Tri Dao committed
68
69
        state_dict[f"transformer.layers.{l}.mlp.fc1.weight"] = torch.cat([w3, w1], dim=0)

Tri Dao's avatar
Tri Dao committed
70
    def key_mapping_mlp(key):
Tri Dao's avatar
Tri Dao committed
71
72
73
74
        return re.sub(
            r"^transformer.layers.(\d+).feed_forward.w2.", r"transformer.layers.\1.mlp.fc2.", key
        )

Tri Dao's avatar
Tri Dao committed
75
76
77
78
    state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())

    # Attention
    for l in range(config.n_layer):
Tri Dao's avatar
Tri Dao committed
79
80
81
82
        Wq = state_dict.pop(f"transformer.layers.{l}.attention.wq.weight")
        Wk = state_dict.pop(f"transformer.layers.{l}.attention.wk.weight")
        Wv = state_dict.pop(f"transformer.layers.{l}.attention.wv.weight")
        state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat([Wq, Wk, Wv], dim=0)
Tri Dao's avatar
Tri Dao committed
83
        # We don't store these
Tri Dao's avatar
Tri Dao committed
84
85
        state_dict.pop(f"transformer.layers.{l}.attention.inner_attention.rope.freqs", None)

Tri Dao's avatar
Tri Dao committed
86
    def key_mapping_attn(key):
Tri Dao's avatar
Tri Dao committed
87
88
89
90
91
92
        return re.sub(
            r"^transformer.layers.(\d+).attention.wo.",
            r"transformer.layers.\1.mixer.out_proj.",
            key,
        )

Tri Dao's avatar
Tri Dao committed
93
94
    state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())

95
96
    state_dict.pop("transformer.rope.freqs", None)

Tri Dao's avatar
Tri Dao committed
97
98
99
    return state_dict


100
101
102
def remap_state_dict_hf_llama(state_dict, config):
    # Embedding
    def key_mapping_emb(key):
Tri Dao's avatar
Tri Dao committed
103
        return re.sub(r"^model.embed_tokens.", "transformer.embeddings.word_embeddings.", key)
104
105

    state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
Tri Dao's avatar
Tri Dao committed
106
    word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
107
    # It's possible that vocab_size is padded to be a multiple of 8, for example.
Tri Dao's avatar
Tri Dao committed
108
109
110
111
112
    pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
    vocab_size = (
        math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) * pad_vocab_size_multiple
    )
    state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
113
114
115
116
        word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
    )

    # LM head
Tri Dao's avatar
Tri Dao committed
117
118
    if getattr(config, "tie_word_embeddings"):
        state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
119
    else:
Tri Dao's avatar
Tri Dao committed
120
        output_embeddings = state_dict.pop("lm_head.weight")
121
122
        # Need to recompute vocab_size since LLaMa shards the word embeddings and output embeddings
        # differently.
Tri Dao's avatar
Tri Dao committed
123
124
125
126
        vocab_size = (
            math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple)
            * pad_vocab_size_multiple
        )
127
        # It's possible that vocab_size is padded to be a multiple of 8, for example.
Tri Dao's avatar
Tri Dao committed
128
        state_dict["lm_head.weight"] = F.pad(
129
130
131
132
133
134
135
136
            output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
        )

    # MLP
    for l in range(config.n_layer):
        # Fusing weights this way based on difference in the following:
        # https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/modeling_llama.py#L220
        # https://github.com/Dao-AILab/flash-attention/blob/c60851a8253257eb970e06a022c82517a8033e8c/flash_attn/modules/mlp.py#L115
Tri Dao's avatar
Tri Dao committed
137
138
139
        w1 = state_dict.pop(f"model.layers.{l}.mlp.gate_proj.weight")
        w3 = state_dict.pop(f"model.layers.{l}.mlp.up_proj.weight")
        state_dict[f"transformer.layers.{l}.mlp.fc1.weight"] = torch.cat([w3, w1], dim=0)
140
141

    def key_mapping_mlp(key):
Tri Dao's avatar
Tri Dao committed
142
        return re.sub(r"^model.layers.(\d+).mlp.down_proj.", r"transformer.layers.\1.mlp.fc2.", key)
143
144
145
146
147

    state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())

    # LayerNorm
    def key_mapping_ln(key):
Tri Dao's avatar
Tri Dao committed
148
149
150
151
152
        key = re.sub(r"^model.norm.", r"transformer.ln_f.", key)
        key = re.sub(r"^model.layers.(\d+).input_layernorm.", r"transformer.layers.\1.norm1.", key)
        key = re.sub(
            r"^model.layers.(\d+).post_attention_layernorm.", r"transformer.layers.\1.norm2.", key
        )
153
154
155
156
157
158
159
        return key

    state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())

    def inv_permute(w):
        # Inverse of permute implemented in:
        # https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/convert_llama_weights_to_hf.py#L114
Tri Dao's avatar
Tri Dao committed
160
161
162
163
164
        return (
            w.reshape(config.n_head, 2, config.n_embd // config.n_head // 2, config.n_embd)
            .transpose(1, 2)
            .reshape(config.n_embd, config.n_embd)
        )
165
166
167

    # Attention
    for l in range(config.n_layer):
Tri Dao's avatar
Tri Dao committed
168
169
170
171
        Wq = state_dict.pop(f"model.layers.{l}.self_attn.q_proj.weight")
        Wk = state_dict.pop(f"model.layers.{l}.self_attn.k_proj.weight")
        Wv = state_dict.pop(f"model.layers.{l}.self_attn.v_proj.weight")
        state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat(
172
173
174
            [inv_permute(Wq), inv_permute(Wk), Wv], dim=0
        )
        # We don't store these
Tri Dao's avatar
Tri Dao committed
175
        state_dict.pop(f"model.layers.{l}.self_attn.rotary_emb.inv_freq", None)
176
177

    def key_mapping_attn(key):
Tri Dao's avatar
Tri Dao committed
178
179
180
        return re.sub(
            r"^model.layers.(\d+).self_attn.o_proj.", r"transformer.layers.\1.mixer.out_proj.", key
        )
181
182
183
184
185

    state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
    return state_dict


Tri Dao's avatar
Tri Dao committed
186
187
188
def config_from_meta_checkpoint(
    checkpoint_path: Union[str, os.PathLike], model_name: str
) -> LlamaConfig:
Tri Dao's avatar
Tri Dao committed
189
    """Load a LlamaConfig from a checkpoint path."""
Tri Dao's avatar
Tri Dao committed
190
    with open(Path(checkpoint_path) / model_name / "params.json") as f:
Tri Dao's avatar
Tri Dao committed
191
        params = json.load(f)
Tri Dao's avatar
Tri Dao committed
192
193
194
195
196
197
198
    config = LlamaConfig(
        hidden_size=params["dim"],
        intermediate_size=None,
        num_attention_heads=params["n_heads"],
        num_hidden_layers=params["n_layers"],
        rms_norm_eps=params["norm_eps"],
    )
Tri Dao's avatar
Tri Dao committed
199
200
201
    return config


Tri Dao's avatar
Tri Dao committed
202
203
204
205
def config_from_hf_checkpoint(
    checkpoint_path: Union[str, os.PathLike], model_name: str
) -> LlamaConfig:
    return LlamaConfig.from_pretrained(Path(checkpoint_path) / f"{model_name}-hf" / "config.json")
206
207
208
209
210
211
212
213
214
215
216


def config_from_checkpoint(
    checkpoint_path: Union[str, os.PathLike], model_name: str, checkpoint_format="meta"
) -> LlamaConfig:
    if checkpoint_format == "meta":
        return config_from_meta_checkpoint(checkpoint_path, model_name)
    else:
        return config_from_hf_checkpoint(checkpoint_path, model_name)


Tri Dao's avatar
Tri Dao committed
217
218
219
def state_dicts_from_checkpoint(
    checkpoint_path: Union[str, os.PathLike], model_name: str
) -> list[dict]:
Tri Dao's avatar
Tri Dao committed
220
    # Need to sort, otherwise we mess up the ordering and the weights are wrong
Tri Dao's avatar
Tri Dao committed
221
222
223
224
    return [
        torch.load(path, map_location="cpu")
        for path in sorted((Path(checkpoint_path) / model_name).glob("consolidated.*.pth"))
    ]
Tri Dao's avatar
Tri Dao committed
225
226
227
228
229
230
231
232
233
234


def llama_config_to_gpt2_config(llama_config: LlamaConfig) -> GPT2Config:
    return GPT2Config(
        vocab_size=llama_config.vocab_size,
        n_positions=0,  # No absolute position embedding
        n_embd=llama_config.hidden_size,
        n_layer=llama_config.num_hidden_layers,
        n_head=llama_config.num_attention_heads,
        n_inner=llama_config.intermediate_size,
Tri Dao's avatar
Tri Dao committed
235
        activation_function="swiglu",  # Hardcode since HF calls it 'silu'
Tri Dao's avatar
Tri Dao committed
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
        # Llama doesn't have dropout, idk if it's because they only release the inference code
        resid_pdrop=0.0,
        embd_pdrop=0.0,
        attn_pdrop=0.0,
        layer_norm_epsilon=llama_config.rms_norm_eps,
        initializer_range=llama_config.initializer_range,
        bos_token_id=llama_config.bos_token_id,
        eos_token_id=llama_config.eos_token_id,
        # These are new arguments not in the original GPT2Config
        pad_token_id=llama_config.pad_token_id,  # Idk if this does anything
        rms_norm=True,
        rotary_emb_fraction=1.0,
        rotary_emb_interleaved=True,
        tie_word_embeddings=False,
        qkv_proj_bias=False,
        out_proj_bias=False,
        mlp_fc1_bias=False,
        mlp_fc2_bias=False,
    )