pretrain_vision_dino.py 3.51 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
2
3
4

import torch
import torch.nn.functional as F
5
6
7
import torch.nn as nn
import numpy as np
import torch.distributed as dist
Vijay Korthikanti's avatar
Vijay Korthikanti committed
8
from functools import partial
9
10
from megatron import get_args, get_timers, mpu, print_rank_0
from megatron.data.vit_dataset import build_train_valid_datasets
Vijay Korthikanti's avatar
Vijay Korthikanti committed
11
from megatron.model.vision.dino import DINOPretrainModel
Vijay Korthikanti's avatar
Vijay Korthikanti committed
12
from megatron.model.vision.knn_monitor import knn_predict, get_feature_bank
13
from megatron.training import pretrain
14
15
16
17
from megatron.utils import average_losses_across_data_parallel_group, unwrap_model
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
Vijay Korthikanti's avatar
Vijay Korthikanti committed
18
from megatron.model import ModelType
19

Vijay Korthikanti's avatar
Vijay Korthikanti committed
20
def model_provider(pre_process=True, post_process=True):
21
    """Build the model."""
22
    return DINOPretrainModel(pre_process=pre_process, post_process=post_process)
23
24
25

def get_batch(data_iterator):
    """Build the batch."""
Vijay Korthikanti's avatar
Vijay Korthikanti committed
26
    data = next(data_iterator)
27

Vijay Korthikanti's avatar
Vijay Korthikanti committed
28
    # only data parallelism; no need for broadcast
29
30
31
32
    if isinstance(data[0], list):
        images = [aug.cuda() for aug in data[0]]
    else:
        images = data[0].cuda()
Vijay Korthikanti's avatar
Vijay Korthikanti committed
33
    labels = data[1].cuda()
34
35
36

    return images, labels

Vijay Korthikanti's avatar
Vijay Korthikanti committed
37

38
39
40
41
42
43
44
45
46
47
48
49
50
51
def loss_func(model, labels, output_tensor, collect_data=False):
    args = get_args()
    
    model = unwrap_model(
        model,
        (torchDDP, LocalDDP, Float16Module)
    )
    if model.training:
        student_output, teacher_output = output_tensor
        loss = model.dino_loss(student_output, teacher_output, args.curr_iteration)
        averaged_loss = average_losses_across_data_parallel_group([loss])
        return loss, {"loss": averaged_loss[0]}
    else:
        _, teacher_feature = output_tensor
Vijay Korthikanti's avatar
Vijay Korthikanti committed
52
        feature_bank, feature_labels, classes = get_feature_bank()
53
54
55
56
57
58
59
60
61
62
63
64
65
66
        feature = F.normalize(teacher_feature.float(), dim=1)

        knn_accs = []
        for k in [10, 20, 100, 200]:
            pred_labels = knn_predict(feature, feature_bank,
                                      feature_labels, classes, k, 0.07)
            knn_acc = (pred_labels[:, 0] == labels).float().mean()
            knn_accs.append(knn_acc)

        averaged_loss = average_losses_across_data_parallel_group(knn_accs)
        return 0, {"knn_acc_10": averaged_loss[0],
                   "knn_acc_20": averaged_loss[1],
                   "knn_acc_100": averaged_loss[2],
                   "knn_acc_200": averaged_loss[3]}
Vijay Korthikanti's avatar
Vijay Korthikanti committed
67
68
69


def forward_step(data_iterator, model):
70
71
72
73
    """Forward step."""
    timers = get_timers()

    # Get the batch.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
74
    timers("batch-generator", log_level=2).start()
75
76
77
78
    (
        images,
        labels,
    ) = get_batch(data_iterator)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
79
    timers("batch-generator").stop()
80

81
    return model(images), partial(loss_func, model, labels)
82
83
84
85
86
87
88
89
90


def train_valid_test_datasets_provider(train_val_test_num_samples):
    """Build train, valid, and test datasets."""
    args = get_args()

    print_rank_0(
        "> building train, validation, and test datasets " "for VIT ..."
    )
91
92
93
94
    train_ds, valid_ds = build_train_valid_datasets(
        data_path=args.data_path,
        image_size=(args.img_h, args.img_w)
    )
95
96
97
98
99
100
101
102
103
    print_rank_0("> finished creating VIT datasets ...")

    return train_ds, valid_ds, None


if __name__ == "__main__":
    pretrain(
        train_valid_test_datasets_provider,
        model_provider,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
104
        ModelType.encoder_or_decoder,
105
        forward_step,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
106
        args_defaults={'dataloader_type': 'cyclic', 'vision_pretraining': True}
107
    )
108