classification.py 4.01 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
3
4
5
6

"""Classification model."""

import torch

xingjinliang's avatar
xingjinliang committed
7
8
9
10
11
12
13
from megatron.training import get_args, print_rank_last
from megatron.legacy.model.enums import AttnMaskType
from megatron.legacy.model.bert_model import bert_extended_attention_mask, bert_position_ids
from megatron.legacy.model.language_model import get_language_model
from megatron.legacy.model.utils import get_linear_layer
from megatron.legacy.model.utils import init_method_normal
from megatron.legacy.model.utils import scaled_init_method_normal
14
from .module import MegatronModule
15
16


17
18
class Classification(MegatronModule):

19
    def __init__(self,
liangjing's avatar
v1  
liangjing committed
20
                 config,
21
                 num_classes,
22
23
24
                 num_tokentypes=2,
                 pre_process=True,
                 post_process=True):
liangjing's avatar
v1  
liangjing committed
25
        super().__init__(config=config, share_embeddings_and_output_weights=False)
Mohammad's avatar
Mohammad committed
26
        args = get_args()
27
28

        self.num_classes = num_classes
29
30
        self.pre_process = pre_process
        self.post_process = post_process
31
32

        self.language_model, self._language_model_key = get_language_model(
liangjing's avatar
v1  
liangjing committed
33
            config=config,
34
35
            num_tokentypes=num_tokentypes,
            add_pooler=True,
36
            encoder_attn_mask_type=AttnMaskType.padding,
37
38
            pre_process=self.pre_process,
            post_process=self.post_process)
39
40

        # Multi-choice head.
41
        if self.post_process:
42
43
44
            self.classification_dropout = torch.nn.Dropout(args.hidden_dropout)
            self.classification_head = get_linear_layer(args.hidden_size,
                                                        self.num_classes,
xingjinliang's avatar
xingjinliang committed
45
                                                        config.init_method)
46
            self._classification_head_key = 'classification_head'
47

48
    def set_input_tensor(self, input_tensor):
xingjinliang's avatar
xingjinliang committed
49
        """See megatron.legacy.model.transformer.set_input_tensor()"""
50
51
        self.language_model.set_input_tensor(input_tensor)

52
    def forward(self, model_input, attention_mask, tokentype_ids=None):
53

54
        extended_attention_mask = bert_extended_attention_mask(attention_mask)
55
56
57
58
59
60
61
62
63
        input_ids = model_input
        position_ids = bert_position_ids(input_ids)

        lm_output = self.language_model(
            input_ids,
            position_ids,
            extended_attention_mask,
            tokentype_ids=tokentype_ids
        )
64

65
        if self.post_process:
66
67
68
            _, pooled_output = lm_output
            classification_output = self.classification_dropout(pooled_output)
            classification_logits = self.classification_head(classification_output)
69

70
71
            # Reshape back to separate choices.
            classification_logits = classification_logits.view(-1, self.num_classes)
72

73
74
            return classification_logits
        return lm_output
75

76
    def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
77
78
79
80
81
        """For easy load when model is combined with other heads,
        add an extra key."""

        state_dict_ = {}
        state_dict_[self._language_model_key] \
82
83
            = self.language_model.state_dict_for_save_checkpoint(prefix=prefix,
                                                                 keep_vars=keep_vars)
84
        if self.post_process:
85
            state_dict_[self._classification_head_key] \
86
                = self.classification_head.state_dict(prefix=prefix, keep_vars=keep_vars)
87
88
89
90
91
92
93
        return state_dict_

    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""

        self.language_model.load_state_dict(
            state_dict[self._language_model_key], strict=strict)
94
        if self.post_process:
95
96
97
98
99
100
101
            if self._classification_head_key in state_dict:
                self.classification_head.load_state_dict(
                    state_dict[self._classification_head_key], strict=strict)
            else:
                print_rank_last('***WARNING*** could not find {} in the checkpoint, '
                                'initializing to random'.format(
                                    self._classification_head_key))