multiple_choice.py 6.58 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Multiple choice model."""

import torch

20
from megatron import get_args, print_rank_last
21
from megatron import mpu
Neel Kant's avatar
Neel Kant committed
22
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids
23
24
25
26
from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
27
from megatron.module import PipelinedMegatronModule
28
29


30
class MultipleChoiceBase(PipelinedMegatronModule):
31

Mohammad's avatar
Mohammad committed
32
    def __init__(self, num_tokentypes=2):
33
        super(MultipleChoiceBase, self).__init__(share_word_embeddings=False)
Mohammad's avatar
Mohammad committed
34
        args = get_args()
35

Mohammad's avatar
Mohammad committed
36
        init_method = init_method_normal(args.init_method_std)
37
38

        self.language_model, self._language_model_key = get_language_model(
Mohammad's avatar
Mohammad committed
39
            attention_mask_func=bert_attention_mask_func,
40
41
42
            num_tokentypes=num_tokentypes,
            add_pooler=True,
            init_method=init_method,
Mohammad's avatar
Mohammad committed
43
44
            scaled_init_method=scaled_init_method_normal(args.init_method_std,
                                                         args.num_layers))
45
46

        # Multi-choice head.
47
48
49
50
51
        if mpu.is_pipeline_last_stage():
            self.multichoice_dropout = torch.nn.Dropout(args.hidden_dropout)
            self.multichoice_head = get_linear_layer(args.hidden_size, 1,
                                                     init_method)
            self._multichoice_head_key = 'multichoice_head'
52

53
    def forward(self, model_input, attention_mask, tokentype_ids=None):
54
55
56
57
58
59

        # [batch, choices, sequence] --> [batch * choices, sequence] -->
        #    transformer --> [batch, choices] --> softmax

        # Ensure the shape is [batch-size, choices, sequence]
        assert len(attention_mask.shape) == 3
60
        num_choices = attention_mask.shape[1]
61
62
63

        # Reshape and treat choice dimension the same as batch.
        attention_mask = attention_mask.view(-1, attention_mask.size(-1))
64
        extended_attention_mask = bert_extended_attention_mask(attention_mask)
65

66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
        kwargs = {}
        if mpu.is_pipeline_first_stage():
            input_ids = model_input
            # Do the same as attention_mask for input_ids, tokentype_ids
            assert len(input_ids.shape) == 3
            assert len(tokentype_ids.shape) == 3
            input_ids = input_ids.view(-1, input_ids.size(-1))
            tokentype_ids = tokentype_ids.view(-1, tokentype_ids.size(-1))

            position_ids = bert_position_ids(input_ids)
            args = [input_ids, position_ids, extended_attention_mask]
            kwargs['tokentype_ids'] = tokentype_ids
        else:
            args = [model_input, extended_attention_mask]
        lm_output = self.language_model(*args, **kwargs)
        if mpu.is_pipeline_last_stage():
            _, pooled_output = lm_output
            multichoice_output = self.multichoice_dropout(pooled_output)
            multichoice_logits = self.multichoice_head(multichoice_output)
85

86
87
            # Reshape back to separate choices.
            multichoice_logits = multichoice_logits.view(-1, num_choices)
88

89
90
            return multichoice_logits
        return lm_output
91
92
93
94
95
96
97
98
99
100

    def state_dict_for_save_checkpoint(self, destination=None, prefix='',
                                       keep_vars=False):
        """For easy load when model is combined with other heads,
        add an extra key."""

        state_dict_ = {}
        state_dict_[self._language_model_key] \
            = self.language_model.state_dict_for_save_checkpoint(
                destination, prefix, keep_vars)
101
102
103
104
        if mpu.is_pipeline_last_stage():
            state_dict_[self._multichoice_head_key] \
                = self.multichoice_head.state_dict(
                    destination, prefix, keep_vars)
105
106
107
108
109
110
111
        return state_dict_

    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""

        self.language_model.load_state_dict(
            state_dict[self._language_model_key], strict=strict)
112
113
114
115
116
117
118
119
        if mpu.is_pipeline_last_stage():
            if self._multichoice_head_key in state_dict:
                self.multichoice_head.load_state_dict(
                    state_dict[self._multichoice_head_key], strict=strict)
            else:
                print_rank_last('***WARNING*** could not find {} in the checkpoint, '
                                'initializing to random'.format(
                                    self._multichoice_head_key))
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170

class MultipleChoice(MultipleChoiceBase):

    def __init__(self, num_tokentypes=2):
        super(MultipleChoice, self).__init__(
            num_tokentypes=num_tokentypes)

    def forward(self, input_ids, attention_mask,
                tokentype_ids=None):
        return super(MultipleChoice, self).forward(
            input_ids,
            attention_mask,
            tokentype_ids=tokentype_ids)


class MultipleChoiceFirstStage(MultipleChoiceBase):

    def __init__(self, num_tokentypes=2):
        super(MultipleChoiceFirstStage, self).__init__(
            num_tokentypes=num_tokentypes)

    def forward(self, input_ids, attention_mask,
                tokentype_ids=None):
        return super(MultipleChoiceFirstStage, self).forward(
            input_ids,
            attention_mask,
            tokentype_ids=tokentype_ids)


class MultipleChoiceIntermediateStage(MultipleChoiceBase):

    def __init__(self, num_tokentypes=2):
        super(MultipleChoiceIntermediateStage, self).__init__(
            num_tokentypes=num_tokentypes)

    def forward(self, hidden_state, attention_mask):
        return super(MultipleChoiceIntermediateStage, self).forward(
            hidden_state,
            attention_mask)


class MultipleChoiceLastStage(MultipleChoiceBase):

    def __init__(self, num_tokentypes=2):
        super(MultipleChoiceLastStage, self).__init__(
            num_tokentypes=num_tokentypes)

    def forward(self, hidden_state, attention_mask):
        return super(MultipleChoiceLastStage, self).forward(
            hidden_state,
            attention_mask)