mlp.py 10.4 KB
Newer Older
xingjinliang's avatar
xingjinliang committed
1
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
liangjing's avatar
v1  
liangjing committed
2

xingjinliang's avatar
xingjinliang committed
3
4
5
6
from dataclasses import dataclass
from typing import Optional, Union

import numpy as np
liangjing's avatar
v1  
liangjing committed
7
8
9
import torch
import torch.nn.functional as F

xingjinliang's avatar
xingjinliang committed
10
11
12
13
14
from megatron.core.dist_checkpointing import ShardedTensor
from megatron.core.dist_checkpointing.mapping import (
    ReplicaId,
    ShardedStateDict,
    ShardedTensorFactory,
liangjing's avatar
v1  
liangjing committed
15
)
xingjinliang's avatar
xingjinliang committed
16
17
18
from megatron.core.fusions.fused_bias_geglu import bias_geglu_impl
from megatron.core.fusions.fused_bias_gelu import bias_gelu_impl
from megatron.core.fusions.fused_bias_swiglu import bias_swiglu_impl
liangjing's avatar
v1  
liangjing committed
19
from megatron.core.transformer.module import MegatronModule
xingjinliang's avatar
xingjinliang committed
20
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
liangjing's avatar
v1  
liangjing committed
21
22
23
from megatron.core.transformer.transformer_config import TransformerConfig


wangxj's avatar
wangxj committed
24
# pylint: disable=missing-class-docstring
xingjinliang's avatar
xingjinliang committed
25
26
27
28
29
30
@dataclass
class MLPSubmodules:
    linear_fc1: Union[ModuleSpec, type] = None
    linear_fc2: Union[ModuleSpec, type] = None


liangjing's avatar
v1  
liangjing committed
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
class MLP(MegatronModule):
    """
    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
    state back into h hidden dimension.


    Returns an output and a bias to be added to the output.
    If config.add_bias_linear is False, the bias returned is None.

    We use the following notation:
     h: hidden size
     p: number of tensor model parallel partitions
     b: batch size
     s: sequence length
    """

xingjinliang's avatar
xingjinliang committed
48
49
50
51
52
53
54
    def __init__(
        self,
        config: TransformerConfig,
        submodules: MLPSubmodules,
        is_expert: bool = False,
        input_size: int = None,
    ):
liangjing's avatar
v1  
liangjing committed
55
56
57
58
        super().__init__(config=config)

        self.config: TransformerConfig = config

xingjinliang's avatar
xingjinliang committed
59
60
61
62
        self.input_size = input_size if input_size != None else self.config.hidden_size

        # If this is a gated linear unit we double the output width
        # see https://arxiv.org/pdf/2002.05202.pdf
liangjing's avatar
v1  
liangjing committed
63
64
65
66
        ffn_hidden_size = self.config.ffn_hidden_size
        if self.config.gated_linear_unit:
            ffn_hidden_size *= 2

xingjinliang's avatar
xingjinliang committed
67
68
69
        self.linear_fc1 = build_module(
            submodules.linear_fc1,
            self.input_size,
liangjing's avatar
v1  
liangjing committed
70
71
72
            ffn_hidden_size,
            config=self.config,
            init_method=self.config.init_method,
xingjinliang's avatar
xingjinliang committed
73
            gather_output=False,
liangjing's avatar
v1  
liangjing committed
74
75
            bias=self.config.add_bias_linear,
            skip_bias_add=True,
xingjinliang's avatar
xingjinliang committed
76
77
            is_expert=is_expert,
            tp_comm_buffer_name='fc1',
liangjing's avatar
v1  
liangjing committed
78
79
        )

xingjinliang's avatar
xingjinliang committed
80
        self.activation_func = self.config.activation_func
liangjing's avatar
v1  
liangjing committed
81

xingjinliang's avatar
xingjinliang committed
82
83
        self.linear_fc2 = build_module(
            submodules.linear_fc2,
liangjing's avatar
v1  
liangjing committed
84
85
86
87
88
            self.config.ffn_hidden_size,
            self.config.hidden_size,
            config=self.config,
            init_method=self.config.output_layer_init_method,
            bias=self.config.add_bias_linear,
xingjinliang's avatar
xingjinliang committed
89
            input_is_parallel=True,
liangjing's avatar
v1  
liangjing committed
90
            skip_bias_add=True,
xingjinliang's avatar
xingjinliang committed
91
92
            is_expert=is_expert,
            tp_comm_buffer_name='fc2',
liangjing's avatar
v1  
liangjing committed
93
94
95
        )

    def forward(self, hidden_states):
xingjinliang's avatar
xingjinliang committed
96
        """Perform the forward pass through the MLP block."""
liangjing's avatar
v1  
liangjing committed
97
98
99
        # [s, b, 4 * h/p]
        intermediate_parallel, bias_parallel = self.linear_fc1(hidden_states)

xingjinliang's avatar
xingjinliang committed
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
        if self.config.bias_activation_fusion:
            if self.activation_func == F.gelu:
                if self.config.gated_linear_unit:
                    intermediate_parallel = bias_geglu_impl(intermediate_parallel, bias_parallel)
                else:
                    assert self.config.add_bias_linear is True
                    intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel)
            elif self.activation_func == F.silu and self.config.gated_linear_unit:
                intermediate_parallel = bias_swiglu_impl(
                    intermediate_parallel,
                    bias_parallel,
                    self.config.activation_func_fp8_input_store,
                )
            else:
                raise ValueError("Only support fusion of gelu and swiglu")
liangjing's avatar
v1  
liangjing committed
115
116
117
        else:
            if bias_parallel is not None:
                intermediate_parallel = intermediate_parallel + bias_parallel
xingjinliang's avatar
xingjinliang committed
118
119
120
121
122
123
124
125
126
            if self.config.gated_linear_unit:

                def glu(x):
                    x = torch.chunk(x, 2, dim=-1)
                    return self.config.activation_func(x[0]) * x[1]

                intermediate_parallel = glu(intermediate_parallel)
            else:
                intermediate_parallel = self.activation_func(intermediate_parallel)
liangjing's avatar
v1  
liangjing committed
127
128
129

        # [s, b, h]
        output, output_bias = self.linear_fc2(intermediate_parallel)
xingjinliang's avatar
xingjinliang committed
130

liangjing's avatar
v1  
liangjing committed
131
        return output, output_bias
xingjinliang's avatar
xingjinliang committed
132

wangxj's avatar
wangxj committed
133
    # pylint: disable=missing-function-docstring
xingjinliang's avatar
xingjinliang committed
134
135
136
137
138
139
140
    def sharded_state_dict(
        self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[dict] = None
    ) -> ShardedStateDict:
        sharded_state_dict = {}
        for name, module in self._modules.items():
            sub_sd = module.sharded_state_dict(f'{prefix}{name}.', sharded_offsets, metadata)
            if self.config.gated_linear_unit and name == 'linear_fc1':
wangxj's avatar
wangxj committed
141
142
143
                # NOTE: In custom FSDP, we can have no weight in local.
                if not self.config.use_custom_fsdp:
                    assert f'{prefix}{name}.weight' in sub_sd, sub_sd.keys()
xingjinliang's avatar
xingjinliang committed
144
145
146
147
148
149
150
                for k, v in sub_sd.items():
                    if k in (f'{prefix}{name}.weight', f'{prefix}{name}.bias'):
                        sub_sd[k] = apply_swiglu_sharded_factory(v, sharded_offsets)
            sharded_state_dict.update(sub_sd)
        return sharded_state_dict


wangxj's avatar
wangxj committed
151
# pylint: disable=missing-function-docstring
xingjinliang's avatar
xingjinliang committed
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
def apply_swiglu_sharded_factory(original_sh_ten, sharded_offsets):
    # We must split the tensor into 2 parts, each sharded separately.
    # This requires a ShardedTensorFactory which `chunk`s during saving
    # and `cat`s during loading

    swiglu_shard_axis = 0
    prepend_axis_num = len(sharded_offsets)
    original_shape = original_sh_ten.local_shape
    original_numel = int(np.prod(original_shape))
    local_axis_size = original_shape[swiglu_shard_axis]
    assert (
        original_sh_ten.global_offset[swiglu_shard_axis + prepend_axis_num] % local_axis_size == 0
    )
    rank_offset = (
        original_sh_ten.global_offset[swiglu_shard_axis + prepend_axis_num] // local_axis_size
    )
    axis_frag = original_sh_ten.axis_fragmentations[swiglu_shard_axis + prepend_axis_num]

    @torch.no_grad()
    def sh_ten_build_fn(
        key: str, t: torch.Tensor, replica_id: ReplicaId, flattened_range: Optional[slice]
    ):
        offset_w = (swiglu_shard_axis + prepend_axis_num, rank_offset, axis_frag * 2)
        offset_v = (swiglu_shard_axis + prepend_axis_num, rank_offset + axis_frag, axis_frag * 2)
        if flattened_range is None:
            tensor_w, tensor_v = torch.chunk(t, 2, dim=swiglu_shard_axis)
            return [
                ShardedTensor.from_rank_offsets(
                    key,
                    tensor_w,
                    *sharded_offsets,
                    offset_w,
                    replica_id=replica_id,
                    prepend_axis_num=prepend_axis_num,
                ),
                ShardedTensor.from_rank_offsets(
                    key,
                    tensor_v,
                    *sharded_offsets,
                    offset_v,
                    replica_id=replica_id,
                    prepend_axis_num=prepend_axis_num,
                ),
            ]
        else:
            # Here we need to map a slice `t` (`flattened_range` specifies slice start and stop)
            # of the *original* flattened tensor into slices `w` and `v` of chunked
            # and flattened tensor.
            # Example:
            # If original tensor has (16, 5) shape and flattened_range is `slice(8, 64)`,
            # then `t` has shape `(56,)` and we need to create 2 tensors:
            # w: first 32 elements of `t` with flattened_range slice(8, 40)
            # v: last 24 elements of `t` with flattened_range slice(0, 24)
            # Global offsets are the same as in the non-flattened case
            assert t.ndim == 1, (key, t.shape)
            non_flat_local_shape = (original_shape[0] // 2, *original_shape[1:])
            chunk_numel = original_numel // 2
            result = []
            if flattened_range.start < chunk_numel:
                # Non-empty `w` chunk
                tensor_w = t[: chunk_numel - flattened_range.start]
                flattened_range_w = slice(
                    flattened_range.start, min(chunk_numel, flattened_range.stop)
                )
                assert len(tensor_w) == flattened_range_w.stop - flattened_range_w.start
                result.append(
                    ShardedTensor.from_rank_offsets_flat(
                        key,
                        tensor_w,
                        non_flat_local_shape,
                        *sharded_offsets,
                        offset_w,
                        replica_id=replica_id,
                        prepend_axis_num=prepend_axis_num,
                        flattened_range=flattened_range_w,
                    )
                )
            if flattened_range.stop > chunk_numel:
                # Non-empty `v` chunk
                tensor_v = t[-(flattened_range.stop - chunk_numel) :]
                flattened_range_v = slice(
                    max(chunk_numel, flattened_range.start) - chunk_numel,
                    flattened_range.stop - chunk_numel,
                )
                assert len(tensor_v) == flattened_range_v.stop - flattened_range_v.start, (
                    len(tensor_v),
                    flattened_range_v,
                )

                result.append(
                    ShardedTensor.from_rank_offsets_flat(
                        key,
                        tensor_v,
                        non_flat_local_shape,
                        *sharded_offsets,
                        offset_v,
                        replica_id=replica_id,
                        prepend_axis_num=prepend_axis_num,
                        flattened_range=flattened_range_v,
                    )
                )
            assert sum(sh_ten.data.numel() for sh_ten in result) == t.numel(), (result, t.shape)
            return result

    def sh_ten_merge_fn(sub_state_dict):
        with torch.no_grad():
            return torch.cat(sub_state_dict)

    return ShardedTensorFactory(
        original_sh_ten.key,
        original_sh_ten.data,
        sh_ten_build_fn,
        sh_ten_merge_fn,
        original_sh_ten.replica_id,
wangxj's avatar
wangxj committed
266
        flattened_range=original_sh_ten.flattened_range,
xingjinliang's avatar
xingjinliang committed
267
    )