run_squad.py 32.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
16
""" Finetuning the library models for question-answering on SQuAD (DistilBERT, Bert, XLM, XLNet)."""
17
18
19


import argparse
Aymeric Augustin's avatar
Aymeric Augustin committed
20
import glob
21
22
23
import logging
import os
import random
24
import timeit
Aymeric Augustin's avatar
Aymeric Augustin committed
25

26
27
import numpy as np
import torch
28
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
29
from torch.utils.data.distributed import DistributedSampler
30
from tqdm import tqdm, trange
31

32
33
from transformers import (
    WEIGHTS_NAME,
Aymeric Augustin's avatar
Aymeric Augustin committed
34
35
36
37
    AdamW,
    AlbertConfig,
    AlbertForQuestionAnswering,
    AlbertTokenizer,
38
39
40
    BertConfig,
    BertForQuestionAnswering,
    BertTokenizer,
Aymeric Augustin's avatar
Aymeric Augustin committed
41
42
43
44
    DistilBertConfig,
    DistilBertForQuestionAnswering,
    DistilBertTokenizer,
    RobertaConfig,
45
46
47
48
49
50
51
52
    RobertaForQuestionAnswering,
    RobertaTokenizer,
    XLMConfig,
    XLMForQuestionAnswering,
    XLMTokenizer,
    XLNetConfig,
    XLNetForQuestionAnswering,
    XLNetTokenizer,
Aymeric Augustin's avatar
Aymeric Augustin committed
53
54
    get_linear_schedule_with_warmup,
    squad_convert_examples_to_features,
55
)
Aymeric Augustin's avatar
Aymeric Augustin committed
56
57
58
59
60
61
62
63
64
65
from transformers.data.metrics.squad_metrics import (
    compute_predictions_log_probs,
    compute_predictions_logits,
    squad_evaluate,
)
from transformers.data.processors.squad import SquadResult, SquadV1Processor, SquadV2Processor


try:
    from torch.utils.tensorboard import SummaryWriter
66
except ImportError:
Aymeric Augustin's avatar
Aymeric Augustin committed
67
    from tensorboardX import SummaryWriter
thomwolf's avatar
thomwolf committed
68

69
70
71

logger = logging.getLogger(__name__)

72
73
74
75
ALL_MODELS = sum(
    (tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, RobertaConfig, XLNetConfig, XLMConfig)),
    (),
)
thomwolf's avatar
thomwolf committed
76
77

MODEL_CLASSES = {
78
79
80
81
82
83
    "bert": (BertConfig, BertForQuestionAnswering, BertTokenizer),
    "roberta": (RobertaConfig, RobertaForQuestionAnswering, RobertaTokenizer),
    "xlnet": (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer),
    "xlm": (XLMConfig, XLMForQuestionAnswering, XLMTokenizer),
    "distilbert": (DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer),
    "albert": (AlbertConfig, AlbertForQuestionAnswering, AlbertTokenizer),
thomwolf's avatar
thomwolf committed
84
85
}

86

thomwolf's avatar
thomwolf committed
87
88
89
90
91
92
93
def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

94

95
96
def to_list(tensor):
    return tensor.detach().cpu().tolist()
thomwolf's avatar
thomwolf committed
97

98

99
def train(args, train_dataset, model, tokenizer):
thomwolf's avatar
thomwolf committed
100
101
102
103
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

104
    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
105
106
    train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
thomwolf's avatar
thomwolf committed
107
108

    if args.max_steps > 0:
109
        t_total = args.max_steps
110
        args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
thomwolf's avatar
thomwolf committed
111
    else:
112
        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
thomwolf's avatar
thomwolf committed
113

114
    # Prepare optimizer and schedule (linear warmup and decay)
115
    no_decay = ["bias", "LayerNorm.weight"]
thomwolf's avatar
thomwolf committed
116
    optimizer_grouped_parameters = [
117
118
119
120
121
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
LysandreJik's avatar
Cleanup  
LysandreJik committed
122
    ]
123
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
124
    scheduler = get_linear_schedule_with_warmup(
125
126
        optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
    )
127
128

    # Check if saved optimizer or scheduler states exist
129
130
131
    if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
        os.path.join(args.model_name_or_path, "scheduler.pt")
    ):
132
        # Load in optimizer and scheduler states
133
134
        optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))
LysandreJik's avatar
Cleanup  
LysandreJik committed
135

thomwolf's avatar
thomwolf committed
136
137
138
139
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
140
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
141

142
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
thomwolf's avatar
thomwolf committed
143

144
145
146
147
    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

thomwolf's avatar
thomwolf committed
148
149
    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
150
151
152
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
        )
thomwolf's avatar
thomwolf committed
153

thomwolf's avatar
thomwolf committed
154
155
156
157
    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
158
159
160
161
162
163
164
165
    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size
        * args.gradient_accumulation_steps
        * (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
166
    logger.info("  Total optimization steps = %d", t_total)
thomwolf's avatar
thomwolf committed
167

Lysandre's avatar
Lysandre committed
168
    global_step = 1
169
170
171
172
173
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if os.path.exists(args.model_name_or_path):
        # set global_step to gobal_step of last saved checkpoint from model path
174
175
176
177
178
        global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0])
        epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
        steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)

        logger.info("  Continuing training from checkpoint, will skip to saved global_step")
179
180
        logger.info("  Continuing training from epoch %d", epochs_trained)
        logger.info("  Continuing training from global step %d", global_step)
181
        logger.info("  Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
182

thomwolf's avatar
thomwolf committed
183
    tr_loss, logging_loss = 0.0, 0.0
184
    model.zero_grad()
185
186
187
    train_iterator = trange(
        epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]
    )
188
189
190
    # Added here for reproductibility (even between python 2 and 3)
    set_seed(args)

191
    for _ in train_iterator:
192
        epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
193
        for step, batch in enumerate(epoch_iterator):
194
195
196
197
198
199

            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

200
            model.train()
thomwolf's avatar
thomwolf committed
201
            batch = tuple(t.to(args.device) for t in batch)
LysandreJik's avatar
Cleanup  
LysandreJik committed
202
203

            inputs = {
204
205
206
207
208
                "input_ids": batch[0],
                "attention_mask": batch[1],
                "token_type_ids": None if args.model_type in ["xlm", "roberta", "distilbert"] else batch[2],
                "start_positions": batch[3],
                "end_positions": batch[4],
LysandreJik's avatar
Cleanup  
LysandreJik committed
209
210
            }

211
212
            if args.model_type in ["xlnet", "xlm"]:
                inputs.update({"cls_index": batch[5], "p_mask": batch[6]})
213
                if args.version_2_with_negative:
214
                    inputs.update({"is_impossible": batch[7]})
Peiqin Lin's avatar
typos  
Peiqin Lin committed
215
            outputs = model(**inputs)
216
217
            # model outputs are always tuple in transformers (see doc)
            loss = outputs[0]
thomwolf's avatar
thomwolf committed
218

219
            if args.n_gpu > 1:
220
                loss = loss.mean()  # mean() to average on multi-gpu parallel (not distributed) training
221
222
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
thomwolf's avatar
thomwolf committed
223

224
225
226
227
228
229
230
231
            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
232
                if args.fp16:
233
                    torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
234
                else:
235
                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
236

237
                optimizer.step()
238
                scheduler.step()  # Update learning rate schedule
239
240
241
                model.zero_grad()
                global_step += 1

LysandreJik's avatar
Cleanup  
LysandreJik committed
242
                # Log metrics
243
                if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
244
245
                    # Only evaluate when single GPU otherwise metrics may not average well
                    if args.local_rank == -1 and args.evaluate_during_training:
246
247
                        results = evaluate(args, model, tokenizer)
                        for key, value in results.items():
248
249
250
                            tb_writer.add_scalar("eval_{}".format(key), value, global_step)
                    tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step)
251
252
                    logging_loss = tr_loss

LysandreJik's avatar
Cleanup  
LysandreJik committed
253
                # Save model checkpoint
254
                if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
255
                    output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
256
257
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
258
                    # Take care of distributed/parallel training
259
                    model_to_save = model.module if hasattr(model, "module") else model
260
                    model_to_save.save_pretrained(output_dir)
261
262
                    tokenizer.save_pretrained(output_dir)

263
                    torch.save(args, os.path.join(output_dir, "training_args.bin"))
264
265
                    logger.info("Saving model checkpoint to %s", output_dir)

266
267
268
                    torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
                    logger.info("Saving optimizer and scheduler states to %s", output_dir)
269

270
271
272
273
274
275
276
            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

thomwolf's avatar
thomwolf committed
277
278
279
    if args.local_rank in [-1, 0]:
        tb_writer.close()

280
281
282
283
    return global_step, tr_loss / global_step


def evaluate(args, model, tokenizer, prefix=""):
284
    dataset, examples, features = load_and_cache_examples(args, tokenizer, evaluate=True, output_examples=True)
285
286
287
288
289

    if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
        os.makedirs(args.output_dir)

    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
LysandreJik's avatar
Cleanup  
LysandreJik committed
290

291
    # Note that DistributedSampler samples randomly
292
    eval_sampler = SequentialSampler(dataset)
293
    eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
294

ronakice's avatar
ronakice committed
295
    # multi-gpu evaluate
296
    if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel):
ronakice's avatar
ronakice committed
297
298
        model = torch.nn.DataParallel(model)

299
300
301
302
    # Eval!
    logger.info("***** Running evaluation {} *****".format(prefix))
    logger.info("  Num examples = %d", len(dataset))
    logger.info("  Batch size = %d", args.eval_batch_size)
LysandreJik's avatar
Cleanup  
LysandreJik committed
303

304
    all_results = []
305
    start_time = timeit.default_timer()
LysandreJik's avatar
Cleanup  
LysandreJik committed
306

307
308
309
    for batch in tqdm(eval_dataloader, desc="Evaluating"):
        model.eval()
        batch = tuple(t.to(args.device) for t in batch)
LysandreJik's avatar
Cleanup  
LysandreJik committed
310

311
        with torch.no_grad():
LysandreJik's avatar
LysandreJik committed
312
            inputs = {
313
314
315
                "input_ids": batch[0],
                "attention_mask": batch[1],
                "token_type_ids": None if args.model_type in ["xlm", "roberta", "distilbert"] else batch[2],
LysandreJik's avatar
LysandreJik committed
316
            }
317
            example_indices = batch[3]
318

LysandreJik's avatar
Cleanup  
LysandreJik committed
319
            # XLNet and XLM use more arguments for their predictions
320
321
            if args.model_type in ["xlnet", "xlm"]:
                inputs.update({"cls_index": batch[4], "p_mask": batch[5]})
LysandreJik's avatar
Cleanup  
LysandreJik committed
322

323
324
325
326
327
            outputs = model(**inputs)

        for i, example_index in enumerate(example_indices):
            eval_feature = features[example_index.item()]
            unique_id = int(eval_feature.unique_id)
LysandreJik's avatar
LysandreJik committed
328

LysandreJik's avatar
LysandreJik committed
329
330
            output = [to_list(output[i]) for output in outputs]

LysandreJik's avatar
Cleanup  
LysandreJik committed
331
332
            # Some models (XLNet, XLM) use 5 arguments for their predictions, while the other "simpler"
            # models only use two.
LysandreJik's avatar
LysandreJik committed
333
334
335
336
            if len(output) >= 5:
                start_logits = output[0]
                start_top_index = output[1]
                end_logits = output[2]
LysandreJik's avatar
Cleanup  
LysandreJik committed
337
                end_top_index = output[3]
LysandreJik's avatar
LysandreJik committed
338
339
340
                cls_logits = output[4]

                result = SquadResult(
341
342
343
                    unique_id,
                    start_logits,
                    end_logits,
344
345
                    start_top_index=start_top_index,
                    end_top_index=end_top_index,
346
                    cls_logits=cls_logits,
LysandreJik's avatar
LysandreJik committed
347
348
349
350
                )

            else:
                start_logits, end_logits = output
351
                result = SquadResult(unique_id, start_logits, end_logits)
LysandreJik's avatar
LysandreJik committed
352

353
            all_results.append(result)
354

355
    evalTime = timeit.default_timer() - start_time
356
    logger.info("  Evaluation done in total %f secs (%f sec per example)", evalTime, evalTime / len(dataset))
357

thomwolf's avatar
thomwolf committed
358
    # Compute predictions
359
360
    output_prediction_file = os.path.join(args.output_dir, "predictions_{}.json".format(prefix))
    output_nbest_file = os.path.join(args.output_dir, "nbest_predictions_{}.json".format(prefix))
LysandreJik's avatar
Cleanup  
LysandreJik committed
361

362
    if args.version_2_with_negative:
363
        output_null_log_odds_file = os.path.join(args.output_dir, "null_odds_{}.json".format(prefix))
364
365
    else:
        output_null_log_odds_file = None
366

LysandreJik's avatar
Cleanup  
LysandreJik committed
367
    # XLNet and XLM use a more complex post-processing procedure
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
    if args.model_type in ["xlnet", "xlm"]:
        start_n_top = model.config.start_n_top if hasattr(model, "config") else model.module.config.start_n_top
        end_n_top = model.config.end_n_top if hasattr(model, "config") else model.module.config.end_n_top

        predictions = compute_predictions_log_probs(
            examples,
            features,
            all_results,
            args.n_best_size,
            args.max_answer_length,
            output_prediction_file,
            output_nbest_file,
            output_null_log_odds_file,
            start_n_top,
            end_n_top,
            args.version_2_with_negative,
            tokenizer,
            args.verbose_logging,
        )
387
    else:
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
        predictions = compute_predictions_logits(
            examples,
            features,
            all_results,
            args.n_best_size,
            args.max_answer_length,
            args.do_lower_case,
            output_prediction_file,
            output_nbest_file,
            output_null_log_odds_file,
            args.verbose_logging,
            args.version_2_with_negative,
            args.null_score_diff_threshold,
            tokenizer,
        )
403

LysandreJik's avatar
Cleanup  
LysandreJik committed
404
    # Compute the F1 and exact scores.
LysandreJik's avatar
LysandreJik committed
405
    results = squad_evaluate(examples, predictions)
406
407
    return results

408

409
def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False):
VictorSanh's avatar
VictorSanh committed
410
    if args.local_rank not in [-1, 0] and not evaluate:
411
412
        # Make sure only the first process in distributed training process the dataset, and the others will use the cache
        torch.distributed.barrier()
thomwolf's avatar
thomwolf committed
413

414
    # Load data features from cache or dataset file
LysandreJik's avatar
Cleanup  
LysandreJik committed
415
    input_dir = args.data_dir if args.data_dir else "."
416
417
418
419
420
421
422
    cached_features_file = os.path.join(
        input_dir,
        "cached_{}_{}_{}".format(
            "dev" if evaluate else "train",
            list(filter(None, args.model_name_or_path.split("/"))).pop(),
            str(args.max_seq_length),
        ),
LysandreJik's avatar
Cleanup  
LysandreJik committed
423
424
425
    )

    # Init features and dataset from cache if it exists
426
    if os.path.exists(cached_features_file) and not args.overwrite_cache and not output_examples:
427
        logger.info("Loading features from cached file %s", cached_features_file)
428
429
        features_and_dataset = torch.load(cached_features_file)
        features, dataset = features_and_dataset["features"], features_and_dataset["dataset"]
thomwolf's avatar
thomwolf committed
430
    else:
LysandreJik's avatar
Cleanup  
LysandreJik committed
431
        logger.info("Creating features from dataset file at %s", input_dir)
Lysandre's avatar
Lysandre committed
432

433
        if not args.data_dir and ((evaluate and not args.predict_file) or (not evaluate and not args.train_file)):
LysandreJik's avatar
Cleanup  
LysandreJik committed
434
435
436
            try:
                import tensorflow_datasets as tfds
            except ImportError:
437
                raise ImportError("If not data_dir is specified, tensorflow_datasets needs to be installed.")
LysandreJik's avatar
Cleanup  
LysandreJik committed
438
439

            if args.version_2_with_negative:
440
                logger.warn("tensorflow_datasets does not handle version 2 of SQuAD.")
LysandreJik's avatar
Cleanup  
LysandreJik committed
441
442

            tfds_examples = tfds.load("squad")
443
            examples = SquadV1Processor().get_examples_from_dataset(tfds_examples, evaluate=evaluate)
LysandreJik's avatar
Cleanup  
LysandreJik committed
444
445
        else:
            processor = SquadV2Processor() if args.version_2_with_negative else SquadV1Processor()
446
447
448
449
            if evaluate:
                examples = processor.get_dev_examples(args.data_dir, filename=args.predict_file)
            else:
                examples = processor.get_train_examples(args.data_dir, filename=args.train_file)
LysandreJik's avatar
LysandreJik committed
450

451
        features, dataset = squad_convert_examples_to_features(
Lysandre's avatar
Lysandre committed
452
453
454
455
456
457
            examples=examples,
            tokenizer=tokenizer,
            max_seq_length=args.max_seq_length,
            doc_stride=args.doc_stride,
            max_query_length=args.max_query_length,
            is_training=not evaluate,
458
            return_dataset="pt",
erenup's avatar
erenup committed
459
            threads=args.threads,
Lysandre's avatar
Lysandre committed
460
461
        )

thomwolf's avatar
thomwolf committed
462
        if args.local_rank in [-1, 0]:
463
464
            logger.info("Saving features into cached file %s", cached_features_file)
            torch.save({"features": features, "dataset": dataset}, cached_features_file)
thomwolf's avatar
thomwolf committed
465

VictorSanh's avatar
VictorSanh committed
466
    if args.local_rank == 0 and not evaluate:
467
468
        # Make sure only the first process in distributed training process the dataset, and the others will use the cache
        torch.distributed.barrier()
thomwolf's avatar
thomwolf committed
469

470
471
    if output_examples:
        return dataset, examples, features
thomwolf's avatar
thomwolf committed
472
473
    return dataset

474
475
476
477

def main():
    parser = argparse.ArgumentParser()

478
    # Required parameters
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
    parser.add_argument(
        "--model_type",
        default=None,
        type=str,
        required=True,
        help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
    )
    parser.add_argument(
        "--model_name_or_path",
        default=None,
        type=str,
        required=True,
        help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
    )
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help="The output directory where the model checkpoints and predictions will be written.",
    )
500

501
    # Other parameters
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        help="The input data dir. Should contain the .json files for the task."
        + "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
    )
    parser.add_argument(
        "--train_file",
        default=None,
        type=str,
        help="The input training file. If a data dir is specified, will look for the file there"
        + "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
    )
    parser.add_argument(
        "--predict_file",
        default=None,
        type=str,
        help="The input evaluation file. If a data dir is specified, will look for the file there"
        + "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
    )
    parser.add_argument(
        "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
    )
    parser.add_argument(
        "--tokenizer_name",
        default="",
        type=str,
        help="Pretrained tokenizer name or path if not the same as model_name",
    )
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help="Where do you want to store the pre-trained models downloaded from s3",
    )

    parser.add_argument(
        "--version_2_with_negative",
        action="store_true",
        help="If true, the SQuAD examples contain some that do not have an answer.",
    )
    parser.add_argument(
        "--null_score_diff_threshold",
        type=float,
        default=0.0,
        help="If null_score - best_non_null is greater than the threshold predict null.",
    )

    parser.add_argument(
        "--max_seq_length",
        default=384,
        type=int,
        help="The maximum total input sequence length after WordPiece tokenization. Sequences "
        "longer than this will be truncated, and sequences shorter than this will be padded.",
    )
    parser.add_argument(
        "--doc_stride",
        default=128,
        type=int,
        help="When splitting up a long document into chunks, how much stride to take between chunks.",
    )
    parser.add_argument(
        "--max_query_length",
        default=64,
        type=int,
        help="The maximum number of tokens for the question. Questions longer than this will "
        "be truncated to this length.",
    )
    parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
    parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--evaluate_during_training", action="store_true", help="Rul evaluation during training at each logging step."
    )
    parser.add_argument(
        "--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
    )

    parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.")
    parser.add_argument(
        "--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation."
    )
    parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help="Number of updates steps to accumulate before performing a backward/update pass.",
    )
    parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
    parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
    parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
    parser.add_argument(
        "--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform."
    )
    parser.add_argument(
        "--max_steps",
        default=-1,
        type=int,
        help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
    )
    parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
    parser.add_argument(
        "--n_best_size",
        default=20,
        type=int,
        help="The total number of n-best predictions to generate in the nbest_predictions.json output file.",
    )
    parser.add_argument(
        "--max_answer_length",
        default=30,
        type=int,
        help="The maximum length of an answer that can be generated. This is needed because the start "
        "and end predictions are not conditioned on one another.",
    )
    parser.add_argument(
        "--verbose_logging",
        action="store_true",
        help="If true, all of the warnings related to data processing will be printed. "
        "A number of warnings are expected for a normal SQuAD evaluation.",
    )

    parser.add_argument("--logging_steps", type=int, default=50, help="Log every X updates steps.")
    parser.add_argument("--save_steps", type=int, default=50, help="Save checkpoint every X updates steps.")
    parser.add_argument(
        "--eval_all_checkpoints",
        action="store_true",
        help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
    )
    parser.add_argument("--no_cuda", action="store_true", help="Whether not to use CUDA when available")
    parser.add_argument(
        "--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory"
    )
    parser.add_argument(
        "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
    )
    parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")

    parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus")
    parser.add_argument(
        "--fp16",
        action="store_true",
        help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
    )
    parser.add_argument(
        "--fp16_opt_level",
        type=str,
        default="O1",
        help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
        "See details at https://nvidia.github.io/apex/amp.html",
    )
    parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
    parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")

    parser.add_argument("--threads", type=int, default=1, help="multiple threads for converting example to features")
657
658
    args = parser.parse_args()

659
660
661
662
663
664
    if (
        os.path.exists(args.output_dir)
        and os.listdir(args.output_dir)
        and args.do_train
        and not args.overwrite_output_dir
    ):
665
        raise ValueError(
666
667
668
669
            "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
                args.output_dir
            )
        )
thomwolf's avatar
thomwolf committed
670

671
    # Setup distant debugging if needed
672
673
674
    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
675

676
        print("Waiting for debugger attach")
677
        ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
678
679
        ptvsd.wait_for_attach()

thomwolf's avatar
thomwolf committed
680
    # Setup CUDA, GPU & distributed training
681
    if args.local_rank == -1 or args.no_cuda:
682
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
thomwolf's avatar
thomwolf committed
683
684
        args.n_gpu = torch.cuda.device_count()
    else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
685
686
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
687
        torch.distributed.init_process_group(backend="nccl")
thomwolf's avatar
thomwolf committed
688
689
        args.n_gpu = 1
    args.device = device
690

thomwolf's avatar
thomwolf committed
691
    # Setup logging
692
693
694
695
696
697
698
699
700
701
702
703
704
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
    )
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        args.local_rank,
        device,
        args.n_gpu,
        bool(args.local_rank != -1),
        args.fp16,
    )
705

706
707
    # Set seed
    set_seed(args)
708

thomwolf's avatar
thomwolf committed
709
    # Load pretrained model and tokenizer
710
    if args.local_rank not in [-1, 0]:
711
712
        # Make sure only the first process in distributed training will download model & vocab
        torch.distributed.barrier()
713

714
    args.model_type = args.model_type.lower()
715
    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
    config = config_class.from_pretrained(
        args.config_name if args.config_name else args.model_name_or_path,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )
    tokenizer = tokenizer_class.from_pretrained(
        args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
        do_lower_case=args.do_lower_case,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )
    model = model_class.from_pretrained(
        args.model_name_or_path,
        from_tf=bool(".ckpt" in args.model_name_or_path),
        config=config,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )
731
732

    if args.local_rank == 0:
733
734
        # Make sure only the first process in distributed training will download model & vocab
        torch.distributed.barrier()
735

thomwolf's avatar
thomwolf committed
736
    model.to(args.device)
737

738
739
    logger.info("Training/evaluation parameters %s", args)

Simon Layton's avatar
Simon Layton committed
740
741
742
743
744
745
    # Before we do anything with models, we want to ensure that we get fp16 execution of torch.einsum if args.fp16 is set.
    # Otherwise it'll default to "promote" mode, and we'll get fp32 operations. Note that running `--fp16_opt_level="O2"` will
    # remove the need for this code, but it is still valid.
    if args.fp16:
        try:
            import apex
746
747

            apex.amp.register_half_function(torch, "einsum")
Simon Layton's avatar
Simon Layton committed
748
        except ImportError:
749
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
Simon Layton's avatar
Simon Layton committed
750

thomwolf's avatar
thomwolf committed
751
    # Training
752
    if args.do_train:
753
        train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False)
754
        global_step, tr_loss = train(args, train_dataset, model, tokenizer)
755
        logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
756

thomwolf's avatar
thomwolf committed
757
    # Save the trained model and the tokenizer
Peng Qi's avatar
Peng Qi committed
758
    if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
759
760
761
762
763
764
765
        # Create output directory if needed
        if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
            os.makedirs(args.output_dir)

        logger.info("Saving model checkpoint to %s", args.output_dir)
        # Save a trained model, configuration and tokenizer using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
766
        # Take care of distributed/parallel training
767
        model_to_save = model.module if hasattr(model, "module") else model
768
769
        model_to_save.save_pretrained(args.output_dir)
        tokenizer.save_pretrained(args.output_dir)
770
771

        # Good practice: save your training arguments together with the trained model
772
        torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
773

774
        # Load a trained model and vocabulary that you have fine-tuned
775
776
        model = model_class.from_pretrained(args.output_dir, force_download=True)
        tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
777
778
        model.to(args.device)

thomwolf's avatar
thomwolf committed
779
    # Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
780
781
    results = {}
    if args.do_eval and args.local_rank in [-1, 0]:
782
783
784
785
        if args.do_train:
            logger.info("Loading checkpoints saved during training for evaluation")
            checkpoints = [args.output_dir]
            if args.eval_all_checkpoints:
786
787
788
789
                checkpoints = list(
                    os.path.dirname(c)
                    for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
                )
790
791
792
793
                logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN)  # Reduce model loading logs
        else:
            logger.info("Loading checkpoint %s for evaluation", args.model_name_or_path)
            checkpoints = [args.model_name_or_path]
thomwolf's avatar
thomwolf committed
794

795
        logger.info("Evaluate the following checkpoints: %s", checkpoints)
thomwolf's avatar
thomwolf committed
796

797
        for checkpoint in checkpoints:
thomwolf's avatar
thomwolf committed
798
            # Reload the model
799
800
            global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
            model = model_class.from_pretrained(checkpoint, force_download=True)
801
            model.to(args.device)
thomwolf's avatar
thomwolf committed
802
803

            # Evaluate
804
            result = evaluate(args, model, tokenizer, prefix=global_step)
thomwolf's avatar
thomwolf committed
805

806
            result = dict((k + ("_{}".format(global_step) if global_step else ""), v) for k, v in result.items())
807
            results.update(result)
thomwolf's avatar
thomwolf committed
808

809
    logger.info("Results: {}".format(results))
thomwolf's avatar
thomwolf committed
810

811
    return results
812
813
814
815


if __name__ == "__main__":
    main()