run_seq2seq_finetuning.py 8.15 KB
Newer Older
Rémi Louf's avatar
Rémi Louf committed
1
# coding=utf-8
2
# Copyright 2018 The Microsoft Reseach team and The HuggingFace Inc. team.
Rémi Louf's avatar
Rémi Louf committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# Copyright (c) 2018 Microsoft and The HuggingFace Inc.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Finetuning seq2seq models for sequence generation.

We use the procedure described in [1] to finetune models for sequence
generation. Let S1 and S2 be the source and target sequence respectively; we
pack them using the start of sequence [SOS] and end of sequence [EOS] token:

    [SOS] S1 [EOS] S2 [EOS]

We then mask a fixed percentage of token from S2 at random and learn to predict
the masked words. [EOS] can be masked during finetuning so the model learns to
terminate the generation process.

[1] Dong Li, Nan Yang, Wenhui Wang, Furu Wei, Xiaodong Liu, Yu Wang, Jianfeng
Gao, Ming Zhou, and Hsiao-Wuen Hon.  “Unified Language Model Pre-Training for
Natural Language Understanding and Generation.” (May 2019) ArXiv:1905.03197
"""

33
import argparse
Rémi Louf's avatar
Rémi Louf committed
34
import logging
35
import pickle
Rémi Louf's avatar
Rémi Louf committed
36
import random
37
import os
Rémi Louf's avatar
Rémi Louf committed
38
39
40

import numpy as np
import torch
41
from torch.utils.data import Dataset
Rémi Louf's avatar
Rémi Louf committed
42

43
44
from transformers import BertConfig, Bert2Rnd, BertTokenizer

Rémi Louf's avatar
Rémi Louf committed
45
46
47
48
49
50
51
52
53
logger = logging.getLogger(__name__)


def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)


54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
class TextDataset(Dataset):
    """ Abstracts a dataset used to train seq2seq models.

    A seq2seq dataset consists in two files:
    - The source file that contains the source sequences, one line per sequence;
    - The target file contains the target sequences, one line per sequence.

    The matching betwen source and target sequences is made on the basis of line numbers.

    CNN/Daily News:

    The CNN/Daily News dataset downloaded from [1] consists of two files that
    respectively contain the stories and the associated summaries. Each line
    corresponds to a different story. The files contain WordPiece tokens.

    train.src: the longest story contains 6966 tokens, the shortest 12.
    Sentences are separated with `[SEP_i]` where i is an int between 0 and 9.

    train.tgt: the longest summary contains 2467 tokens, the shortest 4.
    Sentences are separated with `[X_SEP]` tokens.

    [1] https://github.com/microsoft/unilm
    """
    def __init_(self, tokenizer, src_path='train.src', target_path='target.src' block_size=512):
        assert os.path.isfile(file_path)
        directory, filename = os.path.split(file_path)

        cached_features_file = os.path.join(directory, "cached_lm_{}_{}".format(block_size, file_name)
        if os.path.exists(cached_features_file):
            logger.info("Loading features from cached file %s", cached_features_file)
            with open(cached_features_file, "rb") as source:
                self.examples = pickle.load(source)
        else:
            logger.info("Creating features from dataset at %s", directory)

            self.examples = []
            with open(src_path, encoding="utf-8") as source, open(target_path, encoding="utf-8") as target:
                for line_src, line_tgt in zip(source, target)
                    src_sequence = line_src.read()
                    tgt_sequence = line_tgt.read()
                    example = _truncate_and_concatenate(src_sequence, tgt_sequence, block_size)
                    if example is not None:
                        example = tokenizer.convert_tokens_to_ids(example)
                        self.examples.append(example)

            logger.info("Saving features into cache file %s", cached_features_file)
            with open(cached_features_file, "wb") as sink:
                pickle.dump(self.examples, sink, protocole=pickle.HIGHEST_PROTOCOL)

    def __len__(self):
        return len(self.examples)

    def __getitem__(self):
        return torch.tensor(self.examples[items])


def _truncate_and_concatenate(src_sequence, tgt_sequence, block_size):
    """ Concatenate the sequences and adapt their lengths to the block size.

    Following [1] we perform the following transformations:
    - Add an [CLS] token at the beginning of the source sequence;
    - Add an [EOS] token at the end of the source and target sequences;
    - Concatenate the source and target + tokens sequence. If the concatenated sequence is
      longer than 512 we follow the 75%/25% rule in [1]: limit the source sequence's length to 384
      and the target sequence's length to 128.

    [1] Dong, Li, et al. "Unified Language Model Pre-training for Natural
    Language Understanding and Generation." arXiv preprint arXiv:1905.03197 (2019).
    """
    SRC_MAX_LENGTH = int(0.75 * block_size) - 2 # CLS and EOS token
    TGT_MAX_LENGTH = block_size - SRC_MAX_LENGTH - 1 # EOS token

    # the dataset contains special separator tokens that we remove for now.
    # They are of the form `[SEP_i]` in the source file, and `[X_SEP]` in the
    # target file.
    src_tokens = list(filter(lambda t: "[SEP_" in t, src_sequence.split(" ")))
    tgt_tokens = list(filter(lambda t: "_SEP]" in t, tgt_sequence.split(" ")))

    # we dump the examples that are too small to fit in the block size for the
    # sake of simplicity. You can modify this by adding model-specific padding.
    if len(src_tokens) + len(src_tokens) + 3 < block_size:
        return None

    # the source sequence has `[SEP_i]` special tokens with i \in [0,9]. We keep them for now.
    if len(src_tokens) > SRC_MAX_LENGTH
        if len(tgt_tokens) > TGT_MAX_LENGTH:
            src_tokens = src_tokens[:SRC_MAX_LENGTH]
            tgt_tokens = tgt_tokens[:TGT_MAX_LENGTH]
        else:
            src_tokens = src_tokens[block_size - len(tgt_tokens) - 3]
    else:
        if len(tgt_tokens) > TGT_MAX_LENGTH:
            tgt_tokens = tgt_tokens[block_size - len(src_tokens) - 3]

    return ["[CLS]"] + src_tokens + ["[EOS]"] + tgt_tokens + ["[EOS]"]



152
def load_and_cache_examples(args, tokenizer):
153
154
    dataset = TextDataset(tokenizer, file_path=args.train_data_file)
    return dataset
Rémi Louf's avatar
Rémi Louf committed
155
156


157
158
def train(args, train_dataset, model, tokenizer):
    """ Fine-tune the pretrained model on the corpus. """
Rémi Louf's avatar
Rémi Louf committed
159
160
161
162
    raise NotImplementedError


def main():
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument("--train_data_file",
                        default=None,
                        type=str,
                        required=True,
                        help="The input training data file (a text file).")
    parser.add_argument("--output_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="The output directory where the model predictions and checkpoints will be written.")

    # Optional parameters
    parser.add_argument("--model_name_or_path",
                        default="bert-base-cased",
                        type=str,
                        help="The model checkpoint for weights initialization.")
    parser.add_argument("--seed", default=42, type=int)
    args = parser.parse_args()

    # Set up training device
    device = torch.device("cpu")

    # Set seed
    set_seed(args)

    # Load pretrained model and tokenizer
    config_class, model_class, tokenizer_class = BertConfig, Bert2Rnd, BertTokenizer
    config = config_class.from_pretrained(args.model_name_or_path)
    tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
    model = model_class.from_pretrained(args.model_name_or_path, config=config)
    model.to(device)

    logger.info("Training/evaluation parameters %s", args)

    # Training
    train_dataset = load_and_cache_examples(args, tokenizer)
    global_step, tr_loss = train(args, train_dataset, model, tokenizer)
    logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
Rémi Louf's avatar
Rémi Louf committed
204
205


206
if __name__ == "__main__":
207
    main()