Commit 1c6361d8 authored by Qingyun's avatar Qingyun Committed by zhe chen
Browse files

feat: add zero shot instance seg of sam prompted by prediction boxes of...

feat: add zero shot instance seg of sam prompted by prediction boxes of detector based on InternImage
parent f7df4b3c
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import pickle
import shutil
import tempfile
import time
import numpy as np
import torch
import torch.distributed as dist
import torch.nn.functional as F
import mmcv
from mmcv.image import tensor2imgs
from mmcv.runner import get_dist_info
from mmdet.core import encode_mask_results
def prompt_sam_with_bboxes(sam_predictor, data, box_result):
# process detector prediction
# (x1, y1, x2, y2), rescaled in original image space
bboxes = np.concatenate(box_result, axis=0)[..., :4]
if len(bboxes) == 0:
return [[] for _ in range(len(box_result))]
labels = np.concatenate([[i] * len(boxes) for i, boxes in enumerate(box_result)])
# prepare shapes
img_metas = data['img_metas'][0].data[0][0]
original_size = img_metas['ori_shape'][:2]
# prepare input img of sam
sam_predictor.reset_image()
# img has been normed (NOTE 2.x norm img in pipeline)
img = data['img'][0] .to(sam_predictor.model.device)
# resize max length to 1024 and keep aspect ratio (ViT image encoder limitation)
target_size = sam_predictor.transform.get_preprocess_shape(
img.shape[2], img.shape[3],
sam_predictor.transform.target_length)
try:
# `antialias=True` is provided in official implementation of SAM,
# which may raise TypeError in PyTorch of previous versions.
transformed_img = F.interpolate(
img, target_size, mode="bilinear",
align_corners=False, antialias=True)
except TypeError:
transformed_img = F.interpolate(
img, target_size, mode="bilinear", align_corners=False)
# Pad to 1024 x 1024
h, w = transformed_img.shape[-2:]
pad_h = sam_predictor.model.image_encoder.img_size - h
pad_w = sam_predictor.model.image_encoder.img_size - w
transformed_img = F.pad(transformed_img, (0, pad_w, 0, pad_h))
# extract img feature
sam_predictor.features = sam_predictor.model.image_encoder(
transformed_img).to(sam_predictor.model.device)
# set attributes
sam_predictor.original_size = original_size
sam_predictor.input_size = tuple(transformed_img.shape[-2:])
sam_predictor.is_image_set = True
# prepare bboxes and rescale bboxes to relative coordinates
bboxes_tensor = torch.from_numpy(bboxes).to(sam_predictor.model.device)
transformed_boxes = sam_predictor.transform.apply_boxes_torch(bboxes_tensor, original_size)
# prompt with bboxes
batch_masks, _, _ = sam_predictor.predict_torch(
point_coords=None,
point_labels=None,
boxes=transformed_boxes,
multimask_output=False)
batch_masks = batch_masks.squeeze(1).cpu().numpy()
mask_results = [[*batch_masks[labels == i]] for i in range(len(box_result))]
return mask_results
def single_gpu_test(model,
sam_predictor,
data_loader,
show=False,
out_dir=None,
show_score_thr=0.3):
model.eval()
results = []
dataset = data_loader.dataset
PALETTE = getattr(dataset, 'PALETTE', None)
prog_bar = mmcv.ProgressBar(len(dataset))
for i, data in enumerate(data_loader):
with torch.no_grad():
# For instance segmentor, only the box results is used in the
# second stage (prompt sam with box). NOTE the mask_head is still
# calculated, hence the FPS, FLOPS, maybe not accurate.
result = model(return_loss=False, rescale=True, **data)
if getattr(model.module, 'with_mask', False):
box_result = result[0][0] # simple_test supported
mask_result = prompt_sam_with_bboxes(sam_predictor, data, box_result)
result = [(box_result, mask_result)]
else:
raise NotImplementedError('WIP!')
batch_size = len(result)
if show or out_dir:
if batch_size == 1 and isinstance(data['img'][0], torch.Tensor):
img_tensor = data['img'][0]
else:
img_tensor = data['img'][0].data[0]
img_metas = data['img_metas'][0].data[0]
imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
assert len(imgs) == len(img_metas)
for i, (img, img_meta) in enumerate(zip(imgs, img_metas)):
h, w, _ = img_meta['img_shape']
img_show = img[:h, :w, :]
ori_h, ori_w = img_meta['ori_shape'][:-1]
img_show = mmcv.imresize(img_show, (ori_w, ori_h))
if out_dir:
out_file = osp.join(out_dir, img_meta['ori_filename'])
else:
out_file = None
model.module.show_result(
img_show,
result[i],
bbox_color=PALETTE,
text_color=PALETTE,
mask_color=PALETTE,
show=show,
out_file=out_file,
score_thr=show_score_thr)
# encode mask results
if isinstance(result[0], tuple):
result = [(bbox_results, encode_mask_results(mask_results))
for bbox_results, mask_results in result]
# This logic is only used in panoptic segmentation test.
elif isinstance(result[0], dict) and 'ins_results' in result[0]:
for j in range(len(result)):
bbox_results, mask_results = result[j]['ins_results']
result[j]['ins_results'] = (bbox_results,
encode_mask_results(mask_results))
results.extend(result)
for _ in range(batch_size):
prog_bar.update()
return results
# --------------------------------------------------------
# InternImage
# Copyright (c) 2022 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
import argparse
import os
import os.path as osp
import time
import warnings
import mmcv
import torch
from mmcv import Config, DictAction
from mmcv.cnn import fuse_conv_bn
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
wrap_fp16_model)
from mmdet.datasets import (build_dataloader, build_dataset,
replace_ImageToTensor)
from mmdet.models import build_detector
from mmdet.apis import multi_gpu_test
import detection.mmdet_custom # noqa: F401,F403
import detection.mmcv_custom # noqa: F401,F403
from segment_anything import sam_model_registry, SamPredictor
try:
from .engine import single_gpu_test
except ImportError:
from sam.engine import single_gpu_test
def parse_args():
parser = argparse.ArgumentParser(
description='Zero-shot instance segmentation evaluation for '
'SAM prompted by MMDet detector')
parser.add_argument('detector_cfg_path',
help='test config file path of MMDet detector')
parser.add_argument('detector_ckpt_path',
help='checkpoint file path of MMDet detector')
parser.add_argument('sam_ckpt_path', default='vit_b',
help='checkpoint file path of SAM')
parser.add_argument('--sam_type', default='vit_b',
help='test config file path of MMDet detector')
parser.add_argument('--data_type', default='test', choices=['val', 'test'],
help='run val set or test set')
parser.add_argument(
'--work-dir',
help='the directory to save the file containing evaluation metrics')
parser.add_argument('--out', help='output result file in pickle format')
parser.add_argument(
'--fuse-conv-bn',
action='store_true',
help='Whether to fuse conv and bn, this will slightly increase'
'the inference speed')
parser.add_argument('--gpu-ids',
type=int,
nargs='+',
help='ids of gpus to use '
'(only applicable to non-distributed testing)')
parser.add_argument(
'--format-only',
action='store_true',
help='Format the output results without perform evaluation. It is'
'useful when you want to format the result to a specific format and '
'submit it to the test server')
parser.add_argument(
'--eval',
type=str,
nargs='+',
help='evaluation metrics, which depends on the dataset, e.g., "bbox",'
' "segm", "proposal" for COCO, and "mAP", "recall" for PASCAL VOC')
parser.add_argument('--show', action='store_true', help='show results')
parser.add_argument('--show-dir',
help='directory where painted images will be saved')
parser.add_argument('--show-score-thr',
type=float,
default=0.3,
help='score threshold (default: 0.3)')
parser.add_argument('--gpu-collect',
action='store_true',
help='whether to use gpu to collect results.')
parser.add_argument(
'--tmpdir',
help='tmp directory used for collecting results from multiple '
'workers, available when gpu-collect is not specified')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
parser.add_argument(
'--options',
nargs='+',
action=DictAction,
help='custom options for evaluation, the key-value pair in xxx=yyy '
'format will be kwargs for dataset.evaluate() function (deprecate), '
'change to --eval-options instead.')
parser.add_argument(
'--eval-options',
nargs='+',
action=DictAction,
help='custom options for evaluation, the key-value pair in xxx=yyy '
'format will be kwargs for dataset.evaluate() function')
parser.add_argument('--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
if args.options and args.eval_options:
raise ValueError(
'--options and --eval-options cannot be both '
'specified, --options is deprecated in favor of --eval-options')
if args.options:
warnings.warn('--options is deprecated in favor of --eval-options')
args.eval_options = args.options
return args
def main():
print('!!!!!!!!!!!!!!!!!!1', flush=True)
args = parse_args()
assert args.out or args.eval or args.format_only or args.show \
or args.show_dir, \
('Please specify at least one operation (save/eval/format/show the '
'results / save the results) with the argument "--out", "--eval"'
', "--format-only", "--show" or "--show-dir"')
if args.eval and args.format_only:
raise ValueError('--eval and --format_only cannot be both specified')
if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):
raise ValueError('The output file must be a pkl file.')
cfg = Config.fromfile(args.detector_cfg_path)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
cfg.model.pretrained = None
if cfg.model.get('neck'):
if isinstance(cfg.model.neck, list):
for neck_cfg in cfg.model.neck:
if neck_cfg.get('rfp_backbone'):
if neck_cfg.rfp_backbone.get('pretrained'):
neck_cfg.rfp_backbone.pretrained = None
elif cfg.model.neck.get('rfp_backbone'):
if cfg.model.neck.rfp_backbone.get('pretrained'):
cfg.model.neck.rfp_backbone.pretrained = None
# in case the test dataset is concatenated
samples_per_gpu = 1
if isinstance(cfg.data.test, dict):
cfg.data.test.test_mode = True
samples_per_gpu = cfg.data.test.pop('samples_per_gpu', 1)
if samples_per_gpu > 1:
# Replace 'ImageToTensor' to 'DefaultFormatBundle'
cfg.data.test.pipeline = replace_ImageToTensor(
cfg.data.test.pipeline)
elif isinstance(cfg.data.test, list):
for ds_cfg in cfg.data.test:
ds_cfg.test_mode = True
samples_per_gpu = max(
[ds_cfg.pop('samples_per_gpu', 1) for ds_cfg in cfg.data.test])
if samples_per_gpu > 1:
for ds_cfg in cfg.data.test:
ds_cfg.pipeline = replace_ImageToTensor(ds_cfg.pipeline)
if args.gpu_ids is not None:
cfg.gpu_ids = args.gpu_ids
else:
cfg.gpu_ids = range(1)
print('!!!!!!!!!!!!!!!!!!2', flush=True)
# init distributed env first, since logger depends on the dist info.
if args.launcher == 'none':
distributed = False
if len(cfg.gpu_ids) > 1:
warnings.warn(
f'We treat {cfg.gpu_ids} as gpu-ids, and reset to '
f'{cfg.gpu_ids[0:1]} as gpu-ids to avoid potential error in '
'non-distribute testing time.')
cfg.gpu_ids = cfg.gpu_ids[0:1]
else:
distributed = True
init_dist(args.launcher, **cfg.dist_params)
print('!!!!!!!!!!!!!!!!!!3', flush=True)
rank, _ = get_dist_info()
# allows not to create
if args.work_dir is not None and rank == 0:
mmcv.mkdir_or_exist(osp.abspath(args.work_dir))
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
json_file = osp.join(args.work_dir, f'eval_{timestamp}.json')
# build the dataloader
dataset = build_dataset(cfg.data.test)
data_loader = build_dataloader(dataset,
samples_per_gpu=samples_per_gpu,
workers_per_gpu=cfg.data.workers_per_gpu,
dist=distributed,
shuffle=False)
# build the detector and load checkpoint
cfg.model.train_cfg = None
model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg'))
fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None:
wrap_fp16_model(model)
checkpoint = load_checkpoint(model, args.detector_ckpt_path, map_location='cpu')
if args.fuse_conv_bn:
model = fuse_conv_bn(model)
# old versions did not save class info in checkpoints, this walkaround is
# for backward compatibility
if 'CLASSES' in checkpoint.get('meta', {}):
model.CLASSES = checkpoint['meta']['CLASSES']
else:
model.CLASSES = dataset.CLASSES
if not distributed:
model = MMDataParallel(model, device_ids=cfg.gpu_ids)
# The SamPredictor will be invalid If model is wrapped using MMDataParallel
# A better implementation will be not to use the provided SamPredictor API
sam = sam_model_registry[args.sam_type](
checkpoint=args.sam_ckpt_path).to(
list(model.module.parameters())[0].device)
sam_predictor = SamPredictor(sam)
outputs = single_gpu_test(model, sam_predictor, data_loader, args.show,
args.show_dir, args.show_score_thr)
else:
raise NotImplementedError
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False)
outputs = multi_gpu_test(model, data_loader, args.tmpdir,
args.gpu_collect)
rank, _ = get_dist_info()
if rank == 0:
if args.out:
print(f'\nwriting results to {args.out}')
mmcv.dump(outputs, args.out)
kwargs = {} if args.eval_options is None else args.eval_options
if args.format_only:
dataset.format_results(outputs, **kwargs)
if args.eval:
eval_kwargs = cfg.get('evaluation', {}).copy()
# hard-code way to remove EvalHook args
for key in [
'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best',
'rule', 'dynamic_intervals'
]:
eval_kwargs.pop(key, None)
eval_kwargs.update(dict(metric=args.eval, **kwargs))
metric = dataset.evaluate(outputs, **eval_kwargs)
print(metric)
metric_dict = dict(config=args.detector_cfg_path, metric=metric)
if args.work_dir is not None and rank == 0:
mmcv.dump(metric_dict, json_file)
if __name__ == '__main__':
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment