run_language_modeling.py 13.4 KB
Newer Older
1
#!/usr/bin/env python
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
17
"""
18
19
20
Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, CTRL, BERT, RoBERTa, XLNet).
GPT, GPT-2 and CTRL are fine-tuned using a causal language modeling (CLM) loss. BERT and RoBERTa are fine-tuned
using a masked language modeling (MLM) loss. XLNet is fine-tuned using a permutation language modeling (PLM) loss.
21
"""
22
23
24


import logging
Julien Chaumond's avatar
Julien Chaumond committed
25
import math
26
import os
Julien Chaumond's avatar
Julien Chaumond committed
27
from dataclasses import dataclass, field
28
from glob import glob
Julien Chaumond's avatar
Julien Chaumond committed
29
from typing import Optional
30

31
32
from torch.utils.data import ConcatDataset

33
import transformers
34
from transformers import (
Julien Chaumond's avatar
Julien Chaumond committed
35
    CONFIG_MAPPING,
36
37
38
39
    MODEL_WITH_LM_HEAD_MAPPING,
    AutoConfig,
    AutoModelWithLMHead,
    AutoTokenizer,
Julien Chaumond's avatar
Julien Chaumond committed
40
    DataCollatorForLanguageModeling,
41
    DataCollatorForPermutationLanguageModeling,
42
    DataCollatorForWholeWordMask,
Julien Chaumond's avatar
Julien Chaumond committed
43
44
    HfArgumentParser,
    LineByLineTextDataset,
45
    LineByLineWithRefDataset,
46
    PreTrainedTokenizer,
Julien Chaumond's avatar
Julien Chaumond committed
47
48
49
50
    TextDataset,
    Trainer,
    TrainingArguments,
    set_seed,
51
)
52
from transformers.trainer_utils import is_main_process
53

54

55
logger = logging.getLogger(__name__)
56
57


58
59
MODEL_CONFIG_CLASSES = list(MODEL_WITH_LM_HEAD_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
60
61


Julien Chaumond's avatar
Julien Chaumond committed
62
63
64
65
66
@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
    """
67

Julien Chaumond's avatar
Julien Chaumond committed
68
69
70
71
    model_name_or_path: Optional[str] = field(
        default=None,
        metadata={
            "help": "The model checkpoint for weights initialization. Leave None if you want to train a model from scratch."
72
73
        },
    )
Julien Chaumond's avatar
Julien Chaumond committed
74
75
76
    model_type: Optional[str] = field(
        default=None,
        metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
77
    )
Julien Chaumond's avatar
Julien Chaumond committed
78
79
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
80
    )
Julien Chaumond's avatar
Julien Chaumond committed
81
82
83
84
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )
    cache_dir: Optional[str] = field(
85
86
        default=None,
        metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
87
    )
88
89


Julien Chaumond's avatar
Julien Chaumond committed
90
91
92
93
94
@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    """
95

Julien Chaumond's avatar
Julien Chaumond committed
96
97
    train_data_file: Optional[str] = field(
        default=None, metadata={"help": "The input training data file (a text file)."}
98
    )
99
    train_data_files: Optional[str] = field(
sgugger's avatar
sgugger committed
100
101
        default=None,
        metadata={
102
            "help": "The input training data files (multiple files in glob format). "
sgugger's avatar
sgugger committed
103
104
            "Very often splitting large files to smaller files can prevent tokenizer going out of memory"
        },
105
    )
Julien Chaumond's avatar
Julien Chaumond committed
106
    eval_data_file: Optional[str] = field(
107
        default=None,
Julien Chaumond's avatar
Julien Chaumond committed
108
        metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
Julien Chaumond's avatar
Julien Chaumond committed
109
    )
110
    train_ref_file: Optional[str] = field(
111
        default=None,
112
113
114
115
116
        metadata={"help": "An optional input train ref data file for whole word mask in Chinese."},
    )
    eval_ref_file: Optional[str] = field(
        default=None,
        metadata={"help": "An optional input eval ref data file for whole word mask in Chinese."},
117
    )
Julien Chaumond's avatar
Julien Chaumond committed
118
119
120
    line_by_line: bool = field(
        default=False,
        metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},
121
122
    )

Julien Chaumond's avatar
Julien Chaumond committed
123
124
    mlm: bool = field(
        default=False, metadata={"help": "Train with masked-language modeling loss instead of language modeling."}
125
    )
126
    whole_word_mask: bool = field(default=False, metadata={"help": "Whether ot not to use whole word mask."})
Julien Chaumond's avatar
Julien Chaumond committed
127
128
    mlm_probability: float = field(
        default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}
129
    )
130
131
132
133
134
135
136
137
138
    plm_probability: float = field(
        default=1 / 6,
        metadata={
            "help": "Ratio of length of a span of masked tokens to surrounding context length for permutation language modeling."
        },
    )
    max_span_length: int = field(
        default=5, metadata={"help": "Maximum length of a span of masked tokens for permutation language modeling."}
    )
139

Julien Chaumond's avatar
Julien Chaumond committed
140
    block_size: int = field(
141
        default=-1,
Julien Chaumond's avatar
Julien Chaumond committed
142
143
144
145
146
        metadata={
            "help": "Optional input sequence length after tokenization."
            "The training dataset will be truncated in block of this size for training."
            "Default to the model max input length for single sentence inputs (take into account special tokens)."
        },
147
    )
Julien Chaumond's avatar
Julien Chaumond committed
148
149
    overwrite_cache: bool = field(
        default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
150
151
152
    )


153
154
155
156
157
158
def get_dataset(
    args: DataTrainingArguments,
    tokenizer: PreTrainedTokenizer,
    evaluate: bool = False,
    cache_dir: Optional[str] = None,
):
159
    def _dataset(file_path, ref_path=None):
160
        if args.line_by_line:
161
            if ref_path is not None:
162
163
164
165
166
167
                if not args.whole_word_mask or not args.mlm:
                    raise ValueError("You need to set world whole masking and mlm to True for Chinese Whole Word Mask")
                return LineByLineWithRefDataset(
                    tokenizer=tokenizer,
                    file_path=file_path,
                    block_size=args.block_size,
168
                    ref_path=ref_path,
169
170
                )

171
172
173
174
175
176
177
178
179
180
181
            return LineByLineTextDataset(tokenizer=tokenizer, file_path=file_path, block_size=args.block_size)
        else:
            return TextDataset(
                tokenizer=tokenizer,
                file_path=file_path,
                block_size=args.block_size,
                overwrite_cache=args.overwrite_cache,
                cache_dir=cache_dir,
            )

    if evaluate:
182
        return _dataset(args.eval_data_file, args.eval_ref_file)
183
184
    elif args.train_data_files:
        return ConcatDataset([_dataset(f) for f in glob(args.train_data_files)])
Julien Chaumond's avatar
Julien Chaumond committed
185
    else:
186
        return _dataset(args.train_data_file, args.train_ref_file)
187

188

Julien Chaumond's avatar
Julien Chaumond committed
189
190
191
192
193
194
195
196
197
def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    if data_args.eval_data_file is None and training_args.do_eval:
198
199
200
201
202
        raise ValueError(
            "Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file "
            "or remove the --do_eval argument."
        )
    if (
Julien Chaumond's avatar
Julien Chaumond committed
203
204
205
206
        os.path.exists(training_args.output_dir)
        and os.listdir(training_args.output_dir)
        and training_args.do_train
        and not training_args.overwrite_output_dir
207
208
    ):
        raise ValueError(
Julien Chaumond's avatar
Julien Chaumond committed
209
            f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
210
        )
211
212

    # Setup logging
213
    logging.basicConfig(
214
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
215
        datefmt="%m/%d/%Y %H:%M:%S",
Julien Chaumond's avatar
Julien Chaumond committed
216
        level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
217
218
219
    )
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
Julien Chaumond's avatar
Julien Chaumond committed
220
221
222
223
224
        training_args.local_rank,
        training_args.device,
        training_args.n_gpu,
        bool(training_args.local_rank != -1),
        training_args.fp16,
225
    )
226
227
228
229
230
    # Set the verbosity to info of the Transformers logger (on main process only):
    if is_main_process(training_args.local_rank):
        transformers.utils.logging.set_verbosity_info()
        transformers.utils.logging.enable_default_handler()
        transformers.utils.logging.enable_explicit_format()
Julien Chaumond's avatar
Julien Chaumond committed
231
    logger.info("Training/evaluation parameters %s", training_args)
232
233

    # Set seed
Julien Chaumond's avatar
Julien Chaumond committed
234
    set_seed(training_args.seed)
235
236

    # Load pretrained model and tokenizer
Julien Chaumond's avatar
Julien Chaumond committed
237
238
239
240
241
242
243
244
245
    #
    # Distributed training:
    # The .from_pretrained methods guarantee that only one local process can concurrently
    # download model & vocab.

    if model_args.config_name:
        config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
    elif model_args.model_name_or_path:
        config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
246
    else:
Julien Chaumond's avatar
Julien Chaumond committed
247
248
        config = CONFIG_MAPPING[model_args.model_type]()
        logger.warning("You are instantiating a new config instance from scratch.")
249

Julien Chaumond's avatar
Julien Chaumond committed
250
251
252
253
    if model_args.tokenizer_name:
        tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, cache_dir=model_args.cache_dir)
    elif model_args.model_name_or_path:
        tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
254
    else:
255
        raise ValueError(
256
257
            "You are instantiating a new tokenizer from scratch. This is not supported, but you can do it from another script, save it,"
            "and load it from here, using --tokenizer_name"
258
259
        )

Julien Chaumond's avatar
Julien Chaumond committed
260
    if model_args.model_name_or_path:
261
        model = AutoModelWithLMHead.from_pretrained(
Julien Chaumond's avatar
Julien Chaumond committed
262
263
            model_args.model_name_or_path,
            from_tf=bool(".ckpt" in model_args.model_name_or_path),
264
            config=config,
Julien Chaumond's avatar
Julien Chaumond committed
265
            cache_dir=model_args.cache_dir,
266
267
268
        )
    else:
        logger.info("Training new model from scratch")
269
        model = AutoModelWithLMHead.from_config(config)
270

Julien Chaumond's avatar
Julien Chaumond committed
271
    model.resize_token_embeddings(len(tokenizer))
272

Julien Chaumond's avatar
Julien Chaumond committed
273
274
    if config.model_type in ["bert", "roberta", "distilbert", "camembert"] and not data_args.mlm:
        raise ValueError(
275
276
            "BERT and RoBERTa-like models do not have LM heads but masked LM heads. They must be run using the"
            "--mlm flag (masked language modeling)."
Julien Chaumond's avatar
Julien Chaumond committed
277
        )
278

Julien Chaumond's avatar
Julien Chaumond committed
279
280
281
282
283
    if data_args.block_size <= 0:
        data_args.block_size = tokenizer.max_len
        # Our input block size will be the max possible for the model
    else:
        data_args.block_size = min(data_args.block_size, tokenizer.max_len)
284

Julien Chaumond's avatar
Julien Chaumond committed
285
    # Get datasets
286

287
288
289
290
291
292
293
294
    train_dataset = (
        get_dataset(data_args, tokenizer=tokenizer, cache_dir=model_args.cache_dir) if training_args.do_train else None
    )
    eval_dataset = (
        get_dataset(data_args, tokenizer=tokenizer, evaluate=True, cache_dir=model_args.cache_dir)
        if training_args.do_eval
        else None
    )
295
296
    if config.model_type == "xlnet":
        data_collator = DataCollatorForPermutationLanguageModeling(
Lysandre's avatar
Lysandre committed
297
298
299
            tokenizer=tokenizer,
            plm_probability=data_args.plm_probability,
            max_span_length=data_args.max_span_length,
300
301
        )
    else:
302
303
304
305
306
307
308
309
        if data_args.mlm and data_args.whole_word_mask:
            data_collator = DataCollatorForWholeWordMask(
                tokenizer=tokenizer, mlm_probability=data_args.mlm_probability
            )
        else:
            data_collator = DataCollatorForLanguageModeling(
                tokenizer=tokenizer, mlm=data_args.mlm, mlm_probability=data_args.mlm_probability
            )
310

Julien Chaumond's avatar
Julien Chaumond committed
311
312
313
314
315
316
317
318
319
    # Initialize our Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=data_collator,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        prediction_loss_only=True,
    )
320

Julien Chaumond's avatar
Julien Chaumond committed
321
322
323
324
325
326
327
328
329
    # Training
    if training_args.do_train:
        model_path = (
            model_args.model_name_or_path
            if model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path)
            else None
        )
        trainer.train(model_path=model_path)
        trainer.save_model()
330
331
332
333
        # For convenience, we also re-save the tokenizer to the same directory,
        # so that you can share your model easily on huggingface.co/models =)
        if trainer.is_world_master():
            tokenizer.save_pretrained(training_args.output_dir)
334

Julien Chaumond's avatar
Julien Chaumond committed
335
336
    # Evaluation
    results = {}
337
    if training_args.do_eval:
Julien Chaumond's avatar
Julien Chaumond committed
338
        logger.info("*** Evaluate ***")
339

Julien Chaumond's avatar
Julien Chaumond committed
340
        eval_output = trainer.evaluate()
341

342
        perplexity = math.exp(eval_output["eval_loss"])
Julien Chaumond's avatar
Julien Chaumond committed
343
        result = {"perplexity": perplexity}
344

Julien Chaumond's avatar
Julien Chaumond committed
345
        output_eval_file = os.path.join(training_args.output_dir, "eval_results_lm.txt")
346
347
348
349
350
351
        if trainer.is_world_master():
            with open(output_eval_file, "w") as writer:
                logger.info("***** Eval results *****")
                for key in sorted(result.keys()):
                    logger.info("  %s = %s", key, str(result[key]))
                    writer.write("%s = %s\n" % (key, str(result[key])))
352

Julien Chaumond's avatar
Julien Chaumond committed
353
        results.update(result)
354
355
356
357

    return results


358
359
360
361
362
def _mp_fn(index):
    # For xla_spawn (TPUs)
    main()


363
if __name__ == "__main__":
altsoph's avatar
altsoph committed
364
    main()