run_squad.py 33.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
16
""" Finetuning the library models for question-answering on SQuAD (DistilBERT, Bert, XLM, XLNet)."""
17
18
19


import argparse
Aymeric Augustin's avatar
Aymeric Augustin committed
20
import glob
21
22
23
import logging
import os
import random
24
import timeit
Aymeric Augustin's avatar
Aymeric Augustin committed
25

26
27
import numpy as np
import torch
28
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
29
from torch.utils.data.distributed import DistributedSampler
30
from tqdm import tqdm, trange
31

32
from transformers import (
33
    MODEL_FOR_QUESTION_ANSWERING_MAPPING,
34
    WEIGHTS_NAME,
Aymeric Augustin's avatar
Aymeric Augustin committed
35
    AdamW,
36
37
38
    AutoConfig,
    AutoModelForQuestionAnswering,
    AutoTokenizer,
Aymeric Augustin's avatar
Aymeric Augustin committed
39
40
    get_linear_schedule_with_warmup,
    squad_convert_examples_to_features,
41
)
Aymeric Augustin's avatar
Aymeric Augustin committed
42
43
44
45
46
47
48
49
50
51
from transformers.data.metrics.squad_metrics import (
    compute_predictions_log_probs,
    compute_predictions_logits,
    squad_evaluate,
)
from transformers.data.processors.squad import SquadResult, SquadV1Processor, SquadV2Processor


try:
    from torch.utils.tensorboard import SummaryWriter
52
except ImportError:
Aymeric Augustin's avatar
Aymeric Augustin committed
53
    from tensorboardX import SummaryWriter
thomwolf's avatar
thomwolf committed
54

55
56
57

logger = logging.getLogger(__name__)

58
59
MODEL_CONFIG_CLASSES = list(MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
thomwolf's avatar
thomwolf committed
60

61

thomwolf's avatar
thomwolf committed
62
63
64
65
66
67
68
def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

69

70
71
def to_list(tensor):
    return tensor.detach().cpu().tolist()
thomwolf's avatar
thomwolf committed
72

73

74
def train(args, train_dataset, model, tokenizer):
thomwolf's avatar
thomwolf committed
75
76
77
78
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

79
    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
80
81
    train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
thomwolf's avatar
thomwolf committed
82
83

    if args.max_steps > 0:
84
        t_total = args.max_steps
85
        args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
thomwolf's avatar
thomwolf committed
86
    else:
87
        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
thomwolf's avatar
thomwolf committed
88

89
    # Prepare optimizer and schedule (linear warmup and decay)
90
    no_decay = ["bias", "LayerNorm.weight"]
thomwolf's avatar
thomwolf committed
91
    optimizer_grouped_parameters = [
92
93
94
95
96
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
LysandreJik's avatar
Cleanup  
LysandreJik committed
97
    ]
98
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
99
    scheduler = get_linear_schedule_with_warmup(
100
101
        optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
    )
102
103

    # Check if saved optimizer or scheduler states exist
104
105
106
    if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
        os.path.join(args.model_name_or_path, "scheduler.pt")
    ):
107
        # Load in optimizer and scheduler states
108
109
        optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))
LysandreJik's avatar
Cleanup  
LysandreJik committed
110

thomwolf's avatar
thomwolf committed
111
112
113
114
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
115
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
116

117
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
thomwolf's avatar
thomwolf committed
118

119
120
121
122
    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

thomwolf's avatar
thomwolf committed
123
124
    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
125
126
127
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
        )
thomwolf's avatar
thomwolf committed
128

thomwolf's avatar
thomwolf committed
129
130
131
132
    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
133
134
135
136
137
138
139
140
    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size
        * args.gradient_accumulation_steps
        * (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
141
    logger.info("  Total optimization steps = %d", t_total)
thomwolf's avatar
thomwolf committed
142

Lysandre's avatar
Lysandre committed
143
    global_step = 1
144
145
146
147
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if os.path.exists(args.model_name_or_path):
148
149
150
151
152
153
154
155
156
157
158
159
160
        try:
            # set global_step to gobal_step of last saved checkpoint from model path
            checkpoint_suffix = args.model_name_or_path.split("-")[-1].split("/")[0]
            global_step = int(checkpoint_suffix)
            epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
            steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)

            logger.info("  Continuing training from checkpoint, will skip to saved global_step")
            logger.info("  Continuing training from epoch %d", epochs_trained)
            logger.info("  Continuing training from global step %d", global_step)
            logger.info("  Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
        except ValueError:
            logger.info("  Starting fine-tuning.")
161

thomwolf's avatar
thomwolf committed
162
    tr_loss, logging_loss = 0.0, 0.0
163
    model.zero_grad()
164
165
166
    train_iterator = trange(
        epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]
    )
167
    # Added here for reproductibility
168
169
    set_seed(args)

170
    for _ in train_iterator:
171
        epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
172
        for step, batch in enumerate(epoch_iterator):
173
174
175
176
177
178

            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

179
            model.train()
thomwolf's avatar
thomwolf committed
180
            batch = tuple(t.to(args.device) for t in batch)
LysandreJik's avatar
Cleanup  
LysandreJik committed
181
182

            inputs = {
183
184
                "input_ids": batch[0],
                "attention_mask": batch[1],
185
                "token_type_ids": batch[2],
186
187
                "start_positions": batch[3],
                "end_positions": batch[4],
LysandreJik's avatar
Cleanup  
LysandreJik committed
188
189
            }

190
            if args.model_type in ["xlm", "roberta", "distilbert", "camembert"]:
191
192
                del inputs["token_type_ids"]

193
194
            if args.model_type in ["xlnet", "xlm"]:
                inputs.update({"cls_index": batch[5], "p_mask": batch[6]})
195
                if args.version_2_with_negative:
196
                    inputs.update({"is_impossible": batch[7]})
197
198
199
200
201
                if hasattr(model, "config") and hasattr(model.config, "lang2id"):
                    inputs.update(
                        {"langs": (torch.ones(batch[0].shape, dtype=torch.int64) * args.lang_id).to(args.device)}
                    )

Stas Bekman's avatar
Stas Bekman committed
202
203
204
            if isinstance(model, torch.nn.DataParallel):
                inputs["return_tuple"] = True

Peiqin Lin's avatar
typos  
Peiqin Lin committed
205
            outputs = model(**inputs)
206
207
            # model outputs are always tuple in transformers (see doc)
            loss = outputs[0]
thomwolf's avatar
thomwolf committed
208

209
            if args.n_gpu > 1:
210
                loss = loss.mean()  # mean() to average on multi-gpu parallel (not distributed) training
211
212
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
thomwolf's avatar
thomwolf committed
213

214
215
216
217
218
219
220
221
            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
222
                if args.fp16:
223
                    torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
224
                else:
225
                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
226

227
                optimizer.step()
228
                scheduler.step()  # Update learning rate schedule
229
230
231
                model.zero_grad()
                global_step += 1

LysandreJik's avatar
Cleanup  
LysandreJik committed
232
                # Log metrics
233
                if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
234
235
                    # Only evaluate when single GPU otherwise metrics may not average well
                    if args.local_rank == -1 and args.evaluate_during_training:
236
237
                        results = evaluate(args, model, tokenizer)
                        for key, value in results.items():
238
239
240
                            tb_writer.add_scalar("eval_{}".format(key), value, global_step)
                    tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step)
241
242
                    logging_loss = tr_loss

LysandreJik's avatar
Cleanup  
LysandreJik committed
243
                # Save model checkpoint
244
                if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
245
                    output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
246
                    # Take care of distributed/parallel training
247
                    model_to_save = model.module if hasattr(model, "module") else model
248
                    model_to_save.save_pretrained(output_dir)
249
250
                    tokenizer.save_pretrained(output_dir)

251
                    torch.save(args, os.path.join(output_dir, "training_args.bin"))
252
253
                    logger.info("Saving model checkpoint to %s", output_dir)

254
255
256
                    torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
                    logger.info("Saving optimizer and scheduler states to %s", output_dir)
257

258
259
260
261
262
263
264
            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

thomwolf's avatar
thomwolf committed
265
266
267
    if args.local_rank in [-1, 0]:
        tb_writer.close()

268
269
270
271
    return global_step, tr_loss / global_step


def evaluate(args, model, tokenizer, prefix=""):
272
    dataset, examples, features = load_and_cache_examples(args, tokenizer, evaluate=True, output_examples=True)
273
274
275
276
277

    if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
        os.makedirs(args.output_dir)

    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
LysandreJik's avatar
Cleanup  
LysandreJik committed
278

279
    # Note that DistributedSampler samples randomly
280
    eval_sampler = SequentialSampler(dataset)
281
    eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
282

ronakice's avatar
ronakice committed
283
    # multi-gpu evaluate
284
    if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel):
ronakice's avatar
ronakice committed
285
286
        model = torch.nn.DataParallel(model)

287
288
289
290
    # Eval!
    logger.info("***** Running evaluation {} *****".format(prefix))
    logger.info("  Num examples = %d", len(dataset))
    logger.info("  Batch size = %d", args.eval_batch_size)
LysandreJik's avatar
Cleanup  
LysandreJik committed
291

292
    all_results = []
293
    start_time = timeit.default_timer()
LysandreJik's avatar
Cleanup  
LysandreJik committed
294

295
296
297
    for batch in tqdm(eval_dataloader, desc="Evaluating"):
        model.eval()
        batch = tuple(t.to(args.device) for t in batch)
LysandreJik's avatar
Cleanup  
LysandreJik committed
298

299
        with torch.no_grad():
LysandreJik's avatar
LysandreJik committed
300
            inputs = {
301
302
                "input_ids": batch[0],
                "attention_mask": batch[1],
303
                "token_type_ids": batch[2],
LysandreJik's avatar
LysandreJik committed
304
            }
305

306
            if args.model_type in ["xlm", "roberta", "distilbert", "camembert"]:
307
308
                del inputs["token_type_ids"]

309
            feature_indices = batch[3]
310

LysandreJik's avatar
Cleanup  
LysandreJik committed
311
            # XLNet and XLM use more arguments for their predictions
312
313
            if args.model_type in ["xlnet", "xlm"]:
                inputs.update({"cls_index": batch[4], "p_mask": batch[5]})
314
315
316
317
318
                # for lang_id-sensitive xlm models
                if hasattr(model, "config") and hasattr(model.config, "lang2id"):
                    inputs.update(
                        {"langs": (torch.ones(batch[0].shape, dtype=torch.int64) * args.lang_id).to(args.device)}
                    )
LysandreJik's avatar
Cleanup  
LysandreJik committed
319

320
321
            outputs = model(**inputs)

322
323
        for i, feature_index in enumerate(feature_indices):
            eval_feature = features[feature_index.item()]
324
            unique_id = int(eval_feature.unique_id)
LysandreJik's avatar
LysandreJik committed
325

LysandreJik's avatar
LysandreJik committed
326
327
            output = [to_list(output[i]) for output in outputs]

LysandreJik's avatar
Cleanup  
LysandreJik committed
328
329
            # Some models (XLNet, XLM) use 5 arguments for their predictions, while the other "simpler"
            # models only use two.
LysandreJik's avatar
LysandreJik committed
330
331
332
333
            if len(output) >= 5:
                start_logits = output[0]
                start_top_index = output[1]
                end_logits = output[2]
LysandreJik's avatar
Cleanup  
LysandreJik committed
334
                end_top_index = output[3]
LysandreJik's avatar
LysandreJik committed
335
336
337
                cls_logits = output[4]

                result = SquadResult(
338
339
340
                    unique_id,
                    start_logits,
                    end_logits,
341
342
                    start_top_index=start_top_index,
                    end_top_index=end_top_index,
343
                    cls_logits=cls_logits,
LysandreJik's avatar
LysandreJik committed
344
345
346
347
                )

            else:
                start_logits, end_logits = output
348
                result = SquadResult(unique_id, start_logits, end_logits)
LysandreJik's avatar
LysandreJik committed
349

350
            all_results.append(result)
351

352
    evalTime = timeit.default_timer() - start_time
353
    logger.info("  Evaluation done in total %f secs (%f sec per example)", evalTime, evalTime / len(dataset))
354

thomwolf's avatar
thomwolf committed
355
    # Compute predictions
356
357
    output_prediction_file = os.path.join(args.output_dir, "predictions_{}.json".format(prefix))
    output_nbest_file = os.path.join(args.output_dir, "nbest_predictions_{}.json".format(prefix))
LysandreJik's avatar
Cleanup  
LysandreJik committed
358

359
    if args.version_2_with_negative:
360
        output_null_log_odds_file = os.path.join(args.output_dir, "null_odds_{}.json".format(prefix))
361
362
    else:
        output_null_log_odds_file = None
363

LysandreJik's avatar
Cleanup  
LysandreJik committed
364
    # XLNet and XLM use a more complex post-processing procedure
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
    if args.model_type in ["xlnet", "xlm"]:
        start_n_top = model.config.start_n_top if hasattr(model, "config") else model.module.config.start_n_top
        end_n_top = model.config.end_n_top if hasattr(model, "config") else model.module.config.end_n_top

        predictions = compute_predictions_log_probs(
            examples,
            features,
            all_results,
            args.n_best_size,
            args.max_answer_length,
            output_prediction_file,
            output_nbest_file,
            output_null_log_odds_file,
            start_n_top,
            end_n_top,
            args.version_2_with_negative,
            tokenizer,
            args.verbose_logging,
        )
384
    else:
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
        predictions = compute_predictions_logits(
            examples,
            features,
            all_results,
            args.n_best_size,
            args.max_answer_length,
            args.do_lower_case,
            output_prediction_file,
            output_nbest_file,
            output_null_log_odds_file,
            args.verbose_logging,
            args.version_2_with_negative,
            args.null_score_diff_threshold,
            tokenizer,
        )
400

LysandreJik's avatar
Cleanup  
LysandreJik committed
401
    # Compute the F1 and exact scores.
LysandreJik's avatar
LysandreJik committed
402
    results = squad_evaluate(examples, predictions)
403
404
    return results

405

406
def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False):
VictorSanh's avatar
VictorSanh committed
407
    if args.local_rank not in [-1, 0] and not evaluate:
408
409
        # Make sure only the first process in distributed training process the dataset, and the others will use the cache
        torch.distributed.barrier()
thomwolf's avatar
thomwolf committed
410

411
    # Load data features from cache or dataset file
LysandreJik's avatar
Cleanup  
LysandreJik committed
412
    input_dir = args.data_dir if args.data_dir else "."
413
414
415
416
417
418
419
    cached_features_file = os.path.join(
        input_dir,
        "cached_{}_{}_{}".format(
            "dev" if evaluate else "train",
            list(filter(None, args.model_name_or_path.split("/"))).pop(),
            str(args.max_seq_length),
        ),
LysandreJik's avatar
Cleanup  
LysandreJik committed
420
421
422
    )

    # Init features and dataset from cache if it exists
423
    if os.path.exists(cached_features_file) and not args.overwrite_cache:
424
        logger.info("Loading features from cached file %s", cached_features_file)
425
        features_and_dataset = torch.load(cached_features_file)
426
427
428
429
430
        features, dataset, examples = (
            features_and_dataset["features"],
            features_and_dataset["dataset"],
            features_and_dataset["examples"],
        )
thomwolf's avatar
thomwolf committed
431
    else:
LysandreJik's avatar
Cleanup  
LysandreJik committed
432
        logger.info("Creating features from dataset file at %s", input_dir)
Lysandre's avatar
Lysandre committed
433

434
        if not args.data_dir and ((evaluate and not args.predict_file) or (not evaluate and not args.train_file)):
LysandreJik's avatar
Cleanup  
LysandreJik committed
435
436
437
            try:
                import tensorflow_datasets as tfds
            except ImportError:
438
                raise ImportError("If not data_dir is specified, tensorflow_datasets needs to be installed.")
LysandreJik's avatar
Cleanup  
LysandreJik committed
439
440

            if args.version_2_with_negative:
441
                logger.warn("tensorflow_datasets does not handle version 2 of SQuAD.")
LysandreJik's avatar
Cleanup  
LysandreJik committed
442
443

            tfds_examples = tfds.load("squad")
444
            examples = SquadV1Processor().get_examples_from_dataset(tfds_examples, evaluate=evaluate)
LysandreJik's avatar
Cleanup  
LysandreJik committed
445
446
        else:
            processor = SquadV2Processor() if args.version_2_with_negative else SquadV1Processor()
447
448
449
450
            if evaluate:
                examples = processor.get_dev_examples(args.data_dir, filename=args.predict_file)
            else:
                examples = processor.get_train_examples(args.data_dir, filename=args.train_file)
LysandreJik's avatar
LysandreJik committed
451

452
        features, dataset = squad_convert_examples_to_features(
Lysandre's avatar
Lysandre committed
453
454
455
456
457
458
            examples=examples,
            tokenizer=tokenizer,
            max_seq_length=args.max_seq_length,
            doc_stride=args.doc_stride,
            max_query_length=args.max_query_length,
            is_training=not evaluate,
459
            return_dataset="pt",
erenup's avatar
erenup committed
460
            threads=args.threads,
Lysandre's avatar
Lysandre committed
461
462
        )

thomwolf's avatar
thomwolf committed
463
        if args.local_rank in [-1, 0]:
464
            logger.info("Saving features into cached file %s", cached_features_file)
465
            torch.save({"features": features, "dataset": dataset, "examples": examples}, cached_features_file)
thomwolf's avatar
thomwolf committed
466

VictorSanh's avatar
VictorSanh committed
467
    if args.local_rank == 0 and not evaluate:
468
469
        # Make sure only the first process in distributed training process the dataset, and the others will use the cache
        torch.distributed.barrier()
thomwolf's avatar
thomwolf committed
470

471
472
    if output_examples:
        return dataset, examples, features
thomwolf's avatar
thomwolf committed
473
474
    return dataset

475
476
477
478

def main():
    parser = argparse.ArgumentParser()

479
    # Required parameters
480
481
482
483
484
    parser.add_argument(
        "--model_type",
        default=None,
        type=str,
        required=True,
485
        help="Model type selected in the list: " + ", ".join(MODEL_TYPES),
486
487
488
489
490
491
    )
    parser.add_argument(
        "--model_name_or_path",
        default=None,
        type=str,
        required=True,
492
        help="Path to pretrained model or model identifier from huggingface.co/models",
493
494
495
496
497
498
499
500
    )
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help="The output directory where the model checkpoints and predictions will be written.",
    )
501

502
    # Other parameters
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        help="The input data dir. Should contain the .json files for the task."
        + "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
    )
    parser.add_argument(
        "--train_file",
        default=None,
        type=str,
        help="The input training file. If a data dir is specified, will look for the file there"
        + "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
    )
    parser.add_argument(
        "--predict_file",
        default=None,
        type=str,
        help="The input evaluation file. If a data dir is specified, will look for the file there"
        + "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
    )
    parser.add_argument(
        "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
    )
    parser.add_argument(
        "--tokenizer_name",
        default="",
        type=str,
        help="Pretrained tokenizer name or path if not the same as model_name",
    )
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help="Where do you want to store the pre-trained models downloaded from s3",
    )

    parser.add_argument(
        "--version_2_with_negative",
        action="store_true",
        help="If true, the SQuAD examples contain some that do not have an answer.",
    )
    parser.add_argument(
        "--null_score_diff_threshold",
        type=float,
        default=0.0,
        help="If null_score - best_non_null is greater than the threshold predict null.",
    )

    parser.add_argument(
        "--max_seq_length",
        default=384,
        type=int,
        help="The maximum total input sequence length after WordPiece tokenization. Sequences "
        "longer than this will be truncated, and sequences shorter than this will be padded.",
    )
    parser.add_argument(
        "--doc_stride",
        default=128,
        type=int,
        help="When splitting up a long document into chunks, how much stride to take between chunks.",
    )
    parser.add_argument(
        "--max_query_length",
        default=64,
        type=int,
        help="The maximum number of tokens for the question. Questions longer than this will "
        "be truncated to this length.",
    )
    parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
    parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
    parser.add_argument(
575
        "--evaluate_during_training", action="store_true", help="Run evaluation during training at each logging step."
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
    )
    parser.add_argument(
        "--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
    )

    parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.")
    parser.add_argument(
        "--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation."
    )
    parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help="Number of updates steps to accumulate before performing a backward/update pass.",
    )
    parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
    parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
    parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
    parser.add_argument(
        "--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform."
    )
    parser.add_argument(
        "--max_steps",
        default=-1,
        type=int,
        help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
    )
    parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
    parser.add_argument(
        "--n_best_size",
        default=20,
        type=int,
        help="The total number of n-best predictions to generate in the nbest_predictions.json output file.",
    )
    parser.add_argument(
        "--max_answer_length",
        default=30,
        type=int,
        help="The maximum length of an answer that can be generated. This is needed because the start "
        "and end predictions are not conditioned on one another.",
    )
    parser.add_argument(
        "--verbose_logging",
        action="store_true",
        help="If true, all of the warnings related to data processing will be printed. "
        "A number of warnings are expected for a normal SQuAD evaluation.",
    )
624
625
626
627
628
629
    parser.add_argument(
        "--lang_id",
        default=0,
        type=int,
        help="language id of input for language-specific xlm models (see tokenization_xlm.PRETRAINED_INIT_CONFIGURATION)",
    )
630

631
632
    parser.add_argument("--logging_steps", type=int, default=500, help="Log every X updates steps.")
    parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.")
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
    parser.add_argument(
        "--eval_all_checkpoints",
        action="store_true",
        help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
    )
    parser.add_argument("--no_cuda", action="store_true", help="Whether not to use CUDA when available")
    parser.add_argument(
        "--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory"
    )
    parser.add_argument(
        "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
    )
    parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")

    parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus")
    parser.add_argument(
        "--fp16",
        action="store_true",
        help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
    )
    parser.add_argument(
        "--fp16_opt_level",
        type=str,
        default="O1",
        help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
        "See details at https://nvidia.github.io/apex/amp.html",
    )
    parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
    parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")

    parser.add_argument("--threads", type=int, default=1, help="multiple threads for converting example to features")
664
665
    args = parser.parse_args()

666
667
668
669
670
671
672
    if args.doc_stride >= args.max_seq_length - args.max_query_length:
        logger.warning(
            "WARNING - You've set a doc stride which may be superior to the document length in some "
            "examples. This could result in errors when building features from the examples. Please reduce the doc "
            "stride or increase the maximum length to ensure the features are correctly built."
        )

673
674
675
676
677
678
    if (
        os.path.exists(args.output_dir)
        and os.listdir(args.output_dir)
        and args.do_train
        and not args.overwrite_output_dir
    ):
679
        raise ValueError(
680
681
682
683
            "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
                args.output_dir
            )
        )
thomwolf's avatar
thomwolf committed
684

685
    # Setup distant debugging if needed
686
687
688
    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
689

690
        print("Waiting for debugger attach")
691
        ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
692
693
        ptvsd.wait_for_attach()

thomwolf's avatar
thomwolf committed
694
    # Setup CUDA, GPU & distributed training
695
    if args.local_rank == -1 or args.no_cuda:
696
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
697
        args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()
thomwolf's avatar
thomwolf committed
698
    else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
699
700
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
701
        torch.distributed.init_process_group(backend="nccl")
thomwolf's avatar
thomwolf committed
702
703
        args.n_gpu = 1
    args.device = device
704

thomwolf's avatar
thomwolf committed
705
    # Setup logging
706
707
708
709
710
711
712
713
714
715
716
717
718
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
    )
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        args.local_rank,
        device,
        args.n_gpu,
        bool(args.local_rank != -1),
        args.fp16,
    )
719

720
721
    # Set seed
    set_seed(args)
722

thomwolf's avatar
thomwolf committed
723
    # Load pretrained model and tokenizer
724
    if args.local_rank not in [-1, 0]:
725
726
        # Make sure only the first process in distributed training will download model & vocab
        torch.distributed.barrier()
727

728
    args.model_type = args.model_type.lower()
729
    config = AutoConfig.from_pretrained(
730
731
732
        args.config_name if args.config_name else args.model_name_or_path,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )
733
    tokenizer = AutoTokenizer.from_pretrained(
734
735
736
737
        args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
        do_lower_case=args.do_lower_case,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )
738
    model = AutoModelForQuestionAnswering.from_pretrained(
739
740
741
742
743
        args.model_name_or_path,
        from_tf=bool(".ckpt" in args.model_name_or_path),
        config=config,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )
744
745

    if args.local_rank == 0:
746
747
        # Make sure only the first process in distributed training will download model & vocab
        torch.distributed.barrier()
748

thomwolf's avatar
thomwolf committed
749
    model.to(args.device)
750

751
752
    logger.info("Training/evaluation parameters %s", args)

Simon Layton's avatar
Simon Layton committed
753
754
755
756
757
758
    # Before we do anything with models, we want to ensure that we get fp16 execution of torch.einsum if args.fp16 is set.
    # Otherwise it'll default to "promote" mode, and we'll get fp32 operations. Note that running `--fp16_opt_level="O2"` will
    # remove the need for this code, but it is still valid.
    if args.fp16:
        try:
            import apex
759
760

            apex.amp.register_half_function(torch, "einsum")
Simon Layton's avatar
Simon Layton committed
761
        except ImportError:
762
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
Simon Layton's avatar
Simon Layton committed
763

thomwolf's avatar
thomwolf committed
764
    # Training
765
    if args.do_train:
766
        train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False)
767
        global_step, tr_loss = train(args, train_dataset, model, tokenizer)
768
        logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
769

thomwolf's avatar
thomwolf committed
770
    # Save the trained model and the tokenizer
Peng Qi's avatar
Peng Qi committed
771
    if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
772
773
774
        logger.info("Saving model checkpoint to %s", args.output_dir)
        # Save a trained model, configuration and tokenizer using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
775
        # Take care of distributed/parallel training
776
        model_to_save = model.module if hasattr(model, "module") else model
777
778
        model_to_save.save_pretrained(args.output_dir)
        tokenizer.save_pretrained(args.output_dir)
779
780

        # Good practice: save your training arguments together with the trained model
781
        torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
782

783
        # Load a trained model and vocabulary that you have fine-tuned
784
785
        model = AutoModelForQuestionAnswering.from_pretrained(args.output_dir)  # , force_download=True)
        tokenizer = AutoTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
786
787
        model.to(args.device)

thomwolf's avatar
thomwolf committed
788
    # Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
789
790
    results = {}
    if args.do_eval and args.local_rank in [-1, 0]:
791
792
793
794
        if args.do_train:
            logger.info("Loading checkpoints saved during training for evaluation")
            checkpoints = [args.output_dir]
            if args.eval_all_checkpoints:
795
796
797
798
                checkpoints = list(
                    os.path.dirname(c)
                    for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
                )
799
800
801
802
                logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN)  # Reduce model loading logs
        else:
            logger.info("Loading checkpoint %s for evaluation", args.model_name_or_path)
            checkpoints = [args.model_name_or_path]
thomwolf's avatar
thomwolf committed
803

804
        logger.info("Evaluate the following checkpoints: %s", checkpoints)
thomwolf's avatar
thomwolf committed
805

806
        for checkpoint in checkpoints:
thomwolf's avatar
thomwolf committed
807
            # Reload the model
808
            global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
809
            model = AutoModelForQuestionAnswering.from_pretrained(checkpoint)  # , force_download=True)
810
            model.to(args.device)
thomwolf's avatar
thomwolf committed
811
812

            # Evaluate
813
            result = evaluate(args, model, tokenizer, prefix=global_step)
thomwolf's avatar
thomwolf committed
814

815
            result = dict((k + ("_{}".format(global_step) if global_step else ""), v) for k, v in result.items())
816
            results.update(result)
thomwolf's avatar
thomwolf committed
817

818
    logger.info("Results: {}".format(results))
thomwolf's avatar
thomwolf committed
819

820
    return results
821
822
823
824


if __name__ == "__main__":
    main()