gpt.py 8.18 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
# Copyright (c) 2022, Tri Dao.

import math
from functools import partial

from collections import namedtuple
from collections.abc import Sequence

import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers.models.gpt2.configuration_gpt2 import GPT2Config

from flash_attn.modules.mha import MHA
from flash_attn.modules.mlp import Mlp, FusedDenseGeluDense
from flash_attn.modules.block import Block
from flash_attn.modules.embedding import GPT2Embeddings

try:
    from flash_attn.ops.layer_norm import dropout_add_layer_norm
except ImportError:
    dropout_add_layer_norm = None

try:
    from flash_attn.ops.triton.mlp import FusedDenseSqreluDense
except ImportError:
    FusedDenseSqreluDense = None


def create_mixer_cls(config, layer_idx=None):
    head_dim = getattr(config, 'head_dim', config.hidden_size // config.num_attention_heads)
    softmax_scale = 1.0 if not config.scale_attn_weights else head_dim ** (-0.5)
    if config.scale_attn_by_inverse_layer_idx:
        assert layer_idx is not None
        softmax_scale /= float(layer_idx + 1)
    dwconv = getattr(config, 'attn_dwconv', False)
    rotary_emb_dim = int(getattr(config, 'rotary_emb_fraction', 0.0) * head_dim)
    use_flash_attn = getattr(config, 'use_flash_attn', False)
    fused_bias_fc = getattr(config, 'fused_bias_fc', False)
    mixer_cls = partial(MHA, num_heads=config.num_attention_heads, dropout=config.attn_pdrop,
                        softmax_scale=softmax_scale, causal=True, dwconv=dwconv,
                        rotary_emb_dim=rotary_emb_dim,
                        fused_bias_fc=fused_bias_fc, use_flash_attn=use_flash_attn)
    return mixer_cls


def create_mlp_cls(config, layer_idx=None):
    inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
    fused_dense_gelu_dense = getattr(config, 'fused_dense_gelu_dense', False)
    fused_dense_sqrelu_dense = getattr(config, 'fused_dense_sqrelu_dense', False)
    assert not (fused_dense_sqrelu_dense and fused_dense_gelu_dense)
    if not fused_dense_gelu_dense and not fused_dense_sqrelu_dense:
        mlp_cls = partial(Mlp, hidden_features=inner_dim,
                          activation=partial(F.gelu, approximate='tanh'))
    else:
        mlp_checkpoint_lvl = getattr(config, 'mlp_checkpoint_lvl', 0)
        # mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
        if isinstance(mlp_checkpoint_lvl, Sequence):
            assert layer_idx is not None
            mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
        if fused_dense_gelu_dense:
            mlp_cls = partial(FusedDenseGeluDense, hidden_features=inner_dim,
                              checkpoint_lvl=mlp_checkpoint_lvl)
        elif fused_dense_sqrelu_dense:
            assert FusedDenseSqreluDense is not None
            mlp_cls = partial(FusedDenseSqreluDense, hidden_features=inner_dim,
                              checkpoint_lvl=mlp_checkpoint_lvl)
        else:
            raise RuntimeError('MLP type not supported')
    return mlp_cls


def create_block(config, layer_idx=None):
    mixer_cls = create_mixer_cls(config, layer_idx)
    mlp_cls = create_mlp_cls(config, layer_idx)
    norm_cls = partial(nn.LayerNorm, eps=config.layer_norm_epsilon)
    block = Block(config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls,
                  prenorm=True, resid_dropout=config.resid_pdrop,
                  fused_dropout_add_ln=getattr(config, 'fused_dropout_add_ln', False))
    block.layer_idx = layer_idx
    return block


# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
def _init_weights(module, n_layer, initializer_range=0.02, rescale_prenorm_residual=True):
    if isinstance(module, nn.Linear):
        nn.init.normal_(module.weight, std=initializer_range)
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Embedding):
        nn.init.normal_(module.weight, std=initializer_range)

    if rescale_prenorm_residual:
        # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
        #   > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
        #   > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
        #   >   -- GPT-2 :: https://openai.com/blog/better-language-models/
        #
        # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
        for name, p in module.named_parameters():
            if name in ["out_proj.weight", "fc2.weight"]:
                # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
                nn.init.normal_(p, mean=0.0, std=initializer_range / math.sqrt(2 * n_layer))


class GPT2Model(nn.Module):

    def __init__(self, config: GPT2Config):
        super().__init__()
        self.pad_vocab_size_multiple_8 = getattr(config, 'pad_vocab_size_multiple_8', False)
        if self.pad_vocab_size_multiple_8:
            if config.vocab_size % 8 != 0:
                config.vocab_size += 8 - (config.vocab_size % 8)

        self.embeddings = GPT2Embeddings(config.hidden_size, config.vocab_size,
                                         config.max_position_embeddings)
        self.emb_drop = nn.Dropout(config.embd_pdrop)

        # We change the order of residual and layer norm:
        # Instead of LN -> Attn / MLP -> Dropout -> Add, we do:
        # Attn / MLP -> Dropout -> Add -> LN, returning both the residual branch (output of Add) and
        # the main branch (output of LN). The model definition is unchanged, but the mapping of the
        # nn.LayerNorm weights are changed.
        # This is for performance reason: we can fuse dropout + add + layer_norm.
        self.fused_dropout_add_ln = getattr(config, 'fused_dropout_add_ln', False)
        if self.fused_dropout_add_ln and dropout_add_layer_norm is None:
            raise ImportError('dropout_add_layer_norm is not installed')
        # self.ln_0 is the first layer norm in the model, while self.ln_f (in the pretrained weight)
        # is the final layer norm.
        self.ln_0 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)

        self.layers = nn.ModuleList([create_block(config, layer_idx=i)
                                     for i in range(config.num_hidden_layers)])

        self.apply(partial(_init_weights, n_layer=config.num_hidden_layers,
                           initializer_range=config.initializer_range))

    def forward(self, input_ids, position_ids=None):
        hidden_states = self.embeddings(input_ids, position_ids=position_ids)
        # TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
        if not self.fused_dropout_add_ln:
            residual = self.emb_drop(hidden_states).float()
            hidden_states = self.ln_0(residual.to(dtype=self.ln_0.weight.dtype))
        else:
            hidden_states, residual = dropout_add_layer_norm(
                hidden_states, None, self.ln_0.weight, self.ln_0.bias,
                self.emb_drop.p if self.training else 0.0, self.ln_0.eps, prenorm=True,
                residual_in_fp32=True
            )
        for layer in self.layers:
            hidden_states, residual = layer(hidden_states, residual)
        return hidden_states


class GPT2LMHeadModel(nn.Module):

    def __init__(self, config: GPT2Config):
        super().__init__()
        self.transformer = GPT2Model(config)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        # Initialize weights and apply final processing
        self.apply(partial(_init_weights, n_layer=config.num_hidden_layers,
                           initializer_range=config.initializer_range))
        self.tie_weights()

    def tie_weights(self):
        self.lm_head.weight = self.transformer.embeddings.word_embeddings.weight

    def forward(self, input_ids, position_ids=None):
        hidden_states = self.transformer(input_ids, position_ids=position_ids)
        lm_logits = self.lm_head(hidden_states)
        CausalLMOutput = namedtuple('CausalLMOutput', ['logits'])
        return CausalLMOutput(logits=lm_logits)