pretrain_vision_classify.py 3.31 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
2
3
4
5
6

"""Pretrain VIT"""

import torch
import torch.nn.functional as F
Vijay Korthikanti's avatar
Vijay Korthikanti committed
7
from functools import partial
xingjinliang's avatar
xingjinliang committed
8
from megatron.training import get_args, get_timers, print_rank_0
9
from megatron.core.enums import ModelType
xingjinliang's avatar
xingjinliang committed
10
11
12
from megatron.legacy.data.vit_dataset import build_train_valid_datasets
from megatron.legacy.model.vision.classification import VitClassificationModel
from megatron.legacy.model.vision.classification import MitClassificationModel
13
from megatron.training import pretrain
xingjinliang's avatar
xingjinliang committed
14
15
from megatron.training.utils import average_losses_across_data_parallel_group
from megatron.training.arguments import core_transformer_config_from_args
16

17

Vijay Korthikanti's avatar
Vijay Korthikanti committed
18
def model_provider(pre_process=True, post_process=True):
19
20
21
    """Build the model."""

    args = get_args()
liangjing's avatar
v1  
liangjing committed
22
    config = core_transformer_config_from_args(args)
23
    if args.vision_backbone_type == 'vit':
Vijay Korthikanti's avatar
Vijay Korthikanti committed
24
        print_rank_0("building VIT model ...")
liangjing's avatar
v1  
liangjing committed
25
26
        model = VitClassificationModel(config=config,
                                       num_classes=args.num_classes,
27
28
29
                                       pre_process=pre_process,
                                       post_process=post_process)
    elif args.vision_backbone_type == 'mit':
Vijay Korthikanti's avatar
Vijay Korthikanti committed
30
        print_rank_0("building MIT model ...")
31
32
33
34
35
36
        model = MitClassificationModel(num_classes=args.num_classes,
                                       pre_process=pre_process,
                                       post_process=post_process)
    else:
        raise Exception('{} vision backbone is not supported.'.format(
                              args.vision_backbone_type))
37
38
    return model

39

40
41
def get_batch(data_iterator):
    """Build the batch."""
Vijay Korthikanti's avatar
Vijay Korthikanti committed
42
    data = next(data_iterator)
43

Vijay Korthikanti's avatar
Vijay Korthikanti committed
44
45
46
    # only data parallelism; no need for broadcast
    images = data[0].cuda()
    labels = data[1].cuda()
47
48
49

    return images, labels

50

Vijay Korthikanti's avatar
Vijay Korthikanti committed
51
52
53
54
55
56
57
58
59
60
61
62
def loss_func(labels, output_tensor):
    logits = output_tensor.contiguous().float()
    loss = F.cross_entropy(logits, labels)

    outputs = torch.argmax(logits, -1)
    correct = (outputs == labels).float()
    accuracy = torch.mean(correct)

    averaged_loss = average_losses_across_data_parallel_group([loss, accuracy])

    return loss, {"loss": averaged_loss[0], "accuracy": averaged_loss[1]}

63

Vijay Korthikanti's avatar
Vijay Korthikanti committed
64
def forward_step(data_iterator, model):
65
66
67
68
    """Forward step."""
    timers = get_timers()

    # Get the batch.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
69
    timers("batch-generator", log_level=2).start()
70
71
72
73
    (
        images,
        labels,
    ) = get_batch(data_iterator)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
74
    timers("batch-generator").stop()
75
76

    # Forward model. lm_labels
Vijay Korthikanti's avatar
Vijay Korthikanti committed
77
    output_tensor = model(images)
78

Vijay Korthikanti's avatar
Vijay Korthikanti committed
79
    return output_tensor, partial(loss_func, labels)
80
81
82
83
84
85
86
87

def train_valid_test_datasets_provider(train_val_test_num_samples):
    """Build train, valid, and test datasets."""
    args = get_args()

    print_rank_0(
        "> building train, validation, and test datasets " "for VIT ..."
    )
88
89
90
91
    train_ds, valid_ds = build_train_valid_datasets(
        data_path=args.data_path,
        image_size=(args.img_h, args.img_w)
    )
92
93
94
95
96
97
98
99
100
101
    print_rank_0("> finished creating VIT datasets ...")

    return train_ds, valid_ds, None


if __name__ == "__main__":

    pretrain(
        train_valid_test_datasets_provider,
        model_provider,
102
        ModelType.encoder_or_decoder,
103
        forward_step,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
104
        args_defaults={'dataloader_type': 'cyclic', 'vision_pretraining': True}
105
    )