resmlp.py 10.4 KB
Newer Older
yuguo960516's avatar
yuguo960516 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# --------------------------------------------------------
# ResMLP Model
# References:
# resmlp: https://github.com/facebookresearch/deit/blob/main/resmlp_models.py
# --------------------------------------------------------

import oneflow as flow
import oneflow.nn as nn
from flowvision.layers.weight_init import trunc_normal_

import libai.utils.distributed as dist
from libai.config import configurable
from libai.layers import MLP, DropPath, LayerNorm, Linear, PatchEmbedding


class Affine(nn.Module):
    def __init__(self, dim, *, layer_idx=0):
        super().__init__()
        self.alpha = nn.Parameter(
            flow.ones(
                dim,
                placement=dist.get_layer_placement(layer_idx),
                sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
            )
        )
        self.beta = nn.Parameter(
            flow.zeros(
                dim,
                placement=dist.get_layer_placement(layer_idx),
                sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
            ),
        )

        self.layer_idx = layer_idx

    def forward(self, x):
        x = x.to_global(placement=dist.get_layer_placement(self.layer_idx))
        return self.alpha * x + self.beta


class layers_scale_mlp_blocks(nn.Module):
    def __init__(
        self, dim, drop=0.0, drop_path=0.0, init_values=1e-4, num_patches=196, *, layer_idx=0
    ):
        super().__init__()
        self.norm1 = Affine(dim, layer_idx=layer_idx)
        self.attn = Linear(num_patches, num_patches, layer_idx=layer_idx)
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.norm2 = Affine(dim, layer_idx=layer_idx)
        self.mlp = MLP(hidden_size=dim, ffn_hidden_size=int(4.0 * dim), layer_idx=layer_idx)
        self.gamma_1 = nn.Parameter(
            init_values
            * flow.ones(
                dim,
                sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
                placement=dist.get_layer_placement(layer_idx),
            ),
            requires_grad=True,
        )
        self.gamma_2 = nn.Parameter(
            init_values
            * flow.ones(
                dim,
                sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
                placement=dist.get_layer_placement(layer_idx),
            ),
            requires_grad=True,
        )

        self.layer_idx = layer_idx

    def forward(self, x):
        x = x.to_global(placement=dist.get_layer_placement(self.layer_idx))
        x = x + self.drop_path(
            self.gamma_1 * self.attn(self.norm1(x).transpose(1, 2)).transpose(1, 2)
        )
        x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
        return x


class ResMLP(nn.Module):
    """ResMLP in LiBai.

    LiBai's implementation of:
    `ResMLP: Feedforward networks for image classification with data-efficient training
    <https://arxiv.org/abs/2105.03404>`_

    Args:
        img_size (int, tuple(int)): input image size
        patch_size (int, tuple(int)): patch size
        in_chans (int): number of input channels
        embed_dim (int): embedding dimension
        depth (int): depth of transformer
        drop_rate (float): dropout rate
        drop_path_rate (float): stochastic depth rate
        init_scale (float): the layer scale ratio
        num_classes (int): number of classes for classification head
        loss_func (callable, optional): loss function for computing the total loss
                                        between logits and labels

    """

    @configurable
    def __init__(
        self,
        img_size=224,
        patch_size=16,
        in_chans=3,
        embed_dim=768,
        depth=12,
        drop_rate=0.0,
        drop_path_rate=0.0,
        init_scale=1e-4,
        num_classes=1000,
        loss_func=None,
    ):
        super().__init__()

        self.num_classes = num_classes
        self.num_features = self.embed_dim = embed_dim

        self.patch_embed = PatchEmbedding(
            img_size=img_size,
            patch_size=patch_size,
            in_chans=in_chans,
            embed_dim=embed_dim,
        )

        num_patches = self.patch_embed.num_patches
        dpr = [drop_path_rate for i in range(depth)]  # stochastic depth decay rule

        self.blocks = nn.ModuleList(
            [
                layers_scale_mlp_blocks(
                    dim=embed_dim,
                    drop=drop_rate,
                    drop_path=dpr[i],
                    init_values=init_scale,
                    num_patches=num_patches,
                    layer_idx=i,
                )
                for i in range(depth)
            ]
        )

        self.norm = Affine(embed_dim, layer_idx=-1)
        self.head = (
            Linear(embed_dim, num_classes, layer_idx=-1) if num_classes > 0 else nn.Identity()
        )

        # loss func
        self.loss_func = nn.CrossEntropyLoss() if loss_func is None else loss_func

        # weight init
        self.apply(self._init_weights)

    @classmethod
    def from_config(cls, cfg):
        return {
            "img_size": cfg.img_size,
            "patch_size": cfg.patch_size,
            "in_chans": cfg.in_chans,
            "embed_dim": cfg.embed_dim,
            "depth": cfg.depth,
            "drop_rate": cfg.drop_rate,
            "drop_path_rate": cfg.drop_path_rate,
            "init_scale": cfg.init_scale,
            "num_classes": cfg.num_classes,
            "loss_func": cfg.loss_func,
        }

    def _init_weights(self, m):
        if isinstance(m, Linear):
            trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward_features(self, x):
        x = self.patch_embed(x)

        # layer scale mlp blocks
        for i, blk in enumerate(self.blocks):
            x = blk(x)

        return x

    def forward_head(self, x):
        B = x.shape[0]
        x = self.norm(x)
        x = x.mean(dim=1).reshape(B, 1, -1)
        return self.head(x[:, 0])

    def forward(self, images, labels=None):
        """

        Args:
            images (flow.Tensor): training samples.
            labels (flow.LongTensor, optional): training targets

        Returns:
            dict:
                A dict containing :code:`loss_value` or :code:`logits`
                depending on training or evaluation mode.
                :code:`{"losses": loss_value}` when training,
                :code:`{"prediction_scores": logits}` when evaluating.
        """
        x = self.forward_features(images)
        x = self.forward_head(x)

        if labels is not None and self.training:
            losses = self.loss_func(x, labels)
            return {"losses": losses}
        else:
            return {"prediction_scores": x}

    @staticmethod
    def set_pipeline_stage_id(model):
        dist_utils = dist.get_dist_util()

        # Set pipeline parallelism stage_id
        if hasattr(model.loss_func, "config"):
            # Old API in OneFlow 0.8
            for module_block in model.modules():
                # module.origin can get the original module
                if isinstance(module_block.origin, PatchEmbedding):
                    module_block.config.set_stage(
                        dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0)
                    )
                elif isinstance(module_block.origin, layers_scale_mlp_blocks):
                    module_block.config.set_stage(
                        dist_utils.get_layer_stage_id(module_block.layer_idx),
                        dist.get_layer_placement(module_block.layer_idx),
                    )

            # Set norm and head stage id
            model.norm.config.set_stage(
                dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
            )
            model.head.config.set_stage(
                dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
            )
            model.loss_func.config.set_stage(
                dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
            )
        else:
            for module_block in model.modules():
                if isinstance(module_block.to(nn.Module), PatchEmbedding):
                    module_block.to(nn.graph.GraphModule).set_stage(
                        dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0)
                    )
                elif isinstance(module_block.to(nn.Module), layers_scale_mlp_blocks):
                    module_block.to(nn.graph.GraphModule).set_stage(
                        dist_utils.get_layer_stage_id(module_block.layer_idx),
                        dist.get_layer_placement(module_block.layer_idx),
                    )

            # Set norm and head stage id
            model.norm.to(nn.graph.GraphModule).set_stage(
                dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
            )
            model.head.to(nn.graph.GraphModule).set_stage(
                dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
            )
            model.loss_func.to(nn.graph.GraphModule).set_stage(
                dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
            )

    @staticmethod
    def set_activation_checkpoint(model):
        for module_block in model.modules():
            if hasattr(module_block, "origin"):
                # Old API in OneFlow 0.8
                if isinstance(module_block.origin, layers_scale_mlp_blocks):
                    module_block.config.activation_checkpointing = True
            else:
                if isinstance(module_block.to(nn.Module), layers_scale_mlp_blocks):
                    module_block.to(nn.graph.GraphModule).activation_checkpointing = True