dino.py 10.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the Apache license found in the
# LICENSE file in the root directory of this source tree.

# copied from https://github.com/facebookresearch/dino/blob/main/main_dino.py
# reworked/refactored some parts to make it run in Megatron.
import math
import apex
import einops
import torch
import numpy as np
import torch.nn.functional as F
Vijay Korthikanti's avatar
Vijay Korthikanti committed
14
from torch.nn.init import trunc_normal_
xingjinliang's avatar
xingjinliang committed
15
16
17
18
19
20
from megatron.training import get_args, print_rank_0
from megatron.legacy.model.utils import get_linear_layer
from megatron.legacy.model.vision.vit_backbone import VitBackbone
from megatron.legacy.model.module import MegatronModule
from megatron.legacy.model.vision.mit_backbone import mit_b5_avg
from megatron.legacy.model.vision.esvit_swin_backbone import get_swin
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175


class DINOLoss(torch.nn.Module):
    def __init__(self, out_dim, ncrops, warmup_teacher_temp, teacher_temp,
                 warmup_teacher_temp_epochs, nepochs, student_temp=0.1,
                 center_momentum=0.9):
        super().__init__()
        self.student_temp = student_temp
        self.center_momentum = center_momentum
        self.ncrops = ncrops
        self.register_buffer("center", torch.zeros(1, out_dim))
        # we apply a warm up for the teacher temperature because
        # a too high temperature makes the training instable at the beginning
        self.teacher_temp_schedule = np.concatenate((
            np.linspace(warmup_teacher_temp,
                        teacher_temp, warmup_teacher_temp_epochs),
            np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp
        ))
        self.teacher_temp = teacher_temp

    def forward(self, student_output, teacher_output, iteration):
        """
        Cross-entropy between softmax outputs of the teacher
        and student network.
        """
        args = get_args()
        student_out = student_output / self.student_temp
        student_out = student_out.chunk(self.ncrops)

        epoch = iteration // args.iter_per_epoch

        # teacher centering and sharpening
        temp = self.teacher_temp_schedule[epoch]
        teacher_out = F.softmax((teacher_output - self.center) / temp, dim=-1)

        teacher_out = teacher_out.detach().chunk(2)

        total_loss = 0
        n_loss_terms = 0
        for iq, q in enumerate(teacher_out):
            for v in range(len(student_out)):
                if v == iq:
                    # we skip cases where student and teacher operate on the same view
                    continue
                loss = torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1)
                total_loss += loss.mean()
                n_loss_terms += 1
        total_loss /= n_loss_terms
        self.update_center(teacher_output)
        return total_loss

    @torch.no_grad()
    def update_center(self, teacher_output):
        """
        Update center used for teacher output.
        """
        batch_center = torch.sum(teacher_output, dim=0, keepdim=True)
        torch.distributed.all_reduce(batch_center)
        batch_center = batch_center / (len(teacher_output) * torch.distributed.get_world_size())
        self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)

class DINOHead(torch.nn.Module):
    def __init__(self, in_dim, out_dim, norm_last_layer=True, nlayers=3):
        super().__init__()
        args = get_args()
        hidden_dim = args.dino_head_hidden_size
        bottleneck_dim = args.dino_bottleneck_size
        nlayers = max(nlayers, 1)
        if nlayers == 1:
            self.mlp = torch.nn.Linear(in_dim, bottleneck_dim)
        else:
            layers = [torch.nn.Linear(in_dim, hidden_dim)]
            layers.append(torch.nn.GELU())
            for _ in range(nlayers - 2):
                layers.append(torch.nn.Linear(hidden_dim, hidden_dim))
                layers.append(torch.nn.GELU())
            layers.append(torch.nn.Linear(hidden_dim, bottleneck_dim))
            self.mlp = torch.nn.Sequential(*layers)
        self.apply(self._init_weights)
        self.last_layer = torch.nn.utils.weight_norm(torch.nn.Linear(bottleneck_dim, out_dim, bias=False))
        self.last_layer.weight_g.data.fill_(1)
        if norm_last_layer:
            self.last_layer.weight_g.requires_grad = False

    def _init_weights(self, m):
        if isinstance(m, torch.nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, torch.nn.Linear) and m.bias is not None:
                torch.nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.mlp(x)
        x = torch.nn.functional.normalize(x, dim=-1, p=2)
        x = self.last_layer(x)
        return x


class MultiCropWrapper(MegatronModule):

    """
    Perform forward pass separately on each resolution input.
    The inputs corresponding to a single resolution are clubbed and single
    forward is run on the same resolution inputs. Hence we do several
    forward passes = number of different resolutions used. We then
    concatenate all the output features and run the head forward on these
    concatenated features.
    """
    def __init__(self, backbone, head):
        super(MultiCropWrapper, self).__init__()
        # disable layers dedicated to ImageNet labels classification
        #backbone.fc, backbone.head = torch.nn.Identity(), torch.nn.Identity()
        self.backbone = backbone
        self.head = head

    def forward(self, x):
        # convert to list
        if not isinstance(x, list):
            x = [x]
        idx_crops = torch.cumsum(torch.unique_consecutive(
            torch.tensor([inp.shape[-1] for inp in x]),
            return_counts=True,
        )[1], 0)

        start_idx = 0
        for end_idx in idx_crops:
            _out = self.backbone(torch.cat(x[start_idx: end_idx]))
            if start_idx == 0:
                output = _out
            else:
                output = torch.cat((output, _out))
            start_idx = end_idx
        # Run the head forward on the concatenated features.
        if self.training:
            return self.head(output)
        else:
            return output


def cosine_scheduler(base_value, final_value, epochs, niter_per_ep,
                     warmup_epochs=0, start_warmup_value=0):
    warmup_schedule = np.array([])
    warmup_iters = warmup_epochs * niter_per_ep
    if warmup_epochs > 0:
        warmup_schedule = \
                np.linspace(start_warmup_value, base_value, warmup_iters)

    iters = np.arange(epochs * niter_per_ep - warmup_iters)
    schedule = final_value + 0.5 * (base_value - final_value) \
        * (1 + np.cos(np.pi * iters / len(iters)))

    schedule = np.concatenate((warmup_schedule, schedule))
    assert len(schedule) == epochs * niter_per_ep
    return schedule


liangjing's avatar
v1  
liangjing committed
176
def get_student_backbone_and_num_features(config, pre_process=True, post_process=True):
177
178
179
    args = get_args()

    if args.vision_backbone_type == 'vit':
liangjing's avatar
v1  
liangjing committed
180
181
        student = VitBackbone(config,
                              pre_process=pre_process,
182
183
184
185
186
187
188
189
190
191
192
193
194
                              post_process=post_process,
                              drop_path_rate=0.1,
                              single_token_output=True)
        num_features = args.hidden_size
    elif args.vision_backbone_type == 'mit':
        student = mit_b5_avg(drop_path_rate=0.1)
        num_features = 512
    elif args.vision_backbone_type == 'swin':
        student = get_swin()
        num_features = student.num_features
    else:
        raise Exception('{} vision backbone is not supported.'.format(
                              args.vision_backbone_type))
xingjinliang's avatar
xingjinliang committed
195

196
197
    return student, num_features

liangjing's avatar
v1  
liangjing committed
198
def get_teacher_backbone_and_num_features(config, pre_process=True, post_process=True):
199
200
201
    args = get_args()

    if args.vision_backbone_type == 'vit':
liangjing's avatar
v1  
liangjing committed
202
203
        teacher = VitBackbone(config,
                              pre_process=pre_process,
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
                              post_process=post_process,
                              single_token_output=True)
        num_features = args.hidden_size
    elif args.vision_backbone_type == 'mit':
        teacher = mit_b5_avg(drop_path_rate=0.0)
        num_features = 512
    elif args.vision_backbone_type == 'swin':
        teacher = get_swin(is_teacher=True)
        num_features = teacher.num_features
    else:
        raise Exception('{} vision backbone is not supported.'.format(
                              args.vision_backbone_type))
    return teacher, num_features


class DINOPretrainModel(MegatronModule):
liangjing's avatar
v1  
liangjing committed
220
    def __init__(self, config, pre_process=True, post_process=True):
221
222
        super(DINOPretrainModel, self).__init__()
        args = get_args()
xingjinliang's avatar
xingjinliang committed
223
        self.config = config
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
        self.out_dim = 65536

        self.dino_loss = DINOLoss(
            self.out_dim,
            args.dino_local_crops_number + 2,
            args.dino_warmup_teacher_temp,
            args.dino_teacher_temp,
            args.dino_warmup_teacher_temp_epochs,
            300,
        )

        self.pre_process = pre_process
        self.post_process = post_process
        self.momentum_teacher = 0.996

        student_backbone, num_features = \
liangjing's avatar
v1  
liangjing committed
240
            get_student_backbone_and_num_features(config, pre_process, post_process)
241
242
243
244
245
246
247
248
249
250
251
252
253
254

        self.student = MultiCropWrapper(
            student_backbone,
            DINOHead(num_features, self.out_dim,
                     norm_last_layer=args.dino_norm_last_layer)
        )

        self.momentum_schedule = cosine_scheduler(
            self.momentum_teacher, 1,
            args.train_iters // args.iter_per_epoch,
            args.iter_per_epoch
        )

        teacher_backbone, num_features = \
liangjing's avatar
v1  
liangjing committed
255
            get_teacher_backbone_and_num_features(config, pre_process, post_process)
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
        self.teacher = MultiCropWrapper(
            teacher_backbone,
            DINOHead(num_features, self.out_dim)
        )
        self.teacher.load_state_dict(self.student.state_dict())

        for p in self.teacher.parameters():
            if hasattr(p, "requires_grad") and p.requires_grad is not None:
                p.requires_grad = False

    def set_input_tensor(self, tensor):
        pass

    def forward(self, input):
        student_output = None
        if self.training:
            student_output = self.student(input)
            teacher_output = self.teacher(input[:2])
        else:
            teacher_output = self.teacher(input)
        return student_output, teacher_output

    def cancel_gradients_last_layer(self, iteration):
        args = get_args()
        epoch = iteration // args.iter_per_epoch
        if epoch < args.dino_freeze_last_layer:
            for n, p in self.student.named_parameters():
                if "last_layer" in n:
                    p.grad = None

    def update_momentum(self, iteration):
        with torch.no_grad():
            m = self.momentum_schedule[iteration]
            for param_q, param_k in zip(self.student.parameters(), self.teacher.parameters()):
                param_k.data.mul_(m).add_((1 - m) * param_q.detach().data)