gpt.py 22 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
# Copyright (c) 2022, Tri Dao.

3
import logging
Tri Dao's avatar
Tri Dao committed
4
import math
5
import re
Tri Dao's avatar
Tri Dao committed
6
7
from functools import partial

8
from collections import namedtuple, OrderedDict
Tri Dao's avatar
Tri Dao committed
9
10
11
12
13
14
from collections.abc import Sequence

import torch
import torch.nn as nn
import torch.nn.functional as F

Tri Dao's avatar
Tri Dao committed
15
from transformers import GPT2Config
Tri Dao's avatar
Tri Dao committed
16

17
18
from einops import rearrange

19
20
from flash_attn.modules.mha import MHA, ParallelMHA
from flash_attn.modules.mlp import Mlp, FusedDenseGeluDense, ParallelFusedDenseGeluDense
Tri Dao's avatar
Tri Dao committed
21
from flash_attn.modules.block import Block
22
from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings
23
from flash_attn.utils.distributed import sync_shared_params, all_gather_raw
24
from flash_attn.utils.pretrained import state_dict_from_pretrained
Tri Dao's avatar
Tri Dao committed
25
from flash_attn.utils.generation import GenerationMixin
26
27
28
29
30

try:
    from flash_attn.ops.fused_dense import ColumnParallelLinear
except ImportError:
    ColumnParallelLinear = None
Tri Dao's avatar
Tri Dao committed
31
32
33
34
35
36
37
38
39
40
41
42

try:
    from flash_attn.ops.layer_norm import dropout_add_layer_norm
except ImportError:
    dropout_add_layer_norm = None

try:
    from flash_attn.ops.triton.mlp import FusedDenseSqreluDense
except ImportError:
    FusedDenseSqreluDense = None


43
44
45
logger = logging.getLogger(__name__)


46
47
def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dtype=None):
    factory_kwargs = {'device': device, 'dtype': dtype}
Tri Dao's avatar
Tri Dao committed
48
49
50
51
52
53
    head_dim = getattr(config, 'head_dim', config.hidden_size // config.num_attention_heads)
    softmax_scale = 1.0 if not config.scale_attn_weights else head_dim ** (-0.5)
    if config.scale_attn_by_inverse_layer_idx:
        assert layer_idx is not None
        softmax_scale /= float(layer_idx + 1)
    dwconv = getattr(config, 'attn_dwconv', False)
54
55
    if dwconv:
        assert process_group is None, 'TensorParallel MHA does not support dwconv yet'
Tri Dao's avatar
Tri Dao committed
56
    rotary_emb_dim = int(getattr(config, 'rotary_emb_fraction', 0.0) * head_dim)
Tri Dao's avatar
Tri Dao committed
57
    rotary_emb_scale_base = getattr(config, 'rotary_emb_scale_base', 0)
Tri Dao's avatar
Tri Dao committed
58
59
    use_flash_attn = getattr(config, 'use_flash_attn', False)
    fused_bias_fc = getattr(config, 'fused_bias_fc', False)
60
61
62
63
64
    if not fused_bias_fc:
        assert process_group is None, 'TensorParallel MHA requires fused_bias_fc'
    mha_cls = MHA if process_group is None else ParallelMHA
    serial_kwargs = ({'fused_bias_fc': fused_bias_fc, 'dwconv': dwconv}
                     if process_group is None else {})
65
66
67
    parallel_kwargs = ({'process_group': process_group,
                        'sequence_parallel': getattr(config, 'sequence_parallel', True)}
                       if process_group is not None else {})
68
    mixer_cls = partial(mha_cls, num_heads=config.num_attention_heads, dropout=config.attn_pdrop,
Tri Dao's avatar
Tri Dao committed
69
                        softmax_scale=softmax_scale, causal=True, layer_idx=layer_idx,
Tri Dao's avatar
Tri Dao committed
70
                        rotary_emb_dim=rotary_emb_dim, rotary_emb_scale_base=rotary_emb_scale_base,
71
72
                        use_flash_attn=use_flash_attn,
                        **serial_kwargs, **parallel_kwargs, **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
73
74
75
    return mixer_cls


76
77
def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtype=None):
    factory_kwargs = {'device': device, 'dtype': dtype}
Tri Dao's avatar
Tri Dao committed
78
79
    inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
    fused_dense_gelu_dense = getattr(config, 'fused_dense_gelu_dense', False)
80
81
82
    if fused_dense_gelu_dense:
        assert config.activation_function in ['gelu_new', 'gelu_fast'], ('fused_dense_gelu_dense only '
                                                                'supports approximate gelu')
Tri Dao's avatar
Tri Dao committed
83
    fused_dense_sqrelu_dense = getattr(config, 'fused_dense_sqrelu_dense', False)
84
85
86
    if fused_dense_sqrelu_dense:
        assert config.activation_function == 'sqrelu', ('fused_dense_sqrelu_dense only '
                                               'supports approximate activation_function sqrelu')
Tri Dao's avatar
Tri Dao committed
87
    assert not (fused_dense_sqrelu_dense and fused_dense_gelu_dense)
88
89
    if process_group is not None:
        assert fused_dense_gelu_dense, 'Tensor Parallel is only implemented for FusedDenseGeluDense'
Tri Dao's avatar
Tri Dao committed
90
    if not fused_dense_gelu_dense and not fused_dense_sqrelu_dense:
91
        approximate = 'tanh' if config.activation_function in ['gelu_new', 'gelu_fast'] else 'none'
Tri Dao's avatar
Tri Dao committed
92
        mlp_cls = partial(Mlp, hidden_features=inner_dim,
93
                          activation=partial(F.gelu, approximate=approximate), **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
94
95
96
97
98
99
100
    else:
        mlp_checkpoint_lvl = getattr(config, 'mlp_checkpoint_lvl', 0)
        # mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
        if isinstance(mlp_checkpoint_lvl, Sequence):
            assert layer_idx is not None
            mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
        if fused_dense_gelu_dense:
Tri Dao's avatar
Tri Dao committed
101
102
            if FusedDenseGeluDense is None:
                raise ImportError('fused_dense is not installed')
103
            mlp_cls = FusedDenseGeluDense if process_group is None else ParallelFusedDenseGeluDense
104
105
106
            parallel_kwargs = ({'process_group': process_group,
                                'sequence_parallel': getattr(config, 'sequence_parallel', True)}
                               if process_group is not None else {})
107
108
            mlp_cls = partial(mlp_cls, hidden_features=inner_dim, checkpoint_lvl=mlp_checkpoint_lvl,
                              **parallel_kwargs, **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
109
110
111
        elif fused_dense_sqrelu_dense:
            assert FusedDenseSqreluDense is not None
            mlp_cls = partial(FusedDenseSqreluDense, hidden_features=inner_dim,
112
                              checkpoint_lvl=mlp_checkpoint_lvl, **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
113
114
115
116
117
        else:
            raise RuntimeError('MLP type not supported')
    return mlp_cls


118
119
def create_block(config, layer_idx=None, process_group=None, device=None, dtype=None):
    factory_kwargs = {'device': device, 'dtype': dtype}
120
    sequence_parallel = getattr(config, 'sequence_parallel', True)
121
122
123
    mixer_cls = create_mixer_cls(config, layer_idx, process_group=process_group, **factory_kwargs)
    mlp_cls = create_mlp_cls(config, layer_idx, process_group=process_group, **factory_kwargs)
    norm_cls = partial(nn.LayerNorm, eps=config.layer_norm_epsilon, **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
124
125
    block = Block(config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls,
                  prenorm=True, resid_dropout=config.resid_pdrop,
126
                  fused_dropout_add_ln=getattr(config, 'fused_dropout_add_ln', False),
127
128
                  sequence_parallel=sequence_parallel and process_group is not None,
                  mark_shared_params=process_group is not None)
Tri Dao's avatar
Tri Dao committed
129
130
131
132
    block.layer_idx = layer_idx
    return block


133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
class GPTPreTrainedModel(nn.Module):
    """ An abstract class to handle weights initialization and
        a simple interface for dowloading and loading pretrained models.
    """
    def __init__(self, config, *inputs, **kwargs):
        super().__init__()
        if not isinstance(config, GPT2Config):
            raise ValueError(
                "Parameter config in `{}(config)` should be an instance of class `GPT2Config`. "
                "To create a model from a Google pretrained model use "
                "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
                    self.__class__.__name__, self.__class__.__name__
                ))
        self.config = config

    @classmethod
149
150
    def from_pretrained(cls, model_name, config, *args, strict=True, device=None, dtype=None,
                        world_size=1, rank=0, **kwargs):
151
152
153
154
155
        """
        Instantiate a GPTPreTrainedModel from a pre-trained model file or a pytorch state dict.
        Download and cache the pre-trained model file if needed.
        """
        # Instantiate model.
156
157
158
159
160
        model = cls(config, *args, device=device, dtype=dtype, **kwargs)
        state_dict = remap_state_dict_gpt2(
            # If we're going to shard the model, then don't load fp32 weights to GPU.
            state_dict_from_pretrained(model_name, device=device if world_size == 1 else None,
                                       dtype=dtype), config
161
        )
162
163
164
165
        if world_size > 1:
            state_dict = shard_state_dict_tp(state_dict, config, world_size, rank)
            state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
        load_return = model.load_state_dict(state_dict, strict=strict)
166
167
168
        logger.info(load_return)
        return model

Tri Dao's avatar
Tri Dao committed
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
def _init_weights(module, n_layer, initializer_range=0.02, rescale_prenorm_residual=True):
    if isinstance(module, nn.Linear):
        nn.init.normal_(module.weight, std=initializer_range)
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Embedding):
        nn.init.normal_(module.weight, std=initializer_range)

    if rescale_prenorm_residual:
        # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
        #   > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
        #   > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
        #   >   -- GPT-2 :: https://openai.com/blog/better-language-models/
        #
        # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
        for name, p in module.named_parameters():
            if name in ["out_proj.weight", "fc2.weight"]:
                # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
                nn.init.normal_(p, mean=0.0, std=initializer_range / math.sqrt(2 * n_layer))


191
class GPTModel(GPTPreTrainedModel):
Tri Dao's avatar
Tri Dao committed
192

193
    def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=None):
194
        super().__init__(config)
195
196
        factory_kwargs = {'device': device, 'dtype': dtype}
        self.process_group = process_group
197
        self.sequence_parallel = getattr(config, 'sequence_parallel', True)
198
        assert config.activation_function in ['gelu', 'gelu_new', 'gelu_fast', 'sqrelu']
199
200
201
        pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
        vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple)
                      * pad_vocab_size_multiple)
Tri Dao's avatar
Tri Dao committed
202

203
        if process_group is None:
204
            self.embeddings = GPT2Embeddings(config.hidden_size, vocab_size,
205
206
207
                                             config.max_position_embeddings, **factory_kwargs)
        else:
            self.embeddings = ParallelGPT2Embeddings(
208
                config.hidden_size, vocab_size, config.max_position_embeddings,
209
210
                process_group=process_group, sequence_parallel=self.sequence_parallel,
                **factory_kwargs
211
            )
Tri Dao's avatar
Tri Dao committed
212
213
214
215
216
217
218
219
220
221
222
223
224
        self.emb_drop = nn.Dropout(config.embd_pdrop)

        # We change the order of residual and layer norm:
        # Instead of LN -> Attn / MLP -> Dropout -> Add, we do:
        # Attn / MLP -> Dropout -> Add -> LN, returning both the residual branch (output of Add) and
        # the main branch (output of LN). The model definition is unchanged, but the mapping of the
        # nn.LayerNorm weights are changed.
        # This is for performance reason: we can fuse dropout + add + layer_norm.
        self.fused_dropout_add_ln = getattr(config, 'fused_dropout_add_ln', False)
        if self.fused_dropout_add_ln and dropout_add_layer_norm is None:
            raise ImportError('dropout_add_layer_norm is not installed')
        # self.ln_0 is the first layer norm in the model, while self.ln_f (in the pretrained weight)
        # is the final layer norm.
225
226
227
228
        self.ln_0 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon,
                                 **factory_kwargs)
        if process_group is not None:
            for p in self.ln_0.parameters():
229
230
231
232
233
                # Mark the norm parameters as "shared_params" so that we sync their values at init.
                p._shared_params = True
                # Mark the norm params as "sequence_parallel" so we run all-reduce on their grads.
                if self.sequence_parallel:
                    p._sequence_parallel = True
234
235
236

        self.layers = nn.ModuleList([create_block(config, layer_idx=i, process_group=process_group,
                                                  **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
237
238
239
240
                                     for i in range(config.num_hidden_layers)])

        self.apply(partial(_init_weights, n_layer=config.num_hidden_layers,
                           initializer_range=config.initializer_range))
241
242
243
        self.tie_weights()

    def tie_weights(self):
244
        if self.process_group is not None:
245
            sync_shared_params(self, self.process_group)
Tri Dao's avatar
Tri Dao committed
246

Tri Dao's avatar
Tri Dao committed
247
    def forward(self, input_ids, position_ids=None, inference_params=None):
248
249
250
251
        # If using Tensor Parallel with sequence parallel, we combine the batch and the seqlen
        # dimensions so that we can split on it easily, in case of small batch size.
        # Only the attention layers need to know the seqlen.
        embedding_kwargs = ({'combine_batch_seqlen_dim': True}
252
                            if self.process_group is not None and self.sequence_parallel else {})
253
        hidden_states = self.embeddings(input_ids, position_ids=position_ids, **embedding_kwargs)
Tri Dao's avatar
Tri Dao committed
254
255
        # TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
        if not self.fused_dropout_add_ln:
256
            residual = self.emb_drop(hidden_states)
Tri Dao's avatar
Tri Dao committed
257
            hidden_states = self.ln_0(residual.to(dtype=self.ln_0.weight.dtype))
258
            residual = residual.float()
Tri Dao's avatar
Tri Dao committed
259
260
261
262
263
264
        else:
            hidden_states, residual = dropout_add_layer_norm(
                hidden_states, None, self.ln_0.weight, self.ln_0.bias,
                self.emb_drop.p if self.training else 0.0, self.ln_0.eps, prenorm=True,
                residual_in_fp32=True
            )
265
266
        mixer_kwargs = ({'seqlen': input_ids.shape[1]}
                        if self.process_group is not None and self.sequence_parallel else {})
Tri Dao's avatar
Tri Dao committed
267
268
        if inference_params is not None:
            mixer_kwargs['inference_params'] = inference_params
Tri Dao's avatar
Tri Dao committed
269
        for layer in self.layers:
270
            hidden_states, residual = layer(hidden_states, residual, mixer_kwargs=mixer_kwargs)
Tri Dao's avatar
Tri Dao committed
271
272
273
        return hidden_states


Tri Dao's avatar
Tri Dao committed
274
class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
Tri Dao's avatar
Tri Dao committed
275

276
277
    def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=None):
        factory_kwargs = {'device': device, 'dtype': dtype}
278
        super().__init__(config)
279
280
        self.process_group = process_group
        self.transformer = GPTModel(config, process_group=process_group, **factory_kwargs)
281
282
283
        pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
        vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple)
                      * pad_vocab_size_multiple)
284
        if process_group is None:
285
            self.lm_head = nn.Linear(config.n_embd, vocab_size, bias=False, **factory_kwargs)
286
287
288
        else:
            if ColumnParallelLinear is None:
                raise ImportError('fused_dense_lib is not installed')
289
            self.lm_head = ColumnParallelLinear(
290
                config.n_embd, vocab_size, process_group, bias=False,
291
292
                sequence_parallel=getattr(config, 'sequence_parallel', True), **factory_kwargs
            )
Tri Dao's avatar
Tri Dao committed
293
294
295
296
297
298
299
        # Initialize weights and apply final processing
        self.apply(partial(_init_weights, n_layer=config.num_hidden_layers,
                           initializer_range=config.initializer_range))
        self.tie_weights()

    def tie_weights(self):
        self.lm_head.weight = self.transformer.embeddings.word_embeddings.weight
300
        if self.process_group is not None:
301
            sync_shared_params(self, self.process_group)
Tri Dao's avatar
Tri Dao committed
302

Tri Dao's avatar
Tri Dao committed
303
304
305
306
307
308
309
    def forward(self, input_ids, position_ids=None, inference_params=None):
        """
            inference_params: for generation. Adapted from Megatron-LM (and Apex)
            https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
        """
        hidden_states = self.transformer(input_ids, position_ids=position_ids,
                                         inference_params=inference_params)
Tri Dao's avatar
Tri Dao committed
310
        lm_logits = self.lm_head(hidden_states)
311
312
313
314
        # During inference, we want the full logit for sampling
        if isinstance(self.lm_head, ColumnParallelLinear) and inference_params is not None:
            lm_logits, _ = all_gather_raw(lm_logits, self.lm_head.process_group)
            lm_logits = rearrange(lm_logits, '(n b) s d -> b s (n d)', b=hidden_states.shape[0])
Tri Dao's avatar
Tri Dao committed
315
316
        CausalLMOutput = namedtuple('CausalLMOutput', ['logits'])
        return CausalLMOutput(logits=lm_logits)
317
318
319
320
321
322
323
324
325


def remap_state_dict_gpt2(state_dict, config):
    # Word embedding and position embedding
    def key_mapping_pos_emb(key):
        return re.sub(r'^wpe.', 'transformer.embeddings.position_embeddings.', key)
    state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items())
    word_embeddings = state_dict.pop('wte.weight')
    # It's possible that vocab_size is padded to be a multiple of 8, for example.
326
327
    pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
    vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
328
    state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
329
        word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
    )
    state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']

    # LayerNorm
    ln_weight, ln_bias = state_dict.pop('ln_f.weight'), state_dict.pop('ln_f.bias')
    state_dict[f'transformer.layers.{config.num_hidden_layers - 1}.norm2.weight'] = ln_weight
    state_dict[f'transformer.layers.{config.num_hidden_layers - 1}.norm2.bias'] = ln_bias
    ln_weight, ln_bias = state_dict.pop('h.0.ln_1.weight'), state_dict.pop('h.0.ln_1.bias')
    state_dict['transformer.ln_0.weight'] = ln_weight
    state_dict['transformer.ln_0.bias'] = ln_bias
    for d in range(config.num_hidden_layers):
        ln_weight = state_dict.pop(f'h.{d}.ln_2.weight')
        ln_bias = state_dict.pop(f'h.{d}.ln_2.bias')
        state_dict[f'transformer.layers.{d}.norm1.weight'] = ln_weight
        state_dict[f'transformer.layers.{d}.norm1.bias'] = ln_bias
        if d > 0:
            ln_weight = state_dict.pop(f'h.{d}.ln_1.weight')
            ln_bias = state_dict.pop(f'h.{d}.ln_1.bias')
            state_dict[f'transformer.layers.{d - 1}.norm2.weight'] = ln_weight
            state_dict[f'transformer.layers.{d - 1}.norm2.bias'] = ln_bias

    # MLP
    for d in range(config.num_hidden_layers):
        W1 = state_dict.pop(f'h.{d}.mlp.c_fc.weight')
        state_dict[f'transformer.layers.{d}.mlp.fc1.weight'] = W1.t()
        W2 = state_dict.pop(f'h.{d}.mlp.c_proj.weight')
        state_dict[f'transformer.layers.{d}.mlp.fc2.weight'] = W2.t()
    def key_mapping_mlp(key):
        key = re.sub(r'^h.(\d+).mlp.c_fc.bias', r'transformer.layers.\1.mlp.fc1.bias', key)
        key = re.sub(r'^h.(\d+).mlp.c_proj.bias', r'transformer.layers.\1.mlp.fc2.bias', key)
        return key
    state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())

    # Attention
    for d in range(config.num_hidden_layers):
        state_dict.pop(f'h.{d}.attn.bias')  # We don't store this bias
        Wqkv = state_dict.pop(f'h.{d}.attn.c_attn.weight')
        state_dict[f'transformer.layers.{d}.mixer.Wqkv.weight'] = Wqkv.t()
        Wout = state_dict.pop(f'h.{d}.attn.c_proj.weight')
        state_dict[f'transformer.layers.{d}.mixer.out_proj.weight'] = Wout.t()
    def key_mapping_attn(key):
        key = re.sub(r'^h.(\d+).attn.c_attn.bias', r'transformer.layers.\1.mixer.Wqkv.bias', key)
        key = re.sub(r'^h.(\d+).attn.c_proj.bias', r'transformer.layers.\1.mixer.out_proj.bias', key)
        return key
    state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())

    return state_dict
377
378
379
380
381
382


def shard_state_dict_tp(state_dict, config, world_size, rank):
    """Convert the state_dict of a standard GPT model to the state_dict of a GPT model
    with tensor parallel.
    """
383
384
    pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
    vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
    assert vocab_size % world_size == 0
    assert config.hidden_size % world_size == 0
    inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
    assert inner_dim % world_size == 0

    def shard_first_dim(state_dict, key):
        x = state_dict[key]
        dim = x.shape[0] // world_size
        state_dict[key] = x[rank * dim:(rank + 1) * dim]

    def shard_last_dim(state_dict, key):
        x = state_dict[key]
        dim = x.shape[-1] // world_size
        state_dict[key] = x[..., rank * dim:(rank + 1) * dim]

    def shard_qkv_headdim(state_dict, key):
        x = rearrange(state_dict[key], '(three d) ... -> three d ...', three=3)
        dim = x.shape[1] // world_size
        state_dict[key] = rearrange(x[:, rank * dim:(rank + 1) * dim],
                                    'three d ... -> (three d) ...')

    shard_first_dim(state_dict, 'transformer.embeddings.word_embeddings.weight')
    if 'lm_head.weight' in state_dict:
        shard_first_dim(state_dict, 'lm_head.weight')
    if 'transformer.embeddings.position_embeddings.weight' in state_dict:
        shard_last_dim(state_dict, 'transformer.embeddings.position_embeddings.weight')
    for i in range(config.num_hidden_layers):
        shard_qkv_headdim(state_dict, f'transformer.layers.{i}.mixer.Wqkv.weight')
        shard_qkv_headdim(state_dict, f'transformer.layers.{i}.mixer.Wqkv.bias')
        shard_last_dim(state_dict, f'transformer.layers.{i}.mixer.out_proj.weight')
        if rank != 0:
            state_dict.pop(f'transformer.layers.{i}.mixer.out_proj.bias')
        shard_first_dim(state_dict, f'transformer.layers.{i}.mlp.fc1.weight')
        shard_first_dim(state_dict, f'transformer.layers.{i}.mlp.fc1.bias')
        shard_last_dim(state_dict, f'transformer.layers.{i}.mlp.fc2.weight')
        if rank != 0:
            state_dict.pop(f'transformer.layers.{i}.mlp.fc2.bias')
    return state_dict