run_glue.py 29.1 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
16
""" Finetuning the library models for sequence classification on GLUE (Bert, XLM, XLNet, RoBERTa)."""
thomwolf's avatar
thomwolf committed
17
18
19
20

from __future__ import absolute_import, division, print_function

import argparse
thomwolf's avatar
thomwolf committed
21
import glob
Aymeric Augustin's avatar
Aymeric Augustin committed
22
import json
thomwolf's avatar
thomwolf committed
23
24
25
26
27
28
import logging
import os
import random

import numpy as np
import torch
29
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
thomwolf's avatar
thomwolf committed
30
from torch.utils.data.distributed import DistributedSampler
thomwolf's avatar
thomwolf committed
31
from tqdm import tqdm, trange
thomwolf's avatar
thomwolf committed
32

33
34
from transformers import (
    WEIGHTS_NAME,
Aymeric Augustin's avatar
Aymeric Augustin committed
35
36
37
38
    AdamW,
    AlbertConfig,
    AlbertForSequenceClassification,
    AlbertTokenizer,
39
40
41
    BertConfig,
    BertForSequenceClassification,
    BertTokenizer,
Aymeric Augustin's avatar
Aymeric Augustin committed
42
43
44
    DistilBertConfig,
    DistilBertForSequenceClassification,
    DistilBertTokenizer,
45
46
47
48
49
    RobertaConfig,
    RobertaForSequenceClassification,
    RobertaTokenizer,
    XLMConfig,
    XLMForSequenceClassification,
Aymeric Augustin's avatar
Aymeric Augustin committed
50
51
52
    XLMRobertaConfig,
    XLMRobertaForSequenceClassification,
    XLMRobertaTokenizer,
53
54
55
56
    XLMTokenizer,
    XLNetConfig,
    XLNetForSequenceClassification,
    XLNetTokenizer,
Aymeric Augustin's avatar
Aymeric Augustin committed
57
    get_linear_schedule_with_warmup,
58
)
59
from transformers import glue_compute_metrics as compute_metrics
Aymeric Augustin's avatar
Aymeric Augustin committed
60
from transformers import glue_convert_examples_to_features as convert_examples_to_features
61
62
from transformers import glue_output_modes as output_modes
from transformers import glue_processors as processors
Aymeric Augustin's avatar
Aymeric Augustin committed
63
64
65
66


try:
    from torch.utils.tensorboard import SummaryWriter
67
except ImportError:
Aymeric Augustin's avatar
Aymeric Augustin committed
68
69
    from tensorboardX import SummaryWriter

thomwolf's avatar
thomwolf committed
70
71
72

logger = logging.getLogger(__name__)

73
74
75
76
77
78
79
ALL_MODELS = sum(
    (
        tuple(conf.pretrained_config_archive_map.keys())
        for conf in (BertConfig, XLNetConfig, XLMConfig, RobertaConfig, DistilBertConfig)
    ),
    (),
)
80
81

MODEL_CLASSES = {
82
83
84
85
86
87
88
    "bert": (BertConfig, BertForSequenceClassification, BertTokenizer),
    "xlnet": (XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer),
    "xlm": (XLMConfig, XLMForSequenceClassification, XLMTokenizer),
    "roberta": (RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer),
    "distilbert": (DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer),
    "albert": (AlbertConfig, AlbertForSequenceClassification, AlbertTokenizer),
    "xlmroberta": (XLMRobertaConfig, XLMRobertaForSequenceClassification, XLMRobertaTokenizer),
89
}
thomwolf's avatar
thomwolf committed
90

thomwolf's avatar
thomwolf committed
91
92
93
94
95
96
97
98
99

def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)


thomwolf's avatar
thomwolf committed
100
def train(args, train_dataset, model, tokenizer):
thomwolf's avatar
thomwolf committed
101
102
103
104
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

thomwolf's avatar
thomwolf committed
105
    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
106
107
    train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
thomwolf's avatar
thomwolf committed
108

thomwolf's avatar
thomwolf committed
109
    if args.max_steps > 0:
thomwolf's avatar
thomwolf committed
110
        t_total = args.max_steps
thomwolf's avatar
thomwolf committed
111
112
        args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
thomwolf's avatar
thomwolf committed
113
        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
thomwolf's avatar
thomwolf committed
114

thomwolf's avatar
thomwolf committed
115
    # Prepare optimizer and schedule (linear warmup and decay)
116
    no_decay = ["bias", "LayerNorm.weight"]
thomwolf's avatar
thomwolf committed
117
    optimizer_grouped_parameters = [
118
119
120
121
122
123
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
    ]
Lysandre's avatar
Lysandre committed
124

thomwolf's avatar
thomwolf committed
125
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
126
127
128
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
    )
129
130

    # Check if saved optimizer or scheduler states exist
131
132
133
    if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
        os.path.join(args.model_name_or_path, "scheduler.pt")
    ):
134
        # Load in optimizer and scheduler states
135
136
        optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))
137

thomwolf's avatar
thomwolf committed
138
139
    if args.fp16:
        try:
thomwolf's avatar
thomwolf committed
140
            from apex import amp
thomwolf's avatar
thomwolf committed
141
142
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
thomwolf's avatar
thomwolf committed
143
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
thomwolf's avatar
thomwolf committed
144

145
146
147
148
    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

thomwolf's avatar
thomwolf committed
149
150
    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
151
152
153
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
        )
thomwolf's avatar
thomwolf committed
154

thomwolf's avatar
thomwolf committed
155
156
    # Train!
    logger.info("***** Running training *****")
157
158
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
thomwolf's avatar
thomwolf committed
159
    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
160
161
162
163
164
165
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size
        * args.gradient_accumulation_steps
        * (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
166
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
thomwolf's avatar
thomwolf committed
167
    logger.info("  Total optimization steps = %d", t_total)
thomwolf's avatar
thomwolf committed
168
169

    global_step = 0
170
171
172
173
174
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if os.path.exists(args.model_name_or_path):
        # set global_step to gobal_step of last saved checkpoint from model path
175
        global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0])
176
177
178
179
180
181
182
183
        epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
        steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)

        logger.info("  Continuing training from checkpoint, will skip to saved global_step")
        logger.info("  Continuing training from epoch %d", epochs_trained)
        logger.info("  Continuing training from global step %d", global_step)
        logger.info("  Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)

thomwolf's avatar
thomwolf committed
184
    tr_loss, logging_loss = 0.0, 0.0
thomwolf's avatar
thomwolf committed
185
    model.zero_grad()
186
187
188
    train_iterator = trange(
        epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]
    )
thomwolf's avatar
thomwolf committed
189
    set_seed(args)  # Added here for reproductibility (even between python 2 and 3)
thomwolf's avatar
thomwolf committed
190
191
192
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):
193
194
195
196
197
198

            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

thomwolf's avatar
thomwolf committed
199
            model.train()
thomwolf's avatar
thomwolf committed
200
            batch = tuple(t.to(args.device) for t in batch)
201
202
203
204
205
            inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
            if args.model_type != "distilbert":
                inputs["token_type_ids"] = (
                    batch[2] if args.model_type in ["bert", "xlnet"] else None
                )  # XLM, DistilBERT and RoBERTa don't use segment_ids
Peiqin Lin's avatar
typos  
Peiqin Lin committed
206
            outputs = model(**inputs)
207
            loss = outputs[0]  # model outputs are always tuple in transformers (see doc)
thomwolf's avatar
thomwolf committed
208
209

            if args.n_gpu > 1:
210
                loss = loss.mean()  # mean() to average on multi-gpu parallel training
thomwolf's avatar
thomwolf committed
211
212
213
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

thomwolf's avatar
thomwolf committed
214
215
216
217
218
            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
thomwolf's avatar
thomwolf committed
219
220

            tr_loss += loss.item()
221
            if (step + 1) % args.gradient_accumulation_steps == 0:
222
223
224
225
226
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)

thomwolf's avatar
thomwolf committed
227
                optimizer.step()
thomwolf's avatar
thomwolf committed
228
                scheduler.step()  # Update learning rate schedule
thomwolf's avatar
thomwolf committed
229
                model.zero_grad()
thomwolf's avatar
thomwolf committed
230
                global_step += 1
thomwolf's avatar
thomwolf committed
231

thomwolf's avatar
thomwolf committed
232
                if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
Juha Kiili's avatar
Juha Kiili committed
233
                    logs = {}
234
235
236
                    if (
                        args.local_rank == -1 and args.evaluate_during_training
                    ):  # Only evaluate when single GPU otherwise metrics may not average well
thomwolf's avatar
thomwolf committed
237
                        results = evaluate(args, model, tokenizer)
thomwolf's avatar
thomwolf committed
238
                        for key, value in results.items():
239
                            eval_key = "eval_{}".format(key)
Juha Kiili's avatar
Juha Kiili committed
240
241
                            logs[eval_key] = value

242
243
                    loss_scalar = (tr_loss - logging_loss) / args.logging_steps
                    learning_rate_scalar = scheduler.get_lr()[0]
244
245
                    logs["learning_rate"] = learning_rate_scalar
                    logs["loss"] = loss_scalar
thomwolf's avatar
thomwolf committed
246
                    logging_loss = tr_loss
thomwolf's avatar
thomwolf committed
247

Juha Kiili's avatar
Juha Kiili committed
248
249
                    for key, value in logs.items():
                        tb_writer.add_scalar(key, value, global_step)
250
                    print(json.dumps({**logs, **{"step": global_step}}))
thomwolf's avatar
thomwolf committed
251
252
253

                if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
254
                    output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
thomwolf's avatar
thomwolf committed
255
256
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
257
258
259
                    model_to_save = (
                        model.module if hasattr(model, "module") else model
                    )  # Take care of distributed/parallel training
thomwolf's avatar
thomwolf committed
260
                    model_to_save.save_pretrained(output_dir)
261
262
                    tokenizer.save_pretrained(output_dir)

263
                    torch.save(args, os.path.join(output_dir, "training_args.bin"))
thomwolf's avatar
thomwolf committed
264
                    logger.info("Saving model checkpoint to %s", output_dir)
thomwolf's avatar
thomwolf committed
265

266
267
                    torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
268
269
                    logger.info("Saving optimizer and scheduler states to %s", output_dir)

thomwolf's avatar
thomwolf committed
270
            if args.max_steps > 0 and global_step > args.max_steps:
thomwolf's avatar
thomwolf committed
271
                epoch_iterator.close()
thomwolf's avatar
thomwolf committed
272
273
                break
        if args.max_steps > 0 and global_step > args.max_steps:
thomwolf's avatar
thomwolf committed
274
            train_iterator.close()
thomwolf's avatar
thomwolf committed
275
            break
thomwolf's avatar
thomwolf committed
276

thomwolf's avatar
thomwolf committed
277
278
279
    if args.local_rank in [-1, 0]:
        tb_writer.close()

thomwolf's avatar
thomwolf committed
280
281
282
    return global_step, tr_loss / global_step


thomwolf's avatar
thomwolf committed
283
def evaluate(args, model, tokenizer, prefix=""):
thomwolf's avatar
thomwolf committed
284
285
    # Loop to handle MNLI double evaluation (matched, mis-matched)
    eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,)
286
    eval_outputs_dirs = (args.output_dir, args.output_dir + "-MM") if args.task_name == "mnli" else (args.output_dir,)
thomwolf's avatar
thomwolf committed
287
288
289
290
291
292
293
294

    results = {}
    for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs):
        eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=True)

        if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
            os.makedirs(eval_output_dir)

thomwolf's avatar
thomwolf committed
295
        args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
thomwolf's avatar
thomwolf committed
296
        # Note that DistributedSampler samples randomly
297
        eval_sampler = SequentialSampler(eval_dataset)
thomwolf's avatar
thomwolf committed
298
299
        eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)

ronakice's avatar
ronakice committed
300
301
302
303
        # multi-gpu eval
        if args.n_gpu > 1:
            model = torch.nn.DataParallel(model)

thomwolf's avatar
thomwolf committed
304
        # Eval!
thomwolf's avatar
thomwolf committed
305
        logger.info("***** Running evaluation {} *****".format(prefix))
thomwolf's avatar
thomwolf committed
306
307
        logger.info("  Num examples = %d", len(eval_dataset))
        logger.info("  Batch size = %d", args.eval_batch_size)
thomwolf's avatar
thomwolf committed
308
        eval_loss = 0.0
thomwolf's avatar
thomwolf committed
309
310
311
312
        nb_eval_steps = 0
        preds = None
        out_label_ids = None
        for batch in tqdm(eval_dataloader, desc="Evaluating"):
thomwolf's avatar
thomwolf committed
313
            model.eval()
thomwolf's avatar
thomwolf committed
314
315
316
            batch = tuple(t.to(args.device) for t in batch)

            with torch.no_grad():
317
318
319
320
321
                inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
                if args.model_type != "distilbert":
                    inputs["token_type_ids"] = (
                        batch[2] if args.model_type in ["bert", "xlnet"] else None
                    )  # XLM, DistilBERT and RoBERTa don't use segment_ids
thomwolf's avatar
thomwolf committed
322
323
324
                outputs = model(**inputs)
                tmp_eval_loss, logits = outputs[:2]

thomwolf's avatar
thomwolf committed
325
                eval_loss += tmp_eval_loss.mean().item()
thomwolf's avatar
thomwolf committed
326
327
328
            nb_eval_steps += 1
            if preds is None:
                preds = logits.detach().cpu().numpy()
329
                out_label_ids = inputs["labels"].detach().cpu().numpy()
thomwolf's avatar
thomwolf committed
330
331
            else:
                preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
332
                out_label_ids = np.append(out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0)
thomwolf's avatar
thomwolf committed
333
334
335
336
337
338
339
340
341

        eval_loss = eval_loss / nb_eval_steps
        if args.output_mode == "classification":
            preds = np.argmax(preds, axis=1)
        elif args.output_mode == "regression":
            preds = np.squeeze(preds)
        result = compute_metrics(eval_task, preds, out_label_ids)
        results.update(result)

342
        output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
thomwolf's avatar
thomwolf committed
343
        with open(output_eval_file, "w") as writer:
thomwolf's avatar
thomwolf committed
344
            logger.info("***** Eval results {} *****".format(prefix))
thomwolf's avatar
thomwolf committed
345
346
347
348
349
350
351
            for key in sorted(result.keys()):
                logger.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))

    return results


thomwolf's avatar
thomwolf committed
352
def load_and_cache_examples(args, task, tokenizer, evaluate=False):
VictorSanh's avatar
VictorSanh committed
353
    if args.local_rank not in [-1, 0] and not evaluate:
thomwolf's avatar
thomwolf committed
354
355
        torch.distributed.barrier()  # Make sure only the first process in distributed training process the dataset, and the others will use the cache

thomwolf's avatar
thomwolf committed
356
    processor = processors[task]()
357
358
    output_mode = output_modes[task]
    # Load data features from cache or dataset file
359
360
361
362
363
364
365
366
367
    cached_features_file = os.path.join(
        args.data_dir,
        "cached_{}_{}_{}_{}".format(
            "dev" if evaluate else "train",
            list(filter(None, args.model_name_or_path.split("/"))).pop(),
            str(args.max_seq_length),
            str(task),
        ),
    )
368
    if os.path.exists(cached_features_file) and not args.overwrite_cache:
thomwolf's avatar
thomwolf committed
369
        logger.info("Loading features from cached file %s", cached_features_file)
thomwolf's avatar
thomwolf committed
370
371
        features = torch.load(cached_features_file)
    else:
372
373
        logger.info("Creating features from dataset file at %s", args.data_dir)
        label_list = processor.get_labels()
374
        if task in ["mnli", "mnli-mm"] and args.model_type in ["roberta", "xlmroberta"]:
375
            # HACK(label indices are swapped in RoBERTa pretrained model)
376
            label_list[1], label_list[2] = label_list[2], label_list[1]
377
378
379
380
381
382
383
384
385
386
387
388
        examples = (
            processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir)
        )
        features = convert_examples_to_features(
            examples,
            tokenizer,
            label_list=label_list,
            max_length=args.max_seq_length,
            output_mode=output_mode,
            pad_on_left=bool(args.model_type in ["xlnet"]),  # pad on the left for xlnet
            pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
            pad_token_segment_id=4 if args.model_type in ["xlnet"] else 0,
389
        )
390
        if args.local_rank in [-1, 0]:
thomwolf's avatar
thomwolf committed
391
            logger.info("Saving features into cached file %s", cached_features_file)
thomwolf's avatar
thomwolf committed
392
393
            torch.save(features, cached_features_file)

VictorSanh's avatar
VictorSanh committed
394
    if args.local_rank == 0 and not evaluate:
thomwolf's avatar
thomwolf committed
395
396
        torch.distributed.barrier()  # Make sure only the first process in distributed training process the dataset, and the others will use the cache

397
398
    # Convert to Tensors and build dataset
    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
thomwolf's avatar
thomwolf committed
399
400
    all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
    all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
401
    if output_mode == "classification":
thomwolf's avatar
thomwolf committed
402
        all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
403
    elif output_mode == "regression":
thomwolf's avatar
thomwolf committed
404
        all_labels = torch.tensor([f.label for f in features], dtype=torch.float)
405

thomwolf's avatar
thomwolf committed
406
    dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels)
407
    return dataset
thomwolf's avatar
thomwolf committed
408
409


thomwolf's avatar
thomwolf committed
410
411
412
def main():
    parser = argparse.ArgumentParser()

413
    # Required parameters
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help="The input data dir. Should contain the .tsv files (or other data files) for the task.",
    )
    parser.add_argument(
        "--model_type",
        default=None,
        type=str,
        required=True,
        help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
    )
    parser.add_argument(
        "--model_name_or_path",
        default=None,
        type=str,
        required=True,
        help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
    )
    parser.add_argument(
        "--task_name",
        default=None,
        type=str,
        required=True,
        help="The name of the task to train selected in the list: " + ", ".join(processors.keys()),
    )
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help="The output directory where the model predictions and checkpoints will be written.",
    )
thomwolf's avatar
thomwolf committed
449

450
    # Other parameters
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
    parser.add_argument(
        "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
    )
    parser.add_argument(
        "--tokenizer_name",
        default="",
        type=str,
        help="Pretrained tokenizer name or path if not the same as model_name",
    )
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help="Where do you want to store the pre-trained models downloaded from s3",
    )
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help="The maximum total input sequence length after tokenization. Sequences longer "
        "than this will be truncated, sequences shorter will be padded.",
    )
    parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
    parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--evaluate_during_training", action="store_true", help="Rul evaluation during training at each logging step."
    )
    parser.add_argument(
        "--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
    )

    parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.")
    parser.add_argument(
        "--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation."
    )
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help="Number of updates steps to accumulate before performing a backward/update pass.",
    )
    parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
    parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
    parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
    parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
    parser.add_argument(
        "--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform."
    )
    parser.add_argument(
        "--max_steps",
        default=-1,
        type=int,
        help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
    )
    parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")

    parser.add_argument("--logging_steps", type=int, default=50, help="Log every X updates steps.")
    parser.add_argument("--save_steps", type=int, default=50, help="Save checkpoint every X updates steps.")
    parser.add_argument(
        "--eval_all_checkpoints",
        action="store_true",
        help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
    )
    parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
    parser.add_argument(
        "--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory"
    )
    parser.add_argument(
        "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
    )
    parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")

    parser.add_argument(
        "--fp16",
        action="store_true",
        help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
    )
    parser.add_argument(
        "--fp16_opt_level",
        type=str,
        default="O1",
        help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
        "See details at https://nvidia.github.io/apex/amp.html",
    )
    parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
    parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
    parser.add_argument("--server_port", type=str, default="", help="For distant debugging.")
thomwolf's avatar
thomwolf committed
538
539
    args = parser.parse_args()

540
541
542
543
544
545
546
547
548
549
550
    if (
        os.path.exists(args.output_dir)
        and os.listdir(args.output_dir)
        and args.do_train
        and not args.overwrite_output_dir
    ):
        raise ValueError(
            "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
                args.output_dir
            )
        )
thomwolf's avatar
thomwolf committed
551

thomwolf's avatar
thomwolf committed
552
553
554
555
    # Setup distant debugging if needed
    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
556

thomwolf's avatar
thomwolf committed
557
558
559
560
561
562
563
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
        ptvsd.wait_for_attach()

    # Setup CUDA, GPU & distributed training
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
thomwolf's avatar
thomwolf committed
564
        args.n_gpu = torch.cuda.device_count()
thomwolf's avatar
thomwolf committed
565
566
567
    else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
568
        torch.distributed.init_process_group(backend="nccl")
thomwolf's avatar
thomwolf committed
569
        args.n_gpu = 1
thomwolf's avatar
thomwolf committed
570
571
572
    args.device = device

    # Setup logging
573
574
575
576
577
578
579
580
581
582
583
584
585
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
    )
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        args.local_rank,
        device,
        args.n_gpu,
        bool(args.local_rank != -1),
        args.fp16,
    )
thomwolf's avatar
thomwolf committed
586

thomwolf's avatar
thomwolf committed
587
588
    # Set seed
    set_seed(args)
thomwolf's avatar
thomwolf committed
589
590

    # Prepare GLUE task
thomwolf's avatar
thomwolf committed
591
592
593
594
595
    args.task_name = args.task_name.lower()
    if args.task_name not in processors:
        raise ValueError("Task not found: %s" % (args.task_name))
    processor = processors[args.task_name]()
    args.output_mode = output_modes[args.task_name]
thomwolf's avatar
thomwolf committed
596
597
598
599
600
    label_list = processor.get_labels()
    num_labels = len(label_list)

    # Load pretrained model and tokenizer
    if args.local_rank not in [-1, 0]:
601
        torch.distributed.barrier()  # Make sure only the first process in distributed training will download model & vocab
thomwolf's avatar
thomwolf committed
602

603
    args.model_type = args.model_type.lower()
thomwolf's avatar
thomwolf committed
604
    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
    config = config_class.from_pretrained(
        args.config_name if args.config_name else args.model_name_or_path,
        num_labels=num_labels,
        finetuning_task=args.task_name,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )
    tokenizer = tokenizer_class.from_pretrained(
        args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
        do_lower_case=args.do_lower_case,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )
    model = model_class.from_pretrained(
        args.model_name_or_path,
        from_tf=bool(".ckpt" in args.model_name_or_path),
        config=config,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )
thomwolf's avatar
thomwolf committed
622
623

    if args.local_rank == 0:
624
        torch.distributed.barrier()  # Make sure only the first process in distributed training will download model & vocab
thomwolf's avatar
thomwolf committed
625

thomwolf's avatar
thomwolf committed
626
    model.to(args.device)
thomwolf's avatar
thomwolf committed
627

thomwolf's avatar
thomwolf committed
628
629
    logger.info("Training/evaluation parameters %s", args)

thomwolf's avatar
thomwolf committed
630
    # Training
thomwolf's avatar
thomwolf committed
631
    if args.do_train:
632
        train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False)
thomwolf's avatar
thomwolf committed
633
        global_step, tr_loss = train(args, train_dataset, model, tokenizer)
thomwolf's avatar
thomwolf committed
634
        logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
thomwolf's avatar
thomwolf committed
635

thomwolf's avatar
thomwolf committed
636
    # Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
637
    if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
638
639
640
641
        # Create output directory if needed
        if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
            os.makedirs(args.output_dir)

thomwolf's avatar
thomwolf committed
642
        logger.info("Saving model checkpoint to %s", args.output_dir)
643
644
        # Save a trained model, configuration and tokenizer using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
645
646
647
        model_to_save = (
            model.module if hasattr(model, "module") else model
        )  # Take care of distributed/parallel training
thomwolf's avatar
thomwolf committed
648
        model_to_save.save_pretrained(args.output_dir)
649
        tokenizer.save_pretrained(args.output_dir)
thomwolf's avatar
thomwolf committed
650
651

        # Good practice: save your training arguments together with the trained model
652
        torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
thomwolf's avatar
thomwolf committed
653

654
        # Load a trained model and vocabulary that you have fine-tuned
655
        model = model_class.from_pretrained(args.output_dir)
thomwolf's avatar
thomwolf committed
656
        tokenizer = tokenizer_class.from_pretrained(args.output_dir)
657
        model.to(args.device)
thomwolf's avatar
thomwolf committed
658

thomwolf's avatar
thomwolf committed
659
    # Evaluation
thomwolf's avatar
thomwolf committed
660
    results = {}
thomwolf's avatar
thomwolf committed
661
    if args.do_eval and args.local_rank in [-1, 0]:
662
        tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
thomwolf's avatar
thomwolf committed
663
        checkpoints = [args.output_dir]
thomwolf's avatar
thomwolf committed
664
        if args.eval_all_checkpoints:
665
666
667
            checkpoints = list(
                os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
            )
668
            logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN)  # Reduce logging
thomwolf's avatar
thomwolf committed
669
670
        logger.info("Evaluate the following checkpoints: %s", checkpoints)
        for checkpoint in checkpoints:
671
672
673
            global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
            prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""

thomwolf's avatar
thomwolf committed
674
            model = model_class.from_pretrained(checkpoint)
thomwolf's avatar
thomwolf committed
675
            model.to(args.device)
676
            result = evaluate(args, model, tokenizer, prefix=prefix)
677
            result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
thomwolf's avatar
thomwolf committed
678
679
            results.update(result)

thomwolf's avatar
thomwolf committed
680
    return results
thomwolf's avatar
thomwolf committed
681
682
683
684


if __name__ == "__main__":
    main()