gptj.py 4.33 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright (c) 2023, Tri Dao.

import math
import re
from collections import OrderedDict

import torch
import torch.nn.functional as F
from transformers import GPT2Config, GPTJConfig


def remap_state_dict_hf_gptj(state_dict, config):
    def key_mapping_layers(key):
Tri Dao's avatar
Tri Dao committed
14
15
        return re.sub(r"^transformer.h.", "transformer.layers.", key)

Tri Dao's avatar
Tri Dao committed
16
17
18
    state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
    # Word embedding
    def key_mapping_emb(key):
Tri Dao's avatar
Tri Dao committed
19
20
        return re.sub(r"^transformer.wte.", "transformer.embeddings.word_embeddings.", key)

Tri Dao's avatar
Tri Dao committed
21
    state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
Tri Dao's avatar
Tri Dao committed
22
    word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
Tri Dao's avatar
Tri Dao committed
23
    # It's possible that vocab_size is padded to be a multiple of 8, for example.
Tri Dao's avatar
Tri Dao committed
24
25
26
    pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
    vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
    state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
Tri Dao's avatar
Tri Dao committed
27
28
        word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
    )
Tri Dao's avatar
Tri Dao committed
29
30
    if getattr(config, "tie_word_embeddings"):
        state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
Tri Dao's avatar
Tri Dao committed
31
    else:
Tri Dao's avatar
Tri Dao committed
32
        output_embeddings = state_dict.pop("lm_head.weight")
Tri Dao's avatar
Tri Dao committed
33
        # It's possible that vocab_size is padded to be a multiple of 8, for example.
Tri Dao's avatar
Tri Dao committed
34
        state_dict["lm_head.weight"] = F.pad(
Tri Dao's avatar
Tri Dao committed
35
36
            output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
        )
Tri Dao's avatar
Tri Dao committed
37
38
        output_embeddings_bias = state_dict.pop("lm_head.bias")
        state_dict["lm_head.bias"] = F.pad(
Tri Dao's avatar
Tri Dao committed
39
40
            output_embeddings_bias, (0, vocab_size - output_embeddings_bias.shape[0])
        )
Tri Dao's avatar
Tri Dao committed
41
42
43

    # LayerNorm
    def key_mapping_ln(key):
Tri Dao's avatar
Tri Dao committed
44
45
        return re.sub(r"^transformer.layers.(\d+).ln_1.", r"transformer.layers.\1.norm1.", key)

Tri Dao's avatar
Tri Dao committed
46
47
48
49
    state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())

    # MLP
    def key_mapping_mlp(key):
Tri Dao's avatar
Tri Dao committed
50
51
52
53
54
55
        key = re.sub(
            r"^transformer.layers.(\d+).mlp.fc_in.", r"transformer.layers.\1.mlp.fc1.", key
        )
        key = re.sub(
            r"^transformer.layers.(\d+).mlp.fc_out.", r"transformer.layers.\1.mlp.fc2.", key
        )
Tri Dao's avatar
Tri Dao committed
56
        return key
Tri Dao's avatar
Tri Dao committed
57

Tri Dao's avatar
Tri Dao committed
58
59
60
61
    state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())

    # Attention
    for l in range(config.n_layer):
Tri Dao's avatar
Tri Dao committed
62
63
64
65
        Wq = state_dict.pop(f"transformer.layers.{l}.attn.q_proj.weight")
        Wk = state_dict.pop(f"transformer.layers.{l}.attn.k_proj.weight")
        Wv = state_dict.pop(f"transformer.layers.{l}.attn.v_proj.weight")
        state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat([Wq, Wk, Wv], dim=0)
Tri Dao's avatar
Tri Dao committed
66
        # We don't store these biases
Tri Dao's avatar
Tri Dao committed
67
68
69
        state_dict.pop(f"transformer.layers.{l}.attn.bias")
        state_dict.pop(f"transformer.layers.{l}.attn.masked_bias")

Tri Dao's avatar
Tri Dao committed
70
    def key_mapping_attn(key):
Tri Dao's avatar
Tri Dao committed
71
72
73
74
75
76
        return re.sub(
            r"^transformer.layers.(\d+).attn.out_proj.",
            r"transformer.layers.\1.mixer.out_proj.",
            key,
        )

Tri Dao's avatar
Tri Dao committed
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
    state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())

    return state_dict


def gptj_config_to_gpt2_config(gptj_config: GPTJConfig) -> GPT2Config:
    headdim = gptj_config.n_embd // gptj_config.n_head
    return GPT2Config(
        vocab_size=gptj_config.vocab_size,
        n_positions=0,  # No absolute position embedding
        n_embd=gptj_config.n_embd,
        n_layer=gptj_config.n_layer,
        n_head=gptj_config.n_head,
        n_inner=gptj_config.n_inner,
        activation_function=gptj_config.activation_function,
        resid_pdrop=gptj_config.resid_pdrop,
        embd_pdrop=gptj_config.embd_pdrop,
        attn_pdrop=gptj_config.attn_pdrop,
        layer_norm_epsilon=gptj_config.layer_norm_epsilon,
        initializer_range=gptj_config.initializer_range,
        bos_token_id=gptj_config.bos_token_id,
        eos_token_id=gptj_config.eos_token_id,
        # These are new arguments not in the original GPT2Config
        prenorm=True,
        parallel_block=True,
        parallel_block_tied_norm=True,
        rotary_emb_fraction=gptj_config.rotary_dim / headdim,
        rotary_emb_interleaved=True,
        tie_word_embeddings=False,
        qkv_proj_bias=False,
        out_proj_bias=False,
Tri Dao's avatar
Tri Dao committed
108
        lm_head_bias=True,
Tri Dao's avatar
Tri Dao committed
109
    )