tensor_mapping.py 29.4 KB
Newer Older
mashun1's avatar
v1  
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
from __future__ import annotations

from typing import Sequence

from .constants import MODEL_ARCH, MODEL_TENSOR, MODEL_TENSORS, TENSOR_NAMES


class TensorNameMap:
    mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = {
        # Token embeddings
        MODEL_TENSOR.TOKEN_EMBD: (
            "gpt_neox.embed_in",                         # gptneox
xuxzh1's avatar
init  
xuxzh1 committed
13
            "transformer.wte",                           # gpt2 gpt-j mpt refact qwen dbrx jais
mashun1's avatar
v1  
mashun1 committed
14
15
16
17
18
19
20
21
22
23
24
25
26
            "transformer.word_embeddings",               # falcon
            "word_embeddings",                           # bloom
            "model.embed_tokens",                        # llama-hf
            "tok_embeddings",                            # llama-pth
            "embeddings.word_embeddings",                # bert nomic-bert
            "language_model.embedding.word_embeddings",  # persimmon
            "wte",                                       # gpt2
            "transformer.embd.wte",                      # phi2
            "model.tok_embeddings",                      # internlm2
            "model.embedding",                           # mamba-qbert
            "backbone.embedding",                        # mamba
            "backbone.embeddings",                       # mamba-hf
            "transformer.in_out_embed",                  # Grok
xuxzh1's avatar
init  
xuxzh1 committed
27
28
29
            "embedding.word_embeddings",                 # chatglm
            "transformer.token_embeddings",              # openelm
            "shared",                                    # t5
mashun1's avatar
v1  
mashun1 committed
30
31
32
33
34
35
36
37
38
39
40
41
        ),

        # Token type embeddings
        MODEL_TENSOR.TOKEN_TYPES: (
            "embeddings.token_type_embeddings",  # bert nomic-bert
        ),

        # Normalization of token embeddings
        MODEL_TENSOR.TOKEN_EMBD_NORM: (
            "word_embeddings_layernorm",  # bloom
            "embeddings.LayerNorm",       # bert
            "emb_ln",                     # nomic-bert
xuxzh1's avatar
init  
xuxzh1 committed
42
            "transformer.norm",           # openelm
mashun1's avatar
v1  
mashun1 committed
43
44
45
46
47
48
49
50
51
52
53
54
        ),

        # Position embeddings
        MODEL_TENSOR.POS_EMBD: (
            "transformer.wpe",                 # gpt2
            "embeddings.position_embeddings",  # bert
            "wpe",                             # gpt2
        ),

        # Output
        MODEL_TENSOR.OUTPUT: (
            "embed_out",                 # gptneox
xuxzh1's avatar
init  
xuxzh1 committed
55
            "lm_head",                   # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais
mashun1's avatar
v1  
mashun1 committed
56
57
58
            "output",                    # llama-pth bloom internlm2
            "word_embeddings_for_head",  # persimmon
            "lm_head.linear",            # phi2
xuxzh1's avatar
init  
xuxzh1 committed
59
            "output_layer",              # chatglm
mashun1's avatar
v1  
mashun1 committed
60
61
62
63
64
        ),

        # Output norm
        MODEL_TENSOR.OUTPUT_NORM: (
            "gpt_neox.final_layer_norm",               # gptneox
xuxzh1's avatar
init  
xuxzh1 committed
65
            "transformer.ln_f",                        # gpt2 gpt-j falcon jais
mashun1's avatar
v1  
mashun1 committed
66
67
68
69
70
71
72
73
74
75
            "model.norm",                              # llama-hf baichuan internlm2
            "norm",                                    # llama-pth
            "transformer.norm_f",                      # mpt dbrx
            "ln_f",                                    # refact bloom qwen gpt2
            "language_model.encoder.final_layernorm",  # persimmon
            "model.final_layernorm",                   # persimmon
            "lm_head.ln",                              # phi2
            "model.norm_f",                            # mamba-qbert
            "backbone.norm_f",                         # mamba
            "transformer.rms_norm",                    # Grok
xuxzh1's avatar
init  
xuxzh1 committed
76
77
            "encoder.final_layernorm",                 # chatglm
            "transformer.norm",                        # openelm
mashun1's avatar
v1  
mashun1 committed
78
79
80
81
82
        ),

        # Rope frequencies
        MODEL_TENSOR.ROPE_FREQS: (
            "rope.freqs",  # llama-pth
xuxzh1's avatar
init  
xuxzh1 committed
83
            "rotary_pos_emb.inv_freq",  # chatglm
mashun1's avatar
v1  
mashun1 committed
84
85
86
87
88
89
90
        ),
    }

    block_mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = {
        # Attention norm
        MODEL_TENSOR.ATTN_NORM: (
            "gpt_neox.layers.{bid}.input_layernorm",                # gptneox
xuxzh1's avatar
init  
xuxzh1 committed
91
            "transformer.h.{bid}.ln_1",                             # gpt2 gpt-j refact qwen jais
mashun1's avatar
v1  
mashun1 committed
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
            "transformer.blocks.{bid}.norm_1",                      # mpt
            "transformer.h.{bid}.input_layernorm",                  # falcon7b
            "h.{bid}.input_layernorm",                              # bloom
            "transformer.h.{bid}.ln_mlp",                           # falcon40b
            "model.layers.{bid}.input_layernorm",                   # llama-hf
            "layers.{bid}.attention_norm",                          # llama-pth
            "language_model.encoder.layers.{bid}.input_layernorm",  # persimmon
            "model.layers.{bid}.ln1",                               # yi
            "h.{bid}.ln_1",                                         # gpt2
            "transformer.h.{bid}.ln",                               # phi2
            "model.layers.layers.{bid}.norm",                       # plamo
            "model.layers.{bid}.attention_norm",                    # internlm2
            "model.layers.{bid}.norm",                              # mamba-qbert
            "backbone.layers.{bid}.norm",                           # mamba
            "transformer.decoder_layer.{bid}.rms_norm",             # Grok
            "transformer.blocks.{bid}.norm_attn_norm.norm_1",       # dbrx
xuxzh1's avatar
init  
xuxzh1 committed
108
109
            "encoder.layers.{bid}.input_layernorm",                 # chatglm
            "transformer.layers.{bid}.attn_norm",                   # openelm
mashun1's avatar
v1  
mashun1 committed
110
111
112
113
114
        ),

        # Attention norm 2
        MODEL_TENSOR.ATTN_NORM_2: (
            "transformer.h.{bid}.ln_attn",  # falcon40b
xuxzh1's avatar
init  
xuxzh1 committed
115
            "encoder.layer.{bid}.layer_norm_1",             # jina-v2-code
mashun1's avatar
v1  
mashun1 committed
116
117
118
119
120
        ),

        # Attention query-key-value
        MODEL_TENSOR.ATTN_QKV: (
            "gpt_neox.layers.{bid}.attention.query_key_value",                     # gptneox
xuxzh1's avatar
init  
xuxzh1 committed
121
            "transformer.h.{bid}.attn.c_attn",                                     # gpt2 qwen jais
mashun1's avatar
v1  
mashun1 committed
122
123
124
125
126
127
128
129
130
            "transformer.blocks.{bid}.attn.Wqkv",                                  # mpt
            "transformer.blocks.{bid}.norm_attn_norm.attn.Wqkv",                   # dbrx
            "transformer.h.{bid}.self_attention.query_key_value",                  # falcon
            "h.{bid}.self_attention.query_key_value",                              # bloom
            "language_model.encoder.layers.{bid}.self_attention.query_key_value",  # persimmon
            "model.layers.{bid}.self_attn.query_key_value",                        # persimmon
            "h.{bid}.attn.c_attn",                                                 # gpt2
            "transformer.h.{bid}.mixer.Wqkv",                                      # phi2
            "encoder.layers.{bid}.attn.Wqkv",                                      # nomic-bert
xuxzh1's avatar
init  
xuxzh1 committed
131
132
133
            "model.layers.{bid}.self_attn.qkv_proj",                               # phi3
            "encoder.layers.{bid}.self_attention.query_key_value",                 # chatglm
            "transformer.layers.{bid}.attn.qkv_proj",                              # openelm
mashun1's avatar
v1  
mashun1 committed
134
135
136
137
138
139
140
141
142
143
        ),

        # Attention query
        MODEL_TENSOR.ATTN_Q: (
            "model.layers.{bid}.self_attn.q_proj",                       # llama-hf
            "layers.{bid}.attention.wq",                                 # llama-pth
            "encoder.layer.{bid}.attention.self.query",                  # bert
            "transformer.h.{bid}.attn.q_proj",                           # gpt-j
            "model.layers.layers.{bid}.self_attn.q_proj",                # plamo
            "model.layers.{bid}.attention.wq",                           # internlm2
xuxzh1's avatar
init  
xuxzh1 committed
144
            "transformer.decoder_layer.{bid}.multi_head_attention.query",# Grok
mashun1's avatar
v1  
mashun1 committed
145
146
147
148
149
150
151
152
153
154
155
        ),

        # Attention key
        MODEL_TENSOR.ATTN_K: (
            "model.layers.{bid}.self_attn.k_proj",                     # llama-hf
            "layers.{bid}.attention.wk",                               # llama-pth
            "encoder.layer.{bid}.attention.self.key",                  # bert
            "transformer.h.{bid}.attn.k_proj",                         # gpt-j
            "transformer.h.{bid}.attn.k",                              # refact
            "model.layers.layers.{bid}.self_attn.k_proj",              # plamo
            "model.layers.{bid}.attention.wk",                         # internlm2
xuxzh1's avatar
init  
xuxzh1 committed
156
            "transformer.decoder_layer.{bid}.multi_head_attention.key",# Grok
mashun1's avatar
v1  
mashun1 committed
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
        ),

        # Attention value
        MODEL_TENSOR.ATTN_V: (
            "model.layers.{bid}.self_attn.v_proj",                       # llama-hf
            "layers.{bid}.attention.wv",                                 # llama-pth
            "encoder.layer.{bid}.attention.self.value",                  # bert
            "transformer.h.{bid}.attn.v_proj",                           # gpt-j
            "transformer.h.{bid}.attn.v",                                # refact
            "model.layers.layers.{bid}.self_attn.v_proj",                # plamo
            "model.layers.{bid}.attention.wv",                           # internlm2
            "transformer.decoder_layer.{bid}.multi_head_attention.value" # Grok
        ),

        # Attention output
        MODEL_TENSOR.ATTN_OUT: (
            "gpt_neox.layers.{bid}.attention.dense",                        # gptneox
xuxzh1's avatar
init  
xuxzh1 committed
174
            "transformer.h.{bid}.attn.c_proj",                              # gpt2 refact qwen jais
mashun1's avatar
v1  
mashun1 committed
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
            "transformer.blocks.{bid}.attn.out_proj",                       # mpt
            "transformer.h.{bid}.self_attention.dense",                     # falcon
            "h.{bid}.self_attention.dense",                                 # bloom
            "model.layers.{bid}.self_attn.o_proj",                          # llama-hf
            "layers.{bid}.attention.wo",                                    # llama-pth
            "encoder.layer.{bid}.attention.output.dense",                   # bert
            "transformer.h.{bid}.attn.out_proj",                            # gpt-j
            "language_model.encoder.layers.{bid}.self_attention.dense",     # persimmon
            "model.layers.{bid}.self_attn.dense",                           # persimmon
            "h.{bid}.attn.c_proj",                                          # gpt2
            "transformer.h.{bid}.mixer.out_proj",                           # phi2
            "model.layers.layers.{bid}.self_attn.o_proj",                   # plamo
            "model.layers.{bid}.attention.wo",                              # internlm2
            "encoder.layers.{bid}.attn.out_proj",                           # nomic-bert
            "transformer.decoder_layer.{bid}.multi_head_attention.linear",  # Grok
            "transformer.blocks.{bid}.norm_attn_norm.attn.out_proj",        # dbrx
xuxzh1's avatar
init  
xuxzh1 committed
191
192
            "encoder.layers.{bid}.self_attention.dense",                    # chatglm
            "transformer.layers.{bid}.attn.out_proj",                       # openelm
mashun1's avatar
v1  
mashun1 committed
193
194
195
196
197
198
199
200
201
202
        ),

        # Attention output norm
        MODEL_TENSOR.ATTN_OUT_NORM: (
            "encoder.layer.{bid}.attention.output.LayerNorm",  # bert
            "encoder.layers.{bid}.norm1",                      # nomic-bert
            "transformer.decoder_layer.{bid}.rms_norm_1",      # Grok
            "transformer.blocks.{bid}.norm_attn_norm.norm_2",  # dbrx
        ),

xuxzh1's avatar
init  
xuxzh1 committed
203
204
205
206
        MODEL_TENSOR.ATTN_POST_NORM: (
            "model.layers.{bid}.post_attention_layernorm",     # gemma2
        ),

mashun1's avatar
v1  
mashun1 committed
207
208
209
210
211
212
213
214
215
216
217
        # Rotary embeddings
        MODEL_TENSOR.ATTN_ROT_EMBD: (
            "model.layers.{bid}.self_attn.rotary_emb.inv_freq",        # llama-hf
            "layers.{bid}.attention.inner_attention.rope.freqs",       # llama-pth
            "model.layers.layers.{bid}.self_attn.rotary_emb.inv_freq", # plamo
            "transformer.h.{bid}.attn.rotary_emb.inv_freq",            # codeshell
        ),

        # Feed-forward norm
        MODEL_TENSOR.FFN_NORM: (
            "gpt_neox.layers.{bid}.post_attention_layernorm",                # gptneox
xuxzh1's avatar
init  
xuxzh1 committed
218
            "transformer.h.{bid}.ln_2",                                      # gpt2 refact qwen jais
mashun1's avatar
v1  
mashun1 committed
219
220
221
222
223
224
225
226
227
            "h.{bid}.post_attention_layernorm",                              # bloom
            "transformer.blocks.{bid}.norm_2",                               # mpt
            "model.layers.{bid}.post_attention_layernorm",                   # llama-hf
            "layers.{bid}.ffn_norm",                                         # llama-pth
            "language_model.encoder.layers.{bid}.post_attention_layernorm",  # persimmon
            "model.layers.{bid}.ln2",                                        # yi
            "h.{bid}.ln_2",                                                  # gpt2
            "model.layers.{bid}.ffn_norm",                                   # internlm2
            "transformer.decoder_layer.{bid}.rms_norm_2",                    # Grok
xuxzh1's avatar
init  
xuxzh1 committed
228
229
230
231
232
233
234
235
236
237
238
239
            "encoder.layers.{bid}.post_attention_layernorm",                 # chatglm
            "transformer.layers.{bid}.ffn_norm",                             # openelm
        ),

        # Post feed-forward norm
        MODEL_TENSOR.FFN_PRE_NORM: (
            "model.layers.{bid}.pre_feedforward_layernorm", # gemma2
        ),

        # Post feed-forward norm
        MODEL_TENSOR.FFN_POST_NORM: (
            "model.layers.{bid}.post_feedforward_layernorm", # gemma2
mashun1's avatar
v1  
mashun1 committed
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
        ),

        MODEL_TENSOR.FFN_GATE_INP: (
            "layers.{bid}.feed_forward.gate",             # mixtral
            "model.layers.{bid}.block_sparse_moe.gate",   # mixtral
            "model.layers.{bid}.mlp.gate",                # qwen2moe
            "transformer.decoder_layer.{bid}.router",     # Grok
            "transformer.blocks.{bid}.ffn.router.layer",  # dbrx
        ),

        MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
            "model.layers.{bid}.mlp.shared_expert_gate", # qwen2moe
        ),

        # Feed-forward up
        MODEL_TENSOR.FFN_UP: (
            "gpt_neox.layers.{bid}.mlp.dense_h_to_4h",                # gptneox
xuxzh1's avatar
init  
xuxzh1 committed
257
            "transformer.h.{bid}.mlp.c_fc",                           # gpt2 jais
mashun1's avatar
v1  
mashun1 committed
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
            "transformer.blocks.{bid}.ffn.up_proj",                   # mpt
            "transformer.h.{bid}.mlp.dense_h_to_4h",                  # falcon
            "h.{bid}.mlp.dense_h_to_4h",                              # bloom
            "model.layers.{bid}.mlp.up_proj",                         # llama-hf refact
            "layers.{bid}.feed_forward.w3",                           # llama-pth
            "encoder.layer.{bid}.intermediate.dense",                 # bert
            "transformer.h.{bid}.mlp.fc_in",                          # gpt-j
            "transformer.h.{bid}.mlp.linear_3",                       # refact
            "language_model.encoder.layers.{bid}.mlp.dense_h_to_4h",  # persimmon
            "model.layers.{bid}.mlp.dense_h_to_4h",                   # persimmon
            "transformer.h.{bid}.mlp.w1",                             # qwen
            "h.{bid}.mlp.c_fc",                                       # gpt2
            "transformer.h.{bid}.mlp.fc1",                            # phi2
            "model.layers.{bid}.mlp.fc1",                             # phi2
            "model.layers.{bid}.mlp.gate_up_proj",                    # phi3
            "model.layers.layers.{bid}.mlp.up_proj",                  # plamo
            "model.layers.{bid}.feed_forward.w3",                     # internlm2
            "encoder.layers.{bid}.mlp.fc11",                          # nomic-bert
            "model.layers.{bid}.mlp.c_fc",                            # starcoder2
            "encoder.layer.{bid}.mlp.gated_layers_v",                 # jina-bert-v2
            "model.layers.{bid}.residual_mlp.w3",                     # arctic
xuxzh1's avatar
init  
xuxzh1 committed
279
            "encoder.layers.{bid}.mlp.dense_h_to_4h",                 # chatglm
mashun1's avatar
v1  
mashun1 committed
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
        ),

        MODEL_TENSOR.FFN_UP_EXP: (
            "layers.{bid}.feed_forward.experts.w3",          # mixtral (merged)
            "transformer.decoder_layer.{bid}.moe.linear_v",  # Grok (merged)
            "transformer.blocks.{bid}.ffn.experts.mlp.v1",   # dbrx
            "model.layers.{bid}.mlp.experts.up_proj",        # qwen2moe (merged)
        ),

        MODEL_TENSOR.FFN_UP_SHEXP: (
            "model.layers.{bid}.mlp.shared_expert.up_proj",  # qwen2moe
            "model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek2
        ),

        # AWQ-activation gate
        MODEL_TENSOR.FFN_ACT: (
            "transformer.blocks.{bid}.ffn.act",  # mpt
        ),

        # Feed-forward gate
        MODEL_TENSOR.FFN_GATE: (
            "model.layers.{bid}.mlp.gate_proj",           # llama-hf refact
            "layers.{bid}.feed_forward.w1",               # llama-pth
            "transformer.h.{bid}.mlp.w2",                 # qwen
xuxzh1's avatar
init  
xuxzh1 committed
304
            "transformer.h.{bid}.mlp.c_fc2",              # jais
mashun1's avatar
v1  
mashun1 committed
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
            "model.layers.layers.{bid}.mlp.gate_proj",    # plamo
            "model.layers.{bid}.feed_forward.w1",         # internlm2
            "encoder.layers.{bid}.mlp.fc12",              # nomic-bert
            "encoder.layer.{bid}.mlp.gated_layers_w",     # jina-bert-v2
            "transformer.h.{bid}.mlp.linear_1",           # refact
            "model.layers.{bid}.residual_mlp.w1",         # arctic
        ),

        MODEL_TENSOR.FFN_GATE_EXP: (
            "layers.{bid}.feed_forward.experts.w1",         # mixtral (merged)
            "transformer.decoder_layer.{bid}.moe.linear",   # Grok (merged)
            "transformer.blocks.{bid}.ffn.experts.mlp.w1",  # dbrx
            "model.layers.{bid}.mlp.experts.gate_proj",     # qwen2moe (merged)
        ),

        MODEL_TENSOR.FFN_GATE_SHEXP: (
            "model.layers.{bid}.mlp.shared_expert.gate_proj",  # qwen2moe
            "model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek2
        ),

        # Feed-forward down
        MODEL_TENSOR.FFN_DOWN: (
            "gpt_neox.layers.{bid}.mlp.dense_4h_to_h",                # gptneox
xuxzh1's avatar
init  
xuxzh1 committed
328
            "transformer.h.{bid}.mlp.c_proj",                         # gpt2 refact qwen jais
mashun1's avatar
v1  
mashun1 committed
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
            "transformer.blocks.{bid}.ffn.down_proj",                 # mpt
            "transformer.h.{bid}.mlp.dense_4h_to_h",                  # falcon
            "h.{bid}.mlp.dense_4h_to_h",                              # bloom
            "model.layers.{bid}.mlp.down_proj",                       # llama-hf
            "layers.{bid}.feed_forward.w2",                           # llama-pth
            "encoder.layer.{bid}.output.dense",                       # bert
            "transformer.h.{bid}.mlp.fc_out",                         # gpt-j
            "language_model.encoder.layers.{bid}.mlp.dense_4h_to_h",  # persimmon
            "model.layers.{bid}.mlp.dense_4h_to_h",                   # persimmon
            "h.{bid}.mlp.c_proj",                                     # gpt2
            "transformer.h.{bid}.mlp.fc2",                            # phi2
            "model.layers.{bid}.mlp.fc2",                             # phi2
            "model.layers.layers.{bid}.mlp.down_proj",                # plamo
            "model.layers.{bid}.feed_forward.w2",                     # internlm2
            "encoder.layers.{bid}.mlp.fc2",                           # nomic-bert
            "model.layers.{bid}.mlp.c_proj",                          # starcoder2
            "encoder.layer.{bid}.mlp.wo",                             # jina-bert-v2
xuxzh1's avatar
init  
xuxzh1 committed
346
            "transformer.layers.{bid}.ffn.proj_2",                    # openelm
mashun1's avatar
v1  
mashun1 committed
347
            "model.layers.{bid}.residual_mlp.w2",                     # arctic
xuxzh1's avatar
init  
xuxzh1 committed
348
349
            "encoder.layer.{bid}.mlp.down_layer",                     # jina-bert-v2
            "encoder.layers.{bid}.mlp.dense_4h_to_h",                 # chatglm
mashun1's avatar
v1  
mashun1 committed
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
        ),

        MODEL_TENSOR.FFN_DOWN_EXP: (
            "layers.{bid}.feed_forward.experts.w2",          # mixtral (merged)
            "transformer.decoder_layer.{bid}.moe.linear_1",  # Grok (merged)
            "transformer.blocks.{bid}.ffn.experts.mlp.w2",   # dbrx
            "model.layers.{bid}.mlp.experts.down_proj",      # qwen2moe (merged)
        ),

        MODEL_TENSOR.FFN_DOWN_SHEXP: (
            "model.layers.{bid}.mlp.shared_expert.down_proj",  # qwen2moe
            "model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek2
        ),

        MODEL_TENSOR.ATTN_Q_NORM: (
            "language_model.encoder.layers.{bid}.self_attention.q_layernorm",
            "model.layers.{bid}.self_attn.q_layernorm",                       # persimmon
            "model.layers.{bid}.self_attn.q_norm",                            # cohere
            "transformer.blocks.{bid}.attn.q_ln",                             # sea-lion
xuxzh1's avatar
init  
xuxzh1 committed
369
370
            "encoder.layer.{bid}.attention.self.layer_norm_q",                # jina-bert-v2
            "transformer.layers.{bid}.attn.q_norm",                           # openelm
mashun1's avatar
v1  
mashun1 committed
371
372
373
374
375
376
377
        ),

        MODEL_TENSOR.ATTN_K_NORM: (
            "language_model.encoder.layers.{bid}.self_attention.k_layernorm",
            "model.layers.{bid}.self_attn.k_layernorm",                       # persimmon
            "model.layers.{bid}.self_attn.k_norm",                            # cohere
            "transformer.blocks.{bid}.attn.k_ln",                             # sea-lion
xuxzh1's avatar
init  
xuxzh1 committed
378
379
            "encoder.layer.{bid}.attention.self.layer_norm_k",                # jina-bert-v2
            "transformer.layers.{bid}.attn.k_norm",                           # openelm
mashun1's avatar
v1  
mashun1 committed
380
381
382
383
384
385
386
387
388
389
390
        ),

        MODEL_TENSOR.ROPE_FREQS: (
            "language_model.encoder.layers.{bid}.self_attention.rotary_emb.inv_freq",  # persimmon
        ),

        MODEL_TENSOR.LAYER_OUT_NORM: (
            "encoder.layer.{bid}.output.LayerNorm",         # bert
            "encoder.layers.{bid}.norm2",                   # nomic-bert
            "transformer.decoder_layer.{bid}.rms_norm_3",   # Grok
            "encoder.layer.{bid}.mlp.layernorm",            # jina-bert-v2
xuxzh1's avatar
init  
xuxzh1 committed
391
            "encoder.layer.{bid}.layer_norm_2"              # jina-v2-code
mashun1's avatar
v1  
mashun1 committed
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
        ),

        MODEL_TENSOR.SSM_IN: (
            "model.layers.{bid}.in_proj",
            "backbone.layers.{bid}.mixer.in_proj",
        ),

        MODEL_TENSOR.SSM_CONV1D: (
            "model.layers.{bid}.conv1d",
            "backbone.layers.{bid}.mixer.conv1d",
        ),

        MODEL_TENSOR.SSM_X: (
            "model.layers.{bid}.x_proj",
            "backbone.layers.{bid}.mixer.x_proj",
        ),

        MODEL_TENSOR.SSM_DT: (
            "model.layers.{bid}.dt_proj",
            "backbone.layers.{bid}.mixer.dt_proj",
        ),

        MODEL_TENSOR.SSM_A: (
            "model.layers.{bid}.A_log",
            "backbone.layers.{bid}.mixer.A_log",
        ),

        MODEL_TENSOR.SSM_D: (
            "model.layers.{bid}.D",
            "backbone.layers.{bid}.mixer.D",
        ),

        MODEL_TENSOR.SSM_OUT: (
            "model.layers.{bid}.out_proj",
            "backbone.layers.{bid}.mixer.out_proj",
        ),

        MODEL_TENSOR.ATTN_Q_A: (
            "model.layers.{bid}.self_attn.q_a_proj", # deepseek2
        ),

        MODEL_TENSOR.ATTN_Q_B: (
            "model.layers.{bid}.self_attn.q_b_proj", # deepseek2
        ),

        MODEL_TENSOR.ATTN_KV_A_MQA: (
            "model.layers.{bid}.self_attn.kv_a_proj_with_mqa", # deepseek2
        ),

        MODEL_TENSOR.ATTN_KV_B: (
            "model.layers.{bid}.self_attn.kv_b_proj", # deepseek2
        ),

        MODEL_TENSOR.ATTN_Q_A_NORM: (
            "model.layers.{bid}.self_attn.q_a_layernorm", # deepseek2
        ),

        MODEL_TENSOR.ATTN_KV_A_NORM: (
            "model.layers.{bid}.self_attn.kv_a_layernorm", # deepseek2
        ),
xuxzh1's avatar
init  
xuxzh1 committed
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573

        MODEL_TENSOR.ATTN_SUB_NORM: (
            "model.layers.{bid}.self_attn.inner_attn_ln",  # bitnet
        ),

        MODEL_TENSOR.FFN_SUB_NORM: (
            "model.layers.{bid}.mlp.ffn_layernorm",  # bitnet
        ),

        MODEL_TENSOR.DEC_ATTN_NORM: (
            "decoder.block.{bid}.layer.0.layer_norm", # t5
        ),

        MODEL_TENSOR.DEC_ATTN_Q: (
            "decoder.block.{bid}.layer.0.SelfAttention.q", # t5
        ),

        MODEL_TENSOR.DEC_ATTN_K: (
            "decoder.block.{bid}.layer.0.SelfAttention.k", # t5
        ),

        MODEL_TENSOR.DEC_ATTN_V: (
            "decoder.block.{bid}.layer.0.SelfAttention.v", # t5
        ),

        MODEL_TENSOR.DEC_ATTN_OUT: (
            "decoder.block.{bid}.layer.0.SelfAttention.o", # t5
        ),

        MODEL_TENSOR.DEC_ATTN_REL_B: (
            "decoder.block.{bid}.layer.0.SelfAttention.relative_attention_bias", # t5
        ),

        MODEL_TENSOR.DEC_CROSS_ATTN_NORM: (
            "decoder.block.{bid}.layer.1.layer_norm", # t5
        ),

        MODEL_TENSOR.DEC_CROSS_ATTN_Q: (
            "decoder.block.{bid}.layer.1.EncDecAttention.q", # t5
        ),

        MODEL_TENSOR.DEC_CROSS_ATTN_K: (
            "decoder.block.{bid}.layer.1.EncDecAttention.k", # t5
        ),

        MODEL_TENSOR.DEC_CROSS_ATTN_V: (
            "decoder.block.{bid}.layer.1.EncDecAttention.v", # t5
        ),

        MODEL_TENSOR.DEC_CROSS_ATTN_OUT: (
            "decoder.block.{bid}.layer.1.EncDecAttention.o", # t5
        ),

        MODEL_TENSOR.DEC_CROSS_ATTN_REL_B: (
            "decoder.block.{bid}.layer.1.EncDecAttention.relative_attention_bias", # t5
        ),

        MODEL_TENSOR.DEC_FFN_NORM: (
            "decoder.block.{bid}.layer.2.layer_norm", # t5
        ),

        MODEL_TENSOR.DEC_FFN_GATE: (
            "decoder.block.{bid}.layer.2.DenseReluDense.wi_0", # flan-t5
        ),

        MODEL_TENSOR.DEC_FFN_UP: (
            "decoder.block.{bid}.layer.2.DenseReluDense.wi",   # t5
            "decoder.block.{bid}.layer.2.DenseReluDense.wi_1", # flan-t5
        ),

        MODEL_TENSOR.DEC_FFN_DOWN: (
            "decoder.block.{bid}.layer.2.DenseReluDense.wo", # t5
        ),

        MODEL_TENSOR.DEC_OUTPUT_NORM: (
            "decoder.final_layer_norm", # t5
        ),

        MODEL_TENSOR.ENC_ATTN_NORM: (
            "encoder.block.{bid}.layer.0.layer_norm", # t5
        ),

        MODEL_TENSOR.ENC_ATTN_Q: (
            "encoder.block.{bid}.layer.0.SelfAttention.q", # t5
        ),

        MODEL_TENSOR.ENC_ATTN_K: (
            "encoder.block.{bid}.layer.0.SelfAttention.k", # t5
        ),

        MODEL_TENSOR.ENC_ATTN_V: (
            "encoder.block.{bid}.layer.0.SelfAttention.v", # t5
        ),

        MODEL_TENSOR.ENC_ATTN_OUT: (
            "encoder.block.{bid}.layer.0.SelfAttention.o", # t5
        ),

        MODEL_TENSOR.ENC_ATTN_REL_B: (
            "encoder.block.{bid}.layer.0.SelfAttention.relative_attention_bias", # t5
        ),

        MODEL_TENSOR.ENC_FFN_NORM: (
            "encoder.block.{bid}.layer.1.layer_norm", # t5
        ),

        MODEL_TENSOR.ENC_FFN_GATE: (
            "encoder.block.{bid}.layer.1.DenseReluDense.wi_0", # flan-t5
        ),

        MODEL_TENSOR.ENC_FFN_UP: (
            "encoder.block.{bid}.layer.1.DenseReluDense.wi",   # t5
            "encoder.block.{bid}.layer.1.DenseReluDense.wi_1", # flan-t5
        ),

        MODEL_TENSOR.ENC_FFN_DOWN: (
            "encoder.block.{bid}.layer.1.DenseReluDense.wo", # t5
        ),

        MODEL_TENSOR.ENC_OUTPUT_NORM: (
            "encoder.final_layer_norm", # t5
        ),
mashun1's avatar
v1  
mashun1 committed
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
    }

    # architecture-specific block mappings
    arch_block_mappings_cfg: dict[MODEL_ARCH, dict[MODEL_TENSOR, tuple[str, ...]]] = {
        MODEL_ARCH.ARCTIC: {
            MODEL_TENSOR.FFN_NORM: (
                "model.layers.{bid}.residual_layernorm",
            ),
            MODEL_TENSOR.FFN_NORM_EXP: (
                "model.layers.{bid}.post_attention_layernorm",
            ),
        },
    }

    mapping: dict[str, tuple[MODEL_TENSOR, str]]

    def __init__(self, arch: MODEL_ARCH, n_blocks: int):
        self.mapping = {}
        for tensor, keys in self.mappings_cfg.items():
            if tensor not in MODEL_TENSORS[arch]:
                continue
            tensor_name = TENSOR_NAMES[tensor]
            self.mapping[tensor_name] = (tensor, tensor_name)
            for key in keys:
                self.mapping[key] = (tensor, tensor_name)
        if arch in self.arch_block_mappings_cfg:
            self.block_mappings_cfg.update(self.arch_block_mappings_cfg[arch])
        for bid in range(n_blocks):
            for tensor, keys in self.block_mappings_cfg.items():
                if tensor not in MODEL_TENSORS[arch]:
                    continue
xuxzh1's avatar
init  
xuxzh1 committed
605
606
607
608
609
610

                tensor_name = TENSOR_NAMES[tensor].format(bid = bid)
                self.mapping[tensor_name] = (tensor, tensor_name)
                for key in keys:
                    key = key.format(bid = bid)
                    self.mapping[key] = (tensor, tensor_name)
mashun1's avatar
v1  
mashun1 committed
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649

    def get_type_and_name(self, key: str, try_suffixes: Sequence[str] = ()) -> tuple[MODEL_TENSOR, str] | None:
        result = self.mapping.get(key)
        if result is not None:
            return result
        for suffix in try_suffixes:
            if key.endswith(suffix):
                result = self.mapping.get(key[:-len(suffix)])
                if result is not None:
                    return result[0], result[1] + suffix
        return None

    def get_name(self, key: str, try_suffixes: Sequence[str] = ()) -> str | None:
        result = self.get_type_and_name(key, try_suffixes = try_suffixes)
        if result is None:
            return None
        return result[1]

    def get_type(self, key: str, try_suffixes: Sequence[str] = ()) -> MODEL_TENSOR | None:
        result = self.get_type_and_name(key, try_suffixes = try_suffixes)
        if result is None:
            return None
        return result[0]

    def __getitem__(self, key: str) -> str:
        try:
            return self.mapping[key][1]
        except KeyError:
            raise KeyError(key)

    def __contains__(self, key: str) -> bool:
        return key in self.mapping

    def __repr__(self) -> str:
        return repr(self.mapping)


def get_tensor_name_map(arch: MODEL_ARCH, n_blocks: int) -> TensorNameMap:
    return TensorNameMap(arch, n_blocks)