multiple_choice.py 5.22 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Multiple choice model."""

import torch

20
from megatron import get_args, print_rank_last
21
from megatron import mpu
22
from megatron.model.enums import AttnMaskType
23
from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids
24
25
26
27
from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
28
from .module import MegatronModule
29
30


31
class MultipleChoice(MegatronModule):
32

Jared Casper's avatar
Jared Casper committed
33
    def __init__(self,
34
35
36
37
                 num_tokentypes=2,
                 pre_process=True,
                 post_process=True):
        super(MultipleChoice, self).__init__(share_word_embeddings=False)
Mohammad's avatar
Mohammad committed
38
        args = get_args()
39

Mohammad's avatar
Mohammad committed
40
        init_method = init_method_normal(args.init_method_std)
41
42
        self.pre_process = pre_process
        self.post_process = post_process
43
44
45
46

        self.language_model, self._language_model_key = get_language_model(
            num_tokentypes=num_tokentypes,
            add_pooler=True,
47
            encoder_attn_mask_type=AttnMaskType.padding,
48
            init_method=init_method,
Mohammad's avatar
Mohammad committed
49
            scaled_init_method=scaled_init_method_normal(args.init_method_std,
50
51
52
                                                         args.num_layers),
            pre_process=self.pre_process,
            post_process=self.post_process)
53
54

        # Multi-choice head.
55
        if self.post_process:
56
57
58
59
            self.multichoice_dropout = torch.nn.Dropout(args.hidden_dropout)
            self.multichoice_head = get_linear_layer(args.hidden_size, 1,
                                                     init_method)
            self._multichoice_head_key = 'multichoice_head'
60

Jared Casper's avatar
Jared Casper committed
61
    def set_input_tensor(self, input_tensor):
62
        """See megatron.model.transformer.set_input_tensor()"""
63
64
        self.language_model.set_input_tensor(input_tensor)

65
    def forward(self, model_input, attention_mask, tokentype_ids=None):
66
67
68
69
70
71

        # [batch, choices, sequence] --> [batch * choices, sequence] -->
        #    transformer --> [batch, choices] --> softmax

        # Ensure the shape is [batch-size, choices, sequence]
        assert len(attention_mask.shape) == 3
72
        num_choices = attention_mask.shape[1]
73
74
75

        # Reshape and treat choice dimension the same as batch.
        attention_mask = attention_mask.view(-1, attention_mask.size(-1))
76
        extended_attention_mask = bert_extended_attention_mask(attention_mask)
77

78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
        input_ids = model_input
        # Do the same as attention_mask for input_ids, tokentype_ids
        assert len(input_ids.shape) == 3
        assert len(tokentype_ids.shape) == 3
        input_ids = input_ids.view(-1, input_ids.size(-1))
        tokentype_ids = tokentype_ids.view(-1, tokentype_ids.size(-1))
        position_ids = bert_position_ids(input_ids)

        lm_output = self.language_model(
            input_ids,
            position_ids,
            extended_attention_mask,
            tokentype_ids=tokentype_ids
        )
        if self.post_process:
93
94
95
            _, pooled_output = lm_output
            multichoice_output = self.multichoice_dropout(pooled_output)
            multichoice_logits = self.multichoice_head(multichoice_output)
96

97
98
            # Reshape back to separate choices.
            multichoice_logits = multichoice_logits.view(-1, num_choices)
99

100
101
            return multichoice_logits
        return lm_output
102
103
104
105
106
107
108
109
110
111

    def state_dict_for_save_checkpoint(self, destination=None, prefix='',
                                       keep_vars=False):
        """For easy load when model is combined with other heads,
        add an extra key."""

        state_dict_ = {}
        state_dict_[self._language_model_key] \
            = self.language_model.state_dict_for_save_checkpoint(
                destination, prefix, keep_vars)
112
        if self.post_process:
113
114
115
            state_dict_[self._multichoice_head_key] \
                = self.multichoice_head.state_dict(
                    destination, prefix, keep_vars)
116
117
118
119
120
121
122
        return state_dict_

    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""

        self.language_model.load_state_dict(
            state_dict[self._language_model_key], strict=strict)
123
        if self.post_process:
124
125
126
127
128
129
130
            if self._multichoice_head_key in state_dict:
                self.multichoice_head.load_state_dict(
                    state_dict[self._multichoice_head_key], strict=strict)
            else:
                print_rank_last('***WARNING*** could not find {} in the checkpoint, '
                                'initializing to random'.format(
                                    self._multichoice_head_key))