pretrain_vision_dino.py 4.04 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.nn.functional as F
18
19
20
import torch.nn as nn
import numpy as np
import torch.distributed as dist
Vijay Korthikanti's avatar
Vijay Korthikanti committed
21
from functools import partial
22
23
from megatron import get_args, get_timers, mpu, print_rank_0
from megatron.data.vit_dataset import build_train_valid_datasets
Vijay Korthikanti's avatar
Vijay Korthikanti committed
24
from megatron.model.vision.dino import DINOPretrainModel
Vijay Korthikanti's avatar
Vijay Korthikanti committed
25
from megatron.model.vision.knn_monitor import knn_predict, get_feature_bank
26
from megatron.training import pretrain
27
28
29
30
from megatron.utils import average_losses_across_data_parallel_group, unwrap_model
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
Vijay Korthikanti's avatar
Vijay Korthikanti committed
31
from megatron.model import ModelType
32

Vijay Korthikanti's avatar
Vijay Korthikanti committed
33
def model_provider(pre_process=True, post_process=True):
34
    """Build the model."""
35
    return DINOPretrainModel(pre_process=pre_process, post_process=post_process)
36
37
38

def get_batch(data_iterator):
    """Build the batch."""
Vijay Korthikanti's avatar
Vijay Korthikanti committed
39
    data = next(data_iterator)
40

Vijay Korthikanti's avatar
Vijay Korthikanti committed
41
    # only data parallelism; no need for broadcast
42
43
44
45
    if isinstance(data[0], list):
        images = [aug.cuda() for aug in data[0]]
    else:
        images = data[0].cuda()
Vijay Korthikanti's avatar
Vijay Korthikanti committed
46
    labels = data[1].cuda()
47
48
49

    return images, labels

Vijay Korthikanti's avatar
Vijay Korthikanti committed
50

51
52
53
54
55
56
57
58
59
60
61
62
63
64
def loss_func(model, labels, output_tensor, collect_data=False):
    args = get_args()
    
    model = unwrap_model(
        model,
        (torchDDP, LocalDDP, Float16Module)
    )
    if model.training:
        student_output, teacher_output = output_tensor
        loss = model.dino_loss(student_output, teacher_output, args.curr_iteration)
        averaged_loss = average_losses_across_data_parallel_group([loss])
        return loss, {"loss": averaged_loss[0]}
    else:
        _, teacher_feature = output_tensor
Vijay Korthikanti's avatar
Vijay Korthikanti committed
65
        feature_bank, feature_labels, classes = get_feature_bank()
66
67
68
69
70
71
72
73
74
75
76
77
78
79
        feature = F.normalize(teacher_feature.float(), dim=1)

        knn_accs = []
        for k in [10, 20, 100, 200]:
            pred_labels = knn_predict(feature, feature_bank,
                                      feature_labels, classes, k, 0.07)
            knn_acc = (pred_labels[:, 0] == labels).float().mean()
            knn_accs.append(knn_acc)

        averaged_loss = average_losses_across_data_parallel_group(knn_accs)
        return 0, {"knn_acc_10": averaged_loss[0],
                   "knn_acc_20": averaged_loss[1],
                   "knn_acc_100": averaged_loss[2],
                   "knn_acc_200": averaged_loss[3]}
Vijay Korthikanti's avatar
Vijay Korthikanti committed
80
81
82


def forward_step(data_iterator, model):
83
84
85
86
    """Forward step."""
    timers = get_timers()

    # Get the batch.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
87
    timers("batch-generator").start()
88
89
90
91
    (
        images,
        labels,
    ) = get_batch(data_iterator)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
92
    timers("batch-generator").stop()
93

94
    return model(images), partial(loss_func, model, labels)
95
96
97
98
99
100
101
102
103


def train_valid_test_datasets_provider(train_val_test_num_samples):
    """Build train, valid, and test datasets."""
    args = get_args()

    print_rank_0(
        "> building train, validation, and test datasets " "for VIT ..."
    )
104
105
106
107
    train_ds, valid_ds = build_train_valid_datasets(
        data_path=args.data_path,
        image_size=(args.img_h, args.img_w)
    )
108
109
110
111
112
113
114
115
116
    print_rank_0("> finished creating VIT datasets ...")

    return train_ds, valid_ds, None


if __name__ == "__main__":
    pretrain(
        train_valid_test_datasets_provider,
        model_provider,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
117
        ModelType.encoder_or_decoder,
118
        forward_step,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
119
        args_defaults={'dataloader_type': 'cyclic', 'vision_pretraining': True}
120
    )
121