run_swag.py 7.71 KB
Newer Older
Grégory Châtel's avatar
Grégory Châtel committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BERT finetuning runner."""

17
18
import pandas as pd

19
20
21
22
23
24
25
26
27
import logging

from pytorch_pretrained_bert.tokenization import BertTokenizer

logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt = '%m/%d/%Y %H:%M:%S',
                    level = logging.INFO)
logger = logging.getLogger(__name__)

28

Grégory Châtel's avatar
Grégory Châtel committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class SwagExample(object):
    """A single training/test example for the SWAG dataset."""
    def __init__(self,
                 swag_id,
                 context_sentence,
                 start_ending,
                 ending_0,
                 ending_1,
                 ending_2,
                 ending_3,
                 label = None):
        self.swag_id = swag_id
        self.context_sentence = context_sentence
        self.start_ending = start_ending
43
44
45
46
47
48
        self.endings = [
            ending_0,
            ending_1,
            ending_2,
            ending_3,
        ]
Grégory Châtel's avatar
Grégory Châtel committed
49
50
51
52
53
54
55
        self.label = label

    def __str__(self):
        return self.__repr__()

    def __repr__(self):
        l = [
56
57
58
59
60
61
62
            f"swag_id: {self.swag_id}",
            f"context_sentence: {self.context_sentence}",
            f"start_ending: {self.start_ending}",
            f"ending_0: {self.endings[0]}",
            f"ending_1: {self.endings[1]}",
            f"ending_2: {self.endings[2]}",
            f"ending_3: {self.endings[3]}",
Grégory Châtel's avatar
Grégory Châtel committed
63
64
65
        ]

        if self.label is not None:
66
67
68
69
70
71
72
73
            l.append(f"label: {self.label}")

        return ", ".join(l)


class InputFeatures(object):
    def __init__(self,
                 example_id,
74
75
                 choices_features,
                 label
76
77
    ):
        self.example_id = example_id
78
79
        self.choices_features = choices_features
        self.label = label
Grégory Châtel's avatar
Grégory Châtel committed
80

81
82
83
84
85
86
def read_swag_examples(input_file, is_training):
    input_df = pd.read_csv(input_file)

    if is_training and 'label' not in input_df.columns:
        raise ValueError(
            "For training, the input file must contain a label column.")
Grégory Châtel's avatar
Grégory Châtel committed
87

88
89
90
91
    examples = [
        SwagExample(
            swag_id = row['fold-ind'],
            context_sentence = row['sent1'],
92
93
94
            start_ending = row['sent2'], # in the swag dataset, the
                                         # common beginning of each
                                         # choice is stored in "sent2".
95
96
97
98
99
100
101
102
103
104
105
            ending_0 = row['ending0'],
            ending_1 = row['ending1'],
            ending_2 = row['ending2'],
            ending_3 = row['ending3'],
            label = row['label'] if is_training else None
        ) for _, row in input_df.iterrows()
    ]

    return examples


106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
def convert_examples_to_features(examples, tokenizer, max_seq_length,
                                 is_training):
    """Loads a data file into a list of `InputBatch`s."""

    # Swag is a multiple choice task. To perform this task using Bert,
    # we will use the formatting proposed in "Improving Language
    # Understanding by Generative Pre-Training" and suggested by
    # @jacobdevlin-google in this issue
    # https://github.com/google-research/bert/issues/38.
    #
    # Each choice will correspond to a sample on which we run the
    # inference. For a given Swag example, we will create the 4
    # following inputs:
    # - [CLS] context [SEP] choice_1 [SEP]
    # - [CLS] context [SEP] choice_2 [SEP]
    # - [CLS] context [SEP] choice_3 [SEP]
    # - [CLS] context [SEP] choice_4 [SEP]
    # The model will output a single value for each input. To get the
    # final decision of the model, we will run a softmax over these 4
    # outputs.
    features = []
    for example_index, example in enumerate(examples):
        context_tokens = tokenizer.tokenize(example.context_sentence)
        start_ending_tokens = tokenizer.tokenize(example.start_ending)

        choices_features = []
        for ending_index, ending in enumerate(example.endings):
            # We create a copy of the context tokens in order to be
            # able to shrink it according to ending_tokens
            context_tokens_choice = context_tokens[:]
            ending_tokens = start_ending_tokens + tokenizer.tokenize(ending)
            # Modifies `context_tokens_choice` and `ending_tokens` in
            # place so that the total length is less than the
            # specified length.  Account for [CLS], [SEP], [SEP] with
            # "- 3"
141
            _truncate_seq_pair(context_tokens_choice, ending_tokens, max_seq_length - 3)
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173

            tokens = ["[CLS]"] + context_tokens_choice + ["[SEP]"] + ending_tokens + ["[SEP]"]
            segment_ids = [0] * (len(context_tokens_choice) + 2) + [1] * (len(ending_tokens) + 1)

            input_ids = tokenizer.convert_tokens_to_ids(tokens)
            input_mask = [1] * len(input_ids)

            # Zero-pad up to the sequence length.
            padding = [0] * (max_seq_length - len(input_ids))
            input_ids += padding
            input_mask += padding
            segment_ids += padding

            assert len(input_ids) == max_seq_length
            assert len(input_mask) == max_seq_length
            assert len(segment_ids) == max_seq_length

            choices_features.append((tokens, input_ids, input_mask, segment_ids))

        label = example.label
        if example_index < 5:
            logger.info("*** Example ***")
            logger.info(f"swag_id: {example.swag_id}")
            for choice_idx, (tokens, input_ids, input_mask, segment_ids) in enumerate(choices_features):
                logger.info(f"choice: {choice_idx}")
                logger.info(f"tokens: {' '.join(tokens)}")
                logger.info(f"input_ids: {' '.join(map(str, input_ids))}")
                logger.info(f"input_mask: {' '.join(map(str, input_mask))}")
                logger.info(f"segment_ids: {' '.join(map(str, segment_ids))}")
            if is_training:
                logger.info(f"label: {label}")

174
175
176
177
178
179
180
        features.append(
            InputFeatures(
                example_id = example.swag_id,
                choices_features = choices_features,
                label = label
            )
        )
181

182
    return features
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200

def _truncate_seq_pair(tokens_a, tokens_b, max_length):
    """Truncates a sequence pair in place to the maximum length."""

    # This is a simple heuristic which will always truncate the longer sequence
    # one token at a time. This makes more sense than truncating an equal percent
    # of tokens from each, since if one sequence is very short then each token
    # that's truncated likely contains more information than a longer sequence.
    while True:
        total_length = len(tokens_a) + len(tokens_b)
        if total_length <= max_length:
            break
        if len(tokens_a) > len(tokens_b):
            tokens_a.pop()
        else:
            tokens_b.pop()


201
if __name__ == "__main__":
202
203
204
    is_training = True
    max_seq_length = 80
    examples = read_swag_examples('data/train.csv', is_training)
205
206
    print(len(examples))
    for example in examples[:5]:
207
        print("###########################")
208
        print(example)
209
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
210
    features = convert_examples_to_features(examples, tokenizer, max_seq_length, is_training)