mlp.py 9.89 KB
Newer Older
liangjing's avatar
liangjing committed
1
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
liangjing's avatar
v1  
liangjing committed
2

liangjing's avatar
liangjing committed
3
4
5
6
from dataclasses import dataclass
from typing import Optional, Tuple, Union

import numpy as np
liangjing's avatar
v1  
liangjing committed
7
8
9
import torch
import torch.nn.functional as F

liangjing's avatar
liangjing committed
10
11
12
13
14
15
from megatron.core import parallel_state
from megatron.core.dist_checkpointing import ShardedTensor
from megatron.core.dist_checkpointing.mapping import (
    ReplicaId,
    ShardedStateDict,
    ShardedTensorFactory,
liangjing's avatar
v1  
liangjing committed
16
)
liangjing's avatar
liangjing committed
17
18
19
from megatron.core.fusions.fused_bias_geglu import bias_geglu_impl
from megatron.core.fusions.fused_bias_gelu import bias_gelu_impl
from megatron.core.fusions.fused_bias_swiglu import bias_swiglu_impl
liangjing's avatar
v1  
liangjing committed
20
from megatron.core.transformer.module import MegatronModule
liangjing's avatar
liangjing committed
21
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
liangjing's avatar
v1  
liangjing committed
22
from megatron.core.transformer.transformer_config import TransformerConfig
liangjing's avatar
liangjing committed
23
24
25
26
27
28
29
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint


@dataclass
class MLPSubmodules:
    linear_fc1: Union[ModuleSpec, type] = None
    linear_fc2: Union[ModuleSpec, type] = None
liangjing's avatar
v1  
liangjing committed
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48


class MLP(MegatronModule):
    """
    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
    state back into h hidden dimension.


    Returns an output and a bias to be added to the output.
    If config.add_bias_linear is False, the bias returned is None.

    We use the following notation:
     h: hidden size
     p: number of tensor model parallel partitions
     b: batch size
     s: sequence length
    """

liangjing's avatar
liangjing committed
49
50
51
52
53
54
55
    def __init__(
        self,
        config: TransformerConfig,
        submodules: MLPSubmodules,
        is_expert: bool = False,
        input_size: int = None,
    ):
liangjing's avatar
v1  
liangjing committed
56
57
58
59
        super().__init__(config=config)

        self.config: TransformerConfig = config

liangjing's avatar
liangjing committed
60
61
        self.input_size = input_size if input_size != None else self.config.hidden_size

liangjing's avatar
v1  
liangjing committed
62
63
64
65
66
        # If this is a gated linear unit we double the output width, see https://arxiv.org/pdf/2002.05202.pdf
        ffn_hidden_size = self.config.ffn_hidden_size
        if self.config.gated_linear_unit:
            ffn_hidden_size *= 2

liangjing's avatar
liangjing committed
67
68
69
        self.linear_fc1 = build_module(
            submodules.linear_fc1,
            self.input_size,
liangjing's avatar
v1  
liangjing committed
70
71
72
            ffn_hidden_size,
            config=self.config,
            init_method=self.config.init_method,
liangjing's avatar
liangjing committed
73
            gather_output=False,
liangjing's avatar
v1  
liangjing committed
74
75
            bias=self.config.add_bias_linear,
            skip_bias_add=True,
liangjing's avatar
liangjing committed
76
77
            is_expert=is_expert,
            tp_comm_buffer_name='fc1',
liangjing's avatar
v1  
liangjing committed
78
79
        )

liangjing's avatar
liangjing committed
80
        self.activation_func = self.config.activation_func
liangjing's avatar
v1  
liangjing committed
81

liangjing's avatar
liangjing committed
82
83
        self.linear_fc2 = build_module(
            submodules.linear_fc2,
liangjing's avatar
v1  
liangjing committed
84
85
86
87
88
            self.config.ffn_hidden_size,
            self.config.hidden_size,
            config=self.config,
            init_method=self.config.output_layer_init_method,
            bias=self.config.add_bias_linear,
liangjing's avatar
liangjing committed
89
            input_is_parallel=True,
liangjing's avatar
v1  
liangjing committed
90
            skip_bias_add=True,
liangjing's avatar
liangjing committed
91
92
            is_expert=is_expert,
            tp_comm_buffer_name='fc2',
liangjing's avatar
v1  
liangjing committed
93
94
95
96
97
98
99
        )

    def forward(self, hidden_states):

        # [s, b, 4 * h/p]
        intermediate_parallel, bias_parallel = self.linear_fc1(hidden_states)

liangjing's avatar
liangjing committed
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
        if self.config.bias_activation_fusion:
            if self.activation_func == F.gelu:
                if self.config.gated_linear_unit:
                    intermediate_parallel = bias_geglu_impl(intermediate_parallel, bias_parallel)
                else:
                    assert self.config.add_bias_linear is True
                    intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel)
            elif self.activation_func == F.silu and self.config.gated_linear_unit:
                intermediate_parallel = bias_swiglu_impl(
                    intermediate_parallel,
                    bias_parallel,
                    self.config.activation_func_fp8_input_store,
                )
            else:
                raise ValueError("Only support fusion of gelu and swiglu")
liangjing's avatar
v1  
liangjing committed
115
116
117
        else:
            if bias_parallel is not None:
                intermediate_parallel = intermediate_parallel + bias_parallel
liangjing's avatar
liangjing committed
118
119
120
121
122
123
124
125
126
            if self.config.gated_linear_unit:

                def glu(x):
                    x = torch.chunk(x, 2, dim=-1)
                    return self.config.activation_func(x[0]) * x[1]

                intermediate_parallel = glu(intermediate_parallel)
            else:
                intermediate_parallel = self.activation_func(intermediate_parallel)
liangjing's avatar
v1  
liangjing committed
127
128
129

        # [s, b, h]
        output, output_bias = self.linear_fc2(intermediate_parallel)
liangjing's avatar
liangjing committed
130

liangjing's avatar
v1  
liangjing committed
131
        return output, output_bias
liangjing's avatar
liangjing committed
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255

    def sharded_state_dict(
        self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[dict] = None
    ) -> ShardedStateDict:
        sharded_state_dict = {}
        for name, module in self._modules.items():
            sub_sd = module.sharded_state_dict(f'{prefix}{name}.', sharded_offsets, metadata)
            if self.config.gated_linear_unit and name == 'linear_fc1':
                assert f'{prefix}{name}.weight' in sub_sd, sub_sd.keys()
                for k, v in sub_sd.items():
                    if k in (f'{prefix}{name}.weight', f'{prefix}{name}.bias'):
                        sub_sd[k] = apply_swiglu_sharded_factory(v, sharded_offsets)
            sharded_state_dict.update(sub_sd)
        return sharded_state_dict


def apply_swiglu_sharded_factory(original_sh_ten, sharded_offsets):
    # We must split the tensor into 2 parts, each sharded separately.
    # This requires a ShardedTensorFactory which `chunk`s during saving
    # and `cat`s during loading
    tp_rank = parallel_state.get_tensor_model_parallel_rank()
    tp_size = parallel_state.get_tensor_model_parallel_world_size()
    swiglu_shard_axis = 0
    prepend_axis_num = len(sharded_offsets)
    original_shape = original_sh_ten.local_shape
    original_numel = int(np.prod(original_shape))

    @torch.no_grad()
    def sh_ten_build_fn(
        key: str, t: torch.Tensor, replica_id: ReplicaId, flattened_range: Optional[slice]
    ):
        offset_w = (swiglu_shard_axis + prepend_axis_num, tp_rank, tp_size * 2)
        offset_v = (swiglu_shard_axis + prepend_axis_num, tp_size + tp_rank, tp_size * 2)
        if flattened_range is None:
            tensor_w, tensor_v = torch.chunk(t, 2, dim=swiglu_shard_axis)
            return [
                ShardedTensor.from_rank_offsets(
                    key,
                    tensor_w,
                    *sharded_offsets,
                    offset_w,
                    replica_id=replica_id,
                    prepend_axis_num=prepend_axis_num,
                ),
                ShardedTensor.from_rank_offsets(
                    key,
                    tensor_v,
                    *sharded_offsets,
                    offset_v,
                    replica_id=replica_id,
                    prepend_axis_num=prepend_axis_num,
                ),
            ]
        else:
            # Here we need to map a slice `t` (`flattened_range` specifies slice start and stop)
            # of the *original* flattened tensor into slices `w` and `v` of chunked
            # and flattened tensor.
            # Example:
            # If original tensor has (16, 5) shape and flattened_range is `slice(8, 64)`,
            # then `t` has shape `(56,)` and we need to create 2 tensors:
            # w: first 32 elements of `t` with flattened_range slice(8, 40)
            # v: last 24 elements of `t` with flattened_range slice(0, 24)
            # Global offsets are the same as in the non-flattened case
            assert t.ndim == 1, (key, t.shape)
            non_flat_local_shape = (original_shape[0] // 2, *original_shape[1:])
            chunk_numel = original_numel // 2
            result = []
            if flattened_range.start < chunk_numel:
                # Non-empty `w` chunk
                tensor_w = t[: chunk_numel - flattened_range.start]
                flattened_range_w = slice(
                    flattened_range.start, min(chunk_numel, flattened_range.stop)
                )
                assert len(tensor_w) == flattened_range_w.stop - flattened_range_w.start
                result.append(
                    ShardedTensor.from_rank_offsets_flat(
                        key,
                        tensor_w,
                        non_flat_local_shape,
                        *sharded_offsets,
                        offset_w,
                        replica_id=replica_id,
                        prepend_axis_num=prepend_axis_num,
                        flattened_range=flattened_range_w,
                    )
                )
            if flattened_range.stop > chunk_numel:
                # Non-empty `v` chunk
                tensor_v = t[-(flattened_range.stop - chunk_numel) :]
                flattened_range_v = slice(
                    max(chunk_numel, flattened_range.start) - chunk_numel,
                    flattened_range.stop - chunk_numel,
                )
                assert len(tensor_v) == flattened_range_v.stop - flattened_range_v.start, (
                    len(tensor_v),
                    flattened_range_v,
                )

                result.append(
                    ShardedTensor.from_rank_offsets_flat(
                        key,
                        tensor_v,
                        non_flat_local_shape,
                        *sharded_offsets,
                        offset_v,
                        replica_id=replica_id,
                        prepend_axis_num=prepend_axis_num,
                        flattened_range=flattened_range_v,
                    )
                )
            assert sum(sh_ten.data.numel() for sh_ten in result) == t.numel(), (result, t.shape)
            return result

    def sh_ten_merge_fn(sub_state_dict):
        with torch.no_grad():
            return torch.cat(sub_state_dict)

    return ShardedTensorFactory(
        original_sh_ten.key,
        original_sh_ten.data,
        sh_ten_build_fn,
        sh_ten_merge_fn,
        original_sh_ten.replica_id,
    )