pretrain_vision_dino.py 4.06 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Pretrain VIT"""

import torch
import torch.nn.functional as F
20
21
22
import torch.nn as nn
import numpy as np
import torch.distributed as dist
Vijay Korthikanti's avatar
Vijay Korthikanti committed
23
from functools import partial
24
25
from megatron import get_args, get_timers, mpu, print_rank_0
from megatron.data.vit_dataset import build_train_valid_datasets
Vijay Korthikanti's avatar
Vijay Korthikanti committed
26
from megatron.model.vision.dino import DINOPretrainModel
27
from megatron.model.vision.knn_monitor import knn_predict
28
from megatron.training import pretrain
29
30
31
32
from megatron.utils import average_losses_across_data_parallel_group, unwrap_model
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
Vijay Korthikanti's avatar
Vijay Korthikanti committed
33
from megatron.model import ModelType
34

Vijay Korthikanti's avatar
Vijay Korthikanti committed
35
def model_provider(pre_process=True, post_process=True):
36
37
    """Build the model."""
    print_rank_0("building VIT model ...")
38
    return DINOPretrainModel(pre_process=pre_process, post_process=post_process)
39
40
41

def get_batch(data_iterator):
    """Build the batch."""
Vijay Korthikanti's avatar
Vijay Korthikanti committed
42
    data = next(data_iterator)
43

Vijay Korthikanti's avatar
Vijay Korthikanti committed
44
    # only data parallelism; no need for broadcast
45
46
47
48
    if isinstance(data[0], list):
        images = [aug.cuda() for aug in data[0]]
    else:
        images = data[0].cuda()
Vijay Korthikanti's avatar
Vijay Korthikanti committed
49
    labels = data[1].cuda()
50
51
52

    return images, labels

Vijay Korthikanti's avatar
Vijay Korthikanti committed
53

54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
def loss_func(model, labels, output_tensor, collect_data=False):
    args = get_args()
    
    model = unwrap_model(
        model,
        (torchDDP, LocalDDP, Float16Module)
    )
    if model.training:
        student_output, teacher_output = output_tensor
        loss = model.dino_loss(student_output, teacher_output, args.curr_iteration)
        averaged_loss = average_losses_across_data_parallel_group([loss])
        return loss, {"loss": averaged_loss[0]}
    else:
        _, teacher_feature = output_tensor
        feature_bank, feature_labels, classes = args.knn_features
        feature = F.normalize(teacher_feature.float(), dim=1)

        knn_accs = []
        for k in [10, 20, 100, 200]:
            pred_labels = knn_predict(feature, feature_bank,
                                      feature_labels, classes, k, 0.07)
            knn_acc = (pred_labels[:, 0] == labels).float().mean()
            knn_accs.append(knn_acc)

        averaged_loss = average_losses_across_data_parallel_group(knn_accs)
        return 0, {"knn_acc_10": averaged_loss[0],
                   "knn_acc_20": averaged_loss[1],
                   "knn_acc_100": averaged_loss[2],
                   "knn_acc_200": averaged_loss[3]}
Vijay Korthikanti's avatar
Vijay Korthikanti committed
83
84
85


def forward_step(data_iterator, model):
86
87
88
89
    """Forward step."""
    timers = get_timers()

    # Get the batch.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
90
    timers("batch-generator").start()
91
92
93
94
    (
        images,
        labels,
    ) = get_batch(data_iterator)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
95
    timers("batch-generator").stop()
96

97
    return model(images), partial(loss_func, model, labels)
98
99
100
101
102
103
104
105
106


def train_valid_test_datasets_provider(train_val_test_num_samples):
    """Build train, valid, and test datasets."""
    args = get_args()

    print_rank_0(
        "> building train, validation, and test datasets " "for VIT ..."
    )
107
108
109
110
    train_ds, valid_ds = build_train_valid_datasets(
        data_path=args.data_path,
        image_size=(args.img_h, args.img_w)
    )
111
112
113
114
115
116
117
118
119
    print_rank_0("> finished creating VIT datasets ...")

    return train_ds, valid_ds, None


if __name__ == "__main__":
    pretrain(
        train_valid_test_datasets_provider,
        model_provider,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
120
        ModelType.encoder_or_decoder,
121
        forward_step,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
122
        args_defaults={'dataloader_type': 'cyclic'}
123
    )
124