run_language_modeling.py 33.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
16
"""
LysandreJik's avatar
LysandreJik committed
17
Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, BERT, RoBERTa).
18
19
20
GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT and RoBERTa are fine-tuned
using a masked language modeling (MLM) loss.
"""
21
22
23
24
25
26


import argparse
import glob
import logging
import os
27
import pickle
28
import random
jinoobaek-qz's avatar
jinoobaek-qz committed
29
30
import re
import shutil
31
from typing import Dict, List, Tuple
32
33
34

import numpy as np
import torch
35
from torch.nn.utils.rnn import pad_sequence
Aymeric Augustin's avatar
Aymeric Augustin committed
36
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
37
38
39
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange

40
41
42
43
44
45
from transformers import (
    WEIGHTS_NAME,
    AdamW,
    BertConfig,
    BertForMaskedLM,
    BertTokenizer,
Aymeric Augustin's avatar
Aymeric Augustin committed
46
47
48
49
50
51
    CamembertConfig,
    CamembertForMaskedLM,
    CamembertTokenizer,
    DistilBertConfig,
    DistilBertForMaskedLM,
    DistilBertTokenizer,
52
53
54
55
56
57
    GPT2Config,
    GPT2LMHeadModel,
    GPT2Tokenizer,
    OpenAIGPTConfig,
    OpenAIGPTLMHeadModel,
    OpenAIGPTTokenizer,
58
    PreTrainedModel,
59
    PreTrainedTokenizer,
60
61
62
    RobertaConfig,
    RobertaForMaskedLM,
    RobertaTokenizer,
Aymeric Augustin's avatar
Aymeric Augustin committed
63
    get_linear_schedule_with_warmup,
64
)
65

66

Aymeric Augustin's avatar
Aymeric Augustin committed
67
68
try:
    from torch.utils.tensorboard import SummaryWriter
69
except ImportError:
Aymeric Augustin's avatar
Aymeric Augustin committed
70
71
72
    from tensorboardX import SummaryWriter


73
logger = logging.getLogger(__name__)
74
75
76


MODEL_CLASSES = {
77
78
79
80
81
82
    "gpt2": (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
    "openai-gpt": (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
    "bert": (BertConfig, BertForMaskedLM, BertTokenizer),
    "roberta": (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer),
    "distilbert": (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer),
    "camembert": (CamembertConfig, CamembertForMaskedLM, CamembertTokenizer),
83
84
85
}


86
class TextDataset(Dataset):
87
    def __init__(self, tokenizer: PreTrainedTokenizer, args, file_path: str, block_size=512):
88
        assert os.path.isfile(file_path)
89
90
91

        block_size = block_size - (tokenizer.max_len - tokenizer.max_len_single_sentence)

92
        directory, filename = os.path.split(file_path)
93
        cached_features_file = os.path.join(
94
            directory, args.model_type + "_cached_lm_" + str(block_size) + "_" + filename
95
        )
96

Lysandre's avatar
Lysandre committed
97
        if os.path.exists(cached_features_file) and not args.overwrite_cache:
98
            logger.info("Loading features from cached file %s", cached_features_file)
99
            with open(cached_features_file, "rb") as handle:
100
101
102
103
104
105
106
107
108
                self.examples = pickle.load(handle)
        else:
            logger.info("Creating features from dataset file at %s", directory)

            self.examples = []
            with open(file_path, encoding="utf-8") as f:
                text = f.read()

            tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
109

110
111
            for i in range(0, len(tokenized_text) - block_size + 1, block_size):  # Truncate in block of block_size
                self.examples.append(tokenizer.build_inputs_with_special_tokens(tokenized_text[i : i + block_size]))
112
113
114
115
116
            # Note that we are loosing the last truncated example here for the sake of simplicity (no padding)
            # If your dataset is small, first you should loook for a bigger one :-) and second you
            # can change this behavior by adding (model specific) padding.

            logger.info("Saving features into cached file %s", cached_features_file)
117
            with open(cached_features_file, "wb") as handle:
118
119
120
121
122
123
                pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, item):
124
        return torch.tensor(self.examples[item], dtype=torch.long)
125
126


127
128
129
130
131
132
133
134
135
class LineByLineTextDataset(Dataset):
    def __init__(self, tokenizer: PreTrainedTokenizer, args, file_path: str, block_size=512):
        assert os.path.isfile(file_path)
        # Here, we do not cache the features, operating under the assumption
        # that we will soon use fast multithreaded tokenizers from the
        # `tokenizers` repo everywhere =)
        logger.info("Creating features from dataset file at %s", file_path)

        with open(file_path, encoding="utf-8") as f:
136
            lines = [line for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())]
137

138
        self.examples = tokenizer.batch_encode_plus(lines, add_special_tokens=True, max_length=block_size)["input_ids"]
139
140
141
142
143

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, i):
144
        return torch.tensor(self.examples[i], dtype=torch.long)
145
146


147
def load_and_cache_examples(args, tokenizer, evaluate=False):
148
149
150
151
152
    file_path = args.eval_data_file if evaluate else args.train_data_file
    if args.line_by_line:
        return LineByLineTextDataset(tokenizer, args, file_path=file_path, block_size=args.block_size)
    else:
        return TextDataset(tokenizer, args, file_path=file_path, block_size=args.block_size)
153
154


155
156
157
158
159
160
161
def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

162

163
164
def _sorted_checkpoints(args, checkpoint_prefix="checkpoint", use_mtime=False) -> List[str]:
    ordering_and_checkpoint_path = []
165

166
    glob_checkpoints = glob.glob(os.path.join(args.output_dir, "{}-*".format(checkpoint_prefix)))
jinoobaek-qz's avatar
jinoobaek-qz committed
167
168

    for path in glob_checkpoints:
169
170
171
        if use_mtime:
            ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
        else:
172
            regex_match = re.match(".*{}-([0-9]+)".format(checkpoint_prefix), path)
173
174
175
176
            if regex_match and regex_match.groups():
                ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))

    checkpoints_sorted = sorted(ordering_and_checkpoint_path)
jinoobaek-qz's avatar
jinoobaek-qz committed
177
    checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
178
179
180
181
182
183
184
185
186
187
188
189
190
191
    return checkpoints_sorted


def _rotate_checkpoints(args, checkpoint_prefix="checkpoint", use_mtime=False) -> None:
    if not args.save_total_limit:
        return
    if args.save_total_limit <= 0:
        return

    # Check if we should delete older checkpoint(s)
    checkpoints_sorted = _sorted_checkpoints(args, checkpoint_prefix, use_mtime)
    if len(checkpoints_sorted) <= args.save_total_limit:
        return

jinoobaek-qz's avatar
jinoobaek-qz committed
192
193
194
195
196
    number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - args.save_total_limit)
    checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
    for checkpoint in checkpoints_to_be_deleted:
        logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint))
        shutil.rmtree(checkpoint)
jinoobaek-qz's avatar
jinoobaek-qz committed
197
198


199
def mask_tokens(inputs: torch.Tensor, tokenizer: PreTrainedTokenizer, args) -> Tuple[torch.Tensor, torch.Tensor]:
200
    """ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """
201
202
203
204
205
206

    if tokenizer.mask_token is None:
        raise ValueError(
            "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer."
        )

207
    labels = inputs.clone()
208
    # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
209
    probability_matrix = torch.full(labels.shape, args.mlm_probability)
210
211
212
    special_tokens_mask = [
        tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
    ]
213
    probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
214
215
216
    if tokenizer._pad_token is not None:
        padding_mask = labels.eq(tokenizer.pad_token_id)
        probability_matrix.masked_fill_(padding_mask, value=0.0)
217
    masked_indices = torch.bernoulli(probability_matrix).bool()
Lysandre's avatar
Lysandre committed
218
    labels[~masked_indices] = -100  # We only compute loss on masked tokens
219
220

    # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
thomwolf's avatar
thomwolf committed
221
    indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
222
223
224
    inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)

    # 10% of the time, we replace masked input tokens with random word
thomwolf's avatar
thomwolf committed
225
    indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
226
227
    random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
    inputs[indices_random] = random_words[indices_random]
228

229
    # The rest of the time (10% of the time) we keep the masked input tokens unchanged
230
    return inputs, labels
231

232

233
def train(args, train_dataset, model: PreTrainedModel, tokenizer: PreTrainedTokenizer) -> Tuple[int, float]:
234
235
236
237
238
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
239
240

    def collate(examples: List[torch.Tensor]):
241
242
        if tokenizer._pad_token is None:
            return pad_sequence(examples, batch_first=True)
243
244
        return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id)

thomwolf's avatar
thomwolf committed
245
    train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
246
247
248
    train_dataloader = DataLoader(
        train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, collate_fn=collate
    )
249
250
251
252
253
254
255
256

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
257
    no_decay = ["bias", "LayerNorm.weight"]
258
    optimizer_grouped_parameters = [
259
260
261
262
263
264
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
    ]
265
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
266
267
268
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
    )
269
270

    # Check if saved optimizer or scheduler states exist
Julien Chaumond's avatar
Julien Chaumond committed
271
272
273
274
    if (
        args.model_name_or_path
        and os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt"))
        and os.path.isfile(os.path.join(args.model_name_or_path, "scheduler.pt"))
275
    ):
276
        # Load in optimizer and scheduler states
277
278
        optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))
279

280
281
282
283
284
285
286
287
288
289
290
291
292
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
293
294
295
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
        )
296
297
298
299
300
301

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
302
303
304
305
306
307
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size
        * args.gradient_accumulation_steps
        * (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
308
309
310
311
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
312
313
314
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
315
    if args.model_name_or_path and os.path.exists(args.model_name_or_path):
316
317
318
319
320
321
322
323
324
325
326
327
328
        try:
            # set global_step to gobal_step of last saved checkpoint from model path
            checkpoint_suffix = args.model_name_or_path.split("-")[-1].split("/")[0]
            global_step = int(checkpoint_suffix)
            epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
            steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)

            logger.info("  Continuing training from checkpoint, will skip to saved global_step")
            logger.info("  Continuing training from epoch %d", epochs_trained)
            logger.info("  Continuing training from global step %d", global_step)
            logger.info("  Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
        except ValueError:
            logger.info("  Starting fine-tuning.")
329

330
    tr_loss, logging_loss = 0.0, 0.0
thomwolf's avatar
thomwolf committed
331

332
    model_to_resize = model.module if hasattr(model, "module") else model  # Take care of distributed/parallel training
thomwolf's avatar
thomwolf committed
333
334
    model_to_resize.resize_token_embeddings(len(tokenizer))

335
    model.zero_grad()
336
337
338
    train_iterator = trange(
        epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]
    )
339
    set_seed(args)  # Added here for reproducibility
Bilal Khan's avatar
Bilal Khan committed
340
    for _ in train_iterator:
341
342
        epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):
343

344
345
346
347
348
            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

349
            inputs, labels = mask_tokens(batch, tokenizer, args) if args.mlm else (batch, batch)
350
351
352
            inputs = inputs.to(args.device)
            labels = labels.to(args.device)
            model.train()
353
            outputs = model(inputs, masked_lm_labels=labels) if args.mlm else model(inputs, labels=labels)
354
            loss = outputs[0]  # model outputs are always tuple in transformers (see doc)
355
356

            if args.n_gpu > 1:
357
                loss = loss.mean()  # mean() to average on multi-gpu parallel training
358
359
360
361
362
363
364
365
366
367
368
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
369
370
371
372
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
373
                optimizer.step()
374
                scheduler.step()  # Update learning rate schedule
375
376
377
378
379
                model.zero_grad()
                global_step += 1

                if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
380
381
382
                    if (
                        args.local_rank == -1 and args.evaluate_during_training
                    ):  # Only evaluate when single GPU otherwise metrics may not average well
383
384
                        results = evaluate(args, model, tokenizer)
                        for key, value in results.items():
385
386
387
                            tb_writer.add_scalar("eval_{}".format(key), value, global_step)
                    tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step)
388
389
390
                    logging_loss = tr_loss

                if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
391
                    checkpoint_prefix = "checkpoint"
392
                    # Save model checkpoint
393
                    output_dir = os.path.join(args.output_dir, "{}-{}".format(checkpoint_prefix, global_step))
394
                    os.makedirs(output_dir, exist_ok=True)
395
396
397
                    model_to_save = (
                        model.module if hasattr(model, "module") else model
                    )  # Take care of distributed/parallel training
398
                    model_to_save.save_pretrained(output_dir)
399
400
                    tokenizer.save_pretrained(output_dir)

401
                    torch.save(args, os.path.join(output_dir, "training_args.bin"))
402
403
                    logger.info("Saving model checkpoint to %s", output_dir)

404
                    _rotate_checkpoints(args, checkpoint_prefix)
jinoobaek-qz's avatar
jinoobaek-qz committed
405

406
407
                    torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
Bilal Khan's avatar
Bilal Khan committed
408
                    logger.info("Saving optimizer and scheduler states to %s", output_dir)
409

410
411
412
413
414
415
416
417
418
419
420
421
422
            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step


423
def evaluate(args, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prefix="") -> Dict:
424
425
426
427
428
    # Loop to handle MNLI double evaluation (matched, mis-matched)
    eval_output_dir = args.output_dir

    eval_dataset = load_and_cache_examples(args, tokenizer, evaluate=True)

429
430
    if args.local_rank in [-1, 0]:
        os.makedirs(eval_output_dir, exist_ok=True)
431
432
433

    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
    # Note that DistributedSampler samples randomly
434
435

    def collate(examples: List[torch.Tensor]):
436
437
        if tokenizer._pad_token is None:
            return pad_sequence(examples, batch_first=True)
438
439
        return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id)

440
    eval_sampler = SequentialSampler(eval_dataset)
441
442
443
    eval_dataloader = DataLoader(
        eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, collate_fn=collate
    )
444

ronakice's avatar
ronakice committed
445
446
447
448
    # multi-gpu evaluate
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

449
450
451
452
453
454
    # Eval!
    logger.info("***** Running evaluation {} *****".format(prefix))
    logger.info("  Num examples = %d", len(eval_dataset))
    logger.info("  Batch size = %d", args.eval_batch_size)
    eval_loss = 0.0
    nb_eval_steps = 0
455
456
    model.eval()

457
    for batch in tqdm(eval_dataloader, desc="Evaluating"):
altsoph's avatar
altsoph committed
458
459
460
        inputs, labels = mask_tokens(batch, tokenizer, args) if args.mlm else (batch, batch)
        inputs = inputs.to(args.device)
        labels = labels.to(args.device)
461
462

        with torch.no_grad():
altsoph's avatar
altsoph committed
463
            outputs = model(inputs, masked_lm_labels=labels) if args.mlm else model(inputs, labels=labels)
464
465
466
467
468
469
470
            lm_loss = outputs[0]
            eval_loss += lm_loss.mean().item()
        nb_eval_steps += 1

    eval_loss = eval_loss / nb_eval_steps
    perplexity = torch.exp(torch.tensor(eval_loss))

471
    result = {"perplexity": perplexity}
472

473
    output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
474
475
476
477
478
479
    with open(output_eval_file, "w") as writer:
        logger.info("***** Eval results {} *****".format(prefix))
        for key in sorted(result.keys()):
            logger.info("  %s = %s", key, str(result[key]))
            writer.write("%s = %s\n" % (key, str(result[key])))

480
    return result
481
482
483
484
485


def main():
    parser = argparse.ArgumentParser()

486
    # Required parameters
487
488
489
490
491
492
493
494
495
    parser.add_argument(
        "--train_data_file", default=None, type=str, required=True, help="The input training data file (a text file)."
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        required=True,
        help="The output directory where the model predictions and checkpoints will be written.",
    )
496
497
498
    parser.add_argument(
        "--model_type", type=str, required=True, help="The model architecture to be trained or fine-tuned.",
    )
499

500
    # Other parameters
501
502
503
504
505
506
    parser.add_argument(
        "--eval_data_file",
        default=None,
        type=str,
        help="An optional input evaluation data file to evaluate the perplexity on (a text file).",
    )
507
508
509
510
511
    parser.add_argument(
        "--line_by_line",
        action="store_true",
        help="Whether distinct lines of text in the dataset are to be handled as distinct sequences.",
    )
Julien Chaumond's avatar
Julien Chaumond committed
512
513
514
    parser.add_argument(
        "--should_continue", action="store_true", help="Whether to continue from latest checkpoint in output_dir"
    )
515
516
    parser.add_argument(
        "--model_name_or_path",
517
        default=None,
518
        type=str,
519
        help="The model checkpoint for weights initialization. Leave None if you want to train a model from scratch.",
520
521
522
523
524
525
526
527
528
529
530
    )

    parser.add_argument(
        "--mlm", action="store_true", help="Train with masked-language modeling loss instead of language modeling."
    )
    parser.add_argument(
        "--mlm_probability", type=float, default=0.15, help="Ratio of tokens to mask for masked language modeling loss"
    )

    parser.add_argument(
        "--config_name",
531
        default=None,
532
        type=str,
533
        help="Optional pretrained config name or path if not the same as model_name_or_path. If both are None, initialize a new config.",
534
535
536
    )
    parser.add_argument(
        "--tokenizer_name",
537
538
539
540
        default=None,
        type=str,
        help="Optional pretrained tokenizer name or path if not the same as model_name_or_path. If both are None, initialize a new tokenizer.",
    )
541
542
    parser.add_argument(
        "--cache_dir",
543
        default=None,
544
        type=str,
Oren Amsalem's avatar
Oren Amsalem committed
545
        help="Optional directory to store the pre-trained models downloaded from s3 (instead of the default one)",
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
    )
    parser.add_argument(
        "--block_size",
        default=-1,
        type=int,
        help="Optional input sequence length after tokenization."
        "The training dataset will be truncated in block of this size for training."
        "Default to the model max input length for single sentence inputs (take into account special tokens).",
    )
    parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
    parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--evaluate_during_training", action="store_true", help="Run evaluation during training at each logging step."
    )

    parser.add_argument("--per_gpu_train_batch_size", default=4, type=int, help="Batch size per GPU/CPU for training.")
    parser.add_argument(
        "--per_gpu_eval_batch_size", default=4, type=int, help="Batch size per GPU/CPU for evaluation."
    )
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help="Number of updates steps to accumulate before performing a backward/update pass.",
    )
    parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
    parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
    parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
    parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
    parser.add_argument(
        "--num_train_epochs", default=1.0, type=float, help="Total number of training epochs to perform."
    )
    parser.add_argument(
        "--max_steps",
        default=-1,
        type=int,
        help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
    )
    parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")

586
587
    parser.add_argument("--logging_steps", type=int, default=500, help="Log every X updates steps.")
    parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.")
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
    parser.add_argument(
        "--save_total_limit",
        type=int,
        default=None,
        help="Limit the total amount of checkpoints, delete the older checkpoints in the output_dir, does not delete by default",
    )
    parser.add_argument(
        "--eval_all_checkpoints",
        action="store_true",
        help="Evaluate all checkpoints starting with the same prefix as model_name_or_path ending and ending with step number",
    )
    parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
    parser.add_argument(
        "--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory"
    )
    parser.add_argument(
        "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
    )
    parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")

    parser.add_argument(
        "--fp16",
        action="store_true",
        help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
    )
    parser.add_argument(
        "--fp16_opt_level",
        type=str,
        default="O1",
        help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
        "See details at https://nvidia.github.io/apex/amp.html",
    )
    parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
    parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
    parser.add_argument("--server_port", type=str, default="", help="For distant debugging.")
623
624
    args = parser.parse_args()

maxvidal's avatar
maxvidal committed
625
    if args.model_type in ["bert", "roberta", "distilbert", "camembert"] and not args.mlm:
626
        raise ValueError(
627
            "BERT and RoBERTa-like models do not have LM heads but masked LM heads. They must be run using the --mlm "
628
629
            "flag (masked language modeling)."
        )
630
    if args.eval_data_file is None and args.do_eval:
631
632
633
634
        raise ValueError(
            "Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file "
            "or remove the --do_eval argument."
        )
635
636
637
    if args.should_continue:
        sorted_checkpoints = _sorted_checkpoints(args)
        if len(sorted_checkpoints) == 0:
Julien Chaumond's avatar
Julien Chaumond committed
638
            raise ValueError("Used --should_continue but no checkpoint was found in --output_dir.")
639
640
        else:
            args.model_name_or_path = sorted_checkpoints[-1]
641
642
643
644
645
646
647
648
649
650
651
652

    if (
        os.path.exists(args.output_dir)
        and os.listdir(args.output_dir)
        and args.do_train
        and not args.overwrite_output_dir
    ):
        raise ValueError(
            "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
                args.output_dir
            )
        )
653
654
655
656
657

    # Setup distant debugging if needed
    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
658

659
660
661
662
663
664
665
666
667
668
669
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
        ptvsd.wait_for_attach()

    # Setup CUDA, GPU & distributed training
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        args.n_gpu = torch.cuda.device_count()
    else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
670
        torch.distributed.init_process_group(backend="nccl")
671
672
673
674
        args.n_gpu = 1
    args.device = device

    # Setup logging
675
676
677
678
679
680
681
682
683
684
685
686
687
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
    )
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        args.local_rank,
        device,
        args.n_gpu,
        bool(args.local_rank != -1),
        args.fp16,
    )
688
689
690
691
692
693

    # Set seed
    set_seed(args)

    # Load pretrained model and tokenizer
    if args.local_rank not in [-1, 0]:
694
695
696
        torch.distributed.barrier()  # Barrier to make sure only the first process in distributed training download model & vocab

    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
697
698
699
700
701
702
703
704
705
706
707
708
709

    if args.config_name:
        config = config_class.from_pretrained(args.config_name, cache_dir=args.cache_dir)
    elif args.model_name_or_path:
        config = config_class.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
    else:
        config = config_class()

    if args.tokenizer_name:
        tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name, cache_dir=args.cache_dir)
    elif args.model_name_or_path:
        tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
    else:
710
711
712
        raise ValueError(
            "You are instantiating a new {} tokenizer. This is not supported, but you can do it from another script, save it,"
            "and load it from here, using --tokenizer_name".format(tokenizer_class.__name__)
713
714
        )

715
    if args.block_size <= 0:
716
        args.block_size = tokenizer.max_len
717
718
        # Our input block size will be the max possible for the model
    else:
719
        args.block_size = min(args.block_size, tokenizer.max_len)
720
721
722
723
724
725
726
727
728
729
730
731

    if args.model_name_or_path:
        model = model_class.from_pretrained(
            args.model_name_or_path,
            from_tf=bool(".ckpt" in args.model_name_or_path),
            config=config,
            cache_dir=args.cache_dir,
        )
    else:
        logger.info("Training new model from scratch")
        model = model_class(config=config)

732
    model.to(args.device)
733
734

    if args.local_rank == 0:
735
        torch.distributed.barrier()  # End of barrier to make sure only the first process in distributed training download model & vocab
736
737
738
739
740

    logger.info("Training/evaluation parameters %s", args)

    # Training
    if args.do_train:
741
742
743
        if args.local_rank not in [-1, 0]:
            torch.distributed.barrier()  # Barrier to make sure only the first process in distributed training process the dataset, and the others will use the cache

744
        train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False)
745
746
747
748

        if args.local_rank == 0:
            torch.distributed.barrier()

749
750
751
        global_step, tr_loss = train(args, train_dataset, model, tokenizer)
        logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)

752
    # Saving best-practices: if you use save_pretrained for the model and tokenizer, you can reload them using from_pretrained()
753
754
    if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
        # Create output directory if needed
755
756
        if args.local_rank in [-1, 0]:
            os.makedirs(args.output_dir, exist_ok=True)
757
758
759
760

        logger.info("Saving model checkpoint to %s", args.output_dir)
        # Save a trained model, configuration and tokenizer using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
761
762
763
        model_to_save = (
            model.module if hasattr(model, "module") else model
        )  # Take care of distributed/parallel training
764
765
766
767
        model_to_save.save_pretrained(args.output_dir)
        tokenizer.save_pretrained(args.output_dir)

        # Good practice: save your training arguments together with the trained model
768
        torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
769
770
771

        # Load a trained model and vocabulary that you have fine-tuned
        model = model_class.from_pretrained(args.output_dir)
772
        tokenizer = tokenizer_class.from_pretrained(args.output_dir)
773
774
775
776
777
778
779
        model.to(args.device)

    # Evaluation
    results = {}
    if args.do_eval and args.local_rank in [-1, 0]:
        checkpoints = [args.output_dir]
        if args.eval_all_checkpoints:
780
781
782
            checkpoints = list(
                os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
            )
783
            logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN)  # Reduce logging
784
785
        logger.info("Evaluate the following checkpoints: %s", checkpoints)
        for checkpoint in checkpoints:
786
787
788
            global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
            prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""

789
790
            model = model_class.from_pretrained(checkpoint)
            model.to(args.device)
791
            result = evaluate(args, model, tokenizer, prefix=prefix)
792
            result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
793
794
795
796
797
798
            results.update(result)

    return results


if __name__ == "__main__":
altsoph's avatar
altsoph committed
799
    main()