run_glue.py 22.8 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BERT finetuning runner."""

from __future__ import absolute_import, division, print_function

import argparse
thomwolf's avatar
thomwolf committed
21
import glob
thomwolf's avatar
thomwolf committed
22
23
24
25
26
27
import logging
import os
import random

import numpy as np
import torch
thomwolf's avatar
thomwolf committed
28
from tensorboardX import SummaryWriter
thomwolf's avatar
thomwolf committed
29
30
31
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
                              TensorDataset)
from torch.utils.data.distributed import DistributedSampler
thomwolf's avatar
thomwolf committed
32
from tqdm import tqdm, trange
thomwolf's avatar
thomwolf committed
33

thomwolf's avatar
thomwolf committed
34
35
36
37
38
39
from pytorch_transformers import WEIGHTS_NAME
from pytorch_transformers import (BertConfig, BertForSequenceClassification,
                                  BertTokenizer, XLMConfig,
                                  XLMForSequenceClassification, XLMTokenizer,
                                  XLNetConfig, XLNetForSequenceClassification,
                                  XLNetTokenizer)
thomwolf's avatar
thomwolf committed
40
from pytorch_transformers.optimization import BertAdam
thomwolf's avatar
thomwolf committed
41
42
from utils_glue import (compute_metrics, convert_examples_to_features,
                        output_modes, processors)
thomwolf's avatar
thomwolf committed
43
44
45

logger = logging.getLogger(__name__)

thomwolf's avatar
thomwolf committed
46
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, XLMConfig)), ())
47
48

MODEL_CLASSES = {
thomwolf's avatar
thomwolf committed
49
50
51
    'bert': (BertConfig, BertForSequenceClassification, BertTokenizer),
    'xlnet': (XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer),
    'xlm': (XLMConfig, XLMForSequenceClassification, XLMTokenizer),
52
}
thomwolf's avatar
thomwolf committed
53

thomwolf's avatar
thomwolf committed
54
def train(args, train_dataset, model, tokenizer):
thomwolf's avatar
thomwolf committed
55
56
57
58
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

thomwolf's avatar
thomwolf committed
59
    args.train_batch_size = args.per_gpu_train_batch_size * args.n_gpu
60
61
    train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
thomwolf's avatar
thomwolf committed
62

thomwolf's avatar
thomwolf committed
63
64
65
66
67
    if args.max_steps > 0:
        num_train_optimization_steps = args.max_steps
        args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        num_train_optimization_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
thomwolf's avatar
thomwolf committed
68
69

    # Prepare optimizer
thomwolf's avatar
thomwolf committed
70
    no_decay = ['bias', 'LayerNorm.weight']
thomwolf's avatar
thomwolf committed
71
    optimizer_grouped_parameters = [
thomwolf's avatar
thomwolf committed
72
73
        {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
        {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
thomwolf's avatar
thomwolf committed
74
        ]
thomwolf's avatar
thomwolf committed
75
76
    optimizer = BertAdam(optimizer_grouped_parameters, lr=args.learning_rate,
                         t_total=num_train_optimization_steps, warmup=args.warmup_proportion)
thomwolf's avatar
thomwolf committed
77
78
    if args.fp16:
        try:
thomwolf's avatar
thomwolf committed
79
            from apex import amp
thomwolf's avatar
thomwolf committed
80
81
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
thomwolf's avatar
thomwolf committed
82
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
thomwolf's avatar
thomwolf committed
83
84
85

    # Train!
    logger.info("***** Running training *****")
86
87
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
thomwolf's avatar
thomwolf committed
88
89
90
    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
    logger.info("  Total train batch size (w. parallel, distributed & accumulation) = %d",
                   args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
91
92
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", num_train_optimization_steps)
thomwolf's avatar
thomwolf committed
93
94

    global_step = 0
thomwolf's avatar
thomwolf committed
95
    tr_loss, logging_loss = 0.0, 0.0
96
    optimizer.zero_grad()
thomwolf's avatar
thomwolf committed
97
98
    for _ in trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]):
        for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])):
thomwolf's avatar
thomwolf committed
99
            model.train()
thomwolf's avatar
thomwolf committed
100
            batch = tuple(t.to(args.device) for t in batch)
101
102
            inputs = {'input_ids':      batch[0],
                      'attention_mask': batch[1],
thomwolf's avatar
thomwolf committed
103
                      'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None,  # XLM don't use segment_ids
104
105
                      'labels':         batch[3]}
            ouputs = model(**inputs)
thomwolf's avatar
thomwolf committed
106
107
108
109
110
111
112
            loss = ouputs[0]

            if args.n_gpu > 1:
                loss = loss.mean() # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

thomwolf's avatar
thomwolf committed
113
114
115
116
117
            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
thomwolf's avatar
thomwolf committed
118
119
120
121
122
123

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()
                global_step += 1
thomwolf's avatar
thomwolf committed
124

thomwolf's avatar
thomwolf committed
125
                if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
thomwolf's avatar
thomwolf committed
126
                    # Log metrics
thomwolf's avatar
thomwolf committed
127
                    if args.local_rank == -1:  # Only evaluate on single GPU otherwise metrics may not average well
thomwolf's avatar
thomwolf committed
128
                        results = evaluate(args, model, tokenizer, prefix=global_step)
thomwolf's avatar
thomwolf committed
129
130
131
132
133
                        for key, value in results.items():
                            tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
                    tb_writer.add_scalar('lr', optimizer.get_lr()[0], global_step)
                    tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
                    logging_loss = tr_loss
thomwolf's avatar
thomwolf committed
134
135
136
137
138
139
140
141
142
143

                if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
                    output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = model.module if hasattr(model, 'module') else model  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    torch.save(args, os.path.join(output_dir, 'training_args.bin'))

thomwolf's avatar
thomwolf committed
144
145
146
147
            if args.max_steps > 0 and global_step > args.max_steps:
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            break
thomwolf's avatar
thomwolf committed
148
149
150
151

    return global_step, tr_loss / global_step


thomwolf's avatar
thomwolf committed
152
def evaluate(args, model, tokenizer, prefix=""):
thomwolf's avatar
thomwolf committed
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
    # Loop to handle MNLI double evaluation (matched, mis-matched)
    eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,)
    eval_outputs_dirs = (args.output_dir, args.output_dir + '-MM') if args.task_name == "mnli" else (args.output_dir,)

    results = {}
    for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs):
        eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=True)

        """ Evaluate the model """
        if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
            os.makedirs(eval_output_dir)

        # Note that DistributedSampler samples randomly
        eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
        eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)

        # Eval!
        logger.info("***** Running evaluation *****")
        logger.info("  Num examples = %d", len(eval_dataset))
        logger.info("  Batch size = %d", args.eval_batch_size)
        model.eval()
        eval_loss = 0
        nb_eval_steps = 0
        preds = None
        out_label_ids = None
        for batch in tqdm(eval_dataloader, desc="Evaluating"):
            batch = tuple(t.to(args.device) for t in batch)

            with torch.no_grad():
                inputs = {'input_ids':      batch[0],
                          'attention_mask': batch[1],
                          'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None,  # XLM don't use segment_ids
                          'labels':         batch[3]}
                outputs = model(**inputs)
                tmp_eval_loss, logits = outputs[:2]

            eval_loss += tmp_eval_loss.mean().item()
            nb_eval_steps += 1
            if preds is None:
                preds = logits.detach().cpu().numpy()
                out_label_ids = inputs['labels'].detach().cpu().numpy()
            else:
                preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
                out_label_ids = np.append(out_label_ids, inputs['labels'].detach().cpu().numpy(), axis=0)

        eval_loss = eval_loss / nb_eval_steps
        if args.output_mode == "classification":
            preds = np.argmax(preds, axis=1)
        elif args.output_mode == "regression":
            preds = np.squeeze(preds)
        result = compute_metrics(eval_task, preds, out_label_ids)
        results.update(result)

        output_eval_file = os.path.join(eval_output_dir, "eval_results.txt")
        with open(output_eval_file, "w") as writer:
thomwolf's avatar
thomwolf committed
208
            logger.info("***** Eval results {} *****".format(prefix))
thomwolf's avatar
thomwolf committed
209
210
211
212
213
214
215
216
            for key in sorted(result.keys()):
                logger.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))

    return results


def load_and_cache_examples(args, task, tokenizer, evaluate=False, overwrite_cache=False):
thomwolf's avatar
thomwolf committed
217
    processor = processors[task]()
218
219
220
221
    output_mode = output_modes[task]
    # Load data features from cache or dataset file
    cached_features_file = os.path.join(args.data_dir, 'cached_{}_{}_{}_{}'.format(
        'dev' if evaluate else 'train',
222
        list(filter(None, args.model_name.split('/'))).pop(),
thomwolf's avatar
thomwolf committed
223
224
        str(args.max_seq_length),
        str(task)))
thomwolf's avatar
thomwolf committed
225
    if os.path.exists(cached_features_file) and not args.overwrite_cache:
thomwolf's avatar
thomwolf committed
226
        logger.info("Loading features from cached file %s", cached_features_file)
thomwolf's avatar
thomwolf committed
227
228
        features = torch.load(cached_features_file)
    else:
229
230
231
        logger.info("Creating features from dataset file at %s", args.data_dir)
        label_list = processor.get_labels()
        examples = processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir)
232
        features = convert_examples_to_features(examples, label_list, args.max_seq_length, tokenizer, output_mode,
233
            cls_token_at_end=bool(args.model_type in ['xlnet']),            # xlnet has a cls token at the end
234
            cls_token=tokenizer.cls_token,
235
236
237
238
239
            sep_token=tokenizer.sep_token,
            cls_token_segment_id=2 if args.model_type in ['xlnet'] else 1,
            pad_on_left=bool(args.model_type in ['xlnet']),                 # pad on the left for xlnet
            pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0)
        if args.local_rank in [-1, 0]:
thomwolf's avatar
thomwolf committed
240
            logger.info("Saving features into cached file %s", cached_features_file)
thomwolf's avatar
thomwolf committed
241
242
            torch.save(features, cached_features_file)

243
244
245
246
247
248
249
250
251
252
253
    # Convert to Tensors and build dataset
    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
    if output_mode == "classification":
        all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long)
    elif output_mode == "regression":
        all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.float)

    dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
    return dataset
thomwolf's avatar
thomwolf committed
254
255


thomwolf's avatar
thomwolf committed
256
257
258
259
260
261
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--data_dir", default=None, type=str, required=True,
                        help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
262
263
    parser.add_argument("--model_name", default=None, type=str, required=True,
                        help="Bert/XLNet/XLM pre-trained model selected in the list: " + ", ".join(ALL_MODELS))
thomwolf's avatar
thomwolf committed
264
    parser.add_argument("--task_name", default=None, type=str, required=True,
265
                        help="The name of the task to train selected in the list: " + ", ".join(processors.keys()))
thomwolf's avatar
thomwolf committed
266
267
268
269
    parser.add_argument("--output_dir", default=None, type=str, required=True,
                        help="The output directory where the model predictions and checkpoints will be written.")

    ## Other parameters
thomwolf's avatar
thomwolf committed
270
271
272
273
    parser.add_argument("--config_name", default="", type=str,
                        help="Pretrained config name or path if not the same as model_name")
    parser.add_argument("--tokenizer_name", default="", type=str,
                        help="Pretrained tokenizer name or path if not the same as model_name")
thomwolf's avatar
thomwolf committed
274
275
276
    parser.add_argument("--cache_dir", default="", type=str,
                        help="Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument("--max_seq_length", default=128, type=int,
277
278
                        help="The maximum total input sequence length after tokenization. Sequences longer "
                             "than this will be truncated, sequences shorter will be padded.")
thomwolf's avatar
thomwolf committed
279
280
281
282
283
284
    parser.add_argument("--do_train", action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval", action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--do_lower_case", action='store_true',
                        help="Set this flag if you are using an uncased model.")
thomwolf's avatar
thomwolf committed
285
286
287

    parser.add_argument("--per_gpu_train_batch_size", default=8, type=int,
                        help="Batch size per GPU for training.")
thomwolf's avatar
thomwolf committed
288
289
290
291
292
293
    parser.add_argument("--eval_batch_size", default=8, type=int,
                        help="Total batch size for eval.")
    parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument("--learning_rate", default=5e-5, type=float,
                        help="The initial learning rate for Adam.")
thomwolf's avatar
thomwolf committed
294
295
    parser.add_argument("--weight_decay", default=0.0, type=float,
                        help="Weight deay if we apply some.")
thomwolf's avatar
thomwolf committed
296
297
    parser.add_argument("--num_train_epochs", default=3.0, type=float,
                        help="Total number of training epochs to perform.")
thomwolf's avatar
thomwolf committed
298
299
    parser.add_argument("--max_steps", default=-1, type=int,
                        help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
thomwolf's avatar
thomwolf committed
300
    parser.add_argument("--warmup_proportion", default=0.1, type=float,
301
                        help="Proportion of training with linear learning rate warmup (0.1 = 10%% of training).")
thomwolf's avatar
thomwolf committed
302

thomwolf's avatar
thomwolf committed
303
    parser.add_argument('--logging_steps', type=int, default=50,
thomwolf's avatar
thomwolf committed
304
                        help="Log every X updates steps.")
thomwolf's avatar
thomwolf committed
305
306
307
308
    parser.add_argument('--save_steps', type=int, default=50,
                        help="Save checkpoint every X updates steps.")
    parser.add_argument("--eval_all_checkpoints", action='store_true',
                        help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number")
thomwolf's avatar
thomwolf committed
309
310
311
312
    parser.add_argument("--no_cuda", action='store_true',
                        help="Avoid using CUDA when available")
    parser.add_argument('--overwrite_output_dir', action='store_true',
                        help="Overwrite the content of the output directory")
thomwolf's avatar
thomwolf committed
313
314
    parser.add_argument('--overwrite_cache', action='store_true',
                        help="Overwrite the cached training and evaluation sets")
thomwolf's avatar
thomwolf committed
315
316
317
318
    parser.add_argument('--seed', type=int, default=42,
                        help="random seed for initialization")

    parser.add_argument('--fp16', action='store_true',
thomwolf's avatar
thomwolf committed
319
320
321
322
                        help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
    parser.add_argument('--fp16_opt_level', type=str, default='O1',
                        help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
                             "See details at https://nvidia.github.io/apex/amp.html")
thomwolf's avatar
thomwolf committed
323
    parser.add_argument("--local_rank", type=int, default=-1,
thomwolf's avatar
thomwolf committed
324
325
326
                        help="For distributed training: local_rank")
    parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.")
    parser.add_argument('--server_port', type=str, default='', help="For distant debugging.")
thomwolf's avatar
thomwolf committed
327
328
    args = parser.parse_args()

thomwolf's avatar
thomwolf committed
329
330
331
    if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
        raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))

thomwolf's avatar
thomwolf committed
332
333
334
335
336
337
338
339
340
341
342
    # Setup distant debugging if needed
    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
        ptvsd.wait_for_attach()

    # Setup CUDA, GPU & distributed training
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
thomwolf's avatar
thomwolf committed
343
        args.n_gpu = torch.cuda.device_count()
thomwolf's avatar
thomwolf committed
344
345
346
347
    else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl')
thomwolf's avatar
thomwolf committed
348
        args.n_gpu = 1
thomwolf's avatar
thomwolf committed
349
350
351
352
    args.device = device

    # Setup logging
    logging.basicConfig(level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
thomwolf's avatar
thomwolf committed
353
354
    logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
                args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
thomwolf's avatar
thomwolf committed
355
356
357
358
359

    # Setup seeds
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
thomwolf's avatar
thomwolf committed
360
    if args.n_gpu > 0:
thomwolf's avatar
thomwolf committed
361
362
363
        torch.cuda.manual_seed_all(args.seed)

    # Prepare GLUE task
thomwolf's avatar
thomwolf committed
364
365
366
367
368
    args.task_name = args.task_name.lower()
    if args.task_name not in processors:
        raise ValueError("Task not found: %s" % (args.task_name))
    processor = processors[args.task_name]()
    args.output_mode = output_modes[args.task_name]
thomwolf's avatar
thomwolf committed
369
370
371
372
373
374
375
376
    label_list = processor.get_labels()
    num_labels = len(label_list)

    # Load pretrained model and tokenizer
    if args.local_rank not in [-1, 0]:
        # Make sure only the first process in distributed training will download model & vocab
        torch.distributed.barrier()

thomwolf's avatar
thomwolf committed
377
378
379
380
381
382
383
384
385
    args.model_type = ""
    for key in MODEL_CLASSES:
        if key in args.model_name.lower():
            args.model_type = key  # take the first match in model types
            break
    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name, num_labels=num_labels, finetuning_task=args.task_name)
    tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name, do_lower_case=args.do_lower_case)
    model = model_class.from_pretrained(args.model_name, from_tf=bool('.ckpt' in args.model_name), config=config)
thomwolf's avatar
thomwolf committed
386
387
388
389

    if args.local_rank == 0:
        torch.distributed.barrier()

thomwolf's avatar
thomwolf committed
390
    # Distributed and parrallel training
thomwolf's avatar
thomwolf committed
391
    model.to(args.device)
thomwolf's avatar
thomwolf committed
392
    if args.local_rank != -1:
thomwolf's avatar
thomwolf committed
393
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
thomwolf's avatar
thomwolf committed
394
395
                                                          output_device=args.local_rank,
                                                          find_unused_parameters=True)
thomwolf's avatar
thomwolf committed
396
    elif args.n_gpu > 1:
thomwolf's avatar
thomwolf committed
397
398
        model = torch.nn.DataParallel(model)

thomwolf's avatar
thomwolf committed
399
    # Training
thomwolf's avatar
thomwolf committed
400
    if args.do_train:
401
        train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False)
thomwolf's avatar
thomwolf committed
402
        global_step, tr_loss = train(args, train_dataset, model, tokenizer)
thomwolf's avatar
thomwolf committed
403
        logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
thomwolf's avatar
thomwolf committed
404
405


thomwolf's avatar
thomwolf committed
406
    # Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
thomwolf's avatar
thomwolf committed
407
    if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
408
409
410
411
412
413
        # Create output directory if needed
        if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
            os.makedirs(args.output_dir)

        # Save a trained model, configuration and tokenizer using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
thomwolf's avatar
thomwolf committed
414
415
        model_to_save = model.module if hasattr(model, 'module') else model  # Take care of distributed/parallel training
        model_to_save.save_pretrained(args.output_dir)
416
        tokenizer.save_pretrained(args.output_dir)
thomwolf's avatar
thomwolf committed
417
418

        # Good practice: save your training arguments together with the trained model
419
        torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))
thomwolf's avatar
thomwolf committed
420

421
        # Load a trained model and vocabulary that you have fine-tuned
422
423
        model = model_class.from_pretrained(args.output_dir)
        tokenizer = tokenizer_class.from_pretrained(args.output_dir)
424
        model.to(args.device)
thomwolf's avatar
thomwolf committed
425

thomwolf's avatar
thomwolf committed
426
    # Evaluation
thomwolf's avatar
thomwolf committed
427
    if args.do_eval and args.local_rank in [-1, 0]:
thomwolf's avatar
thomwolf committed
428
429
430
431
432
433
434
435
436
437
438
        checkpoints = [args.output_dir + './' + WEIGHTS_NAME]
        if args.eval_all_checkpoints:
            checkpoints = list(os.path.dirname(c) for c in glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True))
        logger.info("Evaluate the following checkpoints: %s", checkpoints)
        results = {}
        for checkpoint in checkpoints:
            global_step = int(checkpoints.split('-')[-1])
            model = model_class.from_pretrained(checkpoints)
            model.to(args.device)
            result = evaluate(args, model, tokenizer, prefix=global_step)
            result = dict(n + '_{}'.format())
thomwolf's avatar
thomwolf committed
439
        return results
thomwolf's avatar
thomwolf committed
440
441
442
443


if __name__ == "__main__":
    main()