run_squad.py 30.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
16
""" Finetuning the library models for question-answering on SQuAD (DistilBERT, Bert, XLM, XLNet)."""
17
18
19
20
21
22
23

from __future__ import absolute_import, division, print_function

import argparse
import logging
import os
import random
thomwolf's avatar
thomwolf committed
24
import glob
25
import timeit
26
27
28
29
30
31
32

import numpy as np
import torch
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
                              TensorDataset)
from torch.utils.data.distributed import DistributedSampler

33
34
35
36
37
38
try:
    from torch.utils.tensorboard import SummaryWriter
except:
    from tensorboardX import SummaryWriter

from tqdm import tqdm, trange
39

40
from transformers import (WEIGHTS_NAME, BertConfig,
thomwolf's avatar
thomwolf committed
41
42
43
44
                                  BertForQuestionAnswering, BertTokenizer,
                                  XLMConfig, XLMForQuestionAnswering,
                                  XLMTokenizer, XLNetConfig,
                                  XLNetForQuestionAnswering,
45
46
                                  XLNetTokenizer,
                                  DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer)
thomwolf's avatar
thomwolf committed
47

48
from transformers import AdamW, WarmupLinearSchedule
49

50
51
52
from utils_squad import (read_squad_examples, convert_examples_to_features,
                         RawResult, write_predictions,
                         RawResultExtended, write_predictions_extended)
53

thomwolf's avatar
thomwolf committed
54
55
56
# The follwing import is the official SQuAD evaluation script (2.0).
# You can remove it from the dependencies if you are using this script outside of the library
# We've added it here for automated tests (see examples/test_examples.py file)
57
58
from utils_squad_evaluate import EVAL_OPTS, main as evaluate_on_squad

59
60
logger = logging.getLogger(__name__)

thomwolf's avatar
thomwolf committed
61
62
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) \
                  for conf in (BertConfig, XLNetConfig, XLMConfig)), ())
thomwolf's avatar
thomwolf committed
63
64

MODEL_CLASSES = {
thomwolf's avatar
thomwolf committed
65
66
67
    'bert': (BertConfig, BertForQuestionAnswering, BertTokenizer),
    'xlnet': (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer),
    'xlm': (XLMConfig, XLMForQuestionAnswering, XLMTokenizer),
68
    'distilbert': (DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer)
thomwolf's avatar
thomwolf committed
69
70
}

thomwolf's avatar
thomwolf committed
71
72
73
74
75
76
77
def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

78
79
def to_list(tensor):
    return tensor.detach().cpu().tolist()
thomwolf's avatar
thomwolf committed
80

81
def train(args, train_dataset, model, tokenizer):
thomwolf's avatar
thomwolf committed
82
83
84
85
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

86
    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
thomwolf's avatar
thomwolf committed
87
88
89
90
    train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)

    if args.max_steps > 0:
91
        t_total = args.max_steps
thomwolf's avatar
thomwolf committed
92
93
        args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
94
        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
thomwolf's avatar
thomwolf committed
95

96
    # Prepare optimizer and schedule (linear warmup and decay)
thomwolf's avatar
thomwolf committed
97
98
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
99
        {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
thomwolf's avatar
thomwolf committed
100
101
        {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
102
103
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)
thomwolf's avatar
thomwolf committed
104
105
106
107
108
109
110
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)

111
112
113
114
    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

thomwolf's avatar
thomwolf committed
115
116
117
118
119
120
    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
                                                          output_device=args.local_rank,
                                                          find_unused_parameters=True)

thomwolf's avatar
thomwolf committed
121
122
123
124
    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
125
126
127
    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
    logger.info("  Total train batch size (w. parallel, distributed & accumulation) = %d",
                   args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
thomwolf's avatar
thomwolf committed
128
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
129
    logger.info("  Total optimization steps = %d", t_total)
thomwolf's avatar
thomwolf committed
130
131
132

    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
133
134
135
136
137
138
139
    model.zero_grad()
    train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
    set_seed(args)  # Added here for reproductibility (even between python 2 and 3)
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):
            model.train()
thomwolf's avatar
thomwolf committed
140
            batch = tuple(t.to(args.device) for t in batch)
141
            inputs = {'input_ids':       batch[0],
Simon Layton's avatar
Simon Layton committed
142
143
                      'attention_mask':  batch[1],
                      'start_positions': batch[3],
144
                      'end_positions':   batch[4]}
145
146
            if args.model_type != 'distilbert':
                inputs['token_type_ids'] = None if args.model_type == 'xlm' else batch[2]
147
148
            if args.model_type in ['xlnet', 'xlm']:
                inputs.update({'cls_index': batch[5],
thomwolf's avatar
thomwolf committed
149
                               'p_mask':       batch[6]})
Peiqin Lin's avatar
typos  
Peiqin Lin committed
150
            outputs = model(**inputs)
151
            loss = outputs[0]  # model outputs are always tuple in transformers (see doc)
thomwolf's avatar
thomwolf committed
152

153
            if args.n_gpu > 1:
thomwolf's avatar
thomwolf committed
154
                loss = loss.mean() # mean() to average on multi-gpu parallel (not distributed) training
155
156
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
thomwolf's avatar
thomwolf committed
157

158
159
160
161
162
163
164
165
            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
166
167
168
169
170
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)

171
                optimizer.step()
172
                scheduler.step()  # Update learning rate schedule
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
                model.zero_grad()
                global_step += 1

                if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    if args.local_rank == -1 and args.evaluate_during_training:  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(args, model, tokenizer)
                        for key, value in results.items():
                            tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
                    tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
                    logging_loss = tr_loss

                if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
                    output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = model.module if hasattr(model, 'module') else model  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    torch.save(args, os.path.join(output_dir, 'training_args.bin'))
                    logger.info("Saving model checkpoint to %s", output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

thomwolf's avatar
thomwolf committed
203
204
205
    if args.local_rank in [-1, 0]:
        tb_writer.close()

206
207
208
209
210
211
212
213
214
215
216
217
218
219
    return global_step, tr_loss / global_step


def evaluate(args, model, tokenizer, prefix=""):
    dataset, examples, features = load_and_cache_examples(args, tokenizer, evaluate=True, output_examples=True)

    if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
        os.makedirs(args.output_dir)

    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
    # Note that DistributedSampler samples randomly
    eval_sampler = SequentialSampler(dataset) if args.local_rank == -1 else DistributedSampler(dataset)
    eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)

ronakice's avatar
ronakice committed
220
221
222
223
    # multi-gpu evaluate
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

224
225
226
227
228
    # Eval!
    logger.info("***** Running evaluation {} *****".format(prefix))
    logger.info("  Num examples = %d", len(dataset))
    logger.info("  Batch size = %d", args.eval_batch_size)
    all_results = []
229
    start_time = timeit.default_timer()
230
231
232
233
234
    for batch in tqdm(eval_dataloader, desc="Evaluating"):
        model.eval()
        batch = tuple(t.to(args.device) for t in batch)
        with torch.no_grad():
            inputs = {'input_ids':      batch[0],
235
                      'attention_mask': batch[1]
thomwolf's avatar
thomwolf committed
236
                      }
237
238
            if args.model_type != 'distilbert':
                inputs['token_type_ids'] = None if args.model_type == 'xlm' else batch[2]  # XLM don't use segment_ids
239
240
241
242
            example_indices = batch[3]
            if args.model_type in ['xlnet', 'xlm']:
                inputs.update({'cls_index': batch[4],
                               'p_mask':    batch[5]})
243
244
245
246
247
            outputs = model(**inputs)

        for i, example_index in enumerate(example_indices):
            eval_feature = features[example_index.item()]
            unique_id = int(eval_feature.unique_id)
248
249
250
251
252
253
254
255
256
257
258
259
260
            if args.model_type in ['xlnet', 'xlm']:
                # XLNet uses a more complex post-processing procedure
                result = RawResultExtended(unique_id            = unique_id,
                                           start_top_log_probs  = to_list(outputs[0][i]),
                                           start_top_index      = to_list(outputs[1][i]),
                                           end_top_log_probs    = to_list(outputs[2][i]),
                                           end_top_index        = to_list(outputs[3][i]),
                                           cls_logits           = to_list(outputs[4][i]))
            else:
                result = RawResult(unique_id    = unique_id,
                                   start_logits = to_list(outputs[0][i]),
                                   end_logits   = to_list(outputs[1][i]))
            all_results.append(result)
261

262
263
264
    evalTime = timeit.default_timer() - start_time
    logger.info("  Evaluation done in total %f secs (%f sec per example)", evalTime, evalTime / len(dataset))

thomwolf's avatar
thomwolf committed
265
    # Compute predictions
266
267
    output_prediction_file = os.path.join(args.output_dir, "predictions_{}.json".format(prefix))
    output_nbest_file = os.path.join(args.output_dir, "nbest_predictions_{}.json".format(prefix))
268
269
270
271
    if args.version_2_with_negative:
        output_null_log_odds_file = os.path.join(args.output_dir, "null_odds_{}.json".format(prefix))
    else:
        output_null_log_odds_file = None
272
273
274
275
276
277

    if args.model_type in ['xlnet', 'xlm']:
        # XLNet uses a more complex post-processing procedure
        write_predictions_extended(examples, features, all_results, args.n_best_size,
                        args.max_answer_length, output_prediction_file,
                        output_nbest_file, output_null_log_odds_file, args.predict_file,
278
279
                        model.config.start_n_top, model.config.end_n_top,
                        args.version_2_with_negative, tokenizer, args.verbose_logging)
280
281
282
283
284
    else:
        write_predictions(examples, features, all_results, args.n_best_size,
                        args.max_answer_length, args.do_lower_case, output_prediction_file,
                        output_nbest_file, output_null_log_odds_file, args.verbose_logging,
                        args.version_2_with_negative, args.null_score_diff_threshold)
285

thomwolf's avatar
thomwolf committed
286
    # Evaluate with the official SQuAD script
287
288
289
290
291
292
293
294
    evaluate_options = EVAL_OPTS(data_file=args.predict_file,
                                 pred_file=output_prediction_file,
                                 na_prob_file=output_null_log_odds_file)
    results = evaluate_on_squad(evaluate_options)
    return results


def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False):
VictorSanh's avatar
VictorSanh committed
295
    if args.local_rank not in [-1, 0] and not evaluate:
thomwolf's avatar
thomwolf committed
296
297
        torch.distributed.barrier()  # Make sure only the first process in distributed training process the dataset, and the others will use the cache

298
299
300
    # Load data features from cache or dataset file
    input_file = args.predict_file if evaluate else args.train_file
    cached_features_file = os.path.join(os.path.dirname(input_file), 'cached_{}_{}_{}'.format(
thomwolf's avatar
thomwolf committed
301
        'dev' if evaluate else 'train',
302
        list(filter(None, args.model_name_or_path.split('/'))).pop(),
303
304
        str(args.max_seq_length)))
    if os.path.exists(cached_features_file) and not args.overwrite_cache and not output_examples:
thomwolf's avatar
thomwolf committed
305
306
307
        logger.info("Loading features from cached file %s", cached_features_file)
        features = torch.load(cached_features_file)
    else:
308
309
        logger.info("Creating features from dataset file at %s", input_file)
        examples = read_squad_examples(input_file=input_file,
310
311
                                                is_training=not evaluate,
                                                version_2_with_negative=args.version_2_with_negative)
312
313
314
315
316
        features = convert_examples_to_features(examples=examples,
                                                tokenizer=tokenizer,
                                                max_seq_length=args.max_seq_length,
                                                doc_stride=args.doc_stride,
                                                max_query_length=args.max_query_length,
317
318
319
320
321
                                                is_training=not evaluate,
                                                cls_token_segment_id=2 if args.model_type in ['xlnet'] else 0,
                                                pad_token_segment_id=3 if args.model_type in ['xlnet'] else 0,
                                                cls_token_at_end=True if args.model_type in ['xlnet'] else False,
                                                sequence_a_is_doc=True if args.model_type in ['xlnet'] else False)
thomwolf's avatar
thomwolf committed
322
323
324
325
        if args.local_rank in [-1, 0]:
            logger.info("Saving features into cached file %s", cached_features_file)
            torch.save(features, cached_features_file)

VictorSanh's avatar
VictorSanh committed
326
    if args.local_rank == 0 and not evaluate:
thomwolf's avatar
thomwolf committed
327
328
        torch.distributed.barrier()  # Make sure only the first process in distributed training process the dataset, and the others will use the cache

thomwolf's avatar
thomwolf committed
329
    # Convert to Tensors and build dataset
330
331
332
    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
333
334
    all_cls_index = torch.tensor([f.cls_index for f in features], dtype=torch.long)
    all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float)
335
    if evaluate:
thomwolf's avatar
thomwolf committed
336
        all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
337
338
        dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
                                all_example_index, all_cls_index, all_p_mask)
339
340
341
    else:
        all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long)
        all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long)
342
343
344
        dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
                                all_start_positions, all_end_positions,
                                all_cls_index, all_p_mask)
thomwolf's avatar
thomwolf committed
345

346
347
    if output_examples:
        return dataset, examples, features
thomwolf's avatar
thomwolf committed
348
349
    return dataset

350
351
352
353
354

def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
thomwolf's avatar
thomwolf committed
355
356
357
358
    parser.add_argument("--train_file", default=None, type=str, required=True,
                        help="SQuAD json for training. E.g., train-v1.1.json")
    parser.add_argument("--predict_file", default=None, type=str, required=True,
                        help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json")
359
360
361
362
    parser.add_argument("--model_type", default=None, type=str, required=True,
                        help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
    parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
                        help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
363
364
365
366
    parser.add_argument("--output_dir", default=None, type=str, required=True,
                        help="The output directory where the model checkpoints and predictions will be written.")

    ## Other parameters
367
368
369
370
371
372
373
    parser.add_argument("--config_name", default="", type=str,
                        help="Pretrained config name or path if not the same as model_name")
    parser.add_argument("--tokenizer_name", default="", type=str,
                        help="Pretrained tokenizer name or path if not the same as model_name")
    parser.add_argument("--cache_dir", default="", type=str,
                        help="Where do you want to store the pre-trained models downloaded from s3")

thomwolf's avatar
thomwolf committed
374
375
376
377
378
    parser.add_argument('--version_2_with_negative', action='store_true',
                        help='If true, the SQuAD examples contain some that do not have an answer.')
    parser.add_argument('--null_score_diff_threshold', type=float, default=0.0,
                        help="If null_score - best_non_null is greater than the threshold predict null.")

379
380
381
382
383
384
385
386
    parser.add_argument("--max_seq_length", default=384, type=int,
                        help="The maximum total input sequence length after WordPiece tokenization. Sequences "
                             "longer than this will be truncated, and sequences shorter than this will be padded.")
    parser.add_argument("--doc_stride", default=128, type=int,
                        help="When splitting up a long document into chunks, how much stride to take between chunks.")
    parser.add_argument("--max_query_length", default=64, type=int,
                        help="The maximum number of tokens for the question. Questions longer than this will "
                             "be truncated to this length.")
thomwolf's avatar
thomwolf committed
387
388
    parser.add_argument("--do_train", action='store_true',
                        help="Whether to run training.")
389
    parser.add_argument("--do_eval", action='store_true',
thomwolf's avatar
thomwolf committed
390
                        help="Whether to run eval on the dev set.")
391
392
    parser.add_argument("--evaluate_during_training", action='store_true',
                        help="Rul evaluation during training at each logging step.")
thomwolf's avatar
thomwolf committed
393
    parser.add_argument("--do_lower_case", action='store_true',
394
                        help="Set this flag if you are using an uncased model.")
thomwolf's avatar
thomwolf committed
395

396
397
398
399
    parser.add_argument("--per_gpu_train_batch_size", default=8, type=int,
                        help="Batch size per GPU/CPU for training.")
    parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int,
                        help="Batch size per GPU/CPU for evaluation.")
thomwolf's avatar
thomwolf committed
400
401
402
403
    parser.add_argument("--learning_rate", default=5e-5, type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
404
405
406
407
408
409
    parser.add_argument("--weight_decay", default=0.0, type=float,
                        help="Weight deay if we apply some.")
    parser.add_argument("--adam_epsilon", default=1e-8, type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument("--max_grad_norm", default=1.0, type=float,
                        help="Max gradient norm.")
410
411
    parser.add_argument("--num_train_epochs", default=3.0, type=float,
                        help="Total number of training epochs to perform.")
412
413
414
415
    parser.add_argument("--max_steps", default=-1, type=int,
                        help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
    parser.add_argument("--warmup_steps", default=0, type=int,
                        help="Linear warmup over warmup_steps.")
416
    parser.add_argument("--n_best_size", default=20, type=int,
thomwolf's avatar
thomwolf committed
417
                        help="The total number of n-best predictions to generate in the nbest_predictions.json output file.")
418
419
420
421
422
423
    parser.add_argument("--max_answer_length", default=30, type=int,
                        help="The maximum length of an answer that can be generated. This is needed because the start "
                             "and end predictions are not conditioned on one another.")
    parser.add_argument("--verbose_logging", action='store_true',
                        help="If true, all of the warnings related to data processing will be printed. "
                             "A number of warnings are expected for a normal SQuAD evaluation.")
thomwolf's avatar
thomwolf committed
424

425
426
427
428
429
430
    parser.add_argument('--logging_steps', type=int, default=50,
                        help="Log every X updates steps.")
    parser.add_argument('--save_steps', type=int, default=50,
                        help="Save checkpoint every X updates steps.")
    parser.add_argument("--eval_all_checkpoints", action='store_true',
                        help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number")
thomwolf's avatar
thomwolf committed
431
    parser.add_argument("--no_cuda", action='store_true',
432
                        help="Whether not to use CUDA when available")
433
434
435
436
    parser.add_argument('--overwrite_output_dir', action='store_true',
                        help="Overwrite the content of the output directory")
    parser.add_argument('--overwrite_cache', action='store_true',
                        help="Overwrite the cached training and evaluation sets")
thomwolf's avatar
thomwolf committed
437
    parser.add_argument('--seed', type=int, default=42,
438
                        help="random seed for initialization")
439

thomwolf's avatar
thomwolf committed
440
    parser.add_argument("--local_rank", type=int, default=-1,
441
                        help="local_rank for distributed training on gpus")
thomwolf's avatar
thomwolf committed
442
443
444
445
446
    parser.add_argument('--fp16', action='store_true',
                        help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
    parser.add_argument('--fp16_opt_level', type=str, default='O1',
                        help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
                             "See details at https://nvidia.github.io/apex/amp.html")
447
448
449
450
    parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
    parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
    args = parser.parse_args()

thomwolf's avatar
thomwolf committed
451
452
453
    if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
        raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))

454
    # Setup distant debugging if needed
455
456
457
458
459
460
461
    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
        ptvsd.wait_for_attach()

thomwolf's avatar
thomwolf committed
462
    # Setup CUDA, GPU & distributed training
463
464
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
thomwolf's avatar
thomwolf committed
465
466
        args.n_gpu = torch.cuda.device_count()
    else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
467
468
469
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl')
thomwolf's avatar
thomwolf committed
470
471
        args.n_gpu = 1
    args.device = device
472

thomwolf's avatar
thomwolf committed
473
    # Setup logging
474
475
476
    logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                        datefmt = '%m/%d/%Y %H:%M:%S',
                        level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
thomwolf's avatar
thomwolf committed
477
    logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
478
                    args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
479

480
481
    # Set seed
    set_seed(args)
482

thomwolf's avatar
thomwolf committed
483
    # Load pretrained model and tokenizer
484
    if args.local_rank not in [-1, 0]:
485
486
        torch.distributed.barrier()  # Make sure only the first process in distributed training will download model & vocab

487
    args.model_type = args.model_type.lower()
488
    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
thomwolf's avatar
thomwolf committed
489
490
491
492
493
494
495
496
497
    config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path,
                                          cache_dir=args.cache_dir if args.cache_dir else None)
    tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
                                                do_lower_case=args.do_lower_case,
                                                cache_dir=args.cache_dir if args.cache_dir else None)
    model = model_class.from_pretrained(args.model_name_or_path,
                                        from_tf=bool('.ckpt' in args.model_name_or_path),
                                        config=config,
                                        cache_dir=args.cache_dir if args.cache_dir else None)
498
499

    if args.local_rank == 0:
500
        torch.distributed.barrier()  # Make sure only the first process in distributed training will download model & vocab
501

thomwolf's avatar
thomwolf committed
502
    model.to(args.device)
503

504
505
    logger.info("Training/evaluation parameters %s", args)

Simon Layton's avatar
Simon Layton committed
506
507
508
509
510
511
512
513
514
515
    # Before we do anything with models, we want to ensure that we get fp16 execution of torch.einsum if args.fp16 is set.
    # Otherwise it'll default to "promote" mode, and we'll get fp32 operations. Note that running `--fp16_opt_level="O2"` will
    # remove the need for this code, but it is still valid.
    if args.fp16:
        try:
            import apex
            apex.amp.register_half_function(torch, 'einsum')
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")

thomwolf's avatar
thomwolf committed
516
    # Training
517
    if args.do_train:
518
        train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False)
519
        global_step, tr_loss = train(args, train_dataset, model, tokenizer)
520
        logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
521

522

thomwolf's avatar
thomwolf committed
523
    # Save the trained model and the tokenizer
Peng Qi's avatar
Peng Qi committed
524
    if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
525
526
527
528
529
530
531
532
533
534
        # Create output directory if needed
        if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
            os.makedirs(args.output_dir)

        logger.info("Saving model checkpoint to %s", args.output_dir)
        # Save a trained model, configuration and tokenizer using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        model_to_save = model.module if hasattr(model, 'module') else model  # Take care of distributed/parallel training
        model_to_save.save_pretrained(args.output_dir)
        tokenizer.save_pretrained(args.output_dir)
535
536

        # Good practice: save your training arguments together with the trained model
537
        torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))
538

539
540
        # Load a trained model and vocabulary that you have fine-tuned
        model = model_class.from_pretrained(args.output_dir)
Peng Qi's avatar
Peng Qi committed
541
        tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
542
543
544
        model.to(args.device)


thomwolf's avatar
thomwolf committed
545
    # Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
546
547
548
549
550
    results = {}
    if args.do_eval and args.local_rank in [-1, 0]:
        checkpoints = [args.output_dir]
        if args.eval_all_checkpoints:
            checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
551
            logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN)  # Reduce model loading logs
thomwolf's avatar
thomwolf committed
552

553
        logger.info("Evaluate the following checkpoints: %s", checkpoints)
thomwolf's avatar
thomwolf committed
554

555
        for checkpoint in checkpoints:
thomwolf's avatar
thomwolf committed
556
            # Reload the model
557
558
559
            global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
            model = model_class.from_pretrained(checkpoint)
            model.to(args.device)
thomwolf's avatar
thomwolf committed
560
561

            # Evaluate
562
            result = evaluate(args, model, tokenizer, prefix=global_step)
thomwolf's avatar
thomwolf committed
563

564
565
            result = dict((k + ('_{}'.format(global_step) if global_step else ''), v) for k, v in result.items())
            results.update(result)
thomwolf's avatar
thomwolf committed
566

567
    logger.info("Results: {}".format(results))
thomwolf's avatar
thomwolf committed
568

569
    return results
570
571
572
573


if __name__ == "__main__":
    main()