gpt.py 45.7 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
# Copyright (c) 2023, Tri Dao.
Tri Dao's avatar
Tri Dao committed
2

3
import logging
Tri Dao's avatar
Tri Dao committed
4
import math
5
import re
Tri Dao's avatar
Tri Dao committed
6
from collections import OrderedDict, namedtuple
Tri Dao's avatar
Tri Dao committed
7
from collections.abc import Sequence
Tri Dao's avatar
Tri Dao committed
8
from functools import partial
Tri Dao's avatar
Tri Dao committed
9
10
11
12

import torch
import torch.nn as nn
import torch.nn.functional as F
13
from einops import rearrange
Tri Dao's avatar
Tri Dao committed
14
15
from transformers import GPT2Config

Kevin Hu's avatar
Kevin Hu committed
16
from flash_attn.models.bigcode import remap_state_dict_hf_bigcode
Tri Dao's avatar
Tri Dao committed
17
18
19
from flash_attn.models.falcon import remap_state_dict_hf_falcon
from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox
from flash_attn.models.gptj import remap_state_dict_hf_gptj
20
from flash_attn.models.llama import remap_state_dict_hf_llama
Tri Dao's avatar
Tri Dao committed
21
from flash_attn.models.opt import remap_state_dict_hf_opt
Tri Dao's avatar
Tri Dao committed
22
from flash_attn.modules.block import Block, ParallelBlock
23
from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings
Tri Dao's avatar
Tri Dao committed
24
from flash_attn.modules.mha import MHA, ParallelMHA
Kevin Hu's avatar
Kevin Hu committed
25
26
27
28
29
30
31
32
from flash_attn.modules.mlp import (
    FusedMLP,
    GatedMlp,
    Mlp,
    ParallelFusedMLP,
    ParallelGatedMlp,
    ParallelMLP,
)
Tri Dao's avatar
Tri Dao committed
33
from flash_attn.ops.activations import sqrelu_fwd
Kevin Hu's avatar
Kevin Hu committed
34
from flash_attn.utils.distributed import all_gather_raw, get_dim_for_local_rank, sync_shared_params
Tri Dao's avatar
Tri Dao committed
35
from flash_attn.utils.generation import GenerationMixin
Tri Dao's avatar
Tri Dao committed
36
from flash_attn.utils.pretrained import state_dict_from_pretrained
37
38
39
40
41

try:
    from flash_attn.ops.fused_dense import ColumnParallelLinear
except ImportError:
    ColumnParallelLinear = None
Tri Dao's avatar
Tri Dao committed
42
43
44
45
46
47

try:
    from flash_attn.ops.layer_norm import dropout_add_layer_norm
except ImportError:
    dropout_add_layer_norm = None

48
try:
Kevin Hu's avatar
Kevin Hu committed
49
    from flash_attn.ops.layer_norm import dropout_add_layer_norm_parallel_residual
50
51
52
except ImportError:
    dropout_add_layer_norm_parallel_residual = None

Tri Dao's avatar
Tri Dao committed
53
54
55
try:
    from flash_attn.ops.rms_norm import RMSNorm, dropout_add_rms_norm
except ImportError:
56
    RMSNorm, dropout_add_rms_norm = None, None
Tri Dao's avatar
Tri Dao committed
57
58
59
60
61
62

try:
    from flash_attn.ops.rms_norm import dropout_add_rms_norm_parallel_residual
except ImportError:
    dropout_add_rms_norm_parallel_residual = None

Tri Dao's avatar
Tri Dao committed
63
try:
Tri Dao's avatar
Tri Dao committed
64
    from flash_attn.ops.triton.mlp import FusedDenseSqreluDense
Tri Dao's avatar
Tri Dao committed
65
66
67
except ImportError:
    FusedDenseSqreluDense = None

68
69
70
logger = logging.getLogger(__name__)


71
def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dtype=None):
Tri Dao's avatar
Tri Dao committed
72
73
    factory_kwargs = {"device": device, "dtype": dtype}
    head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
Tri Dao's avatar
Tri Dao committed
74
75
76
77
    softmax_scale = 1.0 if not config.scale_attn_weights else head_dim ** (-0.5)
    if config.scale_attn_by_inverse_layer_idx:
        assert layer_idx is not None
        softmax_scale /= float(layer_idx + 1)
Tri Dao's avatar
Tri Dao committed
78
    dwconv = getattr(config, "attn_dwconv", False)
79
    if dwconv:
Tri Dao's avatar
Tri Dao committed
80
81
82
83
84
85
86
87
88
        assert process_group is None, "TensorParallel MHA does not support dwconv yet"
    qkv_proj_bias = getattr(config, "qkv_proj_bias", True)
    out_proj_bias = getattr(config, "out_proj_bias", True)
    rotary_emb_dim = int(getattr(config, "rotary_emb_fraction", 0.0) * head_dim)
    rotary_emb_base = getattr(config, "rotary_emb_base", 10000.0)
    rotary_emb_scale_base = getattr(config, "rotary_emb_scale_base", None)
    rotary_emb_interleaved = getattr(config, "rotary_emb_interleaved", False)
    use_flash_attn = getattr(config, "use_flash_attn", False)
    fused_bias_fc = getattr(config, "fused_bias_fc", False)
89
    if not fused_bias_fc:
Tri Dao's avatar
Tri Dao committed
90
        assert process_group is None, "TensorParallel MHA requires fused_bias_fc"
91
    mha_cls = MHA if process_group is None else ParallelMHA
Tri Dao's avatar
Tri Dao committed
92
93
94
95
96
97
98
99
100
101
102
    serial_kwargs = (
        {"fused_bias_fc": fused_bias_fc, "dwconv": dwconv} if process_group is None else {}
    )
    parallel_kwargs = (
        {
            "process_group": process_group,
            "sequence_parallel": getattr(config, "sequence_parallel", True),
        }
        if process_group is not None
        else {}
    )
Tri Dao's avatar
Tri Dao committed
103
    num_heads_kv = getattr(config, "n_head_kv", None)
Tri Dao's avatar
Tri Dao committed
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
    mixer_cls = partial(
        mha_cls,
        num_heads=config.num_attention_heads,
        num_heads_kv=num_heads_kv,
        qkv_proj_bias=qkv_proj_bias,
        out_proj_bias=out_proj_bias,
        dropout=config.attn_pdrop,
        softmax_scale=softmax_scale,
        causal=True,
        layer_idx=layer_idx,
        rotary_emb_dim=rotary_emb_dim,
        rotary_emb_base=rotary_emb_base,
        rotary_emb_scale_base=rotary_emb_scale_base,
        rotary_emb_interleaved=rotary_emb_interleaved,
        use_flash_attn=use_flash_attn,
        **serial_kwargs,
        **parallel_kwargs,
        **factory_kwargs,
    )
Tri Dao's avatar
Tri Dao committed
123
124
125
    return mixer_cls


126
def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtype=None):
Tri Dao's avatar
Tri Dao committed
127
128
129
130
    factory_kwargs = {"device": device, "dtype": dtype}
    mlp_fc1_bias = getattr(config, "mlp_fc1_bias", True)
    mlp_fc2_bias = getattr(config, "mlp_fc2_bias", True)
    fused_mlp = getattr(config, "fused_mlp", False)
131
    if fused_mlp:
Tri Dao's avatar
Tri Dao committed
132
133
134
135
        assert config.activation_function in [
            "gelu_new",
            "gelu_fast",
            "gelu_approx",
Kevin Hu's avatar
Kevin Hu committed
136
            "gelu_pytorch_tanh",
Tri Dao's avatar
Tri Dao committed
137
138
139
140
            "relu",
            "sqrelu",
        ]
    fused_dense_sqrelu_dense = getattr(config, "fused_dense_sqrelu_dense", False)
141
    if fused_dense_sqrelu_dense:
Tri Dao's avatar
Tri Dao committed
142
143
144
        assert config.activation_function == "sqrelu", (
            "fused_dense_sqrelu_dense only " "supports approximate activation_function sqrelu"
        )
145
146
    assert not (fused_dense_sqrelu_dense and fused_mlp)
    if not fused_mlp and not fused_dense_sqrelu_dense:
Tri Dao's avatar
Tri Dao committed
147
148
149
150
151
        assert config.activation_function in [
            "gelu",
            "gelu_new",
            "gelu_fast",
            "gelu_approx",
Kevin Hu's avatar
Kevin Hu committed
152
            "gelu_pytorch_tanh",
Tri Dao's avatar
Tri Dao committed
153
154
155
156
157
158
159
160
161
162
163
164
            "relu",
            "sqrelu",
            "glu",
            "swiglu",
            "geglu",
        ]
        if config.activation_function in ["glu", "swiglu", "geglu"]:
            activation = (
                F.sigmoid
                if config.activation_function == "glu"
                else (F.silu if config.activation_function == "swiglu" else F.gelu)
            )
165
            mlp_cls = GatedMlp if process_group is None else ParallelGatedMlp
Tri Dao's avatar
Tri Dao committed
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
            parallel_kwargs = (
                {
                    "process_group": process_group,
                    "sequence_parallel": getattr(config, "sequence_parallel", True),
                }
                if process_group is not None
                else {}
            )
            mlp_cls = partial(
                mlp_cls,
                hidden_features=config.n_inner,
                activation=activation,
                bias1=mlp_fc1_bias,
                bias2=mlp_fc2_bias,
                **parallel_kwargs,
                **factory_kwargs,
            )
Tri Dao's avatar
Tri Dao committed
183
        else:
Tri Dao's avatar
Tri Dao committed
184
            if config.activation_function == "relu":
Tri Dao's avatar
Tri Dao committed
185
                activation = partial(F.relu, inplace=True)
Tri Dao's avatar
Tri Dao committed
186
            elif config.activation_function == "sqrelu":
Tri Dao's avatar
Tri Dao committed
187
188
                activation = sqrelu_fwd
            else:
Tri Dao's avatar
Tri Dao committed
189
190
                approximate = (
                    "tanh"
Kevin Hu's avatar
Kevin Hu committed
191
192
                    if config.activation_function
                    in ["gelu_new", "gelu_fast", "gelu_approx", "gelu_pytorch_tanh"]
Tri Dao's avatar
Tri Dao committed
193
194
195
                    else "none"
                )
                activation = partial(F.gelu, approximate=approximate)
Tri Dao's avatar
Tri Dao committed
196
            mlp_cls = Mlp if process_group is None else ParallelMLP
Tri Dao's avatar
Tri Dao committed
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
            parallel_kwargs = (
                {
                    "process_group": process_group,
                    "sequence_parallel": getattr(config, "sequence_parallel", True),
                }
                if process_group is not None
                else {}
            )
            mlp_cls = partial(
                mlp_cls,
                hidden_features=config.n_inner,
                activation=activation,
                bias1=mlp_fc1_bias,
                bias2=mlp_fc2_bias,
                **parallel_kwargs,
                **factory_kwargs,
            )
Tri Dao's avatar
Tri Dao committed
214
    else:
Tri Dao's avatar
Tri Dao committed
215
        mlp_checkpoint_lvl = getattr(config, "mlp_checkpoint_lvl", 0)
Tri Dao's avatar
Tri Dao committed
216
217
218
219
        # mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
        if isinstance(mlp_checkpoint_lvl, Sequence):
            assert layer_idx is not None
            mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
220
221
        if fused_mlp:
            if FusedMLP is None:
Tri Dao's avatar
Tri Dao committed
222
223
224
                raise ImportError("fused_dense is not installed")
            activation = (
                "gelu_approx"
Kevin Hu's avatar
Kevin Hu committed
225
226
                if config.activation_function
                in ["gelu_new", "gelu_fast", "gelu_approx", "gelu_pytorch_tanh"]
Tri Dao's avatar
Tri Dao committed
227
228
                else config.activation_function
            )
229
            mlp_cls = FusedMLP if process_group is None else ParallelFusedMLP
Tri Dao's avatar
Tri Dao committed
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
            parallel_kwargs = (
                {
                    "process_group": process_group,
                    "sequence_parallel": getattr(config, "sequence_parallel", True),
                }
                if process_group is not None
                else {}
            )
            mlp_cls = partial(
                mlp_cls,
                hidden_features=config.n_inner,
                activation=activation,
                checkpoint_lvl=mlp_checkpoint_lvl,
                bias1=mlp_fc1_bias,
                bias2=mlp_fc2_bias,
                **parallel_kwargs,
                **factory_kwargs,
            )
Tri Dao's avatar
Tri Dao committed
248
        elif fused_dense_sqrelu_dense:
249
            if process_group is not None:
Tri Dao's avatar
Tri Dao committed
250
                assert fused_mlp, "Tensor Parallel is not implemented for FusedDenseSqreluDense"
Tri Dao's avatar
Tri Dao committed
251
            assert FusedDenseSqreluDense is not None
Tri Dao's avatar
Tri Dao committed
252
253
254
255
256
257
            mlp_cls = partial(
                FusedDenseSqreluDense,
                hidden_features=config.n_inner,
                checkpoint_lvl=mlp_checkpoint_lvl,
                **factory_kwargs,
            )
Tri Dao's avatar
Tri Dao committed
258
        else:
Tri Dao's avatar
Tri Dao committed
259
            raise RuntimeError("MLP type not supported")
Tri Dao's avatar
Tri Dao committed
260
261
262
    return mlp_cls


263
def create_block(config, layer_idx=None, process_group=None, device=None, dtype=None):
Tri Dao's avatar
Tri Dao committed
264
265
    factory_kwargs = {"device": device, "dtype": dtype}
    sequence_parallel = getattr(config, "sequence_parallel", True)
266
267
    mixer_cls = create_mixer_cls(config, layer_idx, process_group=process_group, **factory_kwargs)
    mlp_cls = create_mlp_cls(config, layer_idx, process_group=process_group, **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
268
269
270
271
272
273
    use_rms_norm = getattr(config, "rms_norm", False)
    norm_cls = partial(
        nn.LayerNorm if not use_rms_norm else RMSNorm,
        eps=config.layer_norm_epsilon,
        **factory_kwargs,
    )
Tri Dao's avatar
Tri Dao committed
274
    # TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
Tri Dao's avatar
Tri Dao committed
275
    residual_in_fp32 = getattr(config, "residual_in_fp32", False)
Tri Dao's avatar
Tri Dao committed
276
    resid_dropout1 = config.resid_pdrop if layer_idx is None or layer_idx > 0 else config.embd_pdrop
Tri Dao's avatar
Tri Dao committed
277
278
    prenorm = getattr(config, "prenorm", True)
    parallel_block = getattr(config, "parallel_block", False)
Tri Dao's avatar
Tri Dao committed
279
280
    if not parallel_block:
        block = Block(
Tri Dao's avatar
Tri Dao committed
281
282
283
284
285
286
287
288
            config.hidden_size,
            mixer_cls,
            mlp_cls,
            norm_cls=norm_cls,
            prenorm=prenorm,
            resid_dropout1=resid_dropout1,
            resid_dropout2=config.resid_pdrop,
            fused_dropout_add_ln=getattr(config, "fused_dropout_add_ln", False),
Tri Dao's avatar
Tri Dao committed
289
290
            residual_in_fp32=residual_in_fp32,
            sequence_parallel=sequence_parallel and process_group is not None,
Tri Dao's avatar
Tri Dao committed
291
            mark_shared_params=process_group is not None,
Tri Dao's avatar
Tri Dao committed
292
293
294
295
        )
    else:
        assert prenorm
        block = ParallelBlock(
Tri Dao's avatar
Tri Dao committed
296
297
298
299
300
301
302
303
            config.hidden_size,
            mixer_cls,
            mlp_cls,
            norm_cls=norm_cls,
            resid_dropout1=resid_dropout1,
            resid_dropout2=config.resid_pdrop,
            tied_norm=getattr(config, "parallel_block_tied_norm", False),
            fused_dropout_add_ln=getattr(config, "fused_dropout_add_ln", False),
Tri Dao's avatar
Tri Dao committed
304
305
            residual_in_fp32=residual_in_fp32,
            sequence_parallel=sequence_parallel and process_group is not None,
Tri Dao's avatar
Tri Dao committed
306
            mark_shared_params=process_group is not None,
Tri Dao's avatar
Tri Dao committed
307
        )
Tri Dao's avatar
Tri Dao committed
308
309
310
311
    block.layer_idx = layer_idx
    return block


312
class GPTPreTrainedModel(nn.Module):
Tri Dao's avatar
Tri Dao committed
313
314
    """An abstract class to handle weights initialization and
    a simple interface for dowloading and loading pretrained models.
315
    """
Tri Dao's avatar
Tri Dao committed
316

317
318
319
320
321
322
323
324
    def __init__(self, config, *inputs, **kwargs):
        super().__init__()
        if not isinstance(config, GPT2Config):
            raise ValueError(
                "Parameter config in `{}(config)` should be an instance of class `GPT2Config`. "
                "To create a model from a Google pretrained model use "
                "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
                    self.__class__.__name__, self.__class__.__name__
Tri Dao's avatar
Tri Dao committed
325
326
                )
            )
327
328
329
        self.config = config

    @classmethod
Tri Dao's avatar
Tri Dao committed
330
331
332
333
334
335
336
337
338
339
340
341
    def from_pretrained(
        cls,
        model_name,
        config,
        *args,
        strict=True,
        device=None,
        dtype=None,
        world_size=1,
        rank=0,
        **kwargs,
    ):
342
343
344
345
346
        """
        Instantiate a GPTPreTrainedModel from a pre-trained model file or a pytorch state dict.
        Download and cache the pre-trained model file if needed.
        """
        # Instantiate model.
347
        model = cls(config, *args, device=device, dtype=dtype, **kwargs)
348
349
        # Load state_dict in cpu because we already initialized the model in GPU, and we don't
        # want extra stuff taking up more GPU memory
Tri Dao's avatar
Tri Dao committed
350
351
        state_dict = state_dict_from_pretrained(model_name, device="cpu", dtype=dtype)
        if model_name.startswith("gpt2"):
Tri Dao's avatar
Tri Dao committed
352
            state_dict = remap_state_dict_hf_gpt2(state_dict, config)
Tri Dao's avatar
Tri Dao committed
353
        elif model_name.startswith("facebook/opt"):
Tri Dao's avatar
Tri Dao committed
354
            state_dict = remap_state_dict_hf_opt(state_dict, config)
Tri Dao's avatar
Tri Dao committed
355
        elif model_name.startswith("EleutherAI/gpt-j-"):
Tri Dao's avatar
Tri Dao committed
356
            state_dict = remap_state_dict_hf_gptj(state_dict, config)
Tri Dao's avatar
Tri Dao committed
357
        elif model_name.startswith("EleutherAI/gpt-neox-"):
Tri Dao's avatar
Tri Dao committed
358
            state_dict = remap_state_dict_hf_gpt_neox(state_dict, config)
Tri Dao's avatar
Tri Dao committed
359
        elif model_name.startswith("tiiuae/falcon-"):
Tri Dao's avatar
Tri Dao committed
360
            state_dict = remap_state_dict_hf_falcon(state_dict, config)
361
362
        elif model_name.startswith("meta-llama/Llama-"):
            state_dict = remap_state_dict_hf_llama(state_dict, config)
Kevin Hu's avatar
Kevin Hu committed
363
364
        elif model_name.startswith("bigcode/") or model_name.startswith("WizardLM/"):
            state_dict = remap_state_dict_hf_bigcode(state_dict, config)
Tri Dao's avatar
Tri Dao committed
365
        else:
Tri Dao's avatar
Tri Dao committed
366
            raise NotImplementedError(f"Model {model_name} not supported")
367
368
369
        if world_size > 1:
            state_dict = shard_state_dict_tp(state_dict, config, world_size, rank)
        load_return = model.load_state_dict(state_dict, strict=strict)
370
371
372
        logger.info(load_return)
        return model

Tri Dao's avatar
Tri Dao committed
373

Tri Dao's avatar
Tri Dao committed
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
def _init_weights(module, n_layer, initializer_range=0.02, rescale_prenorm_residual=True):
    if isinstance(module, nn.Linear):
        nn.init.normal_(module.weight, std=initializer_range)
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Embedding):
        nn.init.normal_(module.weight, std=initializer_range)

    if rescale_prenorm_residual:
        # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
        #   > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
        #   > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
        #   >   -- GPT-2 :: https://openai.com/blog/better-language-models/
        #
        # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
        for name, p in module.named_parameters():
            if name in ["out_proj.weight", "fc2.weight"]:
                # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
                nn.init.normal_(p, mean=0.0, std=initializer_range / math.sqrt(2 * n_layer))


396
class GPTModel(GPTPreTrainedModel):
397
    def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=None):
398
        super().__init__(config)
Tri Dao's avatar
Tri Dao committed
399
        factory_kwargs = {"device": device, "dtype": dtype}
400
        self.process_group = process_group
Tri Dao's avatar
Tri Dao committed
401
402
403
404
405
406
        self.sequence_parallel = getattr(config, "sequence_parallel", True)
        assert config.activation_function in [
            "gelu",
            "gelu_new",
            "gelu_fast",
            "gelu_approx",
Kevin Hu's avatar
Kevin Hu committed
407
            "gelu_pytorch_tanh",
Tri Dao's avatar
Tri Dao committed
408
409
410
411
412
413
414
415
416
417
            "relu",
            "sqrelu",
            "glu",
            "swiglu",
            "geglu",
        ]
        pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
        vocab_size = (
            math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
        )
Tri Dao's avatar
Tri Dao committed
418
        # TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
Tri Dao's avatar
Tri Dao committed
419
        self.residual_in_fp32 = getattr(config, "residual_in_fp32", False)
Tri Dao's avatar
Tri Dao committed
420
        # These 2 options are for OPT-350m
Tri Dao's avatar
Tri Dao committed
421
422
423
        self.prenorm = getattr(config, "prenorm", True)
        use_rms_norm = getattr(config, "rms_norm", False)
        word_embed_proj_dim = getattr(config, "word_embed_proj_dim", None)
Tri Dao's avatar
Tri Dao committed
424
        # For GPT-J, GPT-NeoX
Tri Dao's avatar
Tri Dao committed
425
        self.parallel_block = getattr(config, "parallel_block", False)
Tri Dao's avatar
Tri Dao committed
426

427
        if process_group is None:
Tri Dao's avatar
Tri Dao committed
428
            self.embeddings = GPT2Embeddings(
Tri Dao's avatar
Tri Dao committed
429
430
431
432
433
                config.hidden_size,
                vocab_size,
                config.max_position_embeddings,
                word_embed_proj_dim=word_embed_proj_dim,
                **factory_kwargs,
Tri Dao's avatar
Tri Dao committed
434
            )
435
436
        else:
            self.embeddings = ParallelGPT2Embeddings(
Tri Dao's avatar
Tri Dao committed
437
438
439
440
441
442
                config.hidden_size,
                vocab_size,
                config.max_position_embeddings,
                process_group=process_group,
                sequence_parallel=self.sequence_parallel,
                **factory_kwargs,
443
            )
Tri Dao's avatar
Tri Dao committed
444

Tri Dao's avatar
Tri Dao committed
445
        # We change the order of dropout, residual and layer norm:
Tri Dao's avatar
Tri Dao committed
446
        # Instead of LN -> Attn / MLP -> Dropout -> Add, we do:
Tri Dao's avatar
Tri Dao committed
447
448
449
        # Dropout -> Add -> LN -> Attn / MLP, returning both the residual branch (output of Add) and
        # the main branch (output of MLP). The model definition is unchanged, but the mapping of the
        # nn.Dropout probabilities are changed.
Tri Dao's avatar
Tri Dao committed
450
        # This is for performance reason: we can fuse dropout + add + layer_norm.
Tri Dao's avatar
Tri Dao committed
451
452
453
454
455
456
        self.layers = nn.ModuleList(
            [
                create_block(config, layer_idx=i, process_group=process_group, **factory_kwargs)
                for i in range(config.num_hidden_layers)
            ]
        )
Tri Dao's avatar
Tri Dao committed
457

Tri Dao's avatar
Tri Dao committed
458
        self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
459
        if self.fused_dropout_add_ln:
Tri Dao's avatar
Tri Dao committed
460
461
462
463
            if (not self.parallel_block and dropout_add_layer_norm is None) or (
                self.parallel_block and dropout_add_layer_norm_parallel_residual is None
            ):
                raise ImportError("dropout_layer_norm is not installed")
Tri Dao's avatar
Tri Dao committed
464
465
        if self.prenorm:
            self.drop_f = nn.Dropout(config.resid_pdrop)
Tri Dao's avatar
Tri Dao committed
466
            norm_cls = nn.LayerNorm if not use_rms_norm else RMSNorm
Tri Dao's avatar
Tri Dao committed
467
468
469
            self.ln_f = norm_cls(
                config.hidden_size, eps=config.layer_norm_epsilon, **factory_kwargs
            )
470
        if process_group is not None:
Tri Dao's avatar
Tri Dao committed
471
            for p in self.ln_f.parameters():
472
473
474
475
476
                # Mark the norm parameters as "shared_params" so that we sync their values at init.
                p._shared_params = True
                # Mark the norm params as "sequence_parallel" so we run all-reduce on their grads.
                if self.sequence_parallel:
                    p._sequence_parallel = True
477

Tri Dao's avatar
Tri Dao committed
478
479
480
481
482
483
484
        self.apply(
            partial(
                _init_weights,
                n_layer=config.num_hidden_layers,
                initializer_range=config.initializer_range,
            )
        )
485
486
487
        self.tie_weights()

    def tie_weights(self):
488
        if self.process_group is not None:
489
            sync_shared_params(self, self.process_group)
Tri Dao's avatar
Tri Dao committed
490

491
    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
Tri Dao's avatar
Tri Dao committed
492
493
494
495
        return {
            i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
            for i, layer in enumerate(self.layers)
        }
496

Tri Dao's avatar
Tri Dao committed
497
    def forward(self, input_ids, position_ids=None, inference_params=None):
498
499
500
        # If using Tensor Parallel with sequence parallel, we combine the batch and the seqlen
        # dimensions so that we can split on it easily, in case of small batch size.
        # Only the attention layers need to know the seqlen.
Tri Dao's avatar
Tri Dao committed
501
502
503
504
505
        embedding_kwargs = (
            {"combine_batch_seqlen_dim": True}
            if self.process_group is not None and self.sequence_parallel
            else {}
        )
506
        hidden_states = self.embeddings(input_ids, position_ids=position_ids, **embedding_kwargs)
Tri Dao's avatar
Tri Dao committed
507
508
        if self.parallel_block:
            hidden_states2 = None
Tri Dao's avatar
Tri Dao committed
509
        residual = None
Tri Dao's avatar
Tri Dao committed
510
511
512
513
514
        mixer_kwargs = (
            {"seqlen": input_ids.shape[1]}
            if self.process_group is not None and self.sequence_parallel
            else {}
        )
Tri Dao's avatar
Tri Dao committed
515
        if inference_params is not None:
Tri Dao's avatar
Tri Dao committed
516
            mixer_kwargs["inference_params"] = inference_params
Tri Dao's avatar
Tri Dao committed
517
        for layer in self.layers:
Tri Dao's avatar
Tri Dao committed
518
            if self.prenorm:
Tri Dao's avatar
Tri Dao committed
519
                if not self.parallel_block:
Tri Dao's avatar
Tri Dao committed
520
521
522
                    hidden_states, residual = layer(
                        hidden_states, residual, mixer_kwargs=mixer_kwargs
                    )
Tri Dao's avatar
Tri Dao committed
523
524
525
526
                else:
                    hidden_states, hidden_states2, residual = layer(
                        hidden_states, hidden_states2, residual, mixer_kwargs=mixer_kwargs
                    )
Tri Dao's avatar
Tri Dao committed
527
528
529
530
531
            else:
                hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
        if self.prenorm:
            if not self.fused_dropout_add_ln:
                dropped = self.drop_f(hidden_states)
Tri Dao's avatar
Tri Dao committed
532
533
534
535
                if not self.parallel_block:
                    residual = (dropped + residual) if residual is not None else dropped
                else:
                    dropped2 = self.drop_f(hidden_states2)
Tri Dao's avatar
Tri Dao committed
536
537
538
539
540
                    residual = (
                        (residual + dropped + dropped2)
                        if residual is not None
                        else dropped + dropped2
                    )
Tri Dao's avatar
Tri Dao committed
541
542
                hidden_states = self.ln_f(residual.to(dtype=self.ln_f.weight.dtype))
            else:
Tri Dao's avatar
Tri Dao committed
543
                # Set prenorm=False here since we don't need the residual
544
                if not self.parallel_block:
Tri Dao's avatar
Tri Dao committed
545
546
547
548
549
                    fused_add_norm_fn = (
                        dropout_add_rms_norm
                        if isinstance(self.ln_f, RMSNorm)
                        else dropout_add_layer_norm
                    )
550
                    hidden_states = fused_add_norm_fn(
Tri Dao's avatar
Tri Dao committed
551
552
553
554
555
556
557
558
                        hidden_states,
                        residual,
                        self.ln_f.weight,
                        self.ln_f.bias,
                        self.drop_f.p if self.training else 0.0,
                        self.ln_f.eps,
                        prenorm=False,
                        residual_in_fp32=self.residual_in_fp32,
559
560
                    )
                else:
Tri Dao's avatar
Tri Dao committed
561
562
563
564
565
                    fused_add_norm_fn = (
                        dropout_add_rms_norm_parallel_residual
                        if isinstance(self.ln_f, RMSNorm)
                        else dropout_add_layer_norm_parallel_residual
                    )
566
                    hidden_states, _ = fused_add_norm_fn(
Tri Dao's avatar
Tri Dao committed
567
568
569
570
571
572
573
574
575
576
577
                        hidden_states,
                        hidden_states2,
                        residual,
                        self.ln_f.weight,
                        self.ln_f.bias,
                        None,
                        None,
                        self.drop_f.p if self.training else 0.0,
                        self.ln_f.eps,
                        prenorm=False,
                        residual_in_fp32=self.residual_in_fp32,
578
                    )
Tri Dao's avatar
Tri Dao committed
579
580
581
        return hidden_states


Tri Dao's avatar
Tri Dao committed
582
class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
583
    def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=None):
Tri Dao's avatar
Tri Dao committed
584
        factory_kwargs = {"device": device, "dtype": dtype}
585
        super().__init__(config)
586
587
        self.process_group = process_group
        self.transformer = GPTModel(config, process_group=process_group, **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
588
589
590
591
592
593
        self.tie_word_embeddings = getattr(config, "tie_word_embeddings", True)
        lm_head_bias = getattr(config, "lm_head_bias", False)
        pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
        vocab_size = (
            math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
        )
Tri Dao's avatar
Tri Dao committed
594
        # This option is for OPT-350m
Tri Dao's avatar
Tri Dao committed
595
        word_embed_proj_dim = getattr(config, "word_embed_proj_dim", None)
Tri Dao's avatar
Tri Dao committed
596
597
598
599
600
        embed_dim = config.n_embd if word_embed_proj_dim is None else word_embed_proj_dim
        if word_embed_proj_dim is not None:
            self.project_out = nn.Linear(config.n_embd, embed_dim, bias=False, **factory_kwargs)
        else:
            self.project_out = None
601
        if process_group is None:
Tri Dao's avatar
Tri Dao committed
602
            self.lm_head = nn.Linear(embed_dim, vocab_size, bias=lm_head_bias, **factory_kwargs)
603
604
        else:
            if ColumnParallelLinear is None:
Tri Dao's avatar
Tri Dao committed
605
                raise ImportError("fused_dense_lib is not installed")
606
            self.lm_head = ColumnParallelLinear(
Tri Dao's avatar
Tri Dao committed
607
608
609
610
611
612
                embed_dim,
                vocab_size,
                process_group,
                bias=lm_head_bias,
                sequence_parallel=getattr(config, "sequence_parallel", True),
                **factory_kwargs,
613
            )
Tri Dao's avatar
Tri Dao committed
614
        # Initialize weights and apply final processing
Tri Dao's avatar
Tri Dao committed
615
616
617
618
619
620
621
        self.apply(
            partial(
                _init_weights,
                n_layer=config.num_hidden_layers,
                initializer_range=config.initializer_range,
            )
        )
Tri Dao's avatar
Tri Dao committed
622
623
624
        self.tie_weights()

    def tie_weights(self):
Tri Dao's avatar
Tri Dao committed
625
626
        if self.tie_word_embeddings:
            self.lm_head.weight = self.transformer.embeddings.word_embeddings.weight
627
        if self.process_group is not None:
628
            sync_shared_params(self, self.process_group)
Tri Dao's avatar
Tri Dao committed
629

630
    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
Tri Dao's avatar
Tri Dao committed
631
632
633
        return self.transformer.allocate_inference_cache(
            batch_size, max_seqlen, dtype=dtype, **kwargs
        )
634

635
    def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0):
Tri Dao's avatar
Tri Dao committed
636
        """
637
        input_ids: (batch, seqlen) int tensor
Tri Dao's avatar
Tri Dao committed
638
639
        inference_params: for generation. Adapted from Megatron-LM (and Apex)
        https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
640
        num_last_tokens: if > 0, only return the logits for the last n tokens
Tri Dao's avatar
Tri Dao committed
641
        """
Kevin Hu's avatar
Kevin Hu committed
642
643
644
        assert (
            input_ids.ndim == 2
        ), f"Expected `input_ids` to have shape [b, slen], but got shape {input_ids.shape}"
645
        b, slen = input_ids.shape
Tri Dao's avatar
Tri Dao committed
646
647
648
        hidden_states = self.transformer(
            input_ids, position_ids=position_ids, inference_params=inference_params
        )
Tri Dao's avatar
Tri Dao committed
649
650
        if inference_params is not None:
            assert hidden_states.ndim == 3, "sequence_parallel is not supported in generation mode"
651
652
        if num_last_tokens > 0:
            hidden_states = hidden_states[:, -num_last_tokens:]
Tri Dao's avatar
Tri Dao committed
653
654
        if self.project_out is not None:
            hidden_states = self.project_out(hidden_states)
Tri Dao's avatar
Tri Dao committed
655
        lm_logits = self.lm_head(hidden_states)
656
657
658
        # During inference, we want the full logit for sampling
        if isinstance(self.lm_head, ColumnParallelLinear) and inference_params is not None:
            lm_logits, _ = all_gather_raw(lm_logits, self.lm_head.process_group)
659
            lm_logits = rearrange(lm_logits, "(n b) ... d -> b ... (n d)", b=b)
Tri Dao's avatar
Tri Dao committed
660
        CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
Tri Dao's avatar
Tri Dao committed
661
        return CausalLMOutput(logits=lm_logits)
662

Tri Dao's avatar
Tri Dao committed
663
664
665
666
    def load_state_dict(self, state_dict, strict=True):
        # Remapping from our checkpoints that used a different ordering of layers in the block
        # Previous: Attn / MLP -> Dropout -> Add -> LN
        # Current: Dropout -> Add -> LN -> Attn / MLP
Tri Dao's avatar
Tri Dao committed
667
        if "transformer.ln_0.weight" in state_dict:
Tri Dao's avatar
Tri Dao committed
668
            n_layers = len(self.transformer.layers)
Tri Dao's avatar
Tri Dao committed
669
670
671
672
            ln_weight = state_dict.pop(f"transformer.layers.{n_layers - 1}.norm2.weight")
            ln_bias = state_dict.pop(f"transformer.layers.{n_layers - 1}.norm2.bias")
            state_dict["transformer.ln_f.weight"] = ln_weight
            state_dict["transformer.ln_f.bias"] = ln_bias
Tri Dao's avatar
Tri Dao committed
673
            for l in reversed(range(n_layers)):
Tri Dao's avatar
Tri Dao committed
674
675
676
677
                ln_weight = state_dict.pop(f"transformer.layers.{l}.norm1.weight")
                ln_bias = state_dict.pop(f"transformer.layers.{l}.norm1.bias")
                state_dict[f"transformer.layers.{l}.norm2.weight"] = ln_weight
                state_dict[f"transformer.layers.{l}.norm2.bias"] = ln_bias
Tri Dao's avatar
Tri Dao committed
678
                if l > 0:
Tri Dao's avatar
Tri Dao committed
679
680
681
682
683
684
685
686
                    ln_weight = state_dict.pop(f"transformer.layers.{l - 1}.norm2.weight")
                    ln_bias = state_dict.pop(f"transformer.layers.{l - 1}.norm2.bias")
                    state_dict[f"transformer.layers.{l}.norm1.weight"] = ln_weight
                    state_dict[f"transformer.layers.{l}.norm1.bias"] = ln_bias
            ln_weight = state_dict.pop("transformer.ln_0.weight")
            ln_bias = state_dict.pop("transformer.ln_0.bias")
            state_dict[f"transformer.layers.0.norm1.weight"] = ln_weight
            state_dict[f"transformer.layers.0.norm1.bias"] = ln_bias
Tri Dao's avatar
Tri Dao committed
687
688
        return super().load_state_dict(state_dict, strict=strict)

689

Tri Dao's avatar
Tri Dao committed
690
691
692
def shard_state_dict_tp(state_dict, config, world_size, rank):
    """Convert the state_dict of a standard GPT model to the state_dict of a GPT model
    with tensor parallel.
693
694

    This function modifies state_dict in place.
Tri Dao's avatar
Tri Dao committed
695
    """
Tri Dao's avatar
Tri Dao committed
696
697
    pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
    vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
Tri Dao's avatar
Tri Dao committed
698
699
700
701
702
    assert vocab_size % world_size == 0
    assert config.hidden_size % world_size == 0
    inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
    assert inner_dim % world_size == 0

703
704
705
706
707
708
    n_head = config.n_head
    n_head_kv = getattr(config, "n_head_kv", n_head)

    embed_dim = config.hidden_size
    head_dim = embed_dim // n_head

Tri Dao's avatar
Tri Dao committed
709
    def shard_first_dim(state_dict, key):
Tri Dao's avatar
Tri Dao committed
710
711
712
        if key in state_dict:
            x = state_dict[key]
            dim = x.shape[0] // world_size
Tri Dao's avatar
Tri Dao committed
713
            state_dict[key] = x[rank * dim : (rank + 1) * dim]
Tri Dao's avatar
Tri Dao committed
714

715
    def shard_last_dim(state_dict, key, multiple_of=1):
Tri Dao's avatar
Tri Dao committed
716
717
        if key in state_dict:
            x = state_dict[key]
718
719
720
721
722
723
            dim_each_rank = [
                get_dim_for_local_rank(x.size(-1), world_size, local_rank, multiple_of)
                for local_rank in range(world_size)
            ]
            beg, end = tuple(sum(dim_each_rank[:pos]) for pos in (rank, rank + 1))
            state_dict[key] = x[..., beg:end]
Tri Dao's avatar
Tri Dao committed
724

Tri Dao's avatar
Tri Dao committed
725
726
727
728
729
    def shard_gatedmlp_fc1_dim(state_dict, key):
        if key in state_dict:
            x = state_dict[key]
            dim = x.shape[0] // world_size // 2
            state_dict[key] = rearrange(
Tri Dao's avatar
Tri Dao committed
730
                rearrange(x, "(two o) ... -> two o ...", two=2)[:, rank * dim : (rank + 1) * dim],
Tri Dao's avatar
Tri Dao committed
731
                "two o ... -> (two o) ...",
Tri Dao's avatar
Tri Dao committed
732
733
            )

Tri Dao's avatar
Tri Dao committed
734
    def shard_qkv_headdim(state_dict, key):
Tri Dao's avatar
Tri Dao committed
735
        if key in state_dict:
736
            n_head_each_rank = [
Tri Dao's avatar
Tri Dao committed
737
738
                get_dim_for_local_rank(n_head, world_size, local_rank)
                for local_rank in range(world_size)
739
740
            ]
            n_head_kv_each_rank = [
Tri Dao's avatar
Tri Dao committed
741
742
                get_dim_for_local_rank(n_head_kv, world_size, local_rank)
                for local_rank in range(world_size)
743
744
745
746
747
748
749
750
            ]

            beg_n_head = sum(n_head_each_rank[:rank])
            end_n_head = sum(n_head_each_rank[: rank + 1])

            beg_n_head_kv = sum(n_head_kv_each_rank[:rank])
            end_n_head_kv = sum(n_head_kv_each_rank[: rank + 1])

Tri Dao's avatar
Tri Dao committed
751
            if n_head_kv == n_head:
Tri Dao's avatar
Tri Dao committed
752
753
                x = rearrange(state_dict[key], "(three d) ... -> three d ...", three=3)
                state_dict[key] = rearrange(
Tri Dao's avatar
Tri Dao committed
754
755
                    x[:, beg_n_head * head_dim : end_n_head * head_dim],
                    "three d ... -> (three d) ...",
Tri Dao's avatar
Tri Dao committed
756
                )
Tri Dao's avatar
Tri Dao committed
757
            else:
Tri Dao's avatar
Tri Dao committed
758
759
760
761
762
763
764
765
                x = rearrange(
                    state_dict[key],
                    "(nheadqkv headdim) ... -> nheadqkv headdim ...",
                    nheadqkv=n_head + 2 * n_head_kv,
                )
                state_dict[key] = rearrange(
                    torch.cat(
                        [
766
                            x[beg_n_head:end_n_head],
Tri Dao's avatar
Tri Dao committed
767
768
769
770
771
772
773
774
                            x[n_head + beg_n_head_kv : n_head + end_n_head_kv],
                            x[
                                n_head
                                + n_head_kv
                                + beg_n_head_kv : n_head
                                + n_head_kv
                                + end_n_head_kv
                            ],
Tri Dao's avatar
Tri Dao committed
775
776
777
778
779
780
781
782
783
784
785
                        ],
                        dim=0,
                    ),
                    "nheadqkv headdim ... -> (nheadqkv headdim) ...",
                )

    shard_first_dim(state_dict, "transformer.embeddings.word_embeddings.weight")
    if "lm_head.weight" in state_dict:
        shard_first_dim(state_dict, "lm_head.weight")
    if "transformer.embeddings.position_embeddings.weight" in state_dict:
        shard_last_dim(state_dict, "transformer.embeddings.position_embeddings.weight")
Tri Dao's avatar
Tri Dao committed
786
    for i in range(config.num_hidden_layers):
Tri Dao's avatar
Tri Dao committed
787
788
        shard_qkv_headdim(state_dict, f"transformer.layers.{i}.mixer.Wqkv.weight")
        shard_qkv_headdim(state_dict, f"transformer.layers.{i}.mixer.Wqkv.bias")
789
790
791
        shard_last_dim(
            state_dict, f"transformer.layers.{i}.mixer.out_proj.weight", multiple_of=head_dim
        )
Tri Dao's avatar
Tri Dao committed
792
        if rank != 0:
Tri Dao's avatar
Tri Dao committed
793
            state_dict.pop(f"transformer.layers.{i}.mixer.out_proj.bias", None)
Tri Dao's avatar
Tri Dao committed
794
        if config.activation_function in ["glu", "swiglu", "geglu"]:
Tri Dao's avatar
Tri Dao committed
795
796
            shard_gatedmlp_fc1_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.weight")
            shard_gatedmlp_fc1_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.bias")
Tri Dao's avatar
Tri Dao committed
797
        else:
Tri Dao's avatar
Tri Dao committed
798
799
800
            shard_first_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.weight")
            shard_first_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.bias")
        shard_last_dim(state_dict, f"transformer.layers.{i}.mlp.fc2.weight")
Tri Dao's avatar
Tri Dao committed
801
        if rank != 0:
Tri Dao's avatar
Tri Dao committed
802
            state_dict.pop(f"transformer.layers.{i}.mlp.fc2.bias", None)
Tri Dao's avatar
Tri Dao committed
803
804
805
    return state_dict


806
807
808
def combine_state_dicts_tp(state_dicts: list[dict[str, torch.Tensor]], config: GPT2Config):
    """Convert the list of sharded state_dict of a GPT model with tensor parallel to
    the state_dict of a standard GPT model.
809
810

    This function is meant to be the "reverse" of shard_state_dict_tp.
811
812
813

    Precondition:
        - state_dicts should be ordered in the same way as the shards were created.
Tri Dao's avatar
Tri Dao committed
814
815
816
    """
    world_size = len(state_dicts)
    keys = state_dicts[0].keys()
Tri Dao's avatar
Tri Dao committed
817
818
    pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
    vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
Tri Dao's avatar
Tri Dao committed
819
820
821
822
    assert vocab_size % world_size == 0
    assert config.hidden_size % world_size == 0
    inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
    assert inner_dim % world_size == 0
823
824
    assert config.hidden_size % config.n_head == 0
    headdim = config.hidden_size // config.n_head
Tri Dao's avatar
Tri Dao committed
825

Tri Dao's avatar
Tri Dao committed
826
    # Sometimes the word embeddings are sharded on the 0th dim, sometimes on the 1st dim.
Tri Dao's avatar
Tri Dao committed
827
828
    # vocab_size // world_size coordinates are nonzero.
    def combine_word_embeddings(state_dicts, state_dict, key):
Tri Dao's avatar
Tri Dao committed
829
830
        dim = 0 if state_dicts[0][key].shape[0] == vocab_size // world_size else 1
        state_dict[key] = torch.cat([s[key] for s in state_dicts], dim=dim)
Tri Dao's avatar
Tri Dao committed
831
832

    def combine_dim(state_dicts, state_dict, key, dim=-1):
Tri Dao's avatar
Tri Dao committed
833
834
        if key in state_dict:
            state_dict[key] = torch.cat([s[key] for s in state_dicts], dim=dim)
Tri Dao's avatar
Tri Dao committed
835
836

    def combine_qkv_headdim(state_dicts, state_dict, key):
Tri Dao's avatar
Tri Dao committed
837
        n_head = config.n_head
Tri Dao's avatar
Tri Dao committed
838
        n_head_kv = getattr(config, "n_head_kv", n_head)
Tri Dao's avatar
Tri Dao committed
839
        if key in state_dict:
Tri Dao's avatar
Tri Dao committed
840
            if n_head_kv == n_head:
Tri Dao's avatar
Tri Dao committed
841
842
843
844
                xs = [
                    rearrange(s[key], "(three d) ... -> three d ...", three=3) for s in state_dicts
                ]
                state_dict[key] = rearrange(torch.cat(xs, dim=1), "three d ... -> (three d) ...")
Tri Dao's avatar
Tri Dao committed
845
            else:
846
847
848
849
850
851
852
853
                n_head_each_rank = [
                    get_dim_for_local_rank(n_head, world_size, local_rank)
                    for local_rank in range(world_size)
                ]
                n_head_kv_each_rank = [
                    get_dim_for_local_rank(n_head_kv, world_size, local_rank)
                    for local_rank in range(world_size)
                ]
854
855
856
857
858
859
860
                xs = [
                    rearrange(
                        s[key],
                        "(nheadqkv headdim) ... -> nheadqkv headdim ...",
                        nheadqkv=rank_n_head + 2 * rank_n_head_kv,
                        headdim=headdim,
                    )
Kevin Hu's avatar
Kevin Hu committed
861
862
863
                    for s, rank_n_head, rank_n_head_kv in zip(
                        state_dicts, n_head_each_rank, n_head_kv_each_rank
                    )
864
                ]
Kevin Hu's avatar
Kevin Hu committed
865
                wq = torch.cat([x[: n_head_each_rank[rank]] for rank, x in enumerate(xs)], dim=0)
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
                wk = torch.cat(
                    [
                        x[
                            n_head_each_rank[rank] : n_head_each_rank[rank]
                            + n_head_kv_each_rank[rank]
                        ]
                        for rank, x in enumerate(xs)
                    ],
                    dim=0,
                )
                wv = torch.cat(
                    [
                        x[n_head_each_rank[rank] + n_head_kv_each_rank[rank] :]
                        for rank, x in enumerate(xs)
                    ],
                    dim=0,
                )
                wqkv = torch.cat(
                    [wq, wk, wv],
                    dim=0,
                )
Tri Dao's avatar
Tri Dao committed
887
                state_dict[key] = rearrange(
888
                    wqkv,
Tri Dao's avatar
Tri Dao committed
889
890
                    "nheadqkv headdim ... -> (nheadqkv headdim) ...",
                )
Tri Dao's avatar
Tri Dao committed
891
892
893

    def combine_gated_mlp(state_dicts, state_dict, key):
        if key in state_dict:
Tri Dao's avatar
Tri Dao committed
894
895
            xs = [rearrange(s[key], "(two d) ... -> two d ...", two=2) for s in state_dicts]
            state_dict[key] = rearrange(torch.cat(xs, dim=1), "two d ... -> (two d) ...")
Tri Dao's avatar
Tri Dao committed
896
897

    state_dict = state_dicts[0].copy()  # don't modify state_dict[0] inplace
Tri Dao's avatar
Tri Dao committed
898
899
900
901
902
903
904
905
906
907
908
909
910
911
    combine_word_embeddings(
        state_dicts, state_dict, "transformer.embeddings.word_embeddings.weight"
    )
    if "lm_head.weight" in state_dict:
        combine_word_embeddings(state_dicts, state_dict, "lm_head.weight")
    if "transformer.embeddings.position_embeddings.weight" in state_dict:
        combine_dim(
            state_dicts, state_dict, "transformer.embeddings.position_embeddings.weight", -1
        )
    mlp_combine_fn = (
        combine_gated_mlp
        if config.activation_function in ["glu", "swiglu", "geglu"]
        else partial(combine_dim, dim=0)
    )
Tri Dao's avatar
Tri Dao committed
912
    for i in range(config.num_hidden_layers):
Tri Dao's avatar
Tri Dao committed
913
914
915
916
917
918
        combine_qkv_headdim(state_dicts, state_dict, f"transformer.layers.{i}.mixer.Wqkv.weight")
        combine_qkv_headdim(state_dicts, state_dict, f"transformer.layers.{i}.mixer.Wqkv.bias")
        combine_dim(state_dicts, state_dict, f"transformer.layers.{i}.mixer.out_proj.weight", -1)
        mlp_combine_fn(state_dicts, state_dict, f"transformer.layers.{i}.mlp.fc1.weight")
        combine_dim(state_dicts, state_dict, f"transformer.layers.{i}.mlp.fc1.bias", 0)
        combine_dim(state_dicts, state_dict, f"transformer.layers.{i}.mlp.fc2.weight", -1)
Tri Dao's avatar
Tri Dao committed
919
920
921
922
    return state_dict


def remap_state_dict_hf_gpt2(state_dict, config):
923
924
    # Word embedding and position embedding
    def key_mapping_pos_emb(key):
Tri Dao's avatar
Tri Dao committed
925
926
        return re.sub(r"^wpe.", "transformer.embeddings.position_embeddings.", key)

927
    state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items())
Tri Dao's avatar
Tri Dao committed
928
    word_embeddings = state_dict.pop("wte.weight")
929
    # It's possible that vocab_size is padded to be a multiple of 8, for example.
Tri Dao's avatar
Tri Dao committed
930
931
932
    pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
    vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
    state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
933
        word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
934
    )
Tri Dao's avatar
Tri Dao committed
935
    state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
936
937

    # LayerNorm
Tri Dao's avatar
Tri Dao committed
938
    def key_mapping_ln(key):
Tri Dao's avatar
Tri Dao committed
939
940
        key = re.sub(r"^ln_f.(weight|bias)", r"transformer.ln_f.\1", key)
        key = re.sub(r"^h.(\d+).ln_(1|2).(weight|bias)", r"transformer.layers.\1.norm\2.\3", key)
Tri Dao's avatar
Tri Dao committed
941
        return key
Tri Dao's avatar
Tri Dao committed
942

Tri Dao's avatar
Tri Dao committed
943
    state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
944
945
946

    # MLP
    for d in range(config.num_hidden_layers):
Tri Dao's avatar
Tri Dao committed
947
948
949
950
951
        W1 = state_dict.pop(f"h.{d}.mlp.c_fc.weight")
        state_dict[f"transformer.layers.{d}.mlp.fc1.weight"] = W1.t()
        W2 = state_dict.pop(f"h.{d}.mlp.c_proj.weight")
        state_dict[f"transformer.layers.{d}.mlp.fc2.weight"] = W2.t()

952
    def key_mapping_mlp(key):
Tri Dao's avatar
Tri Dao committed
953
954
        key = re.sub(r"^h.(\d+).mlp.c_fc.bias", r"transformer.layers.\1.mlp.fc1.bias", key)
        key = re.sub(r"^h.(\d+).mlp.c_proj.bias", r"transformer.layers.\1.mlp.fc2.bias", key)
955
        return key
Tri Dao's avatar
Tri Dao committed
956

957
958
959
960
    state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())

    # Attention
    for d in range(config.num_hidden_layers):
Tri Dao's avatar
Tri Dao committed
961
962
963
964
965
966
        state_dict.pop(f"h.{d}.attn.bias")  # We don't store this bias
        Wqkv = state_dict.pop(f"h.{d}.attn.c_attn.weight")
        state_dict[f"transformer.layers.{d}.mixer.Wqkv.weight"] = Wqkv.t()
        Wout = state_dict.pop(f"h.{d}.attn.c_proj.weight")
        state_dict[f"transformer.layers.{d}.mixer.out_proj.weight"] = Wout.t()

967
    def key_mapping_attn(key):
Tri Dao's avatar
Tri Dao committed
968
969
970
971
        key = re.sub(r"^h.(\d+).attn.c_attn.bias", r"transformer.layers.\1.mixer.Wqkv.bias", key)
        key = re.sub(
            r"^h.(\d+).attn.c_proj.bias", r"transformer.layers.\1.mixer.out_proj.bias", key
        )
972
        return key
Tri Dao's avatar
Tri Dao committed
973

974
975
976
    state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())

    return state_dict
977
978


Tri Dao's avatar
Tri Dao committed
979
980
def remap_state_dict_megatron(state_dict, config):
    def key_mapping_transformer(key):
Tri Dao's avatar
Tri Dao committed
981
982
        key = re.sub(r"^language_model.encoder.", "transformer.", key)
        key = re.sub(r"^language_model.", "transformer.", key)
Tri Dao's avatar
Tri Dao committed
983
        return key
Tri Dao's avatar
Tri Dao committed
984

Tri Dao's avatar
Tri Dao committed
985
    state_dict = OrderedDict((key_mapping_transformer(k), v) for k, v in state_dict.items())
986

Tri Dao's avatar
Tri Dao committed
987
988
    # Word embedding and position embedding
    def key_mapping_pos_emb(key):
Tri Dao's avatar
Tri Dao committed
989
990
        return re.sub(r"^wpe.", "transformer.embeddings.position_embeddings.", key)

Tri Dao's avatar
Tri Dao committed
991
    state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items())
Tri Dao's avatar
Tri Dao committed
992
    word_embeddings = state_dict.pop("transformer.embedding.word_embeddings.weight")
Tri Dao's avatar
Tri Dao committed
993
    # It's possible that vocab_size is padded to be a multiple of 8, for example.
Tri Dao's avatar
Tri Dao committed
994
995
996
997
998
    pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
    vocab_size = (
        math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) * pad_vocab_size_multiple
    )
    state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
Tri Dao's avatar
Tri Dao committed
999
1000
        word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
    )
Tri Dao's avatar
Tri Dao committed
1001
    state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
1002

Tri Dao's avatar
Tri Dao committed
1003
1004
    # LayerNorm
    def key_mapping_ln(key):
Tri Dao's avatar
Tri Dao committed
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
        key = re.sub(r"^transformer.final_layernorm.(weight|bias)", r"transformer.ln_f.\1", key)
        key = re.sub(
            r"^transformer.layers.(\d+).input_layernorm.(weight|bias)",
            r"transformer.layers.\1.norm1.\2",
            key,
        )
        key = re.sub(
            r"^transformer.layers.(\d+).post_attention_layernorm.(weight|bias)",
            r"transformer.layers.\1.norm2.\2",
            key,
        )
Tri Dao's avatar
Tri Dao committed
1016
        return key
Tri Dao's avatar
Tri Dao committed
1017

Tri Dao's avatar
Tri Dao committed
1018
    state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
1019

Tri Dao's avatar
Tri Dao committed
1020
1021
    # MLP
    def key_mapping_mlp(key):
Tri Dao's avatar
Tri Dao committed
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
        key = re.sub(
            r"^transformer.layers.(\d+).mlp.dense_h_to_4h.(weight|bias)",
            r"transformer.layers.\1.mlp.fc1.\2",
            key,
        )
        key = re.sub(
            r"^transformer.layers.(\d+).mlp.dense_4h_to_h.(weight|bias)",
            r"transformer.layers.\1.mlp.fc2.\2",
            key,
        )
Tri Dao's avatar
Tri Dao committed
1032
        return key
Tri Dao's avatar
Tri Dao committed
1033

Tri Dao's avatar
Tri Dao committed
1034
    state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
1035

Tri Dao's avatar
Tri Dao committed
1036
1037
    # Attention
    def key_mapping_attn(key):
Tri Dao's avatar
Tri Dao committed
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
        key = re.sub(
            r"^transformer.layers.(\d+).self_attention.rotary_emb.inv_freq",
            r"transformer.layers.\1.mixer.rotary_emb.inv_freq",
            key,
        )
        key = re.sub(
            r"^transformer.layers.(\d+).self_attention.query_key_value.(weight|bias)",
            r"transformer.layers.\1.mixer.Wqkv.\2",
            key,
        )
        key = re.sub(
            r"^transformer.layers.(\d+).self_attention.dense.(weight|bias)",
            r"transformer.layers.\1.mixer.out_proj.\2",
            key,
        )
Tri Dao's avatar
Tri Dao committed
1053
        return key
Tri Dao's avatar
Tri Dao committed
1054

Tri Dao's avatar
Tri Dao committed
1055
1056
1057
1058
1059
    state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
    # Megatron stores Wqkv as ((nheads 3 headdim), hidden_dim)
    # while we store Wqkv as ((3 nheads headdim), hidden_dim)
    headdim = config.hidden_size // config.num_attention_heads
    for d in range(config.num_hidden_layers):
Tri Dao's avatar
Tri Dao committed
1060
1061
1062
1063
1064
1065
        Wqkv = state_dict.pop(f"transformer.layers.{d}.mixer.Wqkv.weight")
        state_dict[f"transformer.layers.{d}.mixer.Wqkv.weight"] = rearrange(
            Wqkv,
            "(nheads three headdim) ... -> (three nheads headdim) ...",
            three=3,
            headdim=headdim,
Tri Dao's avatar
Tri Dao committed
1066
        )
Tri Dao's avatar
Tri Dao committed
1067
1068
1069
        bqkv = state_dict.pop(f"transformer.layers.{d}.mixer.Wqkv.bias")
        state_dict[f"transformer.layers.{d}.mixer.Wqkv.bias"] = rearrange(
            bqkv, "(nheads three headdim) -> (three nheads headdim)", three=3, headdim=headdim
Tri Dao's avatar
Tri Dao committed
1070
        )
1071
1072

    return state_dict