run_glue.py 29.5 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
16
""" Finetuning the library models for sequence classification on GLUE (Bert, XLM, XLNet, RoBERTa, Albert, XLM-RoBERTa)."""
thomwolf's avatar
thomwolf committed
17
18
19


import argparse
thomwolf's avatar
thomwolf committed
20
import glob
Aymeric Augustin's avatar
Aymeric Augustin committed
21
import json
thomwolf's avatar
thomwolf committed
22
23
24
25
26
27
import logging
import os
import random

import numpy as np
import torch
28
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
thomwolf's avatar
thomwolf committed
29
from torch.utils.data.distributed import DistributedSampler
thomwolf's avatar
thomwolf committed
30
from tqdm import tqdm, trange
thomwolf's avatar
thomwolf committed
31

32
33
from transformers import (
    WEIGHTS_NAME,
Aymeric Augustin's avatar
Aymeric Augustin committed
34
35
36
37
    AdamW,
    AlbertConfig,
    AlbertForSequenceClassification,
    AlbertTokenizer,
38
39
40
    BertConfig,
    BertForSequenceClassification,
    BertTokenizer,
Aymeric Augustin's avatar
Aymeric Augustin committed
41
42
43
    DistilBertConfig,
    DistilBertForSequenceClassification,
    DistilBertTokenizer,
Hang Le's avatar
Hang Le committed
44
45
46
    FlaubertConfig,
    FlaubertForSequenceClassification,
    FlaubertTokenizer,
47
48
49
50
51
    RobertaConfig,
    RobertaForSequenceClassification,
    RobertaTokenizer,
    XLMConfig,
    XLMForSequenceClassification,
Aymeric Augustin's avatar
Aymeric Augustin committed
52
53
54
    XLMRobertaConfig,
    XLMRobertaForSequenceClassification,
    XLMRobertaTokenizer,
55
56
57
58
    XLMTokenizer,
    XLNetConfig,
    XLNetForSequenceClassification,
    XLNetTokenizer,
Aymeric Augustin's avatar
Aymeric Augustin committed
59
    get_linear_schedule_with_warmup,
60
)
61
from transformers import glue_compute_metrics as compute_metrics
Aymeric Augustin's avatar
Aymeric Augustin committed
62
from transformers import glue_convert_examples_to_features as convert_examples_to_features
63
64
from transformers import glue_output_modes as output_modes
from transformers import glue_processors as processors
Aymeric Augustin's avatar
Aymeric Augustin committed
65
66
67
68


try:
    from torch.utils.tensorboard import SummaryWriter
69
except ImportError:
Aymeric Augustin's avatar
Aymeric Augustin committed
70
71
    from tensorboardX import SummaryWriter

thomwolf's avatar
thomwolf committed
72
73
74

logger = logging.getLogger(__name__)

75
76
77
ALL_MODELS = sum(
    (
        tuple(conf.pretrained_config_archive_map.keys())
78
79
80
81
82
83
84
85
        for conf in (
            BertConfig,
            XLNetConfig,
            XLMConfig,
            RobertaConfig,
            DistilBertConfig,
            AlbertConfig,
            XLMRobertaConfig,
Hang Le's avatar
Hang Le committed
86
            FlaubertConfig,
87
        )
88
89
90
    ),
    (),
)
91
92

MODEL_CLASSES = {
93
94
95
96
97
98
99
    "bert": (BertConfig, BertForSequenceClassification, BertTokenizer),
    "xlnet": (XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer),
    "xlm": (XLMConfig, XLMForSequenceClassification, XLMTokenizer),
    "roberta": (RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer),
    "distilbert": (DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer),
    "albert": (AlbertConfig, AlbertForSequenceClassification, AlbertTokenizer),
    "xlmroberta": (XLMRobertaConfig, XLMRobertaForSequenceClassification, XLMRobertaTokenizer),
Hang Le's avatar
Hang Le committed
100
    "flaubert": (FlaubertConfig, FlaubertForSequenceClassification, FlaubertTokenizer),
101
}
thomwolf's avatar
thomwolf committed
102

thomwolf's avatar
thomwolf committed
103
104
105
106
107
108
109
110
111

def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)


thomwolf's avatar
thomwolf committed
112
def train(args, train_dataset, model, tokenizer):
thomwolf's avatar
thomwolf committed
113
114
115
116
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

thomwolf's avatar
thomwolf committed
117
    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
118
119
    train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
thomwolf's avatar
thomwolf committed
120

thomwolf's avatar
thomwolf committed
121
    if args.max_steps > 0:
thomwolf's avatar
thomwolf committed
122
        t_total = args.max_steps
thomwolf's avatar
thomwolf committed
123
124
        args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
thomwolf's avatar
thomwolf committed
125
        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
thomwolf's avatar
thomwolf committed
126

thomwolf's avatar
thomwolf committed
127
    # Prepare optimizer and schedule (linear warmup and decay)
128
    no_decay = ["bias", "LayerNorm.weight"]
thomwolf's avatar
thomwolf committed
129
    optimizer_grouped_parameters = [
130
131
132
133
134
135
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
    ]
Lysandre's avatar
Lysandre committed
136

thomwolf's avatar
thomwolf committed
137
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
138
139
140
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
    )
141
142

    # Check if saved optimizer or scheduler states exist
143
144
145
    if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
        os.path.join(args.model_name_or_path, "scheduler.pt")
    ):
146
        # Load in optimizer and scheduler states
147
148
        optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))
149

thomwolf's avatar
thomwolf committed
150
151
    if args.fp16:
        try:
thomwolf's avatar
thomwolf committed
152
            from apex import amp
thomwolf's avatar
thomwolf committed
153
154
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
thomwolf's avatar
thomwolf committed
155
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
thomwolf's avatar
thomwolf committed
156

157
158
159
160
    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

thomwolf's avatar
thomwolf committed
161
162
    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
163
        model = torch.nn.parallel.DistributedDataParallel(
164
            model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True,
165
        )
thomwolf's avatar
thomwolf committed
166

thomwolf's avatar
thomwolf committed
167
168
    # Train!
    logger.info("***** Running training *****")
169
170
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
thomwolf's avatar
thomwolf committed
171
    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
172
173
174
175
176
177
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size
        * args.gradient_accumulation_steps
        * (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
178
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
thomwolf's avatar
thomwolf committed
179
    logger.info("  Total optimization steps = %d", t_total)
thomwolf's avatar
thomwolf committed
180
181

    global_step = 0
182
183
184
185
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if os.path.exists(args.model_name_or_path):
186
187
188
189
190
        # set global_step to global_step of last saved checkpoint from model path
        try:
            global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0])
        except ValueError:
            global_step = 0
191
192
193
194
195
196
197
198
        epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
        steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)

        logger.info("  Continuing training from checkpoint, will skip to saved global_step")
        logger.info("  Continuing training from epoch %d", epochs_trained)
        logger.info("  Continuing training from global step %d", global_step)
        logger.info("  Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)

thomwolf's avatar
thomwolf committed
199
    tr_loss, logging_loss = 0.0, 0.0
thomwolf's avatar
thomwolf committed
200
    model.zero_grad()
201
    train_iterator = trange(
202
        epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0],
203
    )
204
    set_seed(args)  # Added here for reproductibility
thomwolf's avatar
thomwolf committed
205
206
207
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):
208
209
210
211
212
213

            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

thomwolf's avatar
thomwolf committed
214
            model.train()
thomwolf's avatar
thomwolf committed
215
            batch = tuple(t.to(args.device) for t in batch)
216
217
218
            inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
            if args.model_type != "distilbert":
                inputs["token_type_ids"] = (
219
220
                    batch[2] if args.model_type in ["bert", "xlnet", "albert"] else None
                )  # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids
Peiqin Lin's avatar
typos  
Peiqin Lin committed
221
            outputs = model(**inputs)
222
            loss = outputs[0]  # model outputs are always tuple in transformers (see doc)
thomwolf's avatar
thomwolf committed
223
224

            if args.n_gpu > 1:
225
                loss = loss.mean()  # mean() to average on multi-gpu parallel training
thomwolf's avatar
thomwolf committed
226
227
228
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

thomwolf's avatar
thomwolf committed
229
230
231
232
233
            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
thomwolf's avatar
thomwolf committed
234
235

            tr_loss += loss.item()
236
            if (step + 1) % args.gradient_accumulation_steps == 0:
237
238
239
240
241
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)

thomwolf's avatar
thomwolf committed
242
                optimizer.step()
thomwolf's avatar
thomwolf committed
243
                scheduler.step()  # Update learning rate schedule
thomwolf's avatar
thomwolf committed
244
                model.zero_grad()
thomwolf's avatar
thomwolf committed
245
                global_step += 1
thomwolf's avatar
thomwolf committed
246

thomwolf's avatar
thomwolf committed
247
                if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
Juha Kiili's avatar
Juha Kiili committed
248
                    logs = {}
249
250
251
                    if (
                        args.local_rank == -1 and args.evaluate_during_training
                    ):  # Only evaluate when single GPU otherwise metrics may not average well
thomwolf's avatar
thomwolf committed
252
                        results = evaluate(args, model, tokenizer)
thomwolf's avatar
thomwolf committed
253
                        for key, value in results.items():
254
                            eval_key = "eval_{}".format(key)
Juha Kiili's avatar
Juha Kiili committed
255
256
                            logs[eval_key] = value

257
258
                    loss_scalar = (tr_loss - logging_loss) / args.logging_steps
                    learning_rate_scalar = scheduler.get_lr()[0]
259
260
                    logs["learning_rate"] = learning_rate_scalar
                    logs["loss"] = loss_scalar
thomwolf's avatar
thomwolf committed
261
                    logging_loss = tr_loss
thomwolf's avatar
thomwolf committed
262

Juha Kiili's avatar
Juha Kiili committed
263
264
                    for key, value in logs.items():
                        tb_writer.add_scalar(key, value, global_step)
265
                    print(json.dumps({**logs, **{"step": global_step}}))
thomwolf's avatar
thomwolf committed
266
267
268

                if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
269
                    output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
thomwolf's avatar
thomwolf committed
270
271
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
272
273
274
                    model_to_save = (
                        model.module if hasattr(model, "module") else model
                    )  # Take care of distributed/parallel training
thomwolf's avatar
thomwolf committed
275
                    model_to_save.save_pretrained(output_dir)
276
277
                    tokenizer.save_pretrained(output_dir)

278
                    torch.save(args, os.path.join(output_dir, "training_args.bin"))
thomwolf's avatar
thomwolf committed
279
                    logger.info("Saving model checkpoint to %s", output_dir)
thomwolf's avatar
thomwolf committed
280

281
282
                    torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
283
284
                    logger.info("Saving optimizer and scheduler states to %s", output_dir)

thomwolf's avatar
thomwolf committed
285
            if args.max_steps > 0 and global_step > args.max_steps:
thomwolf's avatar
thomwolf committed
286
                epoch_iterator.close()
thomwolf's avatar
thomwolf committed
287
288
                break
        if args.max_steps > 0 and global_step > args.max_steps:
thomwolf's avatar
thomwolf committed
289
            train_iterator.close()
thomwolf's avatar
thomwolf committed
290
            break
thomwolf's avatar
thomwolf committed
291

thomwolf's avatar
thomwolf committed
292
293
294
    if args.local_rank in [-1, 0]:
        tb_writer.close()

thomwolf's avatar
thomwolf committed
295
296
297
    return global_step, tr_loss / global_step


thomwolf's avatar
thomwolf committed
298
def evaluate(args, model, tokenizer, prefix=""):
thomwolf's avatar
thomwolf committed
299
300
    # Loop to handle MNLI double evaluation (matched, mis-matched)
    eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,)
301
    eval_outputs_dirs = (args.output_dir, args.output_dir + "-MM") if args.task_name == "mnli" else (args.output_dir,)
thomwolf's avatar
thomwolf committed
302
303
304
305
306
307
308
309

    results = {}
    for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs):
        eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=True)

        if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
            os.makedirs(eval_output_dir)

thomwolf's avatar
thomwolf committed
310
        args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
thomwolf's avatar
thomwolf committed
311
        # Note that DistributedSampler samples randomly
312
        eval_sampler = SequentialSampler(eval_dataset)
thomwolf's avatar
thomwolf committed
313
314
        eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)

ronakice's avatar
ronakice committed
315
        # multi-gpu eval
316
        if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel):
ronakice's avatar
ronakice committed
317
318
            model = torch.nn.DataParallel(model)

thomwolf's avatar
thomwolf committed
319
        # Eval!
thomwolf's avatar
thomwolf committed
320
        logger.info("***** Running evaluation {} *****".format(prefix))
thomwolf's avatar
thomwolf committed
321
322
        logger.info("  Num examples = %d", len(eval_dataset))
        logger.info("  Batch size = %d", args.eval_batch_size)
thomwolf's avatar
thomwolf committed
323
        eval_loss = 0.0
thomwolf's avatar
thomwolf committed
324
325
326
327
        nb_eval_steps = 0
        preds = None
        out_label_ids = None
        for batch in tqdm(eval_dataloader, desc="Evaluating"):
thomwolf's avatar
thomwolf committed
328
            model.eval()
thomwolf's avatar
thomwolf committed
329
330
331
            batch = tuple(t.to(args.device) for t in batch)

            with torch.no_grad():
332
333
334
                inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
                if args.model_type != "distilbert":
                    inputs["token_type_ids"] = (
335
336
                        batch[2] if args.model_type in ["bert", "xlnet", "albert"] else None
                    )  # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids
thomwolf's avatar
thomwolf committed
337
338
339
                outputs = model(**inputs)
                tmp_eval_loss, logits = outputs[:2]

thomwolf's avatar
thomwolf committed
340
                eval_loss += tmp_eval_loss.mean().item()
thomwolf's avatar
thomwolf committed
341
342
343
            nb_eval_steps += 1
            if preds is None:
                preds = logits.detach().cpu().numpy()
344
                out_label_ids = inputs["labels"].detach().cpu().numpy()
thomwolf's avatar
thomwolf committed
345
346
            else:
                preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
347
                out_label_ids = np.append(out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0)
thomwolf's avatar
thomwolf committed
348
349
350
351
352
353
354
355
356

        eval_loss = eval_loss / nb_eval_steps
        if args.output_mode == "classification":
            preds = np.argmax(preds, axis=1)
        elif args.output_mode == "regression":
            preds = np.squeeze(preds)
        result = compute_metrics(eval_task, preds, out_label_ids)
        results.update(result)

357
        output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
thomwolf's avatar
thomwolf committed
358
        with open(output_eval_file, "w") as writer:
thomwolf's avatar
thomwolf committed
359
            logger.info("***** Eval results {} *****".format(prefix))
thomwolf's avatar
thomwolf committed
360
361
362
363
364
365
366
            for key in sorted(result.keys()):
                logger.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))

    return results


thomwolf's avatar
thomwolf committed
367
def load_and_cache_examples(args, task, tokenizer, evaluate=False):
VictorSanh's avatar
VictorSanh committed
368
    if args.local_rank not in [-1, 0] and not evaluate:
thomwolf's avatar
thomwolf committed
369
370
        torch.distributed.barrier()  # Make sure only the first process in distributed training process the dataset, and the others will use the cache

thomwolf's avatar
thomwolf committed
371
    processor = processors[task]()
372
373
    output_mode = output_modes[task]
    # Load data features from cache or dataset file
374
375
376
377
378
379
380
381
382
    cached_features_file = os.path.join(
        args.data_dir,
        "cached_{}_{}_{}_{}".format(
            "dev" if evaluate else "train",
            list(filter(None, args.model_name_or_path.split("/"))).pop(),
            str(args.max_seq_length),
            str(task),
        ),
    )
383
    if os.path.exists(cached_features_file) and not args.overwrite_cache:
thomwolf's avatar
thomwolf committed
384
        logger.info("Loading features from cached file %s", cached_features_file)
thomwolf's avatar
thomwolf committed
385
386
        features = torch.load(cached_features_file)
    else:
387
388
        logger.info("Creating features from dataset file at %s", args.data_dir)
        label_list = processor.get_labels()
389
        if task in ["mnli", "mnli-mm"] and args.model_type in ["roberta", "xlmroberta"]:
390
            # HACK(label indices are swapped in RoBERTa pretrained model)
391
            label_list[1], label_list[2] = label_list[2], label_list[1]
392
393
394
395
396
397
398
399
400
401
402
403
        examples = (
            processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir)
        )
        features = convert_examples_to_features(
            examples,
            tokenizer,
            label_list=label_list,
            max_length=args.max_seq_length,
            output_mode=output_mode,
            pad_on_left=bool(args.model_type in ["xlnet"]),  # pad on the left for xlnet
            pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
            pad_token_segment_id=4 if args.model_type in ["xlnet"] else 0,
404
        )
405
        if args.local_rank in [-1, 0]:
thomwolf's avatar
thomwolf committed
406
            logger.info("Saving features into cached file %s", cached_features_file)
thomwolf's avatar
thomwolf committed
407
408
            torch.save(features, cached_features_file)

VictorSanh's avatar
VictorSanh committed
409
    if args.local_rank == 0 and not evaluate:
thomwolf's avatar
thomwolf committed
410
411
        torch.distributed.barrier()  # Make sure only the first process in distributed training process the dataset, and the others will use the cache

412
413
    # Convert to Tensors and build dataset
    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
thomwolf's avatar
thomwolf committed
414
415
    all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
    all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
416
    if output_mode == "classification":
thomwolf's avatar
thomwolf committed
417
        all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
418
    elif output_mode == "regression":
thomwolf's avatar
thomwolf committed
419
        all_labels = torch.tensor([f.label for f in features], dtype=torch.float)
420

thomwolf's avatar
thomwolf committed
421
    dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels)
422
    return dataset
thomwolf's avatar
thomwolf committed
423
424


thomwolf's avatar
thomwolf committed
425
426
427
def main():
    parser = argparse.ArgumentParser()

428
    # Required parameters
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help="The input data dir. Should contain the .tsv files (or other data files) for the task.",
    )
    parser.add_argument(
        "--model_type",
        default=None,
        type=str,
        required=True,
        help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
    )
    parser.add_argument(
        "--model_name_or_path",
        default=None,
        type=str,
        required=True,
        help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
    )
    parser.add_argument(
        "--task_name",
        default=None,
        type=str,
        required=True,
        help="The name of the task to train selected in the list: " + ", ".join(processors.keys()),
    )
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help="The output directory where the model predictions and checkpoints will be written.",
    )
thomwolf's avatar
thomwolf committed
464

465
    # Other parameters
466
    parser.add_argument(
467
        "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name",
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
    )
    parser.add_argument(
        "--tokenizer_name",
        default="",
        type=str,
        help="Pretrained tokenizer name or path if not the same as model_name",
    )
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help="Where do you want to store the pre-trained models downloaded from s3",
    )
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help="The maximum total input sequence length after tokenization. Sequences longer "
        "than this will be truncated, sequences shorter will be padded.",
    )
    parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
    parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
    parser.add_argument(
Hang Le's avatar
Hang Le committed
491
        "--evaluate_during_training", action="store_true", help="Run evaluation during training at each logging step.",
492
493
    )
    parser.add_argument(
494
        "--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model.",
495
496
497
    )

    parser.add_argument(
498
499
500
501
        "--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.",
    )
    parser.add_argument(
        "--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation.",
502
503
504
505
506
507
508
509
510
511
512
513
    )
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help="Number of updates steps to accumulate before performing a backward/update pass.",
    )
    parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
    parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
    parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
    parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
    parser.add_argument(
514
        "--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.",
515
516
517
518
519
520
521
522
523
    )
    parser.add_argument(
        "--max_steps",
        default=-1,
        type=int,
        help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
    )
    parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")

524
525
    parser.add_argument("--logging_steps", type=int, default=500, help="Log every X updates steps.")
    parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.")
526
527
528
529
530
531
532
    parser.add_argument(
        "--eval_all_checkpoints",
        action="store_true",
        help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
    )
    parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
    parser.add_argument(
533
        "--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory",
534
535
    )
    parser.add_argument(
536
        "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets",
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
    )
    parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")

    parser.add_argument(
        "--fp16",
        action="store_true",
        help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
    )
    parser.add_argument(
        "--fp16_opt_level",
        type=str,
        default="O1",
        help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
        "See details at https://nvidia.github.io/apex/amp.html",
    )
    parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
    parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
    parser.add_argument("--server_port", type=str, default="", help="For distant debugging.")
thomwolf's avatar
thomwolf committed
555
556
    args = parser.parse_args()

557
558
559
560
561
562
563
564
565
566
567
    if (
        os.path.exists(args.output_dir)
        and os.listdir(args.output_dir)
        and args.do_train
        and not args.overwrite_output_dir
    ):
        raise ValueError(
            "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
                args.output_dir
            )
        )
thomwolf's avatar
thomwolf committed
568

thomwolf's avatar
thomwolf committed
569
570
571
572
    # Setup distant debugging if needed
    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
573

thomwolf's avatar
thomwolf committed
574
575
576
577
578
579
580
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
        ptvsd.wait_for_attach()

    # Setup CUDA, GPU & distributed training
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
581
        args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()
thomwolf's avatar
thomwolf committed
582
583
584
    else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
585
        torch.distributed.init_process_group(backend="nccl")
thomwolf's avatar
thomwolf committed
586
        args.n_gpu = 1
thomwolf's avatar
thomwolf committed
587
588
589
    args.device = device

    # Setup logging
590
591
592
593
594
595
596
597
598
599
600
601
602
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
    )
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        args.local_rank,
        device,
        args.n_gpu,
        bool(args.local_rank != -1),
        args.fp16,
    )
thomwolf's avatar
thomwolf committed
603

thomwolf's avatar
thomwolf committed
604
605
    # Set seed
    set_seed(args)
thomwolf's avatar
thomwolf committed
606
607

    # Prepare GLUE task
thomwolf's avatar
thomwolf committed
608
609
610
611
612
    args.task_name = args.task_name.lower()
    if args.task_name not in processors:
        raise ValueError("Task not found: %s" % (args.task_name))
    processor = processors[args.task_name]()
    args.output_mode = output_modes[args.task_name]
thomwolf's avatar
thomwolf committed
613
614
615
616
617
    label_list = processor.get_labels()
    num_labels = len(label_list)

    # Load pretrained model and tokenizer
    if args.local_rank not in [-1, 0]:
618
        torch.distributed.barrier()  # Make sure only the first process in distributed training will download model & vocab
thomwolf's avatar
thomwolf committed
619

620
    args.model_type = args.model_type.lower()
thomwolf's avatar
thomwolf committed
621
    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
    config = config_class.from_pretrained(
        args.config_name if args.config_name else args.model_name_or_path,
        num_labels=num_labels,
        finetuning_task=args.task_name,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )
    tokenizer = tokenizer_class.from_pretrained(
        args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
        do_lower_case=args.do_lower_case,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )
    model = model_class.from_pretrained(
        args.model_name_or_path,
        from_tf=bool(".ckpt" in args.model_name_or_path),
        config=config,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )
thomwolf's avatar
thomwolf committed
639
640

    if args.local_rank == 0:
641
        torch.distributed.barrier()  # Make sure only the first process in distributed training will download model & vocab
thomwolf's avatar
thomwolf committed
642

thomwolf's avatar
thomwolf committed
643
    model.to(args.device)
thomwolf's avatar
thomwolf committed
644

thomwolf's avatar
thomwolf committed
645
646
    logger.info("Training/evaluation parameters %s", args)

thomwolf's avatar
thomwolf committed
647
    # Training
thomwolf's avatar
thomwolf committed
648
    if args.do_train:
649
        train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False)
thomwolf's avatar
thomwolf committed
650
        global_step, tr_loss = train(args, train_dataset, model, tokenizer)
thomwolf's avatar
thomwolf committed
651
        logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
thomwolf's avatar
thomwolf committed
652

thomwolf's avatar
thomwolf committed
653
    # Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
654
    if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
655
656
657
658
        # Create output directory if needed
        if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
            os.makedirs(args.output_dir)

thomwolf's avatar
thomwolf committed
659
        logger.info("Saving model checkpoint to %s", args.output_dir)
660
661
        # Save a trained model, configuration and tokenizer using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
662
663
664
        model_to_save = (
            model.module if hasattr(model, "module") else model
        )  # Take care of distributed/parallel training
thomwolf's avatar
thomwolf committed
665
        model_to_save.save_pretrained(args.output_dir)
666
        tokenizer.save_pretrained(args.output_dir)
thomwolf's avatar
thomwolf committed
667
668

        # Good practice: save your training arguments together with the trained model
669
        torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
thomwolf's avatar
thomwolf committed
670

671
        # Load a trained model and vocabulary that you have fine-tuned
672
        model = model_class.from_pretrained(args.output_dir)
thomwolf's avatar
thomwolf committed
673
        tokenizer = tokenizer_class.from_pretrained(args.output_dir)
674
        model.to(args.device)
thomwolf's avatar
thomwolf committed
675

thomwolf's avatar
thomwolf committed
676
    # Evaluation
thomwolf's avatar
thomwolf committed
677
    results = {}
thomwolf's avatar
thomwolf committed
678
    if args.do_eval and args.local_rank in [-1, 0]:
679
        tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
thomwolf's avatar
thomwolf committed
680
        checkpoints = [args.output_dir]
thomwolf's avatar
thomwolf committed
681
        if args.eval_all_checkpoints:
682
683
684
            checkpoints = list(
                os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
            )
685
            logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN)  # Reduce logging
thomwolf's avatar
thomwolf committed
686
687
        logger.info("Evaluate the following checkpoints: %s", checkpoints)
        for checkpoint in checkpoints:
688
689
690
            global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
            prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""

thomwolf's avatar
thomwolf committed
691
            model = model_class.from_pretrained(checkpoint)
thomwolf's avatar
thomwolf committed
692
            model.to(args.device)
693
            result = evaluate(args, model, tokenizer, prefix=prefix)
694
            result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
thomwolf's avatar
thomwolf committed
695
696
            results.update(result)

thomwolf's avatar
thomwolf committed
697
    return results
thomwolf's avatar
thomwolf committed
698
699
700
701


if __name__ == "__main__":
    main()