classification.py 4.13 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
3
4
5
6

"""Classification model."""

import torch

7
from megatron import get_args, print_rank_last
8
from megatron.model.enums import AttnMaskType
9
from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids
10
11
12
13
from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
14
from .module import MegatronModule
15
16


17
18
class Classification(MegatronModule):

19
20
    def __init__(self,
                 num_classes,
21
22
23
24
                 num_tokentypes=2,
                 pre_process=True,
                 post_process=True):
        super(Classification, self).__init__(share_word_embeddings=False)
Mohammad's avatar
Mohammad committed
25
        args = get_args()
26
27

        self.num_classes = num_classes
28
29
        self.pre_process = pre_process
        self.post_process = post_process
Mohammad's avatar
Mohammad committed
30
        init_method = init_method_normal(args.init_method_std)
31
32
33
34

        self.language_model, self._language_model_key = get_language_model(
            num_tokentypes=num_tokentypes,
            add_pooler=True,
35
            encoder_attn_mask_type=AttnMaskType.padding,
36
            init_method=init_method,
Mohammad's avatar
Mohammad committed
37
            scaled_init_method=scaled_init_method_normal(args.init_method_std,
38
39
40
                                                         args.num_layers),
            pre_process=self.pre_process,
            post_process=self.post_process)
41
42

        # Multi-choice head.
43
        if self.post_process:
44
45
46
47
48
            self.classification_dropout = torch.nn.Dropout(args.hidden_dropout)
            self.classification_head = get_linear_layer(args.hidden_size,
                                                        self.num_classes,
                                                        init_method)
            self._classification_head_key = 'classification_head'
49

50
    def set_input_tensor(self, input_tensor):
51
        """See megatron.model.transformer.set_input_tensor()"""
52
53
        self.language_model.set_input_tensor(input_tensor)

54
    def forward(self, model_input, attention_mask, tokentype_ids=None):
55

56
        extended_attention_mask = bert_extended_attention_mask(attention_mask)
57
58
59
60
61
62
63
64
65
        input_ids = model_input
        position_ids = bert_position_ids(input_ids)

        lm_output = self.language_model(
            input_ids,
            position_ids,
            extended_attention_mask,
            tokentype_ids=tokentype_ids
        )
66

67
        if self.post_process:
68
69
70
            _, pooled_output = lm_output
            classification_output = self.classification_dropout(pooled_output)
            classification_logits = self.classification_head(classification_output)
71

72
73
            # Reshape back to separate choices.
            classification_logits = classification_logits.view(-1, self.num_classes)
74

75
76
            return classification_logits
        return lm_output
77

78
    def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
79
80
81
82
83
        """For easy load when model is combined with other heads,
        add an extra key."""

        state_dict_ = {}
        state_dict_[self._language_model_key] \
84
85
            = self.language_model.state_dict_for_save_checkpoint(prefix=prefix,
                                                                 keep_vars=keep_vars)
86
        if self.post_process:
87
            state_dict_[self._classification_head_key] \
88
                = self.classification_head.state_dict(prefix=prefix, keep_vars=keep_vars)
89
90
91
92
93
94
95
        return state_dict_

    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""

        self.language_model.load_state_dict(
            state_dict[self._language_model_key], strict=strict)
96
        if self.post_process:
97
98
99
100
101
102
103
            if self._classification_head_key in state_dict:
                self.classification_head.load_state_dict(
                    state_dict[self._classification_head_key], strict=strict)
            else:
                print_rank_last('***WARNING*** could not find {} in the checkpoint, '
                                'initializing to random'.format(
                                    self._classification_head_key))