llama.py 14.6 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
# Copyright (c) 2023, Tri Dao.

import json
4
5
import math
import os
Tri Dao's avatar
Tri Dao committed
6
7
import re
from collections import OrderedDict
8
9
from pathlib import Path
from typing import Union
Tri Dao's avatar
Tri Dao committed
10
11
12
13
14
15

import torch
import torch.nn.functional as F
from transformers import GPT2Config, LlamaConfig


16
17
18
19
20
21
22
23
def remap_state_dict_meta_llama(
    state_dict: dict[str, torch.Tensor], config: GPT2Config
) -> dict[str, torch.Tensor]:
    """Convert the state_dict in Meta format to standard GPT format.

    This function modifies state_dict in place.
    """

Tri Dao's avatar
Tri Dao committed
24
    def key_mapping_layers(key):
Tri Dao's avatar
Tri Dao committed
25
26
        return f"transformer.{key}" if not key.startswith("output.") else key

Tri Dao's avatar
Tri Dao committed
27
28
29
    state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
    # Word embedding
    def key_mapping_emb(key):
Tri Dao's avatar
Tri Dao committed
30
31
32
33
        return re.sub(
            r"^transformer.tok_embeddings.", "transformer.embeddings.word_embeddings.", key
        )

Tri Dao's avatar
Tri Dao committed
34
    state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
Tri Dao's avatar
Tri Dao committed
35
    word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
Tri Dao's avatar
Tri Dao committed
36
    # It's possible that vocab_size is padded to be a multiple of 8, for example.
Tri Dao's avatar
Tri Dao committed
37
38
39
40
41
    pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
    vocab_size = (
        math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) * pad_vocab_size_multiple
    )
    state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
Tri Dao's avatar
Tri Dao committed
42
43
        word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
    )
Tri Dao's avatar
Tri Dao committed
44
45
    if getattr(config, "tie_word_embeddings"):
        state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
Tri Dao's avatar
Tri Dao committed
46
    else:
Tri Dao's avatar
Tri Dao committed
47
        output_embeddings = state_dict.pop("output.weight")
Tri Dao's avatar
Tri Dao committed
48
49
        # Need to recompute vocab_size since LLaMa shards the word embeddings and output embeddings
        # differently.
Tri Dao's avatar
Tri Dao committed
50
51
52
53
        vocab_size = (
            math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple)
            * pad_vocab_size_multiple
        )
Tri Dao's avatar
Tri Dao committed
54
        # It's possible that vocab_size is padded to be a multiple of 8, for example.
Tri Dao's avatar
Tri Dao committed
55
        state_dict["lm_head.weight"] = F.pad(
Tri Dao's avatar
Tri Dao committed
56
57
58
59
60
            output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
        )

    # LayerNorm
    def key_mapping_ln(key):
Tri Dao's avatar
Tri Dao committed
61
62
63
64
65
        key = re.sub(r"^transformer.norm.", r"transformer.ln_f.", key)
        key = re.sub(
            r"^transformer.layers.(\d+).attention_norm.", r"transformer.layers.\1.norm1.", key
        )
        key = re.sub(r"^transformer.layers.(\d+).ffn_norm.", r"transformer.layers.\1.norm2.", key)
Tri Dao's avatar
Tri Dao committed
66
        return key
Tri Dao's avatar
Tri Dao committed
67

Tri Dao's avatar
Tri Dao committed
68
69
70
71
    state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())

    # MLP
    for l in range(config.n_layer):
Tri Dao's avatar
Tri Dao committed
72
73
        w1 = state_dict.pop(f"transformer.layers.{l}.feed_forward.w1.weight")
        w3 = state_dict.pop(f"transformer.layers.{l}.feed_forward.w3.weight")
Tri Dao's avatar
Tri Dao committed
74
        # Our ordering is different
Tri Dao's avatar
Tri Dao committed
75
76
        state_dict[f"transformer.layers.{l}.mlp.fc1.weight"] = torch.cat([w3, w1], dim=0)

Tri Dao's avatar
Tri Dao committed
77
    def key_mapping_mlp(key):
Tri Dao's avatar
Tri Dao committed
78
79
80
81
        return re.sub(
            r"^transformer.layers.(\d+).feed_forward.w2.", r"transformer.layers.\1.mlp.fc2.", key
        )

Tri Dao's avatar
Tri Dao committed
82
83
84
85
    state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())

    # Attention
    for l in range(config.n_layer):
Tri Dao's avatar
Tri Dao committed
86
87
88
89
        Wq = state_dict.pop(f"transformer.layers.{l}.attention.wq.weight")
        Wk = state_dict.pop(f"transformer.layers.{l}.attention.wk.weight")
        Wv = state_dict.pop(f"transformer.layers.{l}.attention.wv.weight")
        state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat([Wq, Wk, Wv], dim=0)
Tri Dao's avatar
Tri Dao committed
90
        # We don't store these
Tri Dao's avatar
Tri Dao committed
91
92
        state_dict.pop(f"transformer.layers.{l}.attention.inner_attention.rope.freqs", None)

Tri Dao's avatar
Tri Dao committed
93
    def key_mapping_attn(key):
Tri Dao's avatar
Tri Dao committed
94
95
96
97
98
99
        return re.sub(
            r"^transformer.layers.(\d+).attention.wo.",
            r"transformer.layers.\1.mixer.out_proj.",
            key,
        )

Tri Dao's avatar
Tri Dao committed
100
101
    state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())

102
103
    state_dict.pop("transformer.rope.freqs", None)

Tri Dao's avatar
Tri Dao committed
104
105
106
    return state_dict


107
108
109
110
111
112
113
def remap_state_dict_hf_llama(
    state_dict: dict[str, torch.Tensor], config: GPT2Config
) -> dict[str, torch.Tensor]:
    """Convert the state_dict in Hugging Face format to standard GPT format.

    This function modifies state_dict in place.
    """
114
115
    # Embedding
    def key_mapping_emb(key):
Tri Dao's avatar
Tri Dao committed
116
        return re.sub(r"^model.embed_tokens.", "transformer.embeddings.word_embeddings.", key)
117
118

    state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
Tri Dao's avatar
Tri Dao committed
119
    word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
120
    # It's possible that vocab_size is padded to be a multiple of 8, for example.
Tri Dao's avatar
Tri Dao committed
121
122
123
124
125
    pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
    vocab_size = (
        math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) * pad_vocab_size_multiple
    )
    state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
126
127
128
129
        word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
    )

    # LM head
Tri Dao's avatar
Tri Dao committed
130
131
    if getattr(config, "tie_word_embeddings"):
        state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
132
    else:
Tri Dao's avatar
Tri Dao committed
133
        output_embeddings = state_dict.pop("lm_head.weight")
134
135
        # Need to recompute vocab_size since LLaMa shards the word embeddings and output embeddings
        # differently.
Tri Dao's avatar
Tri Dao committed
136
137
138
139
        vocab_size = (
            math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple)
            * pad_vocab_size_multiple
        )
140
        # It's possible that vocab_size is padded to be a multiple of 8, for example.
Tri Dao's avatar
Tri Dao committed
141
        state_dict["lm_head.weight"] = F.pad(
142
143
144
145
146
147
148
149
            output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
        )

    # MLP
    for l in range(config.n_layer):
        # Fusing weights this way based on difference in the following:
        # https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/modeling_llama.py#L220
        # https://github.com/Dao-AILab/flash-attention/blob/c60851a8253257eb970e06a022c82517a8033e8c/flash_attn/modules/mlp.py#L115
Tri Dao's avatar
Tri Dao committed
150
151
152
        w1 = state_dict.pop(f"model.layers.{l}.mlp.gate_proj.weight")
        w3 = state_dict.pop(f"model.layers.{l}.mlp.up_proj.weight")
        state_dict[f"transformer.layers.{l}.mlp.fc1.weight"] = torch.cat([w3, w1], dim=0)
153
154

    def key_mapping_mlp(key):
Tri Dao's avatar
Tri Dao committed
155
        return re.sub(r"^model.layers.(\d+).mlp.down_proj.", r"transformer.layers.\1.mlp.fc2.", key)
156
157
158
159
160

    state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())

    # LayerNorm
    def key_mapping_ln(key):
Tri Dao's avatar
Tri Dao committed
161
162
163
164
165
        key = re.sub(r"^model.norm.", r"transformer.ln_f.", key)
        key = re.sub(r"^model.layers.(\d+).input_layernorm.", r"transformer.layers.\1.norm1.", key)
        key = re.sub(
            r"^model.layers.(\d+).post_attention_layernorm.", r"transformer.layers.\1.norm2.", key
        )
166
167
168
169
170
171
172
        return key

    state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())

    def inv_permute(w):
        # Inverse of permute implemented in:
        # https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/convert_llama_weights_to_hf.py#L114
Tri Dao's avatar
Tri Dao committed
173
174
175
176
177
        return (
            w.reshape(config.n_head, 2, config.n_embd // config.n_head // 2, config.n_embd)
            .transpose(1, 2)
            .reshape(config.n_embd, config.n_embd)
        )
178
179
180

    # Attention
    for l in range(config.n_layer):
Tri Dao's avatar
Tri Dao committed
181
182
183
184
        Wq = state_dict.pop(f"model.layers.{l}.self_attn.q_proj.weight")
        Wk = state_dict.pop(f"model.layers.{l}.self_attn.k_proj.weight")
        Wv = state_dict.pop(f"model.layers.{l}.self_attn.v_proj.weight")
        state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat(
185
186
187
            [inv_permute(Wq), inv_permute(Wk), Wv], dim=0
        )
        # We don't store these
Tri Dao's avatar
Tri Dao committed
188
        state_dict.pop(f"model.layers.{l}.self_attn.rotary_emb.inv_freq", None)
189
190

    def key_mapping_attn(key):
Tri Dao's avatar
Tri Dao committed
191
192
193
        return re.sub(
            r"^model.layers.(\d+).self_attn.o_proj.", r"transformer.layers.\1.mixer.out_proj.", key
        )
194
195
196
197
198

    state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
    return state_dict


199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
def inv_remap_state_dict_hf_llama(
    state_dict: dict[str, torch.Tensor], config: GPT2Config
) -> dict[str, torch.Tensor]:
    """Convert the state_dict in standard GPT format to Hugging Face format.

    This function is meant to be the inverse of remap_state_dict_hf_llama, up to a
    multiplier pad in the embedding and lm_head. That is if the original embedding
    isn't a multiple of pad_vocab_size_multiple, then
    inv_remap_state_dict_hf_llama(remap_state_dict_hf_llama(state_dict)) != state_dict.

    This function modifies state_dict in place.
    """

    # Embedding
    def key_mapping_emb(key):
        return re.sub(r"^transformer.embeddings.word_embeddings.", "model.embed_tokens.", key)

    state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
    word_embeddings = state_dict.pop("model.embed_tokens.weight")
    pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
    vocab_size = (
        math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) * pad_vocab_size_multiple
    )
    state_dict["model.embed_tokens.weight"] = F.pad(
        word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
    )

    # LM head
    if getattr(config, "tie_word_embeddings"):
        state_dict["lm_head.weight"] = state_dict["model.embed_tokens.weight"]
    else:
        output_embeddings = state_dict.pop("lm_head.weight")
        vocab_size = (
            math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple)
            * pad_vocab_size_multiple
        )
        state_dict["lm_head.weight"] = F.pad(
            output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
        )

    # MLP
    for l in range(config.n_layer):
        w3, w1 = torch.chunk(
            state_dict.pop(f"transformer.layers.{l}.mlp.fc1.weight"), chunks=2, dim=0
        )
        state_dict[f"model.layers.{l}.mlp.gate_proj.weight"] = w1
        state_dict[f"model.layers.{l}.mlp.up_proj.weight"] = w3

    def key_mapping_mlp(key):
        return re.sub(r"^transformer.layers.(\d+).mlp.fc2.", r"model.layers.\1.mlp.down_proj.", key)

    state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())

    # LayerNorm
    def key_mapping_ln(key):
        key = re.sub(r"^transformer.ln_f.", r"model.norm.", key)
        key = re.sub(r"^transformer.layers.(\d+).norm1.", r"model.layers.\1.input_layernorm.", key)
        key = re.sub(
            r"^transformer.layers.(\d+).norm2.", r"model.layers.\1.post_attention_layernorm.", key
        )
        return key

    state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())

    def permute(w):
        return (
            w.view(config.n_head, config.n_embd // config.n_head // 2, 2, config.n_embd)
            .transpose(1, 2)
            .reshape(config.n_embd, config.n_embd)
        )

    n_head = config.n_head
    n_head_kv = getattr(config, "n_head_kv", n_head)

    embed_dim = config.hidden_size
    head_dim = embed_dim // n_head

    q_dim = n_head * head_dim
    k_dim = v_dim = n_head_kv * head_dim

    # Attention
    for l in range(config.n_layer):
        Wqkv = state_dict.pop(f"transformer.layers.{l}.mixer.Wqkv.weight")
        Wq = Wqkv[:q_dim]
        Wk = Wqkv[q_dim : q_dim + k_dim]
        Wv = Wqkv[q_dim + k_dim : q_dim + k_dim + v_dim]
        state_dict[f"model.layers.{l}.self_attn.q_proj.weight"] = permute(Wq)
        state_dict[f"model.layers.{l}.self_attn.k_proj.weight"] = permute(Wk)
        state_dict[f"model.layers.{l}.self_attn.v_proj.weight"] = Wv
        state_dict.pop(f"transformer.layers.{l}.attention.inner_attention.rope.freqs", None)

    def key_mapping_attn(key):
        return re.sub(
            r"^transformer.layers.(\d+).mixer.out_proj.", r"model.layers.\1.self_attn.o_proj.", key
        )

    state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
    return state_dict


Tri Dao's avatar
Tri Dao committed
299
300
301
def config_from_meta_checkpoint(
    checkpoint_path: Union[str, os.PathLike], model_name: str
) -> LlamaConfig:
Tri Dao's avatar
Tri Dao committed
302
    """Load a LlamaConfig from a checkpoint path."""
Tri Dao's avatar
Tri Dao committed
303
    with open(Path(checkpoint_path) / model_name / "params.json") as f:
Tri Dao's avatar
Tri Dao committed
304
        params = json.load(f)
Tri Dao's avatar
Tri Dao committed
305
306
307
308
309
310
311
    config = LlamaConfig(
        hidden_size=params["dim"],
        intermediate_size=None,
        num_attention_heads=params["n_heads"],
        num_hidden_layers=params["n_layers"],
        rms_norm_eps=params["norm_eps"],
    )
Tri Dao's avatar
Tri Dao committed
312
313
314
    return config


Tri Dao's avatar
Tri Dao committed
315
316
317
318
def config_from_hf_checkpoint(
    checkpoint_path: Union[str, os.PathLike], model_name: str
) -> LlamaConfig:
    return LlamaConfig.from_pretrained(Path(checkpoint_path) / f"{model_name}-hf" / "config.json")
319
320
321
322
323
324
325
326
327
328
329


def config_from_checkpoint(
    checkpoint_path: Union[str, os.PathLike], model_name: str, checkpoint_format="meta"
) -> LlamaConfig:
    if checkpoint_format == "meta":
        return config_from_meta_checkpoint(checkpoint_path, model_name)
    else:
        return config_from_hf_checkpoint(checkpoint_path, model_name)


Tri Dao's avatar
Tri Dao committed
330
331
332
def state_dicts_from_checkpoint(
    checkpoint_path: Union[str, os.PathLike], model_name: str
) -> list[dict]:
Tri Dao's avatar
Tri Dao committed
333
    # Need to sort, otherwise we mess up the ordering and the weights are wrong
Tri Dao's avatar
Tri Dao committed
334
335
336
337
    return [
        torch.load(path, map_location="cpu")
        for path in sorted((Path(checkpoint_path) / model_name).glob("consolidated.*.pth"))
    ]
Tri Dao's avatar
Tri Dao committed
338
339
340
341
342
343
344
345
346
347


def llama_config_to_gpt2_config(llama_config: LlamaConfig) -> GPT2Config:
    return GPT2Config(
        vocab_size=llama_config.vocab_size,
        n_positions=0,  # No absolute position embedding
        n_embd=llama_config.hidden_size,
        n_layer=llama_config.num_hidden_layers,
        n_head=llama_config.num_attention_heads,
        n_inner=llama_config.intermediate_size,
Tri Dao's avatar
Tri Dao committed
348
        activation_function="swiglu",  # Hardcode since HF calls it 'silu'
Tri Dao's avatar
Tri Dao committed
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
        # Llama doesn't have dropout, idk if it's because they only release the inference code
        resid_pdrop=0.0,
        embd_pdrop=0.0,
        attn_pdrop=0.0,
        layer_norm_epsilon=llama_config.rms_norm_eps,
        initializer_range=llama_config.initializer_range,
        bos_token_id=llama_config.bos_token_id,
        eos_token_id=llama_config.eos_token_id,
        # These are new arguments not in the original GPT2Config
        pad_token_id=llama_config.pad_token_id,  # Idk if this does anything
        rms_norm=True,
        rotary_emb_fraction=1.0,
        rotary_emb_interleaved=True,
        tie_word_embeddings=False,
        qkv_proj_bias=False,
        out_proj_bias=False,
        mlp_fc1_bias=False,
        mlp_fc2_bias=False,
    )