run_mmimdb.py 23.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# coding=utf-8
# Copyright (c) Facebook, Inc. and its affiliates.
# Copyright (c) HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Finetuning the library models for multimodal multiclass prediction on MM-IMDB dataset."""


import argparse
import glob
Aymeric Augustin's avatar
Aymeric Augustin committed
21
import json
22
23
24
25
26
27
import logging
import os
import random

import numpy as np
import torch
Aymeric Augustin's avatar
Aymeric Augustin committed
28
from sklearn.metrics import f1_score
29
from torch import nn
30
31
32
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
33
from utils_mmimdb import ImageEncoder, JsonlDataset, collate_fn, get_image_transforms, get_mmimdb_labels
34

35
import transformers
36
37
from transformers import (
    WEIGHTS_NAME,
Aymeric Augustin's avatar
Aymeric Augustin committed
38
    AdamW,
39
40
41
    AutoConfig,
    AutoModel,
    AutoTokenizer,
Aymeric Augustin's avatar
Aymeric Augustin committed
42
43
44
    MMBTConfig,
    MMBTForClassification,
    get_linear_schedule_with_warmup,
45
)
46
from transformers.trainer_utils import is_main_process
Aymeric Augustin's avatar
Aymeric Augustin committed
47
48
49
50


try:
    from torch.utils.tensorboard import SummaryWriter
51
except ImportError:
Aymeric Augustin's avatar
Aymeric Augustin committed
52
    from tensorboardX import SummaryWriter
53
54
55
56
57
58
59
60
61
62
63
64
65
66


logger = logging.getLogger(__name__)


def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)


def train(args, train_dataset, model, tokenizer, criterion):
Patrick von Platen's avatar
Patrick von Platen committed
67
    """Train the model"""
68
69
70
71
72
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
73
74
75
76
77
78
79
    train_dataloader = DataLoader(
        train_dataset,
        sampler=train_sampler,
        batch_size=args.train_batch_size,
        collate_fn=collate_fn,
        num_workers=args.num_workers,
    )
80
81
82
83
84
85
86
87

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
88
    no_decay = ["bias", "LayerNorm.weight"]
89
    optimizer_grouped_parameters = [
90
91
92
93
94
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
95
96
97
    ]

    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
98
99
100
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
    )
101
102
103
104
105
106
107
108
109
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
110
        model = nn.DataParallel(model)
111
112
113

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
114
        model = nn.parallel.DistributedDataParallel(
115
116
            model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
        )
117
118
119
120
121
122

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
123
124
125
126
127
128
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size
        * args.gradient_accumulation_steps
        * (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
129
130
131
132
133
134
135
136
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    best_f1, n_no_improve = 0, 0
    model.zero_grad()
    train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
137
    set_seed(args)  # Added here for reproductibility
138
139
140
141
142
143
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):
            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            labels = batch[5]
144
145
146
147
148
149
150
            inputs = {
                "input_ids": batch[0],
                "input_modal": batch[2],
                "attention_mask": batch[1],
                "modal_start_tokens": batch[3],
                "modal_end_tokens": batch[4],
            }
151
152
153
154
155
            outputs = model(**inputs)
            logits = outputs[0]  # model outputs are always tuple in transformers (see doc)
            loss = criterion(logits, labels)

            if args.n_gpu > 1:
156
                loss = loss.mean()  # mean() to average on multi-gpu parallel training
157
158
159
160
161
162
163
164
165
166
167
168
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
169
                    nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
170
                else:
171
                    nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
172
173
174
175
176
177
178
179

                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    logs = {}
180
181
182
                    if (
                        args.local_rank == -1 and args.evaluate_during_training
                    ):  # Only evaluate when single GPU otherwise metrics may not average well
183
184
                        results = evaluate(args, model, tokenizer, criterion)
                        for key, value in results.items():
185
                            eval_key = "eval_{}".format(key)
186
187
188
189
                            logs[eval_key] = value

                    loss_scalar = (tr_loss - logging_loss) / args.logging_steps
                    learning_rate_scalar = scheduler.get_lr()[0]
190
191
                    logs["learning_rate"] = learning_rate_scalar
                    logs["loss"] = loss_scalar
192
193
194
195
                    logging_loss = tr_loss

                    for key, value in logs.items():
                        tb_writer.add_scalar(key, value, global_step)
196
                    print(json.dumps({**logs, **{"step": global_step}}))
197
198
199

                if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
200
                    output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
201
202
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
203
204
205
                    model_to_save = (
                        model.module if hasattr(model, "module") else model
                    )  # Take care of distributed/parallel training
206
                    torch.save(model_to_save.state_dict(), os.path.join(output_dir, WEIGHTS_NAME))
207
                    torch.save(args, os.path.join(output_dir, "training_args.bin"))
208
209
210
211
212
213
214
215
216
217
218
                    logger.info("Saving model checkpoint to %s", output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

        if args.local_rank == -1:
            results = evaluate(args, model, tokenizer, criterion)
219
220
            if results["micro_f1"] > best_f1:
                best_f1 = results["micro_f1"]
221
222
223
224
                n_no_improve = 0
            else:
                n_no_improve += 1

225
            if n_no_improve > args.patience:
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
                train_iterator.close()
                break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step


def evaluate(args, model, tokenizer, criterion, prefix=""):
    # Loop to handle MNLI double evaluation (matched, mis-matched)
    eval_output_dir = args.output_dir
    eval_dataset = load_examples(args, tokenizer, evaluate=True)

    if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
        os.makedirs(eval_output_dir)

    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
    # Note that DistributedSampler samples randomly
    eval_sampler = SequentialSampler(eval_dataset)
246
247
248
    eval_dataloader = DataLoader(
        eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, collate_fn=collate_fn
    )
249
250

    # multi-gpu eval
251
252
    if args.n_gpu > 1 and not isinstance(model, nn.DataParallel):
        model = nn.DataParallel(model)
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268

    # Eval!
    logger.info("***** Running evaluation {} *****".format(prefix))
    logger.info("  Num examples = %d", len(eval_dataset))
    logger.info("  Batch size = %d", args.eval_batch_size)
    eval_loss = 0.0
    nb_eval_steps = 0
    preds = None
    out_label_ids = None
    for batch in tqdm(eval_dataloader, desc="Evaluating"):
        model.eval()
        batch = tuple(t.to(args.device) for t in batch)

        with torch.no_grad():
            batch = tuple(t.to(args.device) for t in batch)
            labels = batch[5]
269
270
271
272
273
274
275
            inputs = {
                "input_ids": batch[0],
                "input_modal": batch[2],
                "attention_mask": batch[1],
                "modal_start_tokens": batch[3],
                "modal_end_tokens": batch[4],
            }
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
            outputs = model(**inputs)
            logits = outputs[0]  # model outputs are always tuple in transformers (see doc)
            tmp_eval_loss = criterion(logits, labels)
            eval_loss += tmp_eval_loss.mean().item()
        nb_eval_steps += 1
        if preds is None:
            preds = torch.sigmoid(logits).detach().cpu().numpy() > 0.5
            out_label_ids = labels.detach().cpu().numpy()
        else:
            preds = np.append(preds, torch.sigmoid(logits).detach().cpu().numpy() > 0.5, axis=0)
            out_label_ids = np.append(out_label_ids, labels.detach().cpu().numpy(), axis=0)

    eval_loss = eval_loss / nb_eval_steps
    result = {
        "loss": eval_loss,
        "macro_f1": f1_score(out_label_ids, preds, average="macro"),
292
        "micro_f1": f1_score(out_label_ids, preds, average="micro"),
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
    }

    output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
    with open(output_eval_file, "w") as writer:
        logger.info("***** Eval results {} *****".format(prefix))
        for key in sorted(result.keys()):
            logger.info("  %s = %s", key, str(result[key]))
            writer.write("%s = %s\n" % (key, str(result[key])))

    return result


def load_examples(args, tokenizer, evaluate=False):
    path = os.path.join(args.data_dir, "dev.jsonl" if evaluate else "train.jsonl")
    transforms = get_image_transforms()
    labels = get_mmimdb_labels()
    dataset = JsonlDataset(path, tokenizer, transforms, labels, args.max_seq_length - args.num_image_embeds - 2)
    return dataset


def main():
    parser = argparse.ArgumentParser()

316
    # Required parameters
317
318
319
320
321
322
323
324
325
326
327
328
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help="The input data dir. Should contain the .jsonl files for MMIMDB.",
    )
    parser.add_argument(
        "--model_name_or_path",
        default=None,
        type=str,
        required=True,
329
        help="Path to pretrained model or model identifier from huggingface.co/models",
330
331
332
333
334
335
336
337
    )
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help="The output directory where the model predictions and checkpoints will be written.",
    )
338

339
    # Other parameters
340
341
342
343
344
345
346
347
348
349
350
    parser.add_argument(
        "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
    )
    parser.add_argument(
        "--tokenizer_name",
        default="",
        type=str,
        help="Pretrained tokenizer name or path if not the same as model_name",
    )
    parser.add_argument(
        "--cache_dir",
351
        default=None,
352
        type=str,
353
        help="Where do you want to store the pre-trained models downloaded from huggingface.co",
354
355
356
357
358
    )
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
Sylvain Gugger's avatar
Sylvain Gugger committed
359
360
361
362
        help=(
            "The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded."
        ),
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
    )
    parser.add_argument(
        "--num_image_embeds", default=1, type=int, help="Number of Image Embeddings from the Image Encoder"
    )
    parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
    parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--evaluate_during_training", action="store_true", help="Rul evaluation during training at each logging step."
    )
    parser.add_argument(
        "--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
    )

    parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.")
    parser.add_argument(
        "--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation."
    )
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help="Number of updates steps to accumulate before performing a backward/update pass.",
    )
    parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
    parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight deay if we apply some.")
    parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
    parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
    parser.add_argument(
        "--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform."
    )
    parser.add_argument("--patience", default=5, type=int, help="Patience for Early Stopping.")
    parser.add_argument(
        "--max_steps",
        default=-1,
        type=int,
        help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
    )
    parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")

    parser.add_argument("--logging_steps", type=int, default=50, help="Log every X updates steps.")
    parser.add_argument("--save_steps", type=int, default=50, help="Save checkpoint every X updates steps.")
    parser.add_argument(
        "--eval_all_checkpoints",
        action="store_true",
        help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
    )
    parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
    parser.add_argument("--num_workers", type=int, default=8, help="number of worker threads for dataloading")
    parser.add_argument(
        "--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory"
    )
    parser.add_argument(
        "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
    )
    parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")

    parser.add_argument(
        "--fp16",
        action="store_true",
        help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
    )
    parser.add_argument(
        "--fp16_opt_level",
        type=str,
        default="O1",
Sylvain Gugger's avatar
Sylvain Gugger committed
428
        help=(
429
            "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']. "
Sylvain Gugger's avatar
Sylvain Gugger committed
430
431
            "See details at https://nvidia.github.io/apex/amp.html"
        ),
432
433
434
435
    )
    parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
    parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
    parser.add_argument("--server_port", type=str, default="", help="For distant debugging.")
436
437
    args = parser.parse_args()

438
439
440
441
442
443
444
445
446
447
448
    if (
        os.path.exists(args.output_dir)
        and os.listdir(args.output_dir)
        and args.do_train
        and not args.overwrite_output_dir
    ):
        raise ValueError(
            "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
                args.output_dir
            )
        )
449
450
451
452
453

    # Setup distant debugging if needed
    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
454

455
456
457
458
459
460
461
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
        ptvsd.wait_for_attach()

    # Setup CUDA, GPU & distributed training
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
462
        args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()
463
464
465
    else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
466
        torch.distributed.init_process_group(backend="nccl")
467
468
469
470
471
        args.n_gpu = 1

    args.device = device

    # Setup logging
472
    logging.basicConfig(
473
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
474
475
476
477
478
479
480
481
482
483
484
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
    )
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        args.local_rank,
        device,
        args.n_gpu,
        bool(args.local_rank != -1),
        args.fp16,
    )
485
486
487
488
489
    # Set the verbosity to info of the Transformers logger (on main process only):
    if is_main_process(args.local_rank):
        transformers.utils.logging.set_verbosity_info()
        transformers.utils.logging.enable_default_handler()
        transformers.utils.logging.enable_explicit_format()
490
491
492
493
494
495
496
497
498
499
    # Set seed
    set_seed(args)

    # Load pretrained model and tokenizer
    if args.local_rank not in [-1, 0]:
        torch.distributed.barrier()  # Make sure only the first process in distributed training will download model & vocab

    # Setup model
    labels = get_mmimdb_labels()
    num_labels = len(labels)
500
501
    transformer_config = AutoConfig.from_pretrained(args.config_name if args.config_name else args.model_name_or_path)
    tokenizer = AutoTokenizer.from_pretrained(
502
503
        args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
        do_lower_case=args.do_lower_case,
504
        cache_dir=args.cache_dir,
505
    )
506
507
    transformer = AutoModel.from_pretrained(
        args.model_name_or_path, config=transformer_config, cache_dir=args.cache_dir
508
    )
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
    img_encoder = ImageEncoder(args)
    config = MMBTConfig(transformer_config, num_labels=num_labels)
    model = MMBTForClassification(config, transformer, img_encoder)

    if args.local_rank == 0:
        torch.distributed.barrier()  # Make sure only the first process in distributed training will download model & vocab

    model.to(args.device)

    logger.info("Training/evaluation parameters %s", args)

    # Training
    if args.do_train:
        train_dataset = load_examples(args, tokenizer, evaluate=False)
        label_frequences = train_dataset.get_label_frequencies()
        label_frequences = [label_frequences[l] for l in labels]
525
526
527
        label_weights = (
            torch.tensor(label_frequences, device=args.device, dtype=torch.float) / len(train_dataset)
        ) ** -1
528
529
530
531
532
533
534
535
536
        criterion = nn.BCEWithLogitsLoss(pos_weight=label_weights)
        global_step, tr_loss = train(args, train_dataset, model, tokenizer, criterion)
        logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)

    # Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
    if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
        logger.info("Saving model checkpoint to %s", args.output_dir)
        # Save a trained model, configuration and tokenizer using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
537
538
539
        model_to_save = (
            model.module if hasattr(model, "module") else model
        )  # Take care of distributed/parallel training
540
541
542
543
        torch.save(model_to_save.state_dict(), os.path.join(args.output_dir, WEIGHTS_NAME))
        tokenizer.save_pretrained(args.output_dir)

        # Good practice: save your training arguments together with the trained model
544
        torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
545
546
547
548

        # Load a trained model and vocabulary that you have fine-tuned
        model = MMBTForClassification(config, transformer, img_encoder)
        model.load_state_dict(torch.load(os.path.join(args.output_dir, WEIGHTS_NAME)))
549
        tokenizer = AutoTokenizer.from_pretrained(args.output_dir)
550
551
552
553
554
555
556
        model.to(args.device)

    # Evaluation
    results = {}
    if args.do_eval and args.local_rank in [-1, 0]:
        checkpoints = [args.output_dir]
        if args.eval_all_checkpoints:
557
            checkpoints = [
558
                os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
559
            ]
560

561
562
        logger.info("Evaluate the following checkpoints: %s", checkpoints)
        for checkpoint in checkpoints:
563
564
            global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
            prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""
565
566
567
568
            model = MMBTForClassification(config, transformer, img_encoder)
            model.load_state_dict(torch.load(checkpoint))
            model.to(args.device)
            result = evaluate(args, model, tokenizer, criterion, prefix=prefix)
569
            result = {k + "_{}".format(global_step): v for k, v in result.items()}
570
571
572
573
574
575
576
            results.update(result)

    return results


if __name__ == "__main__":
    main()