"langchain_demo/main.py" did not exist on "d057250751af519923b789b75fd2c0c8ba93eaf9"
gpt_neox.py 5.08 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright (c) 2023, Tri Dao.

import math
import re
from collections import OrderedDict

import torch
import torch.nn.functional as F
from einops import rearrange
from transformers import GPT2Config, GPTNeoXConfig


def remap_state_dict_hf_gpt_neox(state_dict, config):
    def key_mapping_layers(key):
Tri Dao's avatar
Tri Dao committed
15
16
        return re.sub(r"^gpt_neox.", "transformer.", key)

Tri Dao's avatar
Tri Dao committed
17
18
19
    state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
    # Word embedding
    def key_mapping_emb(key):
Tri Dao's avatar
Tri Dao committed
20
21
        return re.sub(r"^transformer.embed_in.", "transformer.embeddings.word_embeddings.", key)

Tri Dao's avatar
Tri Dao committed
22
    state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
Tri Dao's avatar
Tri Dao committed
23
    word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
Tri Dao's avatar
Tri Dao committed
24
    # It's possible that vocab_size is padded to be a multiple of 8, for example.
Tri Dao's avatar
Tri Dao committed
25
26
27
    pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
    vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
    state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
Tri Dao's avatar
Tri Dao committed
28
29
        word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
    )
Tri Dao's avatar
Tri Dao committed
30
31
    if getattr(config, "tie_word_embeddings"):
        state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
Tri Dao's avatar
Tri Dao committed
32
    else:
Tri Dao's avatar
Tri Dao committed
33
        output_embeddings = state_dict.pop("embed_out.weight")
Tri Dao's avatar
Tri Dao committed
34
        # It's possible that vocab_size is padded to be a multiple of 8, for example.
Tri Dao's avatar
Tri Dao committed
35
        state_dict["lm_head.weight"] = F.pad(
Tri Dao's avatar
Tri Dao committed
36
37
38
39
40
            output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
        )

    # LayerNorm
    def key_mapping_ln(key):
Tri Dao's avatar
Tri Dao committed
41
42
43
44
45
46
47
48
49
        key = re.sub(r"^transformer.final_layer_norm.", r"transformer.ln_f.", key)
        key = re.sub(
            r"^transformer.layers.(\d+).input_layernorm.", r"transformer.layers.\1.norm1.", key
        )
        key = re.sub(
            r"^transformer.layers.(\d+).post_attention_layernorm.",
            r"transformer.layers.\1.norm2.",
            key,
        )
Tri Dao's avatar
Tri Dao committed
50
        return key
Tri Dao's avatar
Tri Dao committed
51

Tri Dao's avatar
Tri Dao committed
52
53
54
55
    state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())

    # MLP
    def key_mapping_mlp(key):
Tri Dao's avatar
Tri Dao committed
56
57
58
59
60
61
        key = re.sub(
            r"^transformer.layers.(\d+).mlp.dense_h_to_4h.", r"transformer.layers.\1.mlp.fc1.", key
        )
        key = re.sub(
            r"^transformer.layers.(\d+).mlp.dense_4h_to_h.", r"transformer.layers.\1.mlp.fc2.", key
        )
Tri Dao's avatar
Tri Dao committed
62
        return key
Tri Dao's avatar
Tri Dao committed
63

Tri Dao's avatar
Tri Dao committed
64
65
66
67
68
    state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())

    # Attention
    for l in range(config.n_layer):
        # We don't store these biases
Tri Dao's avatar
Tri Dao committed
69
70
        state_dict.pop(f"transformer.layers.{l}.attention.bias")
        state_dict.pop(f"transformer.layers.{l}.attention.masked_bias")
Tri Dao's avatar
Tri Dao committed
71
72
73
        # GPT-NeoX stores Wqkv as ((nheads 3 headdim), hidden_dim)
        # while we store Wqkv as ((3 nheads headdim), hidden_dim)
        headdim = config.hidden_size // config.num_attention_heads
Tri Dao's avatar
Tri Dao committed
74
75
76
77
78
79
        Wqkv = state_dict.pop(f"transformer.layers.{l}.attention.query_key_value.weight")
        state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = rearrange(
            Wqkv,
            "(nheads three headdim) ... -> (three nheads headdim) ...",
            three=3,
            headdim=headdim,
Tri Dao's avatar
Tri Dao committed
80
        )
Tri Dao's avatar
Tri Dao committed
81
82
83
        bqkv = state_dict.pop(f"transformer.layers.{l}.attention.query_key_value.bias")
        state_dict[f"transformer.layers.{l}.mixer.Wqkv.bias"] = rearrange(
            bqkv, "(nheads three headdim) -> (three nheads headdim)", three=3, headdim=headdim
Tri Dao's avatar
Tri Dao committed
84
        )
Tri Dao's avatar
Tri Dao committed
85

Tri Dao's avatar
Tri Dao committed
86
    def key_mapping_attn(key):
Tri Dao's avatar
Tri Dao committed
87
88
89
90
91
92
93
94
95
96
        key = re.sub(
            r"^transformer.layers.(\d+).attention.dense.",
            r"transformer.layers.\1.mixer.out_proj.",
            key,
        )
        key = re.sub(
            r"^transformer.layers.(\d+).attention.rotary_emb.",
            r"transformer.layers.\1.mixer.rotary_emb.",
            key,
        )
Tri Dao's avatar
Tri Dao committed
97
        return key
Tri Dao's avatar
Tri Dao committed
98

Tri Dao's avatar
Tri Dao committed
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
    state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())

    return state_dict


def gpt_neox_config_to_gpt2_config(gpt_neox_config: GPTNeoXConfig) -> GPT2Config:
    assert gpt_neox_config.rotary_emb_base == 10000
    return GPT2Config(
        vocab_size=gpt_neox_config.vocab_size,
        n_positions=0,  # No absolute position embedding
        n_embd=gpt_neox_config.hidden_size,
        n_layer=gpt_neox_config.num_hidden_layers,
        n_head=gpt_neox_config.num_attention_heads,
        n_inner=gpt_neox_config.intermediate_size,
        activation_function=gpt_neox_config.hidden_act,
        resid_pdrop=0.0,  # No dropout
        embd_pdrop=0.0,
        attn_pdrop=0.0,
        layer_norm_epsilon=gpt_neox_config.layer_norm_eps,
        initializer_range=gpt_neox_config.initializer_range,
        bos_token_id=gpt_neox_config.bos_token_id,
        eos_token_id=gpt_neox_config.eos_token_id,
        # These are new arguments not in the original GPT2Config
        prenorm=True,
        parallel_block=gpt_neox_config.use_parallel_residual,
        parallel_block_tied_norm=False,
        rotary_emb_fraction=gpt_neox_config.rotary_pct,
        tie_word_embeddings=gpt_neox_config.tie_word_embeddings,
    )