convert_hf_to_nanotron.py 4.99 KB
Newer Older
chenzk's avatar
v1.0.8  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
"""
Converts a HF model to nanotron format
Command:
    torchrun --nproc_per_node=1 convert_hf_to_nanotron.py --checkpoint_path=hf_weights --save_path=nanotron_weights
"""

import dataclasses
import json
from argparse import ArgumentParser
from pathlib import Path

import nanotron
import torch
from convert_weights import get_config_mapping, get_weight_mapping, load_nanotron_model
from nanotron.config import LlamaConfig as NanotronLlamaConfig
from nanotron.models.llama import LlamaForTraining
from transformers import LlamaConfig as HFLlamaConfig
from transformers import LlamaForCausalLM


def _handle_attention_block(
    q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, n_q_heads: int, n_kv_heads: int, d_qk: int
) -> torch.Tensor:
    # Huggingface Llama separates the q, k, v weights (as opposed to nanotron).
    # Furthermore, in the rotary embeddings in nanotron expects interleaved pairs of even
    # and odd dimensions GPT-J style, while the huggingface implementation expects
    # the whole 1st half and then the whole 2nd half GPT-NeoX style (for more information
    # see flash_attn.layers.rotary.RotaryEmbedding).
    # This function handles the concatenation of the q, k, v weights and proper permutation
    # to ensure correct transformation.

    def interleave(w: torch.Tensor):
        w_new = []
        for head_w in w.split(d_qk):
            head_w = head_w.view(2, d_qk // 2, -1).transpose(0, 1).reshape(d_qk, -1)
            w_new.append(head_w)
        return torch.cat(w_new)

    q = interleave(q)
    k = interleave(k)
    return torch.cat([q, k, v])


def convert_hf_to_nt(model_hf: LlamaForCausalLM, model_nt: LlamaForTraining, config: NanotronLlamaConfig):
    """Converts the weights from the model_hf to model_nt, making modifications
    in-place."""

    hf_sd = model_hf.state_dict()
    nt_to_hf = get_weight_mapping(config, nt_to_hf=True)

    for module_name_nt, module_nt in model_nt.named_modules():
        for param_name_nt, param_nt in module_nt.named_parameters(recurse=False):
            # In the case of qkv_proj, the nt_to_hf has exactly three keys, ccorresponding
            # to q, k, v.
            if "qkv_proj" in module_name_nt:
                key_k, key_q, key_v = sorted(nt_to_hf[f"{module_name_nt}.{param_name_nt}"])
                q = hf_sd[key_q]
                k = hf_sd[key_k]
                v = hf_sd[key_v]
                param = _handle_attention_block(
                    q,
                    k,
                    v,
                    config.num_attention_heads,
                    config.num_key_value_heads,
                    config.hidden_size // config.num_attention_heads,
                )
            # The case of gate_up_proj, nt_to_hf_map has two keys.
            elif "gate_up_proj" in module_name_nt:
                key_gate, key_up = sorted(nt_to_hf[f"{module_name_nt}.{param_name_nt}"])
                gate = hf_sd[key_gate]
                up = hf_sd[key_up]
                param = torch.cat([gate, up])
            # All other cases are simple 1-to-1 correspondence.
            else:
                hf_key = nt_to_hf[f"{module_name_nt}.{param_name_nt}"]
                param = hf_sd[hf_key]

            with torch.no_grad():
                param_nt.copy_(param)


def get_nanotron_config(config: HFLlamaConfig) -> NanotronLlamaConfig:
    """Converts a huggingface configuration to nanotron configuration."""
    attrs = {key: getattr(config, value) for key, value in get_config_mapping(nt_to_hf=True).items()}
    return NanotronLlamaConfig(**attrs)


def convert_checkpoint_and_save(checkpoint_path: Path, save_path: Path):
    """Loads the huggingface checkpoint in `checkpoint_path`, creates
    a new nanotron instance, copies the weights from the huggingface checkpoint
    and saves the transformed nanotron to `save_path`."""

    # Load huggingface.
    hf_model = LlamaForCausalLM.from_pretrained(checkpoint_path)

    # Init nanotron model.
    model_config = get_nanotron_config(hf_model.config)
    nanotron_model = load_nanotron_model(model_config=model_config)

    # Copy weights and save model.
    parallel_context = nanotron.parallel.ParallelContext(
        data_parallel_size=1, pipeline_parallel_size=1, tensor_parallel_size=1
    )
    convert_hf_to_nt(hf_model, nanotron_model, model_config)
    nanotron.serialize.save_weights(model=nanotron_model, parallel_context=parallel_context, root_folder=save_path)
    with open(save_path / "model_config.json", "w+") as f:
        json.dump(dataclasses.asdict(model_config), f)
    print(f"Model saved to {save_path}")


if __name__ == "__main__":
    parser = ArgumentParser(description="Convert HF weights to nanotron format")
    parser.add_argument("--checkpoint_path", type=Path, default="llama-7b", help="Path to the checkpoint")
    parser.add_argument("--save_path", type=Path, default="llama-7b-hf", help="Path to save the nanotron model")
    args = parser.parse_args()

    # Convert HF model to nanotron format.
    convert_checkpoint_and_save(checkpoint_path=args.checkpoint_path, save_path=args.save_path)