gpt.py 45 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
# Copyright (c) 2023, Tri Dao.
Tri Dao's avatar
Tri Dao committed
2

3
import logging
Tri Dao's avatar
Tri Dao committed
4
import math
5
import re
Tri Dao's avatar
Tri Dao committed
6
from collections import OrderedDict, namedtuple
Tri Dao's avatar
Tri Dao committed
7
from collections.abc import Sequence
Tri Dao's avatar
Tri Dao committed
8
from functools import partial
Tri Dao's avatar
Tri Dao committed
9
10
11
12

import torch
import torch.nn as nn
import torch.nn.functional as F
13
from einops import rearrange
Tri Dao's avatar
Tri Dao committed
14
15
from transformers import GPT2Config

Tri Dao's avatar
Tri Dao committed
16
17
18
19
from flash_attn.models.falcon import remap_state_dict_hf_falcon
from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox
from flash_attn.models.gptj import remap_state_dict_hf_gptj
from flash_attn.models.opt import remap_state_dict_hf_opt
Tri Dao's avatar
Tri Dao committed
20
from flash_attn.modules.block import Block, ParallelBlock
21
from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings
Tri Dao's avatar
Tri Dao committed
22
from flash_attn.modules.mha import MHA, ParallelMHA
23
24
from flash_attn.modules.mlp import (FusedMLP, GatedMlp, Mlp, ParallelFusedMLP,
                                    ParallelGatedMlp, ParallelMLP)
Tri Dao's avatar
Tri Dao committed
25
from flash_attn.ops.activations import sqrelu_fwd
26
27
28
from flash_attn.utils.distributed import (all_gather_raw,
                                          get_dim_for_local_rank,
                                          sync_shared_params)
Tri Dao's avatar
Tri Dao committed
29
from flash_attn.utils.generation import GenerationMixin
Tri Dao's avatar
Tri Dao committed
30
from flash_attn.utils.pretrained import state_dict_from_pretrained
31
32
33
34
35

try:
    from flash_attn.ops.fused_dense import ColumnParallelLinear
except ImportError:
    ColumnParallelLinear = None
Tri Dao's avatar
Tri Dao committed
36
37
38
39
40
41

try:
    from flash_attn.ops.layer_norm import dropout_add_layer_norm
except ImportError:
    dropout_add_layer_norm = None

42
try:
43
44
    from flash_attn.ops.layer_norm import \
        dropout_add_layer_norm_parallel_residual
45
46
47
except ImportError:
    dropout_add_layer_norm_parallel_residual = None

Tri Dao's avatar
Tri Dao committed
48
49
50
try:
    from flash_attn.ops.rms_norm import RMSNorm, dropout_add_rms_norm
except ImportError:
51
    RMSNorm, dropout_add_rms_norm = None, None
Tri Dao's avatar
Tri Dao committed
52
53
54
55
56
57

try:
    from flash_attn.ops.rms_norm import dropout_add_rms_norm_parallel_residual
except ImportError:
    dropout_add_rms_norm_parallel_residual = None

Tri Dao's avatar
Tri Dao committed
58
try:
Tri Dao's avatar
Tri Dao committed
59
    from flash_attn.ops.triton.mlp import FusedDenseSqreluDense
Tri Dao's avatar
Tri Dao committed
60
61
62
except ImportError:
    FusedDenseSqreluDense = None

63
64
65
logger = logging.getLogger(__name__)


66
def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dtype=None):
Tri Dao's avatar
Tri Dao committed
67
68
    factory_kwargs = {"device": device, "dtype": dtype}
    head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
Tri Dao's avatar
Tri Dao committed
69
70
71
72
    softmax_scale = 1.0 if not config.scale_attn_weights else head_dim ** (-0.5)
    if config.scale_attn_by_inverse_layer_idx:
        assert layer_idx is not None
        softmax_scale /= float(layer_idx + 1)
Tri Dao's avatar
Tri Dao committed
73
    dwconv = getattr(config, "attn_dwconv", False)
74
    if dwconv:
Tri Dao's avatar
Tri Dao committed
75
76
77
78
79
80
81
82
83
        assert process_group is None, "TensorParallel MHA does not support dwconv yet"
    qkv_proj_bias = getattr(config, "qkv_proj_bias", True)
    out_proj_bias = getattr(config, "out_proj_bias", True)
    rotary_emb_dim = int(getattr(config, "rotary_emb_fraction", 0.0) * head_dim)
    rotary_emb_base = getattr(config, "rotary_emb_base", 10000.0)
    rotary_emb_scale_base = getattr(config, "rotary_emb_scale_base", None)
    rotary_emb_interleaved = getattr(config, "rotary_emb_interleaved", False)
    use_flash_attn = getattr(config, "use_flash_attn", False)
    fused_bias_fc = getattr(config, "fused_bias_fc", False)
84
    if not fused_bias_fc:
Tri Dao's avatar
Tri Dao committed
85
        assert process_group is None, "TensorParallel MHA requires fused_bias_fc"
86
    mha_cls = MHA if process_group is None else ParallelMHA
Tri Dao's avatar
Tri Dao committed
87
88
89
90
91
92
93
94
95
96
97
    serial_kwargs = (
        {"fused_bias_fc": fused_bias_fc, "dwconv": dwconv} if process_group is None else {}
    )
    parallel_kwargs = (
        {
            "process_group": process_group,
            "sequence_parallel": getattr(config, "sequence_parallel", True),
        }
        if process_group is not None
        else {}
    )
Tri Dao's avatar
Tri Dao committed
98
    num_heads_kv = getattr(config, "n_head_kv", None)
Tri Dao's avatar
Tri Dao committed
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
    mixer_cls = partial(
        mha_cls,
        num_heads=config.num_attention_heads,
        num_heads_kv=num_heads_kv,
        qkv_proj_bias=qkv_proj_bias,
        out_proj_bias=out_proj_bias,
        dropout=config.attn_pdrop,
        softmax_scale=softmax_scale,
        causal=True,
        layer_idx=layer_idx,
        rotary_emb_dim=rotary_emb_dim,
        rotary_emb_base=rotary_emb_base,
        rotary_emb_scale_base=rotary_emb_scale_base,
        rotary_emb_interleaved=rotary_emb_interleaved,
        use_flash_attn=use_flash_attn,
        **serial_kwargs,
        **parallel_kwargs,
        **factory_kwargs,
    )
Tri Dao's avatar
Tri Dao committed
118
119
120
    return mixer_cls


121
def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtype=None):
Tri Dao's avatar
Tri Dao committed
122
123
124
125
    factory_kwargs = {"device": device, "dtype": dtype}
    mlp_fc1_bias = getattr(config, "mlp_fc1_bias", True)
    mlp_fc2_bias = getattr(config, "mlp_fc2_bias", True)
    fused_mlp = getattr(config, "fused_mlp", False)
126
    if fused_mlp:
Tri Dao's avatar
Tri Dao committed
127
128
129
130
131
132
133
134
        assert config.activation_function in [
            "gelu_new",
            "gelu_fast",
            "gelu_approx",
            "relu",
            "sqrelu",
        ]
    fused_dense_sqrelu_dense = getattr(config, "fused_dense_sqrelu_dense", False)
135
    if fused_dense_sqrelu_dense:
Tri Dao's avatar
Tri Dao committed
136
137
138
        assert config.activation_function == "sqrelu", (
            "fused_dense_sqrelu_dense only " "supports approximate activation_function sqrelu"
        )
139
140
    assert not (fused_dense_sqrelu_dense and fused_mlp)
    if not fused_mlp and not fused_dense_sqrelu_dense:
Tri Dao's avatar
Tri Dao committed
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
        assert config.activation_function in [
            "gelu",
            "gelu_new",
            "gelu_fast",
            "gelu_approx",
            "relu",
            "sqrelu",
            "glu",
            "swiglu",
            "geglu",
        ]
        if config.activation_function in ["glu", "swiglu", "geglu"]:
            activation = (
                F.sigmoid
                if config.activation_function == "glu"
                else (F.silu if config.activation_function == "swiglu" else F.gelu)
            )
158
            mlp_cls = GatedMlp if process_group is None else ParallelGatedMlp
Tri Dao's avatar
Tri Dao committed
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
            parallel_kwargs = (
                {
                    "process_group": process_group,
                    "sequence_parallel": getattr(config, "sequence_parallel", True),
                }
                if process_group is not None
                else {}
            )
            mlp_cls = partial(
                mlp_cls,
                hidden_features=config.n_inner,
                activation=activation,
                bias1=mlp_fc1_bias,
                bias2=mlp_fc2_bias,
                **parallel_kwargs,
                **factory_kwargs,
            )
Tri Dao's avatar
Tri Dao committed
176
        else:
Tri Dao's avatar
Tri Dao committed
177
            if config.activation_function == "relu":
Tri Dao's avatar
Tri Dao committed
178
                activation = partial(F.relu, inplace=True)
Tri Dao's avatar
Tri Dao committed
179
            elif config.activation_function == "sqrelu":
Tri Dao's avatar
Tri Dao committed
180
181
                activation = sqrelu_fwd
            else:
Tri Dao's avatar
Tri Dao committed
182
183
184
185
186
187
                approximate = (
                    "tanh"
                    if config.activation_function in ["gelu_new", "gelu_fast", "gelu_approx"]
                    else "none"
                )
                activation = partial(F.gelu, approximate=approximate)
Tri Dao's avatar
Tri Dao committed
188
            mlp_cls = Mlp if process_group is None else ParallelMLP
Tri Dao's avatar
Tri Dao committed
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
            parallel_kwargs = (
                {
                    "process_group": process_group,
                    "sequence_parallel": getattr(config, "sequence_parallel", True),
                }
                if process_group is not None
                else {}
            )
            mlp_cls = partial(
                mlp_cls,
                hidden_features=config.n_inner,
                activation=activation,
                bias1=mlp_fc1_bias,
                bias2=mlp_fc2_bias,
                **parallel_kwargs,
                **factory_kwargs,
            )
Tri Dao's avatar
Tri Dao committed
206
    else:
Tri Dao's avatar
Tri Dao committed
207
        mlp_checkpoint_lvl = getattr(config, "mlp_checkpoint_lvl", 0)
Tri Dao's avatar
Tri Dao committed
208
209
210
211
        # mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
        if isinstance(mlp_checkpoint_lvl, Sequence):
            assert layer_idx is not None
            mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
212
213
        if fused_mlp:
            if FusedMLP is None:
Tri Dao's avatar
Tri Dao committed
214
215
216
217
218
219
                raise ImportError("fused_dense is not installed")
            activation = (
                "gelu_approx"
                if config.activation_function in ["gelu_new", "gelu_fast", "gelu_approx"]
                else config.activation_function
            )
220
            mlp_cls = FusedMLP if process_group is None else ParallelFusedMLP
Tri Dao's avatar
Tri Dao committed
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
            parallel_kwargs = (
                {
                    "process_group": process_group,
                    "sequence_parallel": getattr(config, "sequence_parallel", True),
                }
                if process_group is not None
                else {}
            )
            mlp_cls = partial(
                mlp_cls,
                hidden_features=config.n_inner,
                activation=activation,
                checkpoint_lvl=mlp_checkpoint_lvl,
                bias1=mlp_fc1_bias,
                bias2=mlp_fc2_bias,
                **parallel_kwargs,
                **factory_kwargs,
            )
Tri Dao's avatar
Tri Dao committed
239
        elif fused_dense_sqrelu_dense:
240
            if process_group is not None:
Tri Dao's avatar
Tri Dao committed
241
                assert fused_mlp, "Tensor Parallel is not implemented for FusedDenseSqreluDense"
Tri Dao's avatar
Tri Dao committed
242
            assert FusedDenseSqreluDense is not None
Tri Dao's avatar
Tri Dao committed
243
244
245
246
247
248
            mlp_cls = partial(
                FusedDenseSqreluDense,
                hidden_features=config.n_inner,
                checkpoint_lvl=mlp_checkpoint_lvl,
                **factory_kwargs,
            )
Tri Dao's avatar
Tri Dao committed
249
        else:
Tri Dao's avatar
Tri Dao committed
250
            raise RuntimeError("MLP type not supported")
Tri Dao's avatar
Tri Dao committed
251
252
253
    return mlp_cls


254
def create_block(config, layer_idx=None, process_group=None, device=None, dtype=None):
Tri Dao's avatar
Tri Dao committed
255
256
    factory_kwargs = {"device": device, "dtype": dtype}
    sequence_parallel = getattr(config, "sequence_parallel", True)
257
258
    mixer_cls = create_mixer_cls(config, layer_idx, process_group=process_group, **factory_kwargs)
    mlp_cls = create_mlp_cls(config, layer_idx, process_group=process_group, **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
259
260
261
262
263
264
    use_rms_norm = getattr(config, "rms_norm", False)
    norm_cls = partial(
        nn.LayerNorm if not use_rms_norm else RMSNorm,
        eps=config.layer_norm_epsilon,
        **factory_kwargs,
    )
Tri Dao's avatar
Tri Dao committed
265
    # TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
Tri Dao's avatar
Tri Dao committed
266
    residual_in_fp32 = getattr(config, "residual_in_fp32", False)
Tri Dao's avatar
Tri Dao committed
267
    resid_dropout1 = config.resid_pdrop if layer_idx is None or layer_idx > 0 else config.embd_pdrop
Tri Dao's avatar
Tri Dao committed
268
269
    prenorm = getattr(config, "prenorm", True)
    parallel_block = getattr(config, "parallel_block", False)
Tri Dao's avatar
Tri Dao committed
270
271
    if not parallel_block:
        block = Block(
Tri Dao's avatar
Tri Dao committed
272
273
274
275
276
277
278
279
            config.hidden_size,
            mixer_cls,
            mlp_cls,
            norm_cls=norm_cls,
            prenorm=prenorm,
            resid_dropout1=resid_dropout1,
            resid_dropout2=config.resid_pdrop,
            fused_dropout_add_ln=getattr(config, "fused_dropout_add_ln", False),
Tri Dao's avatar
Tri Dao committed
280
281
            residual_in_fp32=residual_in_fp32,
            sequence_parallel=sequence_parallel and process_group is not None,
Tri Dao's avatar
Tri Dao committed
282
            mark_shared_params=process_group is not None,
Tri Dao's avatar
Tri Dao committed
283
284
285
286
        )
    else:
        assert prenorm
        block = ParallelBlock(
Tri Dao's avatar
Tri Dao committed
287
288
289
290
291
292
293
294
            config.hidden_size,
            mixer_cls,
            mlp_cls,
            norm_cls=norm_cls,
            resid_dropout1=resid_dropout1,
            resid_dropout2=config.resid_pdrop,
            tied_norm=getattr(config, "parallel_block_tied_norm", False),
            fused_dropout_add_ln=getattr(config, "fused_dropout_add_ln", False),
Tri Dao's avatar
Tri Dao committed
295
296
            residual_in_fp32=residual_in_fp32,
            sequence_parallel=sequence_parallel and process_group is not None,
Tri Dao's avatar
Tri Dao committed
297
            mark_shared_params=process_group is not None,
Tri Dao's avatar
Tri Dao committed
298
        )
Tri Dao's avatar
Tri Dao committed
299
300
301
302
    block.layer_idx = layer_idx
    return block


303
class GPTPreTrainedModel(nn.Module):
Tri Dao's avatar
Tri Dao committed
304
305
    """An abstract class to handle weights initialization and
    a simple interface for dowloading and loading pretrained models.
306
    """
Tri Dao's avatar
Tri Dao committed
307

308
309
310
311
312
313
314
315
    def __init__(self, config, *inputs, **kwargs):
        super().__init__()
        if not isinstance(config, GPT2Config):
            raise ValueError(
                "Parameter config in `{}(config)` should be an instance of class `GPT2Config`. "
                "To create a model from a Google pretrained model use "
                "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
                    self.__class__.__name__, self.__class__.__name__
Tri Dao's avatar
Tri Dao committed
316
317
                )
            )
318
319
320
        self.config = config

    @classmethod
Tri Dao's avatar
Tri Dao committed
321
322
323
324
325
326
327
328
329
330
331
332
    def from_pretrained(
        cls,
        model_name,
        config,
        *args,
        strict=True,
        device=None,
        dtype=None,
        world_size=1,
        rank=0,
        **kwargs,
    ):
333
334
335
336
337
        """
        Instantiate a GPTPreTrainedModel from a pre-trained model file or a pytorch state dict.
        Download and cache the pre-trained model file if needed.
        """
        # Instantiate model.
338
        model = cls(config, *args, device=device, dtype=dtype, **kwargs)
339
340
        # Load state_dict in cpu because we already initialized the model in GPU, and we don't
        # want extra stuff taking up more GPU memory
Tri Dao's avatar
Tri Dao committed
341
342
        state_dict = state_dict_from_pretrained(model_name, device="cpu", dtype=dtype)
        if model_name.startswith("gpt2"):
Tri Dao's avatar
Tri Dao committed
343
            state_dict = remap_state_dict_hf_gpt2(state_dict, config)
Tri Dao's avatar
Tri Dao committed
344
        elif model_name.startswith("facebook/opt"):
Tri Dao's avatar
Tri Dao committed
345
            state_dict = remap_state_dict_hf_opt(state_dict, config)
Tri Dao's avatar
Tri Dao committed
346
        elif model_name.startswith("EleutherAI/gpt-j-"):
Tri Dao's avatar
Tri Dao committed
347
            state_dict = remap_state_dict_hf_gptj(state_dict, config)
Tri Dao's avatar
Tri Dao committed
348
        elif model_name.startswith("EleutherAI/gpt-neox-"):
Tri Dao's avatar
Tri Dao committed
349
            state_dict = remap_state_dict_hf_gpt_neox(state_dict, config)
Tri Dao's avatar
Tri Dao committed
350
        elif model_name.startswith("tiiuae/falcon-"):
Tri Dao's avatar
Tri Dao committed
351
            state_dict = remap_state_dict_hf_falcon(state_dict, config)
Tri Dao's avatar
Tri Dao committed
352
        else:
Tri Dao's avatar
Tri Dao committed
353
            raise NotImplementedError(f"Model {model_name} not supported")
354
355
356
        if world_size > 1:
            state_dict = shard_state_dict_tp(state_dict, config, world_size, rank)
        load_return = model.load_state_dict(state_dict, strict=strict)
357
358
359
        logger.info(load_return)
        return model

Tri Dao's avatar
Tri Dao committed
360

Tri Dao's avatar
Tri Dao committed
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
def _init_weights(module, n_layer, initializer_range=0.02, rescale_prenorm_residual=True):
    if isinstance(module, nn.Linear):
        nn.init.normal_(module.weight, std=initializer_range)
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Embedding):
        nn.init.normal_(module.weight, std=initializer_range)

    if rescale_prenorm_residual:
        # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
        #   > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
        #   > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
        #   >   -- GPT-2 :: https://openai.com/blog/better-language-models/
        #
        # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
        for name, p in module.named_parameters():
            if name in ["out_proj.weight", "fc2.weight"]:
                # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
                nn.init.normal_(p, mean=0.0, std=initializer_range / math.sqrt(2 * n_layer))


383
class GPTModel(GPTPreTrainedModel):
384
    def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=None):
385
        super().__init__(config)
Tri Dao's avatar
Tri Dao committed
386
        factory_kwargs = {"device": device, "dtype": dtype}
387
        self.process_group = process_group
Tri Dao's avatar
Tri Dao committed
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
        self.sequence_parallel = getattr(config, "sequence_parallel", True)
        assert config.activation_function in [
            "gelu",
            "gelu_new",
            "gelu_fast",
            "gelu_approx",
            "relu",
            "sqrelu",
            "glu",
            "swiglu",
            "geglu",
        ]
        pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
        vocab_size = (
            math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
        )
Tri Dao's avatar
Tri Dao committed
404
        # TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
Tri Dao's avatar
Tri Dao committed
405
        self.residual_in_fp32 = getattr(config, "residual_in_fp32", False)
Tri Dao's avatar
Tri Dao committed
406
        # These 2 options are for OPT-350m
Tri Dao's avatar
Tri Dao committed
407
408
409
        self.prenorm = getattr(config, "prenorm", True)
        use_rms_norm = getattr(config, "rms_norm", False)
        word_embed_proj_dim = getattr(config, "word_embed_proj_dim", None)
Tri Dao's avatar
Tri Dao committed
410
        # For GPT-J, GPT-NeoX
Tri Dao's avatar
Tri Dao committed
411
        self.parallel_block = getattr(config, "parallel_block", False)
Tri Dao's avatar
Tri Dao committed
412

413
        if process_group is None:
Tri Dao's avatar
Tri Dao committed
414
            self.embeddings = GPT2Embeddings(
Tri Dao's avatar
Tri Dao committed
415
416
417
418
419
                config.hidden_size,
                vocab_size,
                config.max_position_embeddings,
                word_embed_proj_dim=word_embed_proj_dim,
                **factory_kwargs,
Tri Dao's avatar
Tri Dao committed
420
            )
421
422
        else:
            self.embeddings = ParallelGPT2Embeddings(
Tri Dao's avatar
Tri Dao committed
423
424
425
426
427
428
                config.hidden_size,
                vocab_size,
                config.max_position_embeddings,
                process_group=process_group,
                sequence_parallel=self.sequence_parallel,
                **factory_kwargs,
429
            )
Tri Dao's avatar
Tri Dao committed
430

Tri Dao's avatar
Tri Dao committed
431
        # We change the order of dropout, residual and layer norm:
Tri Dao's avatar
Tri Dao committed
432
        # Instead of LN -> Attn / MLP -> Dropout -> Add, we do:
Tri Dao's avatar
Tri Dao committed
433
434
435
        # Dropout -> Add -> LN -> Attn / MLP, returning both the residual branch (output of Add) and
        # the main branch (output of MLP). The model definition is unchanged, but the mapping of the
        # nn.Dropout probabilities are changed.
Tri Dao's avatar
Tri Dao committed
436
        # This is for performance reason: we can fuse dropout + add + layer_norm.
Tri Dao's avatar
Tri Dao committed
437
438
439
440
441
442
        self.layers = nn.ModuleList(
            [
                create_block(config, layer_idx=i, process_group=process_group, **factory_kwargs)
                for i in range(config.num_hidden_layers)
            ]
        )
Tri Dao's avatar
Tri Dao committed
443

Tri Dao's avatar
Tri Dao committed
444
        self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
445
        if self.fused_dropout_add_ln:
Tri Dao's avatar
Tri Dao committed
446
447
448
449
            if (not self.parallel_block and dropout_add_layer_norm is None) or (
                self.parallel_block and dropout_add_layer_norm_parallel_residual is None
            ):
                raise ImportError("dropout_layer_norm is not installed")
Tri Dao's avatar
Tri Dao committed
450
451
        if self.prenorm:
            self.drop_f = nn.Dropout(config.resid_pdrop)
Tri Dao's avatar
Tri Dao committed
452
            norm_cls = nn.LayerNorm if not use_rms_norm else RMSNorm
Tri Dao's avatar
Tri Dao committed
453
454
455
            self.ln_f = norm_cls(
                config.hidden_size, eps=config.layer_norm_epsilon, **factory_kwargs
            )
456
        if process_group is not None:
Tri Dao's avatar
Tri Dao committed
457
            for p in self.ln_f.parameters():
458
459
460
461
462
                # Mark the norm parameters as "shared_params" so that we sync their values at init.
                p._shared_params = True
                # Mark the norm params as "sequence_parallel" so we run all-reduce on their grads.
                if self.sequence_parallel:
                    p._sequence_parallel = True
463

Tri Dao's avatar
Tri Dao committed
464
465
466
467
468
469
470
        self.apply(
            partial(
                _init_weights,
                n_layer=config.num_hidden_layers,
                initializer_range=config.initializer_range,
            )
        )
471
472
473
        self.tie_weights()

    def tie_weights(self):
474
        if self.process_group is not None:
475
            sync_shared_params(self, self.process_group)
Tri Dao's avatar
Tri Dao committed
476

477
    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
Tri Dao's avatar
Tri Dao committed
478
479
480
481
        return {
            i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
            for i, layer in enumerate(self.layers)
        }
482

Tri Dao's avatar
Tri Dao committed
483
    def forward(self, input_ids, position_ids=None, inference_params=None):
484
485
486
        # If using Tensor Parallel with sequence parallel, we combine the batch and the seqlen
        # dimensions so that we can split on it easily, in case of small batch size.
        # Only the attention layers need to know the seqlen.
Tri Dao's avatar
Tri Dao committed
487
488
489
490
491
        embedding_kwargs = (
            {"combine_batch_seqlen_dim": True}
            if self.process_group is not None and self.sequence_parallel
            else {}
        )
492
        hidden_states = self.embeddings(input_ids, position_ids=position_ids, **embedding_kwargs)
Tri Dao's avatar
Tri Dao committed
493
494
        if self.parallel_block:
            hidden_states2 = None
Tri Dao's avatar
Tri Dao committed
495
        residual = None
Tri Dao's avatar
Tri Dao committed
496
497
498
499
500
        mixer_kwargs = (
            {"seqlen": input_ids.shape[1]}
            if self.process_group is not None and self.sequence_parallel
            else {}
        )
Tri Dao's avatar
Tri Dao committed
501
        if inference_params is not None:
Tri Dao's avatar
Tri Dao committed
502
            mixer_kwargs["inference_params"] = inference_params
Tri Dao's avatar
Tri Dao committed
503
        for layer in self.layers:
Tri Dao's avatar
Tri Dao committed
504
            if self.prenorm:
Tri Dao's avatar
Tri Dao committed
505
                if not self.parallel_block:
Tri Dao's avatar
Tri Dao committed
506
507
508
                    hidden_states, residual = layer(
                        hidden_states, residual, mixer_kwargs=mixer_kwargs
                    )
Tri Dao's avatar
Tri Dao committed
509
510
511
512
                else:
                    hidden_states, hidden_states2, residual = layer(
                        hidden_states, hidden_states2, residual, mixer_kwargs=mixer_kwargs
                    )
Tri Dao's avatar
Tri Dao committed
513
514
515
516
517
            else:
                hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
        if self.prenorm:
            if not self.fused_dropout_add_ln:
                dropped = self.drop_f(hidden_states)
Tri Dao's avatar
Tri Dao committed
518
519
520
521
                if not self.parallel_block:
                    residual = (dropped + residual) if residual is not None else dropped
                else:
                    dropped2 = self.drop_f(hidden_states2)
Tri Dao's avatar
Tri Dao committed
522
523
524
525
526
                    residual = (
                        (residual + dropped + dropped2)
                        if residual is not None
                        else dropped + dropped2
                    )
Tri Dao's avatar
Tri Dao committed
527
528
                hidden_states = self.ln_f(residual.to(dtype=self.ln_f.weight.dtype))
            else:
Tri Dao's avatar
Tri Dao committed
529
                # Set prenorm=False here since we don't need the residual
530
                if not self.parallel_block:
Tri Dao's avatar
Tri Dao committed
531
532
533
534
535
                    fused_add_norm_fn = (
                        dropout_add_rms_norm
                        if isinstance(self.ln_f, RMSNorm)
                        else dropout_add_layer_norm
                    )
536
                    hidden_states = fused_add_norm_fn(
Tri Dao's avatar
Tri Dao committed
537
538
539
540
541
542
543
544
                        hidden_states,
                        residual,
                        self.ln_f.weight,
                        self.ln_f.bias,
                        self.drop_f.p if self.training else 0.0,
                        self.ln_f.eps,
                        prenorm=False,
                        residual_in_fp32=self.residual_in_fp32,
545
546
                    )
                else:
Tri Dao's avatar
Tri Dao committed
547
548
549
550
551
                    fused_add_norm_fn = (
                        dropout_add_rms_norm_parallel_residual
                        if isinstance(self.ln_f, RMSNorm)
                        else dropout_add_layer_norm_parallel_residual
                    )
552
                    hidden_states, _ = fused_add_norm_fn(
Tri Dao's avatar
Tri Dao committed
553
554
555
556
557
558
559
560
561
562
563
                        hidden_states,
                        hidden_states2,
                        residual,
                        self.ln_f.weight,
                        self.ln_f.bias,
                        None,
                        None,
                        self.drop_f.p if self.training else 0.0,
                        self.ln_f.eps,
                        prenorm=False,
                        residual_in_fp32=self.residual_in_fp32,
564
                    )
Tri Dao's avatar
Tri Dao committed
565
566
567
        return hidden_states


Tri Dao's avatar
Tri Dao committed
568
class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
569
    def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=None):
Tri Dao's avatar
Tri Dao committed
570
        factory_kwargs = {"device": device, "dtype": dtype}
571
        super().__init__(config)
572
573
        self.process_group = process_group
        self.transformer = GPTModel(config, process_group=process_group, **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
574
575
576
577
578
579
        self.tie_word_embeddings = getattr(config, "tie_word_embeddings", True)
        lm_head_bias = getattr(config, "lm_head_bias", False)
        pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
        vocab_size = (
            math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
        )
Tri Dao's avatar
Tri Dao committed
580
        # This option is for OPT-350m
Tri Dao's avatar
Tri Dao committed
581
        word_embed_proj_dim = getattr(config, "word_embed_proj_dim", None)
Tri Dao's avatar
Tri Dao committed
582
583
584
585
586
        embed_dim = config.n_embd if word_embed_proj_dim is None else word_embed_proj_dim
        if word_embed_proj_dim is not None:
            self.project_out = nn.Linear(config.n_embd, embed_dim, bias=False, **factory_kwargs)
        else:
            self.project_out = None
587
        if process_group is None:
Tri Dao's avatar
Tri Dao committed
588
            self.lm_head = nn.Linear(embed_dim, vocab_size, bias=lm_head_bias, **factory_kwargs)
589
590
        else:
            if ColumnParallelLinear is None:
Tri Dao's avatar
Tri Dao committed
591
                raise ImportError("fused_dense_lib is not installed")
592
            self.lm_head = ColumnParallelLinear(
Tri Dao's avatar
Tri Dao committed
593
594
595
596
597
598
                embed_dim,
                vocab_size,
                process_group,
                bias=lm_head_bias,
                sequence_parallel=getattr(config, "sequence_parallel", True),
                **factory_kwargs,
599
            )
Tri Dao's avatar
Tri Dao committed
600
        # Initialize weights and apply final processing
Tri Dao's avatar
Tri Dao committed
601
602
603
604
605
606
607
        self.apply(
            partial(
                _init_weights,
                n_layer=config.num_hidden_layers,
                initializer_range=config.initializer_range,
            )
        )
Tri Dao's avatar
Tri Dao committed
608
609
610
        self.tie_weights()

    def tie_weights(self):
Tri Dao's avatar
Tri Dao committed
611
612
        if self.tie_word_embeddings:
            self.lm_head.weight = self.transformer.embeddings.word_embeddings.weight
613
        if self.process_group is not None:
614
            sync_shared_params(self, self.process_group)
Tri Dao's avatar
Tri Dao committed
615

616
    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
Tri Dao's avatar
Tri Dao committed
617
618
619
        return self.transformer.allocate_inference_cache(
            batch_size, max_seqlen, dtype=dtype, **kwargs
        )
620

621
    def forward(self, input_ids, position_ids=None, inference_params=None, last_token_only=False):
Tri Dao's avatar
Tri Dao committed
622
        """
Tri Dao's avatar
Tri Dao committed
623
624
625
626
        inference_params: for generation. Adapted from Megatron-LM (and Apex)
        https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
        last_token_only: whether to return the logit for the last token only,
            of shape (batch_size, vocab_size)
Tri Dao's avatar
Tri Dao committed
627
        """
Tri Dao's avatar
Tri Dao committed
628
629
630
        hidden_states = self.transformer(
            input_ids, position_ids=position_ids, inference_params=inference_params
        )
631
632
        if last_token_only:
            hidden_states = hidden_states[:, -1]
Tri Dao's avatar
Tri Dao committed
633
634
        if self.project_out is not None:
            hidden_states = self.project_out(hidden_states)
Tri Dao's avatar
Tri Dao committed
635
        lm_logits = self.lm_head(hidden_states)
636
637
638
        # During inference, we want the full logit for sampling
        if isinstance(self.lm_head, ColumnParallelLinear) and inference_params is not None:
            lm_logits, _ = all_gather_raw(lm_logits, self.lm_head.process_group)
Tri Dao's avatar
Tri Dao committed
639
640
            lm_logits = rearrange(lm_logits, "(n b) ... d -> b ... (n d)", b=hidden_states.shape[0])
        CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
Tri Dao's avatar
Tri Dao committed
641
        return CausalLMOutput(logits=lm_logits)
642

Tri Dao's avatar
Tri Dao committed
643
644
645
646
    def load_state_dict(self, state_dict, strict=True):
        # Remapping from our checkpoints that used a different ordering of layers in the block
        # Previous: Attn / MLP -> Dropout -> Add -> LN
        # Current: Dropout -> Add -> LN -> Attn / MLP
Tri Dao's avatar
Tri Dao committed
647
        if "transformer.ln_0.weight" in state_dict:
Tri Dao's avatar
Tri Dao committed
648
            n_layers = len(self.transformer.layers)
Tri Dao's avatar
Tri Dao committed
649
650
651
652
            ln_weight = state_dict.pop(f"transformer.layers.{n_layers - 1}.norm2.weight")
            ln_bias = state_dict.pop(f"transformer.layers.{n_layers - 1}.norm2.bias")
            state_dict["transformer.ln_f.weight"] = ln_weight
            state_dict["transformer.ln_f.bias"] = ln_bias
Tri Dao's avatar
Tri Dao committed
653
            for l in reversed(range(n_layers)):
Tri Dao's avatar
Tri Dao committed
654
655
656
657
                ln_weight = state_dict.pop(f"transformer.layers.{l}.norm1.weight")
                ln_bias = state_dict.pop(f"transformer.layers.{l}.norm1.bias")
                state_dict[f"transformer.layers.{l}.norm2.weight"] = ln_weight
                state_dict[f"transformer.layers.{l}.norm2.bias"] = ln_bias
Tri Dao's avatar
Tri Dao committed
658
                if l > 0:
Tri Dao's avatar
Tri Dao committed
659
660
661
662
663
664
665
666
                    ln_weight = state_dict.pop(f"transformer.layers.{l - 1}.norm2.weight")
                    ln_bias = state_dict.pop(f"transformer.layers.{l - 1}.norm2.bias")
                    state_dict[f"transformer.layers.{l}.norm1.weight"] = ln_weight
                    state_dict[f"transformer.layers.{l}.norm1.bias"] = ln_bias
            ln_weight = state_dict.pop("transformer.ln_0.weight")
            ln_bias = state_dict.pop("transformer.ln_0.bias")
            state_dict[f"transformer.layers.0.norm1.weight"] = ln_weight
            state_dict[f"transformer.layers.0.norm1.bias"] = ln_bias
Tri Dao's avatar
Tri Dao committed
667
668
        return super().load_state_dict(state_dict, strict=strict)

669

Tri Dao's avatar
Tri Dao committed
670
671
672
def shard_state_dict_tp(state_dict, config, world_size, rank):
    """Convert the state_dict of a standard GPT model to the state_dict of a GPT model
    with tensor parallel.
673
674

    This function modifies state_dict in place.
Tri Dao's avatar
Tri Dao committed
675
    """
Tri Dao's avatar
Tri Dao committed
676
677
    pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
    vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
Tri Dao's avatar
Tri Dao committed
678
679
680
681
682
    assert vocab_size % world_size == 0
    assert config.hidden_size % world_size == 0
    inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
    assert inner_dim % world_size == 0

683
684
685
686
687
688
    n_head = config.n_head
    n_head_kv = getattr(config, "n_head_kv", n_head)

    embed_dim = config.hidden_size
    head_dim = embed_dim // n_head

Tri Dao's avatar
Tri Dao committed
689
    def shard_first_dim(state_dict, key):
Tri Dao's avatar
Tri Dao committed
690
691
692
        if key in state_dict:
            x = state_dict[key]
            dim = x.shape[0] // world_size
Tri Dao's avatar
Tri Dao committed
693
            state_dict[key] = x[rank * dim : (rank + 1) * dim]
Tri Dao's avatar
Tri Dao committed
694

695
    def shard_last_dim(state_dict, key, multiple_of=1):
Tri Dao's avatar
Tri Dao committed
696
697
        if key in state_dict:
            x = state_dict[key]
698
699
700
701
702
703
            dim_each_rank = [
                get_dim_for_local_rank(x.size(-1), world_size, local_rank, multiple_of)
                for local_rank in range(world_size)
            ]
            beg, end = tuple(sum(dim_each_rank[:pos]) for pos in (rank, rank + 1))
            state_dict[key] = x[..., beg:end]
Tri Dao's avatar
Tri Dao committed
704

Tri Dao's avatar
Tri Dao committed
705
706
707
708
709
    def shard_gatedmlp_fc1_dim(state_dict, key):
        if key in state_dict:
            x = state_dict[key]
            dim = x.shape[0] // world_size // 2
            state_dict[key] = rearrange(
Tri Dao's avatar
Tri Dao committed
710
                rearrange(x, "(two o) ... -> two o ...", two=2)[:, rank * dim : (rank + 1) * dim],
Tri Dao's avatar
Tri Dao committed
711
                "two o ... -> (two o) ...",
Tri Dao's avatar
Tri Dao committed
712
713
            )

Tri Dao's avatar
Tri Dao committed
714
    def shard_qkv_headdim(state_dict, key):
Tri Dao's avatar
Tri Dao committed
715
        if key in state_dict:
716
            n_head_each_rank = [
Tri Dao's avatar
Tri Dao committed
717
718
                get_dim_for_local_rank(n_head, world_size, local_rank)
                for local_rank in range(world_size)
719
720
            ]
            n_head_kv_each_rank = [
Tri Dao's avatar
Tri Dao committed
721
722
                get_dim_for_local_rank(n_head_kv, world_size, local_rank)
                for local_rank in range(world_size)
723
724
725
726
727
728
729
730
            ]

            beg_n_head = sum(n_head_each_rank[:rank])
            end_n_head = sum(n_head_each_rank[: rank + 1])

            beg_n_head_kv = sum(n_head_kv_each_rank[:rank])
            end_n_head_kv = sum(n_head_kv_each_rank[: rank + 1])

Tri Dao's avatar
Tri Dao committed
731
            if n_head_kv == n_head:
Tri Dao's avatar
Tri Dao committed
732
733
                x = rearrange(state_dict[key], "(three d) ... -> three d ...", three=3)
                state_dict[key] = rearrange(
Tri Dao's avatar
Tri Dao committed
734
735
                    x[:, beg_n_head * head_dim : end_n_head * head_dim],
                    "three d ... -> (three d) ...",
Tri Dao's avatar
Tri Dao committed
736
                )
Tri Dao's avatar
Tri Dao committed
737
            else:
Tri Dao's avatar
Tri Dao committed
738
739
740
741
742
743
744
745
                x = rearrange(
                    state_dict[key],
                    "(nheadqkv headdim) ... -> nheadqkv headdim ...",
                    nheadqkv=n_head + 2 * n_head_kv,
                )
                state_dict[key] = rearrange(
                    torch.cat(
                        [
746
                            x[beg_n_head:end_n_head],
Tri Dao's avatar
Tri Dao committed
747
748
749
750
751
752
753
754
                            x[n_head + beg_n_head_kv : n_head + end_n_head_kv],
                            x[
                                n_head
                                + n_head_kv
                                + beg_n_head_kv : n_head
                                + n_head_kv
                                + end_n_head_kv
                            ],
Tri Dao's avatar
Tri Dao committed
755
756
757
758
759
760
761
762
763
764
765
                        ],
                        dim=0,
                    ),
                    "nheadqkv headdim ... -> (nheadqkv headdim) ...",
                )

    shard_first_dim(state_dict, "transformer.embeddings.word_embeddings.weight")
    if "lm_head.weight" in state_dict:
        shard_first_dim(state_dict, "lm_head.weight")
    if "transformer.embeddings.position_embeddings.weight" in state_dict:
        shard_last_dim(state_dict, "transformer.embeddings.position_embeddings.weight")
Tri Dao's avatar
Tri Dao committed
766
    for i in range(config.num_hidden_layers):
Tri Dao's avatar
Tri Dao committed
767
768
        shard_qkv_headdim(state_dict, f"transformer.layers.{i}.mixer.Wqkv.weight")
        shard_qkv_headdim(state_dict, f"transformer.layers.{i}.mixer.Wqkv.bias")
769
770
771
        shard_last_dim(
            state_dict, f"transformer.layers.{i}.mixer.out_proj.weight", multiple_of=head_dim
        )
Tri Dao's avatar
Tri Dao committed
772
        if rank != 0:
Tri Dao's avatar
Tri Dao committed
773
            state_dict.pop(f"transformer.layers.{i}.mixer.out_proj.bias", None)
Tri Dao's avatar
Tri Dao committed
774
        if config.activation_function in ["glu", "swiglu", "geglu"]:
Tri Dao's avatar
Tri Dao committed
775
776
            shard_gatedmlp_fc1_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.weight")
            shard_gatedmlp_fc1_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.bias")
Tri Dao's avatar
Tri Dao committed
777
        else:
Tri Dao's avatar
Tri Dao committed
778
779
780
            shard_first_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.weight")
            shard_first_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.bias")
        shard_last_dim(state_dict, f"transformer.layers.{i}.mlp.fc2.weight")
Tri Dao's avatar
Tri Dao committed
781
        if rank != 0:
Tri Dao's avatar
Tri Dao committed
782
            state_dict.pop(f"transformer.layers.{i}.mlp.fc2.bias", None)
Tri Dao's avatar
Tri Dao committed
783
784
785
    return state_dict


786
787
788
def combine_state_dicts_tp(state_dicts: list[dict[str, torch.Tensor]], config: GPT2Config):
    """Convert the list of sharded state_dict of a GPT model with tensor parallel to
    the state_dict of a standard GPT model.
789
790

    This function is meant to be the "reverse" of shard_state_dict_tp.
791
792
793

    Precondition:
        - state_dicts should be ordered in the same way as the shards were created.
Tri Dao's avatar
Tri Dao committed
794
795
796
    """
    world_size = len(state_dicts)
    keys = state_dicts[0].keys()
Tri Dao's avatar
Tri Dao committed
797
798
    pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
    vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
Tri Dao's avatar
Tri Dao committed
799
800
801
802
803
    assert vocab_size % world_size == 0
    assert config.hidden_size % world_size == 0
    inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
    assert inner_dim % world_size == 0

Tri Dao's avatar
Tri Dao committed
804
    # Sometimes the word embeddings are sharded on the 0th dim, sometimes on the 1st dim.
Tri Dao's avatar
Tri Dao committed
805
806
    # vocab_size // world_size coordinates are nonzero.
    def combine_word_embeddings(state_dicts, state_dict, key):
Tri Dao's avatar
Tri Dao committed
807
808
        dim = 0 if state_dicts[0][key].shape[0] == vocab_size // world_size else 1
        state_dict[key] = torch.cat([s[key] for s in state_dicts], dim=dim)
Tri Dao's avatar
Tri Dao committed
809
810

    def combine_dim(state_dicts, state_dict, key, dim=-1):
Tri Dao's avatar
Tri Dao committed
811
812
        if key in state_dict:
            state_dict[key] = torch.cat([s[key] for s in state_dicts], dim=dim)
Tri Dao's avatar
Tri Dao committed
813
814

    def combine_qkv_headdim(state_dicts, state_dict, key):
Tri Dao's avatar
Tri Dao committed
815
        n_head = config.n_head
Tri Dao's avatar
Tri Dao committed
816
        n_head_kv = getattr(config, "n_head_kv", n_head)
Tri Dao's avatar
Tri Dao committed
817
        if key in state_dict:
Tri Dao's avatar
Tri Dao committed
818
            if n_head_kv == n_head:
Tri Dao's avatar
Tri Dao committed
819
820
821
822
                xs = [
                    rearrange(s[key], "(three d) ... -> three d ...", three=3) for s in state_dicts
                ]
                state_dict[key] = rearrange(torch.cat(xs, dim=1), "three d ... -> (three d) ...")
Tri Dao's avatar
Tri Dao committed
823
            else:
Tri Dao's avatar
Tri Dao committed
824
825
826
827
828
829
830
831
                xs = [
                    rearrange(
                        s[key],
                        "(nheadqkv headdim) ... -> nheadqkv headdim ...",
                        nheadqkv=n_head + 2 * n_head_kv,
                    )
                    for s in state_dicts
                ]
832
833
834
835
836
837
838
839
                n_head_each_rank = [
                    get_dim_for_local_rank(n_head, world_size, local_rank)
                    for local_rank in range(world_size)
                ]
                n_head_kv_each_rank = [
                    get_dim_for_local_rank(n_head_kv, world_size, local_rank)
                    for local_rank in range(world_size)
                ]
Tri Dao's avatar
Tri Dao committed
840
841
842
                state_dict[key] = rearrange(
                    torch.cat(
                        [
843
844
845
846
847
848
849
850
851
852
853
854
855
                            torch.cat(
                                [x[: n_head_each_rank[rank]] for rank, x in enumerate(xs)], dim=0
                            ),
                            torch.cat(
                                [
                                    x[
                                        n_head_each_rank[rank] : n_head_each_rank[rank]
                                        + n_head_kv_each_rank[rank]
                                    ]
                                    for rank, x in enumerate(xs)
                                ],
                                dim=0,
                            ),
Tri Dao's avatar
Tri Dao committed
856
857
                            torch.cat(
                                [
858
859
                                    x[n_head_each_rank[rank] + n_head_kv_each_rank[rank] :]
                                    for rank, x in enumerate(xs)
Tri Dao's avatar
Tri Dao committed
860
861
862
863
864
865
866
867
                                ],
                                dim=0,
                            ),
                        ],
                        dim=0,
                    ),
                    "nheadqkv headdim ... -> (nheadqkv headdim) ...",
                )
Tri Dao's avatar
Tri Dao committed
868
869
870

    def combine_gated_mlp(state_dicts, state_dict, key):
        if key in state_dict:
Tri Dao's avatar
Tri Dao committed
871
872
            xs = [rearrange(s[key], "(two d) ... -> two d ...", two=2) for s in state_dicts]
            state_dict[key] = rearrange(torch.cat(xs, dim=1), "two d ... -> (two d) ...")
Tri Dao's avatar
Tri Dao committed
873
874

    state_dict = state_dicts[0].copy()  # don't modify state_dict[0] inplace
Tri Dao's avatar
Tri Dao committed
875
876
877
878
879
880
881
882
883
884
885
886
887
888
    combine_word_embeddings(
        state_dicts, state_dict, "transformer.embeddings.word_embeddings.weight"
    )
    if "lm_head.weight" in state_dict:
        combine_word_embeddings(state_dicts, state_dict, "lm_head.weight")
    if "transformer.embeddings.position_embeddings.weight" in state_dict:
        combine_dim(
            state_dicts, state_dict, "transformer.embeddings.position_embeddings.weight", -1
        )
    mlp_combine_fn = (
        combine_gated_mlp
        if config.activation_function in ["glu", "swiglu", "geglu"]
        else partial(combine_dim, dim=0)
    )
Tri Dao's avatar
Tri Dao committed
889
    for i in range(config.num_hidden_layers):
Tri Dao's avatar
Tri Dao committed
890
891
892
893
894
895
        combine_qkv_headdim(state_dicts, state_dict, f"transformer.layers.{i}.mixer.Wqkv.weight")
        combine_qkv_headdim(state_dicts, state_dict, f"transformer.layers.{i}.mixer.Wqkv.bias")
        combine_dim(state_dicts, state_dict, f"transformer.layers.{i}.mixer.out_proj.weight", -1)
        mlp_combine_fn(state_dicts, state_dict, f"transformer.layers.{i}.mlp.fc1.weight")
        combine_dim(state_dicts, state_dict, f"transformer.layers.{i}.mlp.fc1.bias", 0)
        combine_dim(state_dicts, state_dict, f"transformer.layers.{i}.mlp.fc2.weight", -1)
Tri Dao's avatar
Tri Dao committed
896
897
898
899
    return state_dict


def remap_state_dict_hf_gpt2(state_dict, config):
900
901
    # Word embedding and position embedding
    def key_mapping_pos_emb(key):
Tri Dao's avatar
Tri Dao committed
902
903
        return re.sub(r"^wpe.", "transformer.embeddings.position_embeddings.", key)

904
    state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items())
Tri Dao's avatar
Tri Dao committed
905
    word_embeddings = state_dict.pop("wte.weight")
906
    # It's possible that vocab_size is padded to be a multiple of 8, for example.
Tri Dao's avatar
Tri Dao committed
907
908
909
    pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
    vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
    state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
910
        word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
911
    )
Tri Dao's avatar
Tri Dao committed
912
    state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
913
914

    # LayerNorm
Tri Dao's avatar
Tri Dao committed
915
    def key_mapping_ln(key):
Tri Dao's avatar
Tri Dao committed
916
917
        key = re.sub(r"^ln_f.(weight|bias)", r"transformer.ln_f.\1", key)
        key = re.sub(r"^h.(\d+).ln_(1|2).(weight|bias)", r"transformer.layers.\1.norm\2.\3", key)
Tri Dao's avatar
Tri Dao committed
918
        return key
Tri Dao's avatar
Tri Dao committed
919

Tri Dao's avatar
Tri Dao committed
920
    state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
921
922
923

    # MLP
    for d in range(config.num_hidden_layers):
Tri Dao's avatar
Tri Dao committed
924
925
926
927
928
        W1 = state_dict.pop(f"h.{d}.mlp.c_fc.weight")
        state_dict[f"transformer.layers.{d}.mlp.fc1.weight"] = W1.t()
        W2 = state_dict.pop(f"h.{d}.mlp.c_proj.weight")
        state_dict[f"transformer.layers.{d}.mlp.fc2.weight"] = W2.t()

929
    def key_mapping_mlp(key):
Tri Dao's avatar
Tri Dao committed
930
931
        key = re.sub(r"^h.(\d+).mlp.c_fc.bias", r"transformer.layers.\1.mlp.fc1.bias", key)
        key = re.sub(r"^h.(\d+).mlp.c_proj.bias", r"transformer.layers.\1.mlp.fc2.bias", key)
932
        return key
Tri Dao's avatar
Tri Dao committed
933

934
935
936
937
    state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())

    # Attention
    for d in range(config.num_hidden_layers):
Tri Dao's avatar
Tri Dao committed
938
939
940
941
942
943
        state_dict.pop(f"h.{d}.attn.bias")  # We don't store this bias
        Wqkv = state_dict.pop(f"h.{d}.attn.c_attn.weight")
        state_dict[f"transformer.layers.{d}.mixer.Wqkv.weight"] = Wqkv.t()
        Wout = state_dict.pop(f"h.{d}.attn.c_proj.weight")
        state_dict[f"transformer.layers.{d}.mixer.out_proj.weight"] = Wout.t()

944
    def key_mapping_attn(key):
Tri Dao's avatar
Tri Dao committed
945
946
947
948
        key = re.sub(r"^h.(\d+).attn.c_attn.bias", r"transformer.layers.\1.mixer.Wqkv.bias", key)
        key = re.sub(
            r"^h.(\d+).attn.c_proj.bias", r"transformer.layers.\1.mixer.out_proj.bias", key
        )
949
        return key
Tri Dao's avatar
Tri Dao committed
950

951
952
953
    state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())

    return state_dict
954
955


Tri Dao's avatar
Tri Dao committed
956
957
def remap_state_dict_megatron(state_dict, config):
    def key_mapping_transformer(key):
Tri Dao's avatar
Tri Dao committed
958
959
        key = re.sub(r"^language_model.encoder.", "transformer.", key)
        key = re.sub(r"^language_model.", "transformer.", key)
Tri Dao's avatar
Tri Dao committed
960
        return key
Tri Dao's avatar
Tri Dao committed
961

Tri Dao's avatar
Tri Dao committed
962
    state_dict = OrderedDict((key_mapping_transformer(k), v) for k, v in state_dict.items())
963

Tri Dao's avatar
Tri Dao committed
964
965
    # Word embedding and position embedding
    def key_mapping_pos_emb(key):
Tri Dao's avatar
Tri Dao committed
966
967
        return re.sub(r"^wpe.", "transformer.embeddings.position_embeddings.", key)

Tri Dao's avatar
Tri Dao committed
968
    state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items())
Tri Dao's avatar
Tri Dao committed
969
    word_embeddings = state_dict.pop("transformer.embedding.word_embeddings.weight")
Tri Dao's avatar
Tri Dao committed
970
    # It's possible that vocab_size is padded to be a multiple of 8, for example.
Tri Dao's avatar
Tri Dao committed
971
972
973
974
975
    pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
    vocab_size = (
        math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) * pad_vocab_size_multiple
    )
    state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
Tri Dao's avatar
Tri Dao committed
976
977
        word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
    )
Tri Dao's avatar
Tri Dao committed
978
    state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
979

Tri Dao's avatar
Tri Dao committed
980
981
    # LayerNorm
    def key_mapping_ln(key):
Tri Dao's avatar
Tri Dao committed
982
983
984
985
986
987
988
989
990
991
992
        key = re.sub(r"^transformer.final_layernorm.(weight|bias)", r"transformer.ln_f.\1", key)
        key = re.sub(
            r"^transformer.layers.(\d+).input_layernorm.(weight|bias)",
            r"transformer.layers.\1.norm1.\2",
            key,
        )
        key = re.sub(
            r"^transformer.layers.(\d+).post_attention_layernorm.(weight|bias)",
            r"transformer.layers.\1.norm2.\2",
            key,
        )
Tri Dao's avatar
Tri Dao committed
993
        return key
Tri Dao's avatar
Tri Dao committed
994

Tri Dao's avatar
Tri Dao committed
995
    state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
996

Tri Dao's avatar
Tri Dao committed
997
998
    # MLP
    def key_mapping_mlp(key):
Tri Dao's avatar
Tri Dao committed
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
        key = re.sub(
            r"^transformer.layers.(\d+).mlp.dense_h_to_4h.(weight|bias)",
            r"transformer.layers.\1.mlp.fc1.\2",
            key,
        )
        key = re.sub(
            r"^transformer.layers.(\d+).mlp.dense_4h_to_h.(weight|bias)",
            r"transformer.layers.\1.mlp.fc2.\2",
            key,
        )
Tri Dao's avatar
Tri Dao committed
1009
        return key
Tri Dao's avatar
Tri Dao committed
1010

Tri Dao's avatar
Tri Dao committed
1011
    state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
1012

Tri Dao's avatar
Tri Dao committed
1013
1014
    # Attention
    def key_mapping_attn(key):
Tri Dao's avatar
Tri Dao committed
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
        key = re.sub(
            r"^transformer.layers.(\d+).self_attention.rotary_emb.inv_freq",
            r"transformer.layers.\1.mixer.rotary_emb.inv_freq",
            key,
        )
        key = re.sub(
            r"^transformer.layers.(\d+).self_attention.query_key_value.(weight|bias)",
            r"transformer.layers.\1.mixer.Wqkv.\2",
            key,
        )
        key = re.sub(
            r"^transformer.layers.(\d+).self_attention.dense.(weight|bias)",
            r"transformer.layers.\1.mixer.out_proj.\2",
            key,
        )
Tri Dao's avatar
Tri Dao committed
1030
        return key
Tri Dao's avatar
Tri Dao committed
1031

Tri Dao's avatar
Tri Dao committed
1032
1033
1034
1035
1036
    state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
    # Megatron stores Wqkv as ((nheads 3 headdim), hidden_dim)
    # while we store Wqkv as ((3 nheads headdim), hidden_dim)
    headdim = config.hidden_size // config.num_attention_heads
    for d in range(config.num_hidden_layers):
Tri Dao's avatar
Tri Dao committed
1037
1038
1039
1040
1041
1042
        Wqkv = state_dict.pop(f"transformer.layers.{d}.mixer.Wqkv.weight")
        state_dict[f"transformer.layers.{d}.mixer.Wqkv.weight"] = rearrange(
            Wqkv,
            "(nheads three headdim) ... -> (three nheads headdim) ...",
            three=3,
            headdim=headdim,
Tri Dao's avatar
Tri Dao committed
1043
        )
Tri Dao's avatar
Tri Dao committed
1044
1045
1046
        bqkv = state_dict.pop(f"transformer.layers.{d}.mixer.Wqkv.bias")
        state_dict[f"transformer.layers.{d}.mixer.Wqkv.bias"] = rearrange(
            bqkv, "(nheads three headdim) -> (three nheads headdim)", three=3, headdim=headdim
Tri Dao's avatar
Tri Dao committed
1047
        )
1048
1049

    return state_dict