pretrain_vision_dino.py 3.99 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Pretrain VIT"""

import torch
import torch.nn.functional as F
20
21
22
import torch.nn as nn
import numpy as np
import torch.distributed as dist
Vijay Korthikanti's avatar
Vijay Korthikanti committed
23
from functools import partial
24
25
from megatron import get_args, get_timers, mpu, print_rank_0
from megatron.data.vit_dataset import build_train_valid_datasets
26
27
from megatron.model.vision.contrastive import DINOPretrainModel
from megatron.model.vision.knn_monitor import knn_predict
28
from megatron.training import pretrain
29
30
31
32
from megatron.utils import average_losses_across_data_parallel_group, unwrap_model
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
33

Vijay Korthikanti's avatar
Vijay Korthikanti committed
34
def model_provider(pre_process=True, post_process=True):
35
36
    """Build the model."""
    print_rank_0("building VIT model ...")
37
    return DINOPretrainModel(pre_process=pre_process, post_process=post_process)
38
39
40

def get_batch(data_iterator):
    """Build the batch."""
Vijay Korthikanti's avatar
Vijay Korthikanti committed
41
    data = next(data_iterator)
42

Vijay Korthikanti's avatar
Vijay Korthikanti committed
43
    # only data parallelism; no need for broadcast
44
45
46
47
    if isinstance(data[0], list):
        images = [aug.cuda() for aug in data[0]]
    else:
        images = data[0].cuda()
Vijay Korthikanti's avatar
Vijay Korthikanti committed
48
    labels = data[1].cuda()
49
50
51

    return images, labels

Vijay Korthikanti's avatar
Vijay Korthikanti committed
52

53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def loss_func(model, labels, output_tensor, collect_data=False):
    args = get_args()
    
    model = unwrap_model(
        model,
        (torchDDP, LocalDDP, Float16Module)
    )
    if model.training:
        student_output, teacher_output = output_tensor
        loss = model.dino_loss(student_output, teacher_output, args.curr_iteration)
        averaged_loss = average_losses_across_data_parallel_group([loss])
        return loss, {"loss": averaged_loss[0]}
    else:
        _, teacher_feature = output_tensor
        feature_bank, feature_labels, classes = args.knn_features
        feature = F.normalize(teacher_feature.float(), dim=1)

        knn_accs = []
        for k in [10, 20, 100, 200]:
            pred_labels = knn_predict(feature, feature_bank,
                                      feature_labels, classes, k, 0.07)
            knn_acc = (pred_labels[:, 0] == labels).float().mean()
            knn_accs.append(knn_acc)

        averaged_loss = average_losses_across_data_parallel_group(knn_accs)
        return 0, {"knn_acc_10": averaged_loss[0],
                   "knn_acc_20": averaged_loss[1],
                   "knn_acc_100": averaged_loss[2],
                   "knn_acc_200": averaged_loss[3]}
Vijay Korthikanti's avatar
Vijay Korthikanti committed
82
83
84


def forward_step(data_iterator, model):
85
86
87
88
    """Forward step."""
    timers = get_timers()

    # Get the batch.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
89
    timers("batch-generator").start()
90
91
92
93
    (
        images,
        labels,
    ) = get_batch(data_iterator)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
94
    timers("batch-generator").stop()
95

96
    return model(images), partial(loss_func, model, labels)
97
98
99
100
101
102
103
104
105


def train_valid_test_datasets_provider(train_val_test_num_samples):
    """Build train, valid, and test datasets."""
    args = get_args()

    print_rank_0(
        "> building train, validation, and test datasets " "for VIT ..."
    )
106
107
108
109
    train_ds, valid_ds = build_train_valid_datasets(
        data_path=args.data_path,
        image_size=(args.img_h, args.img_w)
    )
110
111
112
113
114
115
116
117
118
119
    print_rank_0("> finished creating VIT datasets ...")

    return train_ds, valid_ds, None


if __name__ == "__main__":
    pretrain(
        train_valid_test_datasets_provider,
        model_provider,
        forward_step,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
120
        args_defaults={'dataloader_type': 'cyclic'}
121
    )
122