gpt.py 24.8 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
# Copyright (c) 2023, Tri Dao.
Tri Dao's avatar
Tri Dao committed
2

3
import logging
Tri Dao's avatar
Tri Dao committed
4
import math
5
import re
Tri Dao's avatar
Tri Dao committed
6
7
from functools import partial

8
from collections import namedtuple, OrderedDict
Tri Dao's avatar
Tri Dao committed
9
10
11
12
13
14
from collections.abc import Sequence

import torch
import torch.nn as nn
import torch.nn.functional as F

Tri Dao's avatar
Tri Dao committed
15
from transformers import GPT2Config
Tri Dao's avatar
Tri Dao committed
16

17
18
from einops import rearrange

19
from flash_attn.modules.mha import MHA, ParallelMHA
20
from flash_attn.modules.mlp import Mlp, FusedMLP, ParallelFusedMLP
Tri Dao's avatar
Tri Dao committed
21
from flash_attn.modules.block import Block
22
from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings
23
from flash_attn.utils.distributed import sync_shared_params, all_gather_raw
24
from flash_attn.utils.pretrained import state_dict_from_pretrained
Tri Dao's avatar
Tri Dao committed
25
from flash_attn.utils.generation import GenerationMixin
Tri Dao's avatar
Tri Dao committed
26
from flash_attn.models.opt import remap_state_dict_opt
27
28
29
30
31

try:
    from flash_attn.ops.fused_dense import ColumnParallelLinear
except ImportError:
    ColumnParallelLinear = None
Tri Dao's avatar
Tri Dao committed
32
33
34
35
36
37
38
39
40
41
42
43

try:
    from flash_attn.ops.layer_norm import dropout_add_layer_norm
except ImportError:
    dropout_add_layer_norm = None

try:
    from flash_attn.ops.triton.mlp import FusedDenseSqreluDense
except ImportError:
    FusedDenseSqreluDense = None


44
45
46
logger = logging.getLogger(__name__)


47
48
def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dtype=None):
    factory_kwargs = {'device': device, 'dtype': dtype}
Tri Dao's avatar
Tri Dao committed
49
50
51
52
53
54
    head_dim = getattr(config, 'head_dim', config.hidden_size // config.num_attention_heads)
    softmax_scale = 1.0 if not config.scale_attn_weights else head_dim ** (-0.5)
    if config.scale_attn_by_inverse_layer_idx:
        assert layer_idx is not None
        softmax_scale /= float(layer_idx + 1)
    dwconv = getattr(config, 'attn_dwconv', False)
55
56
    if dwconv:
        assert process_group is None, 'TensorParallel MHA does not support dwconv yet'
Tri Dao's avatar
Tri Dao committed
57
    rotary_emb_dim = int(getattr(config, 'rotary_emb_fraction', 0.0) * head_dim)
Tri Dao's avatar
Tri Dao committed
58
    rotary_emb_scale_base = getattr(config, 'rotary_emb_scale_base', 0)
Tri Dao's avatar
Tri Dao committed
59
60
    use_flash_attn = getattr(config, 'use_flash_attn', False)
    fused_bias_fc = getattr(config, 'fused_bias_fc', False)
61
62
63
64
65
    if not fused_bias_fc:
        assert process_group is None, 'TensorParallel MHA requires fused_bias_fc'
    mha_cls = MHA if process_group is None else ParallelMHA
    serial_kwargs = ({'fused_bias_fc': fused_bias_fc, 'dwconv': dwconv}
                     if process_group is None else {})
66
67
68
    parallel_kwargs = ({'process_group': process_group,
                        'sequence_parallel': getattr(config, 'sequence_parallel', True)}
                       if process_group is not None else {})
69
    mixer_cls = partial(mha_cls, num_heads=config.num_attention_heads, dropout=config.attn_pdrop,
Tri Dao's avatar
Tri Dao committed
70
                        softmax_scale=softmax_scale, causal=True, layer_idx=layer_idx,
Tri Dao's avatar
Tri Dao committed
71
                        rotary_emb_dim=rotary_emb_dim, rotary_emb_scale_base=rotary_emb_scale_base,
72
73
                        use_flash_attn=use_flash_attn,
                        **serial_kwargs, **parallel_kwargs, **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
74
75
76
    return mixer_cls


77
78
def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtype=None):
    factory_kwargs = {'device': device, 'dtype': dtype}
Tri Dao's avatar
Tri Dao committed
79
    inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
80
81
82
    fused_mlp = getattr(config, 'fused_mlp', False)
    if fused_mlp:
        assert config.activation_function in ['gelu_new', 'gelu_fast', 'gelu_approx', 'relu']
Tri Dao's avatar
Tri Dao committed
83
    fused_dense_sqrelu_dense = getattr(config, 'fused_dense_sqrelu_dense', False)
84
85
86
    if fused_dense_sqrelu_dense:
        assert config.activation_function == 'sqrelu', ('fused_dense_sqrelu_dense only '
                                               'supports approximate activation_function sqrelu')
87
    assert not (fused_dense_sqrelu_dense and fused_mlp)
88
    if process_group is not None:
89
90
        assert fused_mlp, 'Tensor Parallel is only implemented for FusedMLP'
    if not fused_mlp and not fused_dense_sqrelu_dense:
Tri Dao's avatar
Tri Dao committed
91
92
93
        if config.activation_function == 'relu':
            activation = partial(F.relu, inplace=True)
        else:
94
95
            approximate = ('tanh' if config.activation_function
                           in ['gelu_new', 'gelu_fast', 'gelu_approx'] else 'none')
Tri Dao's avatar
Tri Dao committed
96
97
            activation=partial(F.gelu, approximate=approximate)
        mlp_cls = partial(Mlp, hidden_features=inner_dim, activation=activation, **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
98
99
100
101
102
103
    else:
        mlp_checkpoint_lvl = getattr(config, 'mlp_checkpoint_lvl', 0)
        # mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
        if isinstance(mlp_checkpoint_lvl, Sequence):
            assert layer_idx is not None
            mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
104
105
        if fused_mlp:
            if FusedMLP is None:
Tri Dao's avatar
Tri Dao committed
106
                raise ImportError('fused_dense is not installed')
107
108
109
            activation = ('gelu_approx' if config.activation_function
                          in ['gelu_new', 'gelu_fast', 'gelu_approx'] else 'relu')
            mlp_cls = FusedMLP if process_group is None else ParallelFusedMLP
110
111
112
            parallel_kwargs = ({'process_group': process_group,
                                'sequence_parallel': getattr(config, 'sequence_parallel', True)}
                               if process_group is not None else {})
113
114
            mlp_cls = partial(mlp_cls, hidden_features=inner_dim, activation=activation,
                              checkpoint_lvl=mlp_checkpoint_lvl,
115
                              **parallel_kwargs, **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
116
117
118
        elif fused_dense_sqrelu_dense:
            assert FusedDenseSqreluDense is not None
            mlp_cls = partial(FusedDenseSqreluDense, hidden_features=inner_dim,
119
                              checkpoint_lvl=mlp_checkpoint_lvl, **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
120
121
122
123
124
        else:
            raise RuntimeError('MLP type not supported')
    return mlp_cls


125
126
def create_block(config, layer_idx=None, process_group=None, device=None, dtype=None):
    factory_kwargs = {'device': device, 'dtype': dtype}
127
    sequence_parallel = getattr(config, 'sequence_parallel', True)
128
129
130
    mixer_cls = create_mixer_cls(config, layer_idx, process_group=process_group, **factory_kwargs)
    mlp_cls = create_mlp_cls(config, layer_idx, process_group=process_group, **factory_kwargs)
    norm_cls = partial(nn.LayerNorm, eps=config.layer_norm_epsilon, **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
131
132
133
134
    # TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
    residual_in_fp32 = getattr(config, 'residual_in_fp32', False)
    resid_dropout1 = config.resid_pdrop if layer_idx is None or layer_idx > 0 else config.embd_pdrop
    prenorm = getattr(config, 'prenorm', True)
Tri Dao's avatar
Tri Dao committed
135
    block = Block(config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls,
Tri Dao's avatar
Tri Dao committed
136
                  prenorm=prenorm, resid_dropout1=resid_dropout1, resid_dropout2=config.resid_pdrop,
137
                  fused_dropout_add_ln=getattr(config, 'fused_dropout_add_ln', False),
Tri Dao's avatar
Tri Dao committed
138
                  residual_in_fp32=residual_in_fp32,
139
140
                  sequence_parallel=sequence_parallel and process_group is not None,
                  mark_shared_params=process_group is not None)
Tri Dao's avatar
Tri Dao committed
141
142
143
144
    block.layer_idx = layer_idx
    return block


145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
class GPTPreTrainedModel(nn.Module):
    """ An abstract class to handle weights initialization and
        a simple interface for dowloading and loading pretrained models.
    """
    def __init__(self, config, *inputs, **kwargs):
        super().__init__()
        if not isinstance(config, GPT2Config):
            raise ValueError(
                "Parameter config in `{}(config)` should be an instance of class `GPT2Config`. "
                "To create a model from a Google pretrained model use "
                "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
                    self.__class__.__name__, self.__class__.__name__
                ))
        self.config = config

    @classmethod
161
162
    def from_pretrained(cls, model_name, config, *args, strict=True, device=None, dtype=None,
                        world_size=1, rank=0, **kwargs):
163
164
165
166
167
        """
        Instantiate a GPTPreTrainedModel from a pre-trained model file or a pytorch state dict.
        Download and cache the pre-trained model file if needed.
        """
        # Instantiate model.
168
        model = cls(config, *args, device=device, dtype=dtype, **kwargs)
Tri Dao's avatar
Tri Dao committed
169
170
171
        # If we're going to shard the model, then don't load fp32 weights to GPU.
        state_dict = state_dict_from_pretrained(
            model_name, device=device if world_size == 1 else None, dtype=dtype
172
        )
Tri Dao's avatar
Tri Dao committed
173
174
175
176
177
178
        if model_name.startswith('gpt2'):
            state_dict = remap_state_dict_gpt2(state_dict, config)
        elif model_name.startswith('facebook/opt'):
            state_dict = remap_state_dict_opt(state_dict, config)
        else:
            raise NotImplementedError(f'Model {model_name} not supported')
179
180
181
182
        if world_size > 1:
            state_dict = shard_state_dict_tp(state_dict, config, world_size, rank)
            state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
        load_return = model.load_state_dict(state_dict, strict=strict)
183
184
185
        logger.info(load_return)
        return model

Tri Dao's avatar
Tri Dao committed
186

Tri Dao's avatar
Tri Dao committed
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
def _init_weights(module, n_layer, initializer_range=0.02, rescale_prenorm_residual=True):
    if isinstance(module, nn.Linear):
        nn.init.normal_(module.weight, std=initializer_range)
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Embedding):
        nn.init.normal_(module.weight, std=initializer_range)

    if rescale_prenorm_residual:
        # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
        #   > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
        #   > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
        #   >   -- GPT-2 :: https://openai.com/blog/better-language-models/
        #
        # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
        for name, p in module.named_parameters():
            if name in ["out_proj.weight", "fc2.weight"]:
                # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
                nn.init.normal_(p, mean=0.0, std=initializer_range / math.sqrt(2 * n_layer))


209
class GPTModel(GPTPreTrainedModel):
Tri Dao's avatar
Tri Dao committed
210

211
    def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=None):
212
        super().__init__(config)
213
214
        factory_kwargs = {'device': device, 'dtype': dtype}
        self.process_group = process_group
215
        self.sequence_parallel = getattr(config, 'sequence_parallel', True)
216
217
        assert config.activation_function in ['gelu', 'gelu_new', 'gelu_fast', 'gelu_approx',
                                              'relu', 'sqrelu']
218
219
220
        pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
        vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple)
                      * pad_vocab_size_multiple)
Tri Dao's avatar
Tri Dao committed
221
222
223
224
225
        # TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
        self.residual_in_fp32 = getattr(config, 'residual_in_fp32', False)
        # These 2 options are for OPT-350m
        self.prenorm = getattr(config, 'prenorm', True)
        word_embed_proj_dim = getattr(config, 'word_embed_proj_dim', None)
Tri Dao's avatar
Tri Dao committed
226

227
        if process_group is None:
Tri Dao's avatar
Tri Dao committed
228
229
230
231
            self.embeddings = GPT2Embeddings(
                config.hidden_size, vocab_size, config.max_position_embeddings,
                word_embed_proj_dim=word_embed_proj_dim, **factory_kwargs
            )
232
233
        else:
            self.embeddings = ParallelGPT2Embeddings(
234
                config.hidden_size, vocab_size, config.max_position_embeddings,
235
236
                process_group=process_group, sequence_parallel=self.sequence_parallel,
                **factory_kwargs
237
            )
Tri Dao's avatar
Tri Dao committed
238

Tri Dao's avatar
Tri Dao committed
239
        # We change the order of dropout, residual and layer norm:
Tri Dao's avatar
Tri Dao committed
240
        # Instead of LN -> Attn / MLP -> Dropout -> Add, we do:
Tri Dao's avatar
Tri Dao committed
241
242
243
        # Dropout -> Add -> LN -> Attn / MLP, returning both the residual branch (output of Add) and
        # the main branch (output of MLP). The model definition is unchanged, but the mapping of the
        # nn.Dropout probabilities are changed.
Tri Dao's avatar
Tri Dao committed
244
        # This is for performance reason: we can fuse dropout + add + layer_norm.
Tri Dao's avatar
Tri Dao committed
245
246
247
248
        self.layers = nn.ModuleList([create_block(config, layer_idx=i, process_group=process_group,
                                                  **factory_kwargs)
                                     for i in range(config.num_hidden_layers)])

Tri Dao's avatar
Tri Dao committed
249
250
251
        self.fused_dropout_add_ln = getattr(config, 'fused_dropout_add_ln', False)
        if self.fused_dropout_add_ln and dropout_add_layer_norm is None:
            raise ImportError('dropout_add_layer_norm is not installed')
Tri Dao's avatar
Tri Dao committed
252
253
254
255
        if self.prenorm:
            self.drop_f = nn.Dropout(config.resid_pdrop)
            self.ln_f = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon,
                                    **factory_kwargs)
256
        if process_group is not None:
Tri Dao's avatar
Tri Dao committed
257
            for p in self.ln_f.parameters():
258
259
260
261
262
                # Mark the norm parameters as "shared_params" so that we sync their values at init.
                p._shared_params = True
                # Mark the norm params as "sequence_parallel" so we run all-reduce on their grads.
                if self.sequence_parallel:
                    p._sequence_parallel = True
263

Tri Dao's avatar
Tri Dao committed
264
265
        self.apply(partial(_init_weights, n_layer=config.num_hidden_layers,
                           initializer_range=config.initializer_range))
266
267
268
        self.tie_weights()

    def tie_weights(self):
269
        if self.process_group is not None:
270
            sync_shared_params(self, self.process_group)
Tri Dao's avatar
Tri Dao committed
271

Tri Dao's avatar
Tri Dao committed
272
    def forward(self, input_ids, position_ids=None, inference_params=None):
273
274
275
276
        # If using Tensor Parallel with sequence parallel, we combine the batch and the seqlen
        # dimensions so that we can split on it easily, in case of small batch size.
        # Only the attention layers need to know the seqlen.
        embedding_kwargs = ({'combine_batch_seqlen_dim': True}
277
                            if self.process_group is not None and self.sequence_parallel else {})
278
        hidden_states = self.embeddings(input_ids, position_ids=position_ids, **embedding_kwargs)
Tri Dao's avatar
Tri Dao committed
279
        residual = None
280
281
        mixer_kwargs = ({'seqlen': input_ids.shape[1]}
                        if self.process_group is not None and self.sequence_parallel else {})
Tri Dao's avatar
Tri Dao committed
282
283
        if inference_params is not None:
            mixer_kwargs['inference_params'] = inference_params
Tri Dao's avatar
Tri Dao committed
284
        for layer in self.layers:
Tri Dao's avatar
Tri Dao committed
285
286
287
288
289
290
291
292
293
294
            if self.prenorm:
                hidden_states, residual = layer(hidden_states, residual, mixer_kwargs=mixer_kwargs)
            else:
                hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
        if self.prenorm:
            if not self.fused_dropout_add_ln:
                dropped = self.drop_f(hidden_states)
                residual = (dropped + residual) if residual is not None else dropped
                hidden_states = self.ln_f(residual.to(dtype=self.ln_f.weight.dtype))
            else:
Tri Dao's avatar
Tri Dao committed
295
                # Set prenorm=False here since we don't need the residual
Tri Dao's avatar
Tri Dao committed
296
297
298
299
300
                hidden_states = dropout_add_layer_norm(
                    hidden_states, residual, self.ln_f.weight, self.ln_f.bias,
                    self.drop_f.p if self.training else 0.0, self.ln_f.eps, prenorm=False,
                    residual_in_fp32=self.residual_in_fp32
                )
Tri Dao's avatar
Tri Dao committed
301
302
303
        return hidden_states


Tri Dao's avatar
Tri Dao committed
304
class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
Tri Dao's avatar
Tri Dao committed
305

306
307
    def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=None):
        factory_kwargs = {'device': device, 'dtype': dtype}
308
        super().__init__(config)
309
310
        self.process_group = process_group
        self.transformer = GPTModel(config, process_group=process_group, **factory_kwargs)
311
312
313
        pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
        vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple)
                      * pad_vocab_size_multiple)
Tri Dao's avatar
Tri Dao committed
314
315
316
317
318
319
320
        # This option is for OPT-350m
        word_embed_proj_dim = getattr(config, 'word_embed_proj_dim', None)
        embed_dim = config.n_embd if word_embed_proj_dim is None else word_embed_proj_dim
        if word_embed_proj_dim is not None:
            self.project_out = nn.Linear(config.n_embd, embed_dim, bias=False, **factory_kwargs)
        else:
            self.project_out = None
321
        if process_group is None:
Tri Dao's avatar
Tri Dao committed
322
            self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False, **factory_kwargs)
323
324
325
        else:
            if ColumnParallelLinear is None:
                raise ImportError('fused_dense_lib is not installed')
326
            self.lm_head = ColumnParallelLinear(
Tri Dao's avatar
Tri Dao committed
327
                embed_dim, vocab_size, process_group, bias=False,
328
329
                sequence_parallel=getattr(config, 'sequence_parallel', True), **factory_kwargs
            )
Tri Dao's avatar
Tri Dao committed
330
331
332
333
334
335
336
        # Initialize weights and apply final processing
        self.apply(partial(_init_weights, n_layer=config.num_hidden_layers,
                           initializer_range=config.initializer_range))
        self.tie_weights()

    def tie_weights(self):
        self.lm_head.weight = self.transformer.embeddings.word_embeddings.weight
337
        if self.process_group is not None:
338
            sync_shared_params(self, self.process_group)
Tri Dao's avatar
Tri Dao committed
339

Tri Dao's avatar
Tri Dao committed
340
341
342
343
344
345
346
    def forward(self, input_ids, position_ids=None, inference_params=None):
        """
            inference_params: for generation. Adapted from Megatron-LM (and Apex)
            https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
        """
        hidden_states = self.transformer(input_ids, position_ids=position_ids,
                                         inference_params=inference_params)
Tri Dao's avatar
Tri Dao committed
347
348
        if self.project_out is not None:
            hidden_states = self.project_out(hidden_states)
Tri Dao's avatar
Tri Dao committed
349
        lm_logits = self.lm_head(hidden_states)
350
351
352
353
        # During inference, we want the full logit for sampling
        if isinstance(self.lm_head, ColumnParallelLinear) and inference_params is not None:
            lm_logits, _ = all_gather_raw(lm_logits, self.lm_head.process_group)
            lm_logits = rearrange(lm_logits, '(n b) s d -> b s (n d)', b=hidden_states.shape[0])
Tri Dao's avatar
Tri Dao committed
354
355
        CausalLMOutput = namedtuple('CausalLMOutput', ['logits'])
        return CausalLMOutput(logits=lm_logits)
356

Tri Dao's avatar
Tri Dao committed
357
358
359
360
361
    def load_state_dict(self, state_dict, strict=True):
        # Remapping from our checkpoints that used a different ordering of layers in the block
        # Previous: Attn / MLP -> Dropout -> Add -> LN
        # Current: Dropout -> Add -> LN -> Attn / MLP
        if 'transformer.ln_0.weight' in state_dict:
Tri Dao's avatar
Tri Dao committed
362
            n_layers = len(self.transformer.layers)
Tri Dao's avatar
Tri Dao committed
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
            ln_weight = state_dict.pop(f'transformer.layers.{n_layers - 1}.norm2.weight')
            ln_bias = state_dict.pop(f'transformer.layers.{n_layers - 1}.norm2.bias')
            state_dict['transformer.ln_f.weight'] = ln_weight
            state_dict['transformer.ln_f.bias'] = ln_bias
            for l in reversed(range(n_layers)):
                ln_weight = state_dict.pop(f'transformer.layers.{l}.norm1.weight')
                ln_bias = state_dict.pop(f'transformer.layers.{l}.norm1.bias')
                state_dict[f'transformer.layers.{l}.norm2.weight'] = ln_weight
                state_dict[f'transformer.layers.{l}.norm2.bias'] = ln_bias
                if l > 0:
                    ln_weight = state_dict.pop(f'transformer.layers.{l - 1}.norm2.weight')
                    ln_bias = state_dict.pop(f'transformer.layers.{l - 1}.norm2.bias')
                    state_dict[f'transformer.layers.{l}.norm1.weight'] = ln_weight
                    state_dict[f'transformer.layers.{l}.norm1.bias'] = ln_bias
            ln_weight = state_dict.pop('transformer.ln_0.weight')
            ln_bias = state_dict.pop('transformer.ln_0.bias')
            state_dict[f'transformer.layers.0.norm1.weight'] = ln_weight
            state_dict[f'transformer.layers.0.norm1.bias'] = ln_bias
        return super().load_state_dict(state_dict, strict=strict)

383
384
385
386
387
388
389
390

def remap_state_dict_gpt2(state_dict, config):
    # Word embedding and position embedding
    def key_mapping_pos_emb(key):
        return re.sub(r'^wpe.', 'transformer.embeddings.position_embeddings.', key)
    state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items())
    word_embeddings = state_dict.pop('wte.weight')
    # It's possible that vocab_size is padded to be a multiple of 8, for example.
391
392
    pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
    vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
393
    state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
394
        word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
395
396
397
398
    )
    state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']

    # LayerNorm
Tri Dao's avatar
Tri Dao committed
399
400
401
402
403
    def key_mapping_ln(key):
        key = re.sub(r'^ln_f.(weight|bias)', r'transformer.ln_f.\1', key)
        key = re.sub(r'^h.(\d+).ln_(1|2).(weight|bias)', r'transformer.layers.\1.norm\2.\3', key)
        return key
    state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430

    # MLP
    for d in range(config.num_hidden_layers):
        W1 = state_dict.pop(f'h.{d}.mlp.c_fc.weight')
        state_dict[f'transformer.layers.{d}.mlp.fc1.weight'] = W1.t()
        W2 = state_dict.pop(f'h.{d}.mlp.c_proj.weight')
        state_dict[f'transformer.layers.{d}.mlp.fc2.weight'] = W2.t()
    def key_mapping_mlp(key):
        key = re.sub(r'^h.(\d+).mlp.c_fc.bias', r'transformer.layers.\1.mlp.fc1.bias', key)
        key = re.sub(r'^h.(\d+).mlp.c_proj.bias', r'transformer.layers.\1.mlp.fc2.bias', key)
        return key
    state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())

    # Attention
    for d in range(config.num_hidden_layers):
        state_dict.pop(f'h.{d}.attn.bias')  # We don't store this bias
        Wqkv = state_dict.pop(f'h.{d}.attn.c_attn.weight')
        state_dict[f'transformer.layers.{d}.mixer.Wqkv.weight'] = Wqkv.t()
        Wout = state_dict.pop(f'h.{d}.attn.c_proj.weight')
        state_dict[f'transformer.layers.{d}.mixer.out_proj.weight'] = Wout.t()
    def key_mapping_attn(key):
        key = re.sub(r'^h.(\d+).attn.c_attn.bias', r'transformer.layers.\1.mixer.Wqkv.bias', key)
        key = re.sub(r'^h.(\d+).attn.c_proj.bias', r'transformer.layers.\1.mixer.out_proj.bias', key)
        return key
    state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())

    return state_dict
431
432
433
434
435
436


def shard_state_dict_tp(state_dict, config, world_size, rank):
    """Convert the state_dict of a standard GPT model to the state_dict of a GPT model
    with tensor parallel.
    """
437
438
    pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
    vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
    assert vocab_size % world_size == 0
    assert config.hidden_size % world_size == 0
    inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
    assert inner_dim % world_size == 0

    def shard_first_dim(state_dict, key):
        x = state_dict[key]
        dim = x.shape[0] // world_size
        state_dict[key] = x[rank * dim:(rank + 1) * dim]

    def shard_last_dim(state_dict, key):
        x = state_dict[key]
        dim = x.shape[-1] // world_size
        state_dict[key] = x[..., rank * dim:(rank + 1) * dim]

    def shard_qkv_headdim(state_dict, key):
        x = rearrange(state_dict[key], '(three d) ... -> three d ...', three=3)
        dim = x.shape[1] // world_size
        state_dict[key] = rearrange(x[:, rank * dim:(rank + 1) * dim],
                                    'three d ... -> (three d) ...')

    shard_first_dim(state_dict, 'transformer.embeddings.word_embeddings.weight')
    if 'lm_head.weight' in state_dict:
        shard_first_dim(state_dict, 'lm_head.weight')
    if 'transformer.embeddings.position_embeddings.weight' in state_dict:
        shard_last_dim(state_dict, 'transformer.embeddings.position_embeddings.weight')
    for i in range(config.num_hidden_layers):
        shard_qkv_headdim(state_dict, f'transformer.layers.{i}.mixer.Wqkv.weight')
        shard_qkv_headdim(state_dict, f'transformer.layers.{i}.mixer.Wqkv.bias')
        shard_last_dim(state_dict, f'transformer.layers.{i}.mixer.out_proj.weight')
        if rank != 0:
            state_dict.pop(f'transformer.layers.{i}.mixer.out_proj.bias')
        shard_first_dim(state_dict, f'transformer.layers.{i}.mlp.fc1.weight')
        shard_first_dim(state_dict, f'transformer.layers.{i}.mlp.fc1.bias')
        shard_last_dim(state_dict, f'transformer.layers.{i}.mlp.fc2.weight')
        if rank != 0:
            state_dict.pop(f'transformer.layers.{i}.mlp.fc2.bias')
    return state_dict