test.py 5.82 KB
Newer Older
chenych's avatar
chenych committed
1
2
3
4
5
6
7
8
import os
import cv2
import torch
import argparse

from datasets import build_dataset, get_coco_api_from_dataset
from models import build_test_model
from datasets.coco_eval import CocoEvaluator
chenych's avatar
chenych committed
9

chenych's avatar
chenych committed
10
11
12
import util.misc as utils

def test_img(args, model, postprocessors, save_path):
chenych's avatar
chenych committed
13
14
    '''test img'''
    dataset_test = build_dataset(image_set='val', args=args)
chenych's avatar
chenych committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
    sampler_test = torch.utils.data.SequentialSampler(dataset_test)
    data_loader_test = torch.utils.data.DataLoader(
            dataset_test,
            1,
            sampler=sampler_test,
            drop_last=False,
            collate_fn=utils.collate_fn,
            num_workers=0,
            pin_memory=True,)

    base_ds = get_coco_api_from_dataset(dataset_test)
    for img_data, targets in data_loader_test:
        img_data = img_data.to(device)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        outputs = model(img_data)
chenych's avatar
chenych committed
31
        orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
chenych's avatar
chenych committed
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
        result = postprocessors['bbox'](outputs, orig_target_sizes)
        if 'segm' in postprocessors.keys():
            target_sizes = torch.stack([t["size"] for t in targets], dim=0)
            results = postprocessors['segm'](results, outputs, orig_target_sizes, target_sizes)

        res = {target['image_id'].item(): output for target, output in zip(targets, result)}
        iou_types = tuple(k for k in ("segm", "bbox") if k in postprocessors.keys())
        coco_evaluator = CocoEvaluator(base_ds, iou_types)
        if coco_evaluator is not None:
            coco_evaluator.update(res)

        res = res[targets[0]['image_id'].item()]

        min_score = 0.65
        img_name = dataset_test.coco.loadImgs(targets[0]['image_id'].item())[0]['file_name']
        img = cv2.imread(os.path.join(args.coco_path, 'images/val2017', img_name))

        draw_img = img.copy()
        save_status = False
        for i in range(0, 100):
            res_tmp = res['scores']
            if float(res_tmp[i]) > min_score:
                save_status = True
                score = float(res_tmp[i])
                label = int(res['labels'][i].cpu().numpy())
                bbox = res['boxes'][i].cpu().numpy().tolist()
                print("***", label, bbox)
                cv2.putText(draw_img, "{} | {}".format(label, str(score)[:3]), (int(bbox[0]), int(bbox[1])-2), 0, 0.5, (255, 255, 255), 1)
                cv2.rectangle(draw_img, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])), (0, 0, 255), 1)
        if save_status:
            cv2.imwrite("{}/{}".format(save_path, img_name), draw_img)

    if coco_evaluator is not None:
        coco_evaluator.synchronize_between_processes()
        coco_evaluator.accumulate()
        coco_evaluator.summarize()
        print(coco_evaluator)

def get_parser():
    parser = argparse.ArgumentParser("DETR Detector", add_help=False)
    parser.add_argument('--dataset_file', default='coco')
chenych's avatar
chenych committed
73
    parser.add_argument('--coco_path', default='/home/datasets/COCO2017', type=str)
chenych's avatar
chenych committed
74
75
    parser.add_argument('--coco_panoptic_path', type=str)
    parser.add_argument('--remove_difficult', action='store_true')
chenych's avatar
chenych committed
76
    parser.add_argument("--save_path", default="./result_img", type=str)
chenych's avatar
chenych committed
77
    parser.add_argument('--lr_backbone', default=1e-5, type=float)
chenych's avatar
chenych committed
78
    parser.add_argument('--device', default='cuda',
chenych's avatar
chenych committed
79
                        help='device to use for training / testing')
chenych's avatar
chenych committed
80
    parser.add_argument("--pre_trained_model", default="")
chenych's avatar
chenych committed
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
    parser.add_argument('--seed', default=42, type=int)
    # * Segmentation
    parser.add_argument('--masks', action='store_true',
                        help="Train segmentation head if the flag is provided")
    # * Backbone
    parser.add_argument('--backbone', default='resnet50', type=str,
                        help="Name of the convolutional backbone to use")
    parser.add_argument('--dilation', action='store_true',
                        help="If true, we replace stride with dilation in the last convolutional block (DC5)")
    parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'),
                        help="Type of positional embedding to use on top of the image features")
    # * Transformer
    parser.add_argument('--enc_layers', default=6, type=int,
                        help="Number of encoding layers in the transformer")
    parser.add_argument('--dec_layers', default=6, type=int,
                        help="Number of decoding layers in the transformer")
    parser.add_argument('--dim_feedforward', default=2048, type=int,
                        help="Intermediate size of the feedforward layers in the transformer blocks")
    parser.add_argument('--hidden_dim', default=256, type=int,
                        help="Size of the embeddings (dimension of the transformer)")
    parser.add_argument('--dropout', default=0.1, type=float,
                        help="Dropout applied in the transformer")
    parser.add_argument('--nheads', default=8, type=int,
                        help="Number of attention heads inside the transformer's attentions")
    parser.add_argument('--num_queries', default=100, type=int,
                        help="Number of query slots")
    parser.add_argument('--pre_norm', action='store_true')
    # Loss
    parser.add_argument('--no_aux_loss', dest='aux_loss', action='store_false',
                        help="Disables auxiliary decoding losses (loss at each layer)")
    return parser

chenych's avatar
chenych committed
113

chenych's avatar
chenych committed
114
115
116
117
if __name__ == "__main__":
    args = get_parser().parse_args()
    device = torch.device(args.device)

chenych's avatar
chenych committed
118
119
120
    if not os.path.exists(args.save_path):
        os.makedirs()

chenych's avatar
chenych committed
121
122
    model, postprocessors = build_test_model(args)
    model.to(device)
chenych's avatar
chenych committed
123
    checkpoint = torch.load(args.pre_trained_model, map_location='cpu')
chenych's avatar
chenych committed
124
125
126
    model.load_state_dict(checkpoint["model"], False)
    model.eval()

chenych's avatar
chenych committed
127
    test_img(args, model, postprocessors, args.save_path)