classification.py 2.63 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2

3
"""Vision-classification finetuning/evaluation."""
4

5
6
import torch.nn.functional as F
from functools import partial
xingjinliang's avatar
xingjinliang committed
7
8
9
10
from megatron.training import get_args, get_timers
from megatron.training import print_rank_0
from megatron.legacy.model.vision.classification import VitClassificationModel
from megatron.legacy.data.vit_dataset import build_train_valid_datasets
11
from tasks.vision.classification.eval_utils import accuracy_func_provider
12
from tasks.vision.finetune_utils import finetune
xingjinliang's avatar
xingjinliang committed
13
from megatron.training.utils import average_losses_across_data_parallel_group
14
15
16
17
18
19
20
21
22


def classification():
    def train_valid_datasets_provider():
        """Build train and validation dataset."""
        args = get_args()

        train_ds, valid_ds = build_train_valid_datasets(
            data_path=args.data_path,
23
            image_size=(args.img_h, args.img_w),
24
25
26
        )
        return train_ds, valid_ds

Vijay Korthikanti's avatar
Vijay Korthikanti committed
27
    def model_provider(pre_process=True, post_process=True):
28
29
30
31
32
        """Build the model."""
        args = get_args()

        print_rank_0("building classification model for ImageNet ...")

33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
        return VitClassificationModel(num_classes=args.num_classes, finetune=True,
                                      pre_process=pre_process, post_process=post_process)

    def process_batch(batch):
        """Process batch and produce inputs for the model."""
        images = batch[0].cuda().contiguous()
        labels = batch[1].cuda().contiguous()
        return images, labels

    def cross_entropy_loss_func(labels, output_tensor):
        logits = output_tensor

        # Cross-entropy loss.
        loss = F.cross_entropy(logits.contiguous().float(), labels)

        # Reduce loss for logging.
        averaged_loss = average_losses_across_data_parallel_group([loss])

        return loss, {'lm loss': averaged_loss[0]}

    def _cross_entropy_forward_step(batch, model):
        """Simple forward step with cross-entropy loss."""
        timers = get_timers()

        # Get the batch.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
58
        timers("batch generator", log_level=2).start()
59
60
        try:
            batch_ = next(batch)
xingjinliang's avatar
xingjinliang committed
61
        except Exception:
62
63
64
65
66
67
68
69
            batch_ = batch
        images, labels = process_batch(batch_)
        timers("batch generator").stop()

        # Forward model.
        output_tensor = model(images)
      
        return output_tensor, partial(cross_entropy_loss_func, labels)
70
71
72
73
74

    """Finetune/evaluate."""
    finetune(
        train_valid_datasets_provider,
        model_provider,
75
        forward_step=_cross_entropy_forward_step,
76
77
78
79
80
        end_of_epoch_callback_provider=accuracy_func_provider,
    )

def main():
    classification()
81