classification.py 3.13 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16
"""Vision-classification finetuning/evaluation."""
17

18
19
20
import torch.nn.functional as F
from functools import partial
from megatron import get_args, get_timers
21
from megatron import print_rank_0
22
from megatron.model.vision.classification import VitClassificationModel
23
from megatron.data.vit_dataset import build_train_valid_datasets
24
from tasks.vision.classification.eval_utils import accuracy_func_provider
25
from tasks.vision.finetune_utils import finetune
26
from megatron.utils import average_losses_across_data_parallel_group
27
28
29
30
31
32
33
34
35


def classification():
    def train_valid_datasets_provider():
        """Build train and validation dataset."""
        args = get_args()

        train_ds, valid_ds = build_train_valid_datasets(
            data_path=args.data_path,
36
            image_size=(args.img_h, args.img_w),
37
38
39
        )
        return train_ds, valid_ds

Vijay Korthikanti's avatar
Vijay Korthikanti committed
40
    def model_provider(pre_process=True, post_process=True):
41
42
43
44
45
        """Build the model."""
        args = get_args()

        print_rank_0("building classification model for ImageNet ...")

46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
        return VitClassificationModel(num_classes=args.num_classes, finetune=True,
                                      pre_process=pre_process, post_process=post_process)

    def process_batch(batch):
        """Process batch and produce inputs for the model."""
        images = batch[0].cuda().contiguous()
        labels = batch[1].cuda().contiguous()
        return images, labels

    def cross_entropy_loss_func(labels, output_tensor):
        logits = output_tensor

        # Cross-entropy loss.
        loss = F.cross_entropy(logits.contiguous().float(), labels)

        # Reduce loss for logging.
        averaged_loss = average_losses_across_data_parallel_group([loss])

        return loss, {'lm loss': averaged_loss[0]}

    def _cross_entropy_forward_step(batch, model):
        """Simple forward step with cross-entropy loss."""
        timers = get_timers()

        # Get the batch.
        timers("batch generator").start()
        try:
            batch_ = next(batch)
        except BaseException:
            batch_ = batch
        images, labels = process_batch(batch_)
        timers("batch generator").stop()

        # Forward model.
        output_tensor = model(images)
      
        return output_tensor, partial(cross_entropy_loss_func, labels)
83
84
85
86
87

    """Finetune/evaluate."""
    finetune(
        train_valid_datasets_provider,
        model_provider,
88
        forward_step=_cross_entropy_forward_step,
89
90
91
92
93
        end_of_epoch_callback_provider=accuracy_func_provider,
    )

def main():
    classification()
94