run_language_modeling.py 10.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
16
"""
17
18
19
Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, CTRL, BERT, RoBERTa, XLNet).
GPT, GPT-2 and CTRL are fine-tuned using a causal language modeling (CLM) loss. BERT and RoBERTa are fine-tuned
using a masked language modeling (MLM) loss. XLNet is fine-tuned using a permutation language modeling (PLM) loss.
20
"""
21
22
23


import logging
Julien Chaumond's avatar
Julien Chaumond committed
24
import math
25
import os
Julien Chaumond's avatar
Julien Chaumond committed
26
27
from dataclasses import dataclass, field
from typing import Optional
28

29
from transformers import (
Julien Chaumond's avatar
Julien Chaumond committed
30
    CONFIG_MAPPING,
31
32
33
34
    MODEL_WITH_LM_HEAD_MAPPING,
    AutoConfig,
    AutoModelWithLMHead,
    AutoTokenizer,
Julien Chaumond's avatar
Julien Chaumond committed
35
    DataCollatorForLanguageModeling,
36
    DataCollatorForPermutationLanguageModeling,
Julien Chaumond's avatar
Julien Chaumond committed
37
38
    HfArgumentParser,
    LineByLineTextDataset,
39
    PreTrainedTokenizer,
Julien Chaumond's avatar
Julien Chaumond committed
40
41
42
43
    TextDataset,
    Trainer,
    TrainingArguments,
    set_seed,
44
)
45

46

47
logger = logging.getLogger(__name__)
48
49


50
51
MODEL_CONFIG_CLASSES = list(MODEL_WITH_LM_HEAD_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
52
53


Julien Chaumond's avatar
Julien Chaumond committed
54
55
56
57
58
@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
    """
59

Julien Chaumond's avatar
Julien Chaumond committed
60
61
62
63
    model_name_or_path: Optional[str] = field(
        default=None,
        metadata={
            "help": "The model checkpoint for weights initialization. Leave None if you want to train a model from scratch."
64
65
        },
    )
Julien Chaumond's avatar
Julien Chaumond committed
66
67
68
    model_type: Optional[str] = field(
        default=None,
        metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
69
    )
Julien Chaumond's avatar
Julien Chaumond committed
70
71
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
72
    )
Julien Chaumond's avatar
Julien Chaumond committed
73
74
75
76
77
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )
    cache_dir: Optional[str] = field(
        default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
78
    )
79
80


Julien Chaumond's avatar
Julien Chaumond committed
81
82
83
84
85
@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    """
86

Julien Chaumond's avatar
Julien Chaumond committed
87
88
    train_data_file: Optional[str] = field(
        default=None, metadata={"help": "The input training data file (a text file)."}
89
    )
Julien Chaumond's avatar
Julien Chaumond committed
90
    eval_data_file: Optional[str] = field(
91
        default=None,
Julien Chaumond's avatar
Julien Chaumond committed
92
        metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
Julien Chaumond's avatar
Julien Chaumond committed
93
    )
Julien Chaumond's avatar
Julien Chaumond committed
94
95
96
    line_by_line: bool = field(
        default=False,
        metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},
97
98
    )

Julien Chaumond's avatar
Julien Chaumond committed
99
100
    mlm: bool = field(
        default=False, metadata={"help": "Train with masked-language modeling loss instead of language modeling."}
101
    )
Julien Chaumond's avatar
Julien Chaumond committed
102
103
    mlm_probability: float = field(
        default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}
104
    )
105
106
107
108
109
110
111
112
113
    plm_probability: float = field(
        default=1 / 6,
        metadata={
            "help": "Ratio of length of a span of masked tokens to surrounding context length for permutation language modeling."
        },
    )
    max_span_length: int = field(
        default=5, metadata={"help": "Maximum length of a span of masked tokens for permutation language modeling."}
    )
114

Julien Chaumond's avatar
Julien Chaumond committed
115
    block_size: int = field(
116
        default=-1,
Julien Chaumond's avatar
Julien Chaumond committed
117
118
119
120
121
        metadata={
            "help": "Optional input sequence length after tokenization."
            "The training dataset will be truncated in block of this size for training."
            "Default to the model max input length for single sentence inputs (take into account special tokens)."
        },
122
    )
Julien Chaumond's avatar
Julien Chaumond committed
123
124
    overwrite_cache: bool = field(
        default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
125
126
127
    )


Julien Chaumond's avatar
Julien Chaumond committed
128
def get_dataset(args: DataTrainingArguments, tokenizer: PreTrainedTokenizer, evaluate=False):
Julien Chaumond's avatar
Julien Chaumond committed
129
130
    file_path = args.eval_data_file if evaluate else args.train_data_file
    if args.line_by_line:
131
        return LineByLineTextDataset(tokenizer=tokenizer, file_path=file_path, block_size=args.block_size)
Julien Chaumond's avatar
Julien Chaumond committed
132
    else:
133
134
135
        return TextDataset(
            tokenizer=tokenizer, file_path=file_path, block_size=args.block_size, overwrite_cache=args.overwrite_cache
        )
136

137

Julien Chaumond's avatar
Julien Chaumond committed
138
139
140
141
142
143
144
145
146
def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    if data_args.eval_data_file is None and training_args.do_eval:
147
148
149
150
151
152
        raise ValueError(
            "Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file "
            "or remove the --do_eval argument."
        )

    if (
Julien Chaumond's avatar
Julien Chaumond committed
153
154
155
156
        os.path.exists(training_args.output_dir)
        and os.listdir(training_args.output_dir)
        and training_args.do_train
        and not training_args.overwrite_output_dir
157
158
    ):
        raise ValueError(
Julien Chaumond's avatar
Julien Chaumond committed
159
            f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
160
        )
161
162

    # Setup logging
163
164
165
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
Julien Chaumond's avatar
Julien Chaumond committed
166
        level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
167
168
169
    )
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
Julien Chaumond's avatar
Julien Chaumond committed
170
171
172
173
174
        training_args.local_rank,
        training_args.device,
        training_args.n_gpu,
        bool(training_args.local_rank != -1),
        training_args.fp16,
175
    )
Julien Chaumond's avatar
Julien Chaumond committed
176
    logger.info("Training/evaluation parameters %s", training_args)
177
178

    # Set seed
Julien Chaumond's avatar
Julien Chaumond committed
179
    set_seed(training_args.seed)
180
181

    # Load pretrained model and tokenizer
Julien Chaumond's avatar
Julien Chaumond committed
182
183
184
185
186
187
188
189
190
    #
    # Distributed training:
    # The .from_pretrained methods guarantee that only one local process can concurrently
    # download model & vocab.

    if model_args.config_name:
        config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
    elif model_args.model_name_or_path:
        config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
191
    else:
Julien Chaumond's avatar
Julien Chaumond committed
192
193
        config = CONFIG_MAPPING[model_args.model_type]()
        logger.warning("You are instantiating a new config instance from scratch.")
194

Julien Chaumond's avatar
Julien Chaumond committed
195
196
197
198
    if model_args.tokenizer_name:
        tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, cache_dir=model_args.cache_dir)
    elif model_args.model_name_or_path:
        tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
199
    else:
200
        raise ValueError(
201
202
            "You are instantiating a new tokenizer from scratch. This is not supported, but you can do it from another script, save it,"
            "and load it from here, using --tokenizer_name"
203
204
        )

Julien Chaumond's avatar
Julien Chaumond committed
205
    if model_args.model_name_or_path:
206
        model = AutoModelWithLMHead.from_pretrained(
Julien Chaumond's avatar
Julien Chaumond committed
207
208
            model_args.model_name_or_path,
            from_tf=bool(".ckpt" in model_args.model_name_or_path),
209
            config=config,
Julien Chaumond's avatar
Julien Chaumond committed
210
            cache_dir=model_args.cache_dir,
211
212
213
        )
    else:
        logger.info("Training new model from scratch")
214
        model = AutoModelWithLMHead.from_config(config)
215

Julien Chaumond's avatar
Julien Chaumond committed
216
    model.resize_token_embeddings(len(tokenizer))
217

Julien Chaumond's avatar
Julien Chaumond committed
218
219
    if config.model_type in ["bert", "roberta", "distilbert", "camembert"] and not data_args.mlm:
        raise ValueError(
220
221
            "BERT and RoBERTa-like models do not have LM heads but masked LM heads. They must be run using the"
            "--mlm flag (masked language modeling)."
Julien Chaumond's avatar
Julien Chaumond committed
222
        )
223

Julien Chaumond's avatar
Julien Chaumond committed
224
225
226
227
228
    if data_args.block_size <= 0:
        data_args.block_size = tokenizer.max_len
        # Our input block size will be the max possible for the model
    else:
        data_args.block_size = min(data_args.block_size, tokenizer.max_len)
229

Julien Chaumond's avatar
Julien Chaumond committed
230
    # Get datasets
231

Julien Chaumond's avatar
Julien Chaumond committed
232
233
    train_dataset = get_dataset(data_args, tokenizer=tokenizer) if training_args.do_train else None
    eval_dataset = get_dataset(data_args, tokenizer=tokenizer, evaluate=True) if training_args.do_eval else None
234
235
    if config.model_type == "xlnet":
        data_collator = DataCollatorForPermutationLanguageModeling(
Lysandre's avatar
Lysandre committed
236
237
238
            tokenizer=tokenizer,
            plm_probability=data_args.plm_probability,
            max_span_length=data_args.max_span_length,
239
240
241
242
243
        )
    else:
        data_collator = DataCollatorForLanguageModeling(
            tokenizer=tokenizer, mlm=data_args.mlm, mlm_probability=data_args.mlm_probability
        )
244

Julien Chaumond's avatar
Julien Chaumond committed
245
246
247
248
249
250
251
252
253
    # Initialize our Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=data_collator,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        prediction_loss_only=True,
    )
254

Julien Chaumond's avatar
Julien Chaumond committed
255
256
257
258
259
260
261
262
263
    # Training
    if training_args.do_train:
        model_path = (
            model_args.model_name_or_path
            if model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path)
            else None
        )
        trainer.train(model_path=model_path)
        trainer.save_model()
264
265
266
267
        # For convenience, we also re-save the tokenizer to the same directory,
        # so that you can share your model easily on huggingface.co/models =)
        if trainer.is_world_master():
            tokenizer.save_pretrained(training_args.output_dir)
268

Julien Chaumond's avatar
Julien Chaumond committed
269
270
    # Evaluation
    results = {}
271
    if training_args.do_eval:
Julien Chaumond's avatar
Julien Chaumond committed
272
        logger.info("*** Evaluate ***")
273

Julien Chaumond's avatar
Julien Chaumond committed
274
        eval_output = trainer.evaluate()
275

276
        perplexity = math.exp(eval_output["eval_loss"])
Julien Chaumond's avatar
Julien Chaumond committed
277
        result = {"perplexity": perplexity}
278

Julien Chaumond's avatar
Julien Chaumond committed
279
        output_eval_file = os.path.join(training_args.output_dir, "eval_results_lm.txt")
280
281
282
283
284
285
        if trainer.is_world_master():
            with open(output_eval_file, "w") as writer:
                logger.info("***** Eval results *****")
                for key in sorted(result.keys()):
                    logger.info("  %s = %s", key, str(result[key]))
                    writer.write("%s = %s\n" % (key, str(result[key])))
286

Julien Chaumond's avatar
Julien Chaumond committed
287
        results.update(result)
288
289
290
291

    return results


292
293
294
295
296
def _mp_fn(index):
    # For xla_spawn (TPUs)
    main()


297
if __name__ == "__main__":
altsoph's avatar
altsoph committed
298
    main()