run_language_modeling.py 13.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
16
"""
17
18
19
Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, CTRL, BERT, RoBERTa, XLNet).
GPT, GPT-2 and CTRL are fine-tuned using a causal language modeling (CLM) loss. BERT and RoBERTa are fine-tuned
using a masked language modeling (MLM) loss. XLNet is fine-tuned using a permutation language modeling (PLM) loss.
20
"""
21
22
23


import logging
Julien Chaumond's avatar
Julien Chaumond committed
24
import math
25
import os
Julien Chaumond's avatar
Julien Chaumond committed
26
from dataclasses import dataclass, field
27
from glob import glob
Julien Chaumond's avatar
Julien Chaumond committed
28
from typing import Optional
29

30
31
from torch.utils.data import ConcatDataset

32
import transformers
33
from transformers import (
Julien Chaumond's avatar
Julien Chaumond committed
34
    CONFIG_MAPPING,
35
36
37
38
    MODEL_WITH_LM_HEAD_MAPPING,
    AutoConfig,
    AutoModelWithLMHead,
    AutoTokenizer,
Julien Chaumond's avatar
Julien Chaumond committed
39
    DataCollatorForLanguageModeling,
40
    DataCollatorForPermutationLanguageModeling,
41
    DataCollatorForWholeWordMask,
Julien Chaumond's avatar
Julien Chaumond committed
42
43
    HfArgumentParser,
    LineByLineTextDataset,
44
    LineByLineWithRefDataset,
45
    PreTrainedTokenizer,
Julien Chaumond's avatar
Julien Chaumond committed
46
47
48
49
    TextDataset,
    Trainer,
    TrainingArguments,
    set_seed,
50
)
51
from transformers.trainer_utils import is_main_process
52

53

54
logger = logging.getLogger(__name__)
55
56


57
58
MODEL_CONFIG_CLASSES = list(MODEL_WITH_LM_HEAD_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
59
60


Julien Chaumond's avatar
Julien Chaumond committed
61
62
63
64
65
@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
    """
66

Julien Chaumond's avatar
Julien Chaumond committed
67
68
69
70
    model_name_or_path: Optional[str] = field(
        default=None,
        metadata={
            "help": "The model checkpoint for weights initialization. Leave None if you want to train a model from scratch."
71
72
        },
    )
Julien Chaumond's avatar
Julien Chaumond committed
73
74
75
    model_type: Optional[str] = field(
        default=None,
        metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
76
    )
Julien Chaumond's avatar
Julien Chaumond committed
77
78
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
79
    )
Julien Chaumond's avatar
Julien Chaumond committed
80
81
82
83
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )
    cache_dir: Optional[str] = field(
84
85
        default=None,
        metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
86
    )
87
88


Julien Chaumond's avatar
Julien Chaumond committed
89
90
91
92
93
@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    """
94

Julien Chaumond's avatar
Julien Chaumond committed
95
96
    train_data_file: Optional[str] = field(
        default=None, metadata={"help": "The input training data file (a text file)."}
97
    )
98
    train_data_files: Optional[str] = field(
sgugger's avatar
sgugger committed
99
100
        default=None,
        metadata={
101
            "help": "The input training data files (multiple files in glob format). "
sgugger's avatar
sgugger committed
102
103
            "Very often splitting large files to smaller files can prevent tokenizer going out of memory"
        },
104
    )
Julien Chaumond's avatar
Julien Chaumond committed
105
    eval_data_file: Optional[str] = field(
106
        default=None,
Julien Chaumond's avatar
Julien Chaumond committed
107
        metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
Julien Chaumond's avatar
Julien Chaumond committed
108
    )
109
    train_ref_file: Optional[str] = field(
110
        default=None,
111
112
113
114
115
        metadata={"help": "An optional input train ref data file for whole word mask in Chinese."},
    )
    eval_ref_file: Optional[str] = field(
        default=None,
        metadata={"help": "An optional input eval ref data file for whole word mask in Chinese."},
116
    )
Julien Chaumond's avatar
Julien Chaumond committed
117
118
119
    line_by_line: bool = field(
        default=False,
        metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},
120
121
    )

Julien Chaumond's avatar
Julien Chaumond committed
122
123
    mlm: bool = field(
        default=False, metadata={"help": "Train with masked-language modeling loss instead of language modeling."}
124
    )
125
    whole_word_mask: bool = field(default=False, metadata={"help": "Whether ot not to use whole word mask."})
Julien Chaumond's avatar
Julien Chaumond committed
126
127
    mlm_probability: float = field(
        default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}
128
    )
129
130
131
132
133
134
135
136
137
    plm_probability: float = field(
        default=1 / 6,
        metadata={
            "help": "Ratio of length of a span of masked tokens to surrounding context length for permutation language modeling."
        },
    )
    max_span_length: int = field(
        default=5, metadata={"help": "Maximum length of a span of masked tokens for permutation language modeling."}
    )
138

Julien Chaumond's avatar
Julien Chaumond committed
139
    block_size: int = field(
140
        default=-1,
Julien Chaumond's avatar
Julien Chaumond committed
141
142
143
144
145
        metadata={
            "help": "Optional input sequence length after tokenization."
            "The training dataset will be truncated in block of this size for training."
            "Default to the model max input length for single sentence inputs (take into account special tokens)."
        },
146
    )
Julien Chaumond's avatar
Julien Chaumond committed
147
148
    overwrite_cache: bool = field(
        default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
149
150
151
    )


152
153
154
155
156
157
def get_dataset(
    args: DataTrainingArguments,
    tokenizer: PreTrainedTokenizer,
    evaluate: bool = False,
    cache_dir: Optional[str] = None,
):
158
    def _dataset(file_path, ref_path=None):
159
        if args.line_by_line:
160
            if ref_path is not None:
161
162
163
164
165
166
                if not args.whole_word_mask or not args.mlm:
                    raise ValueError("You need to set world whole masking and mlm to True for Chinese Whole Word Mask")
                return LineByLineWithRefDataset(
                    tokenizer=tokenizer,
                    file_path=file_path,
                    block_size=args.block_size,
167
                    ref_path=ref_path,
168
169
                )

170
171
172
173
174
175
176
177
178
179
180
            return LineByLineTextDataset(tokenizer=tokenizer, file_path=file_path, block_size=args.block_size)
        else:
            return TextDataset(
                tokenizer=tokenizer,
                file_path=file_path,
                block_size=args.block_size,
                overwrite_cache=args.overwrite_cache,
                cache_dir=cache_dir,
            )

    if evaluate:
181
        return _dataset(args.eval_data_file, args.eval_ref_file)
182
183
    elif args.train_data_files:
        return ConcatDataset([_dataset(f) for f in glob(args.train_data_files)])
Julien Chaumond's avatar
Julien Chaumond committed
184
    else:
185
        return _dataset(args.train_data_file, args.train_ref_file)
186

187

Julien Chaumond's avatar
Julien Chaumond committed
188
189
190
191
192
193
194
195
196
def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    if data_args.eval_data_file is None and training_args.do_eval:
197
198
199
200
201
        raise ValueError(
            "Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file "
            "or remove the --do_eval argument."
        )
    if (
Julien Chaumond's avatar
Julien Chaumond committed
202
203
204
205
        os.path.exists(training_args.output_dir)
        and os.listdir(training_args.output_dir)
        and training_args.do_train
        and not training_args.overwrite_output_dir
206
207
    ):
        raise ValueError(
Julien Chaumond's avatar
Julien Chaumond committed
208
            f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
209
        )
210
211

    # Setup logging
212
213
214
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
Julien Chaumond's avatar
Julien Chaumond committed
215
        level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
216
217
218
    )
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
Julien Chaumond's avatar
Julien Chaumond committed
219
220
221
222
223
        training_args.local_rank,
        training_args.device,
        training_args.n_gpu,
        bool(training_args.local_rank != -1),
        training_args.fp16,
224
    )
225
226
227
228
229
    # Set the verbosity to info of the Transformers logger (on main process only):
    if is_main_process(training_args.local_rank):
        transformers.utils.logging.set_verbosity_info()
        transformers.utils.logging.enable_default_handler()
        transformers.utils.logging.enable_explicit_format()
Julien Chaumond's avatar
Julien Chaumond committed
230
    logger.info("Training/evaluation parameters %s", training_args)
231
232

    # Set seed
Julien Chaumond's avatar
Julien Chaumond committed
233
    set_seed(training_args.seed)
234
235

    # Load pretrained model and tokenizer
Julien Chaumond's avatar
Julien Chaumond committed
236
237
238
239
240
241
242
243
244
    #
    # Distributed training:
    # The .from_pretrained methods guarantee that only one local process can concurrently
    # download model & vocab.

    if model_args.config_name:
        config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
    elif model_args.model_name_or_path:
        config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
245
    else:
Julien Chaumond's avatar
Julien Chaumond committed
246
247
        config = CONFIG_MAPPING[model_args.model_type]()
        logger.warning("You are instantiating a new config instance from scratch.")
248

Julien Chaumond's avatar
Julien Chaumond committed
249
250
251
252
    if model_args.tokenizer_name:
        tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, cache_dir=model_args.cache_dir)
    elif model_args.model_name_or_path:
        tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
253
    else:
254
        raise ValueError(
255
256
            "You are instantiating a new tokenizer from scratch. This is not supported, but you can do it from another script, save it,"
            "and load it from here, using --tokenizer_name"
257
258
        )

Julien Chaumond's avatar
Julien Chaumond committed
259
    if model_args.model_name_or_path:
260
        model = AutoModelWithLMHead.from_pretrained(
Julien Chaumond's avatar
Julien Chaumond committed
261
262
            model_args.model_name_or_path,
            from_tf=bool(".ckpt" in model_args.model_name_or_path),
263
            config=config,
Julien Chaumond's avatar
Julien Chaumond committed
264
            cache_dir=model_args.cache_dir,
265
266
267
        )
    else:
        logger.info("Training new model from scratch")
268
        model = AutoModelWithLMHead.from_config(config)
269

Julien Chaumond's avatar
Julien Chaumond committed
270
    model.resize_token_embeddings(len(tokenizer))
271

Julien Chaumond's avatar
Julien Chaumond committed
272
273
    if config.model_type in ["bert", "roberta", "distilbert", "camembert"] and not data_args.mlm:
        raise ValueError(
274
275
            "BERT and RoBERTa-like models do not have LM heads but masked LM heads. They must be run using the"
            "--mlm flag (masked language modeling)."
Julien Chaumond's avatar
Julien Chaumond committed
276
        )
277

Julien Chaumond's avatar
Julien Chaumond committed
278
279
280
281
282
    if data_args.block_size <= 0:
        data_args.block_size = tokenizer.max_len
        # Our input block size will be the max possible for the model
    else:
        data_args.block_size = min(data_args.block_size, tokenizer.max_len)
283

Julien Chaumond's avatar
Julien Chaumond committed
284
    # Get datasets
285

286
287
288
289
290
291
292
293
    train_dataset = (
        get_dataset(data_args, tokenizer=tokenizer, cache_dir=model_args.cache_dir) if training_args.do_train else None
    )
    eval_dataset = (
        get_dataset(data_args, tokenizer=tokenizer, evaluate=True, cache_dir=model_args.cache_dir)
        if training_args.do_eval
        else None
    )
294
295
    if config.model_type == "xlnet":
        data_collator = DataCollatorForPermutationLanguageModeling(
Lysandre's avatar
Lysandre committed
296
297
298
            tokenizer=tokenizer,
            plm_probability=data_args.plm_probability,
            max_span_length=data_args.max_span_length,
299
300
        )
    else:
301
302
303
304
305
306
307
308
        if data_args.mlm and data_args.whole_word_mask:
            data_collator = DataCollatorForWholeWordMask(
                tokenizer=tokenizer, mlm_probability=data_args.mlm_probability
            )
        else:
            data_collator = DataCollatorForLanguageModeling(
                tokenizer=tokenizer, mlm=data_args.mlm, mlm_probability=data_args.mlm_probability
            )
309

Julien Chaumond's avatar
Julien Chaumond committed
310
311
312
313
314
315
316
317
318
    # Initialize our Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=data_collator,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        prediction_loss_only=True,
    )
319

Julien Chaumond's avatar
Julien Chaumond committed
320
321
322
323
324
325
326
327
328
    # Training
    if training_args.do_train:
        model_path = (
            model_args.model_name_or_path
            if model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path)
            else None
        )
        trainer.train(model_path=model_path)
        trainer.save_model()
329
330
331
332
        # For convenience, we also re-save the tokenizer to the same directory,
        # so that you can share your model easily on huggingface.co/models =)
        if trainer.is_world_master():
            tokenizer.save_pretrained(training_args.output_dir)
333

Julien Chaumond's avatar
Julien Chaumond committed
334
335
    # Evaluation
    results = {}
336
    if training_args.do_eval:
Julien Chaumond's avatar
Julien Chaumond committed
337
        logger.info("*** Evaluate ***")
338

Julien Chaumond's avatar
Julien Chaumond committed
339
        eval_output = trainer.evaluate()
340

341
        perplexity = math.exp(eval_output["eval_loss"])
Julien Chaumond's avatar
Julien Chaumond committed
342
        result = {"perplexity": perplexity}
343

Julien Chaumond's avatar
Julien Chaumond committed
344
        output_eval_file = os.path.join(training_args.output_dir, "eval_results_lm.txt")
345
346
347
348
349
350
        if trainer.is_world_master():
            with open(output_eval_file, "w") as writer:
                logger.info("***** Eval results *****")
                for key in sorted(result.keys()):
                    logger.info("  %s = %s", key, str(result[key]))
                    writer.write("%s = %s\n" % (key, str(result[key])))
351

Julien Chaumond's avatar
Julien Chaumond committed
352
        results.update(result)
353
354
355
356

    return results


357
358
359
360
361
def _mp_fn(index):
    # For xla_spawn (TPUs)
    main()


362
if __name__ == "__main__":
altsoph's avatar
altsoph committed
363
    main()