mlp.py 10.1 KB
Newer Older
xingjinliang's avatar
xingjinliang committed
1
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
liangjing's avatar
v1  
liangjing committed
2

xingjinliang's avatar
xingjinliang committed
3
4
5
6
from dataclasses import dataclass
from typing import Optional, Union

import numpy as np
liangjing's avatar
v1  
liangjing committed
7
8
9
import torch
import torch.nn.functional as F

xingjinliang's avatar
xingjinliang committed
10
11
12
13
14
from megatron.core.dist_checkpointing import ShardedTensor
from megatron.core.dist_checkpointing.mapping import (
    ReplicaId,
    ShardedStateDict,
    ShardedTensorFactory,
liangjing's avatar
v1  
liangjing committed
15
)
xingjinliang's avatar
xingjinliang committed
16
17
18
from megatron.core.fusions.fused_bias_geglu import bias_geglu_impl
from megatron.core.fusions.fused_bias_gelu import bias_gelu_impl
from megatron.core.fusions.fused_bias_swiglu import bias_swiglu_impl
liangjing's avatar
v1  
liangjing committed
19
from megatron.core.transformer.module import MegatronModule
xingjinliang's avatar
xingjinliang committed
20
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
liangjing's avatar
v1  
liangjing committed
21
22
23
from megatron.core.transformer.transformer_config import TransformerConfig


xingjinliang's avatar
xingjinliang committed
24
25
26
27
28
29
@dataclass
class MLPSubmodules:
    linear_fc1: Union[ModuleSpec, type] = None
    linear_fc2: Union[ModuleSpec, type] = None


liangjing's avatar
v1  
liangjing committed
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class MLP(MegatronModule):
    """
    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
    state back into h hidden dimension.


    Returns an output and a bias to be added to the output.
    If config.add_bias_linear is False, the bias returned is None.

    We use the following notation:
     h: hidden size
     p: number of tensor model parallel partitions
     b: batch size
     s: sequence length
    """

xingjinliang's avatar
xingjinliang committed
47
48
49
50
51
52
53
    def __init__(
        self,
        config: TransformerConfig,
        submodules: MLPSubmodules,
        is_expert: bool = False,
        input_size: int = None,
    ):
liangjing's avatar
v1  
liangjing committed
54
55
56
57
        super().__init__(config=config)

        self.config: TransformerConfig = config

xingjinliang's avatar
xingjinliang committed
58
59
60
61
        self.input_size = input_size if input_size != None else self.config.hidden_size

        # If this is a gated linear unit we double the output width
        # see https://arxiv.org/pdf/2002.05202.pdf
liangjing's avatar
v1  
liangjing committed
62
63
64
65
        ffn_hidden_size = self.config.ffn_hidden_size
        if self.config.gated_linear_unit:
            ffn_hidden_size *= 2

xingjinliang's avatar
xingjinliang committed
66
67
68
        self.linear_fc1 = build_module(
            submodules.linear_fc1,
            self.input_size,
liangjing's avatar
v1  
liangjing committed
69
70
71
            ffn_hidden_size,
            config=self.config,
            init_method=self.config.init_method,
xingjinliang's avatar
xingjinliang committed
72
            gather_output=False,
liangjing's avatar
v1  
liangjing committed
73
74
            bias=self.config.add_bias_linear,
            skip_bias_add=True,
xingjinliang's avatar
xingjinliang committed
75
76
            is_expert=is_expert,
            tp_comm_buffer_name='fc1',
liangjing's avatar
v1  
liangjing committed
77
78
        )

xingjinliang's avatar
xingjinliang committed
79
        self.activation_func = self.config.activation_func
liangjing's avatar
v1  
liangjing committed
80

xingjinliang's avatar
xingjinliang committed
81
82
        self.linear_fc2 = build_module(
            submodules.linear_fc2,
liangjing's avatar
v1  
liangjing committed
83
84
85
86
87
            self.config.ffn_hidden_size,
            self.config.hidden_size,
            config=self.config,
            init_method=self.config.output_layer_init_method,
            bias=self.config.add_bias_linear,
xingjinliang's avatar
xingjinliang committed
88
            input_is_parallel=True,
liangjing's avatar
v1  
liangjing committed
89
            skip_bias_add=True,
xingjinliang's avatar
xingjinliang committed
90
91
            is_expert=is_expert,
            tp_comm_buffer_name='fc2',
liangjing's avatar
v1  
liangjing committed
92
93
94
        )

    def forward(self, hidden_states):
xingjinliang's avatar
xingjinliang committed
95
        """Perform the forward pass through the MLP block."""
liangjing's avatar
v1  
liangjing committed
96
97
98
        # [s, b, 4 * h/p]
        intermediate_parallel, bias_parallel = self.linear_fc1(hidden_states)

xingjinliang's avatar
xingjinliang committed
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
        if self.config.bias_activation_fusion:
            if self.activation_func == F.gelu:
                if self.config.gated_linear_unit:
                    intermediate_parallel = bias_geglu_impl(intermediate_parallel, bias_parallel)
                else:
                    assert self.config.add_bias_linear is True
                    intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel)
            elif self.activation_func == F.silu and self.config.gated_linear_unit:
                intermediate_parallel = bias_swiglu_impl(
                    intermediate_parallel,
                    bias_parallel,
                    self.config.activation_func_fp8_input_store,
                )
            else:
                raise ValueError("Only support fusion of gelu and swiglu")
liangjing's avatar
v1  
liangjing committed
114
115
116
        else:
            if bias_parallel is not None:
                intermediate_parallel = intermediate_parallel + bias_parallel
xingjinliang's avatar
xingjinliang committed
117
118
119
120
121
122
123
124
125
            if self.config.gated_linear_unit:

                def glu(x):
                    x = torch.chunk(x, 2, dim=-1)
                    return self.config.activation_func(x[0]) * x[1]

                intermediate_parallel = glu(intermediate_parallel)
            else:
                intermediate_parallel = self.activation_func(intermediate_parallel)
liangjing's avatar
v1  
liangjing committed
126
127
128

        # [s, b, h]
        output, output_bias = self.linear_fc2(intermediate_parallel)
xingjinliang's avatar
xingjinliang committed
129

liangjing's avatar
v1  
liangjing committed
130
        return output, output_bias
xingjinliang's avatar
xingjinliang committed
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261

    def sharded_state_dict(
        self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[dict] = None
    ) -> ShardedStateDict:
        sharded_state_dict = {}
        for name, module in self._modules.items():
            sub_sd = module.sharded_state_dict(f'{prefix}{name}.', sharded_offsets, metadata)
            if self.config.gated_linear_unit and name == 'linear_fc1':
                assert f'{prefix}{name}.weight' in sub_sd, sub_sd.keys()
                for k, v in sub_sd.items():
                    if k in (f'{prefix}{name}.weight', f'{prefix}{name}.bias'):
                        sub_sd[k] = apply_swiglu_sharded_factory(v, sharded_offsets)
            sharded_state_dict.update(sub_sd)
        return sharded_state_dict


def apply_swiglu_sharded_factory(original_sh_ten, sharded_offsets):
    # We must split the tensor into 2 parts, each sharded separately.
    # This requires a ShardedTensorFactory which `chunk`s during saving
    # and `cat`s during loading

    swiglu_shard_axis = 0
    prepend_axis_num = len(sharded_offsets)
    original_shape = original_sh_ten.local_shape
    original_numel = int(np.prod(original_shape))
    local_axis_size = original_shape[swiglu_shard_axis]
    assert (
        original_sh_ten.global_offset[swiglu_shard_axis + prepend_axis_num] % local_axis_size == 0
    )
    rank_offset = (
        original_sh_ten.global_offset[swiglu_shard_axis + prepend_axis_num] // local_axis_size
    )
    axis_frag = original_sh_ten.axis_fragmentations[swiglu_shard_axis + prepend_axis_num]

    @torch.no_grad()
    def sh_ten_build_fn(
        key: str, t: torch.Tensor, replica_id: ReplicaId, flattened_range: Optional[slice]
    ):
        offset_w = (swiglu_shard_axis + prepend_axis_num, rank_offset, axis_frag * 2)
        offset_v = (swiglu_shard_axis + prepend_axis_num, rank_offset + axis_frag, axis_frag * 2)
        if flattened_range is None:
            tensor_w, tensor_v = torch.chunk(t, 2, dim=swiglu_shard_axis)
            return [
                ShardedTensor.from_rank_offsets(
                    key,
                    tensor_w,
                    *sharded_offsets,
                    offset_w,
                    replica_id=replica_id,
                    prepend_axis_num=prepend_axis_num,
                ),
                ShardedTensor.from_rank_offsets(
                    key,
                    tensor_v,
                    *sharded_offsets,
                    offset_v,
                    replica_id=replica_id,
                    prepend_axis_num=prepend_axis_num,
                ),
            ]
        else:
            # Here we need to map a slice `t` (`flattened_range` specifies slice start and stop)
            # of the *original* flattened tensor into slices `w` and `v` of chunked
            # and flattened tensor.
            # Example:
            # If original tensor has (16, 5) shape and flattened_range is `slice(8, 64)`,
            # then `t` has shape `(56,)` and we need to create 2 tensors:
            # w: first 32 elements of `t` with flattened_range slice(8, 40)
            # v: last 24 elements of `t` with flattened_range slice(0, 24)
            # Global offsets are the same as in the non-flattened case
            assert t.ndim == 1, (key, t.shape)
            non_flat_local_shape = (original_shape[0] // 2, *original_shape[1:])
            chunk_numel = original_numel // 2
            result = []
            if flattened_range.start < chunk_numel:
                # Non-empty `w` chunk
                tensor_w = t[: chunk_numel - flattened_range.start]
                flattened_range_w = slice(
                    flattened_range.start, min(chunk_numel, flattened_range.stop)
                )
                assert len(tensor_w) == flattened_range_w.stop - flattened_range_w.start
                result.append(
                    ShardedTensor.from_rank_offsets_flat(
                        key,
                        tensor_w,
                        non_flat_local_shape,
                        *sharded_offsets,
                        offset_w,
                        replica_id=replica_id,
                        prepend_axis_num=prepend_axis_num,
                        flattened_range=flattened_range_w,
                    )
                )
            if flattened_range.stop > chunk_numel:
                # Non-empty `v` chunk
                tensor_v = t[-(flattened_range.stop - chunk_numel) :]
                flattened_range_v = slice(
                    max(chunk_numel, flattened_range.start) - chunk_numel,
                    flattened_range.stop - chunk_numel,
                )
                assert len(tensor_v) == flattened_range_v.stop - flattened_range_v.start, (
                    len(tensor_v),
                    flattened_range_v,
                )

                result.append(
                    ShardedTensor.from_rank_offsets_flat(
                        key,
                        tensor_v,
                        non_flat_local_shape,
                        *sharded_offsets,
                        offset_v,
                        replica_id=replica_id,
                        prepend_axis_num=prepend_axis_num,
                        flattened_range=flattened_range_v,
                    )
                )
            assert sum(sh_ten.data.numel() for sh_ten in result) == t.numel(), (result, t.shape)
            return result

    def sh_ten_merge_fn(sub_state_dict):
        with torch.no_grad():
            return torch.cat(sub_state_dict)

    return ShardedTensorFactory(
        original_sh_ten.key,
        original_sh_ten.data,
        sh_ten_build_fn,
        sh_ten_merge_fn,
        original_sh_ten.replica_id,
    )