run_seq2seq_finetuning.py 3.59 KB
Newer Older
Rémi Louf's avatar
Rémi Louf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018 Microsoft and The HuggingFace Inc.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Finetuning seq2seq models for sequence generation.

We use the procedure described in [1] to finetune models for sequence
generation. Let S1 and S2 be the source and target sequence respectively; we
pack them using the start of sequence [SOS] and end of sequence [EOS] token:

    [SOS] S1 [EOS] S2 [EOS]

We then mask a fixed percentage of token from S2 at random and learn to predict
the masked words. [EOS] can be masked during finetuning so the model learns to
terminate the generation process.

[1] Dong Li, Nan Yang, Wenhui Wang, Furu Wei, Xiaodong Liu, Yu Wang, Jianfeng
Gao, Ming Zhou, and Hsiao-Wuen Hon.  “Unified Language Model Pre-Training for
Natural Language Understanding and Generation.” (May 2019) ArXiv:1905.03197
"""

33
import argparse
Rémi Louf's avatar
Rémi Louf committed
34
35
36
37
38
39
import logging
import random

import numpy as np
import torch

40
41
from transformers import BertConfig, Bert2Rnd, BertTokenizer

Rémi Louf's avatar
Rémi Louf committed
42
43
44
45
46
47
48
49
50
logger = logging.getLogger(__name__)


def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)


51
def load_and_cache_examples(args, tokenizer):
Rémi Louf's avatar
Rémi Louf committed
52
53
54
    raise NotImplementedError


55
56
def train(args, train_dataset, model, tokenizer):
    """ Fine-tune the pretrained model on the corpus. """
Rémi Louf's avatar
Rémi Louf committed
57
58
59
60
    raise NotImplementedError


def main():
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument("--train_data_file",
                        default=None,
                        type=str,
                        required=True,
                        help="The input training data file (a text file).")
    parser.add_argument("--output_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="The output directory where the model predictions and checkpoints will be written.")

    # Optional parameters
    parser.add_argument("--model_name_or_path",
                        default="bert-base-cased",
                        type=str,
                        help="The model checkpoint for weights initialization.")
    parser.add_argument("--seed", default=42, type=int)
    args = parser.parse_args()

    # Set up training device
    device = torch.device("cpu")

    # Set seed
    set_seed(args)

    # Load pretrained model and tokenizer
    config_class, model_class, tokenizer_class = BertConfig, Bert2Rnd, BertTokenizer
    config = config_class.from_pretrained(args.model_name_or_path)
    tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
    model = model_class.from_pretrained(args.model_name_or_path, config=config)
    model.to(device)

    logger.info("Training/evaluation parameters %s", args)

    # Training
    train_dataset = load_and_cache_examples(args, tokenizer)
    global_step, tr_loss = train(args, train_dataset, model, tokenizer)
    logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
Rémi Louf's avatar
Rémi Louf committed
102
103


104
if __name__ == "__main__":
Rémi Louf's avatar
Rémi Louf committed
105
    main()