run_language_modeling.py 33.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
16
"""
LysandreJik's avatar
LysandreJik committed
17
Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, BERT, RoBERTa).
18
19
20
GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT and RoBERTa are fine-tuned
using a masked language modeling (MLM) loss.
"""
21
22
23
24
25
26


import argparse
import glob
import logging
import os
27
import pickle
28
import random
jinoobaek-qz's avatar
jinoobaek-qz committed
29
30
import re
import shutil
31
from typing import Dict, List, Tuple
32
33
34

import numpy as np
import torch
35
from torch.nn.utils.rnn import pad_sequence
Aymeric Augustin's avatar
Aymeric Augustin committed
36
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
37
38
39
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange

40
from transformers import (
41
    MODEL_WITH_LM_HEAD_MAPPING,
42
43
    WEIGHTS_NAME,
    AdamW,
44
45
46
    AutoConfig,
    AutoModelWithLMHead,
    AutoTokenizer,
47
    PreTrainedModel,
48
    PreTrainedTokenizer,
Aymeric Augustin's avatar
Aymeric Augustin committed
49
    get_linear_schedule_with_warmup,
50
)
51

52

Aymeric Augustin's avatar
Aymeric Augustin committed
53
54
try:
    from torch.utils.tensorboard import SummaryWriter
55
except ImportError:
Aymeric Augustin's avatar
Aymeric Augustin committed
56
57
58
    from tensorboardX import SummaryWriter


59
logger = logging.getLogger(__name__)
60
61


62
63
MODEL_CONFIG_CLASSES = list(MODEL_WITH_LM_HEAD_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
64
65


66
class TextDataset(Dataset):
67
    def __init__(self, tokenizer: PreTrainedTokenizer, args, file_path: str, block_size=512):
68
        assert os.path.isfile(file_path)
69
70
71

        block_size = block_size - (tokenizer.max_len - tokenizer.max_len_single_sentence)

72
        directory, filename = os.path.split(file_path)
73
        cached_features_file = os.path.join(
74
            directory, args.model_type + "_cached_lm_" + str(block_size) + "_" + filename
75
        )
76

Lysandre's avatar
Lysandre committed
77
        if os.path.exists(cached_features_file) and not args.overwrite_cache:
78
            logger.info("Loading features from cached file %s", cached_features_file)
79
            with open(cached_features_file, "rb") as handle:
80
81
82
83
84
85
86
87
88
                self.examples = pickle.load(handle)
        else:
            logger.info("Creating features from dataset file at %s", directory)

            self.examples = []
            with open(file_path, encoding="utf-8") as f:
                text = f.read()

            tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
89

90
91
            for i in range(0, len(tokenized_text) - block_size + 1, block_size):  # Truncate in block of block_size
                self.examples.append(tokenizer.build_inputs_with_special_tokens(tokenized_text[i : i + block_size]))
92
93
94
95
96
            # Note that we are loosing the last truncated example here for the sake of simplicity (no padding)
            # If your dataset is small, first you should loook for a bigger one :-) and second you
            # can change this behavior by adding (model specific) padding.

            logger.info("Saving features into cached file %s", cached_features_file)
97
            with open(cached_features_file, "wb") as handle:
98
99
100
101
102
103
                pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, item):
104
        return torch.tensor(self.examples[item], dtype=torch.long)
105
106


107
108
109
110
111
112
113
114
115
class LineByLineTextDataset(Dataset):
    def __init__(self, tokenizer: PreTrainedTokenizer, args, file_path: str, block_size=512):
        assert os.path.isfile(file_path)
        # Here, we do not cache the features, operating under the assumption
        # that we will soon use fast multithreaded tokenizers from the
        # `tokenizers` repo everywhere =)
        logger.info("Creating features from dataset file at %s", file_path)

        with open(file_path, encoding="utf-8") as f:
116
            lines = [line for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())]
117

118
        self.examples = tokenizer.batch_encode_plus(lines, add_special_tokens=True, max_length=block_size)["input_ids"]
119
120
121
122
123

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, i):
124
        return torch.tensor(self.examples[i], dtype=torch.long)
125
126


127
def load_and_cache_examples(args, tokenizer, evaluate=False):
128
129
130
131
132
    file_path = args.eval_data_file if evaluate else args.train_data_file
    if args.line_by_line:
        return LineByLineTextDataset(tokenizer, args, file_path=file_path, block_size=args.block_size)
    else:
        return TextDataset(tokenizer, args, file_path=file_path, block_size=args.block_size)
133
134


135
136
137
138
139
140
141
def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

142

143
144
def _sorted_checkpoints(args, checkpoint_prefix="checkpoint", use_mtime=False) -> List[str]:
    ordering_and_checkpoint_path = []
145

146
    glob_checkpoints = glob.glob(os.path.join(args.output_dir, "{}-*".format(checkpoint_prefix)))
jinoobaek-qz's avatar
jinoobaek-qz committed
147
148

    for path in glob_checkpoints:
149
150
151
        if use_mtime:
            ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
        else:
152
            regex_match = re.match(".*{}-([0-9]+)".format(checkpoint_prefix), path)
153
154
155
156
            if regex_match and regex_match.groups():
                ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))

    checkpoints_sorted = sorted(ordering_and_checkpoint_path)
jinoobaek-qz's avatar
jinoobaek-qz committed
157
    checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
158
159
160
161
162
163
164
165
166
167
168
169
170
171
    return checkpoints_sorted


def _rotate_checkpoints(args, checkpoint_prefix="checkpoint", use_mtime=False) -> None:
    if not args.save_total_limit:
        return
    if args.save_total_limit <= 0:
        return

    # Check if we should delete older checkpoint(s)
    checkpoints_sorted = _sorted_checkpoints(args, checkpoint_prefix, use_mtime)
    if len(checkpoints_sorted) <= args.save_total_limit:
        return

jinoobaek-qz's avatar
jinoobaek-qz committed
172
173
174
175
176
    number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - args.save_total_limit)
    checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
    for checkpoint in checkpoints_to_be_deleted:
        logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint))
        shutil.rmtree(checkpoint)
jinoobaek-qz's avatar
jinoobaek-qz committed
177
178


179
def mask_tokens(inputs: torch.Tensor, tokenizer: PreTrainedTokenizer, args) -> Tuple[torch.Tensor, torch.Tensor]:
180
    """ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """
181
182
183
184
185
186

    if tokenizer.mask_token is None:
        raise ValueError(
            "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer."
        )

187
    labels = inputs.clone()
188
    # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
189
    probability_matrix = torch.full(labels.shape, args.mlm_probability)
190
191
192
    special_tokens_mask = [
        tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
    ]
193
    probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
194
195
196
    if tokenizer._pad_token is not None:
        padding_mask = labels.eq(tokenizer.pad_token_id)
        probability_matrix.masked_fill_(padding_mask, value=0.0)
197
    masked_indices = torch.bernoulli(probability_matrix).bool()
Lysandre's avatar
Lysandre committed
198
    labels[~masked_indices] = -100  # We only compute loss on masked tokens
199
200

    # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
thomwolf's avatar
thomwolf committed
201
    indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
202
203
204
    inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)

    # 10% of the time, we replace masked input tokens with random word
thomwolf's avatar
thomwolf committed
205
    indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
206
207
    random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
    inputs[indices_random] = random_words[indices_random]
208

209
    # The rest of the time (10% of the time) we keep the masked input tokens unchanged
210
    return inputs, labels
211

212

213
def train(args, train_dataset, model: PreTrainedModel, tokenizer: PreTrainedTokenizer) -> Tuple[int, float]:
214
215
216
217
218
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
219
220

    def collate(examples: List[torch.Tensor]):
221
222
        if tokenizer._pad_token is None:
            return pad_sequence(examples, batch_first=True)
223
224
        return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id)

thomwolf's avatar
thomwolf committed
225
    train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
226
227
228
    train_dataloader = DataLoader(
        train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, collate_fn=collate
    )
229
230
231
232
233
234
235
236

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
237
    no_decay = ["bias", "LayerNorm.weight"]
238
    optimizer_grouped_parameters = [
239
240
241
242
243
244
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
    ]
245
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
246
247
248
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
    )
249
250

    # Check if saved optimizer or scheduler states exist
Julien Chaumond's avatar
Julien Chaumond committed
251
252
253
254
    if (
        args.model_name_or_path
        and os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt"))
        and os.path.isfile(os.path.join(args.model_name_or_path, "scheduler.pt"))
255
    ):
256
        # Load in optimizer and scheduler states
257
258
        optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))
259

260
261
262
263
264
265
266
267
268
269
270
271
272
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
273
274
275
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
        )
276
277
278
279
280
281

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
282
283
284
285
286
287
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size
        * args.gradient_accumulation_steps
        * (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
288
289
290
291
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
292
293
294
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
295
    if args.model_name_or_path and os.path.exists(args.model_name_or_path):
296
297
298
299
300
301
302
303
304
305
306
307
308
        try:
            # set global_step to gobal_step of last saved checkpoint from model path
            checkpoint_suffix = args.model_name_or_path.split("-")[-1].split("/")[0]
            global_step = int(checkpoint_suffix)
            epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
            steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)

            logger.info("  Continuing training from checkpoint, will skip to saved global_step")
            logger.info("  Continuing training from epoch %d", epochs_trained)
            logger.info("  Continuing training from global step %d", global_step)
            logger.info("  Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
        except ValueError:
            logger.info("  Starting fine-tuning.")
309

310
    tr_loss, logging_loss = 0.0, 0.0
thomwolf's avatar
thomwolf committed
311

312
    model_to_resize = model.module if hasattr(model, "module") else model  # Take care of distributed/parallel training
thomwolf's avatar
thomwolf committed
313
314
    model_to_resize.resize_token_embeddings(len(tokenizer))

315
    model.zero_grad()
316
317
318
    train_iterator = trange(
        epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]
    )
319
    set_seed(args)  # Added here for reproducibility
Bilal Khan's avatar
Bilal Khan committed
320
    for _ in train_iterator:
321
322
        epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):
323

324
325
326
327
328
            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

329
            inputs, labels = mask_tokens(batch, tokenizer, args) if args.mlm else (batch, batch)
330
331
332
            inputs = inputs.to(args.device)
            labels = labels.to(args.device)
            model.train()
333
            outputs = model(inputs, masked_lm_labels=labels) if args.mlm else model(inputs, labels=labels)
334
            loss = outputs[0]  # model outputs are always tuple in transformers (see doc)
335
336

            if args.n_gpu > 1:
337
                loss = loss.mean()  # mean() to average on multi-gpu parallel training
338
339
340
341
342
343
344
345
346
347
348
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
349
350
351
352
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
353
                optimizer.step()
354
                scheduler.step()  # Update learning rate schedule
355
356
357
358
359
                model.zero_grad()
                global_step += 1

                if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
360
361
362
                    if (
                        args.local_rank == -1 and args.evaluate_during_training
                    ):  # Only evaluate when single GPU otherwise metrics may not average well
363
364
                        results = evaluate(args, model, tokenizer)
                        for key, value in results.items():
365
366
367
                            tb_writer.add_scalar("eval_{}".format(key), value, global_step)
                    tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step)
368
369
370
                    logging_loss = tr_loss

                if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
371
                    checkpoint_prefix = "checkpoint"
372
                    # Save model checkpoint
373
                    output_dir = os.path.join(args.output_dir, "{}-{}".format(checkpoint_prefix, global_step))
374
                    os.makedirs(output_dir, exist_ok=True)
375
376
377
                    model_to_save = (
                        model.module if hasattr(model, "module") else model
                    )  # Take care of distributed/parallel training
378
                    model_to_save.save_pretrained(output_dir)
379
380
                    tokenizer.save_pretrained(output_dir)

381
                    torch.save(args, os.path.join(output_dir, "training_args.bin"))
382
383
                    logger.info("Saving model checkpoint to %s", output_dir)

384
                    _rotate_checkpoints(args, checkpoint_prefix)
jinoobaek-qz's avatar
jinoobaek-qz committed
385

386
387
                    torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
Bilal Khan's avatar
Bilal Khan committed
388
                    logger.info("Saving optimizer and scheduler states to %s", output_dir)
389

390
391
392
393
394
395
396
397
398
399
400
401
402
            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step


403
def evaluate(args, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prefix="") -> Dict:
404
405
406
407
408
    # Loop to handle MNLI double evaluation (matched, mis-matched)
    eval_output_dir = args.output_dir

    eval_dataset = load_and_cache_examples(args, tokenizer, evaluate=True)

409
410
    if args.local_rank in [-1, 0]:
        os.makedirs(eval_output_dir, exist_ok=True)
411
412
413

    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
    # Note that DistributedSampler samples randomly
414
415

    def collate(examples: List[torch.Tensor]):
416
417
        if tokenizer._pad_token is None:
            return pad_sequence(examples, batch_first=True)
418
419
        return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id)

420
    eval_sampler = SequentialSampler(eval_dataset)
421
422
423
    eval_dataloader = DataLoader(
        eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, collate_fn=collate
    )
424

ronakice's avatar
ronakice committed
425
426
427
428
    # multi-gpu evaluate
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

429
430
431
432
433
434
    # Eval!
    logger.info("***** Running evaluation {} *****".format(prefix))
    logger.info("  Num examples = %d", len(eval_dataset))
    logger.info("  Batch size = %d", args.eval_batch_size)
    eval_loss = 0.0
    nb_eval_steps = 0
435
436
    model.eval()

437
    for batch in tqdm(eval_dataloader, desc="Evaluating"):
altsoph's avatar
altsoph committed
438
439
440
        inputs, labels = mask_tokens(batch, tokenizer, args) if args.mlm else (batch, batch)
        inputs = inputs.to(args.device)
        labels = labels.to(args.device)
441
442

        with torch.no_grad():
altsoph's avatar
altsoph committed
443
            outputs = model(inputs, masked_lm_labels=labels) if args.mlm else model(inputs, labels=labels)
444
445
446
447
448
449
450
            lm_loss = outputs[0]
            eval_loss += lm_loss.mean().item()
        nb_eval_steps += 1

    eval_loss = eval_loss / nb_eval_steps
    perplexity = torch.exp(torch.tensor(eval_loss))

451
    result = {"perplexity": perplexity}
452

453
    output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
454
455
456
457
458
459
    with open(output_eval_file, "w") as writer:
        logger.info("***** Eval results {} *****".format(prefix))
        for key in sorted(result.keys()):
            logger.info("  %s = %s", key, str(result[key]))
            writer.write("%s = %s\n" % (key, str(result[key])))

460
    return result
461
462
463
464
465


def main():
    parser = argparse.ArgumentParser()

466
    # Required parameters
467
468
469
470
471
472
473
474
475
    parser.add_argument(
        "--train_data_file", default=None, type=str, required=True, help="The input training data file (a text file)."
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        required=True,
        help="The output directory where the model predictions and checkpoints will be written.",
    )
476
477
478
    parser.add_argument(
        "--model_type", type=str, required=True, help="The model architecture to be trained or fine-tuned.",
    )
479

480
    # Other parameters
481
482
483
484
485
486
    parser.add_argument(
        "--eval_data_file",
        default=None,
        type=str,
        help="An optional input evaluation data file to evaluate the perplexity on (a text file).",
    )
487
488
489
490
491
    parser.add_argument(
        "--line_by_line",
        action="store_true",
        help="Whether distinct lines of text in the dataset are to be handled as distinct sequences.",
    )
Julien Chaumond's avatar
Julien Chaumond committed
492
493
494
    parser.add_argument(
        "--should_continue", action="store_true", help="Whether to continue from latest checkpoint in output_dir"
    )
495
496
    parser.add_argument(
        "--model_name_or_path",
497
        default=None,
498
        type=str,
499
        help="The model checkpoint for weights initialization. Leave None if you want to train a model from scratch.",
500
501
502
503
504
505
506
507
508
509
510
    )

    parser.add_argument(
        "--mlm", action="store_true", help="Train with masked-language modeling loss instead of language modeling."
    )
    parser.add_argument(
        "--mlm_probability", type=float, default=0.15, help="Ratio of tokens to mask for masked language modeling loss"
    )

    parser.add_argument(
        "--config_name",
511
        default=None,
512
        type=str,
513
        help="Optional pretrained config name or path if not the same as model_name_or_path. If both are None, initialize a new config.",
514
515
516
    )
    parser.add_argument(
        "--tokenizer_name",
517
518
519
520
        default=None,
        type=str,
        help="Optional pretrained tokenizer name or path if not the same as model_name_or_path. If both are None, initialize a new tokenizer.",
    )
521
522
    parser.add_argument(
        "--cache_dir",
523
        default=None,
524
        type=str,
Oren Amsalem's avatar
Oren Amsalem committed
525
        help="Optional directory to store the pre-trained models downloaded from s3 (instead of the default one)",
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
    )
    parser.add_argument(
        "--block_size",
        default=-1,
        type=int,
        help="Optional input sequence length after tokenization."
        "The training dataset will be truncated in block of this size for training."
        "Default to the model max input length for single sentence inputs (take into account special tokens).",
    )
    parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
    parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--evaluate_during_training", action="store_true", help="Run evaluation during training at each logging step."
    )

    parser.add_argument("--per_gpu_train_batch_size", default=4, type=int, help="Batch size per GPU/CPU for training.")
    parser.add_argument(
        "--per_gpu_eval_batch_size", default=4, type=int, help="Batch size per GPU/CPU for evaluation."
    )
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help="Number of updates steps to accumulate before performing a backward/update pass.",
    )
    parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
    parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
    parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
    parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
    parser.add_argument(
        "--num_train_epochs", default=1.0, type=float, help="Total number of training epochs to perform."
    )
    parser.add_argument(
        "--max_steps",
        default=-1,
        type=int,
        help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
    )
    parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")

566
567
    parser.add_argument("--logging_steps", type=int, default=500, help="Log every X updates steps.")
    parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.")
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
    parser.add_argument(
        "--save_total_limit",
        type=int,
        default=None,
        help="Limit the total amount of checkpoints, delete the older checkpoints in the output_dir, does not delete by default",
    )
    parser.add_argument(
        "--eval_all_checkpoints",
        action="store_true",
        help="Evaluate all checkpoints starting with the same prefix as model_name_or_path ending and ending with step number",
    )
    parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
    parser.add_argument(
        "--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory"
    )
    parser.add_argument(
        "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
    )
    parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")

    parser.add_argument(
        "--fp16",
        action="store_true",
        help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
    )
    parser.add_argument(
        "--fp16_opt_level",
        type=str,
        default="O1",
        help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
        "See details at https://nvidia.github.io/apex/amp.html",
    )
    parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
    parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
    parser.add_argument("--server_port", type=str, default="", help="For distant debugging.")
603
604
    args = parser.parse_args()

maxvidal's avatar
maxvidal committed
605
    if args.model_type in ["bert", "roberta", "distilbert", "camembert"] and not args.mlm:
606
        raise ValueError(
607
            "BERT and RoBERTa-like models do not have LM heads but masked LM heads. They must be run using the --mlm "
608
609
            "flag (masked language modeling)."
        )
610
    if args.eval_data_file is None and args.do_eval:
611
612
613
614
        raise ValueError(
            "Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file "
            "or remove the --do_eval argument."
        )
615
616
617
    if args.should_continue:
        sorted_checkpoints = _sorted_checkpoints(args)
        if len(sorted_checkpoints) == 0:
Julien Chaumond's avatar
Julien Chaumond committed
618
            raise ValueError("Used --should_continue but no checkpoint was found in --output_dir.")
619
620
        else:
            args.model_name_or_path = sorted_checkpoints[-1]
621
622
623
624
625
626

    if (
        os.path.exists(args.output_dir)
        and os.listdir(args.output_dir)
        and args.do_train
        and not args.overwrite_output_dir
627
        and not args.should_continue
628
629
630
631
632
633
    ):
        raise ValueError(
            "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
                args.output_dir
            )
        )
634
635
636
637
638

    # Setup distant debugging if needed
    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
639

640
641
642
643
644
645
646
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
        ptvsd.wait_for_attach()

    # Setup CUDA, GPU & distributed training
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
647
        args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()
648
649
650
    else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
651
        torch.distributed.init_process_group(backend="nccl")
652
653
654
655
        args.n_gpu = 1
    args.device = device

    # Setup logging
656
657
658
659
660
661
662
663
664
665
666
667
668
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
    )
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        args.local_rank,
        device,
        args.n_gpu,
        bool(args.local_rank != -1),
        args.fp16,
    )
669
670
671
672
673
674

    # Set seed
    set_seed(args)

    # Load pretrained model and tokenizer
    if args.local_rank not in [-1, 0]:
675
676
        torch.distributed.barrier()  # Barrier to make sure only the first process in distributed training download model & vocab

677
    if args.config_name:
678
        config = AutoConfig.from_pretrained(args.config_name, cache_dir=args.cache_dir)
679
    elif args.model_name_or_path:
680
        config = AutoConfig.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
681
    else:
682
683
684
685
686
687
        # When we release a pip version exposing CONFIG_MAPPING,
        # we can do `config = CONFIG_MAPPING[args.model_type]()`.
        raise ValueError(
            "You are instantiating a new config instance from scratch. This is not supported, but you can do it from another script, save it,"
            "and load it from here, using --config_name"
        )
688
689

    if args.tokenizer_name:
690
        tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, cache_dir=args.cache_dir)
691
    elif args.model_name_or_path:
692
        tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
693
    else:
694
        raise ValueError(
695
696
            "You are instantiating a new tokenizer from scratch. This is not supported, but you can do it from another script, save it,"
            "and load it from here, using --tokenizer_name"
697
698
        )

699
    if args.block_size <= 0:
700
        args.block_size = tokenizer.max_len
701
702
        # Our input block size will be the max possible for the model
    else:
703
        args.block_size = min(args.block_size, tokenizer.max_len)
704
705

    if args.model_name_or_path:
706
        model = AutoModelWithLMHead.from_pretrained(
707
708
709
710
711
712
713
            args.model_name_or_path,
            from_tf=bool(".ckpt" in args.model_name_or_path),
            config=config,
            cache_dir=args.cache_dir,
        )
    else:
        logger.info("Training new model from scratch")
714
        model = AutoModelWithLMHead.from_config(config)
715

716
    model.to(args.device)
717
718

    if args.local_rank == 0:
719
        torch.distributed.barrier()  # End of barrier to make sure only the first process in distributed training download model & vocab
720
721
722
723
724

    logger.info("Training/evaluation parameters %s", args)

    # Training
    if args.do_train:
725
726
727
        if args.local_rank not in [-1, 0]:
            torch.distributed.barrier()  # Barrier to make sure only the first process in distributed training process the dataset, and the others will use the cache

728
        train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False)
729
730
731
732

        if args.local_rank == 0:
            torch.distributed.barrier()

733
734
735
        global_step, tr_loss = train(args, train_dataset, model, tokenizer)
        logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)

736
    # Saving best-practices: if you use save_pretrained for the model and tokenizer, you can reload them using from_pretrained()
737
738
    if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
        # Create output directory if needed
739
740
        if args.local_rank in [-1, 0]:
            os.makedirs(args.output_dir, exist_ok=True)
741
742
743
744

        logger.info("Saving model checkpoint to %s", args.output_dir)
        # Save a trained model, configuration and tokenizer using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
745
746
747
        model_to_save = (
            model.module if hasattr(model, "module") else model
        )  # Take care of distributed/parallel training
748
749
750
751
        model_to_save.save_pretrained(args.output_dir)
        tokenizer.save_pretrained(args.output_dir)

        # Good practice: save your training arguments together with the trained model
752
        torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
753
754

        # Load a trained model and vocabulary that you have fine-tuned
755
756
        model = AutoModelWithLMHead.from_pretrained(args.output_dir)
        tokenizer = AutoTokenizer.from_pretrained(args.output_dir)
757
758
759
760
761
762
763
        model.to(args.device)

    # Evaluation
    results = {}
    if args.do_eval and args.local_rank in [-1, 0]:
        checkpoints = [args.output_dir]
        if args.eval_all_checkpoints:
764
765
766
            checkpoints = list(
                os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
            )
767
            logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN)  # Reduce logging
768
769
        logger.info("Evaluate the following checkpoints: %s", checkpoints)
        for checkpoint in checkpoints:
770
771
772
            global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
            prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""

773
            model = AutoModelWithLMHead.from_pretrained(checkpoint)
774
            model.to(args.device)
775
            result = evaluate(args, model, tokenizer, prefix=prefix)
776
            result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
777
778
779
780
781
782
            results.update(result)

    return results


if __name__ == "__main__":
altsoph's avatar
altsoph committed
783
    main()