run_squad.py 33.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
16
""" Finetuning the library models for question-answering on SQuAD (DistilBERT, Bert, XLM, XLNet)."""
17
18
19


import argparse
Aymeric Augustin's avatar
Aymeric Augustin committed
20
import glob
21
22
23
import logging
import os
import random
24
import timeit
Aymeric Augustin's avatar
Aymeric Augustin committed
25

26
27
import numpy as np
import torch
28
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
29
from torch.utils.data.distributed import DistributedSampler
30
from tqdm import tqdm, trange
31

32
from transformers import (
33
    MODEL_FOR_QUESTION_ANSWERING_MAPPING,
34
    WEIGHTS_NAME,
Aymeric Augustin's avatar
Aymeric Augustin committed
35
    AdamW,
36
37
38
    AutoConfig,
    AutoModelForQuestionAnswering,
    AutoTokenizer,
Aymeric Augustin's avatar
Aymeric Augustin committed
39
40
    get_linear_schedule_with_warmup,
    squad_convert_examples_to_features,
41
)
Aymeric Augustin's avatar
Aymeric Augustin committed
42
43
44
45
46
47
48
49
50
51
from transformers.data.metrics.squad_metrics import (
    compute_predictions_log_probs,
    compute_predictions_logits,
    squad_evaluate,
)
from transformers.data.processors.squad import SquadResult, SquadV1Processor, SquadV2Processor


try:
    from torch.utils.tensorboard import SummaryWriter
52
except ImportError:
Aymeric Augustin's avatar
Aymeric Augustin committed
53
    from tensorboardX import SummaryWriter
thomwolf's avatar
thomwolf committed
54

55
56
57

logger = logging.getLogger(__name__)

58
59
MODEL_CONFIG_CLASSES = list(MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
thomwolf's avatar
thomwolf committed
60

61
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in MODEL_CONFIG_CLASSES), (),)
thomwolf's avatar
thomwolf committed
62

63

thomwolf's avatar
thomwolf committed
64
65
66
67
68
69
70
def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

71

72
73
def to_list(tensor):
    return tensor.detach().cpu().tolist()
thomwolf's avatar
thomwolf committed
74

75

76
def train(args, train_dataset, model, tokenizer):
thomwolf's avatar
thomwolf committed
77
78
79
80
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

81
    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
82
83
    train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
thomwolf's avatar
thomwolf committed
84
85

    if args.max_steps > 0:
86
        t_total = args.max_steps
87
        args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
thomwolf's avatar
thomwolf committed
88
    else:
89
        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
thomwolf's avatar
thomwolf committed
90

91
    # Prepare optimizer and schedule (linear warmup and decay)
92
    no_decay = ["bias", "LayerNorm.weight"]
thomwolf's avatar
thomwolf committed
93
    optimizer_grouped_parameters = [
94
95
96
97
98
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
LysandreJik's avatar
Cleanup  
LysandreJik committed
99
    ]
100
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
101
    scheduler = get_linear_schedule_with_warmup(
102
103
        optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
    )
104
105

    # Check if saved optimizer or scheduler states exist
106
107
108
    if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
        os.path.join(args.model_name_or_path, "scheduler.pt")
    ):
109
        # Load in optimizer and scheduler states
110
111
        optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))
LysandreJik's avatar
Cleanup  
LysandreJik committed
112

thomwolf's avatar
thomwolf committed
113
114
115
116
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
117
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
118

119
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
thomwolf's avatar
thomwolf committed
120

121
122
123
124
    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

thomwolf's avatar
thomwolf committed
125
126
    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
127
128
129
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
        )
thomwolf's avatar
thomwolf committed
130

thomwolf's avatar
thomwolf committed
131
132
133
134
    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
135
136
137
138
139
140
141
142
    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size
        * args.gradient_accumulation_steps
        * (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
143
    logger.info("  Total optimization steps = %d", t_total)
thomwolf's avatar
thomwolf committed
144

Lysandre's avatar
Lysandre committed
145
    global_step = 1
146
147
148
149
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if os.path.exists(args.model_name_or_path):
150
151
152
153
154
155
156
157
158
159
160
161
162
        try:
            # set global_step to gobal_step of last saved checkpoint from model path
            checkpoint_suffix = args.model_name_or_path.split("-")[-1].split("/")[0]
            global_step = int(checkpoint_suffix)
            epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
            steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)

            logger.info("  Continuing training from checkpoint, will skip to saved global_step")
            logger.info("  Continuing training from epoch %d", epochs_trained)
            logger.info("  Continuing training from global step %d", global_step)
            logger.info("  Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
        except ValueError:
            logger.info("  Starting fine-tuning.")
163

thomwolf's avatar
thomwolf committed
164
    tr_loss, logging_loss = 0.0, 0.0
165
    model.zero_grad()
166
167
168
    train_iterator = trange(
        epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]
    )
169
    # Added here for reproductibility
170
171
    set_seed(args)

172
    for _ in train_iterator:
173
        epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
174
        for step, batch in enumerate(epoch_iterator):
175
176
177
178
179
180

            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

181
            model.train()
thomwolf's avatar
thomwolf committed
182
            batch = tuple(t.to(args.device) for t in batch)
LysandreJik's avatar
Cleanup  
LysandreJik committed
183
184

            inputs = {
185
186
                "input_ids": batch[0],
                "attention_mask": batch[1],
187
                "token_type_ids": batch[2],
188
189
                "start_positions": batch[3],
                "end_positions": batch[4],
LysandreJik's avatar
Cleanup  
LysandreJik committed
190
191
            }

192
            if args.model_type in ["xlm", "roberta", "distilbert", "camembert"]:
193
194
                del inputs["token_type_ids"]

195
196
            if args.model_type in ["xlnet", "xlm"]:
                inputs.update({"cls_index": batch[5], "p_mask": batch[6]})
197
                if args.version_2_with_negative:
198
                    inputs.update({"is_impossible": batch[7]})
199
200
201
202
203
                if hasattr(model, "config") and hasattr(model.config, "lang2id"):
                    inputs.update(
                        {"langs": (torch.ones(batch[0].shape, dtype=torch.int64) * args.lang_id).to(args.device)}
                    )

Peiqin Lin's avatar
typos  
Peiqin Lin committed
204
            outputs = model(**inputs)
205
206
            # model outputs are always tuple in transformers (see doc)
            loss = outputs[0]
thomwolf's avatar
thomwolf committed
207

208
            if args.n_gpu > 1:
209
                loss = loss.mean()  # mean() to average on multi-gpu parallel (not distributed) training
210
211
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
thomwolf's avatar
thomwolf committed
212

213
214
215
216
217
218
219
220
            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
221
                if args.fp16:
222
                    torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
223
                else:
224
                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
225

226
                optimizer.step()
227
                scheduler.step()  # Update learning rate schedule
228
229
230
                model.zero_grad()
                global_step += 1

LysandreJik's avatar
Cleanup  
LysandreJik committed
231
                # Log metrics
232
                if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
233
234
                    # Only evaluate when single GPU otherwise metrics may not average well
                    if args.local_rank == -1 and args.evaluate_during_training:
235
236
                        results = evaluate(args, model, tokenizer)
                        for key, value in results.items():
237
238
239
                            tb_writer.add_scalar("eval_{}".format(key), value, global_step)
                    tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step)
240
241
                    logging_loss = tr_loss

LysandreJik's avatar
Cleanup  
LysandreJik committed
242
                # Save model checkpoint
243
                if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
244
                    output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
245
246
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
247
                    # Take care of distributed/parallel training
248
                    model_to_save = model.module if hasattr(model, "module") else model
249
                    model_to_save.save_pretrained(output_dir)
250
251
                    tokenizer.save_pretrained(output_dir)

252
                    torch.save(args, os.path.join(output_dir, "training_args.bin"))
253
254
                    logger.info("Saving model checkpoint to %s", output_dir)

255
256
257
                    torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
                    logger.info("Saving optimizer and scheduler states to %s", output_dir)
258

259
260
261
262
263
264
265
            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

thomwolf's avatar
thomwolf committed
266
267
268
    if args.local_rank in [-1, 0]:
        tb_writer.close()

269
270
271
272
    return global_step, tr_loss / global_step


def evaluate(args, model, tokenizer, prefix=""):
273
    dataset, examples, features = load_and_cache_examples(args, tokenizer, evaluate=True, output_examples=True)
274
275
276
277
278

    if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
        os.makedirs(args.output_dir)

    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
LysandreJik's avatar
Cleanup  
LysandreJik committed
279

280
    # Note that DistributedSampler samples randomly
281
    eval_sampler = SequentialSampler(dataset)
282
    eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
283

ronakice's avatar
ronakice committed
284
    # multi-gpu evaluate
285
    if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel):
ronakice's avatar
ronakice committed
286
287
        model = torch.nn.DataParallel(model)

288
289
290
291
    # Eval!
    logger.info("***** Running evaluation {} *****".format(prefix))
    logger.info("  Num examples = %d", len(dataset))
    logger.info("  Batch size = %d", args.eval_batch_size)
LysandreJik's avatar
Cleanup  
LysandreJik committed
292

293
    all_results = []
294
    start_time = timeit.default_timer()
LysandreJik's avatar
Cleanup  
LysandreJik committed
295

296
297
298
    for batch in tqdm(eval_dataloader, desc="Evaluating"):
        model.eval()
        batch = tuple(t.to(args.device) for t in batch)
LysandreJik's avatar
Cleanup  
LysandreJik committed
299

300
        with torch.no_grad():
LysandreJik's avatar
LysandreJik committed
301
            inputs = {
302
303
                "input_ids": batch[0],
                "attention_mask": batch[1],
304
                "token_type_ids": batch[2],
LysandreJik's avatar
LysandreJik committed
305
            }
306

307
            if args.model_type in ["xlm", "roberta", "distilbert", "camembert"]:
308
309
                del inputs["token_type_ids"]

310
            feature_indices = batch[3]
311

LysandreJik's avatar
Cleanup  
LysandreJik committed
312
            # XLNet and XLM use more arguments for their predictions
313
314
            if args.model_type in ["xlnet", "xlm"]:
                inputs.update({"cls_index": batch[4], "p_mask": batch[5]})
315
316
317
318
319
                # for lang_id-sensitive xlm models
                if hasattr(model, "config") and hasattr(model.config, "lang2id"):
                    inputs.update(
                        {"langs": (torch.ones(batch[0].shape, dtype=torch.int64) * args.lang_id).to(args.device)}
                    )
LysandreJik's avatar
Cleanup  
LysandreJik committed
320

321
322
            outputs = model(**inputs)

323
324
325
        for i, feature_index in enumerate(feature_indices):
            # TODO: i and feature_index are the same number! Simplify by removing enumerate?
            eval_feature = features[feature_index.item()]
326
            unique_id = int(eval_feature.unique_id)
LysandreJik's avatar
LysandreJik committed
327

LysandreJik's avatar
LysandreJik committed
328
329
            output = [to_list(output[i]) for output in outputs]

LysandreJik's avatar
Cleanup  
LysandreJik committed
330
331
            # Some models (XLNet, XLM) use 5 arguments for their predictions, while the other "simpler"
            # models only use two.
LysandreJik's avatar
LysandreJik committed
332
333
334
335
            if len(output) >= 5:
                start_logits = output[0]
                start_top_index = output[1]
                end_logits = output[2]
LysandreJik's avatar
Cleanup  
LysandreJik committed
336
                end_top_index = output[3]
LysandreJik's avatar
LysandreJik committed
337
338
339
                cls_logits = output[4]

                result = SquadResult(
340
341
342
                    unique_id,
                    start_logits,
                    end_logits,
343
344
                    start_top_index=start_top_index,
                    end_top_index=end_top_index,
345
                    cls_logits=cls_logits,
LysandreJik's avatar
LysandreJik committed
346
347
348
349
                )

            else:
                start_logits, end_logits = output
350
                result = SquadResult(unique_id, start_logits, end_logits)
LysandreJik's avatar
LysandreJik committed
351

352
            all_results.append(result)
353

354
    evalTime = timeit.default_timer() - start_time
355
    logger.info("  Evaluation done in total %f secs (%f sec per example)", evalTime, evalTime / len(dataset))
356

thomwolf's avatar
thomwolf committed
357
    # Compute predictions
358
359
    output_prediction_file = os.path.join(args.output_dir, "predictions_{}.json".format(prefix))
    output_nbest_file = os.path.join(args.output_dir, "nbest_predictions_{}.json".format(prefix))
LysandreJik's avatar
Cleanup  
LysandreJik committed
360

361
    if args.version_2_with_negative:
362
        output_null_log_odds_file = os.path.join(args.output_dir, "null_odds_{}.json".format(prefix))
363
364
    else:
        output_null_log_odds_file = None
365

LysandreJik's avatar
Cleanup  
LysandreJik committed
366
    # XLNet and XLM use a more complex post-processing procedure
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
    if args.model_type in ["xlnet", "xlm"]:
        start_n_top = model.config.start_n_top if hasattr(model, "config") else model.module.config.start_n_top
        end_n_top = model.config.end_n_top if hasattr(model, "config") else model.module.config.end_n_top

        predictions = compute_predictions_log_probs(
            examples,
            features,
            all_results,
            args.n_best_size,
            args.max_answer_length,
            output_prediction_file,
            output_nbest_file,
            output_null_log_odds_file,
            start_n_top,
            end_n_top,
            args.version_2_with_negative,
            tokenizer,
            args.verbose_logging,
        )
386
    else:
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
        predictions = compute_predictions_logits(
            examples,
            features,
            all_results,
            args.n_best_size,
            args.max_answer_length,
            args.do_lower_case,
            output_prediction_file,
            output_nbest_file,
            output_null_log_odds_file,
            args.verbose_logging,
            args.version_2_with_negative,
            args.null_score_diff_threshold,
            tokenizer,
        )
402

LysandreJik's avatar
Cleanup  
LysandreJik committed
403
    # Compute the F1 and exact scores.
LysandreJik's avatar
LysandreJik committed
404
    results = squad_evaluate(examples, predictions)
405
406
    return results

407

408
def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False):
VictorSanh's avatar
VictorSanh committed
409
    if args.local_rank not in [-1, 0] and not evaluate:
410
411
        # Make sure only the first process in distributed training process the dataset, and the others will use the cache
        torch.distributed.barrier()
thomwolf's avatar
thomwolf committed
412

413
    # Load data features from cache or dataset file
LysandreJik's avatar
Cleanup  
LysandreJik committed
414
    input_dir = args.data_dir if args.data_dir else "."
415
416
417
418
419
420
421
    cached_features_file = os.path.join(
        input_dir,
        "cached_{}_{}_{}".format(
            "dev" if evaluate else "train",
            list(filter(None, args.model_name_or_path.split("/"))).pop(),
            str(args.max_seq_length),
        ),
LysandreJik's avatar
Cleanup  
LysandreJik committed
422
423
424
    )

    # Init features and dataset from cache if it exists
425
    if os.path.exists(cached_features_file) and not args.overwrite_cache:
426
        logger.info("Loading features from cached file %s", cached_features_file)
427
        features_and_dataset = torch.load(cached_features_file)
428
429
430
431
432
        features, dataset, examples = (
            features_and_dataset["features"],
            features_and_dataset["dataset"],
            features_and_dataset["examples"],
        )
thomwolf's avatar
thomwolf committed
433
    else:
LysandreJik's avatar
Cleanup  
LysandreJik committed
434
        logger.info("Creating features from dataset file at %s", input_dir)
Lysandre's avatar
Lysandre committed
435

436
        if not args.data_dir and ((evaluate and not args.predict_file) or (not evaluate and not args.train_file)):
LysandreJik's avatar
Cleanup  
LysandreJik committed
437
438
439
            try:
                import tensorflow_datasets as tfds
            except ImportError:
440
                raise ImportError("If not data_dir is specified, tensorflow_datasets needs to be installed.")
LysandreJik's avatar
Cleanup  
LysandreJik committed
441
442

            if args.version_2_with_negative:
443
                logger.warn("tensorflow_datasets does not handle version 2 of SQuAD.")
LysandreJik's avatar
Cleanup  
LysandreJik committed
444
445

            tfds_examples = tfds.load("squad")
446
            examples = SquadV1Processor().get_examples_from_dataset(tfds_examples, evaluate=evaluate)
LysandreJik's avatar
Cleanup  
LysandreJik committed
447
448
        else:
            processor = SquadV2Processor() if args.version_2_with_negative else SquadV1Processor()
449
450
451
452
            if evaluate:
                examples = processor.get_dev_examples(args.data_dir, filename=args.predict_file)
            else:
                examples = processor.get_train_examples(args.data_dir, filename=args.train_file)
LysandreJik's avatar
LysandreJik committed
453

454
        features, dataset = squad_convert_examples_to_features(
Lysandre's avatar
Lysandre committed
455
456
457
458
459
460
            examples=examples,
            tokenizer=tokenizer,
            max_seq_length=args.max_seq_length,
            doc_stride=args.doc_stride,
            max_query_length=args.max_query_length,
            is_training=not evaluate,
461
            return_dataset="pt",
erenup's avatar
erenup committed
462
            threads=args.threads,
Lysandre's avatar
Lysandre committed
463
464
        )

thomwolf's avatar
thomwolf committed
465
        if args.local_rank in [-1, 0]:
466
            logger.info("Saving features into cached file %s", cached_features_file)
467
            torch.save({"features": features, "dataset": dataset, "examples": examples}, cached_features_file)
thomwolf's avatar
thomwolf committed
468

VictorSanh's avatar
VictorSanh committed
469
    if args.local_rank == 0 and not evaluate:
470
471
        # Make sure only the first process in distributed training process the dataset, and the others will use the cache
        torch.distributed.barrier()
thomwolf's avatar
thomwolf committed
472

473
474
    if output_examples:
        return dataset, examples, features
thomwolf's avatar
thomwolf committed
475
476
    return dataset

477
478
479
480

def main():
    parser = argparse.ArgumentParser()

481
    # Required parameters
482
483
484
485
486
    parser.add_argument(
        "--model_type",
        default=None,
        type=str,
        required=True,
487
        help="Model type selected in the list: " + ", ".join(MODEL_TYPES),
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
    )
    parser.add_argument(
        "--model_name_or_path",
        default=None,
        type=str,
        required=True,
        help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
    )
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help="The output directory where the model checkpoints and predictions will be written.",
    )
503

504
    # Other parameters
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        help="The input data dir. Should contain the .json files for the task."
        + "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
    )
    parser.add_argument(
        "--train_file",
        default=None,
        type=str,
        help="The input training file. If a data dir is specified, will look for the file there"
        + "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
    )
    parser.add_argument(
        "--predict_file",
        default=None,
        type=str,
        help="The input evaluation file. If a data dir is specified, will look for the file there"
        + "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
    )
    parser.add_argument(
        "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
    )
    parser.add_argument(
        "--tokenizer_name",
        default="",
        type=str,
        help="Pretrained tokenizer name or path if not the same as model_name",
    )
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help="Where do you want to store the pre-trained models downloaded from s3",
    )

    parser.add_argument(
        "--version_2_with_negative",
        action="store_true",
        help="If true, the SQuAD examples contain some that do not have an answer.",
    )
    parser.add_argument(
        "--null_score_diff_threshold",
        type=float,
        default=0.0,
        help="If null_score - best_non_null is greater than the threshold predict null.",
    )

    parser.add_argument(
        "--max_seq_length",
        default=384,
        type=int,
        help="The maximum total input sequence length after WordPiece tokenization. Sequences "
        "longer than this will be truncated, and sequences shorter than this will be padded.",
    )
    parser.add_argument(
        "--doc_stride",
        default=128,
        type=int,
        help="When splitting up a long document into chunks, how much stride to take between chunks.",
    )
    parser.add_argument(
        "--max_query_length",
        default=64,
        type=int,
        help="The maximum number of tokens for the question. Questions longer than this will "
        "be truncated to this length.",
    )
    parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
    parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
    parser.add_argument(
577
        "--evaluate_during_training", action="store_true", help="Run evaluation during training at each logging step."
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
    )
    parser.add_argument(
        "--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
    )

    parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.")
    parser.add_argument(
        "--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation."
    )
    parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help="Number of updates steps to accumulate before performing a backward/update pass.",
    )
    parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
    parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
    parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
    parser.add_argument(
        "--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform."
    )
    parser.add_argument(
        "--max_steps",
        default=-1,
        type=int,
        help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
    )
    parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
    parser.add_argument(
        "--n_best_size",
        default=20,
        type=int,
        help="The total number of n-best predictions to generate in the nbest_predictions.json output file.",
    )
    parser.add_argument(
        "--max_answer_length",
        default=30,
        type=int,
        help="The maximum length of an answer that can be generated. This is needed because the start "
        "and end predictions are not conditioned on one another.",
    )
    parser.add_argument(
        "--verbose_logging",
        action="store_true",
        help="If true, all of the warnings related to data processing will be printed. "
        "A number of warnings are expected for a normal SQuAD evaluation.",
    )
626
627
628
629
630
631
    parser.add_argument(
        "--lang_id",
        default=0,
        type=int,
        help="language id of input for language-specific xlm models (see tokenization_xlm.PRETRAINED_INIT_CONFIGURATION)",
    )
632

633
634
    parser.add_argument("--logging_steps", type=int, default=500, help="Log every X updates steps.")
    parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.")
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
    parser.add_argument(
        "--eval_all_checkpoints",
        action="store_true",
        help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
    )
    parser.add_argument("--no_cuda", action="store_true", help="Whether not to use CUDA when available")
    parser.add_argument(
        "--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory"
    )
    parser.add_argument(
        "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
    )
    parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")

    parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus")
    parser.add_argument(
        "--fp16",
        action="store_true",
        help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
    )
    parser.add_argument(
        "--fp16_opt_level",
        type=str,
        default="O1",
        help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
        "See details at https://nvidia.github.io/apex/amp.html",
    )
    parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
    parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")

    parser.add_argument("--threads", type=int, default=1, help="multiple threads for converting example to features")
666
667
    args = parser.parse_args()

668
669
670
671
672
673
674
    if args.doc_stride >= args.max_seq_length - args.max_query_length:
        logger.warning(
            "WARNING - You've set a doc stride which may be superior to the document length in some "
            "examples. This could result in errors when building features from the examples. Please reduce the doc "
            "stride or increase the maximum length to ensure the features are correctly built."
        )

675
676
677
678
679
680
    if (
        os.path.exists(args.output_dir)
        and os.listdir(args.output_dir)
        and args.do_train
        and not args.overwrite_output_dir
    ):
681
        raise ValueError(
682
683
684
685
            "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
                args.output_dir
            )
        )
thomwolf's avatar
thomwolf committed
686

687
    # Setup distant debugging if needed
688
689
690
    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
691

692
        print("Waiting for debugger attach")
693
        ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
694
695
        ptvsd.wait_for_attach()

thomwolf's avatar
thomwolf committed
696
    # Setup CUDA, GPU & distributed training
697
    if args.local_rank == -1 or args.no_cuda:
698
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
699
        args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()
thomwolf's avatar
thomwolf committed
700
    else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
701
702
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
703
        torch.distributed.init_process_group(backend="nccl")
thomwolf's avatar
thomwolf committed
704
705
        args.n_gpu = 1
    args.device = device
706

thomwolf's avatar
thomwolf committed
707
    # Setup logging
708
709
710
711
712
713
714
715
716
717
718
719
720
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
    )
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        args.local_rank,
        device,
        args.n_gpu,
        bool(args.local_rank != -1),
        args.fp16,
    )
721

722
723
    # Set seed
    set_seed(args)
724

thomwolf's avatar
thomwolf committed
725
    # Load pretrained model and tokenizer
726
    if args.local_rank not in [-1, 0]:
727
728
        # Make sure only the first process in distributed training will download model & vocab
        torch.distributed.barrier()
729

730
    args.model_type = args.model_type.lower()
731
    config = AutoConfig.from_pretrained(
732
733
734
        args.config_name if args.config_name else args.model_name_or_path,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )
735
    tokenizer = AutoTokenizer.from_pretrained(
736
737
738
739
        args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
        do_lower_case=args.do_lower_case,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )
740
    model = AutoModelForQuestionAnswering.from_pretrained(
741
742
743
744
745
        args.model_name_or_path,
        from_tf=bool(".ckpt" in args.model_name_or_path),
        config=config,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )
746
747

    if args.local_rank == 0:
748
749
        # Make sure only the first process in distributed training will download model & vocab
        torch.distributed.barrier()
750

thomwolf's avatar
thomwolf committed
751
    model.to(args.device)
752

753
754
    logger.info("Training/evaluation parameters %s", args)

Simon Layton's avatar
Simon Layton committed
755
756
757
758
759
760
    # Before we do anything with models, we want to ensure that we get fp16 execution of torch.einsum if args.fp16 is set.
    # Otherwise it'll default to "promote" mode, and we'll get fp32 operations. Note that running `--fp16_opt_level="O2"` will
    # remove the need for this code, but it is still valid.
    if args.fp16:
        try:
            import apex
761
762

            apex.amp.register_half_function(torch, "einsum")
Simon Layton's avatar
Simon Layton committed
763
        except ImportError:
764
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
Simon Layton's avatar
Simon Layton committed
765

thomwolf's avatar
thomwolf committed
766
    # Training
767
    if args.do_train:
768
        train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False)
769
        global_step, tr_loss = train(args, train_dataset, model, tokenizer)
770
        logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
771

thomwolf's avatar
thomwolf committed
772
    # Save the trained model and the tokenizer
Peng Qi's avatar
Peng Qi committed
773
    if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
774
775
776
777
778
779
780
        # Create output directory if needed
        if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
            os.makedirs(args.output_dir)

        logger.info("Saving model checkpoint to %s", args.output_dir)
        # Save a trained model, configuration and tokenizer using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
781
        # Take care of distributed/parallel training
782
        model_to_save = model.module if hasattr(model, "module") else model
783
784
        model_to_save.save_pretrained(args.output_dir)
        tokenizer.save_pretrained(args.output_dir)
785
786

        # Good practice: save your training arguments together with the trained model
787
        torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
788

789
        # Load a trained model and vocabulary that you have fine-tuned
790
791
        model = AutoModelForQuestionAnswering.from_pretrained(args.output_dir)  # , force_download=True)
        tokenizer = AutoTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
792
793
        model.to(args.device)

thomwolf's avatar
thomwolf committed
794
    # Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
795
796
    results = {}
    if args.do_eval and args.local_rank in [-1, 0]:
797
798
799
800
        if args.do_train:
            logger.info("Loading checkpoints saved during training for evaluation")
            checkpoints = [args.output_dir]
            if args.eval_all_checkpoints:
801
802
803
804
                checkpoints = list(
                    os.path.dirname(c)
                    for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
                )
805
806
807
808
                logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN)  # Reduce model loading logs
        else:
            logger.info("Loading checkpoint %s for evaluation", args.model_name_or_path)
            checkpoints = [args.model_name_or_path]
thomwolf's avatar
thomwolf committed
809

810
        logger.info("Evaluate the following checkpoints: %s", checkpoints)
thomwolf's avatar
thomwolf committed
811

812
        for checkpoint in checkpoints:
thomwolf's avatar
thomwolf committed
813
            # Reload the model
814
            global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
815
            model = AutoModelForQuestionAnswering.from_pretrained(checkpoint)  # , force_download=True)
816
            model.to(args.device)
thomwolf's avatar
thomwolf committed
817
818

            # Evaluate
819
            result = evaluate(args, model, tokenizer, prefix=global_step)
thomwolf's avatar
thomwolf committed
820

821
            result = dict((k + ("_{}".format(global_step) if global_step else ""), v) for k, v in result.items())
822
            results.update(result)
thomwolf's avatar
thomwolf committed
823

824
    logger.info("Results: {}".format(results))
thomwolf's avatar
thomwolf committed
825

826
    return results
827
828
829
830


if __name__ == "__main__":
    main()