import_hf_llama_weights.py 5.17 KB
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""
Use this file to import Huggingface LlamaForCausalLM weights to ALLaMo format.   
"""
import argparse
import dataclasses
import json
import os
import torch
from transformers import LlamaForCausalLM
from allamo.logging import configure_logger, logger
from allamo.model.model import AllamoTransformerConfig, AllamoTransformer

def import_model(hf_model_path, output_model_path):
    logger.info(f"Importing Huggingface LlamaForCausalLM weights")
    hf_model = LlamaForCausalLM.from_pretrained(hf_model_path, torch_dtype=torch.float32, low_cpu_mem_usage=True)
    logger.info(f"Huggingface LlamaForCausalLM model loaded")

    assert hf_model.config.hidden_act == "silu"
    
    config = AllamoTransformerConfig()
    config.block_size = hf_model.config.max_position_embeddings
    config.vocab_size = hf_model.config.vocab_size
    config.n_layer = hf_model.config.num_hidden_layers
    config.n_head = hf_model.config.num_attention_heads
    config.n_embd = hf_model.config.hidden_size
    config.intermediate_size = hf_model.config.intermediate_size
    config.head_size = config.n_embd // config.n_head
    config.num_kv_heads = hf_model.config.num_key_value_heads
    config.dropout = 0.0
    config.bias = False
    config.norm_eps = hf_model.config.rms_norm_eps

    logger.info(f"initializing vanilla ALLaMo model")
    model = AllamoTransformer(config)
    
    logger.info(f"preparing weights")
    state_dicts_map = {}
    sd_hf_model = hf_model.state_dict()
    model_sd = model.state_dict()
    for layer_i in range(config.n_layer):
        state_dicts_map[f"layers.{layer_i}.attention.q_proj.weight"] = f"model.layers.{layer_i}.self_attn.q_proj.weight"
        state_dicts_map[f"layers.{layer_i}.attention.k_proj.weight"] = f"model.layers.{layer_i}.self_attn.k_proj.weight"
        state_dicts_map[f"layers.{layer_i}.attention.v_proj.weight"] = f"model.layers.{layer_i}.self_attn.v_proj.weight"
        state_dicts_map[f"layers.{layer_i}.attention.c_proj.weight"] = f"model.layers.{layer_i}.self_attn.o_proj.weight"
        state_dicts_map[f"layers.{layer_i}.attention.rotary_emb.inv_freq"] = f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"
        state_dicts_map[f"layers.{layer_i}.feed_forward.gate_proj.weight"] = f"model.layers.{layer_i}.mlp.gate_proj.weight"
        state_dicts_map[f"layers.{layer_i}.feed_forward.down_proj.weight"] = f"model.layers.{layer_i}.mlp.down_proj.weight"
        state_dicts_map[f"layers.{layer_i}.feed_forward.up_proj.weight"] = f"model.layers.{layer_i}.mlp.up_proj.weight"
        state_dicts_map[f"layers.{layer_i}.attention_norm.weight"] = f"model.layers.{layer_i}.input_layernorm.weight"
        state_dicts_map[f"layers.{layer_i}.ffn_norm.weight"] = f"model.layers.{layer_i}.post_attention_layernorm.weight"
    state_dicts_map["tok_embeddings.weight"] = "model.embed_tokens.weight"
    state_dicts_map["norm.weight"] = "model.norm.weight"
    state_dicts_map["lm_head.weight"] = "lm_head.weight"
    
    logger.info(f"checking params coverage")
    for k, v in model_sd.items():
        if k not in state_dicts_map:
            logger.info(f"{k} param won't be updated in the ALLaMo model!")
            
    for k, v in sd_hf_model.items():
        if k not in state_dicts_map.values():
            logger.info(f"{k} param won't be copied to the ALLaMo model!")
    
    logger.info(f"copying params to the ALLaMo model")
    param_count = 0
    for k, v in state_dicts_map.items():
        if not k.endswith('rotary_emb.inv_freq'):
            assert sd_hf_model[v].shape == model_sd[k].shape
            with torch.no_grad():
                model_sd[k].copy_(sd_hf_model[v])
            param_count += model_sd[k].numel()
    logger.info(f"{param_count} params copied to the ALLaMo model")
    
    for k, _ in model_sd.items():
        if not torch.all(torch.eq(model_sd[k], sd_hf_model[state_dicts_map[k]])):
            logger.info(f"{k} param in the ALLaMo model is not the same as {state_dicts_map[k]} param in the source model!")
    logger.info(f"params verified")
    
    ckpt_file_name = 'import_ckpt'
    config_checkpoint = {
        'model_args': dataclasses.asdict(config)
    }
    ckpt_file_path = os.path.join(output_model_path, f'config_{ckpt_file_name}.json')
    logger.info(f"saving config checkpoint to {ckpt_file_path}")
    with open(ckpt_file_path, "w", encoding="utf-8") as f:
        json.dump(config_checkpoint, f, indent=4, ensure_ascii=False)
    ckpt_file_path = os.path.join(output_model_path, f'model_{ckpt_file_name}.pt')
    logger.info(f"saving model checkpoint to {ckpt_file_path}")
    torch.save(model_sd, ckpt_file_path)
    logger.info(f"checkpoint files saved in {output_model_path}")
    
if __name__ == '__main__':
    configure_logger()
    parser = argparse.ArgumentParser(description='Import Huggingface LlamaForCausalLM weights to ALLaMo model')
    parser.add_argument(
        "--huggingface_model",
        help="Huggingface model path",
    )
    parser.add_argument(
        "--output_dir",
        help="Path to output directory",
    )
    args = parser.parse_args()
    
    os.makedirs(args.output_dir, exist_ok=True)
    import_model(
        hf_model_path=args.huggingface_model,
        output_model_path=args.output_dir,
    )
    logger.info("import completed")