gpt.py 45.1 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
# Copyright (c) 2023, Tri Dao.
Tri Dao's avatar
Tri Dao committed
2

3
import logging
Tri Dao's avatar
Tri Dao committed
4
import math
5
import re
Tri Dao's avatar
Tri Dao committed
6
from collections import OrderedDict, namedtuple
Tri Dao's avatar
Tri Dao committed
7
from collections.abc import Sequence
Tri Dao's avatar
Tri Dao committed
8
from functools import partial
Tri Dao's avatar
Tri Dao committed
9
10
11
12

import torch
import torch.nn as nn
import torch.nn.functional as F
13
from einops import rearrange
Tri Dao's avatar
Tri Dao committed
14
15
from transformers import GPT2Config

Tri Dao's avatar
Tri Dao committed
16
17
18
from flash_attn.models.falcon import remap_state_dict_hf_falcon
from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox
from flash_attn.models.gptj import remap_state_dict_hf_gptj
19
from flash_attn.models.llama import remap_state_dict_hf_llama
Tri Dao's avatar
Tri Dao committed
20
from flash_attn.models.opt import remap_state_dict_hf_opt
Tri Dao's avatar
Tri Dao committed
21
from flash_attn.modules.block import Block, ParallelBlock
22
from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings
Tri Dao's avatar
Tri Dao committed
23
from flash_attn.modules.mha import MHA, ParallelMHA
24
25
from flash_attn.modules.mlp import (FusedMLP, GatedMlp, Mlp, ParallelFusedMLP,
                                    ParallelGatedMlp, ParallelMLP)
Tri Dao's avatar
Tri Dao committed
26
from flash_attn.ops.activations import sqrelu_fwd
27
28
29
from flash_attn.utils.distributed import (all_gather_raw,
                                          get_dim_for_local_rank,
                                          sync_shared_params)
Tri Dao's avatar
Tri Dao committed
30
from flash_attn.utils.generation import GenerationMixin
Tri Dao's avatar
Tri Dao committed
31
from flash_attn.utils.pretrained import state_dict_from_pretrained
32
33
34
35
36

try:
    from flash_attn.ops.fused_dense import ColumnParallelLinear
except ImportError:
    ColumnParallelLinear = None
Tri Dao's avatar
Tri Dao committed
37
38
39
40
41
42

try:
    from flash_attn.ops.layer_norm import dropout_add_layer_norm
except ImportError:
    dropout_add_layer_norm = None

43
try:
44
45
    from flash_attn.ops.layer_norm import \
        dropout_add_layer_norm_parallel_residual
46
47
48
except ImportError:
    dropout_add_layer_norm_parallel_residual = None

Tri Dao's avatar
Tri Dao committed
49
50
51
try:
    from flash_attn.ops.rms_norm import RMSNorm, dropout_add_rms_norm
except ImportError:
52
    RMSNorm, dropout_add_rms_norm = None, None
Tri Dao's avatar
Tri Dao committed
53
54
55
56
57
58

try:
    from flash_attn.ops.rms_norm import dropout_add_rms_norm_parallel_residual
except ImportError:
    dropout_add_rms_norm_parallel_residual = None

Tri Dao's avatar
Tri Dao committed
59
try:
Tri Dao's avatar
Tri Dao committed
60
    from flash_attn.ops.triton.mlp import FusedDenseSqreluDense
Tri Dao's avatar
Tri Dao committed
61
62
63
except ImportError:
    FusedDenseSqreluDense = None

64
65
66
logger = logging.getLogger(__name__)


67
def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dtype=None):
Tri Dao's avatar
Tri Dao committed
68
69
    factory_kwargs = {"device": device, "dtype": dtype}
    head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
Tri Dao's avatar
Tri Dao committed
70
71
72
73
    softmax_scale = 1.0 if not config.scale_attn_weights else head_dim ** (-0.5)
    if config.scale_attn_by_inverse_layer_idx:
        assert layer_idx is not None
        softmax_scale /= float(layer_idx + 1)
Tri Dao's avatar
Tri Dao committed
74
    dwconv = getattr(config, "attn_dwconv", False)
75
    if dwconv:
Tri Dao's avatar
Tri Dao committed
76
77
78
79
80
81
82
83
84
        assert process_group is None, "TensorParallel MHA does not support dwconv yet"
    qkv_proj_bias = getattr(config, "qkv_proj_bias", True)
    out_proj_bias = getattr(config, "out_proj_bias", True)
    rotary_emb_dim = int(getattr(config, "rotary_emb_fraction", 0.0) * head_dim)
    rotary_emb_base = getattr(config, "rotary_emb_base", 10000.0)
    rotary_emb_scale_base = getattr(config, "rotary_emb_scale_base", None)
    rotary_emb_interleaved = getattr(config, "rotary_emb_interleaved", False)
    use_flash_attn = getattr(config, "use_flash_attn", False)
    fused_bias_fc = getattr(config, "fused_bias_fc", False)
85
    if not fused_bias_fc:
Tri Dao's avatar
Tri Dao committed
86
        assert process_group is None, "TensorParallel MHA requires fused_bias_fc"
87
    mha_cls = MHA if process_group is None else ParallelMHA
Tri Dao's avatar
Tri Dao committed
88
89
90
91
92
93
94
95
96
97
98
    serial_kwargs = (
        {"fused_bias_fc": fused_bias_fc, "dwconv": dwconv} if process_group is None else {}
    )
    parallel_kwargs = (
        {
            "process_group": process_group,
            "sequence_parallel": getattr(config, "sequence_parallel", True),
        }
        if process_group is not None
        else {}
    )
Tri Dao's avatar
Tri Dao committed
99
    num_heads_kv = getattr(config, "n_head_kv", None)
Tri Dao's avatar
Tri Dao committed
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
    mixer_cls = partial(
        mha_cls,
        num_heads=config.num_attention_heads,
        num_heads_kv=num_heads_kv,
        qkv_proj_bias=qkv_proj_bias,
        out_proj_bias=out_proj_bias,
        dropout=config.attn_pdrop,
        softmax_scale=softmax_scale,
        causal=True,
        layer_idx=layer_idx,
        rotary_emb_dim=rotary_emb_dim,
        rotary_emb_base=rotary_emb_base,
        rotary_emb_scale_base=rotary_emb_scale_base,
        rotary_emb_interleaved=rotary_emb_interleaved,
        use_flash_attn=use_flash_attn,
        **serial_kwargs,
        **parallel_kwargs,
        **factory_kwargs,
    )
Tri Dao's avatar
Tri Dao committed
119
120
121
    return mixer_cls


122
def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtype=None):
Tri Dao's avatar
Tri Dao committed
123
124
125
126
    factory_kwargs = {"device": device, "dtype": dtype}
    mlp_fc1_bias = getattr(config, "mlp_fc1_bias", True)
    mlp_fc2_bias = getattr(config, "mlp_fc2_bias", True)
    fused_mlp = getattr(config, "fused_mlp", False)
127
    if fused_mlp:
Tri Dao's avatar
Tri Dao committed
128
129
130
131
132
133
134
135
        assert config.activation_function in [
            "gelu_new",
            "gelu_fast",
            "gelu_approx",
            "relu",
            "sqrelu",
        ]
    fused_dense_sqrelu_dense = getattr(config, "fused_dense_sqrelu_dense", False)
136
    if fused_dense_sqrelu_dense:
Tri Dao's avatar
Tri Dao committed
137
138
139
        assert config.activation_function == "sqrelu", (
            "fused_dense_sqrelu_dense only " "supports approximate activation_function sqrelu"
        )
140
141
    assert not (fused_dense_sqrelu_dense and fused_mlp)
    if not fused_mlp and not fused_dense_sqrelu_dense:
Tri Dao's avatar
Tri Dao committed
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
        assert config.activation_function in [
            "gelu",
            "gelu_new",
            "gelu_fast",
            "gelu_approx",
            "relu",
            "sqrelu",
            "glu",
            "swiglu",
            "geglu",
        ]
        if config.activation_function in ["glu", "swiglu", "geglu"]:
            activation = (
                F.sigmoid
                if config.activation_function == "glu"
                else (F.silu if config.activation_function == "swiglu" else F.gelu)
            )
159
            mlp_cls = GatedMlp if process_group is None else ParallelGatedMlp
Tri Dao's avatar
Tri Dao committed
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
            parallel_kwargs = (
                {
                    "process_group": process_group,
                    "sequence_parallel": getattr(config, "sequence_parallel", True),
                }
                if process_group is not None
                else {}
            )
            mlp_cls = partial(
                mlp_cls,
                hidden_features=config.n_inner,
                activation=activation,
                bias1=mlp_fc1_bias,
                bias2=mlp_fc2_bias,
                **parallel_kwargs,
                **factory_kwargs,
            )
Tri Dao's avatar
Tri Dao committed
177
        else:
Tri Dao's avatar
Tri Dao committed
178
            if config.activation_function == "relu":
Tri Dao's avatar
Tri Dao committed
179
                activation = partial(F.relu, inplace=True)
Tri Dao's avatar
Tri Dao committed
180
            elif config.activation_function == "sqrelu":
Tri Dao's avatar
Tri Dao committed
181
182
                activation = sqrelu_fwd
            else:
Tri Dao's avatar
Tri Dao committed
183
184
185
186
187
188
                approximate = (
                    "tanh"
                    if config.activation_function in ["gelu_new", "gelu_fast", "gelu_approx"]
                    else "none"
                )
                activation = partial(F.gelu, approximate=approximate)
Tri Dao's avatar
Tri Dao committed
189
            mlp_cls = Mlp if process_group is None else ParallelMLP
Tri Dao's avatar
Tri Dao committed
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
            parallel_kwargs = (
                {
                    "process_group": process_group,
                    "sequence_parallel": getattr(config, "sequence_parallel", True),
                }
                if process_group is not None
                else {}
            )
            mlp_cls = partial(
                mlp_cls,
                hidden_features=config.n_inner,
                activation=activation,
                bias1=mlp_fc1_bias,
                bias2=mlp_fc2_bias,
                **parallel_kwargs,
                **factory_kwargs,
            )
Tri Dao's avatar
Tri Dao committed
207
    else:
Tri Dao's avatar
Tri Dao committed
208
        mlp_checkpoint_lvl = getattr(config, "mlp_checkpoint_lvl", 0)
Tri Dao's avatar
Tri Dao committed
209
210
211
212
        # mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
        if isinstance(mlp_checkpoint_lvl, Sequence):
            assert layer_idx is not None
            mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
213
214
        if fused_mlp:
            if FusedMLP is None:
Tri Dao's avatar
Tri Dao committed
215
216
217
218
219
220
                raise ImportError("fused_dense is not installed")
            activation = (
                "gelu_approx"
                if config.activation_function in ["gelu_new", "gelu_fast", "gelu_approx"]
                else config.activation_function
            )
221
            mlp_cls = FusedMLP if process_group is None else ParallelFusedMLP
Tri Dao's avatar
Tri Dao committed
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
            parallel_kwargs = (
                {
                    "process_group": process_group,
                    "sequence_parallel": getattr(config, "sequence_parallel", True),
                }
                if process_group is not None
                else {}
            )
            mlp_cls = partial(
                mlp_cls,
                hidden_features=config.n_inner,
                activation=activation,
                checkpoint_lvl=mlp_checkpoint_lvl,
                bias1=mlp_fc1_bias,
                bias2=mlp_fc2_bias,
                **parallel_kwargs,
                **factory_kwargs,
            )
Tri Dao's avatar
Tri Dao committed
240
        elif fused_dense_sqrelu_dense:
241
            if process_group is not None:
Tri Dao's avatar
Tri Dao committed
242
                assert fused_mlp, "Tensor Parallel is not implemented for FusedDenseSqreluDense"
Tri Dao's avatar
Tri Dao committed
243
            assert FusedDenseSqreluDense is not None
Tri Dao's avatar
Tri Dao committed
244
245
246
247
248
249
            mlp_cls = partial(
                FusedDenseSqreluDense,
                hidden_features=config.n_inner,
                checkpoint_lvl=mlp_checkpoint_lvl,
                **factory_kwargs,
            )
Tri Dao's avatar
Tri Dao committed
250
        else:
Tri Dao's avatar
Tri Dao committed
251
            raise RuntimeError("MLP type not supported")
Tri Dao's avatar
Tri Dao committed
252
253
254
    return mlp_cls


255
def create_block(config, layer_idx=None, process_group=None, device=None, dtype=None):
Tri Dao's avatar
Tri Dao committed
256
257
    factory_kwargs = {"device": device, "dtype": dtype}
    sequence_parallel = getattr(config, "sequence_parallel", True)
258
259
    mixer_cls = create_mixer_cls(config, layer_idx, process_group=process_group, **factory_kwargs)
    mlp_cls = create_mlp_cls(config, layer_idx, process_group=process_group, **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
260
261
262
263
264
265
    use_rms_norm = getattr(config, "rms_norm", False)
    norm_cls = partial(
        nn.LayerNorm if not use_rms_norm else RMSNorm,
        eps=config.layer_norm_epsilon,
        **factory_kwargs,
    )
Tri Dao's avatar
Tri Dao committed
266
    # TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
Tri Dao's avatar
Tri Dao committed
267
    residual_in_fp32 = getattr(config, "residual_in_fp32", False)
Tri Dao's avatar
Tri Dao committed
268
    resid_dropout1 = config.resid_pdrop if layer_idx is None or layer_idx > 0 else config.embd_pdrop
Tri Dao's avatar
Tri Dao committed
269
270
    prenorm = getattr(config, "prenorm", True)
    parallel_block = getattr(config, "parallel_block", False)
Tri Dao's avatar
Tri Dao committed
271
272
    if not parallel_block:
        block = Block(
Tri Dao's avatar
Tri Dao committed
273
274
275
276
277
278
279
280
            config.hidden_size,
            mixer_cls,
            mlp_cls,
            norm_cls=norm_cls,
            prenorm=prenorm,
            resid_dropout1=resid_dropout1,
            resid_dropout2=config.resid_pdrop,
            fused_dropout_add_ln=getattr(config, "fused_dropout_add_ln", False),
Tri Dao's avatar
Tri Dao committed
281
282
            residual_in_fp32=residual_in_fp32,
            sequence_parallel=sequence_parallel and process_group is not None,
Tri Dao's avatar
Tri Dao committed
283
            mark_shared_params=process_group is not None,
Tri Dao's avatar
Tri Dao committed
284
285
286
287
        )
    else:
        assert prenorm
        block = ParallelBlock(
Tri Dao's avatar
Tri Dao committed
288
289
290
291
292
293
294
295
            config.hidden_size,
            mixer_cls,
            mlp_cls,
            norm_cls=norm_cls,
            resid_dropout1=resid_dropout1,
            resid_dropout2=config.resid_pdrop,
            tied_norm=getattr(config, "parallel_block_tied_norm", False),
            fused_dropout_add_ln=getattr(config, "fused_dropout_add_ln", False),
Tri Dao's avatar
Tri Dao committed
296
297
            residual_in_fp32=residual_in_fp32,
            sequence_parallel=sequence_parallel and process_group is not None,
Tri Dao's avatar
Tri Dao committed
298
            mark_shared_params=process_group is not None,
Tri Dao's avatar
Tri Dao committed
299
        )
Tri Dao's avatar
Tri Dao committed
300
301
302
303
    block.layer_idx = layer_idx
    return block


304
class GPTPreTrainedModel(nn.Module):
Tri Dao's avatar
Tri Dao committed
305
306
    """An abstract class to handle weights initialization and
    a simple interface for dowloading and loading pretrained models.
307
    """
Tri Dao's avatar
Tri Dao committed
308

309
310
311
312
313
314
315
316
    def __init__(self, config, *inputs, **kwargs):
        super().__init__()
        if not isinstance(config, GPT2Config):
            raise ValueError(
                "Parameter config in `{}(config)` should be an instance of class `GPT2Config`. "
                "To create a model from a Google pretrained model use "
                "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
                    self.__class__.__name__, self.__class__.__name__
Tri Dao's avatar
Tri Dao committed
317
318
                )
            )
319
320
321
        self.config = config

    @classmethod
Tri Dao's avatar
Tri Dao committed
322
323
324
325
326
327
328
329
330
331
332
333
    def from_pretrained(
        cls,
        model_name,
        config,
        *args,
        strict=True,
        device=None,
        dtype=None,
        world_size=1,
        rank=0,
        **kwargs,
    ):
334
335
336
337
338
        """
        Instantiate a GPTPreTrainedModel from a pre-trained model file or a pytorch state dict.
        Download and cache the pre-trained model file if needed.
        """
        # Instantiate model.
339
        model = cls(config, *args, device=device, dtype=dtype, **kwargs)
340
341
        # Load state_dict in cpu because we already initialized the model in GPU, and we don't
        # want extra stuff taking up more GPU memory
Tri Dao's avatar
Tri Dao committed
342
343
        state_dict = state_dict_from_pretrained(model_name, device="cpu", dtype=dtype)
        if model_name.startswith("gpt2"):
Tri Dao's avatar
Tri Dao committed
344
            state_dict = remap_state_dict_hf_gpt2(state_dict, config)
Tri Dao's avatar
Tri Dao committed
345
        elif model_name.startswith("facebook/opt"):
Tri Dao's avatar
Tri Dao committed
346
            state_dict = remap_state_dict_hf_opt(state_dict, config)
Tri Dao's avatar
Tri Dao committed
347
        elif model_name.startswith("EleutherAI/gpt-j-"):
Tri Dao's avatar
Tri Dao committed
348
            state_dict = remap_state_dict_hf_gptj(state_dict, config)
Tri Dao's avatar
Tri Dao committed
349
        elif model_name.startswith("EleutherAI/gpt-neox-"):
Tri Dao's avatar
Tri Dao committed
350
            state_dict = remap_state_dict_hf_gpt_neox(state_dict, config)
Tri Dao's avatar
Tri Dao committed
351
        elif model_name.startswith("tiiuae/falcon-"):
Tri Dao's avatar
Tri Dao committed
352
            state_dict = remap_state_dict_hf_falcon(state_dict, config)
353
354
        elif model_name.startswith("meta-llama/Llama-"):
            state_dict = remap_state_dict_hf_llama(state_dict, config)
Tri Dao's avatar
Tri Dao committed
355
        else:
Tri Dao's avatar
Tri Dao committed
356
            raise NotImplementedError(f"Model {model_name} not supported")
357
358
359
        if world_size > 1:
            state_dict = shard_state_dict_tp(state_dict, config, world_size, rank)
        load_return = model.load_state_dict(state_dict, strict=strict)
360
361
362
        logger.info(load_return)
        return model

Tri Dao's avatar
Tri Dao committed
363

Tri Dao's avatar
Tri Dao committed
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
def _init_weights(module, n_layer, initializer_range=0.02, rescale_prenorm_residual=True):
    if isinstance(module, nn.Linear):
        nn.init.normal_(module.weight, std=initializer_range)
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Embedding):
        nn.init.normal_(module.weight, std=initializer_range)

    if rescale_prenorm_residual:
        # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
        #   > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
        #   > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
        #   >   -- GPT-2 :: https://openai.com/blog/better-language-models/
        #
        # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
        for name, p in module.named_parameters():
            if name in ["out_proj.weight", "fc2.weight"]:
                # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
                nn.init.normal_(p, mean=0.0, std=initializer_range / math.sqrt(2 * n_layer))


386
class GPTModel(GPTPreTrainedModel):
387
    def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=None):
388
        super().__init__(config)
Tri Dao's avatar
Tri Dao committed
389
        factory_kwargs = {"device": device, "dtype": dtype}
390
        self.process_group = process_group
Tri Dao's avatar
Tri Dao committed
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
        self.sequence_parallel = getattr(config, "sequence_parallel", True)
        assert config.activation_function in [
            "gelu",
            "gelu_new",
            "gelu_fast",
            "gelu_approx",
            "relu",
            "sqrelu",
            "glu",
            "swiglu",
            "geglu",
        ]
        pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
        vocab_size = (
            math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
        )
Tri Dao's avatar
Tri Dao committed
407
        # TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
Tri Dao's avatar
Tri Dao committed
408
        self.residual_in_fp32 = getattr(config, "residual_in_fp32", False)
Tri Dao's avatar
Tri Dao committed
409
        # These 2 options are for OPT-350m
Tri Dao's avatar
Tri Dao committed
410
411
412
        self.prenorm = getattr(config, "prenorm", True)
        use_rms_norm = getattr(config, "rms_norm", False)
        word_embed_proj_dim = getattr(config, "word_embed_proj_dim", None)
Tri Dao's avatar
Tri Dao committed
413
        # For GPT-J, GPT-NeoX
Tri Dao's avatar
Tri Dao committed
414
        self.parallel_block = getattr(config, "parallel_block", False)
Tri Dao's avatar
Tri Dao committed
415

416
        if process_group is None:
Tri Dao's avatar
Tri Dao committed
417
            self.embeddings = GPT2Embeddings(
Tri Dao's avatar
Tri Dao committed
418
419
420
421
422
                config.hidden_size,
                vocab_size,
                config.max_position_embeddings,
                word_embed_proj_dim=word_embed_proj_dim,
                **factory_kwargs,
Tri Dao's avatar
Tri Dao committed
423
            )
424
425
        else:
            self.embeddings = ParallelGPT2Embeddings(
Tri Dao's avatar
Tri Dao committed
426
427
428
429
430
431
                config.hidden_size,
                vocab_size,
                config.max_position_embeddings,
                process_group=process_group,
                sequence_parallel=self.sequence_parallel,
                **factory_kwargs,
432
            )
Tri Dao's avatar
Tri Dao committed
433

Tri Dao's avatar
Tri Dao committed
434
        # We change the order of dropout, residual and layer norm:
Tri Dao's avatar
Tri Dao committed
435
        # Instead of LN -> Attn / MLP -> Dropout -> Add, we do:
Tri Dao's avatar
Tri Dao committed
436
437
438
        # Dropout -> Add -> LN -> Attn / MLP, returning both the residual branch (output of Add) and
        # the main branch (output of MLP). The model definition is unchanged, but the mapping of the
        # nn.Dropout probabilities are changed.
Tri Dao's avatar
Tri Dao committed
439
        # This is for performance reason: we can fuse dropout + add + layer_norm.
Tri Dao's avatar
Tri Dao committed
440
441
442
443
444
445
        self.layers = nn.ModuleList(
            [
                create_block(config, layer_idx=i, process_group=process_group, **factory_kwargs)
                for i in range(config.num_hidden_layers)
            ]
        )
Tri Dao's avatar
Tri Dao committed
446

Tri Dao's avatar
Tri Dao committed
447
        self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
448
        if self.fused_dropout_add_ln:
Tri Dao's avatar
Tri Dao committed
449
450
451
452
            if (not self.parallel_block and dropout_add_layer_norm is None) or (
                self.parallel_block and dropout_add_layer_norm_parallel_residual is None
            ):
                raise ImportError("dropout_layer_norm is not installed")
Tri Dao's avatar
Tri Dao committed
453
454
        if self.prenorm:
            self.drop_f = nn.Dropout(config.resid_pdrop)
Tri Dao's avatar
Tri Dao committed
455
            norm_cls = nn.LayerNorm if not use_rms_norm else RMSNorm
Tri Dao's avatar
Tri Dao committed
456
457
458
            self.ln_f = norm_cls(
                config.hidden_size, eps=config.layer_norm_epsilon, **factory_kwargs
            )
459
        if process_group is not None:
Tri Dao's avatar
Tri Dao committed
460
            for p in self.ln_f.parameters():
461
462
463
464
465
                # Mark the norm parameters as "shared_params" so that we sync their values at init.
                p._shared_params = True
                # Mark the norm params as "sequence_parallel" so we run all-reduce on their grads.
                if self.sequence_parallel:
                    p._sequence_parallel = True
466

Tri Dao's avatar
Tri Dao committed
467
468
469
470
471
472
473
        self.apply(
            partial(
                _init_weights,
                n_layer=config.num_hidden_layers,
                initializer_range=config.initializer_range,
            )
        )
474
475
476
        self.tie_weights()

    def tie_weights(self):
477
        if self.process_group is not None:
478
            sync_shared_params(self, self.process_group)
Tri Dao's avatar
Tri Dao committed
479

480
    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
Tri Dao's avatar
Tri Dao committed
481
482
483
484
        return {
            i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
            for i, layer in enumerate(self.layers)
        }
485

Tri Dao's avatar
Tri Dao committed
486
    def forward(self, input_ids, position_ids=None, inference_params=None):
487
488
489
        # If using Tensor Parallel with sequence parallel, we combine the batch and the seqlen
        # dimensions so that we can split on it easily, in case of small batch size.
        # Only the attention layers need to know the seqlen.
Tri Dao's avatar
Tri Dao committed
490
491
492
493
494
        embedding_kwargs = (
            {"combine_batch_seqlen_dim": True}
            if self.process_group is not None and self.sequence_parallel
            else {}
        )
495
        hidden_states = self.embeddings(input_ids, position_ids=position_ids, **embedding_kwargs)
Tri Dao's avatar
Tri Dao committed
496
497
        if self.parallel_block:
            hidden_states2 = None
Tri Dao's avatar
Tri Dao committed
498
        residual = None
Tri Dao's avatar
Tri Dao committed
499
500
501
502
503
        mixer_kwargs = (
            {"seqlen": input_ids.shape[1]}
            if self.process_group is not None and self.sequence_parallel
            else {}
        )
Tri Dao's avatar
Tri Dao committed
504
        if inference_params is not None:
Tri Dao's avatar
Tri Dao committed
505
            mixer_kwargs["inference_params"] = inference_params
Tri Dao's avatar
Tri Dao committed
506
        for layer in self.layers:
Tri Dao's avatar
Tri Dao committed
507
            if self.prenorm:
Tri Dao's avatar
Tri Dao committed
508
                if not self.parallel_block:
Tri Dao's avatar
Tri Dao committed
509
510
511
                    hidden_states, residual = layer(
                        hidden_states, residual, mixer_kwargs=mixer_kwargs
                    )
Tri Dao's avatar
Tri Dao committed
512
513
514
515
                else:
                    hidden_states, hidden_states2, residual = layer(
                        hidden_states, hidden_states2, residual, mixer_kwargs=mixer_kwargs
                    )
Tri Dao's avatar
Tri Dao committed
516
517
518
519
520
            else:
                hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
        if self.prenorm:
            if not self.fused_dropout_add_ln:
                dropped = self.drop_f(hidden_states)
Tri Dao's avatar
Tri Dao committed
521
522
523
524
                if not self.parallel_block:
                    residual = (dropped + residual) if residual is not None else dropped
                else:
                    dropped2 = self.drop_f(hidden_states2)
Tri Dao's avatar
Tri Dao committed
525
526
527
528
529
                    residual = (
                        (residual + dropped + dropped2)
                        if residual is not None
                        else dropped + dropped2
                    )
Tri Dao's avatar
Tri Dao committed
530
531
                hidden_states = self.ln_f(residual.to(dtype=self.ln_f.weight.dtype))
            else:
Tri Dao's avatar
Tri Dao committed
532
                # Set prenorm=False here since we don't need the residual
533
                if not self.parallel_block:
Tri Dao's avatar
Tri Dao committed
534
535
536
537
538
                    fused_add_norm_fn = (
                        dropout_add_rms_norm
                        if isinstance(self.ln_f, RMSNorm)
                        else dropout_add_layer_norm
                    )
539
                    hidden_states = fused_add_norm_fn(
Tri Dao's avatar
Tri Dao committed
540
541
542
543
544
545
546
547
                        hidden_states,
                        residual,
                        self.ln_f.weight,
                        self.ln_f.bias,
                        self.drop_f.p if self.training else 0.0,
                        self.ln_f.eps,
                        prenorm=False,
                        residual_in_fp32=self.residual_in_fp32,
548
549
                    )
                else:
Tri Dao's avatar
Tri Dao committed
550
551
552
553
554
                    fused_add_norm_fn = (
                        dropout_add_rms_norm_parallel_residual
                        if isinstance(self.ln_f, RMSNorm)
                        else dropout_add_layer_norm_parallel_residual
                    )
555
                    hidden_states, _ = fused_add_norm_fn(
Tri Dao's avatar
Tri Dao committed
556
557
558
559
560
561
562
563
564
565
566
                        hidden_states,
                        hidden_states2,
                        residual,
                        self.ln_f.weight,
                        self.ln_f.bias,
                        None,
                        None,
                        self.drop_f.p if self.training else 0.0,
                        self.ln_f.eps,
                        prenorm=False,
                        residual_in_fp32=self.residual_in_fp32,
567
                    )
Tri Dao's avatar
Tri Dao committed
568
569
570
        return hidden_states


Tri Dao's avatar
Tri Dao committed
571
class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
572
    def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=None):
Tri Dao's avatar
Tri Dao committed
573
        factory_kwargs = {"device": device, "dtype": dtype}
574
        super().__init__(config)
575
576
        self.process_group = process_group
        self.transformer = GPTModel(config, process_group=process_group, **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
577
578
579
580
581
582
        self.tie_word_embeddings = getattr(config, "tie_word_embeddings", True)
        lm_head_bias = getattr(config, "lm_head_bias", False)
        pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
        vocab_size = (
            math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
        )
Tri Dao's avatar
Tri Dao committed
583
        # This option is for OPT-350m
Tri Dao's avatar
Tri Dao committed
584
        word_embed_proj_dim = getattr(config, "word_embed_proj_dim", None)
Tri Dao's avatar
Tri Dao committed
585
586
587
588
589
        embed_dim = config.n_embd if word_embed_proj_dim is None else word_embed_proj_dim
        if word_embed_proj_dim is not None:
            self.project_out = nn.Linear(config.n_embd, embed_dim, bias=False, **factory_kwargs)
        else:
            self.project_out = None
590
        if process_group is None:
Tri Dao's avatar
Tri Dao committed
591
            self.lm_head = nn.Linear(embed_dim, vocab_size, bias=lm_head_bias, **factory_kwargs)
592
593
        else:
            if ColumnParallelLinear is None:
Tri Dao's avatar
Tri Dao committed
594
                raise ImportError("fused_dense_lib is not installed")
595
            self.lm_head = ColumnParallelLinear(
Tri Dao's avatar
Tri Dao committed
596
597
598
599
600
601
                embed_dim,
                vocab_size,
                process_group,
                bias=lm_head_bias,
                sequence_parallel=getattr(config, "sequence_parallel", True),
                **factory_kwargs,
602
            )
Tri Dao's avatar
Tri Dao committed
603
        # Initialize weights and apply final processing
Tri Dao's avatar
Tri Dao committed
604
605
606
607
608
609
610
        self.apply(
            partial(
                _init_weights,
                n_layer=config.num_hidden_layers,
                initializer_range=config.initializer_range,
            )
        )
Tri Dao's avatar
Tri Dao committed
611
612
613
        self.tie_weights()

    def tie_weights(self):
Tri Dao's avatar
Tri Dao committed
614
615
        if self.tie_word_embeddings:
            self.lm_head.weight = self.transformer.embeddings.word_embeddings.weight
616
        if self.process_group is not None:
617
            sync_shared_params(self, self.process_group)
Tri Dao's avatar
Tri Dao committed
618

619
    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
Tri Dao's avatar
Tri Dao committed
620
621
622
        return self.transformer.allocate_inference_cache(
            batch_size, max_seqlen, dtype=dtype, **kwargs
        )
623

624
    def forward(self, input_ids, position_ids=None, inference_params=None, last_token_only=False):
Tri Dao's avatar
Tri Dao committed
625
        """
Tri Dao's avatar
Tri Dao committed
626
627
628
629
        inference_params: for generation. Adapted from Megatron-LM (and Apex)
        https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
        last_token_only: whether to return the logit for the last token only,
            of shape (batch_size, vocab_size)
Tri Dao's avatar
Tri Dao committed
630
        """
Tri Dao's avatar
Tri Dao committed
631
632
633
        hidden_states = self.transformer(
            input_ids, position_ids=position_ids, inference_params=inference_params
        )
634
635
        if last_token_only:
            hidden_states = hidden_states[:, -1]
Tri Dao's avatar
Tri Dao committed
636
637
        if self.project_out is not None:
            hidden_states = self.project_out(hidden_states)
Tri Dao's avatar
Tri Dao committed
638
        lm_logits = self.lm_head(hidden_states)
639
640
641
        # During inference, we want the full logit for sampling
        if isinstance(self.lm_head, ColumnParallelLinear) and inference_params is not None:
            lm_logits, _ = all_gather_raw(lm_logits, self.lm_head.process_group)
Tri Dao's avatar
Tri Dao committed
642
643
            lm_logits = rearrange(lm_logits, "(n b) ... d -> b ... (n d)", b=hidden_states.shape[0])
        CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
Tri Dao's avatar
Tri Dao committed
644
        return CausalLMOutput(logits=lm_logits)
645

Tri Dao's avatar
Tri Dao committed
646
647
648
649
    def load_state_dict(self, state_dict, strict=True):
        # Remapping from our checkpoints that used a different ordering of layers in the block
        # Previous: Attn / MLP -> Dropout -> Add -> LN
        # Current: Dropout -> Add -> LN -> Attn / MLP
Tri Dao's avatar
Tri Dao committed
650
        if "transformer.ln_0.weight" in state_dict:
Tri Dao's avatar
Tri Dao committed
651
            n_layers = len(self.transformer.layers)
Tri Dao's avatar
Tri Dao committed
652
653
654
655
            ln_weight = state_dict.pop(f"transformer.layers.{n_layers - 1}.norm2.weight")
            ln_bias = state_dict.pop(f"transformer.layers.{n_layers - 1}.norm2.bias")
            state_dict["transformer.ln_f.weight"] = ln_weight
            state_dict["transformer.ln_f.bias"] = ln_bias
Tri Dao's avatar
Tri Dao committed
656
            for l in reversed(range(n_layers)):
Tri Dao's avatar
Tri Dao committed
657
658
659
660
                ln_weight = state_dict.pop(f"transformer.layers.{l}.norm1.weight")
                ln_bias = state_dict.pop(f"transformer.layers.{l}.norm1.bias")
                state_dict[f"transformer.layers.{l}.norm2.weight"] = ln_weight
                state_dict[f"transformer.layers.{l}.norm2.bias"] = ln_bias
Tri Dao's avatar
Tri Dao committed
661
                if l > 0:
Tri Dao's avatar
Tri Dao committed
662
663
664
665
666
667
668
669
                    ln_weight = state_dict.pop(f"transformer.layers.{l - 1}.norm2.weight")
                    ln_bias = state_dict.pop(f"transformer.layers.{l - 1}.norm2.bias")
                    state_dict[f"transformer.layers.{l}.norm1.weight"] = ln_weight
                    state_dict[f"transformer.layers.{l}.norm1.bias"] = ln_bias
            ln_weight = state_dict.pop("transformer.ln_0.weight")
            ln_bias = state_dict.pop("transformer.ln_0.bias")
            state_dict[f"transformer.layers.0.norm1.weight"] = ln_weight
            state_dict[f"transformer.layers.0.norm1.bias"] = ln_bias
Tri Dao's avatar
Tri Dao committed
670
671
        return super().load_state_dict(state_dict, strict=strict)

672

Tri Dao's avatar
Tri Dao committed
673
674
675
def shard_state_dict_tp(state_dict, config, world_size, rank):
    """Convert the state_dict of a standard GPT model to the state_dict of a GPT model
    with tensor parallel.
676
677

    This function modifies state_dict in place.
Tri Dao's avatar
Tri Dao committed
678
    """
Tri Dao's avatar
Tri Dao committed
679
680
    pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
    vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
Tri Dao's avatar
Tri Dao committed
681
682
683
684
685
    assert vocab_size % world_size == 0
    assert config.hidden_size % world_size == 0
    inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
    assert inner_dim % world_size == 0

686
687
688
689
690
691
    n_head = config.n_head
    n_head_kv = getattr(config, "n_head_kv", n_head)

    embed_dim = config.hidden_size
    head_dim = embed_dim // n_head

Tri Dao's avatar
Tri Dao committed
692
    def shard_first_dim(state_dict, key):
Tri Dao's avatar
Tri Dao committed
693
694
695
        if key in state_dict:
            x = state_dict[key]
            dim = x.shape[0] // world_size
Tri Dao's avatar
Tri Dao committed
696
            state_dict[key] = x[rank * dim : (rank + 1) * dim]
Tri Dao's avatar
Tri Dao committed
697

698
    def shard_last_dim(state_dict, key, multiple_of=1):
Tri Dao's avatar
Tri Dao committed
699
700
        if key in state_dict:
            x = state_dict[key]
701
702
703
704
705
706
            dim_each_rank = [
                get_dim_for_local_rank(x.size(-1), world_size, local_rank, multiple_of)
                for local_rank in range(world_size)
            ]
            beg, end = tuple(sum(dim_each_rank[:pos]) for pos in (rank, rank + 1))
            state_dict[key] = x[..., beg:end]
Tri Dao's avatar
Tri Dao committed
707

Tri Dao's avatar
Tri Dao committed
708
709
710
711
712
    def shard_gatedmlp_fc1_dim(state_dict, key):
        if key in state_dict:
            x = state_dict[key]
            dim = x.shape[0] // world_size // 2
            state_dict[key] = rearrange(
Tri Dao's avatar
Tri Dao committed
713
                rearrange(x, "(two o) ... -> two o ...", two=2)[:, rank * dim : (rank + 1) * dim],
Tri Dao's avatar
Tri Dao committed
714
                "two o ... -> (two o) ...",
Tri Dao's avatar
Tri Dao committed
715
716
            )

Tri Dao's avatar
Tri Dao committed
717
    def shard_qkv_headdim(state_dict, key):
Tri Dao's avatar
Tri Dao committed
718
        if key in state_dict:
719
            n_head_each_rank = [
Tri Dao's avatar
Tri Dao committed
720
721
                get_dim_for_local_rank(n_head, world_size, local_rank)
                for local_rank in range(world_size)
722
723
            ]
            n_head_kv_each_rank = [
Tri Dao's avatar
Tri Dao committed
724
725
                get_dim_for_local_rank(n_head_kv, world_size, local_rank)
                for local_rank in range(world_size)
726
727
728
729
730
731
732
733
            ]

            beg_n_head = sum(n_head_each_rank[:rank])
            end_n_head = sum(n_head_each_rank[: rank + 1])

            beg_n_head_kv = sum(n_head_kv_each_rank[:rank])
            end_n_head_kv = sum(n_head_kv_each_rank[: rank + 1])

Tri Dao's avatar
Tri Dao committed
734
            if n_head_kv == n_head:
Tri Dao's avatar
Tri Dao committed
735
736
                x = rearrange(state_dict[key], "(three d) ... -> three d ...", three=3)
                state_dict[key] = rearrange(
Tri Dao's avatar
Tri Dao committed
737
738
                    x[:, beg_n_head * head_dim : end_n_head * head_dim],
                    "three d ... -> (three d) ...",
Tri Dao's avatar
Tri Dao committed
739
                )
Tri Dao's avatar
Tri Dao committed
740
            else:
Tri Dao's avatar
Tri Dao committed
741
742
743
744
745
746
747
748
                x = rearrange(
                    state_dict[key],
                    "(nheadqkv headdim) ... -> nheadqkv headdim ...",
                    nheadqkv=n_head + 2 * n_head_kv,
                )
                state_dict[key] = rearrange(
                    torch.cat(
                        [
749
                            x[beg_n_head:end_n_head],
Tri Dao's avatar
Tri Dao committed
750
751
752
753
754
755
756
757
                            x[n_head + beg_n_head_kv : n_head + end_n_head_kv],
                            x[
                                n_head
                                + n_head_kv
                                + beg_n_head_kv : n_head
                                + n_head_kv
                                + end_n_head_kv
                            ],
Tri Dao's avatar
Tri Dao committed
758
759
760
761
762
763
764
765
766
767
768
                        ],
                        dim=0,
                    ),
                    "nheadqkv headdim ... -> (nheadqkv headdim) ...",
                )

    shard_first_dim(state_dict, "transformer.embeddings.word_embeddings.weight")
    if "lm_head.weight" in state_dict:
        shard_first_dim(state_dict, "lm_head.weight")
    if "transformer.embeddings.position_embeddings.weight" in state_dict:
        shard_last_dim(state_dict, "transformer.embeddings.position_embeddings.weight")
Tri Dao's avatar
Tri Dao committed
769
    for i in range(config.num_hidden_layers):
Tri Dao's avatar
Tri Dao committed
770
771
        shard_qkv_headdim(state_dict, f"transformer.layers.{i}.mixer.Wqkv.weight")
        shard_qkv_headdim(state_dict, f"transformer.layers.{i}.mixer.Wqkv.bias")
772
773
774
        shard_last_dim(
            state_dict, f"transformer.layers.{i}.mixer.out_proj.weight", multiple_of=head_dim
        )
Tri Dao's avatar
Tri Dao committed
775
        if rank != 0:
Tri Dao's avatar
Tri Dao committed
776
            state_dict.pop(f"transformer.layers.{i}.mixer.out_proj.bias", None)
Tri Dao's avatar
Tri Dao committed
777
        if config.activation_function in ["glu", "swiglu", "geglu"]:
Tri Dao's avatar
Tri Dao committed
778
779
            shard_gatedmlp_fc1_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.weight")
            shard_gatedmlp_fc1_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.bias")
Tri Dao's avatar
Tri Dao committed
780
        else:
Tri Dao's avatar
Tri Dao committed
781
782
783
            shard_first_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.weight")
            shard_first_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.bias")
        shard_last_dim(state_dict, f"transformer.layers.{i}.mlp.fc2.weight")
Tri Dao's avatar
Tri Dao committed
784
        if rank != 0:
Tri Dao's avatar
Tri Dao committed
785
            state_dict.pop(f"transformer.layers.{i}.mlp.fc2.bias", None)
Tri Dao's avatar
Tri Dao committed
786
787
788
    return state_dict


789
790
791
def combine_state_dicts_tp(state_dicts: list[dict[str, torch.Tensor]], config: GPT2Config):
    """Convert the list of sharded state_dict of a GPT model with tensor parallel to
    the state_dict of a standard GPT model.
792
793

    This function is meant to be the "reverse" of shard_state_dict_tp.
794
795
796

    Precondition:
        - state_dicts should be ordered in the same way as the shards were created.
Tri Dao's avatar
Tri Dao committed
797
798
799
    """
    world_size = len(state_dicts)
    keys = state_dicts[0].keys()
Tri Dao's avatar
Tri Dao committed
800
801
    pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
    vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
Tri Dao's avatar
Tri Dao committed
802
803
804
805
806
    assert vocab_size % world_size == 0
    assert config.hidden_size % world_size == 0
    inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
    assert inner_dim % world_size == 0

Tri Dao's avatar
Tri Dao committed
807
    # Sometimes the word embeddings are sharded on the 0th dim, sometimes on the 1st dim.
Tri Dao's avatar
Tri Dao committed
808
809
    # vocab_size // world_size coordinates are nonzero.
    def combine_word_embeddings(state_dicts, state_dict, key):
Tri Dao's avatar
Tri Dao committed
810
811
        dim = 0 if state_dicts[0][key].shape[0] == vocab_size // world_size else 1
        state_dict[key] = torch.cat([s[key] for s in state_dicts], dim=dim)
Tri Dao's avatar
Tri Dao committed
812
813

    def combine_dim(state_dicts, state_dict, key, dim=-1):
Tri Dao's avatar
Tri Dao committed
814
815
        if key in state_dict:
            state_dict[key] = torch.cat([s[key] for s in state_dicts], dim=dim)
Tri Dao's avatar
Tri Dao committed
816
817

    def combine_qkv_headdim(state_dicts, state_dict, key):
Tri Dao's avatar
Tri Dao committed
818
        n_head = config.n_head
Tri Dao's avatar
Tri Dao committed
819
        n_head_kv = getattr(config, "n_head_kv", n_head)
Tri Dao's avatar
Tri Dao committed
820
        if key in state_dict:
Tri Dao's avatar
Tri Dao committed
821
            if n_head_kv == n_head:
Tri Dao's avatar
Tri Dao committed
822
823
824
825
                xs = [
                    rearrange(s[key], "(three d) ... -> three d ...", three=3) for s in state_dicts
                ]
                state_dict[key] = rearrange(torch.cat(xs, dim=1), "three d ... -> (three d) ...")
Tri Dao's avatar
Tri Dao committed
826
            else:
Tri Dao's avatar
Tri Dao committed
827
828
829
830
831
832
833
834
                xs = [
                    rearrange(
                        s[key],
                        "(nheadqkv headdim) ... -> nheadqkv headdim ...",
                        nheadqkv=n_head + 2 * n_head_kv,
                    )
                    for s in state_dicts
                ]
835
836
837
838
839
840
841
842
                n_head_each_rank = [
                    get_dim_for_local_rank(n_head, world_size, local_rank)
                    for local_rank in range(world_size)
                ]
                n_head_kv_each_rank = [
                    get_dim_for_local_rank(n_head_kv, world_size, local_rank)
                    for local_rank in range(world_size)
                ]
Tri Dao's avatar
Tri Dao committed
843
844
845
                state_dict[key] = rearrange(
                    torch.cat(
                        [
846
847
848
849
850
851
852
853
854
855
856
857
858
                            torch.cat(
                                [x[: n_head_each_rank[rank]] for rank, x in enumerate(xs)], dim=0
                            ),
                            torch.cat(
                                [
                                    x[
                                        n_head_each_rank[rank] : n_head_each_rank[rank]
                                        + n_head_kv_each_rank[rank]
                                    ]
                                    for rank, x in enumerate(xs)
                                ],
                                dim=0,
                            ),
Tri Dao's avatar
Tri Dao committed
859
860
                            torch.cat(
                                [
861
862
                                    x[n_head_each_rank[rank] + n_head_kv_each_rank[rank] :]
                                    for rank, x in enumerate(xs)
Tri Dao's avatar
Tri Dao committed
863
864
865
866
867
868
869
870
                                ],
                                dim=0,
                            ),
                        ],
                        dim=0,
                    ),
                    "nheadqkv headdim ... -> (nheadqkv headdim) ...",
                )
Tri Dao's avatar
Tri Dao committed
871
872
873

    def combine_gated_mlp(state_dicts, state_dict, key):
        if key in state_dict:
Tri Dao's avatar
Tri Dao committed
874
875
            xs = [rearrange(s[key], "(two d) ... -> two d ...", two=2) for s in state_dicts]
            state_dict[key] = rearrange(torch.cat(xs, dim=1), "two d ... -> (two d) ...")
Tri Dao's avatar
Tri Dao committed
876
877

    state_dict = state_dicts[0].copy()  # don't modify state_dict[0] inplace
Tri Dao's avatar
Tri Dao committed
878
879
880
881
882
883
884
885
886
887
888
889
890
891
    combine_word_embeddings(
        state_dicts, state_dict, "transformer.embeddings.word_embeddings.weight"
    )
    if "lm_head.weight" in state_dict:
        combine_word_embeddings(state_dicts, state_dict, "lm_head.weight")
    if "transformer.embeddings.position_embeddings.weight" in state_dict:
        combine_dim(
            state_dicts, state_dict, "transformer.embeddings.position_embeddings.weight", -1
        )
    mlp_combine_fn = (
        combine_gated_mlp
        if config.activation_function in ["glu", "swiglu", "geglu"]
        else partial(combine_dim, dim=0)
    )
Tri Dao's avatar
Tri Dao committed
892
    for i in range(config.num_hidden_layers):
Tri Dao's avatar
Tri Dao committed
893
894
895
896
897
898
        combine_qkv_headdim(state_dicts, state_dict, f"transformer.layers.{i}.mixer.Wqkv.weight")
        combine_qkv_headdim(state_dicts, state_dict, f"transformer.layers.{i}.mixer.Wqkv.bias")
        combine_dim(state_dicts, state_dict, f"transformer.layers.{i}.mixer.out_proj.weight", -1)
        mlp_combine_fn(state_dicts, state_dict, f"transformer.layers.{i}.mlp.fc1.weight")
        combine_dim(state_dicts, state_dict, f"transformer.layers.{i}.mlp.fc1.bias", 0)
        combine_dim(state_dicts, state_dict, f"transformer.layers.{i}.mlp.fc2.weight", -1)
Tri Dao's avatar
Tri Dao committed
899
900
901
902
    return state_dict


def remap_state_dict_hf_gpt2(state_dict, config):
903
904
    # Word embedding and position embedding
    def key_mapping_pos_emb(key):
Tri Dao's avatar
Tri Dao committed
905
906
        return re.sub(r"^wpe.", "transformer.embeddings.position_embeddings.", key)

907
    state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items())
Tri Dao's avatar
Tri Dao committed
908
    word_embeddings = state_dict.pop("wte.weight")
909
    # It's possible that vocab_size is padded to be a multiple of 8, for example.
Tri Dao's avatar
Tri Dao committed
910
911
912
    pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
    vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
    state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
913
        word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
914
    )
Tri Dao's avatar
Tri Dao committed
915
    state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
916
917

    # LayerNorm
Tri Dao's avatar
Tri Dao committed
918
    def key_mapping_ln(key):
Tri Dao's avatar
Tri Dao committed
919
920
        key = re.sub(r"^ln_f.(weight|bias)", r"transformer.ln_f.\1", key)
        key = re.sub(r"^h.(\d+).ln_(1|2).(weight|bias)", r"transformer.layers.\1.norm\2.\3", key)
Tri Dao's avatar
Tri Dao committed
921
        return key
Tri Dao's avatar
Tri Dao committed
922

Tri Dao's avatar
Tri Dao committed
923
    state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
924
925
926

    # MLP
    for d in range(config.num_hidden_layers):
Tri Dao's avatar
Tri Dao committed
927
928
929
930
931
        W1 = state_dict.pop(f"h.{d}.mlp.c_fc.weight")
        state_dict[f"transformer.layers.{d}.mlp.fc1.weight"] = W1.t()
        W2 = state_dict.pop(f"h.{d}.mlp.c_proj.weight")
        state_dict[f"transformer.layers.{d}.mlp.fc2.weight"] = W2.t()

932
    def key_mapping_mlp(key):
Tri Dao's avatar
Tri Dao committed
933
934
        key = re.sub(r"^h.(\d+).mlp.c_fc.bias", r"transformer.layers.\1.mlp.fc1.bias", key)
        key = re.sub(r"^h.(\d+).mlp.c_proj.bias", r"transformer.layers.\1.mlp.fc2.bias", key)
935
        return key
Tri Dao's avatar
Tri Dao committed
936

937
938
939
940
    state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())

    # Attention
    for d in range(config.num_hidden_layers):
Tri Dao's avatar
Tri Dao committed
941
942
943
944
945
946
        state_dict.pop(f"h.{d}.attn.bias")  # We don't store this bias
        Wqkv = state_dict.pop(f"h.{d}.attn.c_attn.weight")
        state_dict[f"transformer.layers.{d}.mixer.Wqkv.weight"] = Wqkv.t()
        Wout = state_dict.pop(f"h.{d}.attn.c_proj.weight")
        state_dict[f"transformer.layers.{d}.mixer.out_proj.weight"] = Wout.t()

947
    def key_mapping_attn(key):
Tri Dao's avatar
Tri Dao committed
948
949
950
951
        key = re.sub(r"^h.(\d+).attn.c_attn.bias", r"transformer.layers.\1.mixer.Wqkv.bias", key)
        key = re.sub(
            r"^h.(\d+).attn.c_proj.bias", r"transformer.layers.\1.mixer.out_proj.bias", key
        )
952
        return key
Tri Dao's avatar
Tri Dao committed
953

954
955
956
    state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())

    return state_dict
957
958


Tri Dao's avatar
Tri Dao committed
959
960
def remap_state_dict_megatron(state_dict, config):
    def key_mapping_transformer(key):
Tri Dao's avatar
Tri Dao committed
961
962
        key = re.sub(r"^language_model.encoder.", "transformer.", key)
        key = re.sub(r"^language_model.", "transformer.", key)
Tri Dao's avatar
Tri Dao committed
963
        return key
Tri Dao's avatar
Tri Dao committed
964

Tri Dao's avatar
Tri Dao committed
965
    state_dict = OrderedDict((key_mapping_transformer(k), v) for k, v in state_dict.items())
966

Tri Dao's avatar
Tri Dao committed
967
968
    # Word embedding and position embedding
    def key_mapping_pos_emb(key):
Tri Dao's avatar
Tri Dao committed
969
970
        return re.sub(r"^wpe.", "transformer.embeddings.position_embeddings.", key)

Tri Dao's avatar
Tri Dao committed
971
    state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items())
Tri Dao's avatar
Tri Dao committed
972
    word_embeddings = state_dict.pop("transformer.embedding.word_embeddings.weight")
Tri Dao's avatar
Tri Dao committed
973
    # It's possible that vocab_size is padded to be a multiple of 8, for example.
Tri Dao's avatar
Tri Dao committed
974
975
976
977
978
    pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
    vocab_size = (
        math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) * pad_vocab_size_multiple
    )
    state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
Tri Dao's avatar
Tri Dao committed
979
980
        word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
    )
Tri Dao's avatar
Tri Dao committed
981
    state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
982

Tri Dao's avatar
Tri Dao committed
983
984
    # LayerNorm
    def key_mapping_ln(key):
Tri Dao's avatar
Tri Dao committed
985
986
987
988
989
990
991
992
993
994
995
        key = re.sub(r"^transformer.final_layernorm.(weight|bias)", r"transformer.ln_f.\1", key)
        key = re.sub(
            r"^transformer.layers.(\d+).input_layernorm.(weight|bias)",
            r"transformer.layers.\1.norm1.\2",
            key,
        )
        key = re.sub(
            r"^transformer.layers.(\d+).post_attention_layernorm.(weight|bias)",
            r"transformer.layers.\1.norm2.\2",
            key,
        )
Tri Dao's avatar
Tri Dao committed
996
        return key
Tri Dao's avatar
Tri Dao committed
997

Tri Dao's avatar
Tri Dao committed
998
    state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
999

Tri Dao's avatar
Tri Dao committed
1000
1001
    # MLP
    def key_mapping_mlp(key):
Tri Dao's avatar
Tri Dao committed
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
        key = re.sub(
            r"^transformer.layers.(\d+).mlp.dense_h_to_4h.(weight|bias)",
            r"transformer.layers.\1.mlp.fc1.\2",
            key,
        )
        key = re.sub(
            r"^transformer.layers.(\d+).mlp.dense_4h_to_h.(weight|bias)",
            r"transformer.layers.\1.mlp.fc2.\2",
            key,
        )
Tri Dao's avatar
Tri Dao committed
1012
        return key
Tri Dao's avatar
Tri Dao committed
1013

Tri Dao's avatar
Tri Dao committed
1014
    state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
1015

Tri Dao's avatar
Tri Dao committed
1016
1017
    # Attention
    def key_mapping_attn(key):
Tri Dao's avatar
Tri Dao committed
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
        key = re.sub(
            r"^transformer.layers.(\d+).self_attention.rotary_emb.inv_freq",
            r"transformer.layers.\1.mixer.rotary_emb.inv_freq",
            key,
        )
        key = re.sub(
            r"^transformer.layers.(\d+).self_attention.query_key_value.(weight|bias)",
            r"transformer.layers.\1.mixer.Wqkv.\2",
            key,
        )
        key = re.sub(
            r"^transformer.layers.(\d+).self_attention.dense.(weight|bias)",
            r"transformer.layers.\1.mixer.out_proj.\2",
            key,
        )
Tri Dao's avatar
Tri Dao committed
1033
        return key
Tri Dao's avatar
Tri Dao committed
1034

Tri Dao's avatar
Tri Dao committed
1035
1036
1037
1038
1039
    state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
    # Megatron stores Wqkv as ((nheads 3 headdim), hidden_dim)
    # while we store Wqkv as ((3 nheads headdim), hidden_dim)
    headdim = config.hidden_size // config.num_attention_heads
    for d in range(config.num_hidden_layers):
Tri Dao's avatar
Tri Dao committed
1040
1041
1042
1043
1044
1045
        Wqkv = state_dict.pop(f"transformer.layers.{d}.mixer.Wqkv.weight")
        state_dict[f"transformer.layers.{d}.mixer.Wqkv.weight"] = rearrange(
            Wqkv,
            "(nheads three headdim) ... -> (three nheads headdim) ...",
            three=3,
            headdim=headdim,
Tri Dao's avatar
Tri Dao committed
1046
        )
Tri Dao's avatar
Tri Dao committed
1047
1048
1049
        bqkv = state_dict.pop(f"transformer.layers.{d}.mixer.Wqkv.bias")
        state_dict[f"transformer.layers.{d}.mixer.Wqkv.bias"] = rearrange(
            bqkv, "(nheads three headdim) -> (three nheads headdim)", three=3, headdim=headdim
Tri Dao's avatar
Tri Dao committed
1050
        )
1051
1052

    return state_dict