classification.py 2.72 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
3
4
5

"""Vision Transformer(VIT) model."""

import torch
Vijay Korthikanti's avatar
Vijay Korthikanti committed
6
from torch.nn.init import trunc_normal_
7
8
9
from megatron import get_args
from megatron.model.utils import get_linear_layer
from megatron.model.vision.vit_backbone import VitBackbone, VitMlpHead
10
from megatron.model.vision.mit_backbone import mit_b3_avg
11
12
13
14
15
from megatron.model.module import MegatronModule

class VitClassificationModel(MegatronModule):
    """Vision Transformer Model."""

liangjing's avatar
v1  
liangjing committed
16
    def __init__(self, config, num_classes, finetune=False,
17
18
19
20
21
22
23
24
25
26
                 pre_process=True, post_process=True):
        super(VitClassificationModel, self).__init__()
        args = get_args()

        self.hidden_size = args.hidden_size
        self.num_classes = num_classes
        self.finetune = finetune
        self.pre_process = pre_process
        self.post_process = post_process
        self.backbone = VitBackbone(
liangjing's avatar
v1  
liangjing committed
27
            config=config,
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
            pre_process=self.pre_process,
            post_process=self.post_process,
            single_token_output=True
        )
        
        if self.post_process:
            if not self.finetune:
                self.head = VitMlpHead(self.hidden_size, self.num_classes)
            else:
                self.head = get_linear_layer(
                    self.hidden_size,
                    self.num_classes,
                    torch.nn.init.zeros_
                )

    def set_input_tensor(self, input_tensor):
        """See megatron.model.transformer.set_input_tensor()"""
        self.backbone.set_input_tensor(input_tensor)

    def forward(self, input):
        hidden_states = self.backbone(input)

        if self.post_process:
            hidden_states = self.head(hidden_states)

        return hidden_states
54
55
56
57
58


class MitClassificationModel(MegatronModule):
    """Mix vision Transformer Model."""

Vijay Korthikanti's avatar
Vijay Korthikanti committed
59
    def __init__(self, num_classes,
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
                 pre_process=True, post_process=True):
        super(MitClassificationModel, self).__init__()
        args = get_args()

        self.hidden_size = args.hidden_size
        self.num_classes = num_classes

        self.backbone = mit_b3_avg()
        self.head = torch.nn.Linear(512, num_classes)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, torch.nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, torch.nn.Linear) and m.bias is not None:
                torch.nn.init.constant_(m.bias, 0)

    def set_input_tensor(self, input_tensor):
        """See megatron.model.transformer.set_input_tensor()"""
        pass

    def forward(self, input):
        hidden_states = self.backbone(input)
        hidden_states = self.head(hidden_states)

        return hidden_states