test.py 5.82 KB
Newer Older
chenych's avatar
chenych committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import os
import cv2
import torch
import argparse

from datasets import build_dataset, get_coco_api_from_dataset
from models import build_test_model
from datasets.coco_eval import CocoEvaluator

import util.misc as utils

def test_img(args, model, postprocessors, save_path):
    '''test img'''
    dataset_test = build_dataset(image_set='val', args=args)
    sampler_test = torch.utils.data.SequentialSampler(dataset_test)
    data_loader_test = torch.utils.data.DataLoader(
            dataset_test,
            1,
            sampler=sampler_test,
            drop_last=False,
            collate_fn=utils.collate_fn,
            num_workers=0,
            pin_memory=True,)

    base_ds = get_coco_api_from_dataset(dataset_test)
    for img_data, targets in data_loader_test:
        img_data = img_data.to(device)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        outputs = model(img_data)
        orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
        result = postprocessors['bbox'](outputs, orig_target_sizes)
        if 'segm' in postprocessors.keys():
            target_sizes = torch.stack([t["size"] for t in targets], dim=0)
            results = postprocessors['segm'](results, outputs, orig_target_sizes, target_sizes)

        res = {target['image_id'].item(): output for target, output in zip(targets, result)}
        iou_types = tuple(k for k in ("segm", "bbox") if k in postprocessors.keys())
        coco_evaluator = CocoEvaluator(base_ds, iou_types)
        if coco_evaluator is not None:
            coco_evaluator.update(res)

        res = res[targets[0]['image_id'].item()]

        min_score = 0.65
        img_name = dataset_test.coco.loadImgs(targets[0]['image_id'].item())[0]['file_name']
Rayyyyy's avatar
Rayyyyy committed
47
        img = cv2.imread(os.path.join(args.coco_path, 'val2017', img_name))
chenych's avatar
chenych committed
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118

        draw_img = img.copy()
        save_status = False
        for i in range(0, 100):
            res_tmp = res['scores']
            if float(res_tmp[i]) > min_score:
                save_status = True
                score = float(res_tmp[i])
                label = int(res['labels'][i].cpu().numpy())
                bbox = res['boxes'][i].cpu().numpy().tolist()
                print("***", label, bbox)
                cv2.putText(draw_img, "{} | {}".format(label, str(score)[:3]), (int(bbox[0]), int(bbox[1])-2), 0, 0.5, (255, 255, 255), 1)
                cv2.rectangle(draw_img, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])), (0, 0, 255), 1)
        if save_status:
            cv2.imwrite("{}/{}".format(save_path, img_name), draw_img)

    if coco_evaluator is not None:
        coco_evaluator.synchronize_between_processes()
        coco_evaluator.accumulate()
        coco_evaluator.summarize()
        print(coco_evaluator)

def get_parser():
    parser = argparse.ArgumentParser("DETR Detector", add_help=False)
    parser.add_argument('--dataset_file', default='coco')
    parser.add_argument('--coco_path', default='/home/datasets/COCO2017', type=str)
    parser.add_argument('--coco_panoptic_path', type=str)
    parser.add_argument('--remove_difficult', action='store_true')
    parser.add_argument("--save_path", default="./result_img", type=str)
    parser.add_argument('--lr_backbone', default=1e-5, type=float)
    parser.add_argument('--device', default='cuda',
                        help='device to use for training / testing')
    parser.add_argument("--pre_trained_model", default="")
    parser.add_argument('--seed', default=42, type=int)
    # * Segmentation
    parser.add_argument('--masks', action='store_true',
                        help="Train segmentation head if the flag is provided")
    # * Backbone
    parser.add_argument('--backbone', default='resnet50', type=str,
                        help="Name of the convolutional backbone to use")
    parser.add_argument('--dilation', action='store_true',
                        help="If true, we replace stride with dilation in the last convolutional block (DC5)")
    parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'),
                        help="Type of positional embedding to use on top of the image features")
    # * Transformer
    parser.add_argument('--enc_layers', default=6, type=int,
                        help="Number of encoding layers in the transformer")
    parser.add_argument('--dec_layers', default=6, type=int,
                        help="Number of decoding layers in the transformer")
    parser.add_argument('--dim_feedforward', default=2048, type=int,
                        help="Intermediate size of the feedforward layers in the transformer blocks")
    parser.add_argument('--hidden_dim', default=256, type=int,
                        help="Size of the embeddings (dimension of the transformer)")
    parser.add_argument('--dropout', default=0.1, type=float,
                        help="Dropout applied in the transformer")
    parser.add_argument('--nheads', default=8, type=int,
                        help="Number of attention heads inside the transformer's attentions")
    parser.add_argument('--num_queries', default=100, type=int,
                        help="Number of query slots")
    parser.add_argument('--pre_norm', action='store_true')
    # Loss
    parser.add_argument('--no_aux_loss', dest='aux_loss', action='store_false',
                        help="Disables auxiliary decoding losses (loss at each layer)")
    return parser


if __name__ == "__main__":
    args = get_parser().parse_args()
    device = torch.device(args.device)

    if not os.path.exists(args.save_path):
Rayyyyy's avatar
Rayyyyy committed
119
        os.makedirs(args.save_path)
chenych's avatar
chenych committed
120
121
122
123
124
125
126
127

    model, postprocessors = build_test_model(args)
    model.to(device)
    checkpoint = torch.load(args.pre_trained_model, map_location='cpu')
    model.load_state_dict(checkpoint["model"], False)
    model.eval()

    test_img(args, model, postprocessors, args.save_path)