"...replknet/deploy/replknet-XL-deploy_32xb64_in1k-320px.py" did not exist on "8f9dd0edefa849b2552ba149141ddb369bdbec4e"
gpt.py 46 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
# Copyright (c) 2023, Tri Dao.
Tri Dao's avatar
Tri Dao committed
2

3
import logging
Tri Dao's avatar
Tri Dao committed
4
import math
5
import re
Tri Dao's avatar
Tri Dao committed
6
from collections import OrderedDict, namedtuple
Tri Dao's avatar
Tri Dao committed
7
from collections.abc import Sequence
Tri Dao's avatar
Tri Dao committed
8
from functools import partial
Yuchao Dai's avatar
Yuchao Dai committed
9
from typing import Dict, List
Tri Dao's avatar
Tri Dao committed
10
11
12
13

import torch
import torch.nn as nn
import torch.nn.functional as F
14
from einops import rearrange
Tri Dao's avatar
Tri Dao committed
15
16
from transformers import GPT2Config

Kevin Hu's avatar
Kevin Hu committed
17
from flash_attn.models.bigcode import remap_state_dict_hf_bigcode
Tri Dao's avatar
Tri Dao committed
18
19
20
from flash_attn.models.falcon import remap_state_dict_hf_falcon
from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox
from flash_attn.models.gptj import remap_state_dict_hf_gptj
21
from flash_attn.models.llama import remap_state_dict_hf_llama
Tri Dao's avatar
Tri Dao committed
22
from flash_attn.models.opt import remap_state_dict_hf_opt
Tri Dao's avatar
Tri Dao committed
23
from flash_attn.modules.block import Block, ParallelBlock
24
from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings
Tri Dao's avatar
Tri Dao committed
25
from flash_attn.modules.mha import MHA, ParallelMHA
Kevin Hu's avatar
Kevin Hu committed
26
27
28
29
30
31
32
33
from flash_attn.modules.mlp import (
    FusedMLP,
    GatedMlp,
    Mlp,
    ParallelFusedMLP,
    ParallelGatedMlp,
    ParallelMLP,
)
Tri Dao's avatar
Tri Dao committed
34
from flash_attn.ops.activations import sqrelu_fwd
Kevin Hu's avatar
Kevin Hu committed
35
from flash_attn.utils.distributed import all_gather_raw, get_dim_for_local_rank, sync_shared_params
Tri Dao's avatar
Tri Dao committed
36
from flash_attn.utils.generation import GenerationMixin
Tri Dao's avatar
Tri Dao committed
37
from flash_attn.utils.pretrained import state_dict_from_pretrained
38
39
40
41
42

try:
    from flash_attn.ops.fused_dense import ColumnParallelLinear
except ImportError:
    ColumnParallelLinear = None
Tri Dao's avatar
Tri Dao committed
43
44
45
46
47
48

try:
    from flash_attn.ops.layer_norm import dropout_add_layer_norm
except ImportError:
    dropout_add_layer_norm = None

49
try:
Kevin Hu's avatar
Kevin Hu committed
50
    from flash_attn.ops.layer_norm import dropout_add_layer_norm_parallel_residual
51
52
53
except ImportError:
    dropout_add_layer_norm_parallel_residual = None

Tri Dao's avatar
Tri Dao committed
54
55
56
try:
    from flash_attn.ops.rms_norm import RMSNorm, dropout_add_rms_norm
except ImportError:
57
    RMSNorm, dropout_add_rms_norm = None, None
Tri Dao's avatar
Tri Dao committed
58
59
60
61
62
63

try:
    from flash_attn.ops.rms_norm import dropout_add_rms_norm_parallel_residual
except ImportError:
    dropout_add_rms_norm_parallel_residual = None

Tri Dao's avatar
Tri Dao committed
64
try:
Tri Dao's avatar
Tri Dao committed
65
    from flash_attn.ops.triton.mlp import FusedDenseSqreluDense
Tri Dao's avatar
Tri Dao committed
66
67
68
except ImportError:
    FusedDenseSqreluDense = None

69
70
71
logger = logging.getLogger(__name__)


72
def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dtype=None):
Tri Dao's avatar
Tri Dao committed
73
74
    factory_kwargs = {"device": device, "dtype": dtype}
    head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
Tri Dao's avatar
Tri Dao committed
75
76
77
78
    softmax_scale = 1.0 if not config.scale_attn_weights else head_dim ** (-0.5)
    if config.scale_attn_by_inverse_layer_idx:
        assert layer_idx is not None
        softmax_scale /= float(layer_idx + 1)
Tri Dao's avatar
Tri Dao committed
79
    dwconv = getattr(config, "attn_dwconv", False)
80
    if dwconv:
Tri Dao's avatar
Tri Dao committed
81
82
83
84
85
86
87
88
89
        assert process_group is None, "TensorParallel MHA does not support dwconv yet"
    qkv_proj_bias = getattr(config, "qkv_proj_bias", True)
    out_proj_bias = getattr(config, "out_proj_bias", True)
    rotary_emb_dim = int(getattr(config, "rotary_emb_fraction", 0.0) * head_dim)
    rotary_emb_base = getattr(config, "rotary_emb_base", 10000.0)
    rotary_emb_scale_base = getattr(config, "rotary_emb_scale_base", None)
    rotary_emb_interleaved = getattr(config, "rotary_emb_interleaved", False)
    use_flash_attn = getattr(config, "use_flash_attn", False)
    fused_bias_fc = getattr(config, "fused_bias_fc", False)
90
    if not fused_bias_fc:
Tri Dao's avatar
Tri Dao committed
91
        assert process_group is None, "TensorParallel MHA requires fused_bias_fc"
92
    mha_cls = MHA if process_group is None else ParallelMHA
Tri Dao's avatar
Tri Dao committed
93
94
95
96
97
98
99
100
101
102
103
    serial_kwargs = (
        {"fused_bias_fc": fused_bias_fc, "dwconv": dwconv} if process_group is None else {}
    )
    parallel_kwargs = (
        {
            "process_group": process_group,
            "sequence_parallel": getattr(config, "sequence_parallel", True),
        }
        if process_group is not None
        else {}
    )
Tri Dao's avatar
Tri Dao committed
104
    num_heads_kv = getattr(config, "n_head_kv", None)
Tri Dao's avatar
Tri Dao committed
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
    mixer_cls = partial(
        mha_cls,
        num_heads=config.num_attention_heads,
        num_heads_kv=num_heads_kv,
        qkv_proj_bias=qkv_proj_bias,
        out_proj_bias=out_proj_bias,
        dropout=config.attn_pdrop,
        softmax_scale=softmax_scale,
        causal=True,
        layer_idx=layer_idx,
        rotary_emb_dim=rotary_emb_dim,
        rotary_emb_base=rotary_emb_base,
        rotary_emb_scale_base=rotary_emb_scale_base,
        rotary_emb_interleaved=rotary_emb_interleaved,
        use_flash_attn=use_flash_attn,
        **serial_kwargs,
        **parallel_kwargs,
        **factory_kwargs,
    )
Tri Dao's avatar
Tri Dao committed
124
125
126
    return mixer_cls


127
def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtype=None):
Tri Dao's avatar
Tri Dao committed
128
129
130
131
    factory_kwargs = {"device": device, "dtype": dtype}
    mlp_fc1_bias = getattr(config, "mlp_fc1_bias", True)
    mlp_fc2_bias = getattr(config, "mlp_fc2_bias", True)
    fused_mlp = getattr(config, "fused_mlp", False)
132
    if fused_mlp:
Tri Dao's avatar
Tri Dao committed
133
134
135
136
        assert config.activation_function in [
            "gelu_new",
            "gelu_fast",
            "gelu_approx",
Kevin Hu's avatar
Kevin Hu committed
137
            "gelu_pytorch_tanh",
Tri Dao's avatar
Tri Dao committed
138
139
140
141
            "relu",
            "sqrelu",
        ]
    fused_dense_sqrelu_dense = getattr(config, "fused_dense_sqrelu_dense", False)
142
    if fused_dense_sqrelu_dense:
Tri Dao's avatar
Tri Dao committed
143
144
145
        assert config.activation_function == "sqrelu", (
            "fused_dense_sqrelu_dense only " "supports approximate activation_function sqrelu"
        )
146
147
    assert not (fused_dense_sqrelu_dense and fused_mlp)
    if not fused_mlp and not fused_dense_sqrelu_dense:
Tri Dao's avatar
Tri Dao committed
148
149
150
151
152
        assert config.activation_function in [
            "gelu",
            "gelu_new",
            "gelu_fast",
            "gelu_approx",
Kevin Hu's avatar
Kevin Hu committed
153
            "gelu_pytorch_tanh",
Tri Dao's avatar
Tri Dao committed
154
155
156
157
158
159
160
161
162
163
164
165
            "relu",
            "sqrelu",
            "glu",
            "swiglu",
            "geglu",
        ]
        if config.activation_function in ["glu", "swiglu", "geglu"]:
            activation = (
                F.sigmoid
                if config.activation_function == "glu"
                else (F.silu if config.activation_function == "swiglu" else F.gelu)
            )
166
            mlp_cls = GatedMlp if process_group is None else ParallelGatedMlp
Tri Dao's avatar
Tri Dao committed
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
            parallel_kwargs = (
                {
                    "process_group": process_group,
                    "sequence_parallel": getattr(config, "sequence_parallel", True),
                }
                if process_group is not None
                else {}
            )
            mlp_cls = partial(
                mlp_cls,
                hidden_features=config.n_inner,
                activation=activation,
                bias1=mlp_fc1_bias,
                bias2=mlp_fc2_bias,
                **parallel_kwargs,
                **factory_kwargs,
            )
Tri Dao's avatar
Tri Dao committed
184
        else:
Tri Dao's avatar
Tri Dao committed
185
            if config.activation_function == "relu":
Tri Dao's avatar
Tri Dao committed
186
                activation = partial(F.relu, inplace=True)
Tri Dao's avatar
Tri Dao committed
187
            elif config.activation_function == "sqrelu":
Tri Dao's avatar
Tri Dao committed
188
189
                activation = sqrelu_fwd
            else:
Tri Dao's avatar
Tri Dao committed
190
191
                approximate = (
                    "tanh"
Kevin Hu's avatar
Kevin Hu committed
192
193
                    if config.activation_function
                    in ["gelu_new", "gelu_fast", "gelu_approx", "gelu_pytorch_tanh"]
Tri Dao's avatar
Tri Dao committed
194
195
196
                    else "none"
                )
                activation = partial(F.gelu, approximate=approximate)
Tri Dao's avatar
Tri Dao committed
197
            mlp_cls = Mlp if process_group is None else ParallelMLP
Tri Dao's avatar
Tri Dao committed
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
            parallel_kwargs = (
                {
                    "process_group": process_group,
                    "sequence_parallel": getattr(config, "sequence_parallel", True),
                }
                if process_group is not None
                else {}
            )
            mlp_cls = partial(
                mlp_cls,
                hidden_features=config.n_inner,
                activation=activation,
                bias1=mlp_fc1_bias,
                bias2=mlp_fc2_bias,
                **parallel_kwargs,
                **factory_kwargs,
            )
Tri Dao's avatar
Tri Dao committed
215
    else:
Tri Dao's avatar
Tri Dao committed
216
        mlp_checkpoint_lvl = getattr(config, "mlp_checkpoint_lvl", 0)
Tri Dao's avatar
Tri Dao committed
217
218
219
220
        # mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
        if isinstance(mlp_checkpoint_lvl, Sequence):
            assert layer_idx is not None
            mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
221
222
        if fused_mlp:
            if FusedMLP is None:
Tri Dao's avatar
Tri Dao committed
223
224
225
                raise ImportError("fused_dense is not installed")
            activation = (
                "gelu_approx"
Kevin Hu's avatar
Kevin Hu committed
226
227
                if config.activation_function
                in ["gelu_new", "gelu_fast", "gelu_approx", "gelu_pytorch_tanh"]
Tri Dao's avatar
Tri Dao committed
228
229
                else config.activation_function
            )
230
            mlp_cls = FusedMLP if process_group is None else ParallelFusedMLP
Tri Dao's avatar
Tri Dao committed
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
            parallel_kwargs = (
                {
                    "process_group": process_group,
                    "sequence_parallel": getattr(config, "sequence_parallel", True),
                }
                if process_group is not None
                else {}
            )
            mlp_cls = partial(
                mlp_cls,
                hidden_features=config.n_inner,
                activation=activation,
                checkpoint_lvl=mlp_checkpoint_lvl,
                bias1=mlp_fc1_bias,
                bias2=mlp_fc2_bias,
                **parallel_kwargs,
                **factory_kwargs,
            )
Tri Dao's avatar
Tri Dao committed
249
        elif fused_dense_sqrelu_dense:
250
            if process_group is not None:
Tri Dao's avatar
Tri Dao committed
251
                assert fused_mlp, "Tensor Parallel is not implemented for FusedDenseSqreluDense"
Tri Dao's avatar
Tri Dao committed
252
            assert FusedDenseSqreluDense is not None
Tri Dao's avatar
Tri Dao committed
253
254
255
256
257
258
            mlp_cls = partial(
                FusedDenseSqreluDense,
                hidden_features=config.n_inner,
                checkpoint_lvl=mlp_checkpoint_lvl,
                **factory_kwargs,
            )
Tri Dao's avatar
Tri Dao committed
259
        else:
Tri Dao's avatar
Tri Dao committed
260
            raise RuntimeError("MLP type not supported")
Tri Dao's avatar
Tri Dao committed
261
262
263
    return mlp_cls


264
def create_block(config, layer_idx=None, process_group=None, device=None, dtype=None):
Tri Dao's avatar
Tri Dao committed
265
266
    factory_kwargs = {"device": device, "dtype": dtype}
    sequence_parallel = getattr(config, "sequence_parallel", True)
267
268
    mixer_cls = create_mixer_cls(config, layer_idx, process_group=process_group, **factory_kwargs)
    mlp_cls = create_mlp_cls(config, layer_idx, process_group=process_group, **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
269
270
271
272
273
274
    use_rms_norm = getattr(config, "rms_norm", False)
    norm_cls = partial(
        nn.LayerNorm if not use_rms_norm else RMSNorm,
        eps=config.layer_norm_epsilon,
        **factory_kwargs,
    )
Tri Dao's avatar
Tri Dao committed
275
    # TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
Tri Dao's avatar
Tri Dao committed
276
    residual_in_fp32 = getattr(config, "residual_in_fp32", False)
Tri Dao's avatar
Tri Dao committed
277
    resid_dropout1 = config.resid_pdrop if layer_idx is None or layer_idx > 0 else config.embd_pdrop
Tri Dao's avatar
Tri Dao committed
278
279
    prenorm = getattr(config, "prenorm", True)
    parallel_block = getattr(config, "parallel_block", False)
Tri Dao's avatar
Tri Dao committed
280
281
    if not parallel_block:
        block = Block(
Tri Dao's avatar
Tri Dao committed
282
283
284
285
286
287
288
289
            config.hidden_size,
            mixer_cls,
            mlp_cls,
            norm_cls=norm_cls,
            prenorm=prenorm,
            resid_dropout1=resid_dropout1,
            resid_dropout2=config.resid_pdrop,
            fused_dropout_add_ln=getattr(config, "fused_dropout_add_ln", False),
Tri Dao's avatar
Tri Dao committed
290
291
            residual_in_fp32=residual_in_fp32,
            sequence_parallel=sequence_parallel and process_group is not None,
Tri Dao's avatar
Tri Dao committed
292
            mark_shared_params=process_group is not None,
Tri Dao's avatar
Tri Dao committed
293
294
295
296
        )
    else:
        assert prenorm
        block = ParallelBlock(
Tri Dao's avatar
Tri Dao committed
297
298
299
300
301
302
303
304
            config.hidden_size,
            mixer_cls,
            mlp_cls,
            norm_cls=norm_cls,
            resid_dropout1=resid_dropout1,
            resid_dropout2=config.resid_pdrop,
            tied_norm=getattr(config, "parallel_block_tied_norm", False),
            fused_dropout_add_ln=getattr(config, "fused_dropout_add_ln", False),
Tri Dao's avatar
Tri Dao committed
305
306
            residual_in_fp32=residual_in_fp32,
            sequence_parallel=sequence_parallel and process_group is not None,
Tri Dao's avatar
Tri Dao committed
307
            mark_shared_params=process_group is not None,
Tri Dao's avatar
Tri Dao committed
308
        )
Tri Dao's avatar
Tri Dao committed
309
310
311
312
    block.layer_idx = layer_idx
    return block


313
class GPTPreTrainedModel(nn.Module):
Tri Dao's avatar
Tri Dao committed
314
315
    """An abstract class to handle weights initialization and
    a simple interface for dowloading and loading pretrained models.
316
    """
Tri Dao's avatar
Tri Dao committed
317

318
319
320
321
322
323
324
325
    def __init__(self, config, *inputs, **kwargs):
        super().__init__()
        if not isinstance(config, GPT2Config):
            raise ValueError(
                "Parameter config in `{}(config)` should be an instance of class `GPT2Config`. "
                "To create a model from a Google pretrained model use "
                "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
                    self.__class__.__name__, self.__class__.__name__
Tri Dao's avatar
Tri Dao committed
326
327
                )
            )
328
329
330
        self.config = config

    @classmethod
Tri Dao's avatar
Tri Dao committed
331
332
333
334
335
336
337
338
339
340
341
342
    def from_pretrained(
        cls,
        model_name,
        config,
        *args,
        strict=True,
        device=None,
        dtype=None,
        world_size=1,
        rank=0,
        **kwargs,
    ):
343
344
345
346
347
        """
        Instantiate a GPTPreTrainedModel from a pre-trained model file or a pytorch state dict.
        Download and cache the pre-trained model file if needed.
        """
        # Instantiate model.
348
        model = cls(config, *args, device=device, dtype=dtype, **kwargs)
349
350
        # Load state_dict in cpu because we already initialized the model in GPU, and we don't
        # want extra stuff taking up more GPU memory
Tri Dao's avatar
Tri Dao committed
351
352
        state_dict = state_dict_from_pretrained(model_name, device="cpu", dtype=dtype)
        if model_name.startswith("gpt2"):
Tri Dao's avatar
Tri Dao committed
353
            state_dict = remap_state_dict_hf_gpt2(state_dict, config)
Tri Dao's avatar
Tri Dao committed
354
        elif model_name.startswith("facebook/opt"):
Tri Dao's avatar
Tri Dao committed
355
            state_dict = remap_state_dict_hf_opt(state_dict, config)
356
357
358
359
        elif (
            model_name.startswith("EleutherAI/gpt-j-")
            or model_name.startswith("togethercomputer/GPT-JT-")
        ):
Tri Dao's avatar
Tri Dao committed
360
            state_dict = remap_state_dict_hf_gptj(state_dict, config)
361
362
363
364
365
        elif (
            model_name.startswith("EleutherAI/gpt-neox-")
            or model_name.startswith("EleutherAI/pythia-")
            or model_name.startswith("togethercomputer/RedPajama-INCITE-")
        ):
Tri Dao's avatar
Tri Dao committed
366
            state_dict = remap_state_dict_hf_gpt_neox(state_dict, config)
Tri Dao's avatar
Tri Dao committed
367
        elif model_name.startswith("tiiuae/falcon-"):
Tri Dao's avatar
Tri Dao committed
368
            state_dict = remap_state_dict_hf_falcon(state_dict, config)
369
370
        elif model_name.startswith("meta-llama/Llama-"):
            state_dict = remap_state_dict_hf_llama(state_dict, config)
Kevin Hu's avatar
Kevin Hu committed
371
372
        elif model_name.startswith("bigcode/") or model_name.startswith("WizardLM/"):
            state_dict = remap_state_dict_hf_bigcode(state_dict, config)
Tri Dao's avatar
Tri Dao committed
373
        else:
Tri Dao's avatar
Tri Dao committed
374
            raise NotImplementedError(f"Model {model_name} not supported")
375
376
377
        if world_size > 1:
            state_dict = shard_state_dict_tp(state_dict, config, world_size, rank)
        load_return = model.load_state_dict(state_dict, strict=strict)
378
379
380
        logger.info(load_return)
        return model

Tri Dao's avatar
Tri Dao committed
381

Tri Dao's avatar
Tri Dao committed
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
def _init_weights(module, n_layer, initializer_range=0.02, rescale_prenorm_residual=True):
    if isinstance(module, nn.Linear):
        nn.init.normal_(module.weight, std=initializer_range)
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Embedding):
        nn.init.normal_(module.weight, std=initializer_range)

    if rescale_prenorm_residual:
        # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
        #   > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
        #   > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
        #   >   -- GPT-2 :: https://openai.com/blog/better-language-models/
        #
        # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
        for name, p in module.named_parameters():
            if name in ["out_proj.weight", "fc2.weight"]:
                # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
                nn.init.normal_(p, mean=0.0, std=initializer_range / math.sqrt(2 * n_layer))


404
class GPTModel(GPTPreTrainedModel):
405
    def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=None):
406
        super().__init__(config)
Tri Dao's avatar
Tri Dao committed
407
        factory_kwargs = {"device": device, "dtype": dtype}
408
        self.process_group = process_group
Tri Dao's avatar
Tri Dao committed
409
410
411
412
413
414
        self.sequence_parallel = getattr(config, "sequence_parallel", True)
        assert config.activation_function in [
            "gelu",
            "gelu_new",
            "gelu_fast",
            "gelu_approx",
Kevin Hu's avatar
Kevin Hu committed
415
            "gelu_pytorch_tanh",
Tri Dao's avatar
Tri Dao committed
416
417
418
419
420
421
422
423
424
425
            "relu",
            "sqrelu",
            "glu",
            "swiglu",
            "geglu",
        ]
        pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
        vocab_size = (
            math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
        )
Tri Dao's avatar
Tri Dao committed
426
        # TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
Tri Dao's avatar
Tri Dao committed
427
        self.residual_in_fp32 = getattr(config, "residual_in_fp32", False)
Tri Dao's avatar
Tri Dao committed
428
        # These 2 options are for OPT-350m
Tri Dao's avatar
Tri Dao committed
429
430
431
        self.prenorm = getattr(config, "prenorm", True)
        use_rms_norm = getattr(config, "rms_norm", False)
        word_embed_proj_dim = getattr(config, "word_embed_proj_dim", None)
Tri Dao's avatar
Tri Dao committed
432
        # For GPT-J, GPT-NeoX
Tri Dao's avatar
Tri Dao committed
433
        self.parallel_block = getattr(config, "parallel_block", False)
Tri Dao's avatar
Tri Dao committed
434

435
        if process_group is None:
Tri Dao's avatar
Tri Dao committed
436
            self.embeddings = GPT2Embeddings(
Tri Dao's avatar
Tri Dao committed
437
438
439
440
441
                config.hidden_size,
                vocab_size,
                config.max_position_embeddings,
                word_embed_proj_dim=word_embed_proj_dim,
                **factory_kwargs,
Tri Dao's avatar
Tri Dao committed
442
            )
443
444
        else:
            self.embeddings = ParallelGPT2Embeddings(
Tri Dao's avatar
Tri Dao committed
445
446
447
448
449
450
                config.hidden_size,
                vocab_size,
                config.max_position_embeddings,
                process_group=process_group,
                sequence_parallel=self.sequence_parallel,
                **factory_kwargs,
451
            )
Tri Dao's avatar
Tri Dao committed
452

Tri Dao's avatar
Tri Dao committed
453
        # We change the order of dropout, residual and layer norm:
Tri Dao's avatar
Tri Dao committed
454
        # Instead of LN -> Attn / MLP -> Dropout -> Add, we do:
Tri Dao's avatar
Tri Dao committed
455
456
457
        # Dropout -> Add -> LN -> Attn / MLP, returning both the residual branch (output of Add) and
        # the main branch (output of MLP). The model definition is unchanged, but the mapping of the
        # nn.Dropout probabilities are changed.
Tri Dao's avatar
Tri Dao committed
458
        # This is for performance reason: we can fuse dropout + add + layer_norm.
Tri Dao's avatar
Tri Dao committed
459
460
461
462
463
464
        self.layers = nn.ModuleList(
            [
                create_block(config, layer_idx=i, process_group=process_group, **factory_kwargs)
                for i in range(config.num_hidden_layers)
            ]
        )
Tri Dao's avatar
Tri Dao committed
465

Tri Dao's avatar
Tri Dao committed
466
        self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
467
        if self.fused_dropout_add_ln:
Tri Dao's avatar
Tri Dao committed
468
469
470
471
            if (not self.parallel_block and dropout_add_layer_norm is None) or (
                self.parallel_block and dropout_add_layer_norm_parallel_residual is None
            ):
                raise ImportError("dropout_layer_norm is not installed")
Tri Dao's avatar
Tri Dao committed
472
473
        if self.prenorm:
            self.drop_f = nn.Dropout(config.resid_pdrop)
Tri Dao's avatar
Tri Dao committed
474
            norm_cls = nn.LayerNorm if not use_rms_norm else RMSNorm
Tri Dao's avatar
Tri Dao committed
475
476
477
            self.ln_f = norm_cls(
                config.hidden_size, eps=config.layer_norm_epsilon, **factory_kwargs
            )
478
        if process_group is not None:
Tri Dao's avatar
Tri Dao committed
479
            for p in self.ln_f.parameters():
480
481
482
483
484
                # Mark the norm parameters as "shared_params" so that we sync their values at init.
                p._shared_params = True
                # Mark the norm params as "sequence_parallel" so we run all-reduce on their grads.
                if self.sequence_parallel:
                    p._sequence_parallel = True
485

Tri Dao's avatar
Tri Dao committed
486
487
488
489
490
491
492
        self.apply(
            partial(
                _init_weights,
                n_layer=config.num_hidden_layers,
                initializer_range=config.initializer_range,
            )
        )
493
494
495
        self.tie_weights()

    def tie_weights(self):
496
        if self.process_group is not None:
497
            sync_shared_params(self, self.process_group)
Tri Dao's avatar
Tri Dao committed
498

499
    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
Tri Dao's avatar
Tri Dao committed
500
501
502
503
        return {
            i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
            for i, layer in enumerate(self.layers)
        }
504

Tri Dao's avatar
Tri Dao committed
505
    def forward(self, input_ids, position_ids=None, inference_params=None):
506
507
508
        # If using Tensor Parallel with sequence parallel, we combine the batch and the seqlen
        # dimensions so that we can split on it easily, in case of small batch size.
        # Only the attention layers need to know the seqlen.
Tri Dao's avatar
Tri Dao committed
509
510
511
512
513
        embedding_kwargs = (
            {"combine_batch_seqlen_dim": True}
            if self.process_group is not None and self.sequence_parallel
            else {}
        )
514
        hidden_states = self.embeddings(input_ids, position_ids=position_ids, **embedding_kwargs)
Tri Dao's avatar
Tri Dao committed
515
516
        if self.parallel_block:
            hidden_states2 = None
Tri Dao's avatar
Tri Dao committed
517
        residual = None
Tri Dao's avatar
Tri Dao committed
518
519
520
521
522
        mixer_kwargs = (
            {"seqlen": input_ids.shape[1]}
            if self.process_group is not None and self.sequence_parallel
            else {}
        )
Tri Dao's avatar
Tri Dao committed
523
        if inference_params is not None:
Tri Dao's avatar
Tri Dao committed
524
            mixer_kwargs["inference_params"] = inference_params
Tri Dao's avatar
Tri Dao committed
525
        for layer in self.layers:
Tri Dao's avatar
Tri Dao committed
526
            if self.prenorm:
Tri Dao's avatar
Tri Dao committed
527
                if not self.parallel_block:
Tri Dao's avatar
Tri Dao committed
528
529
530
                    hidden_states, residual = layer(
                        hidden_states, residual, mixer_kwargs=mixer_kwargs
                    )
Tri Dao's avatar
Tri Dao committed
531
532
533
534
                else:
                    hidden_states, hidden_states2, residual = layer(
                        hidden_states, hidden_states2, residual, mixer_kwargs=mixer_kwargs
                    )
Tri Dao's avatar
Tri Dao committed
535
536
537
538
539
            else:
                hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
        if self.prenorm:
            if not self.fused_dropout_add_ln:
                dropped = self.drop_f(hidden_states)
Tri Dao's avatar
Tri Dao committed
540
541
542
543
                if not self.parallel_block:
                    residual = (dropped + residual) if residual is not None else dropped
                else:
                    dropped2 = self.drop_f(hidden_states2)
Tri Dao's avatar
Tri Dao committed
544
545
546
547
548
                    residual = (
                        (residual + dropped + dropped2)
                        if residual is not None
                        else dropped + dropped2
                    )
Tri Dao's avatar
Tri Dao committed
549
550
                hidden_states = self.ln_f(residual.to(dtype=self.ln_f.weight.dtype))
            else:
Tri Dao's avatar
Tri Dao committed
551
                # Set prenorm=False here since we don't need the residual
552
                if not self.parallel_block:
Tri Dao's avatar
Tri Dao committed
553
554
555
556
557
                    fused_add_norm_fn = (
                        dropout_add_rms_norm
                        if isinstance(self.ln_f, RMSNorm)
                        else dropout_add_layer_norm
                    )
558
                    hidden_states = fused_add_norm_fn(
Tri Dao's avatar
Tri Dao committed
559
560
561
562
563
564
565
566
                        hidden_states,
                        residual,
                        self.ln_f.weight,
                        self.ln_f.bias,
                        self.drop_f.p if self.training else 0.0,
                        self.ln_f.eps,
                        prenorm=False,
                        residual_in_fp32=self.residual_in_fp32,
567
568
                    )
                else:
Tri Dao's avatar
Tri Dao committed
569
570
571
572
573
                    fused_add_norm_fn = (
                        dropout_add_rms_norm_parallel_residual
                        if isinstance(self.ln_f, RMSNorm)
                        else dropout_add_layer_norm_parallel_residual
                    )
574
                    hidden_states, _ = fused_add_norm_fn(
Tri Dao's avatar
Tri Dao committed
575
576
577
578
579
580
581
582
583
584
585
                        hidden_states,
                        hidden_states2,
                        residual,
                        self.ln_f.weight,
                        self.ln_f.bias,
                        None,
                        None,
                        self.drop_f.p if self.training else 0.0,
                        self.ln_f.eps,
                        prenorm=False,
                        residual_in_fp32=self.residual_in_fp32,
586
                    )
Tri Dao's avatar
Tri Dao committed
587
588
589
        return hidden_states


Tri Dao's avatar
Tri Dao committed
590
class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
591
    def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=None):
Tri Dao's avatar
Tri Dao committed
592
        factory_kwargs = {"device": device, "dtype": dtype}
593
        super().__init__(config)
594
595
        self.process_group = process_group
        self.transformer = GPTModel(config, process_group=process_group, **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
596
597
598
599
600
601
        self.tie_word_embeddings = getattr(config, "tie_word_embeddings", True)
        lm_head_bias = getattr(config, "lm_head_bias", False)
        pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
        vocab_size = (
            math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
        )
Tri Dao's avatar
Tri Dao committed
602
        # This option is for OPT-350m
Tri Dao's avatar
Tri Dao committed
603
        word_embed_proj_dim = getattr(config, "word_embed_proj_dim", None)
Tri Dao's avatar
Tri Dao committed
604
605
606
607
608
        embed_dim = config.n_embd if word_embed_proj_dim is None else word_embed_proj_dim
        if word_embed_proj_dim is not None:
            self.project_out = nn.Linear(config.n_embd, embed_dim, bias=False, **factory_kwargs)
        else:
            self.project_out = None
609
        if process_group is None:
Tri Dao's avatar
Tri Dao committed
610
            self.lm_head = nn.Linear(embed_dim, vocab_size, bias=lm_head_bias, **factory_kwargs)
611
612
        else:
            if ColumnParallelLinear is None:
Tri Dao's avatar
Tri Dao committed
613
                raise ImportError("fused_dense_lib is not installed")
614
            self.lm_head = ColumnParallelLinear(
Tri Dao's avatar
Tri Dao committed
615
616
617
618
619
620
                embed_dim,
                vocab_size,
                process_group,
                bias=lm_head_bias,
                sequence_parallel=getattr(config, "sequence_parallel", True),
                **factory_kwargs,
621
            )
Tri Dao's avatar
Tri Dao committed
622
        # Initialize weights and apply final processing
Tri Dao's avatar
Tri Dao committed
623
624
625
626
627
628
629
        self.apply(
            partial(
                _init_weights,
                n_layer=config.num_hidden_layers,
                initializer_range=config.initializer_range,
            )
        )
Tri Dao's avatar
Tri Dao committed
630
631
632
        self.tie_weights()

    def tie_weights(self):
Tri Dao's avatar
Tri Dao committed
633
634
        if self.tie_word_embeddings:
            self.lm_head.weight = self.transformer.embeddings.word_embeddings.weight
635
        if self.process_group is not None:
636
            sync_shared_params(self, self.process_group)
Tri Dao's avatar
Tri Dao committed
637

638
    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
Tri Dao's avatar
Tri Dao committed
639
640
641
        return self.transformer.allocate_inference_cache(
            batch_size, max_seqlen, dtype=dtype, **kwargs
        )
642

643
    def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0):
Tri Dao's avatar
Tri Dao committed
644
        """
645
        input_ids: (batch, seqlen) int tensor
Tri Dao's avatar
Tri Dao committed
646
647
        inference_params: for generation. Adapted from Megatron-LM (and Apex)
        https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
648
        num_last_tokens: if > 0, only return the logits for the last n tokens
Tri Dao's avatar
Tri Dao committed
649
        """
Kevin Hu's avatar
Kevin Hu committed
650
651
652
        assert (
            input_ids.ndim == 2
        ), f"Expected `input_ids` to have shape [b, slen], but got shape {input_ids.shape}"
653
        b, slen = input_ids.shape
Tri Dao's avatar
Tri Dao committed
654
655
656
        hidden_states = self.transformer(
            input_ids, position_ids=position_ids, inference_params=inference_params
        )
Tri Dao's avatar
Tri Dao committed
657
658
        if inference_params is not None:
            assert hidden_states.ndim == 3, "sequence_parallel is not supported in generation mode"
659
660
        if num_last_tokens > 0:
            hidden_states = hidden_states[:, -num_last_tokens:]
Tri Dao's avatar
Tri Dao committed
661
662
        if self.project_out is not None:
            hidden_states = self.project_out(hidden_states)
Tri Dao's avatar
Tri Dao committed
663
        lm_logits = self.lm_head(hidden_states)
664
665
666
        # During inference, we want the full logit for sampling
        if isinstance(self.lm_head, ColumnParallelLinear) and inference_params is not None:
            lm_logits, _ = all_gather_raw(lm_logits, self.lm_head.process_group)
667
            lm_logits = rearrange(lm_logits, "(n b) ... d -> b ... (n d)", b=b)
Tri Dao's avatar
Tri Dao committed
668
        CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
Tri Dao's avatar
Tri Dao committed
669
        return CausalLMOutput(logits=lm_logits)
670

Tri Dao's avatar
Tri Dao committed
671
672
673
674
    def load_state_dict(self, state_dict, strict=True):
        # Remapping from our checkpoints that used a different ordering of layers in the block
        # Previous: Attn / MLP -> Dropout -> Add -> LN
        # Current: Dropout -> Add -> LN -> Attn / MLP
Tri Dao's avatar
Tri Dao committed
675
        if "transformer.ln_0.weight" in state_dict:
Tri Dao's avatar
Tri Dao committed
676
            n_layers = len(self.transformer.layers)
Tri Dao's avatar
Tri Dao committed
677
678
679
680
            ln_weight = state_dict.pop(f"transformer.layers.{n_layers - 1}.norm2.weight")
            ln_bias = state_dict.pop(f"transformer.layers.{n_layers - 1}.norm2.bias")
            state_dict["transformer.ln_f.weight"] = ln_weight
            state_dict["transformer.ln_f.bias"] = ln_bias
Tri Dao's avatar
Tri Dao committed
681
            for l in reversed(range(n_layers)):
Tri Dao's avatar
Tri Dao committed
682
683
684
685
                ln_weight = state_dict.pop(f"transformer.layers.{l}.norm1.weight")
                ln_bias = state_dict.pop(f"transformer.layers.{l}.norm1.bias")
                state_dict[f"transformer.layers.{l}.norm2.weight"] = ln_weight
                state_dict[f"transformer.layers.{l}.norm2.bias"] = ln_bias
Tri Dao's avatar
Tri Dao committed
686
                if l > 0:
Tri Dao's avatar
Tri Dao committed
687
688
689
690
691
692
693
694
                    ln_weight = state_dict.pop(f"transformer.layers.{l - 1}.norm2.weight")
                    ln_bias = state_dict.pop(f"transformer.layers.{l - 1}.norm2.bias")
                    state_dict[f"transformer.layers.{l}.norm1.weight"] = ln_weight
                    state_dict[f"transformer.layers.{l}.norm1.bias"] = ln_bias
            ln_weight = state_dict.pop("transformer.ln_0.weight")
            ln_bias = state_dict.pop("transformer.ln_0.bias")
            state_dict[f"transformer.layers.0.norm1.weight"] = ln_weight
            state_dict[f"transformer.layers.0.norm1.bias"] = ln_bias
Tri Dao's avatar
Tri Dao committed
695
696
        return super().load_state_dict(state_dict, strict=strict)

697

Tri Dao's avatar
Tri Dao committed
698
699
700
def shard_state_dict_tp(state_dict, config, world_size, rank):
    """Convert the state_dict of a standard GPT model to the state_dict of a GPT model
    with tensor parallel.
701
702

    This function modifies state_dict in place.
Tri Dao's avatar
Tri Dao committed
703
    """
Tri Dao's avatar
Tri Dao committed
704
705
    pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
    vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
Tri Dao's avatar
Tri Dao committed
706
707
708
709
710
    assert vocab_size % world_size == 0
    assert config.hidden_size % world_size == 0
    inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
    assert inner_dim % world_size == 0

711
712
713
714
715
716
    n_head = config.n_head
    n_head_kv = getattr(config, "n_head_kv", n_head)

    embed_dim = config.hidden_size
    head_dim = embed_dim // n_head

Tri Dao's avatar
Tri Dao committed
717
    def shard_first_dim(state_dict, key):
Tri Dao's avatar
Tri Dao committed
718
719
720
        if key in state_dict:
            x = state_dict[key]
            dim = x.shape[0] // world_size
Tri Dao's avatar
Tri Dao committed
721
            state_dict[key] = x[rank * dim : (rank + 1) * dim]
Tri Dao's avatar
Tri Dao committed
722

723
    def shard_last_dim(state_dict, key, multiple_of=1):
Tri Dao's avatar
Tri Dao committed
724
725
        if key in state_dict:
            x = state_dict[key]
726
727
728
729
730
731
            dim_each_rank = [
                get_dim_for_local_rank(x.size(-1), world_size, local_rank, multiple_of)
                for local_rank in range(world_size)
            ]
            beg, end = tuple(sum(dim_each_rank[:pos]) for pos in (rank, rank + 1))
            state_dict[key] = x[..., beg:end]
Tri Dao's avatar
Tri Dao committed
732

Tri Dao's avatar
Tri Dao committed
733
734
735
736
737
    def shard_gatedmlp_fc1_dim(state_dict, key):
        if key in state_dict:
            x = state_dict[key]
            dim = x.shape[0] // world_size // 2
            state_dict[key] = rearrange(
Tri Dao's avatar
Tri Dao committed
738
                rearrange(x, "(two o) ... -> two o ...", two=2)[:, rank * dim : (rank + 1) * dim],
Tri Dao's avatar
Tri Dao committed
739
                "two o ... -> (two o) ...",
Tri Dao's avatar
Tri Dao committed
740
741
            )

Tri Dao's avatar
Tri Dao committed
742
    def shard_qkv_headdim(state_dict, key):
Tri Dao's avatar
Tri Dao committed
743
        if key in state_dict:
744
            n_head_each_rank = [
Tri Dao's avatar
Tri Dao committed
745
746
                get_dim_for_local_rank(n_head, world_size, local_rank)
                for local_rank in range(world_size)
747
748
            ]
            n_head_kv_each_rank = [
Tri Dao's avatar
Tri Dao committed
749
750
                get_dim_for_local_rank(n_head_kv, world_size, local_rank)
                for local_rank in range(world_size)
751
752
753
754
755
756
757
758
            ]

            beg_n_head = sum(n_head_each_rank[:rank])
            end_n_head = sum(n_head_each_rank[: rank + 1])

            beg_n_head_kv = sum(n_head_kv_each_rank[:rank])
            end_n_head_kv = sum(n_head_kv_each_rank[: rank + 1])

Tri Dao's avatar
Tri Dao committed
759
            if n_head_kv == n_head:
Tri Dao's avatar
Tri Dao committed
760
761
                x = rearrange(state_dict[key], "(three d) ... -> three d ...", three=3)
                state_dict[key] = rearrange(
Tri Dao's avatar
Tri Dao committed
762
763
                    x[:, beg_n_head * head_dim : end_n_head * head_dim],
                    "three d ... -> (three d) ...",
Tri Dao's avatar
Tri Dao committed
764
                )
Tri Dao's avatar
Tri Dao committed
765
            else:
Tri Dao's avatar
Tri Dao committed
766
767
768
769
770
771
772
773
                x = rearrange(
                    state_dict[key],
                    "(nheadqkv headdim) ... -> nheadqkv headdim ...",
                    nheadqkv=n_head + 2 * n_head_kv,
                )
                state_dict[key] = rearrange(
                    torch.cat(
                        [
774
                            x[beg_n_head:end_n_head],
Tri Dao's avatar
Tri Dao committed
775
776
777
778
779
780
781
782
                            x[n_head + beg_n_head_kv : n_head + end_n_head_kv],
                            x[
                                n_head
                                + n_head_kv
                                + beg_n_head_kv : n_head
                                + n_head_kv
                                + end_n_head_kv
                            ],
Tri Dao's avatar
Tri Dao committed
783
784
785
786
787
788
789
790
791
792
793
                        ],
                        dim=0,
                    ),
                    "nheadqkv headdim ... -> (nheadqkv headdim) ...",
                )

    shard_first_dim(state_dict, "transformer.embeddings.word_embeddings.weight")
    if "lm_head.weight" in state_dict:
        shard_first_dim(state_dict, "lm_head.weight")
    if "transformer.embeddings.position_embeddings.weight" in state_dict:
        shard_last_dim(state_dict, "transformer.embeddings.position_embeddings.weight")
Tri Dao's avatar
Tri Dao committed
794
    for i in range(config.num_hidden_layers):
Tri Dao's avatar
Tri Dao committed
795
796
        shard_qkv_headdim(state_dict, f"transformer.layers.{i}.mixer.Wqkv.weight")
        shard_qkv_headdim(state_dict, f"transformer.layers.{i}.mixer.Wqkv.bias")
797
798
799
        shard_last_dim(
            state_dict, f"transformer.layers.{i}.mixer.out_proj.weight", multiple_of=head_dim
        )
Tri Dao's avatar
Tri Dao committed
800
        if rank != 0:
Tri Dao's avatar
Tri Dao committed
801
            state_dict.pop(f"transformer.layers.{i}.mixer.out_proj.bias", None)
Tri Dao's avatar
Tri Dao committed
802
        if config.activation_function in ["glu", "swiglu", "geglu"]:
Tri Dao's avatar
Tri Dao committed
803
804
            shard_gatedmlp_fc1_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.weight")
            shard_gatedmlp_fc1_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.bias")
Tri Dao's avatar
Tri Dao committed
805
        else:
Tri Dao's avatar
Tri Dao committed
806
807
808
            shard_first_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.weight")
            shard_first_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.bias")
        shard_last_dim(state_dict, f"transformer.layers.{i}.mlp.fc2.weight")
Tri Dao's avatar
Tri Dao committed
809
        if rank != 0:
Tri Dao's avatar
Tri Dao committed
810
            state_dict.pop(f"transformer.layers.{i}.mlp.fc2.bias", None)
Tri Dao's avatar
Tri Dao committed
811
812
813
    return state_dict


Yuchao Dai's avatar
Yuchao Dai committed
814
def combine_state_dicts_tp(state_dicts: List[Dict[str, torch.Tensor]], config: GPT2Config):
815
816
    """Convert the list of sharded state_dict of a GPT model with tensor parallel to
    the state_dict of a standard GPT model.
817
818

    This function is meant to be the "reverse" of shard_state_dict_tp.
819
820
821

    Precondition:
        - state_dicts should be ordered in the same way as the shards were created.
Tri Dao's avatar
Tri Dao committed
822
823
824
    """
    world_size = len(state_dicts)
    keys = state_dicts[0].keys()
Tri Dao's avatar
Tri Dao committed
825
826
    pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
    vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
Tri Dao's avatar
Tri Dao committed
827
828
829
830
    assert vocab_size % world_size == 0
    assert config.hidden_size % world_size == 0
    inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
    assert inner_dim % world_size == 0
831
832
    assert config.hidden_size % config.n_head == 0
    headdim = config.hidden_size // config.n_head
Tri Dao's avatar
Tri Dao committed
833

Tri Dao's avatar
Tri Dao committed
834
    # Sometimes the word embeddings are sharded on the 0th dim, sometimes on the 1st dim.
Tri Dao's avatar
Tri Dao committed
835
836
    # vocab_size // world_size coordinates are nonzero.
    def combine_word_embeddings(state_dicts, state_dict, key):
Tri Dao's avatar
Tri Dao committed
837
838
        dim = 0 if state_dicts[0][key].shape[0] == vocab_size // world_size else 1
        state_dict[key] = torch.cat([s[key] for s in state_dicts], dim=dim)
Tri Dao's avatar
Tri Dao committed
839
840

    def combine_dim(state_dicts, state_dict, key, dim=-1):
Tri Dao's avatar
Tri Dao committed
841
842
        if key in state_dict:
            state_dict[key] = torch.cat([s[key] for s in state_dicts], dim=dim)
Tri Dao's avatar
Tri Dao committed
843
844

    def combine_qkv_headdim(state_dicts, state_dict, key):
Tri Dao's avatar
Tri Dao committed
845
        n_head = config.n_head
Tri Dao's avatar
Tri Dao committed
846
        n_head_kv = getattr(config, "n_head_kv", n_head)
Tri Dao's avatar
Tri Dao committed
847
        if key in state_dict:
Tri Dao's avatar
Tri Dao committed
848
            if n_head_kv == n_head:
Tri Dao's avatar
Tri Dao committed
849
850
851
852
                xs = [
                    rearrange(s[key], "(three d) ... -> three d ...", three=3) for s in state_dicts
                ]
                state_dict[key] = rearrange(torch.cat(xs, dim=1), "three d ... -> (three d) ...")
Tri Dao's avatar
Tri Dao committed
853
            else:
854
855
856
857
858
859
860
861
                n_head_each_rank = [
                    get_dim_for_local_rank(n_head, world_size, local_rank)
                    for local_rank in range(world_size)
                ]
                n_head_kv_each_rank = [
                    get_dim_for_local_rank(n_head_kv, world_size, local_rank)
                    for local_rank in range(world_size)
                ]
862
863
864
865
866
867
868
                xs = [
                    rearrange(
                        s[key],
                        "(nheadqkv headdim) ... -> nheadqkv headdim ...",
                        nheadqkv=rank_n_head + 2 * rank_n_head_kv,
                        headdim=headdim,
                    )
Kevin Hu's avatar
Kevin Hu committed
869
870
871
                    for s, rank_n_head, rank_n_head_kv in zip(
                        state_dicts, n_head_each_rank, n_head_kv_each_rank
                    )
872
                ]
Kevin Hu's avatar
Kevin Hu committed
873
                wq = torch.cat([x[: n_head_each_rank[rank]] for rank, x in enumerate(xs)], dim=0)
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
                wk = torch.cat(
                    [
                        x[
                            n_head_each_rank[rank] : n_head_each_rank[rank]
                            + n_head_kv_each_rank[rank]
                        ]
                        for rank, x in enumerate(xs)
                    ],
                    dim=0,
                )
                wv = torch.cat(
                    [
                        x[n_head_each_rank[rank] + n_head_kv_each_rank[rank] :]
                        for rank, x in enumerate(xs)
                    ],
                    dim=0,
                )
                wqkv = torch.cat(
                    [wq, wk, wv],
                    dim=0,
                )
Tri Dao's avatar
Tri Dao committed
895
                state_dict[key] = rearrange(
896
                    wqkv,
Tri Dao's avatar
Tri Dao committed
897
898
                    "nheadqkv headdim ... -> (nheadqkv headdim) ...",
                )
Tri Dao's avatar
Tri Dao committed
899
900
901

    def combine_gated_mlp(state_dicts, state_dict, key):
        if key in state_dict:
Tri Dao's avatar
Tri Dao committed
902
903
            xs = [rearrange(s[key], "(two d) ... -> two d ...", two=2) for s in state_dicts]
            state_dict[key] = rearrange(torch.cat(xs, dim=1), "two d ... -> (two d) ...")
Tri Dao's avatar
Tri Dao committed
904
905

    state_dict = state_dicts[0].copy()  # don't modify state_dict[0] inplace
Tri Dao's avatar
Tri Dao committed
906
907
908
909
910
911
912
913
914
915
916
917
918
919
    combine_word_embeddings(
        state_dicts, state_dict, "transformer.embeddings.word_embeddings.weight"
    )
    if "lm_head.weight" in state_dict:
        combine_word_embeddings(state_dicts, state_dict, "lm_head.weight")
    if "transformer.embeddings.position_embeddings.weight" in state_dict:
        combine_dim(
            state_dicts, state_dict, "transformer.embeddings.position_embeddings.weight", -1
        )
    mlp_combine_fn = (
        combine_gated_mlp
        if config.activation_function in ["glu", "swiglu", "geglu"]
        else partial(combine_dim, dim=0)
    )
Tri Dao's avatar
Tri Dao committed
920
    for i in range(config.num_hidden_layers):
Tri Dao's avatar
Tri Dao committed
921
922
923
924
925
926
        combine_qkv_headdim(state_dicts, state_dict, f"transformer.layers.{i}.mixer.Wqkv.weight")
        combine_qkv_headdim(state_dicts, state_dict, f"transformer.layers.{i}.mixer.Wqkv.bias")
        combine_dim(state_dicts, state_dict, f"transformer.layers.{i}.mixer.out_proj.weight", -1)
        mlp_combine_fn(state_dicts, state_dict, f"transformer.layers.{i}.mlp.fc1.weight")
        combine_dim(state_dicts, state_dict, f"transformer.layers.{i}.mlp.fc1.bias", 0)
        combine_dim(state_dicts, state_dict, f"transformer.layers.{i}.mlp.fc2.weight", -1)
Tri Dao's avatar
Tri Dao committed
927
928
929
930
    return state_dict


def remap_state_dict_hf_gpt2(state_dict, config):
931
932
    # Word embedding and position embedding
    def key_mapping_pos_emb(key):
Tri Dao's avatar
Tri Dao committed
933
934
        return re.sub(r"^wpe.", "transformer.embeddings.position_embeddings.", key)

935
    state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items())
Tri Dao's avatar
Tri Dao committed
936
    word_embeddings = state_dict.pop("wte.weight")
937
    # It's possible that vocab_size is padded to be a multiple of 8, for example.
Tri Dao's avatar
Tri Dao committed
938
939
940
    pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
    vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
    state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
941
        word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
942
    )
Tri Dao's avatar
Tri Dao committed
943
    state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
944
945

    # LayerNorm
Tri Dao's avatar
Tri Dao committed
946
    def key_mapping_ln(key):
Tri Dao's avatar
Tri Dao committed
947
948
        key = re.sub(r"^ln_f.(weight|bias)", r"transformer.ln_f.\1", key)
        key = re.sub(r"^h.(\d+).ln_(1|2).(weight|bias)", r"transformer.layers.\1.norm\2.\3", key)
Tri Dao's avatar
Tri Dao committed
949
        return key
Tri Dao's avatar
Tri Dao committed
950

Tri Dao's avatar
Tri Dao committed
951
    state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
952
953
954

    # MLP
    for d in range(config.num_hidden_layers):
Tri Dao's avatar
Tri Dao committed
955
956
957
958
959
        W1 = state_dict.pop(f"h.{d}.mlp.c_fc.weight")
        state_dict[f"transformer.layers.{d}.mlp.fc1.weight"] = W1.t()
        W2 = state_dict.pop(f"h.{d}.mlp.c_proj.weight")
        state_dict[f"transformer.layers.{d}.mlp.fc2.weight"] = W2.t()

960
    def key_mapping_mlp(key):
Tri Dao's avatar
Tri Dao committed
961
962
        key = re.sub(r"^h.(\d+).mlp.c_fc.bias", r"transformer.layers.\1.mlp.fc1.bias", key)
        key = re.sub(r"^h.(\d+).mlp.c_proj.bias", r"transformer.layers.\1.mlp.fc2.bias", key)
963
        return key
Tri Dao's avatar
Tri Dao committed
964

965
966
967
968
    state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())

    # Attention
    for d in range(config.num_hidden_layers):
Tri Dao's avatar
Tri Dao committed
969
970
971
972
973
974
        state_dict.pop(f"h.{d}.attn.bias")  # We don't store this bias
        Wqkv = state_dict.pop(f"h.{d}.attn.c_attn.weight")
        state_dict[f"transformer.layers.{d}.mixer.Wqkv.weight"] = Wqkv.t()
        Wout = state_dict.pop(f"h.{d}.attn.c_proj.weight")
        state_dict[f"transformer.layers.{d}.mixer.out_proj.weight"] = Wout.t()

975
    def key_mapping_attn(key):
Tri Dao's avatar
Tri Dao committed
976
977
978
979
        key = re.sub(r"^h.(\d+).attn.c_attn.bias", r"transformer.layers.\1.mixer.Wqkv.bias", key)
        key = re.sub(
            r"^h.(\d+).attn.c_proj.bias", r"transformer.layers.\1.mixer.out_proj.bias", key
        )
980
        return key
Tri Dao's avatar
Tri Dao committed
981

982
983
984
    state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())

    return state_dict
985
986


Tri Dao's avatar
Tri Dao committed
987
988
def remap_state_dict_megatron(state_dict, config):
    def key_mapping_transformer(key):
Tri Dao's avatar
Tri Dao committed
989
990
        key = re.sub(r"^language_model.encoder.", "transformer.", key)
        key = re.sub(r"^language_model.", "transformer.", key)
Tri Dao's avatar
Tri Dao committed
991
        return key
Tri Dao's avatar
Tri Dao committed
992

Tri Dao's avatar
Tri Dao committed
993
    state_dict = OrderedDict((key_mapping_transformer(k), v) for k, v in state_dict.items())
994

Tri Dao's avatar
Tri Dao committed
995
996
    # Word embedding and position embedding
    def key_mapping_pos_emb(key):
Tri Dao's avatar
Tri Dao committed
997
998
        return re.sub(r"^wpe.", "transformer.embeddings.position_embeddings.", key)

Tri Dao's avatar
Tri Dao committed
999
    state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items())
Tri Dao's avatar
Tri Dao committed
1000
    word_embeddings = state_dict.pop("transformer.embedding.word_embeddings.weight")
Tri Dao's avatar
Tri Dao committed
1001
    # It's possible that vocab_size is padded to be a multiple of 8, for example.
Tri Dao's avatar
Tri Dao committed
1002
1003
1004
1005
1006
    pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
    vocab_size = (
        math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) * pad_vocab_size_multiple
    )
    state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
Tri Dao's avatar
Tri Dao committed
1007
1008
        word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
    )
Tri Dao's avatar
Tri Dao committed
1009
    state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
1010

Tri Dao's avatar
Tri Dao committed
1011
1012
    # LayerNorm
    def key_mapping_ln(key):
Tri Dao's avatar
Tri Dao committed
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
        key = re.sub(r"^transformer.final_layernorm.(weight|bias)", r"transformer.ln_f.\1", key)
        key = re.sub(
            r"^transformer.layers.(\d+).input_layernorm.(weight|bias)",
            r"transformer.layers.\1.norm1.\2",
            key,
        )
        key = re.sub(
            r"^transformer.layers.(\d+).post_attention_layernorm.(weight|bias)",
            r"transformer.layers.\1.norm2.\2",
            key,
        )
Tri Dao's avatar
Tri Dao committed
1024
        return key
Tri Dao's avatar
Tri Dao committed
1025

Tri Dao's avatar
Tri Dao committed
1026
    state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
1027

Tri Dao's avatar
Tri Dao committed
1028
1029
    # MLP
    def key_mapping_mlp(key):
Tri Dao's avatar
Tri Dao committed
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
        key = re.sub(
            r"^transformer.layers.(\d+).mlp.dense_h_to_4h.(weight|bias)",
            r"transformer.layers.\1.mlp.fc1.\2",
            key,
        )
        key = re.sub(
            r"^transformer.layers.(\d+).mlp.dense_4h_to_h.(weight|bias)",
            r"transformer.layers.\1.mlp.fc2.\2",
            key,
        )
Tri Dao's avatar
Tri Dao committed
1040
        return key
Tri Dao's avatar
Tri Dao committed
1041

Tri Dao's avatar
Tri Dao committed
1042
    state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
1043

Tri Dao's avatar
Tri Dao committed
1044
1045
    # Attention
    def key_mapping_attn(key):
Tri Dao's avatar
Tri Dao committed
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
        key = re.sub(
            r"^transformer.layers.(\d+).self_attention.rotary_emb.inv_freq",
            r"transformer.layers.\1.mixer.rotary_emb.inv_freq",
            key,
        )
        key = re.sub(
            r"^transformer.layers.(\d+).self_attention.query_key_value.(weight|bias)",
            r"transformer.layers.\1.mixer.Wqkv.\2",
            key,
        )
        key = re.sub(
            r"^transformer.layers.(\d+).self_attention.dense.(weight|bias)",
            r"transformer.layers.\1.mixer.out_proj.\2",
            key,
        )
Tri Dao's avatar
Tri Dao committed
1061
        return key
Tri Dao's avatar
Tri Dao committed
1062

Tri Dao's avatar
Tri Dao committed
1063
1064
1065
1066
1067
    state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
    # Megatron stores Wqkv as ((nheads 3 headdim), hidden_dim)
    # while we store Wqkv as ((3 nheads headdim), hidden_dim)
    headdim = config.hidden_size // config.num_attention_heads
    for d in range(config.num_hidden_layers):
Tri Dao's avatar
Tri Dao committed
1068
1069
1070
1071
1072
1073
        Wqkv = state_dict.pop(f"transformer.layers.{d}.mixer.Wqkv.weight")
        state_dict[f"transformer.layers.{d}.mixer.Wqkv.weight"] = rearrange(
            Wqkv,
            "(nheads three headdim) ... -> (three nheads headdim) ...",
            three=3,
            headdim=headdim,
Tri Dao's avatar
Tri Dao committed
1074
        )
Tri Dao's avatar
Tri Dao committed
1075
1076
1077
        bqkv = state_dict.pop(f"transformer.layers.{d}.mixer.Wqkv.bias")
        state_dict[f"transformer.layers.{d}.mixer.Wqkv.bias"] = rearrange(
            bqkv, "(nheads three headdim) -> (three nheads headdim)", three=3, headdim=headdim
Tri Dao's avatar
Tri Dao committed
1078
        )
1079
1080

    return state_dict