"scripts/dist_infer/run_wan_t2v_dist_ring.sh" did not exist on "8bc0da34b4d1ae737af9f4e3dd713e76a64de202"
test.py 6.38 KB
Newer Older
chenych's avatar
chenych committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import os
import argparse

import os
import cv2
import torch
import argparse

from datasets import build_dataset, get_coco_api_from_dataset
from models import build_test_model

from datasets.coco_eval import CocoEvaluator
import numpy as np
import util.misc as utils

def test_img(args, model, postprocessors, save_path):

    dataset_test = build_dataset(
        image_set="val", args=args, eval_in_training_set=False,
    )
    sampler_test = torch.utils.data.SequentialSampler(dataset_test)
    data_loader_test = torch.utils.data.DataLoader(
            dataset_test,
            1,
            sampler=sampler_test,
            drop_last=False,
            collate_fn=utils.collate_fn,
            num_workers=0,
            pin_memory=True,)

    base_ds = get_coco_api_from_dataset(dataset_test)
    for img_data, targets in data_loader_test:
        img_data = img_data.to(device)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        outputs = model(img_data)
        orig_target_sizes = torch.stack([t["orig_size"] for t in target], dim=0)
        result = postprocessors['bbox'](outputs, orig_target_sizes)
        if 'segm' in postprocessors.keys():
            target_sizes = torch.stack([t["size"] for t in targets], dim=0)
            results = postprocessors['segm'](results, outputs, orig_target_sizes, target_sizes)

        res = {target['image_id'].item(): output for target, output in zip(targets, result)}
        iou_types = tuple(k for k in ("segm", "bbox") if k in postprocessors.keys())
        coco_evaluator = CocoEvaluator(base_ds, iou_types)
        if coco_evaluator is not None:
            coco_evaluator.update(res)

        res = res[targets[0]['image_id'].item()]

        min_score = 0.65
        img_name = dataset_test.coco.loadImgs(targets[0]['image_id'].item())[0]['file_name']
        img = cv2.imread(os.path.join(args.coco_path, 'images/val2017', img_name))

        draw_img = img.copy()
        save_status = False
        for i in range(0, 100):
            res_tmp = res['scores']
            if float(res_tmp[i]) > min_score:
                save_status = True
                score = float(res_tmp[i])
                label = int(res['labels'][i].cpu().numpy())
                bbox = res['boxes'][i].cpu().numpy().tolist()
                print("***", label, bbox)
                cv2.putText(draw_img, "{} | {}".format(label, str(score)[:3]), (int(bbox[0]), int(bbox[1])-2), 0, 0.5, (255, 255, 255), 1)
                cv2.rectangle(draw_img, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])), (0, 0, 255), 1)
        if save_status:
            cv2.imwrite("{}/{}".format(save_path, img_name), draw_img)

    if coco_evaluator is not None:
        coco_evaluator.synchronize_between_processes()
        coco_evaluator.accumulate()
        coco_evaluator.summarize()
        print(coco_evaluator)


def get_parser():
    parser = argparse.ArgumentParser("DETR Detector", add_help=False)
    parser.add_argument('--dataset_file', default='coco')
chenych's avatar
chenych committed
80
    parser.add_argument('--coco_path', default='/home/datasets/COCO2017', type=str)
chenych's avatar
chenych committed
81
82
    parser.add_argument('--coco_panoptic_path', type=str)
    parser.add_argument('--remove_difficult', action='store_true')
chenych's avatar
chenych committed
83
    parser.add_argument("--save_path", default="./result_img", type=str)
chenych's avatar
chenych committed
84
85
    parser.add_argument('--output_dir', default='',
                        help='path where to save, empty for no saving')
chenych's avatar
chenych committed
86
    parser.add_argument('--device', default='cuda',
chenych's avatar
chenych committed
87
                        help='device to use for training / testing')
chenych's avatar
chenych committed
88
    parser.add_argument("--pre_trained_model", default="")
chenych's avatar
chenych committed
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
    parser.add_argument('--seed', default=42, type=int)
    # * Segmentation
    parser.add_argument('--masks', action='store_true',
                        help="Train segmentation head if the flag is provided")
    # * Backbone
    parser.add_argument('--backbone', default='resnet50', type=str,
                        help="Name of the convolutional backbone to use")
    parser.add_argument('--dilation', action='store_true',
                        help="If true, we replace stride with dilation in the last convolutional block (DC5)")
    parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'),
                        help="Type of positional embedding to use on top of the image features")
    # * Transformer
    parser.add_argument('--enc_layers', default=6, type=int,
                        help="Number of encoding layers in the transformer")
    parser.add_argument('--dec_layers', default=6, type=int,
                        help="Number of decoding layers in the transformer")
    parser.add_argument('--dim_feedforward', default=2048, type=int,
                        help="Intermediate size of the feedforward layers in the transformer blocks")
    parser.add_argument('--hidden_dim', default=256, type=int,
                        help="Size of the embeddings (dimension of the transformer)")
    parser.add_argument('--dropout', default=0.1, type=float,
                        help="Dropout applied in the transformer")
    parser.add_argument('--nheads', default=8, type=int,
                        help="Number of attention heads inside the transformer's attentions")
    parser.add_argument('--num_queries', default=100, type=int,
                        help="Number of query slots")
    parser.add_argument('--pre_norm', action='store_true')
    # Loss
    parser.add_argument('--no_aux_loss', dest='aux_loss', action='store_false',
                        help="Disables auxiliary decoding losses (loss at each layer)")
    # * eval technologies
    parser.add_argument("--eval", action="store_true")
    # eval in training set
    parser.add_argument("--eval_in_training_set", default=False, action="store_true")
    # topk for eval
    parser.add_argument("--topk", default=100, type=int)
    # * training technologies
    parser.add_argument("--use_fp16", default=False, action="store_true")
    parser.add_argument("--use_checkpoint", default=False, action="store_true")
    return parser

if __name__ == "__main__":
    args = get_parser().parse_args()
    device = torch.device(args.device)

chenych's avatar
chenych committed
134
135
136
    if not os.path.exists(args.save_path):
        os.makedirs()

chenych's avatar
chenych committed
137
138
    model, postprocessors = build_test_model(args)
    model.to(device)
chenych's avatar
chenych committed
139
    checkpoint = torch.load(args.pre_trained_model, map_location='cpu')
chenych's avatar
chenych committed
140
141
142
143
    model.load_state_dict(checkpoint["model"], False)

    model.eval()

chenych's avatar
chenych committed
144
    test_img(args, model, postprocessors, args.save_path)