gpt.py 40.7 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
# Copyright (c) 2023, Tri Dao.
Tri Dao's avatar
Tri Dao committed
2

3
import logging
Tri Dao's avatar
Tri Dao committed
4
import math
5
import re
Tri Dao's avatar
Tri Dao committed
6
7
from functools import partial

8
from collections import namedtuple, OrderedDict
Tri Dao's avatar
Tri Dao committed
9
10
11
12
13
14
from collections.abc import Sequence

import torch
import torch.nn as nn
import torch.nn.functional as F

Tri Dao's avatar
Tri Dao committed
15
from transformers import GPT2Config
Tri Dao's avatar
Tri Dao committed
16

17
18
from einops import rearrange

Tri Dao's avatar
Tri Dao committed
19
from flash_attn.ops.activations import sqrelu_fwd
20
from flash_attn.modules.mha import MHA, ParallelMHA
21
22
from flash_attn.modules.mlp import Mlp, ParallelMLP, FusedMLP, ParallelFusedMLP
from flash_attn.modules.mlp import GatedMlp, ParallelGatedMlp
Tri Dao's avatar
Tri Dao committed
23
from flash_attn.modules.block import Block, ParallelBlock
24
from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings
25
from flash_attn.utils.distributed import sync_shared_params, all_gather_raw
26
from flash_attn.utils.pretrained import state_dict_from_pretrained
Tri Dao's avatar
Tri Dao committed
27
from flash_attn.utils.generation import GenerationMixin
Tri Dao's avatar
Tri Dao committed
28
29
from flash_attn.models.opt import remap_state_dict_hf_opt
from flash_attn.models.gptj import remap_state_dict_hf_gptj
Tri Dao's avatar
Tri Dao committed
30
from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox
Tri Dao's avatar
Tri Dao committed
31
from flash_attn.models.falcon import remap_state_dict_hf_falcon
32
33
34
35
36

try:
    from flash_attn.ops.fused_dense import ColumnParallelLinear
except ImportError:
    ColumnParallelLinear = None
Tri Dao's avatar
Tri Dao committed
37
38
39
40
41
42

try:
    from flash_attn.ops.layer_norm import dropout_add_layer_norm
except ImportError:
    dropout_add_layer_norm = None

43
44
45
46
47
try:
    from flash_attn.ops.layer_norm import dropout_add_layer_norm_parallel_residual
except ImportError:
    dropout_add_layer_norm_parallel_residual = None

Tri Dao's avatar
Tri Dao committed
48
49
50
try:
    from flash_attn.ops.rms_norm import RMSNorm, dropout_add_rms_norm
except ImportError:
51
    RMSNorm, dropout_add_rms_norm = None, None
Tri Dao's avatar
Tri Dao committed
52
53
54
55
56
57

try:
    from flash_attn.ops.rms_norm import dropout_add_rms_norm_parallel_residual
except ImportError:
    dropout_add_rms_norm_parallel_residual = None

Tri Dao's avatar
Tri Dao committed
58
try:
Tri Dao's avatar
Tri Dao committed
59
    from flash_attn.ops.triton.mlp import FusedDenseSqreluDense
Tri Dao's avatar
Tri Dao committed
60
61
62
63
except ImportError:
    FusedDenseSqreluDense = None


64
65
66
logger = logging.getLogger(__name__)


67
68
def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dtype=None):
    factory_kwargs = {'device': device, 'dtype': dtype}
Tri Dao's avatar
Tri Dao committed
69
70
71
72
73
74
    head_dim = getattr(config, 'head_dim', config.hidden_size // config.num_attention_heads)
    softmax_scale = 1.0 if not config.scale_attn_weights else head_dim ** (-0.5)
    if config.scale_attn_by_inverse_layer_idx:
        assert layer_idx is not None
        softmax_scale /= float(layer_idx + 1)
    dwconv = getattr(config, 'attn_dwconv', False)
75
76
    if dwconv:
        assert process_group is None, 'TensorParallel MHA does not support dwconv yet'
Tri Dao's avatar
Tri Dao committed
77
78
    qkv_proj_bias = getattr(config, 'qkv_proj_bias', True)
    out_proj_bias = getattr(config, 'out_proj_bias', True)
Tri Dao's avatar
Tri Dao committed
79
    rotary_emb_dim = int(getattr(config, 'rotary_emb_fraction', 0.0) * head_dim)
80
    rotary_emb_base = getattr(config, 'rotary_emb_base', 10000.0)
Tri Dao's avatar
Tri Dao committed
81
82
    rotary_emb_scale_base = getattr(config, 'rotary_emb_scale_base', None)
    rotary_emb_interleaved = getattr(config, 'rotary_emb_interleaved', False)
Tri Dao's avatar
Tri Dao committed
83
84
    use_flash_attn = getattr(config, 'use_flash_attn', False)
    fused_bias_fc = getattr(config, 'fused_bias_fc', False)
85
86
87
88
89
    if not fused_bias_fc:
        assert process_group is None, 'TensorParallel MHA requires fused_bias_fc'
    mha_cls = MHA if process_group is None else ParallelMHA
    serial_kwargs = ({'fused_bias_fc': fused_bias_fc, 'dwconv': dwconv}
                     if process_group is None else {})
90
91
92
    parallel_kwargs = ({'process_group': process_group,
                        'sequence_parallel': getattr(config, 'sequence_parallel', True)}
                       if process_group is not None else {})
Tri Dao's avatar
Tri Dao committed
93
    num_heads_kv = getattr(config, "n_head_kv", None)
Tri Dao's avatar
Tri Dao committed
94
    mixer_cls = partial(mha_cls, num_heads=config.num_attention_heads,
Tri Dao's avatar
Tri Dao committed
95
                        num_heads_kv=num_heads_kv,
Tri Dao's avatar
Tri Dao committed
96
97
                        qkv_proj_bias=qkv_proj_bias, out_proj_bias=out_proj_bias,
                        dropout=config.attn_pdrop,
Tri Dao's avatar
Tri Dao committed
98
                        softmax_scale=softmax_scale, causal=True, layer_idx=layer_idx,
99
100
                        rotary_emb_dim=rotary_emb_dim, rotary_emb_base=rotary_emb_base,
                        rotary_emb_scale_base=rotary_emb_scale_base,
Tri Dao's avatar
Tri Dao committed
101
                        rotary_emb_interleaved=rotary_emb_interleaved,
102
103
                        use_flash_attn=use_flash_attn,
                        **serial_kwargs, **parallel_kwargs, **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
104
105
106
    return mixer_cls


107
108
def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtype=None):
    factory_kwargs = {'device': device, 'dtype': dtype}
Tri Dao's avatar
Tri Dao committed
109
110
    mlp_fc1_bias = getattr(config, 'mlp_fc1_bias', True)
    mlp_fc2_bias = getattr(config, 'mlp_fc2_bias', True)
111
112
    fused_mlp = getattr(config, 'fused_mlp', False)
    if fused_mlp:
113
        assert config.activation_function in ['gelu_new', 'gelu_fast', 'gelu_approx', 'relu', 'sqrelu']
Tri Dao's avatar
Tri Dao committed
114
    fused_dense_sqrelu_dense = getattr(config, 'fused_dense_sqrelu_dense', False)
115
116
117
    if fused_dense_sqrelu_dense:
        assert config.activation_function == 'sqrelu', ('fused_dense_sqrelu_dense only '
                                               'supports approximate activation_function sqrelu')
118
119
    assert not (fused_dense_sqrelu_dense and fused_mlp)
    if not fused_mlp and not fused_dense_sqrelu_dense:
Tri Dao's avatar
Tri Dao committed
120
        assert config.activation_function in ['gelu', 'gelu_new', 'gelu_fast', 'gelu_approx', 'relu',
Tri Dao's avatar
Tri Dao committed
121
122
123
124
125
                                              'sqrelu', 'glu', 'swiglu', 'geglu']
        if config.activation_function in ['glu', 'swiglu', 'geglu']:
            activation = (F.sigmoid if config.activation_function == 'glu'
                          else (F.silu if config.activation_function == 'swiglu'
                                else F.gelu))
126
127
128
129
130
131
132
            mlp_cls = GatedMlp if process_group is None else ParallelGatedMlp
            parallel_kwargs = ({'process_group': process_group,
                                'sequence_parallel': getattr(config, 'sequence_parallel', True)}
                               if process_group is not None else {})
            mlp_cls = partial(mlp_cls, hidden_features=config.n_inner, activation=activation,
                              bias1=mlp_fc1_bias, bias2=mlp_fc2_bias,
                              **parallel_kwargs, **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
133
        else:
Tri Dao's avatar
Tri Dao committed
134
135
136
137
138
139
140
141
            if config.activation_function == 'relu':
                activation = partial(F.relu, inplace=True)
            elif config.activation_function == 'sqrelu':
                activation = sqrelu_fwd
            else:
                approximate = ('tanh' if config.activation_function
                            in ['gelu_new', 'gelu_fast', 'gelu_approx'] else 'none')
                activation=partial(F.gelu, approximate=approximate)
Tri Dao's avatar
Tri Dao committed
142
143
144
145
146
147
148
            mlp_cls = Mlp if process_group is None else ParallelMLP
            parallel_kwargs = ({'process_group': process_group,
                                'sequence_parallel': getattr(config, 'sequence_parallel', True)}
                               if process_group is not None else {})
            mlp_cls = partial(mlp_cls, hidden_features=config.n_inner, activation=activation,
                              bias1=mlp_fc1_bias, bias2=mlp_fc2_bias,
                              **parallel_kwargs, **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
149
150
151
152
153
154
    else:
        mlp_checkpoint_lvl = getattr(config, 'mlp_checkpoint_lvl', 0)
        # mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
        if isinstance(mlp_checkpoint_lvl, Sequence):
            assert layer_idx is not None
            mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
155
156
        if fused_mlp:
            if FusedMLP is None:
Tri Dao's avatar
Tri Dao committed
157
                raise ImportError('fused_dense is not installed')
158
            activation = ('gelu_approx' if config.activation_function
159
                          in ['gelu_new', 'gelu_fast', 'gelu_approx'] else config.activation_function)
160
            mlp_cls = FusedMLP if process_group is None else ParallelFusedMLP
161
162
163
            parallel_kwargs = ({'process_group': process_group,
                                'sequence_parallel': getattr(config, 'sequence_parallel', True)}
                               if process_group is not None else {})
Tri Dao's avatar
Tri Dao committed
164
            mlp_cls = partial(mlp_cls, hidden_features=config.n_inner, activation=activation,
165
                              checkpoint_lvl=mlp_checkpoint_lvl,
Tri Dao's avatar
Tri Dao committed
166
                              bias1=mlp_fc1_bias, bias2=mlp_fc2_bias,
167
                              **parallel_kwargs, **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
168
        elif fused_dense_sqrelu_dense:
169
170
            if process_group is not None:
                assert fused_mlp, 'Tensor Parallel is not implemented for FusedDenseSqreluDense'
Tri Dao's avatar
Tri Dao committed
171
            assert FusedDenseSqreluDense is not None
Tri Dao's avatar
Tri Dao committed
172
            mlp_cls = partial(FusedDenseSqreluDense, hidden_features=config.n_inner,
173
                              checkpoint_lvl=mlp_checkpoint_lvl, **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
174
175
176
177
178
        else:
            raise RuntimeError('MLP type not supported')
    return mlp_cls


179
180
def create_block(config, layer_idx=None, process_group=None, device=None, dtype=None):
    factory_kwargs = {'device': device, 'dtype': dtype}
181
    sequence_parallel = getattr(config, 'sequence_parallel', True)
182
183
    mixer_cls = create_mixer_cls(config, layer_idx, process_group=process_group, **factory_kwargs)
    mlp_cls = create_mlp_cls(config, layer_idx, process_group=process_group, **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
184
185
186
    use_rms_norm = getattr(config, 'rms_norm', False)
    norm_cls = partial(nn.LayerNorm if not use_rms_norm else RMSNorm,
                       eps=config.layer_norm_epsilon, **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
187
188
189
190
    # TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
    residual_in_fp32 = getattr(config, 'residual_in_fp32', False)
    resid_dropout1 = config.resid_pdrop if layer_idx is None or layer_idx > 0 else config.embd_pdrop
    prenorm = getattr(config, 'prenorm', True)
Tri Dao's avatar
Tri Dao committed
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
    parallel_block = getattr(config, 'parallel_block', False)
    if not parallel_block:
        block = Block(
            config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls,
            prenorm=prenorm, resid_dropout1=resid_dropout1, resid_dropout2=config.resid_pdrop,
            fused_dropout_add_ln=getattr(config, 'fused_dropout_add_ln', False),
            residual_in_fp32=residual_in_fp32,
            sequence_parallel=sequence_parallel and process_group is not None,
            mark_shared_params=process_group is not None
        )
    else:
        assert prenorm
        block = ParallelBlock(
            config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls,
            resid_dropout1=resid_dropout1, resid_dropout2=config.resid_pdrop,
            tied_norm=getattr(config, 'parallel_block_tied_norm', False),
            fused_dropout_add_ln=getattr(config, 'fused_dropout_add_ln', False),
            residual_in_fp32=residual_in_fp32,
            sequence_parallel=sequence_parallel and process_group is not None,
            mark_shared_params=process_group is not None
        )
Tri Dao's avatar
Tri Dao committed
212
213
214
215
    block.layer_idx = layer_idx
    return block


216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
class GPTPreTrainedModel(nn.Module):
    """ An abstract class to handle weights initialization and
        a simple interface for dowloading and loading pretrained models.
    """
    def __init__(self, config, *inputs, **kwargs):
        super().__init__()
        if not isinstance(config, GPT2Config):
            raise ValueError(
                "Parameter config in `{}(config)` should be an instance of class `GPT2Config`. "
                "To create a model from a Google pretrained model use "
                "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
                    self.__class__.__name__, self.__class__.__name__
                ))
        self.config = config

    @classmethod
232
233
    def from_pretrained(cls, model_name, config, *args, strict=True, device=None, dtype=None,
                        world_size=1, rank=0, **kwargs):
234
235
236
237
238
        """
        Instantiate a GPTPreTrainedModel from a pre-trained model file or a pytorch state dict.
        Download and cache the pre-trained model file if needed.
        """
        # Instantiate model.
239
        model = cls(config, *args, device=device, dtype=dtype, **kwargs)
240
241
        # Load state_dict in cpu because we already initialized the model in GPU, and we don't
        # want extra stuff taking up more GPU memory
Tri Dao's avatar
Tri Dao committed
242
        state_dict = state_dict_from_pretrained(
243
            model_name, device='cpu', dtype=dtype
244
        )
Tri Dao's avatar
Tri Dao committed
245
        if model_name.startswith('gpt2'):
Tri Dao's avatar
Tri Dao committed
246
            state_dict = remap_state_dict_hf_gpt2(state_dict, config)
Tri Dao's avatar
Tri Dao committed
247
        elif model_name.startswith('facebook/opt'):
Tri Dao's avatar
Tri Dao committed
248
249
250
            state_dict = remap_state_dict_hf_opt(state_dict, config)
        elif model_name.startswith('EleutherAI/gpt-j-'):
            state_dict = remap_state_dict_hf_gptj(state_dict, config)
Tri Dao's avatar
Tri Dao committed
251
252
        elif model_name.startswith('EleutherAI/gpt-neox-'):
            state_dict = remap_state_dict_hf_gpt_neox(state_dict, config)
Tri Dao's avatar
Tri Dao committed
253
254
        elif model_name.startswith('tiiuae/falcon-'):
            state_dict = remap_state_dict_hf_falcon(state_dict, config)
Tri Dao's avatar
Tri Dao committed
255
256
        else:
            raise NotImplementedError(f'Model {model_name} not supported')
257
258
259
        if world_size > 1:
            state_dict = shard_state_dict_tp(state_dict, config, world_size, rank)
        load_return = model.load_state_dict(state_dict, strict=strict)
260
261
262
        logger.info(load_return)
        return model

Tri Dao's avatar
Tri Dao committed
263

Tri Dao's avatar
Tri Dao committed
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
def _init_weights(module, n_layer, initializer_range=0.02, rescale_prenorm_residual=True):
    if isinstance(module, nn.Linear):
        nn.init.normal_(module.weight, std=initializer_range)
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Embedding):
        nn.init.normal_(module.weight, std=initializer_range)

    if rescale_prenorm_residual:
        # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
        #   > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
        #   > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
        #   >   -- GPT-2 :: https://openai.com/blog/better-language-models/
        #
        # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
        for name, p in module.named_parameters():
            if name in ["out_proj.weight", "fc2.weight"]:
                # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
                nn.init.normal_(p, mean=0.0, std=initializer_range / math.sqrt(2 * n_layer))


286
class GPTModel(GPTPreTrainedModel):
Tri Dao's avatar
Tri Dao committed
287

288
    def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=None):
289
        super().__init__(config)
290
291
        factory_kwargs = {'device': device, 'dtype': dtype}
        self.process_group = process_group
292
        self.sequence_parallel = getattr(config, 'sequence_parallel', True)
293
        assert config.activation_function in ['gelu', 'gelu_new', 'gelu_fast', 'gelu_approx',
Tri Dao's avatar
Tri Dao committed
294
                                              'relu', 'sqrelu', 'glu', 'swiglu', 'geglu']
295
296
297
        pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
        vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple)
                      * pad_vocab_size_multiple)
Tri Dao's avatar
Tri Dao committed
298
299
300
301
        # TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
        self.residual_in_fp32 = getattr(config, 'residual_in_fp32', False)
        # These 2 options are for OPT-350m
        self.prenorm = getattr(config, 'prenorm', True)
Tri Dao's avatar
Tri Dao committed
302
        use_rms_norm = getattr(config, 'rms_norm', False)
Tri Dao's avatar
Tri Dao committed
303
        word_embed_proj_dim = getattr(config, 'word_embed_proj_dim', None)
Tri Dao's avatar
Tri Dao committed
304
305
        # For GPT-J, GPT-NeoX
        self.parallel_block = getattr(config, 'parallel_block', False)
Tri Dao's avatar
Tri Dao committed
306

307
        if process_group is None:
Tri Dao's avatar
Tri Dao committed
308
309
310
311
            self.embeddings = GPT2Embeddings(
                config.hidden_size, vocab_size, config.max_position_embeddings,
                word_embed_proj_dim=word_embed_proj_dim, **factory_kwargs
            )
312
313
        else:
            self.embeddings = ParallelGPT2Embeddings(
314
                config.hidden_size, vocab_size, config.max_position_embeddings,
315
316
                process_group=process_group, sequence_parallel=self.sequence_parallel,
                **factory_kwargs
317
            )
Tri Dao's avatar
Tri Dao committed
318

Tri Dao's avatar
Tri Dao committed
319
        # We change the order of dropout, residual and layer norm:
Tri Dao's avatar
Tri Dao committed
320
        # Instead of LN -> Attn / MLP -> Dropout -> Add, we do:
Tri Dao's avatar
Tri Dao committed
321
322
323
        # Dropout -> Add -> LN -> Attn / MLP, returning both the residual branch (output of Add) and
        # the main branch (output of MLP). The model definition is unchanged, but the mapping of the
        # nn.Dropout probabilities are changed.
Tri Dao's avatar
Tri Dao committed
324
        # This is for performance reason: we can fuse dropout + add + layer_norm.
Tri Dao's avatar
Tri Dao committed
325
326
327
328
        self.layers = nn.ModuleList([create_block(config, layer_idx=i, process_group=process_group,
                                                  **factory_kwargs)
                                     for i in range(config.num_hidden_layers)])

Tri Dao's avatar
Tri Dao committed
329
        self.fused_dropout_add_ln = getattr(config, 'fused_dropout_add_ln', False)
330
331
332
333
        if self.fused_dropout_add_ln:
            if ((not self.parallel_block and dropout_add_layer_norm is None)
                or (self.parallel_block and dropout_add_layer_norm_parallel_residual is None)):
                raise ImportError('dropout_layer_norm is not installed')
Tri Dao's avatar
Tri Dao committed
334
335
        if self.prenorm:
            self.drop_f = nn.Dropout(config.resid_pdrop)
Tri Dao's avatar
Tri Dao committed
336
337
338
            norm_cls = nn.LayerNorm if not use_rms_norm else RMSNorm
            self.ln_f = norm_cls(config.hidden_size, eps=config.layer_norm_epsilon,
                                 **factory_kwargs)
339
        if process_group is not None:
Tri Dao's avatar
Tri Dao committed
340
            for p in self.ln_f.parameters():
341
342
343
344
345
                # Mark the norm parameters as "shared_params" so that we sync their values at init.
                p._shared_params = True
                # Mark the norm params as "sequence_parallel" so we run all-reduce on their grads.
                if self.sequence_parallel:
                    p._sequence_parallel = True
346

Tri Dao's avatar
Tri Dao committed
347
348
        self.apply(partial(_init_weights, n_layer=config.num_hidden_layers,
                           initializer_range=config.initializer_range))
349
350
351
        self.tie_weights()

    def tie_weights(self):
352
        if self.process_group is not None:
353
            sync_shared_params(self, self.process_group)
Tri Dao's avatar
Tri Dao committed
354

355
356
357
358
    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
        return {i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
                for i, layer in enumerate(self.layers)}

Tri Dao's avatar
Tri Dao committed
359
    def forward(self, input_ids, position_ids=None, inference_params=None):
360
361
362
363
        # If using Tensor Parallel with sequence parallel, we combine the batch and the seqlen
        # dimensions so that we can split on it easily, in case of small batch size.
        # Only the attention layers need to know the seqlen.
        embedding_kwargs = ({'combine_batch_seqlen_dim': True}
364
                            if self.process_group is not None and self.sequence_parallel else {})
365
        hidden_states = self.embeddings(input_ids, position_ids=position_ids, **embedding_kwargs)
Tri Dao's avatar
Tri Dao committed
366
367
        if self.parallel_block:
            hidden_states2 = None
Tri Dao's avatar
Tri Dao committed
368
        residual = None
369
370
        mixer_kwargs = ({'seqlen': input_ids.shape[1]}
                        if self.process_group is not None and self.sequence_parallel else {})
Tri Dao's avatar
Tri Dao committed
371
372
        if inference_params is not None:
            mixer_kwargs['inference_params'] = inference_params
Tri Dao's avatar
Tri Dao committed
373
        for layer in self.layers:
Tri Dao's avatar
Tri Dao committed
374
            if self.prenorm:
Tri Dao's avatar
Tri Dao committed
375
376
377
378
379
380
381
                if not self.parallel_block:
                    hidden_states, residual = layer(hidden_states, residual,
                                                    mixer_kwargs=mixer_kwargs)
                else:
                    hidden_states, hidden_states2, residual = layer(
                        hidden_states, hidden_states2, residual, mixer_kwargs=mixer_kwargs
                    )
Tri Dao's avatar
Tri Dao committed
382
383
384
385
386
            else:
                hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
        if self.prenorm:
            if not self.fused_dropout_add_ln:
                dropped = self.drop_f(hidden_states)
Tri Dao's avatar
Tri Dao committed
387
388
389
390
391
392
                if not self.parallel_block:
                    residual = (dropped + residual) if residual is not None else dropped
                else:
                    dropped2 = self.drop_f(hidden_states2)
                    residual = ((residual + dropped + dropped2)
                                if residual is not None else dropped + dropped2)
Tri Dao's avatar
Tri Dao committed
393
394
                hidden_states = self.ln_f(residual.to(dtype=self.ln_f.weight.dtype))
            else:
Tri Dao's avatar
Tri Dao committed
395
                # Set prenorm=False here since we don't need the residual
396
                if not self.parallel_block:
397
398
399
                    fused_add_norm_fn = (dropout_add_rms_norm if isinstance(self.ln_f, RMSNorm)
                                         else dropout_add_layer_norm)
                    hidden_states = fused_add_norm_fn(
400
401
402
403
404
                        hidden_states, residual, self.ln_f.weight, self.ln_f.bias,
                        self.drop_f.p if self.training else 0.0, self.ln_f.eps, prenorm=False,
                        residual_in_fp32=self.residual_in_fp32
                    )
                else:
405
406
407
408
                    fused_add_norm_fn = (dropout_add_rms_norm_parallel_residual
                                         if isinstance(self.ln_f, RMSNorm)
                                         else dropout_add_layer_norm_parallel_residual)
                    hidden_states, _ = fused_add_norm_fn(
409
410
411
412
                        hidden_states, hidden_states2, residual, self.ln_f.weight, self.ln_f.bias,
                        None, None, self.drop_f.p if self.training else 0.0, self.ln_f.eps,
                        prenorm=False, residual_in_fp32=self.residual_in_fp32
                    )
Tri Dao's avatar
Tri Dao committed
413
414
415
        return hidden_states


Tri Dao's avatar
Tri Dao committed
416
class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
Tri Dao's avatar
Tri Dao committed
417

418
419
    def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=None):
        factory_kwargs = {'device': device, 'dtype': dtype}
420
        super().__init__(config)
421
422
        self.process_group = process_group
        self.transformer = GPTModel(config, process_group=process_group, **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
423
        self.tie_word_embeddings = getattr(config, 'tie_word_embeddings', True)
Tri Dao's avatar
Tri Dao committed
424
        lm_head_bias = getattr(config, 'lm_head_bias', False)
425
426
427
        pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
        vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple)
                      * pad_vocab_size_multiple)
Tri Dao's avatar
Tri Dao committed
428
429
430
431
432
433
434
        # This option is for OPT-350m
        word_embed_proj_dim = getattr(config, 'word_embed_proj_dim', None)
        embed_dim = config.n_embd if word_embed_proj_dim is None else word_embed_proj_dim
        if word_embed_proj_dim is not None:
            self.project_out = nn.Linear(config.n_embd, embed_dim, bias=False, **factory_kwargs)
        else:
            self.project_out = None
435
        if process_group is None:
Tri Dao's avatar
Tri Dao committed
436
            self.lm_head = nn.Linear(embed_dim, vocab_size, bias=lm_head_bias, **factory_kwargs)
437
438
439
        else:
            if ColumnParallelLinear is None:
                raise ImportError('fused_dense_lib is not installed')
440
            self.lm_head = ColumnParallelLinear(
Tri Dao's avatar
Tri Dao committed
441
                embed_dim, vocab_size, process_group, bias=lm_head_bias,
442
443
                sequence_parallel=getattr(config, 'sequence_parallel', True), **factory_kwargs
            )
Tri Dao's avatar
Tri Dao committed
444
445
446
447
448
449
        # Initialize weights and apply final processing
        self.apply(partial(_init_weights, n_layer=config.num_hidden_layers,
                           initializer_range=config.initializer_range))
        self.tie_weights()

    def tie_weights(self):
Tri Dao's avatar
Tri Dao committed
450
451
        if self.tie_word_embeddings:
            self.lm_head.weight = self.transformer.embeddings.word_embeddings.weight
452
        if self.process_group is not None:
453
            sync_shared_params(self, self.process_group)
Tri Dao's avatar
Tri Dao committed
454

455
456
457
458
    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
        return self.transformer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype,
                                                         **kwargs)

459
    def forward(self, input_ids, position_ids=None, inference_params=None, last_token_only=False):
Tri Dao's avatar
Tri Dao committed
460
461
462
        """
            inference_params: for generation. Adapted from Megatron-LM (and Apex)
            https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
463
464
            last_token_only: whether to return the logit for the last token only,
                of shape (batch_size, vocab_size)
Tri Dao's avatar
Tri Dao committed
465
466
467
        """
        hidden_states = self.transformer(input_ids, position_ids=position_ids,
                                         inference_params=inference_params)
468
469
        if last_token_only:
            hidden_states = hidden_states[:, -1]
Tri Dao's avatar
Tri Dao committed
470
471
        if self.project_out is not None:
            hidden_states = self.project_out(hidden_states)
Tri Dao's avatar
Tri Dao committed
472
        lm_logits = self.lm_head(hidden_states)
473
474
475
        # During inference, we want the full logit for sampling
        if isinstance(self.lm_head, ColumnParallelLinear) and inference_params is not None:
            lm_logits, _ = all_gather_raw(lm_logits, self.lm_head.process_group)
476
            lm_logits = rearrange(lm_logits, '(n b) ... d -> b ... (n d)', b=hidden_states.shape[0])
Tri Dao's avatar
Tri Dao committed
477
478
        CausalLMOutput = namedtuple('CausalLMOutput', ['logits'])
        return CausalLMOutput(logits=lm_logits)
479

Tri Dao's avatar
Tri Dao committed
480
481
482
483
484
    def load_state_dict(self, state_dict, strict=True):
        # Remapping from our checkpoints that used a different ordering of layers in the block
        # Previous: Attn / MLP -> Dropout -> Add -> LN
        # Current: Dropout -> Add -> LN -> Attn / MLP
        if 'transformer.ln_0.weight' in state_dict:
Tri Dao's avatar
Tri Dao committed
485
            n_layers = len(self.transformer.layers)
Tri Dao's avatar
Tri Dao committed
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
            ln_weight = state_dict.pop(f'transformer.layers.{n_layers - 1}.norm2.weight')
            ln_bias = state_dict.pop(f'transformer.layers.{n_layers - 1}.norm2.bias')
            state_dict['transformer.ln_f.weight'] = ln_weight
            state_dict['transformer.ln_f.bias'] = ln_bias
            for l in reversed(range(n_layers)):
                ln_weight = state_dict.pop(f'transformer.layers.{l}.norm1.weight')
                ln_bias = state_dict.pop(f'transformer.layers.{l}.norm1.bias')
                state_dict[f'transformer.layers.{l}.norm2.weight'] = ln_weight
                state_dict[f'transformer.layers.{l}.norm2.bias'] = ln_bias
                if l > 0:
                    ln_weight = state_dict.pop(f'transformer.layers.{l - 1}.norm2.weight')
                    ln_bias = state_dict.pop(f'transformer.layers.{l - 1}.norm2.bias')
                    state_dict[f'transformer.layers.{l}.norm1.weight'] = ln_weight
                    state_dict[f'transformer.layers.{l}.norm1.bias'] = ln_bias
            ln_weight = state_dict.pop('transformer.ln_0.weight')
            ln_bias = state_dict.pop('transformer.ln_0.bias')
            state_dict[f'transformer.layers.0.norm1.weight'] = ln_weight
            state_dict[f'transformer.layers.0.norm1.bias'] = ln_bias
        return super().load_state_dict(state_dict, strict=strict)

506

Tri Dao's avatar
Tri Dao committed
507
508
509
510
511
512
513
514
515
516
517
518
def shard_state_dict_tp(state_dict, config, world_size, rank):
    """Convert the state_dict of a standard GPT model to the state_dict of a GPT model
    with tensor parallel.
    """
    pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
    vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
    assert vocab_size % world_size == 0
    assert config.hidden_size % world_size == 0
    inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
    assert inner_dim % world_size == 0

    def shard_first_dim(state_dict, key):
Tri Dao's avatar
Tri Dao committed
519
520
521
522
        if key in state_dict:
            x = state_dict[key]
            dim = x.shape[0] // world_size
            state_dict[key] = x[rank * dim:(rank + 1) * dim]
Tri Dao's avatar
Tri Dao committed
523
524

    def shard_last_dim(state_dict, key):
Tri Dao's avatar
Tri Dao committed
525
526
527
528
        if key in state_dict:
            x = state_dict[key]
            dim = x.shape[-1] // world_size
            state_dict[key] = x[..., rank * dim:(rank + 1) * dim]
Tri Dao's avatar
Tri Dao committed
529

Tri Dao's avatar
Tri Dao committed
530
531
532
533
534
535
536
537
538
    def shard_gatedmlp_fc1_dim(state_dict, key):
        if key in state_dict:
            x = state_dict[key]
            dim = x.shape[0] // world_size // 2
            state_dict[key] = rearrange(
                rearrange(x, "(two o) ... -> two o ...", two=2)[:, rank * dim:(rank + 1) * dim],
                "two o ... -> (two o) ..."
            )

Tri Dao's avatar
Tri Dao committed
539
    def shard_qkv_headdim(state_dict, key):
Tri Dao's avatar
Tri Dao committed
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
        if key in state_dict:
            n_head = config.n_head
            n_head_kv = getattr(config, 'n_head_kv', n_head)
            assert n_head % world_size == 0 and n_head_kv % world_size == 0
            if n_head_kv == n_head:
                x = rearrange(state_dict[key], '(three d) ... -> three d ...', three=3)
                dim = x.shape[1] // world_size
                state_dict[key] = rearrange(x[:, rank * dim:(rank + 1) * dim],
                                            'three d ... -> (three d) ...')
            else:
                n_head_per_rank = n_head // world_size
                n_head_kv_per_rank = n_head_kv // world_size
                x = rearrange(state_dict[key], '(nheadqkv headdim) ... -> nheadqkv headdim ...',
                              nheadqkv=n_head + 2 * n_head_kv)
                state_dict[key] = rearrange(torch.cat([
                    x[rank * n_head_per_rank:(rank + 1) * n_head_per_rank],
                    x[n_head + rank * n_head_kv_per_rank:n_head + (rank + 1) * n_head_kv_per_rank],
                    x[n_head + n_head_kv + rank * n_head_kv_per_rank:n_head + n_head_kv + (rank + 1) * n_head_kv_per_rank],
                ], dim=0), "nheadqkv headdim ... -> (nheadqkv headdim) ...")
Tri Dao's avatar
Tri Dao committed
559
560
561
562
563
564
565
566
567
568
569

    shard_first_dim(state_dict, 'transformer.embeddings.word_embeddings.weight')
    if 'lm_head.weight' in state_dict:
        shard_first_dim(state_dict, 'lm_head.weight')
    if 'transformer.embeddings.position_embeddings.weight' in state_dict:
        shard_last_dim(state_dict, 'transformer.embeddings.position_embeddings.weight')
    for i in range(config.num_hidden_layers):
        shard_qkv_headdim(state_dict, f'transformer.layers.{i}.mixer.Wqkv.weight')
        shard_qkv_headdim(state_dict, f'transformer.layers.{i}.mixer.Wqkv.bias')
        shard_last_dim(state_dict, f'transformer.layers.{i}.mixer.out_proj.weight')
        if rank != 0:
Tri Dao's avatar
Tri Dao committed
570
            state_dict.pop(f'transformer.layers.{i}.mixer.out_proj.bias', None)
Tri Dao's avatar
Tri Dao committed
571
572
573
574
575
576
        if config.activation_function in ["glu", "swiglu", "geglu"]:
            shard_gatedmlp_fc1_dim(state_dict, f'transformer.layers.{i}.mlp.fc1.weight')
            shard_gatedmlp_fc1_dim(state_dict, f'transformer.layers.{i}.mlp.fc1.bias')
        else:
            shard_first_dim(state_dict, f'transformer.layers.{i}.mlp.fc1.weight')
            shard_first_dim(state_dict, f'transformer.layers.{i}.mlp.fc1.bias')
Tri Dao's avatar
Tri Dao committed
577
578
        shard_last_dim(state_dict, f'transformer.layers.{i}.mlp.fc2.weight')
        if rank != 0:
Tri Dao's avatar
Tri Dao committed
579
            state_dict.pop(f'transformer.layers.{i}.mlp.fc2.bias', None)
Tri Dao's avatar
Tri Dao committed
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
    return state_dict


def combine_state_dicts_tp(state_dicts, config):
    """Convert the state_dict of a standard GPT model to the state_dict of a GPT model
    with tensor parallel.
    """
    world_size = len(state_dicts)
    keys = state_dicts[0].keys()
    pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
    vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
    assert vocab_size % world_size == 0
    assert config.hidden_size % world_size == 0
    inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
    assert inner_dim % world_size == 0

Tri Dao's avatar
Tri Dao committed
596
    # Sometimes the word embeddings are sharded on the 0th dim, sometimes on the 1st dim.
Tri Dao's avatar
Tri Dao committed
597
598
    # vocab_size // world_size coordinates are nonzero.
    def combine_word_embeddings(state_dicts, state_dict, key):
Tri Dao's avatar
Tri Dao committed
599
600
        dim = 0 if state_dicts[0][key].shape[0] == vocab_size // world_size else 1
        state_dict[key] = torch.cat([s[key] for s in state_dicts], dim=dim)
Tri Dao's avatar
Tri Dao committed
601
602

    def combine_dim(state_dicts, state_dict, key, dim=-1):
Tri Dao's avatar
Tri Dao committed
603
604
        if key in state_dict:
            state_dict[key] = torch.cat([s[key] for s in state_dicts], dim=dim)
Tri Dao's avatar
Tri Dao committed
605
606

    def combine_qkv_headdim(state_dicts, state_dict, key):
Tri Dao's avatar
Tri Dao committed
607
608
609
610
611
        n_head = config.n_head
        n_head_kv = getattr(config, 'n_head_kv', n_head)
        assert n_head % world_size == 0 and n_head_kv % world_size == 0
        n_head_per_rank = n_head // world_size
        n_head_kv_per_rank = n_head_kv // world_size
Tri Dao's avatar
Tri Dao committed
612
        if key in state_dict:
Tri Dao's avatar
Tri Dao committed
613
614
615
616
617
618
619
620
621
622
623
            if n_head_kv == n_head:
                xs = [rearrange(s[key], '(three d) ... -> three d ...', three=3) for s in state_dicts]
                state_dict[key] = rearrange(torch.cat(xs, dim=1), 'three d ... -> (three d) ...')
            else:
                xs = [rearrange(s[key], '(nheadqkv headdim) ... -> nheadqkv headdim ...',
                                nheadqkv=n_head + 2 * n_head_kv) for s in state_dicts]
                state_dict[key] = rearrange(torch.cat([
                    torch.cat([x[:n_head_per_rank] for x in xs], dim=0),
                    torch.cat([x[n_head_per_rank:n_head_per_rank + n_head_kv_per_rank] for x in xs], dim=0),
                    torch.cat([x[-n_head_kv_per_rank:] for x in xs], dim=0),
                ], dim=0), "nheadqkv headdim ... -> (nheadqkv headdim) ...")
Tri Dao's avatar
Tri Dao committed
624
625
626
627
628

    def combine_gated_mlp(state_dicts, state_dict, key):
        if key in state_dict:
            xs = [rearrange(s[key], '(two d) ... -> two d ...', two=2) for s in state_dicts]
            state_dict[key] = rearrange(torch.cat(xs, dim=1), 'two d ... -> (two d) ...')
Tri Dao's avatar
Tri Dao committed
629
630
631
632
633
634
635

    state_dict = state_dicts[0].copy()  # don't modify state_dict[0] inplace
    combine_word_embeddings(state_dicts, state_dict, 'transformer.embeddings.word_embeddings.weight')
    if 'lm_head.weight' in state_dict:
        combine_word_embeddings(state_dicts, state_dict, 'lm_head.weight')
    if 'transformer.embeddings.position_embeddings.weight' in state_dict:
        combine_dim(state_dicts, state_dict, 'transformer.embeddings.position_embeddings.weight', -1)
Tri Dao's avatar
Tri Dao committed
636
637
    mlp_combine_fn = (combine_gated_mlp if config.activation_function in ['glu', 'swiglu', 'geglu']
                      else partial(combine_dim, dim=0))
Tri Dao's avatar
Tri Dao committed
638
639
640
641
    for i in range(config.num_hidden_layers):
        combine_qkv_headdim(state_dicts, state_dict, f'transformer.layers.{i}.mixer.Wqkv.weight')
        combine_qkv_headdim(state_dicts, state_dict, f'transformer.layers.{i}.mixer.Wqkv.bias')
        combine_dim(state_dicts, state_dict, f'transformer.layers.{i}.mixer.out_proj.weight', -1)
Tri Dao's avatar
Tri Dao committed
642
        mlp_combine_fn(state_dicts, state_dict, f'transformer.layers.{i}.mlp.fc1.weight')
Tri Dao's avatar
Tri Dao committed
643
644
645
646
647
648
        combine_dim(state_dicts, state_dict, f'transformer.layers.{i}.mlp.fc1.bias', 0)
        combine_dim(state_dicts, state_dict, f'transformer.layers.{i}.mlp.fc2.weight', -1)
    return state_dict


def remap_state_dict_hf_gpt2(state_dict, config):
649
650
651
652
653
654
    # Word embedding and position embedding
    def key_mapping_pos_emb(key):
        return re.sub(r'^wpe.', 'transformer.embeddings.position_embeddings.', key)
    state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items())
    word_embeddings = state_dict.pop('wte.weight')
    # It's possible that vocab_size is padded to be a multiple of 8, for example.
655
656
    pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
    vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
657
    state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
658
        word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
659
660
661
662
    )
    state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']

    # LayerNorm
Tri Dao's avatar
Tri Dao committed
663
664
665
666
667
    def key_mapping_ln(key):
        key = re.sub(r'^ln_f.(weight|bias)', r'transformer.ln_f.\1', key)
        key = re.sub(r'^h.(\d+).ln_(1|2).(weight|bias)', r'transformer.layers.\1.norm\2.\3', key)
        return key
    state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694

    # MLP
    for d in range(config.num_hidden_layers):
        W1 = state_dict.pop(f'h.{d}.mlp.c_fc.weight')
        state_dict[f'transformer.layers.{d}.mlp.fc1.weight'] = W1.t()
        W2 = state_dict.pop(f'h.{d}.mlp.c_proj.weight')
        state_dict[f'transformer.layers.{d}.mlp.fc2.weight'] = W2.t()
    def key_mapping_mlp(key):
        key = re.sub(r'^h.(\d+).mlp.c_fc.bias', r'transformer.layers.\1.mlp.fc1.bias', key)
        key = re.sub(r'^h.(\d+).mlp.c_proj.bias', r'transformer.layers.\1.mlp.fc2.bias', key)
        return key
    state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())

    # Attention
    for d in range(config.num_hidden_layers):
        state_dict.pop(f'h.{d}.attn.bias')  # We don't store this bias
        Wqkv = state_dict.pop(f'h.{d}.attn.c_attn.weight')
        state_dict[f'transformer.layers.{d}.mixer.Wqkv.weight'] = Wqkv.t()
        Wout = state_dict.pop(f'h.{d}.attn.c_proj.weight')
        state_dict[f'transformer.layers.{d}.mixer.out_proj.weight'] = Wout.t()
    def key_mapping_attn(key):
        key = re.sub(r'^h.(\d+).attn.c_attn.bias', r'transformer.layers.\1.mixer.Wqkv.bias', key)
        key = re.sub(r'^h.(\d+).attn.c_proj.bias', r'transformer.layers.\1.mixer.out_proj.bias', key)
        return key
    state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())

    return state_dict
695
696


Tri Dao's avatar
Tri Dao committed
697
698
699
700
701
702
703
704
705
706
707
708
def remap_state_dict_megatron(state_dict, config):
    def key_mapping_transformer(key):
        key = re.sub(r'^language_model.encoder.', 'transformer.', key)
        key = re.sub(r'^language_model.', 'transformer.', key)
        return key
    state_dict = OrderedDict((key_mapping_transformer(k), v) for k, v in state_dict.items())
    # Word embedding and position embedding
    def key_mapping_pos_emb(key):
        return re.sub(r'^wpe.', 'transformer.embeddings.position_embeddings.', key)
    state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items())
    word_embeddings = state_dict.pop('transformer.embedding.word_embeddings.weight')
    # It's possible that vocab_size is padded to be a multiple of 8, for example.
709
    pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
Tri Dao's avatar
Tri Dao committed
710
711
    vocab_size = (math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple)
                  * pad_vocab_size_multiple)
Tri Dao's avatar
Tri Dao committed
712
713
714
715
    state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
        word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
    )
    state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
716

Tri Dao's avatar
Tri Dao committed
717
718
719
720
721
722
723
724
725
    # LayerNorm
    def key_mapping_ln(key):
        key = re.sub(r'^transformer.final_layernorm.(weight|bias)', r'transformer.ln_f.\1', key)
        key = re.sub(r'^transformer.layers.(\d+).input_layernorm.(weight|bias)',
                     r'transformer.layers.\1.norm1.\2', key)
        key = re.sub(r'^transformer.layers.(\d+).post_attention_layernorm.(weight|bias)',
                     r'transformer.layers.\1.norm2.\2', key)
        return key
    state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
726

Tri Dao's avatar
Tri Dao committed
727
728
729
730
731
732
733
734
    # MLP
    def key_mapping_mlp(key):
        key = re.sub(r'^transformer.layers.(\d+).mlp.dense_h_to_4h.(weight|bias)',
                     r'transformer.layers.\1.mlp.fc1.\2', key)
        key = re.sub(r'^transformer.layers.(\d+).mlp.dense_4h_to_h.(weight|bias)',
                     r'transformer.layers.\1.mlp.fc2.\2', key)
        return key
    state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
735

Tri Dao's avatar
Tri Dao committed
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
    # Attention
    def key_mapping_attn(key):
        key = re.sub(r'^transformer.layers.(\d+).self_attention.rotary_emb.inv_freq',
                     r'transformer.layers.\1.mixer.rotary_emb.inv_freq', key)
        key = re.sub(r'^transformer.layers.(\d+).self_attention.query_key_value.(weight|bias)',
                     r'transformer.layers.\1.mixer.Wqkv.\2', key)
        key = re.sub(r'^transformer.layers.(\d+).self_attention.dense.(weight|bias)',
                     r'transformer.layers.\1.mixer.out_proj.\2', key)
        return key
    state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
    # Megatron stores Wqkv as ((nheads 3 headdim), hidden_dim)
    # while we store Wqkv as ((3 nheads headdim), hidden_dim)
    headdim = config.hidden_size // config.num_attention_heads
    for d in range(config.num_hidden_layers):
        Wqkv = state_dict.pop(f'transformer.layers.{d}.mixer.Wqkv.weight')
        state_dict[f'transformer.layers.{d}.mixer.Wqkv.weight'] = rearrange(
            Wqkv, '(nheads three headdim) ... -> (three nheads headdim) ...',
            three=3, headdim=headdim
        )
        bqkv = state_dict.pop(f'transformer.layers.{d}.mixer.Wqkv.bias')
        state_dict[f'transformer.layers.{d}.mixer.Wqkv.bias'] = rearrange(
            bqkv, '(nheads three headdim) -> (three nheads headdim)',
            three=3, headdim=headdim
        )
760
761

    return state_dict