multiple_choice.py 4.65 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
3
4
5
6

"""Multiple choice model."""

import torch

7
from megatron import get_args, print_rank_last
8
from megatron import mpu
9
from megatron.model.enums import AttnMaskType
10
from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids
11
12
13
14
from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
15
from .module import MegatronModule
16
17


18
class MultipleChoice(MegatronModule):
19

Jared Casper's avatar
Jared Casper committed
20
    def __init__(self,
21
22
23
24
                 num_tokentypes=2,
                 pre_process=True,
                 post_process=True):
        super(MultipleChoice, self).__init__(share_word_embeddings=False)
Mohammad's avatar
Mohammad committed
25
        args = get_args()
26

Mohammad's avatar
Mohammad committed
27
        init_method = init_method_normal(args.init_method_std)
28
29
        self.pre_process = pre_process
        self.post_process = post_process
30
31
32
33

        self.language_model, self._language_model_key = get_language_model(
            num_tokentypes=num_tokentypes,
            add_pooler=True,
34
            encoder_attn_mask_type=AttnMaskType.padding,
35
            init_method=init_method,
Mohammad's avatar
Mohammad committed
36
            scaled_init_method=scaled_init_method_normal(args.init_method_std,
37
38
39
                                                         args.num_layers),
            pre_process=self.pre_process,
            post_process=self.post_process)
40
41

        # Multi-choice head.
42
        if self.post_process:
43
44
45
46
            self.multichoice_dropout = torch.nn.Dropout(args.hidden_dropout)
            self.multichoice_head = get_linear_layer(args.hidden_size, 1,
                                                     init_method)
            self._multichoice_head_key = 'multichoice_head'
47

Jared Casper's avatar
Jared Casper committed
48
    def set_input_tensor(self, input_tensor):
49
        """See megatron.model.transformer.set_input_tensor()"""
50
51
        self.language_model.set_input_tensor(input_tensor)

52
    def forward(self, model_input, attention_mask, tokentype_ids=None):
53
54
55
56
57
58

        # [batch, choices, sequence] --> [batch * choices, sequence] -->
        #    transformer --> [batch, choices] --> softmax

        # Ensure the shape is [batch-size, choices, sequence]
        assert len(attention_mask.shape) == 3
59
        num_choices = attention_mask.shape[1]
60
61
62

        # Reshape and treat choice dimension the same as batch.
        attention_mask = attention_mask.view(-1, attention_mask.size(-1))
63
        extended_attention_mask = bert_extended_attention_mask(attention_mask)
64

65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
        input_ids = model_input
        # Do the same as attention_mask for input_ids, tokentype_ids
        assert len(input_ids.shape) == 3
        assert len(tokentype_ids.shape) == 3
        input_ids = input_ids.view(-1, input_ids.size(-1))
        tokentype_ids = tokentype_ids.view(-1, tokentype_ids.size(-1))
        position_ids = bert_position_ids(input_ids)

        lm_output = self.language_model(
            input_ids,
            position_ids,
            extended_attention_mask,
            tokentype_ids=tokentype_ids
        )
        if self.post_process:
80
81
82
            _, pooled_output = lm_output
            multichoice_output = self.multichoice_dropout(pooled_output)
            multichoice_logits = self.multichoice_head(multichoice_output)
83

84
85
            # Reshape back to separate choices.
            multichoice_logits = multichoice_logits.view(-1, num_choices)
86

87
88
            return multichoice_logits
        return lm_output
89

90
    def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
91
92
93
94
95
        """For easy load when model is combined with other heads,
        add an extra key."""

        state_dict_ = {}
        state_dict_[self._language_model_key] \
96
97
            = self.language_model.state_dict_for_save_checkpoint(prefix=prefix,
                                                                 keep_vars=keep_vars)
98
        if self.post_process:
99
            state_dict_[self._multichoice_head_key] \
100
                = self.multichoice_head.state_dict(prefix=prefix, keep_vars=keep_vars)
101
102
103
104
105
106
107
        return state_dict_

    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""

        self.language_model.load_state_dict(
            state_dict[self._language_model_key], strict=strict)
108
        if self.post_process:
109
110
111
112
113
114
115
            if self._multichoice_head_key in state_dict:
                self.multichoice_head.load_state_dict(
                    state_dict[self._multichoice_head_key], strict=strict)
            else:
                print_rank_last('***WARNING*** could not find {} in the checkpoint, '
                                'initializing to random'.format(
                                    self._multichoice_head_key))