run_squad.py 34.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
16
""" Finetuning the library models for question-answering on SQuAD (DistilBERT, Bert, XLM, XLNet)."""
17
18
19


import argparse
Aymeric Augustin's avatar
Aymeric Augustin committed
20
import glob
21
22
23
import logging
import os
import random
24
import timeit
Aymeric Augustin's avatar
Aymeric Augustin committed
25

26
27
import numpy as np
import torch
28
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
29
from torch.utils.data.distributed import DistributedSampler
30
from tqdm import tqdm, trange
31

32
import transformers
33
from transformers import (
34
    MODEL_FOR_QUESTION_ANSWERING_MAPPING,
35
    WEIGHTS_NAME,
Aymeric Augustin's avatar
Aymeric Augustin committed
36
    AdamW,
37
38
39
    AutoConfig,
    AutoModelForQuestionAnswering,
    AutoTokenizer,
Aymeric Augustin's avatar
Aymeric Augustin committed
40
41
    get_linear_schedule_with_warmup,
    squad_convert_examples_to_features,
42
)
Aymeric Augustin's avatar
Aymeric Augustin committed
43
44
45
46
47
48
from transformers.data.metrics.squad_metrics import (
    compute_predictions_log_probs,
    compute_predictions_logits,
    squad_evaluate,
)
from transformers.data.processors.squad import SquadResult, SquadV1Processor, SquadV2Processor
49
from transformers.trainer_utils import is_main_process
Aymeric Augustin's avatar
Aymeric Augustin committed
50
51
52
53


try:
    from torch.utils.tensorboard import SummaryWriter
54
except ImportError:
Aymeric Augustin's avatar
Aymeric Augustin committed
55
    from tensorboardX import SummaryWriter
thomwolf's avatar
thomwolf committed
56

57
58
59

logger = logging.getLogger(__name__)

60
61
MODEL_CONFIG_CLASSES = list(MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
thomwolf's avatar
thomwolf committed
62

63

thomwolf's avatar
thomwolf committed
64
65
66
67
68
69
70
def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

71

72
73
def to_list(tensor):
    return tensor.detach().cpu().tolist()
thomwolf's avatar
thomwolf committed
74

75

76
def train(args, train_dataset, model, tokenizer):
Patrick von Platen's avatar
Patrick von Platen committed
77
    """Train the model"""
thomwolf's avatar
thomwolf committed
78
79
80
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

81
    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
82
83
    train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
thomwolf's avatar
thomwolf committed
84
85

    if args.max_steps > 0:
86
        t_total = args.max_steps
87
        args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
thomwolf's avatar
thomwolf committed
88
    else:
89
        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
thomwolf's avatar
thomwolf committed
90

91
    # Prepare optimizer and schedule (linear warmup and decay)
92
    no_decay = ["bias", "LayerNorm.weight"]
thomwolf's avatar
thomwolf committed
93
    optimizer_grouped_parameters = [
94
95
96
97
98
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
LysandreJik's avatar
Cleanup  
LysandreJik committed
99
    ]
100
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
101
    scheduler = get_linear_schedule_with_warmup(
102
103
        optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
    )
104
105

    # Check if saved optimizer or scheduler states exist
106
107
108
    if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
        os.path.join(args.model_name_or_path, "scheduler.pt")
    ):
109
        # Load in optimizer and scheduler states
110
111
        optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))
LysandreJik's avatar
Cleanup  
LysandreJik committed
112

thomwolf's avatar
thomwolf committed
113
114
115
116
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
117
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
118

119
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
thomwolf's avatar
thomwolf committed
120

121
122
123
124
    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

thomwolf's avatar
thomwolf committed
125
126
    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
127
128
129
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
        )
thomwolf's avatar
thomwolf committed
130

thomwolf's avatar
thomwolf committed
131
132
133
134
    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
135
136
137
138
139
140
141
142
    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size
        * args.gradient_accumulation_steps
        * (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
143
    logger.info("  Total optimization steps = %d", t_total)
thomwolf's avatar
thomwolf committed
144

Lysandre's avatar
Lysandre committed
145
    global_step = 1
146
147
148
149
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if os.path.exists(args.model_name_or_path):
150
151
152
153
154
155
156
157
158
159
160
161
162
        try:
            # set global_step to gobal_step of last saved checkpoint from model path
            checkpoint_suffix = args.model_name_or_path.split("-")[-1].split("/")[0]
            global_step = int(checkpoint_suffix)
            epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
            steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)

            logger.info("  Continuing training from checkpoint, will skip to saved global_step")
            logger.info("  Continuing training from epoch %d", epochs_trained)
            logger.info("  Continuing training from global step %d", global_step)
            logger.info("  Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
        except ValueError:
            logger.info("  Starting fine-tuning.")
163

thomwolf's avatar
thomwolf committed
164
    tr_loss, logging_loss = 0.0, 0.0
165
    model.zero_grad()
166
167
168
    train_iterator = trange(
        epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]
    )
169
    # Added here for reproductibility
170
171
    set_seed(args)

172
    for _ in train_iterator:
173
        epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
174
        for step, batch in enumerate(epoch_iterator):
175
176
177
178
179
            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

180
            model.train()
thomwolf's avatar
thomwolf committed
181
            batch = tuple(t.to(args.device) for t in batch)
LysandreJik's avatar
Cleanup  
LysandreJik committed
182
183

            inputs = {
184
185
                "input_ids": batch[0],
                "attention_mask": batch[1],
186
                "token_type_ids": batch[2],
187
188
                "start_positions": batch[3],
                "end_positions": batch[4],
LysandreJik's avatar
Cleanup  
LysandreJik committed
189
190
            }

191
            if args.model_type in ["xlm", "roberta", "distilbert", "camembert", "bart", "longformer"]:
192
193
                del inputs["token_type_ids"]

194
195
            if args.model_type in ["xlnet", "xlm"]:
                inputs.update({"cls_index": batch[5], "p_mask": batch[6]})
196
                if args.version_2_with_negative:
197
                    inputs.update({"is_impossible": batch[7]})
198
199
200
201
202
                if hasattr(model, "config") and hasattr(model.config, "lang2id"):
                    inputs.update(
                        {"langs": (torch.ones(batch[0].shape, dtype=torch.int64) * args.lang_id).to(args.device)}
                    )

Peiqin Lin's avatar
typos  
Peiqin Lin committed
203
            outputs = model(**inputs)
204
205
            # model outputs are always tuple in transformers (see doc)
            loss = outputs[0]
thomwolf's avatar
thomwolf committed
206

207
            if args.n_gpu > 1:
208
                loss = loss.mean()  # mean() to average on multi-gpu parallel (not distributed) training
209
210
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
thomwolf's avatar
thomwolf committed
211

212
213
214
215
216
217
218
219
            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
220
                if args.fp16:
221
                    torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
222
                else:
223
                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
224

225
                optimizer.step()
226
                scheduler.step()  # Update learning rate schedule
227
228
229
                model.zero_grad()
                global_step += 1

LysandreJik's avatar
Cleanup  
LysandreJik committed
230
                # Log metrics
231
                if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
232
233
                    # Only evaluate when single GPU otherwise metrics may not average well
                    if args.local_rank == -1 and args.evaluate_during_training:
234
235
                        results = evaluate(args, model, tokenizer)
                        for key, value in results.items():
236
237
238
                            tb_writer.add_scalar("eval_{}".format(key), value, global_step)
                    tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step)
239
240
                    logging_loss = tr_loss

LysandreJik's avatar
Cleanup  
LysandreJik committed
241
                # Save model checkpoint
242
                if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
243
                    output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
244
                    # Take care of distributed/parallel training
245
                    model_to_save = model.module if hasattr(model, "module") else model
246
                    model_to_save.save_pretrained(output_dir)
247
248
                    tokenizer.save_pretrained(output_dir)

249
                    torch.save(args, os.path.join(output_dir, "training_args.bin"))
250
251
                    logger.info("Saving model checkpoint to %s", output_dir)

252
253
254
                    torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
                    logger.info("Saving optimizer and scheduler states to %s", output_dir)
255

256
257
258
259
260
261
262
            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

thomwolf's avatar
thomwolf committed
263
264
265
    if args.local_rank in [-1, 0]:
        tb_writer.close()

266
267
268
269
    return global_step, tr_loss / global_step


def evaluate(args, model, tokenizer, prefix=""):
270
    dataset, examples, features = load_and_cache_examples(args, tokenizer, evaluate=True, output_examples=True)
271
272
273
274
275

    if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
        os.makedirs(args.output_dir)

    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
LysandreJik's avatar
Cleanup  
LysandreJik committed
276

277
    # Note that DistributedSampler samples randomly
278
    eval_sampler = SequentialSampler(dataset)
279
    eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
280

ronakice's avatar
ronakice committed
281
    # multi-gpu evaluate
282
    if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel):
ronakice's avatar
ronakice committed
283
284
        model = torch.nn.DataParallel(model)

285
286
287
288
    # Eval!
    logger.info("***** Running evaluation {} *****".format(prefix))
    logger.info("  Num examples = %d", len(dataset))
    logger.info("  Batch size = %d", args.eval_batch_size)
LysandreJik's avatar
Cleanup  
LysandreJik committed
289

290
    all_results = []
291
    start_time = timeit.default_timer()
LysandreJik's avatar
Cleanup  
LysandreJik committed
292

293
294
295
    for batch in tqdm(eval_dataloader, desc="Evaluating"):
        model.eval()
        batch = tuple(t.to(args.device) for t in batch)
LysandreJik's avatar
Cleanup  
LysandreJik committed
296

297
        with torch.no_grad():
LysandreJik's avatar
LysandreJik committed
298
            inputs = {
299
300
                "input_ids": batch[0],
                "attention_mask": batch[1],
301
                "token_type_ids": batch[2],
LysandreJik's avatar
LysandreJik committed
302
            }
303

304
            if args.model_type in ["xlm", "roberta", "distilbert", "camembert", "bart", "longformer"]:
305
306
                del inputs["token_type_ids"]

307
            feature_indices = batch[3]
308

LysandreJik's avatar
Cleanup  
LysandreJik committed
309
            # XLNet and XLM use more arguments for their predictions
310
311
            if args.model_type in ["xlnet", "xlm"]:
                inputs.update({"cls_index": batch[4], "p_mask": batch[5]})
312
313
314
315
316
                # for lang_id-sensitive xlm models
                if hasattr(model, "config") and hasattr(model.config, "lang2id"):
                    inputs.update(
                        {"langs": (torch.ones(batch[0].shape, dtype=torch.int64) * args.lang_id).to(args.device)}
                    )
317
318
            outputs = model(**inputs)

319
320
        for i, feature_index in enumerate(feature_indices):
            eval_feature = features[feature_index.item()]
321
            unique_id = int(eval_feature.unique_id)
LysandreJik's avatar
LysandreJik committed
322

323
            output = [to_list(output[i]) for output in outputs.to_tuple()]
LysandreJik's avatar
LysandreJik committed
324

LysandreJik's avatar
Cleanup  
LysandreJik committed
325
326
            # Some models (XLNet, XLM) use 5 arguments for their predictions, while the other "simpler"
            # models only use two.
LysandreJik's avatar
LysandreJik committed
327
328
329
330
            if len(output) >= 5:
                start_logits = output[0]
                start_top_index = output[1]
                end_logits = output[2]
LysandreJik's avatar
Cleanup  
LysandreJik committed
331
                end_top_index = output[3]
LysandreJik's avatar
LysandreJik committed
332
333
334
                cls_logits = output[4]

                result = SquadResult(
335
336
337
                    unique_id,
                    start_logits,
                    end_logits,
338
339
                    start_top_index=start_top_index,
                    end_top_index=end_top_index,
340
                    cls_logits=cls_logits,
LysandreJik's avatar
LysandreJik committed
341
342
343
344
                )

            else:
                start_logits, end_logits = output
345
                result = SquadResult(unique_id, start_logits, end_logits)
LysandreJik's avatar
LysandreJik committed
346

347
            all_results.append(result)
348

349
    evalTime = timeit.default_timer() - start_time
350
    logger.info("  Evaluation done in total %f secs (%f sec per example)", evalTime, evalTime / len(dataset))
351

thomwolf's avatar
thomwolf committed
352
    # Compute predictions
353
354
    output_prediction_file = os.path.join(args.output_dir, "predictions_{}.json".format(prefix))
    output_nbest_file = os.path.join(args.output_dir, "nbest_predictions_{}.json".format(prefix))
LysandreJik's avatar
Cleanup  
LysandreJik committed
355

356
    if args.version_2_with_negative:
357
        output_null_log_odds_file = os.path.join(args.output_dir, "null_odds_{}.json".format(prefix))
358
359
    else:
        output_null_log_odds_file = None
360

LysandreJik's avatar
Cleanup  
LysandreJik committed
361
    # XLNet and XLM use a more complex post-processing procedure
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
    if args.model_type in ["xlnet", "xlm"]:
        start_n_top = model.config.start_n_top if hasattr(model, "config") else model.module.config.start_n_top
        end_n_top = model.config.end_n_top if hasattr(model, "config") else model.module.config.end_n_top

        predictions = compute_predictions_log_probs(
            examples,
            features,
            all_results,
            args.n_best_size,
            args.max_answer_length,
            output_prediction_file,
            output_nbest_file,
            output_null_log_odds_file,
            start_n_top,
            end_n_top,
            args.version_2_with_negative,
            tokenizer,
            args.verbose_logging,
        )
381
    else:
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
        predictions = compute_predictions_logits(
            examples,
            features,
            all_results,
            args.n_best_size,
            args.max_answer_length,
            args.do_lower_case,
            output_prediction_file,
            output_nbest_file,
            output_null_log_odds_file,
            args.verbose_logging,
            args.version_2_with_negative,
            args.null_score_diff_threshold,
            tokenizer,
        )
397

LysandreJik's avatar
Cleanup  
LysandreJik committed
398
    # Compute the F1 and exact scores.
LysandreJik's avatar
LysandreJik committed
399
    results = squad_evaluate(examples, predictions)
400
401
    return results

402

403
def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False):
VictorSanh's avatar
VictorSanh committed
404
    if args.local_rank not in [-1, 0] and not evaluate:
405
406
        # Make sure only the first process in distributed training process the dataset, and the others will use the cache
        torch.distributed.barrier()
thomwolf's avatar
thomwolf committed
407

408
    # Load data features from cache or dataset file
LysandreJik's avatar
Cleanup  
LysandreJik committed
409
    input_dir = args.data_dir if args.data_dir else "."
410
411
412
413
414
415
416
    cached_features_file = os.path.join(
        input_dir,
        "cached_{}_{}_{}".format(
            "dev" if evaluate else "train",
            list(filter(None, args.model_name_or_path.split("/"))).pop(),
            str(args.max_seq_length),
        ),
LysandreJik's avatar
Cleanup  
LysandreJik committed
417
418
419
    )

    # Init features and dataset from cache if it exists
420
    if os.path.exists(cached_features_file) and not args.overwrite_cache:
421
        logger.info("Loading features from cached file %s", cached_features_file)
422
        features_and_dataset = torch.load(cached_features_file)
423
424
425
426
427
        features, dataset, examples = (
            features_and_dataset["features"],
            features_and_dataset["dataset"],
            features_and_dataset["examples"],
        )
thomwolf's avatar
thomwolf committed
428
    else:
LysandreJik's avatar
Cleanup  
LysandreJik committed
429
        logger.info("Creating features from dataset file at %s", input_dir)
Lysandre's avatar
Lysandre committed
430

431
        if not args.data_dir and ((evaluate and not args.predict_file) or (not evaluate and not args.train_file)):
LysandreJik's avatar
Cleanup  
LysandreJik committed
432
433
434
            try:
                import tensorflow_datasets as tfds
            except ImportError:
435
                raise ImportError("If not data_dir is specified, tensorflow_datasets needs to be installed.")
LysandreJik's avatar
Cleanup  
LysandreJik committed
436
437

            if args.version_2_with_negative:
438
                logger.warning("tensorflow_datasets does not handle version 2 of SQuAD.")
LysandreJik's avatar
Cleanup  
LysandreJik committed
439
440

            tfds_examples = tfds.load("squad")
441
            examples = SquadV1Processor().get_examples_from_dataset(tfds_examples, evaluate=evaluate)
LysandreJik's avatar
Cleanup  
LysandreJik committed
442
443
        else:
            processor = SquadV2Processor() if args.version_2_with_negative else SquadV1Processor()
444
445
446
447
            if evaluate:
                examples = processor.get_dev_examples(args.data_dir, filename=args.predict_file)
            else:
                examples = processor.get_train_examples(args.data_dir, filename=args.train_file)
LysandreJik's avatar
LysandreJik committed
448

449
        features, dataset = squad_convert_examples_to_features(
Lysandre's avatar
Lysandre committed
450
451
452
453
454
455
            examples=examples,
            tokenizer=tokenizer,
            max_seq_length=args.max_seq_length,
            doc_stride=args.doc_stride,
            max_query_length=args.max_query_length,
            is_training=not evaluate,
456
            return_dataset="pt",
erenup's avatar
erenup committed
457
            threads=args.threads,
Lysandre's avatar
Lysandre committed
458
459
        )

thomwolf's avatar
thomwolf committed
460
        if args.local_rank in [-1, 0]:
461
            logger.info("Saving features into cached file %s", cached_features_file)
462
            torch.save({"features": features, "dataset": dataset, "examples": examples}, cached_features_file)
thomwolf's avatar
thomwolf committed
463

VictorSanh's avatar
VictorSanh committed
464
    if args.local_rank == 0 and not evaluate:
465
466
        # Make sure only the first process in distributed training process the dataset, and the others will use the cache
        torch.distributed.barrier()
thomwolf's avatar
thomwolf committed
467

468
469
    if output_examples:
        return dataset, examples, features
thomwolf's avatar
thomwolf committed
470
471
    return dataset

472
473
474
475

def main():
    parser = argparse.ArgumentParser()

476
    # Required parameters
477
478
479
480
481
    parser.add_argument(
        "--model_type",
        default=None,
        type=str,
        required=True,
482
        help="Model type selected in the list: " + ", ".join(MODEL_TYPES),
483
484
485
486
487
488
    )
    parser.add_argument(
        "--model_name_or_path",
        default=None,
        type=str,
        required=True,
489
        help="Path to pretrained model or model identifier from huggingface.co/models",
490
491
492
493
494
495
496
497
    )
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help="The output directory where the model checkpoints and predictions will be written.",
    )
498

499
    # Other parameters
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        help="The input data dir. Should contain the .json files for the task."
        + "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
    )
    parser.add_argument(
        "--train_file",
        default=None,
        type=str,
        help="The input training file. If a data dir is specified, will look for the file there"
        + "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
    )
    parser.add_argument(
        "--predict_file",
        default=None,
        type=str,
        help="The input evaluation file. If a data dir is specified, will look for the file there"
        + "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
    )
    parser.add_argument(
        "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
    )
    parser.add_argument(
        "--tokenizer_name",
        default="",
        type=str,
        help="Pretrained tokenizer name or path if not the same as model_name",
    )
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
534
        help="Where do you want to store the pre-trained models downloaded from huggingface.co",
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
    )

    parser.add_argument(
        "--version_2_with_negative",
        action="store_true",
        help="If true, the SQuAD examples contain some that do not have an answer.",
    )
    parser.add_argument(
        "--null_score_diff_threshold",
        type=float,
        default=0.0,
        help="If null_score - best_non_null is greater than the threshold predict null.",
    )

    parser.add_argument(
        "--max_seq_length",
        default=384,
        type=int,
Sylvain Gugger's avatar
Sylvain Gugger committed
553
554
555
556
        help=(
            "The maximum total input sequence length after WordPiece tokenization. Sequences "
            "longer than this will be truncated, and sequences shorter than this will be padded."
        ),
557
558
559
560
561
562
563
564
565
566
567
    )
    parser.add_argument(
        "--doc_stride",
        default=128,
        type=int,
        help="When splitting up a long document into chunks, how much stride to take between chunks.",
    )
    parser.add_argument(
        "--max_query_length",
        default=64,
        type=int,
Sylvain Gugger's avatar
Sylvain Gugger committed
568
569
570
571
        help=(
            "The maximum number of tokens for the question. Questions longer than this will "
            "be truncated to this length."
        ),
572
573
574
575
    )
    parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
    parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
    parser.add_argument(
576
        "--evaluate_during_training", action="store_true", help="Run evaluation during training at each logging step."
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
    )
    parser.add_argument(
        "--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
    )

    parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.")
    parser.add_argument(
        "--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation."
    )
    parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help="Number of updates steps to accumulate before performing a backward/update pass.",
    )
    parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
    parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
    parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
    parser.add_argument(
        "--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform."
    )
    parser.add_argument(
        "--max_steps",
        default=-1,
        type=int,
        help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
    )
    parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
    parser.add_argument(
        "--n_best_size",
        default=20,
        type=int,
        help="The total number of n-best predictions to generate in the nbest_predictions.json output file.",
    )
    parser.add_argument(
        "--max_answer_length",
        default=30,
        type=int,
Sylvain Gugger's avatar
Sylvain Gugger committed
616
617
618
619
        help=(
            "The maximum length of an answer that can be generated. This is needed because the start "
            "and end predictions are not conditioned on one another."
        ),
620
621
622
623
    )
    parser.add_argument(
        "--verbose_logging",
        action="store_true",
Sylvain Gugger's avatar
Sylvain Gugger committed
624
625
626
627
        help=(
            "If true, all of the warnings related to data processing will be printed. "
            "A number of warnings are expected for a normal SQuAD evaluation."
        ),
628
    )
629
630
631
632
    parser.add_argument(
        "--lang_id",
        default=0,
        type=int,
Sylvain Gugger's avatar
Sylvain Gugger committed
633
634
635
636
        help=(
            "language id of input for language-specific xlm models (see"
            " tokenization_xlm.PRETRAINED_INIT_CONFIGURATION)"
        ),
637
    )
638

639
640
    parser.add_argument("--logging_steps", type=int, default=500, help="Log every X updates steps.")
    parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.")
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
    parser.add_argument(
        "--eval_all_checkpoints",
        action="store_true",
        help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
    )
    parser.add_argument("--no_cuda", action="store_true", help="Whether not to use CUDA when available")
    parser.add_argument(
        "--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory"
    )
    parser.add_argument(
        "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
    )
    parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")

    parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus")
    parser.add_argument(
        "--fp16",
        action="store_true",
        help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
    )
    parser.add_argument(
        "--fp16_opt_level",
        type=str,
        default="O1",
Sylvain Gugger's avatar
Sylvain Gugger committed
665
666
667
668
        help=(
            "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
            "See details at https://nvidia.github.io/apex/amp.html"
        ),
669
670
671
672
673
    )
    parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
    parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")

    parser.add_argument("--threads", type=int, default=1, help="multiple threads for converting example to features")
674
675
    args = parser.parse_args()

676
677
678
679
680
681
682
    if args.doc_stride >= args.max_seq_length - args.max_query_length:
        logger.warning(
            "WARNING - You've set a doc stride which may be superior to the document length in some "
            "examples. This could result in errors when building features from the examples. Please reduce the doc "
            "stride or increase the maximum length to ensure the features are correctly built."
        )

683
684
685
686
687
688
    if (
        os.path.exists(args.output_dir)
        and os.listdir(args.output_dir)
        and args.do_train
        and not args.overwrite_output_dir
    ):
689
        raise ValueError(
690
691
692
693
            "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
                args.output_dir
            )
        )
thomwolf's avatar
thomwolf committed
694

695
    # Setup distant debugging if needed
696
697
698
    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
699

700
        print("Waiting for debugger attach")
701
        ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
702
703
        ptvsd.wait_for_attach()

thomwolf's avatar
thomwolf committed
704
    # Setup CUDA, GPU & distributed training
705
    if args.local_rank == -1 or args.no_cuda:
706
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
707
        args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()
thomwolf's avatar
thomwolf committed
708
    else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
709
710
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
711
        torch.distributed.init_process_group(backend="nccl")
thomwolf's avatar
thomwolf committed
712
713
        args.n_gpu = 1
    args.device = device
714

thomwolf's avatar
thomwolf committed
715
    # Setup logging
716
    logging.basicConfig(
717
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
718
719
720
721
722
723
724
725
726
727
728
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
    )
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        args.local_rank,
        device,
        args.n_gpu,
        bool(args.local_rank != -1),
        args.fp16,
    )
729
730
731
732
733
    # Set the verbosity to info of the Transformers logger (on main process only):
    if is_main_process(args.local_rank):
        transformers.utils.logging.set_verbosity_info()
        transformers.utils.logging.enable_default_handler()
        transformers.utils.logging.enable_explicit_format()
734
735
    # Set seed
    set_seed(args)
736

thomwolf's avatar
thomwolf committed
737
    # Load pretrained model and tokenizer
738
    if args.local_rank not in [-1, 0]:
739
740
        # Make sure only the first process in distributed training will download model & vocab
        torch.distributed.barrier()
741

742
    args.model_type = args.model_type.lower()
743
    config = AutoConfig.from_pretrained(
744
745
746
        args.config_name if args.config_name else args.model_name_or_path,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )
747
    tokenizer = AutoTokenizer.from_pretrained(
748
749
750
        args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
        do_lower_case=args.do_lower_case,
        cache_dir=args.cache_dir if args.cache_dir else None,
751
        use_fast=False,  # SquadDataset is not compatible with Fast tokenizers which have a smarter overflow handeling
752
    )
753
    model = AutoModelForQuestionAnswering.from_pretrained(
754
755
756
757
758
        args.model_name_or_path,
        from_tf=bool(".ckpt" in args.model_name_or_path),
        config=config,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )
759
760

    if args.local_rank == 0:
761
762
        # Make sure only the first process in distributed training will download model & vocab
        torch.distributed.barrier()
763

thomwolf's avatar
thomwolf committed
764
    model.to(args.device)
765

766
767
    logger.info("Training/evaluation parameters %s", args)

Simon Layton's avatar
Simon Layton committed
768
769
770
771
772
773
    # Before we do anything with models, we want to ensure that we get fp16 execution of torch.einsum if args.fp16 is set.
    # Otherwise it'll default to "promote" mode, and we'll get fp32 operations. Note that running `--fp16_opt_level="O2"` will
    # remove the need for this code, but it is still valid.
    if args.fp16:
        try:
            import apex
774
775

            apex.amp.register_half_function(torch, "einsum")
Simon Layton's avatar
Simon Layton committed
776
        except ImportError:
777
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
Simon Layton's avatar
Simon Layton committed
778

thomwolf's avatar
thomwolf committed
779
    # Training
780
    if args.do_train:
781
        train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False)
782
        global_step, tr_loss = train(args, train_dataset, model, tokenizer)
783
        logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
784

thomwolf's avatar
thomwolf committed
785
    # Save the trained model and the tokenizer
Peng Qi's avatar
Peng Qi committed
786
    if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
787
788
789
        logger.info("Saving model checkpoint to %s", args.output_dir)
        # Save a trained model, configuration and tokenizer using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
790
        # Take care of distributed/parallel training
791
        model_to_save = model.module if hasattr(model, "module") else model
792
793
        model_to_save.save_pretrained(args.output_dir)
        tokenizer.save_pretrained(args.output_dir)
794
795

        # Good practice: save your training arguments together with the trained model
796
        torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
797

798
        # Load a trained model and vocabulary that you have fine-tuned
799
        model = AutoModelForQuestionAnswering.from_pretrained(args.output_dir)  # , force_download=True)
800
801
802
803

        # SquadDataset is not compatible with Fast tokenizers which have a smarter overflow handeling
        # So we use use_fast=False here for now until Fast-tokenizer-compatible-examples are out
        tokenizer = AutoTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case, use_fast=False)
804
805
        model.to(args.device)

thomwolf's avatar
thomwolf committed
806
    # Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
807
808
    results = {}
    if args.do_eval and args.local_rank in [-1, 0]:
809
810
811
812
        if args.do_train:
            logger.info("Loading checkpoints saved during training for evaluation")
            checkpoints = [args.output_dir]
            if args.eval_all_checkpoints:
813
                checkpoints = [
814
815
                    os.path.dirname(c)
                    for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
816
                ]
817

818
819
820
        else:
            logger.info("Loading checkpoint %s for evaluation", args.model_name_or_path)
            checkpoints = [args.model_name_or_path]
thomwolf's avatar
thomwolf committed
821

822
        logger.info("Evaluate the following checkpoints: %s", checkpoints)
thomwolf's avatar
thomwolf committed
823

824
        for checkpoint in checkpoints:
thomwolf's avatar
thomwolf committed
825
            # Reload the model
826
            global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
827
            model = AutoModelForQuestionAnswering.from_pretrained(checkpoint)  # , force_download=True)
828
            model.to(args.device)
thomwolf's avatar
thomwolf committed
829
830

            # Evaluate
831
            result = evaluate(args, model, tokenizer, prefix=global_step)
thomwolf's avatar
thomwolf committed
832

833
            result = {k + ("_{}".format(global_step) if global_step else ""): v for k, v in result.items()}
834
            results.update(result)
thomwolf's avatar
thomwolf committed
835

836
    logger.info("Results: {}".format(results))
thomwolf's avatar
thomwolf committed
837

838
    return results
839
840
841
842


if __name__ == "__main__":
    main()