"magic_pdf/pre_proc/resolve_bbox_conflict.py.bak" did not exist on "f01cb89f019d6ddb98ba637cff0a59ef2f2f58d9"
import_hf_mistral_weights.py 5.33 KB
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""
Use this file to import Huggingface MistralForCausalLM weights to ALLaMo format.   
"""
import argparse
import dataclasses
import json
import os
import torch
from transformers import MistralForCausalLM
from allamo.logging import configure_logger, logger
from allamo.model.model import AllamoTransformerConfig, AllamoTransformer

def import_model(hf_model_path, output_model_path):
    logger.info(f"Importing Huggingface MistralForCausalLM weights")
    hf_model = MistralForCausalLM.from_pretrained(hf_model_path, torch_dtype=torch.float32, low_cpu_mem_usage=True)
    logger.info(f"Huggingface MistralForCausalLM model loaded")

    assert hf_model.config.hidden_act == "silu"
    
    config = AllamoTransformerConfig()
    config.block_size = hf_model.config.sliding_window # hf_model.config.max_position_embeddings
    config.vocab_size = hf_model.config.vocab_size
    config.n_layer = hf_model.config.num_hidden_layers
    config.n_head = hf_model.config.num_attention_heads
    config.n_embd = hf_model.config.hidden_size
    config.intermediate_size = hf_model.config.intermediate_size
    config.head_size = config.n_embd // config.n_head
    config.num_kv_heads = hf_model.config.num_key_value_heads
    config.sliding_window = hf_model.config.sliding_window
    config.dropout = 0.0
    config.bias = False
    config.norm_eps = hf_model.config.rms_norm_eps
    config.rope_freq_base = int(hf_model.config.rope_theta)

    logger.info(f"initializing vanilla ALLaMo model")
    model = AllamoTransformer(config)
    
    logger.info(f"preparing weights")
    state_dicts_map = {}
    sd_hf_model = hf_model.state_dict()
    model_sd = model.state_dict()
    for layer_i in range(config.n_layer):
        state_dicts_map[f"layers.{layer_i}.attention.q_proj.weight"] = f"model.layers.{layer_i}.self_attn.q_proj.weight"
        state_dicts_map[f"layers.{layer_i}.attention.k_proj.weight"] = f"model.layers.{layer_i}.self_attn.k_proj.weight"
        state_dicts_map[f"layers.{layer_i}.attention.v_proj.weight"] = f"model.layers.{layer_i}.self_attn.v_proj.weight"
        state_dicts_map[f"layers.{layer_i}.attention.c_proj.weight"] = f"model.layers.{layer_i}.self_attn.o_proj.weight"
        state_dicts_map[f"layers.{layer_i}.attention.rotary_emb.inv_freq"] = f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"
        state_dicts_map[f"layers.{layer_i}.feed_forward.gate_proj.weight"] = f"model.layers.{layer_i}.mlp.gate_proj.weight"
        state_dicts_map[f"layers.{layer_i}.feed_forward.down_proj.weight"] = f"model.layers.{layer_i}.mlp.down_proj.weight"
        state_dicts_map[f"layers.{layer_i}.feed_forward.up_proj.weight"] = f"model.layers.{layer_i}.mlp.up_proj.weight"
        state_dicts_map[f"layers.{layer_i}.attention_norm.weight"] = f"model.layers.{layer_i}.input_layernorm.weight"
        state_dicts_map[f"layers.{layer_i}.ffn_norm.weight"] = f"model.layers.{layer_i}.post_attention_layernorm.weight"
    state_dicts_map["tok_embeddings.weight"] = "model.embed_tokens.weight"
    state_dicts_map["norm.weight"] = "model.norm.weight"
    state_dicts_map["lm_head.weight"] = "lm_head.weight"
    
    logger.info(f"checking params coverage")
    for k, v in model_sd.items():
        if k not in state_dicts_map:
            logger.info(f"{k} param won't be updated in the ALLaMo model!")
            
    for k, v in sd_hf_model.items():
        if k not in state_dicts_map.values():
            logger.info(f"{k} param won't be copied to the ALLaMo model!")
    
    logger.info(f"copying params to the ALLaMo model")
    param_count = 0
    for k, v in state_dicts_map.items():
        if not k.endswith('rotary_emb.inv_freq'):
            assert sd_hf_model[v].shape == model_sd[k].shape
            with torch.no_grad():
                model_sd[k].copy_(sd_hf_model[v])
            param_count += model_sd[k].numel()
    logger.info(f"{param_count} params copied to the ALLaMo model")
    
    for k, _ in model_sd.items():
        if not torch.all(torch.eq(model_sd[k], sd_hf_model[state_dicts_map[k]])):
            logger.info(f"{k} param in the ALLaMo model is not the same as {state_dicts_map[k]} param in the source model!")
    logger.info(f"params verified")
    
    ckpt_file_name = 'import_ckpt'
    config_checkpoint = {
        'model_args': dataclasses.asdict(config)
    }
    ckpt_file_path = os.path.join(output_model_path, f'config_{ckpt_file_name}.json')
    logger.info(f"saving config checkpoint to {ckpt_file_path}")
    with open(ckpt_file_path, "w", encoding="utf-8") as f:
        json.dump(config_checkpoint, f, indent=4, ensure_ascii=False)
    ckpt_file_path = os.path.join(output_model_path, f'model_{ckpt_file_name}.pt')
    logger.info(f"saving model checkpoint to {ckpt_file_path}")
    torch.save(model_sd, ckpt_file_path)
    logger.info(f"checkpoint files saved in {output_model_path}")
    
if __name__ == '__main__':
    configure_logger()
    parser = argparse.ArgumentParser(description='Import Huggingface MistralForCausalLM weights to ALLaMo format')
    parser.add_argument(
        "--huggingface_model",
        help="Huggingface model path",
    )
    parser.add_argument(
        "--output_dir",
        help="Path to output directory",
    )
    args = parser.parse_args()
    
    os.makedirs(args.output_dir, exist_ok=True)
    import_model(
        hf_model_path=args.huggingface_model,
        output_model_path=args.output_dir,
    )
    logger.info("import completed")