Unverified Commit cbee4278 authored by Tong He's avatar Tong He Committed by GitHub
Browse files

[Model] Scene Graph Extraction Model with GluonCV (#1260)



* add working scripts

* add frcnn training script

* remove redundent files

* refactor validation computation, will optimize sgdet and training

* validation finally finished

* f-rcnn training

* test reldn

* rm file

* update reldn training

* data preprocess to h5

* temp

* use coco json

* fix conflict

* new obj dataset for detection

* update training

* before cleanup

* remove abundant files

* add arg parse to train

* cleanup code file

* update

* fix

* add readme

* add ipynb as demo

* add demo pic

* update readme

* add demo script

* improve paths

* improve readme

* add docstrings

* fix args description

* update readme

* add models from s3

* update README
Co-authored-by: default avatarMinjie Wang <minjie.wang@nyu.edu>
parent ce93330e
import numpy as np
import gluoncv as gcv
from matplotlib import pyplot as plt
def plot_sg(img, preds, obj_classes, rel_classes, topk=1):
'''visualization of generated scene graph'''
size = img.shape[0:2]
box_scale = np.array([size[1], size[0], size[1], size[0]])
topk = min(topk, preds.shape[0])
ax = gcv.utils.viz.plot_image(img)
for i in range(topk):
rel = int(preds[i, 2])
src = int(preds[i, 3])
dst = int(preds[i, 4])
src_name = obj_classes[src]
dst_name = obj_classes[dst]
rel_name = rel_classes[rel]
src_bbox = preds[i, 5:9] * box_scale
dst_bbox = preds[i, 9:13] * box_scale
src_center = np.array([(src_bbox[0] + src_bbox[2]) / 2, (src_bbox[1] + src_bbox[3]) / 2])
dst_center = np.array([(dst_bbox[0] + dst_bbox[2]) / 2, (dst_bbox[1] + dst_bbox[3]) / 2])
rel_center = (src_center + dst_center) / 2
line_x = np.array([(src_bbox[0] + src_bbox[2]) / 2, (dst_bbox[0] + dst_bbox[2]) / 2])
line_y = np.array([(src_bbox[1] + src_bbox[3]) / 2, (dst_bbox[1] + dst_bbox[3]) / 2])
ax.plot(line_x, line_y,
linewidth=3.0, alpha=0.7, color=plt.cm.cool(rel))
ax.text(src_center[0], src_center[1],
'{:s}'.format(src_name),
bbox=dict(alpha=0.5),
fontsize=12, color='white')
ax.text(dst_center[0], dst_center[1],
'{:s}'.format(dst_name),
bbox=dict(alpha=0.5),
fontsize=12, color='white')
ax.text(rel_center[0], rel_center[1],
'{:s}'.format(rel_name),
bbox=dict(alpha=0.5),
fontsize=12, color='white')
return ax
plot_sg(img, preds, 2)
import dgl
import mxnet as mx
import numpy as np
import logging, time, argparse
from mxnet import nd, gluon
from gluoncv.data.batchify import Pad
from model import faster_rcnn_resnet101_v1d_custom, RelDN
from utils import *
from data import *
def parse_args():
parser = argparse.ArgumentParser(description='Validate Pre-trained RelDN Model.')
parser.add_argument('--gpus', type=str, default='0',
help="Training with GPUs, you can specify 1,3 for example.")
parser.add_argument('--batch-size', type=int, default=8,
help="Total batch-size for training.")
parser.add_argument('--metric', type=str, default='sgdet',
help="Evaluation metric, could be 'predcls', 'phrcls', 'sgdet' or 'sgdet+'.")
parser.add_argument('--pretrained-faster-rcnn-params', type=str, required=True,
help="Path to saved Faster R-CNN model parameters.")
parser.add_argument('--reldn-params', type=str, required=True,
help="Path to saved Faster R-CNN model parameters.")
parser.add_argument('--faster-rcnn-params', type=str, required=True,
help="Path to saved Faster R-CNN model parameters.")
parser.add_argument('--log-dir', type=str, default='reldn_output.log',
help="Path to save training logs.")
parser.add_argument('--freq-prior', type=str, default='freq_prior.pkl',
help="Path to saved frequency prior data.")
parser.add_argument('--verbose-freq', type=int, default=100,
help="Frequency of log printing in number of iterations.")
args = parser.parse_args()
return args
args = parse_args()
filehandler = logging.FileHandler(args.log_dir)
streamhandler = logging.StreamHandler()
logger = logging.getLogger('')
logger.setLevel(logging.INFO)
logger.addHandler(filehandler)
logger.addHandler(streamhandler)
# Hyperparams
ctx = [mx.gpu(int(i)) for i in args.gpus.split(',') if i.strip()]
if ctx:
num_gpus = len(ctx)
assert args.batch_size % num_gpus == 0
per_device_batch_size = int(args.batch_size / num_gpus)
else:
ctx = [mx.cpu()]
per_device_batch_size = args.batch_size
batch_size = args.batch_size
N_relations = 50
N_objects = 150
batch_verbose_freq = args.verbose_freq
mode = args.metric
metric_list = []
topk_list = [20, 50, 100]
if mode == 'predcls':
for topk in topk_list:
metric_list.append(PredCls(topk=topk))
if mode == 'phrcls':
for topk in topk_list:
metric_list.append(PhrCls(topk=topk))
if mode == 'sgdet':
for topk in topk_list:
metric_list.append(SGDet(topk=topk))
if mode == 'sgdet+':
for topk in topk_list:
metric_list.append(SGDetPlus(topk=topk))
for metric in metric_list:
metric.reset()
semantic_only = False
net = RelDN(n_classes=N_relations, prior_pkl=args.freq_prior,
semantic_only=semantic_only)
net.load_parameters(args.reldn_params, ctx=ctx)
# dataset and dataloader
vg_val = VGRelation(split='val')
logger.info('data loaded!')
val_data = gluon.data.DataLoader(vg_val, batch_size=len(ctx), shuffle=False, num_workers=16*num_gpus,
batchify_fn=dgl_mp_batchify_fn)
n_batches = len(val_data)
detector = faster_rcnn_resnet101_v1d_custom(classes=vg_val.obj_classes,
pretrained_base=False, pretrained=False,
additional_output=True)
params_path = args.pretrained_faster_rcnn_params
detector.load_parameters(params_path, ctx=ctx, ignore_extra=True, allow_missing=True)
detector_feat = faster_rcnn_resnet101_v1d_custom(classes=vg_val.obj_classes,
pretrained_base=False, pretrained=False,
additional_output=True)
detector_feat.load_parameters(params_path, ctx=ctx, ignore_extra=True, allow_missing=True)
detector_feat.features.load_parameters(args.faster_rcnn_params, ctx=ctx)
def get_data_batch(g_list, img_list, ctx_list):
if g_list is None or len(g_list) == 0:
return None, None
n_gpu = len(ctx_list)
size = len(g_list)
if size < n_gpu:
raise Exception("too small batch")
step = size // n_gpu
G_list = [g_list[i*step:(i+1)*step] if i < n_gpu - 1 else g_list[i*step:size] for i in range(n_gpu)]
img_list = [img_list[i*step:(i+1)*step] if i < n_gpu - 1 else img_list[i*step:size] for i in range(n_gpu)]
for G_slice, ctx in zip(G_list, ctx_list):
for G in G_slice:
G.ndata['bbox'] = G.ndata['bbox'].as_in_context(ctx)
G.ndata['node_class'] = G.ndata['node_class'].as_in_context(ctx)
G.ndata['node_class_vec'] = G.ndata['node_class_vec'].as_in_context(ctx)
G.edata['rel_class'] = G.edata['rel_class'].as_in_context(ctx)
img_list = [img.as_in_context(ctx) for img in img_list]
return G_list, img_list
for i, (G_list, img_list) in enumerate(val_data):
G_list, img_list = get_data_batch(G_list, img_list, ctx)
if G_list is None or img_list is None:
if (i+1) % batch_verbose_freq == 0:
print_txt = 'Batch[%d/%d] '%\
(i, n_batches)
for metric in metric_list:
metric_name, metric_val = metric.get()
print_txt += '%s=%.4f '%(metric_name, metric_val)
logger.info(print_txt)
continue
detector_res_list = []
G_batch = []
bbox_pad = Pad(axis=(0))
# loss_cls_val = 0
for G_slice, img in zip(G_list, img_list):
cur_ctx = img.context
if mode == 'predcls':
bbox_list = [G.ndata['bbox'] for G in G_slice]
bbox_stack = bbox_pad(bbox_list).as_in_context(cur_ctx)
ids, scores, bbox, spatial_feat = detector(img, None, None, bbox_stack)
node_class_list = [G.ndata['node_class'] for G in G_slice]
node_class_stack = bbox_pad(node_class_list).as_in_context(cur_ctx)
g_pred_batch = build_graph_validate_gt_obj(img, node_class_stack, bbox, spatial_feat,
bbox_improvement=True, overlap=False)
elif mode == 'phrcls':
# use ground truth bbox
bbox_list = [G.ndata['bbox'] for G in G_slice]
bbox_stack = bbox_pad(bbox_list).as_in_context(cur_ctx)
ids, scores, bbox, spatial_feat = detector(img, None, None, bbox_stack)
g_pred_batch = build_graph_validate_gt_bbox(img, ids, scores, bbox, spatial_feat,
bbox_improvement=True, overlap=False)
else:
# use predicted bbox
ids, scores, bbox, feat, feat_ind, spatial_feat = detector(img)
g_pred_batch = build_graph_validate_pred(img, ids, scores, bbox, feat_ind, spatial_feat,
bbox_improvement=True, scores_top_k=75, overlap=False)
if not semantic_only:
rel_bbox = g_pred_batch.edata['rel_bbox']
batch_id = g_pred_batch.edata['batch_id'].asnumpy()
n_sample_edges = g_pred_batch.number_of_edges()
# g_pred_batch.edata['edge_feat'] = mx.nd.zeros((n_sample_edges, 49), ctx=cur_ctx)
n_graph = len(G_slice)
bbox_rel_list = []
for j in range(n_graph):
eids = np.where(batch_id == j)[0]
if len(eids) > 0:
bbox_rel_list.append(rel_bbox[eids])
bbox_rel_stack = bbox_pad(bbox_rel_list).as_in_context(cur_ctx)
_, _, _, spatial_feat_rel = detector_feat(img, None, None, bbox_rel_stack)
spatial_feat_rel_list = []
for j in range(n_graph):
eids = np.where(batch_id == j)[0]
if len(eids) > 0:
spatial_feat_rel_list.append(spatial_feat_rel[j, 0:len(eids)])
g_pred_batch.edata['edge_feat'] = nd.concat(*spatial_feat_rel_list, dim=0)
G_batch.append(g_pred_batch)
G_batch = [net(G) for G in G_batch]
for G_slice, G_pred, img_slice in zip(G_list, G_batch, img_list):
for G_gt, G_pred_one in zip(G_slice, [G_pred]):
if G_pred_one is None or G_pred_one.number_of_nodes() == 0:
continue
gt_objects, gt_triplet = extract_gt(G_gt, img_slice.shape[2:4])
pred_objects, pred_triplet = extract_pred(G_pred, joint_preds=True)
for metric in metric_list:
if isinstance(metric, PredCls) or \
isinstance(metric, PhrCls) or \
isinstance(metric, SGDet):
metric.update(gt_triplet, pred_triplet)
else:
metric.update((gt_objects, gt_triplet), (pred_objects, pred_triplet))
if (i+1) % batch_verbose_freq == 0:
print_txt = 'Batch[%d/%d] '%\
(i, n_batches)
for metric in metric_list:
metric_name, metric_val = metric.get()
print_txt += '%s=%.4f '%(metric_name, metric_val)
logger.info(print_txt)
print_txt = 'Batch[%d/%d] '%\
(n_batches, n_batches)
for metric in metric_list:
metric_name, metric_val = metric.get()
print_txt += '%s=%.4f '%(metric_name, metric_val)
logger.info(print_txt)
MXNET_CUDNN_AUTOTUNE_DEFAULT=0 python validate_reldn.py \
--pretrained-faster-rcnn-params faster_rcnn_resnet101_v1d_visualgenome/faster_rcnn_resnet101_v1d_custom_best.params \
--reldn-params params_resnet101_v1d_reldn/model-8.params \
--faster-rcnn-params params_resnet101_v1d_reldn/detector_feat.features-8.params
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment