Unverified Commit cbee4278 authored by Tong He's avatar Tong He Committed by GitHub
Browse files

[Model] Scene Graph Extraction Model with GluonCV (#1260)



* add working scripts

* add frcnn training script

* remove redundent files

* refactor validation computation, will optimize sgdet and training

* validation finally finished

* f-rcnn training

* test reldn

* rm file

* update reldn training

* data preprocess to h5

* temp

* use coco json

* fix conflict

* new obj dataset for detection

* update training

* before cleanup

* remove abundant files

* add arg parse to train

* cleanup code file

* update

* fix

* add readme

* add ipynb as demo

* add demo pic

* update readme

* add demo script

* improve paths

* improve readme

* add docstrings

* fix args description

* update readme

* add models from s3

* update README
Co-authored-by: default avatarMinjie Wang <minjie.wang@nyu.edu>
parent ce93330e
......@@ -147,3 +147,6 @@ cscope.*
*.swo
*.un~
*~
# parameters
*.params
# Scene Graph Extraction
Scene graph extraction aims at not only detect objects in the given image, but also classify the relationships between pairs of them.
This example reproduces [Graphical Contrastive Losses for Scene Graph Parsing](https://arxiv.org/abs/1903.02728), author's code can be found [here](https://github.com/NVIDIA/ContrastiveLosses4VRD).
![DEMO](https://raw.githubusercontent.com/dmlc/web-data/master/dgl/examples/mxnet/scenegraph/old-couple-pred.png)
## Results
**VisualGenome**
| Model | Backbone | mAP@50 | SGDET@20 | SGDET@50 | SGDET@100 | PHRCLS@20 | PHRCLS@50 |PHRCLS@100 | PREDCLS@20 | PREDCLS@50 | PREDCLS@100 |
| :--- | :--- | :--- | :--- | :--- | :--- | :--- | :--- | :--- | :--- | :--- | :--- |
| RelDN, L0 | ResNet101 | 29.5 | 22.65 | 30.02 | 35.04 | 32.84 | 35.60 | 36.26 | 60.58 | 65.53 | 66.51 |
## Preparation
This implementation is based on GluonCV. Install GluonCV with
```
pip install gluoncv --upgrade
```
The implementation contains the following files:
```
.
|-- data
| |-- dataloader.py
| |-- __init__.py
| |-- object.py
| |-- prepare_visualgenome.py
| `-- relation.py
|-- demo_reldn.py
|-- model
| |-- faster_rcnn.py
| |-- __init__.py
| `-- reldn.py
|-- README.md
|-- train_faster_rcnn.py
|-- train_faster_rcnn.sh
|-- train_freq_prior.py
|-- train_reldn.py
|-- train_reldn.sh
|-- utils
| |-- build_graph.py
| |-- __init__.py
| |-- metric.py
| |-- sampling.py
| `-- viz.py
|-- validate_reldn.py
`-- validate_reldn.sh
```
- The folder `data` contains the data preparation script, and definition of datasets for object detection and scene graph extraction.
- The folder `model` contains model definition.
- The folder `utils` contains helper functions for training, validation, and visualization.
- The script `train_faster_rcnn.py` trains a Faster R-CNN model on VisualGenome dataset, and `train_faster_rcnn.sh` includes preset parameters.
- The script `train_freq_prior.py` trains the frequency counts for RelDN model training.
- The script `train_reldn.py` trains a RelDN model, and `train_reldn.sh` includes preset parameters.
- The script `validate_reldn.py` validate the trained Faster R-CNN and RelDN models, and `validate_reldn.sh` includes preset parameters.
- The script `demo_reldh.py` makes use of trained parameters and extract an scene graph from an arbitrary input image.
Below are further steps on training your own models. Besides, we also provide pretrained model files for validation and demo:
1. [Faster R-CNN Model for Object Detection](http://dgl-data/models/SceneGraph/faster_rcnn_resnet101_v1d_visualgenome.params)
2. [RelDN Model](http://dgl-data/models/SceneGraph/reldn.params)
3. [Faster R-CNN Model for Edge Feature](http://dgl-data/models/SceneGraph/detector_feature.params)
## Data preparation
We provide scripts to download and prepare the VisualGenome dataset. One can run with
```
python data/prepare_visualgenome.py
```
## Object Detector
First one need to train the object detection model on VisualGenome.
```
bash train_faster_rcnn.sh
```
It runs for about 20 hours on a machine with 64 CPU cores and 8 V100 GPUs.
## Training RelDN
With a trained Faster R-CNN model, one can start the training of RelDN model by
```
bash train_reldn.sh
```
It runs for about 2 days with one single GPU and 8 CPU cores.
## Validate RelDN
After the training, one can evaluate the results with multiple commonly-used metrics:
```
bash validate_reldn.sh
```
## Demo
We provide a demo script of running the model with real-world pictures. Be aware that you need trained model to generate meaningful results from the demo, otherwise the script will download the pre-trained model automatically.
from .object import *
from .relation import *
from .dataloader import *
"""DataLoader utils."""
import dgl
from mxnet import nd
from gluoncv.data.batchify import Pad
def dgl_mp_batchify_fn(data):
if isinstance(data[0], tuple):
data = zip(*data)
return [dgl_mp_batchify_fn(i) for i in data]
for dt in data:
if dt is not None:
if isinstance(dt, dgl.DGLGraph):
return [d for d in data if isinstance(d, dgl.DGLGraph)]
elif isinstance(dt, nd.NDArray):
pad = Pad(axis=(1, 2), num_shards=1, ret_length=False)
data_list = [dt for dt in data if dt is not None]
return pad(data_list)
"""Pascal VOC object detection dataset."""
from __future__ import absolute_import
from __future__ import division
import os
import logging
import warnings
import json
import pickle
import numpy as np
import mxnet as mx
from gluoncv.data import COCODetection
from collections import Counter
class VGObject(COCODetection):
CLASSES = ["airplane", "animal", "arm", "bag", "banana", "basket", "beach",
"bear", "bed", "bench", "bike", "bird", "board", "boat", "book",
"boot", "bottle", "bowl", "box", "boy", "branch", "building", "bus",
"cabinet", "cap", "car", "cat", "chair", "child", "clock", "coat",
"counter", "cow", "cup", "curtain", "desk", "dog", "door", "drawer",
"ear", "elephant", "engine", "eye", "face", "fence", "finger", "flag",
"flower", "food", "fork", "fruit", "giraffe", "girl", "glass", "glove",
"guy", "hair", "hand", "handle", "hat", "head", "helmet", "hill",
"horse", "house", "jacket", "jean", "kid", "kite", "lady", "lamp",
"laptop", "leaf", "leg", "letter", "light", "logo", "man", "men",
"motorcycle", "mountain", "mouth", "neck", "nose", "number", "orange",
"pant", "paper", "paw", "people", "person", "phone", "pillow", "pizza",
"plane", "plant", "plate", "player", "pole", "post", "pot", "racket",
"railing", "rock", "roof", "room", "screen", "seat", "sheep", "shelf",
"shirt", "shoe", "short", "sidewalk", "sign", "sink", "skateboard",
"ski", "skier", "sneaker", "snow", "sock", "stand", "street",
"surfboard", "table", "tail", "tie", "tile", "tire", "toilet",
"towel", "tower", "track", "train", "tree", "truck", "trunk",
"umbrella", "vase", "vegetable", "vehicle", "wave", "wheel",
"window", "windshield", "wing", "wire", "woman", "zebra"]
def __init__(self, **kwargs):
super(VGObject, self).__init__(**kwargs)
@property
def annotation_dir(self):
return ''
def _parse_image_path(self, entry):
dirname = 'VG_100K'
filename = entry['file_name']
abs_path = os.path.join(self._root, dirname, filename)
return abs_path
"""Prepare Visual Genome datasets"""
import os
import shutil
import argparse
import zipfile
import random
import json
import tqdm
import pickle
from gluoncv.utils import download, makedirs
_TARGET_DIR = os.path.expanduser('~/.mxnet/datasets/visualgenome')
def parse_args():
parser = argparse.ArgumentParser(
description='Initialize Visual Genome dataset.',
epilog='Example: python visualgenome.py --download-dir ~/visualgenome',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--download-dir', type=str, default='~/visualgenome/',
help='dataset directory on disk')
parser.add_argument('--no-download', action='store_true', help='disable automatic download if set')
parser.add_argument('--overwrite', action='store_true', help='overwrite downloaded files if set, in case they are corrupted')
args = parser.parse_args()
return args
def download_vg(path, overwrite=False):
_DOWNLOAD_URLS = [
('https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip',
'a055367f675dd5476220e9b93e4ca9957b024b94'),
('https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip',
'2add3aab77623549e92b7f15cda0308f50b64ecf'),
]
makedirs(path)
for url, checksum in _DOWNLOAD_URLS:
filename = download(url, path=path, overwrite=overwrite, sha1_hash=checksum)
# extract
if filename.endswith('zip'):
with zipfile.ZipFile(filename) as zf:
zf.extractall(path=path)
# move all images into folder `VG_100K`
vg_100k_path = os.path.join(path, 'VG_100K')
vg_100k_2_path = os.path.join(path, 'VG_100K_2')
files_2 = os.listdir(vg_100k_2_path)
for fl in files_2:
shutil.move(os.path.join(vg_100k_2_path, fl),
os.path.join(vg_100k_path, fl))
def download_json(path, overwrite=False):
url = 'https://data.dgl.ai/dataset/vg.zip'
output = 'vg.zip'
download(url, path=path)
with zipfile.ZipFile(output) as zf:
zf.extractall(path=path)
json_path = os.path.join(path, 'vg')
json_files = os.listdir(json_path)
for fl in json_files:
shutil.move(os.path.join(json_path, fl),
os.path.join(path, fl))
os.rmdir(json_path)
if __name__ == '__main__':
args = parse_args()
path = os.path.expanduser(args.download_dir)
if not os.path.isdir(path):
if args.no_download:
raise ValueError(('{} is not a valid directory, make sure it is present.'
' Or you should not disable "--no-download" to grab it'.format(path)))
else:
download_vg(path, overwrite=args.overwrite)
download_json(path, overwrite=args.overwrite)
# make symlink
makedirs(os.path.expanduser('~/.mxnet/datasets'))
if os.path.isdir(_TARGET_DIR):
os.rmdir(_TARGET_DIR)
os.symlink(path, _TARGET_DIR)
"""Pascal VOC object detection dataset."""
from __future__ import absolute_import
from __future__ import division
import os
import logging
import warnings
import json
import dgl
import pickle
import numpy as np
import mxnet as mx
from gluoncv.data.base import VisionDataset
from collections import Counter
from gluoncv.data.transforms.presets.rcnn import FasterRCNNDefaultTrainTransform, FasterRCNNDefaultValTransform
class VGRelation(VisionDataset):
def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'visualgenome'), split='train'):
super(VGRelation, self).__init__(root)
self._root = os.path.expanduser(root)
self._img_path = os.path.join(self._root, 'VG_100K', '{}')
if split == 'train':
self._dict_path = os.path.join(self._root, 'rel_annotations_train.json')
elif split == 'val':
self._dict_path = os.path.join(self._root, 'rel_annotations_val.json')
else:
raise NotImplementedError
with open(self._dict_path) as f:
tmp = f.read()
self._dict = json.loads(tmp)
self._predicates_path = os.path.join(self._root, 'predicates.json')
with open(self._predicates_path, 'r') as f:
tmp = f.read()
self.rel_classes = json.loads(tmp)
self.num_rel_classes = len(self.rel_classes) + 1
self._objects_path = os.path.join(self._root, 'objects.json')
with open(self._objects_path, 'r') as f:
tmp = f.read()
self.obj_classes = json.loads(tmp)
self.num_obj_classes = len(self.obj_classes)
if split == 'val':
self.img_transform = FasterRCNNDefaultValTransform(short=600, max_size=1000)
else:
self.img_transform = FasterRCNNDefaultTrainTransform(short=600, max_size=1000)
self.split = split
def __len__(self):
return len(self._dict)
def _hash_bbox(self, object):
num_list = [object['category']] + object['bbox']
return '_'.join([str(num) for num in num_list])
def __getitem__(self, idx):
img_id = list(self._dict)[idx]
img_path = self._img_path.format(img_id)
img = mx.image.imread(img_path)
item = self._dict[img_id]
n_edges = len(item)
# edge to node ids
sub_node_hash = []
ob_node_hash = []
for i, it in enumerate(item):
sub_node_hash.append(self._hash_bbox(it['subject']))
ob_node_hash.append(self._hash_bbox(it['object']))
node_set = sorted(list(set(sub_node_hash + ob_node_hash)))
n_nodes = len(node_set)
node_to_id = {}
for i, node in enumerate(node_set):
node_to_id[node] = i
sub_id = []
ob_id = []
for i in range(n_edges):
sub_id.append(node_to_id[sub_node_hash[i]])
ob_id.append(node_to_id[ob_node_hash[i]])
# node features
bbox = mx.nd.zeros((n_nodes, 4))
node_class_ids = mx.nd.zeros((n_nodes, 1))
node_visited = [False for i in range(n_nodes)]
for i, it in enumerate(item):
if not node_visited[sub_id[i]]:
ind = sub_id[i]
sub = it['subject']
node_class_ids[ind] = sub['category']
# y1y2x1x2 to x1y1x2y2
bbox[ind,0] = sub['bbox'][2]
bbox[ind,1] = sub['bbox'][0]
bbox[ind,2] = sub['bbox'][3]
bbox[ind,3] = sub['bbox'][1]
node_visited[ind] = True
if not node_visited[ob_id[i]]:
ind = ob_id[i]
ob = it['object']
node_class_ids[ind] = ob['category']
# y1y2x1x2 to x1y1x2y2
bbox[ind,0] = ob['bbox'][2]
bbox[ind,1] = ob['bbox'][0]
bbox[ind,2] = ob['bbox'][3]
bbox[ind,3] = ob['bbox'][1]
node_visited[ind] = True
eta = 0.1
node_class_vec = node_class_ids[:,0].one_hot(self.num_obj_classes,
on_value = 1 - eta + eta / self.num_obj_classes,
off_value = eta / self.num_obj_classes)
# augmentation
if self.split == 'val':
img, bbox, _ = self.img_transform(img, bbox)
else:
img, bbox = self.img_transform(img, bbox)
# build the graph
g = dgl.DGLGraph(multigraph=True)
g.add_nodes(n_nodes)
adjmat = np.zeros((n_nodes, n_nodes))
predicate = []
for i, it in enumerate(item):
adjmat[sub_id[i], ob_id[i]] = 1
predicate.append(it['predicate'])
predicate = mx.nd.array(predicate).expand_dims(1)
g.add_edges(sub_id, ob_id, {'rel_class': mx.nd.array(predicate) + 1})
empty_edge_list = []
for i in range(n_nodes):
for j in range(n_nodes):
if i != j and adjmat[i, j] == 0:
empty_edge_list.append((i, j))
if len(empty_edge_list) > 0:
src, dst = tuple(zip(*empty_edge_list))
g.add_edges(src, dst, {'rel_class': mx.nd.zeros((len(empty_edge_list), 1))})
# assign features
g.ndata['bbox'] = bbox
g.ndata['node_class'] = node_class_ids
g.ndata['node_class_vec'] = node_class_vec
return g, img
import dgl
import argparse
import mxnet as mx
import gluoncv as gcv
from gluoncv.utilz import download
from gluoncv.data.transforms import presets
from model import faster_rcnn_resnet101_v1d_custom, RelDN
from utils import *
from data import *
def parse_args():
parser = argparse.ArgumentParser(description='Demo of Scene Graph Extraction.')
parser.add_argument('--image', type=str, default='',
help="The image for scene graph extraction.")
parser.add_argument('--gpu', type=str, default='',
help="GPU id to use for inference, default is not using GPU.")
parser.add_argument('--pretrained-faster-rcnn-params', type=str, default='',
help="Path to saved Faster R-CNN model parameters.")
parser.add_argument('--reldn-params', type=str, default='',
help="Path to saved Faster R-CNN model parameters.")
parser.add_argument('--faster-rcnn-params', type=str, default='',
help="Path to saved Faster R-CNN model parameters.")
parser.add_argument('--freq-prior', type=str, default='freq_prior.pkl',
help="Path to saved frequency prior data.")
args = parser.parse_args()
return args
args = parse_args()
if args.gpu:
ctx = mx.gpu(int(args.gpu))
else:
ctx = mx.cpu()
net = RelDN(n_classes=50, prior_pkl=args.freq_prior, semantic_only=False)
if args.reldn_params == '':
download('http://data.dgl.ai/models/SceneGraph/reldn.params')
net.load_parameters('rendl.params', ctx=ctx)
else:
net.load_parameters(args.reldn_params, ctx=ctx)
# dataset and dataloader
vg_val = VGRelation(split='val')
detector = faster_rcnn_resnet101_v1d_custom(classes=vg_val.obj_classes,
pretrained_base=False, pretrained=False,
additional_output=True)
if args.pretrained_faster_rcnn_params == '':
download('http://data.dgl.ai/models/SceneGraph/faster_rcnn_resnet101_v1d_visualgenome.params')
params_path = 'faster_rcnn_resnet101_v1d_visualgenome.params'
else:
params_path = args.pretrained_faster_rcnn_params
detector.load_parameters(params_path, ctx=ctx, ignore_extra=True, allow_missing=True)
detector_feat = faster_rcnn_resnet101_v1d_custom(classes=vg_val.obj_classes,
pretrained_base=False, pretrained=False,
additional_output=True)
detector_feat.load_parameters(params_path, ctx=ctx, ignore_extra=True, allow_missing=True)
if args.faster_rcnn_params == '':
download('http://data.dgl.ai/models/SceneGraph/faster_rcnn_resnet101_v1d_visualgenome.params')
detector_feat.features.load_parameters('faster_rcnn_resnet101_v1d_visualgenome.params', ctx=ctx)
else:
detector_feat.features.load_parameters(args.faster_rcnn_params, ctx=ctx)
# image input
if args.image:
image_path = args.image
else:
gcv.utils.download('https://raw.githubusercontent.com/dmlc/web-data/master/' +
'dgl/examples/mxnet/scenegraph/old-couple.png',
'old-couple.png')
image_path = 'old-couple.png'
x, img = presets.rcnn.load_test(args.image, short=detector.short, max_size=detector.max_size)
x = x.as_in_context(ctx)
# detector prediction
ids, scores, bboxes, feat, feat_ind, spatial_feat = detector(x)
# build graph, extract edge features
g = build_graph_validate_pred(x, ids, scores, bboxes, feat_ind, spatial_feat, bbox_improvement=True, scores_top_k=75, overlap=False)
rel_bbox = g.edata['rel_bbox'].expand_dims(0).as_in_context(ctx)
_, _, _, spatial_feat_rel = detector_feat(x, None, None, rel_bbox)
g.edata['edge_feat'] = spatial_feat_rel[0]
# graph prediction
g = net(g)
_, preds = extract_pred(g, joint_preds=True)
preds = preds[preds[:,1].argsort()[::-1]]
plot_sg(img, preds, detector.classes, vg_val.rel_classes, 10)
from .faster_rcnn import *
from .reldn import *
This diff is collapsed.
import dgl
import gluoncv as gcv
import mxnet as mx
import numpy as np
from mxnet import nd
from mxnet.gluon import nn
from dgl.utils import toindex
import pickle
from dgl.nn.mxnet import GraphConv
__all__ = ['RelDN']
class EdgeConfMLP(nn.Block):
'''compute the confidence for edges'''
def __init__(self):
super(EdgeConfMLP, self).__init__()
def forward(self, edges):
score_pred = nd.log_softmax(edges.data['preds'])[:,1:].max(axis=1)
score_phr = score_pred + edges.src['node_class_logit'] + edges.dst['node_class_logit']
return {'score_pred': score_pred,
'score_phr': score_phr}
class EdgeBBoxExtend(nn.Block):
'''encode the bounding boxes'''
def __init__(self):
super(EdgeBBoxExtend, self).__init__()
def bbox_delta(self, bbox_a, bbox_b):
n = bbox_a.shape[0]
result = nd.zeros((n, 4), ctx=bbox_a.context)
result[:,0] = bbox_a[:,0] - bbox_b[:,0]
result[:,1] = bbox_a[:,1] - bbox_b[:,1]
result[:,2] = nd.log((bbox_a[:,2] - bbox_a[:,0] + 1e-8) / (bbox_b[:,2] - bbox_b[:,0] + 1e-8))
result[:,3] = nd.log((bbox_a[:,3] - bbox_a[:,1] + 1e-8) / (bbox_b[:,3] - bbox_b[:,1] + 1e-8))
return result
def forward(self, edges):
ctx = edges.src['pred_bbox'].context
n = edges.src['pred_bbox'].shape[0]
delta_src_obj = self.bbox_delta(edges.src['pred_bbox'], edges.dst['pred_bbox'])
delta_src_rel = self.bbox_delta(edges.src['pred_bbox'], edges.data['rel_bbox'])
delta_rel_obj = self.bbox_delta(edges.data['rel_bbox'], edges.dst['pred_bbox'])
result = nd.zeros((n, 12), ctx=ctx)
result[:,0:4] = delta_src_obj
result[:,4:8] = delta_src_rel
result[:,8:12] = delta_rel_obj
return {'pred_bbox_additional': result}
class EdgeFreqPrior(nn.Block):
'''make use of the pre-trained frequency prior'''
def __init__(self, prior_pkl):
super(EdgeFreqPrior, self).__init__()
with open(prior_pkl, 'rb') as f:
freq_prior = pickle.load(f)
self.freq_prior = freq_prior
def forward(self, edges):
ctx = edges.src['node_class_pred'].context
src_ind = edges.src['node_class_pred'].asnumpy().astype(int)
dst_ind = edges.dst['node_class_pred'].asnumpy().astype(int)
prob = self.freq_prior[src_ind, dst_ind]
out = nd.array(prob, ctx=ctx)
return {'freq_prior': out}
class EdgeSpatial(nn.Block):
'''spatial feature branch'''
def __init__(self, n_classes):
super(EdgeSpatial, self).__init__()
self.mlp = nn.Sequential()
self.mlp.add(nn.Dense(64))
self.mlp.add(nn.LeakyReLU(0.1))
self.mlp.add(nn.Dense(64))
self.mlp.add(nn.LeakyReLU(0.1))
self.mlp.add(nn.Dense(n_classes))
def forward(self, edges):
feat = nd.concat(edges.src['pred_bbox'], edges.dst['pred_bbox'],
edges.data['rel_bbox'], edges.data['pred_bbox_additional'])
out = self.mlp(feat)
return {'spatial': out}
class EdgeVisual(nn.Block):
'''visual feature branch'''
def __init__(self, n_classes, vis_feat_dim=7*7*3):
super(EdgeVisual, self).__init__()
self.dim_in = vis_feat_dim
self.mlp_joint = nn.Sequential()
self.mlp_joint.add(nn.Dense(vis_feat_dim // 2))
self.mlp_joint.add(nn.LeakyReLU(0.1))
self.mlp_joint.add(nn.Dense(vis_feat_dim // 3))
self.mlp_joint.add(nn.LeakyReLU(0.1))
self.mlp_joint.add(nn.Dense(n_classes))
self.mlp_sub = nn.Dense(n_classes)
self.mlp_ob = nn.Dense(n_classes)
def forward(self, edges):
feat = nd.concat(edges.src['node_feat'], edges.dst['node_feat'], edges.data['edge_feat'])
out_joint = self.mlp_joint(feat)
out_sub = self.mlp_sub(edges.src['node_feat'])
out_ob = self.mlp_ob(edges.dst['node_feat'])
out = out_joint + out_sub + out_ob
return {'visual': out}
class RelDN(nn.Block):
'''The RelDN Model'''
def __init__(self, n_classes, prior_pkl, semantic_only=False):
super(RelDN, self).__init__()
# output layers
self.edge_bbox_extend = EdgeBBoxExtend()
# semantic through mlp encoding
if prior_pkl is not None:
self.freq_prior = EdgeFreqPrior(prior_pkl)
# with predicate class and a link class
self.spatial = EdgeSpatial(n_classes + 1)
# with visual features
self.visual = EdgeVisual(n_classes + 1)
self.edge_conf_mlp = EdgeConfMLP()
self.semantic_only = semantic_only
def forward(self, g):
if g is None or g.number_of_nodes() == 0:
return g
# predictions
g.apply_edges(self.freq_prior)
if self.semantic_only:
g.edata['preds'] = g.edata['freq_prior']
else:
# bbox extension
g.apply_edges(self.edge_bbox_extend)
g.apply_edges(self.spatial)
g.apply_edges(self.visual)
g.edata['preds'] = g.edata['freq_prior'] + g.edata['spatial'] + g.edata['visual']
# subgraph for gconv
g.apply_edges(self.edge_conf_mlp)
return g
This diff is collapsed.
MXNET_CUDNN_AUTOTUNE_DEFAULT=0 CUDNN_AUTOTUNE_DEFAULT=0 MXNET_GPU_MEM_POOL_TYPE=Round MXNET_GPU_MEM_POOL_ROUND_LINEAR_CUTOFF=28 python train_faster_rcnn.py \
--gpus 0,1,2,3,4,5,6,7 --dataset visualgenome -j 60 --batch-size 8 --val-interval 20 --save-prefix faster_rcnn_resnet101_v1d_visualgenome/
import numpy as np
import json, pickle, os, argparse
def parse_args():
parser = argparse.ArgumentParser(description='Train the Frequenct Prior For RelDN.')
parser.add_argument('--overlap', action='store_true',
help="Only count overlap boxes.")
parser.add_argument('--json-path', type=str, default='~/.mxnet/datasets/visualgenome',
help="Only count overlap boxes.")
args = parser.parse_args()
return args
args = parse_args()
use_overlap = args.overlap
PATH_TO_DATASETS = os.path.expanduser(args.json_path)
path_to_json = os.path.join(PATH_TO_DATASETS, 'rel_annotations_train.json')
# format in y1y2x1x2
def with_overlap(boxA, boxB):
xA = max(boxA[2], boxB[2])
xB = min(boxA[3], boxB[3])
if xB > xA:
yA = max(boxA[0], boxB[0])
yB = min(boxA[1], boxB[1])
if yB > yA:
return 1
return 0
def box_ious(boxes):
n = len(boxes)
res = np.zeros((n, n))
for i in range(n-1):
for j in range(i+1, n):
iou_val = with_overlap(boxes[i], boxes[j])
res[i, j] = iou_val
res[j, i] = iou_val
return res
with open(path_to_json, 'r') as f:
tmp = f.read()
train_data = json.loads(tmp)
fg_matrix = np.zeros((150, 150, 51), dtype=np.int64)
bg_matrix = np.zeros((150, 150), dtype=np.int64)
for _, item in train_data.items():
gt_box_to_label = {}
for rel in item:
sub_bbox = rel['subject']['bbox']
ob_bbox = rel['object']['bbox']
sub_class = rel['subject']['category']
ob_class = rel['object']['category']
rel_class = rel['predicate']
sub_node = tuple(sub_bbox)
ob_node = tuple(ob_bbox)
if sub_node not in gt_box_to_label:
gt_box_to_label[sub_node] = sub_class
if ob_node not in gt_box_to_label:
gt_box_to_label[ob_node] = ob_class
fg_matrix[sub_class, ob_class, rel_class + 1] += 1
if use_overlap:
gt_boxes = [*gt_box_to_label]
gt_classes = np.array([*gt_box_to_label.values()])
iou_mat = box_ious(gt_boxes)
cols, rows = np.where(iou_mat)
if len(cols) and len(rows):
for col, row in zip(cols, rows):
bg_matrix[gt_classes[col], gt_classes[row]] += 1
else:
all_possib = np.ones_like(iou_mat, dtype=np.bool)
np.fill_diagonal(all_possib, 0)
cols, rows = np.where(all_possib)
for col, row in zip(cols, rows):
bg_matrix[gt_classes[col], gt_classes[row]] += 1
else:
for b1, l1 in gt_box_to_label.items():
for b2, l2 in gt_box_to_label.items():
if b1 == b2:
continue
bg_matrix[l1, l2] += 1
eps = 1e-3
bg_matrix += 1
fg_matrix[:, :, 0] = bg_matrix
pred_dist = np.log(fg_matrix / (fg_matrix.sum(2)[:, :, None] + eps) + eps)
if use_overlap:
with open('freq_prior_overlap.pkl', 'wb') as f:
pickle.dump(pred_dist, f)
else:
with open('freq_prior.pkl', 'wb') as f:
pickle.dump(pred_dist, f)
import dgl
import mxnet as mx
import numpy as np
import logging, time, argparse
from mxnet import nd, gluon
from gluoncv.data.batchify import Pad
from gluoncv.utils import makedirs
from model import faster_rcnn_resnet101_v1d_custom, RelDN
from utils import *
from data import *
def parse_args():
parser = argparse.ArgumentParser(description='Train RelDN Model.')
parser.add_argument('--gpus', type=str, default='0',
help="Training with GPUs, you can specify 1,3 for example.")
parser.add_argument('--batch-size', type=int, default=8,
help="Total batch-size for training.")
parser.add_argument('--epochs', type=int, default=9,
help="Training epochs.")
parser.add_argument('--lr-reldn', type=float, default=0.01,
help="Learning rate for RelDN module.")
parser.add_argument('--wd-reldn', type=float, default=0.0001,
help="Weight decay for RelDN module.")
parser.add_argument('--lr-faster-rcnn', type=float, default=0.01,
help="Learning rate for Faster R-CNN module.")
parser.add_argument('--wd-faster-rcnn', type=float, default=0.0001,
help="Weight decay for RelDN module.")
parser.add_argument('--lr-decay-epochs', type=str, default='5,8',
help="Learning rate decay points.")
parser.add_argument('--lr-warmup-iters', type=int, default=4000,
help="Learning rate warm-up iterations.")
parser.add_argument('--save-dir', type=str, default='params_resnet101_v1d_reldn',
help="Path to save model parameters.")
parser.add_argument('--log-dir', type=str, default='reldn_output.log',
help="Path to save training logs.")
parser.add_argument('--pretrained-faster-rcnn-params', type=str, required=True,
help="Path to saved Faster R-CNN model parameters.")
parser.add_argument('--freq-prior', type=str, default='freq_prior.pkl',
help="Path to saved frequency prior data.")
parser.add_argument('--verbose-freq', type=int, default=100,
help="Frequency of log printing in number of iterations.")
args = parser.parse_args()
return args
args = parse_args()
filehandler = logging.FileHandler(args.log_dir)
streamhandler = logging.StreamHandler()
logger = logging.getLogger('')
logger.setLevel(logging.INFO)
logger.addHandler(filehandler)
logger.addHandler(streamhandler)
# Hyperparams
ctx = [mx.gpu(int(i)) for i in args.gpus.split(',') if i.strip()]
if ctx:
num_gpus = len(ctx)
assert args.batch_size % num_gpus == 0
per_device_batch_size = int(args.batch_size / num_gpus)
else:
ctx = [mx.cpu()]
per_device_batch_size = args.batch_size
aggregate_grad = per_device_batch_size > 1
nepoch = args.epochs
N_relations = 50
N_objects = 150
save_dir = args.save_dir
makedirs(save_dir)
batch_verbose_freq = args.verbose_freq
lr_decay_epochs = [int(i) for i in args.lr_decay_epochs.split(',')]
# Dataset and dataloader
vg_train = VGRelation(split='train')
logger.info('data loaded!')
train_data = gluon.data.DataLoader(vg_train, batch_size=len(ctx), shuffle=True, num_workers=8*num_gpus,
batchify_fn=dgl_mp_batchify_fn)
n_batches = len(train_data)
# Network definition
net = RelDN(n_classes=N_relations, prior_pkl=args.freq_prior)
net.spatial.initialize(mx.init.Normal(1e-4), ctx=ctx)
net.visual.initialize(mx.init.Normal(1e-4), ctx=ctx)
for k, v in net.collect_params().items():
v.grad_req = 'add' if aggregate_grad else 'write'
net_params = net.collect_params()
net_trainer = gluon.Trainer(net.collect_params(), 'adam',
{'learning_rate': args.lr_reldn, 'wd': args.wd_reldn})
det_params_path = args.pretrained_faster_rcnn_params
detector = faster_rcnn_resnet101_v1d_custom(classes=vg_train.obj_classes,
pretrained_base=False, pretrained=False,
additional_output=True)
detector.load_parameters(det_params_path, ctx=ctx, ignore_extra=True, allow_missing=True)
for k, v in detector.collect_params().items():
v.grad_req = 'null'
detector_feat = faster_rcnn_resnet101_v1d_custom(classes=vg_train.obj_classes,
pretrained_base=False, pretrained=False,
additional_output=True)
detector_feat.load_parameters(det_params_path, ctx=ctx, ignore_extra=True, allow_missing=True)
for k, v in detector_feat.collect_params().items():
v.grad_req = 'null'
for k, v in detector_feat.features.collect_params().items():
v.grad_req = 'add' if aggregate_grad else 'write'
det_params = detector_feat.features.collect_params()
det_trainer = gluon.Trainer(detector_feat.features.collect_params(), 'adam',
{'learning_rate': args.lr_faster_rcnn, 'wd': args.wd_faster_rcnn})
def get_data_batch(g_list, img_list, ctx_list):
if g_list is None or len(g_list) == 0:
return None, None
n_gpu = len(ctx_list)
size = len(g_list)
if size < n_gpu:
raise Exception("too small batch")
step = size // n_gpu
G_list = [g_list[i*step:(i+1)*step] if i < n_gpu - 1 else g_list[i*step:size] for i in range(n_gpu)]
img_list = [img_list[i*step:(i+1)*step] if i < n_gpu - 1 else img_list[i*step:size] for i in range(n_gpu)]
for G_slice, ctx in zip(G_list, ctx_list):
for G in G_slice:
G.ndata['bbox'] = G.ndata['bbox'].as_in_context(ctx)
G.ndata['node_class'] = G.ndata['node_class'].as_in_context(ctx)
G.ndata['node_class_vec'] = G.ndata['node_class_vec'].as_in_context(ctx)
G.edata['rel_class'] = G.edata['rel_class'].as_in_context(ctx)
img_list = [img.as_in_context(ctx) for img in img_list]
return G_list, img_list
L_rel = gluon.loss.SoftmaxCELoss()
train_metric = mx.metric.Accuracy(name='rel_acc')
train_metric_top5 = mx.metric.TopKAccuracy(5, name='rel_acc_top5')
metric_list = [train_metric, train_metric_top5]
def batch_print(epoch, i, batch_verbose_freq, n_batches, btic, loss_rel_val, metric_list):
if (i+1) % batch_verbose_freq == 0:
print_txt = 'Epoch[%d] Batch[%d/%d], time: %d, loss_rel=%.4f '%\
(epoch, i, n_batches, int(time.time() - btic),
loss_rel_val / (i+1), )
for metric in metric_list:
metric_name, metric_val = metric.get()
print_txt += '%s=%.4f '%(metric_name, metric_val)
logger.info(print_txt)
btic = time.time()
loss_rel_val = 0
return btic, loss_rel_val
for epoch in range(nepoch):
loss_rel_val = 0
tic = time.time()
btic = time.time()
for metric in metric_list:
metric.reset()
if epoch == 0:
net_trainer_base_lr = net_trainer.learning_rate
det_trainer_base_lr = det_trainer.learning_rate
if epoch == 5 or epoch == 8:
net_trainer.set_learning_rate(net_trainer.learning_rate*0.1)
det_trainer.set_learning_rate(det_trainer.learning_rate*0.1)
for i, (G_list, img_list) in enumerate(train_data):
if epoch == 0 and i < args.lr_warmup_iters:
alpha = i / args.lr_warmup_iters
warmup_factor = 1/3 * (1 - alpha) + alpha
net_trainer.set_learning_rate(net_trainer_base_lr*warmup_factor)
det_trainer.set_learning_rate(det_trainer_base_lr*warmup_factor)
G_list, img_list = get_data_batch(G_list, img_list, ctx)
if G_list is None or img_list is None:
btic, loss_rel_val = batch_print(epoch, i, batch_verbose_freq, n_batches, btic, loss_rel_val, metric_list)
continue
loss = []
detector_res_list = []
G_batch = []
bbox_pad = Pad(axis=(0))
with mx.autograd.record():
for G_slice, img in zip(G_list, img_list):
cur_ctx = img.context
bbox_list = [G.ndata['bbox'] for G in G_slice]
bbox_stack = bbox_pad(bbox_list).as_in_context(cur_ctx)
with mx.autograd.pause():
ids, scores, bbox, feat, feat_ind, spatial_feat = detector(img)
g_pred_batch = build_graph_train(G_slice, bbox_stack, img, ids, scores, bbox, feat_ind,
spatial_feat, scores_top_k=300, overlap=False)
g_batch = l0_sample(g_pred_batch)
if g_batch is None:
continue
rel_bbox = g_batch.edata['rel_bbox']
batch_id = g_batch.edata['batch_id'].asnumpy()
n_sample_edges = g_batch.number_of_edges()
n_graph = len(G_slice)
bbox_rel_list = []
for j in range(n_graph):
eids = np.where(batch_id == j)[0]
if len(eids) > 0:
bbox_rel_list.append(rel_bbox[eids])
bbox_rel_stack = bbox_pad(bbox_rel_list).as_in_context(cur_ctx)
img_size = img.shape[2:4]
bbox_rel_stack[:, :, 0] *= img_size[1]
bbox_rel_stack[:, :, 1] *= img_size[0]
bbox_rel_stack[:, :, 2] *= img_size[1]
bbox_rel_stack[:, :, 3] *= img_size[0]
_, _, _, spatial_feat_rel = detector_feat(img, None, None, bbox_rel_stack)
spatial_feat_rel_list = []
for j in range(n_graph):
eids = np.where(batch_id == j)[0]
if len(eids) > 0:
spatial_feat_rel_list.append(spatial_feat_rel[j, 0:len(eids)])
g_batch.edata['edge_feat'] = nd.concat(*spatial_feat_rel_list, dim=0)
G_batch.append(g_batch)
G_batch = [net(G) for G in G_batch]
for G_pred, img in zip(G_batch, img_list):
if G_pred is None or G_pred.number_of_nodes() == 0:
continue
loss_rel = L_rel(G_pred.edata['preds'], G_pred.edata['rel_class'],
G_pred.edata['sample_weights'])
loss.append(loss_rel.sum())
loss_rel_val += loss_rel.mean().asscalar() / num_gpus
if len(loss) == 0:
btic, loss_rel_val = batch_print(epoch, i, batch_verbose_freq, n_batches, btic, loss_rel_val, metric_list)
continue
for l in loss:
l.backward()
if (i+1) % per_device_batch_size == 0 or i == n_batches - 1:
net_trainer.step(args.batch_size)
det_trainer.step(args.batch_size)
if aggregate_grad:
for k, v in net_params.items():
v.zero_grad()
for k, v in det_params.items():
v.zero_grad()
for G_pred, img_slice in zip(G_batch, img_list):
if G_pred is None or G_pred.number_of_nodes() == 0:
continue
link_ind = np.where(G_pred.edata['rel_class'].asnumpy() > 0)[0]
if len(link_ind) == 0:
continue
train_metric.update([G_pred.edata['rel_class'][link_ind]],
[G_pred.edata['preds'][link_ind]])
train_metric_top5.update([G_pred.edata['rel_class'][link_ind]],
[G_pred.edata['preds'][link_ind]])
btic, loss_rel_val = batch_print(epoch, i, batch_verbose_freq, n_batches, btic, loss_rel_val, metric_list)
if (i+1) % batch_verbose_freq == 0:
net.save_parameters('%s/model-%d.params'%(save_dir, epoch))
detector_feat.features.save_parameters('%s/detector_feat.features-%d.params'%(save_dir, epoch))
print_txt = 'Epoch[%d], time: %d, loss_rel=%.4f,'%\
(epoch, int(time.time() - tic),
loss_rel_val / (i+1))
for metric in metric_list:
metric_name, metric_val = metric.get()
print_txt += '%s=%.4f '%(metric_name, metric_val)
logger.info(print_txt)
net.save_parameters('%s/model-%d.params'%(save_dir, epoch))
detector_feat.features.save_parameters('%s/detector_feat.features-%d.params'%(save_dir, epoch))
MXNET_CUDNN_AUTOTUNE_DEFAULT=0 python train_reldn.py \
--pretrained-faster-rcnn-params faster_rcnn_resnet101_v1d_visualgenome/faster_rcnn_resnet101_v1d_custom_best.params
from .metric import *
from .build_graph import *
from .sampling import *
from .viz import *
import dgl
from mxnet import nd
import numpy as np
def bbox_improve(bbox):
'''bbox encoding'''
area = (bbox[:,2] - bbox[:,0]) * (bbox[:,3] - bbox[:,1])
return nd.concat(bbox, area.expand_dims(1))
def extract_edge_bbox(g):
'''bbox encoding'''
src, dst = g.edges(order='eid')
n = g.number_of_edges()
src_bbox = g.ndata['pred_bbox'][src.asnumpy()]
dst_bbox = g.ndata['pred_bbox'][dst.asnumpy()]
edge_bbox = nd.zeros((n, 4), ctx=g.ndata['pred_bbox'].context)
edge_bbox[:,0] = nd.stack(src_bbox[:,0], dst_bbox[:,0]).min(axis=0)
edge_bbox[:,1] = nd.stack(src_bbox[:,1], dst_bbox[:,1]).min(axis=0)
edge_bbox[:,2] = nd.stack(src_bbox[:,2], dst_bbox[:,2]).max(axis=0)
edge_bbox[:,3] = nd.stack(src_bbox[:,3], dst_bbox[:,3]).max(axis=0)
return edge_bbox
def build_graph_train(g_slice, gt_bbox, img, ids, scores, bbox, feat_ind,
spatial_feat, iou_thresh=0.5,
bbox_improvement=True, scores_top_k=50, overlap=False):
'''given ground truth and predicted bboxes, assign the label to the predicted w.r.t iou_thresh'''
# match and re-factor the graph
img_size = img.shape[2:4]
gt_bbox[:, :, 0] /= img_size[1]
gt_bbox[:, :, 1] /= img_size[0]
gt_bbox[:, :, 2] /= img_size[1]
gt_bbox[:, :, 3] /= img_size[0]
bbox[:, :, 0] /= img_size[1]
bbox[:, :, 1] /= img_size[0]
bbox[:, :, 2] /= img_size[1]
bbox[:, :, 3] /= img_size[0]
n_graph = len(g_slice)
g_pred_batch = []
for gi in range(n_graph):
g = g_slice[gi]
ctx = g.ndata['bbox'].context
inds = np.where(scores[gi, :, 0].asnumpy() > 0)[0].tolist()
if len(inds) == 0:
return None
if len(inds) > scores_top_k:
top_score_inds = scores[gi, inds, 0].asnumpy().argsort()[::-1][0:scores_top_k]
inds = np.array(inds)[top_score_inds].tolist()
n_nodes = len(inds)
roi_ind = feat_ind[gi, inds].squeeze(axis=1)
g_pred = dgl.DGLGraph(multigraph=True)
g_pred.add_nodes(n_nodes, {'pred_bbox': bbox[gi, inds],
'node_feat': spatial_feat[gi, roi_ind],
'node_class_pred': ids[gi, inds, 0],
'node_class_logit': nd.log(scores[gi, inds, 0] + 1e-7)})
# iou matching
ious = nd.contrib.box_iou(gt_bbox[gi], g_pred.ndata['pred_bbox']).asnumpy()
H, W = ious.shape
h = H
w = W
pred_to_gt_ind = np.array([-1 for i in range(W)])
pred_to_gt_class_match = [0 for i in range(W)]
pred_to_gt_class_match_id = [0 for i in range(W)]
while h > 0 and w > 0:
ind = int(ious.argmax())
row_ind = ind // W
col_ind = ind % W
if ious[row_ind, col_ind] < iou_thresh:
break
pred_to_gt_ind[col_ind] = row_ind
gt_node_class = g.ndata['node_class'][row_ind]
pred_node_class = g_pred.ndata['node_class_pred'][col_ind]
if gt_node_class == pred_node_class:
pred_to_gt_class_match[col_ind] = 1
pred_to_gt_class_match_id[col_ind] = row_ind
ious[row_ind, :] = -1
ious[:, col_ind] = -1
h -= 1
w -= 1
n_nodes = g_pred.number_of_nodes()
triplet = []
adjmat = np.zeros((n_nodes, n_nodes))
src, dst = g.all_edges(order='eid')
eid_keys = np.column_stack([src.asnumpy(), dst.asnumpy()])
eid_dict = {}
for i, key in enumerate(eid_keys):
k = tuple(key)
if k not in eid_dict:
eid_dict[k] = [i]
else:
eid_dict[k].append(i)
ori_rel_class = g.edata['rel_class'].asnumpy()
for i in range(n_nodes):
for j in range(n_nodes):
if i != j:
if pred_to_gt_class_match[i] and pred_to_gt_class_match[j]:
sub_gt_id = pred_to_gt_class_match_id[i]
ob_gt_id = pred_to_gt_class_match_id[j]
eids = eid_dict[(sub_gt_id, ob_gt_id)]
rel_cls = ori_rel_class[eids]
n_edges_between = len(rel_cls)
for ii in range(n_edges_between):
triplet.append((i, j, rel_cls[ii]))
adjmat[i,j] = 1
else:
triplet.append((i, j, 0))
src, dst, rel_class = tuple(zip(*triplet))
rel_class = nd.array(rel_class, ctx=ctx).expand_dims(1)
g_pred.add_edges(src, dst, data={'rel_class': rel_class})
# other operations
n_nodes = g_pred.number_of_nodes()
n_edges = g_pred.number_of_edges()
if bbox_improvement:
g_pred.ndata['pred_bbox'] = bbox_improve(g_pred.ndata['pred_bbox'])
g_pred.edata['rel_bbox'] = extract_edge_bbox(g_pred)
g_pred.edata['batch_id'] = nd.zeros((n_edges, 1), ctx = ctx) + gi
# remove non-overlapping edges
if overlap:
overlap_ious = nd.contrib.box_iou(g_pred.ndata['pred_bbox'][:,0:4],
g_pred.ndata['pred_bbox'][:,0:4]).asnumpy()
cols, rows = np.where(overlap_ious <= 1e-7)
if cols.shape[0] > 0:
eids = g_pred.edge_ids(cols, rows)[2].asnumpy().tolist()
if len(eids):
g_pred.remove_edges(eids)
if g_pred.number_of_edges() == 0:
g_pred = None
g_pred_batch.append(g_pred)
if n_graph > 1:
return dgl.batch(g_pred_batch)
else:
return g_pred_batch[0]
def build_graph_validate_gt_obj(img, gt_ids, bbox, spatial_feat,
bbox_improvement=True, overlap=False):
'''given ground truth bbox and label, build graph for validation'''
n_batch = img.shape[0]
img_size = img.shape[2:4]
bbox[:, :, 0] /= img_size[1]
bbox[:, :, 1] /= img_size[0]
bbox[:, :, 2] /= img_size[1]
bbox[:, :, 3] /= img_size[0]
ctx = img.context
g_batch = []
for btc in range(n_batch):
inds = np.where(bbox[btc].sum(1).asnumpy() > 0)[0].tolist()
if len(inds) == 0:
continue
n_nodes = len(inds)
g_pred = dgl.DGLGraph()
g_pred.add_nodes(n_nodes, {'pred_bbox': bbox[btc, inds],
'node_feat': spatial_feat[btc, inds],
'node_class_pred': gt_ids[btc, inds, 0],
'node_class_logit': nd.zeros_like(gt_ids[btc, inds, 0], ctx=ctx)})
edge_list = []
for i in range(n_nodes - 1):
for j in range(i + 1, n_nodes):
edge_list.append((i, j))
src, dst = tuple(zip(*edge_list))
g_pred.add_edges(src, dst)
g_pred.add_edges(dst, src)
n_nodes = g_pred.number_of_nodes()
n_edges = g_pred.number_of_edges()
if bbox_improvement:
g_pred.ndata['pred_bbox'] = bbox_improve(g_pred.ndata['pred_bbox'])
g_pred.edata['rel_bbox'] = extract_edge_bbox(g_pred)
g_pred.edata['batch_id'] = nd.zeros((n_edges, 1), ctx = ctx) + btc
g_batch.append(g_pred)
if len(g_batch) == 0:
return None
if len(g_batch) > 1:
return dgl.batch(g_batch)
return g_batch[0]
def build_graph_validate_gt_bbox(img, ids, scores, bbox, spatial_feat, gt_ids=None,
bbox_improvement=True, overlap=False):
'''given ground truth bbox, build graph for validation'''
n_batch = img.shape[0]
img_size = img.shape[2:4]
bbox[:, :, 0] /= img_size[1]
bbox[:, :, 1] /= img_size[0]
bbox[:, :, 2] /= img_size[1]
bbox[:, :, 3] /= img_size[0]
ctx = img.context
g_batch = []
for btc in range(n_batch):
id_btc = scores[btc][:,:,0].argmax(0)
score_btc = scores[btc][:,:,0].max(0)
inds = np.where(bbox[btc].sum(1).asnumpy() > 0)[0].tolist()
if len(inds) == 0:
continue
n_nodes = len(inds)
g_pred = dgl.DGLGraph()
g_pred.add_nodes(n_nodes, {'pred_bbox': bbox[btc, inds],
'node_feat': spatial_feat[btc, inds],
'node_class_pred': id_btc,
'node_class_logit': nd.log(score_btc + 1e-7)})
edge_list = []
for i in range(n_nodes - 1):
for j in range(i + 1, n_nodes):
edge_list.append((i, j))
src, dst = tuple(zip(*edge_list))
g_pred.add_edges(src, dst)
g_pred.add_edges(dst, src)
n_nodes = g_pred.number_of_nodes()
n_edges = g_pred.number_of_edges()
if bbox_improvement:
g_pred.ndata['pred_bbox'] = bbox_improve(g_pred.ndata['pred_bbox'])
g_pred.edata['rel_bbox'] = extract_edge_bbox(g_pred)
g_pred.edata['batch_id'] = nd.zeros((n_edges, 1), ctx = ctx) + btc
g_batch.append(g_pred)
if len(g_batch) == 0:
return None
if len(g_batch) > 1:
return dgl.batch(g_batch)
return g_batch[0]
def build_graph_validate_pred(img, ids, scores, bbox, feat_ind, spatial_feat,
bbox_improvement=True, scores_top_k=50, overlap=False):
'''given predicted bbox, build graph for validation'''
n_batch = img.shape[0]
img_size = img.shape[2:4]
bbox[:, :, 0] /= img_size[1]
bbox[:, :, 1] /= img_size[0]
bbox[:, :, 2] /= img_size[1]
bbox[:, :, 3] /= img_size[0]
ctx = img.context
g_batch = []
for btc in range(n_batch):
inds = np.where(scores[btc, :, 0].asnumpy() > 0)[0].tolist()
if len(inds) == 0:
continue
if len(inds) > scores_top_k:
top_score_inds = scores[btc, inds, 0].asnumpy().argsort()[::-1][0:scores_top_k]
inds = np.array(inds)[top_score_inds].tolist()
n_nodes = len(inds)
roi_ind = feat_ind[btc, inds].squeeze(axis=1)
g_pred = dgl.DGLGraph()
g_pred.add_nodes(n_nodes, {'pred_bbox': bbox[btc, inds],
'node_feat': spatial_feat[btc, roi_ind],
'node_class_pred': ids[btc, inds, 0],
'node_class_logit': nd.log(scores[btc, inds, 0] + 1e-7)})
edge_list = []
for i in range(n_nodes - 1):
for j in range(i + 1, n_nodes):
edge_list.append((i, j))
src, dst = tuple(zip(*edge_list))
g_pred.add_edges(src, dst)
g_pred.add_edges(dst, src)
n_nodes = g_pred.number_of_nodes()
n_edges = g_pred.number_of_edges()
if bbox_improvement:
g_pred.ndata['pred_bbox'] = bbox_improve(g_pred.ndata['pred_bbox'])
g_pred.edata['rel_bbox'] = extract_edge_bbox(g_pred)
g_pred.edata['batch_id'] = nd.zeros((n_edges, 1), ctx = ctx) + btc
g_batch.append(g_pred)
if len(g_batch) == 0:
return None
if len(g_batch) > 1:
return dgl.batch(g_batch)
return g_batch[0]
import dgl
import mxnet as mx
import numpy as np
import logging, time
from operator import attrgetter, itemgetter
from mxnet import nd, gluon
from mxnet.gluon import nn
from dgl.utils import toindex
from dgl.nn.mxnet import GraphConv
from gluoncv.model_zoo import get_model
from gluoncv.data.batchify import Pad
def iou(boxA, boxB):
# determine the (x, y)-coordinates of the intersection rectangle
xA = max(boxA[0], boxB[0])
yA = max(boxA[1], boxB[1])
xB = min(boxA[2], boxB[2])
yB = min(boxA[3], boxB[3])
interArea = max(0, xB - xA) * max(0, yB - yA)
if interArea < 1e-7 :
return 0
boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
if boxAArea + boxBArea - interArea < 1e-7:
return 0
iou_val = interArea / float(boxAArea + boxBArea - interArea)
return iou_val
def object_iou_thresh(gt_object, pred_object, iou_thresh=0.5):
obj_iou = iou(gt_object[1:5], pred_object[1:5])
if obj_iou >= iou_thresh:
return True
return False
def triplet_iou_thresh(pred_triplet, gt_triplet, iou_thresh=0.5):
sub_iou = iou(gt_triplet[5:9], pred_triplet[5:9])
if sub_iou >= iou_thresh:
ob_iou = iou(gt_triplet[9:13], pred_triplet[9:13])
if ob_iou >= iou_thresh:
return True
return False
@mx.metric.register
@mx.metric.alias('auc')
class AUCMetric(mx.metric.EvalMetric):
def __init__(self, name='auc', eps=1e-12):
super(AUCMetric, self).__init__(name)
self.eps = eps
def update(self, labels, preds):
mx.metric.check_label_shapes(labels, preds)
label_weight = labels[0].asnumpy()
preds = preds[0].asnumpy()
tmp = []
for i in range(preds.shape[0]):
tmp.append((label_weight[i], preds[i][1]))
tmp = sorted(tmp, key=itemgetter(1), reverse=True)
label_sum = label_weight.sum()
if label_sum == 0 or label_sum == label_weight.size:
return
label_one_num = np.count_nonzero(label_weight)
label_zero_num = len(label_weight) - label_one_num
total_area = label_zero_num * label_one_num
height = 0
width = 0
area = 0
for a, _ in tmp:
if a == 1.0:
height += 1.0
else:
width += 1.0
area += height
self.sum_metric += area / total_area
self.num_inst += 1
@mx.metric.register
@mx.metric.alias('predcls')
class PredCls(mx.metric.EvalMetric):
'''Metric with ground truth object location and label'''
def __init__(self, topk=20, iou_thresh=0.99):
super(PredCls, self).__init__('predcls@%d'%(topk))
self.topk = topk
self.iou_thresh = iou_thresh
def update(self, labels, preds):
if labels is None or preds is None:
self.num_inst += 1
return
preds = preds[preds[:,0].argsort()[::-1]]
m = min(self.topk, preds.shape[0])
count = 0
gt_edge_num = labels.shape[0]
label_matched = [False for label in labels]
for i in range(m):
pred = preds[i]
for j in range(gt_edge_num):
if label_matched[j]:
continue
label = labels[j]
if int(label[2]) == int(pred[2]) and \
triplet_iou_thresh(pred, label, self.iou_thresh):
count += 1
label_matched[j] = True
total = labels.shape[0]
self.sum_metric += count / total
self.num_inst += 1
@mx.metric.register
@mx.metric.alias('phrcls')
class PhrCls(mx.metric.EvalMetric):
'''Metric with ground truth object location and predicted object label from detector'''
def __init__(self, topk=20, iou_thresh=0.99):
super(PhrCls, self).__init__('phrcls@%d'%(topk))
self.topk = topk
self.iou_thresh = iou_thresh
def update(self, labels, preds):
if labels is None or preds is None:
self.num_inst += 1
return
preds = preds[preds[:,1].argsort()[::-1]]
m = min(self.topk, preds.shape[0])
count = 0
gt_edge_num = labels.shape[0]
label_matched = [False for label in labels]
for i in range(m):
pred = preds[i]
for j in range(gt_edge_num):
if label_matched[j]:
continue
label = labels[j]
if int(label[2]) == int(pred[2]) and \
int(label[3]) == int(pred[3]) and \
int(label[4]) == int(pred[4]) and \
triplet_iou_thresh(pred, label, self.iou_thresh):
count += 1
label_matched[j] = True
total = labels.shape[0]
self.sum_metric += count / total
self.num_inst += 1
@mx.metric.register
@mx.metric.alias('sgdet')
class SGDet(mx.metric.EvalMetric):
'''Metric with predicted object information by the detector'''
def __init__(self, topk=20, iou_thresh=0.5):
super(SGDet, self).__init__('sgdet@%d'%(topk))
self.topk = topk
self.iou_thresh = iou_thresh
def update(self, labels, preds):
if labels is None or preds is None:
self.num_inst += 1
return
preds = preds[preds[:,1].argsort()[::-1]]
m = min(self.topk, len(preds))
count = 0
gt_edge_num = labels.shape[0]
label_matched = [False for label in labels]
for i in range(m):
pred = preds[i]
for j in range(gt_edge_num):
if label_matched[j]:
continue
label = labels[j]
if int(label[2]) == int(pred[2]) and \
int(label[3]) == int(pred[3]) and \
int(label[4]) == int(pred[4]) and \
triplet_iou_thresh(pred, label, self.iou_thresh):
count += 1
label_matched[j] =True
total = labels.shape[0]
self.sum_metric += count / total
self.num_inst += 1
@mx.metric.register
@mx.metric.alias('sgdet+')
class SGDetPlus(mx.metric.EvalMetric):
'''Metric proposed by `Graph R-CNN for Scene Graph Generation`'''
def __init__(self, topk=20, iou_thresh=0.5):
super(SGDetPlus, self).__init__('sgdet+@%d'%(topk))
self.topk = topk
self.iou_thresh = iou_thresh
def update(self, labels, preds):
label_objects, label_triplets = labels
pred_objects, pred_triplets = preds
if label_objects is None or pred_objects is None:
self.num_inst += 1
return
count = 0
# count objects
object_matched = [False for obj in label_objects]
m = len(pred_objects)
gt_obj_num = label_objects.shape[0]
for i in range(m):
pred = pred_objects[i]
for j in range(gt_obj_num):
if object_matched[j]:
continue
label = label_objects[j]
if int(label[0]) == int(pred[0]) and \
object_iou_thresh(pred, label, self.iou_thresh):
count += 1
object_matched[j] = True
# count predicate and triplet
pred_triplets = pred_triplets[pred_triplets[:,1].argsort()[::-1]]
m = min(self.topk, len(pred_triplets))
gt_triplet_num = label_triplets.shape[0]
triplet_matched = [False for label in label_triplets]
predicate_matched = [False for label in label_triplets]
for i in range(m):
pred = pred_triplets[i]
for j in range(gt_triplet_num):
label = label_triplets[j]
if not predicate_matched:
if int(label[2]) == int(pred[2]) and \
triplet_iou_thresh(pred, label, self.iou_thresh):
count += label[3]
predicate_matched[j] = True
if not triplet_matched[j]:
if int(label[2]) == int(pred[2]) and \
int(label[3]) == int(pred[3]) and \
int(label[4]) == int(pred[4]) and \
triplet_iou_thresh(pred, label, self.iou_thresh):
count += 1
triplet_matched[j] = True
# compute sum
total = labels.shape[0]
N = gt_obj_num + 2 * total
self.sum_metric += count / N
self.num_inst += 1
def extract_gt(g, img_size):
'''extract prediction from ground truth graph'''
if g is None or g.number_of_nodes() == 0:
return None, None
gt_eids = np.where(g.edata['rel_class'].asnumpy() > 0)[0]
if len(gt_eids) == 0:
return None, None
gt_class = g.ndata['node_class'][:,0].asnumpy()
gt_bbox = g.ndata['bbox'].asnumpy()
gt_bbox[:, 0] /= img_size[1]
gt_bbox[:, 1] /= img_size[0]
gt_bbox[:, 2] /= img_size[1]
gt_bbox[:, 3] /= img_size[0]
gt_objects = np.vstack([gt_class, gt_bbox.transpose(1, 0)]).transpose(1, 0)
gt_node_ids = g.find_edges(gt_eids)
gt_node_sub = gt_node_ids[0].asnumpy()
gt_node_ob = gt_node_ids[1].asnumpy()
gt_rel_class = g.edata['rel_class'][gt_eids,0].asnumpy() - 1
gt_sub_class = gt_class[gt_node_sub]
gt_ob_class = gt_class[gt_node_ob]
gt_sub_bbox = gt_bbox[gt_node_sub]
gt_ob_bbox = gt_bbox[gt_node_ob]
n = len(gt_eids)
gt_triplets = np.vstack([np.ones(n), np.ones(n),
gt_rel_class, gt_sub_class, gt_ob_class,
gt_sub_bbox.transpose(1, 0),
gt_ob_bbox.transpose(1, 0)]).transpose(1, 0)
return gt_objects, gt_triplets
def extract_pred(g, topk=100, joint_preds=False):
'''extract prediction from prediction graph for validation and visualization'''
if g is None or g.number_of_nodes() == 0:
return None, None
pred_class = g.ndata['node_class_pred'].asnumpy()
pred_class_prob = g.ndata['node_class_logit'].asnumpy()
pred_bbox = g.ndata['pred_bbox'][:,0:4].asnumpy()
pred_objects = np.vstack([pred_class, pred_bbox.transpose(1, 0)]).transpose(1, 0)
score_pred = g.edata['score_pred'].asnumpy()
score_phr = g.edata['score_phr'].asnumpy()
score_pred_topk_eids = (-score_pred).argsort()[0:topk].tolist()
score_phr_topk_eids = (-score_phr).argsort()[0:topk].tolist()
topk_eids = sorted(list(set(score_pred_topk_eids + score_phr_topk_eids)))
pred_rel_prob = g.edata['preds'][topk_eids].asnumpy()
if joint_preds:
pred_rel_class = pred_rel_prob[:,1:].argmax(axis=1)
else:
pred_rel_class = pred_rel_prob.argmax(axis=1)
pred_node_ids = g.find_edges(topk_eids)
pred_node_sub = pred_node_ids[0].asnumpy()
pred_node_ob = pred_node_ids[1].asnumpy()
pred_sub_class = pred_class[pred_node_sub]
pred_sub_class_prob = pred_class_prob[pred_node_sub]
pred_sub_bbox = pred_bbox[pred_node_sub]
pred_ob_class = pred_class[pred_node_ob]
pred_ob_class_prob = pred_class_prob[pred_node_ob]
pred_ob_bbox = pred_bbox[pred_node_ob]
pred_triplets = np.vstack([score_pred[topk_eids], score_phr[topk_eids],
pred_rel_class, pred_sub_class, pred_ob_class,
pred_sub_bbox.transpose(1, 0),
pred_ob_bbox.transpose(1, 0)]).transpose(1, 0)
return pred_objects, pred_triplets
import dgl
from dgl.utils import toindex
import mxnet as mx
import numpy as np
def l0_sample(g, positive_max=128, negative_ratio=3):
'''sampling positive and negative edges'''
if g is None:
return None
n_eids = g.number_of_edges()
pos_eids = np.where(g.edata['rel_class'].asnumpy() > 0)[0]
neg_eids = np.where(g.edata['rel_class'].asnumpy() == 0)[0]
if len(pos_eids) == 0:
return None
positive_num = min(len(pos_eids), positive_max)
negative_num = min(len(neg_eids), positive_num * negative_ratio)
pos_sample = np.random.choice(pos_eids, positive_num, replace=False)
neg_sample = np.random.choice(neg_eids, negative_num, replace=False)
weights = np.zeros(n_eids)
# np.add.at(weights, pos_sample, 1)
weights[pos_sample] = 1
weights[neg_sample] = 1
# g.edata['sample_weights'] = mx.nd.array(weights, ctx=g.edata['rel_class'].context)
# return g
eids = np.where(weights > 0)[0]
sub_g = g.edge_subgraph(toindex(eids.tolist()))
sub_g.copy_from_parent()
sub_g.edata['sample_weights'] = mx.nd.array(weights[eids],
ctx=g.edata['rel_class'].context)
return sub_g
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment