"vscode:/vscode.git/clone" did not exist on "3d620f9462b800b880d51e0eb6e51a91182f79db"
run_swag.py 8.38 KB
Newer Older
Grégory Châtel's avatar
Grégory Châtel committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BERT finetuning runner."""

17
18
import pandas as pd

19
20
21
22
23
24
25
26
27
import logging

from pytorch_pretrained_bert.tokenization import BertTokenizer

logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt = '%m/%d/%Y %H:%M:%S',
                    level = logging.INFO)
logger = logging.getLogger(__name__)

28

Grégory Châtel's avatar
Grégory Châtel committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class SwagExample(object):
    """A single training/test example for the SWAG dataset."""
    def __init__(self,
                 swag_id,
                 context_sentence,
                 start_ending,
                 ending_0,
                 ending_1,
                 ending_2,
                 ending_3,
                 label = None):
        self.swag_id = swag_id
        self.context_sentence = context_sentence
        self.start_ending = start_ending
43
44
45
46
47
48
        self.endings = [
            ending_0,
            ending_1,
            ending_2,
            ending_3,
        ]
Grégory Châtel's avatar
Grégory Châtel committed
49
50
51
52
53
54
55
        self.label = label

    def __str__(self):
        return self.__repr__()

    def __repr__(self):
        l = [
56
57
58
59
60
61
62
            f"swag_id: {self.swag_id}",
            f"context_sentence: {self.context_sentence}",
            f"start_ending: {self.start_ending}",
            f"ending_0: {self.endings[0]}",
            f"ending_1: {self.endings[1]}",
            f"ending_2: {self.endings[2]}",
            f"ending_3: {self.endings[3]}",
Grégory Châtel's avatar
Grégory Châtel committed
63
64
65
        ]

        if self.label is not None:
66
67
68
69
70
71
72
73
            l.append(f"label: {self.label}")

        return ", ".join(l)


class InputFeatures(object):
    def __init__(self,
                 example_id,
74
75
                 choices_features,
                 label
76

77
78
    ):
        self.example_id = example_id
79
80
81
82
83
84
85
86
        self.choices_features = [
            {
                'input_ids': input_ids,
                'input_mask': input_mask,
                'segment_ids': segment_ids
            }
            for _, input_ids, input_mask, segment_ids in choices_features
        ]
87
        self.label = label
Grégory Châtel's avatar
Grégory Châtel committed
88

89
90
91
92
93
94
def read_swag_examples(input_file, is_training):
    input_df = pd.read_csv(input_file)

    if is_training and 'label' not in input_df.columns:
        raise ValueError(
            "For training, the input file must contain a label column.")
Grégory Châtel's avatar
Grégory Châtel committed
95

96
97
98
99
    examples = [
        SwagExample(
            swag_id = row['fold-ind'],
            context_sentence = row['sent1'],
100
101
102
            start_ending = row['sent2'], # in the swag dataset, the
                                         # common beginning of each
                                         # choice is stored in "sent2".
103
104
105
106
107
108
109
110
111
112
113
            ending_0 = row['ending0'],
            ending_1 = row['ending1'],
            ending_2 = row['ending2'],
            ending_3 = row['ending3'],
            label = row['label'] if is_training else None
        ) for _, row in input_df.iterrows()
    ]

    return examples


114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
def convert_examples_to_features(examples, tokenizer, max_seq_length,
                                 is_training):
    """Loads a data file into a list of `InputBatch`s."""

    # Swag is a multiple choice task. To perform this task using Bert,
    # we will use the formatting proposed in "Improving Language
    # Understanding by Generative Pre-Training" and suggested by
    # @jacobdevlin-google in this issue
    # https://github.com/google-research/bert/issues/38.
    #
    # Each choice will correspond to a sample on which we run the
    # inference. For a given Swag example, we will create the 4
    # following inputs:
    # - [CLS] context [SEP] choice_1 [SEP]
    # - [CLS] context [SEP] choice_2 [SEP]
    # - [CLS] context [SEP] choice_3 [SEP]
    # - [CLS] context [SEP] choice_4 [SEP]
    # The model will output a single value for each input. To get the
    # final decision of the model, we will run a softmax over these 4
    # outputs.
    features = []
    for example_index, example in enumerate(examples):
        context_tokens = tokenizer.tokenize(example.context_sentence)
        start_ending_tokens = tokenizer.tokenize(example.start_ending)

        choices_features = []
        for ending_index, ending in enumerate(example.endings):
            # We create a copy of the context tokens in order to be
            # able to shrink it according to ending_tokens
            context_tokens_choice = context_tokens[:]
            ending_tokens = start_ending_tokens + tokenizer.tokenize(ending)
            # Modifies `context_tokens_choice` and `ending_tokens` in
            # place so that the total length is less than the
            # specified length.  Account for [CLS], [SEP], [SEP] with
            # "- 3"
149
            _truncate_seq_pair(context_tokens_choice, ending_tokens, max_seq_length - 3)
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181

            tokens = ["[CLS]"] + context_tokens_choice + ["[SEP]"] + ending_tokens + ["[SEP]"]
            segment_ids = [0] * (len(context_tokens_choice) + 2) + [1] * (len(ending_tokens) + 1)

            input_ids = tokenizer.convert_tokens_to_ids(tokens)
            input_mask = [1] * len(input_ids)

            # Zero-pad up to the sequence length.
            padding = [0] * (max_seq_length - len(input_ids))
            input_ids += padding
            input_mask += padding
            segment_ids += padding

            assert len(input_ids) == max_seq_length
            assert len(input_mask) == max_seq_length
            assert len(segment_ids) == max_seq_length

            choices_features.append((tokens, input_ids, input_mask, segment_ids))

        label = example.label
        if example_index < 5:
            logger.info("*** Example ***")
            logger.info(f"swag_id: {example.swag_id}")
            for choice_idx, (tokens, input_ids, input_mask, segment_ids) in enumerate(choices_features):
                logger.info(f"choice: {choice_idx}")
                logger.info(f"tokens: {' '.join(tokens)}")
                logger.info(f"input_ids: {' '.join(map(str, input_ids))}")
                logger.info(f"input_mask: {' '.join(map(str, input_mask))}")
                logger.info(f"segment_ids: {' '.join(map(str, segment_ids))}")
            if is_training:
                logger.info(f"label: {label}")

182
183
184
185
186
187
188
        features.append(
            InputFeatures(
                example_id = example.swag_id,
                choices_features = choices_features,
                label = label
            )
        )
189

190
    return features
191

192

193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
    """Truncates a sequence pair in place to the maximum length."""

    # This is a simple heuristic which will always truncate the longer sequence
    # one token at a time. This makes more sense than truncating an equal percent
    # of tokens from each, since if one sequence is very short then each token
    # that's truncated likely contains more information than a longer sequence.
    while True:
        total_length = len(tokens_a) + len(tokens_b)
        if total_length <= max_length:
            break
        if len(tokens_a) > len(tokens_b):
            tokens_a.pop()
        else:
            tokens_b.pop()


210
if __name__ == "__main__":
211
212
213
    is_training = True
    max_seq_length = 80
    examples = read_swag_examples('data/train.csv', is_training)
214
215
    print(len(examples))
    for example in examples[:5]:
216
        print("###########################")
217
        print(example)
218
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
219
220
221
222
223
224
225
226
    features = convert_examples_to_features(examples[:500], tokenizer, max_seq_length, is_training)
    for i in range(10):
        choice_feature_list = features[i].choices_features
        for choice_idx, choice_feature in enumerate(choice_feature_list):
            print(f'choice_idx: {choice_idx}')
            print(f'input_ids: {" ".join(map(str, choice_feature["input_ids"]))}')
            print(f'input_mask: {" ".join(map(str, choice_feature["input_mask"]))}')
            print(f'segment_ids: {" ".join(map(str, choice_feature["segment_ids"]))}')