classification.py 4.09 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Classification model."""

import torch

20
from megatron import get_args, print_rank_0
Neel Kant's avatar
Neel Kant committed
21
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids
22
23
24
25
26
27
28
29
30
from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
from megatron.module import MegatronModule


class Classification(MegatronModule):

Mohammad's avatar
Mohammad committed
31
    def __init__(self, num_classes, num_tokentypes=2):
32
        super(Classification, self).__init__()
Mohammad's avatar
Mohammad committed
33
        args = get_args()
34
35

        self.num_classes = num_classes
Mohammad's avatar
Mohammad committed
36
        init_method = init_method_normal(args.init_method_std)
37
38

        self.language_model, self._language_model_key = get_language_model(
Mohammad's avatar
Mohammad committed
39
            attention_mask_func=bert_attention_mask_func,
40
41
42
            num_tokentypes=num_tokentypes,
            add_pooler=True,
            init_method=init_method,
Mohammad's avatar
Mohammad committed
43
44
            scaled_init_method=scaled_init_method_normal(args.init_method_std,
                                                         args.num_layers))
45
46

        # Multi-choice head.
Mohammad's avatar
Mohammad committed
47
48
        self.classification_dropout = torch.nn.Dropout(args.hidden_dropout)
        self.classification_head = get_linear_layer(args.hidden_size,
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
                                                    self.num_classes,
                                                    init_method)
        self._classification_head_key = 'classification_head'

    def forward(self, input_ids, attention_mask, tokentype_ids):

        extended_attention_mask = bert_extended_attention_mask(
            attention_mask, next(self.language_model.parameters()).dtype)
        position_ids = bert_position_ids(input_ids)

        _, pooled_output = self.language_model(input_ids,
                                               position_ids,
                                               extended_attention_mask,
                                               tokentype_ids=tokentype_ids)

        # Output.
        classification_output = self.classification_dropout(pooled_output)
        classification_logits = self.classification_head(classification_output)

        # Reshape back to separate choices.
        classification_logits = classification_logits.view(-1, self.num_classes)

        return classification_logits

    def state_dict_for_save_checkpoint(self, destination=None, prefix='',
                                       keep_vars=False):
        """For easy load when model is combined with other heads,
        add an extra key."""

        state_dict_ = {}
        state_dict_[self._language_model_key] \
            = self.language_model.state_dict_for_save_checkpoint(
                destination, prefix, keep_vars)
        state_dict_[self._classification_head_key] \
            = self.classification_head.state_dict(
                destination, prefix, keep_vars)
        return state_dict_

    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""

        self.language_model.load_state_dict(
            state_dict[self._language_model_key], strict=strict)
        if self._classification_head_key in state_dict:
            self.classification_head.load_state_dict(
                state_dict[self._classification_head_key], strict=strict)
        else:
            print_rank_0('***WARNING*** could not find {} in the checkpoint, '
                         'initializing to random'.format(
                             self._classification_head_key))