Unverified Commit a9f2acf3 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4641)



* [Misc] Black auto fix.

* sort
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 08c50eb7
import argparse import argparse
import time import time
import numpy as np
import networkx as nx
import mxnet as mx import mxnet as mx
import networkx as nx
import numpy as np
from mxnet import gluon, nd from mxnet import gluon, nd
from mxnet.gluon import nn from mxnet.gluon import nn
import dgl import dgl
from dgl.data import register_data_args from dgl.data import (CiteseerGraphDataset, CoraGraphDataset,
from dgl.data import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset PubmedGraphDataset, register_data_args)
from dgl.nn.mxnet.conv import GMMConv from dgl.nn.mxnet.conv import GMMConv
class MoNet(nn.Block): class MoNet(nn.Block):
def __init__(self, def __init__(
g, self,
in_feats, g,
n_hidden, in_feats,
out_feats, n_hidden,
n_layers, out_feats,
dim, n_layers,
n_kernels, dim,
dropout): n_kernels,
dropout,
):
super(MoNet, self).__init__() super(MoNet, self).__init__()
self.g = g self.g = g
with self.name_scope(): with self.name_scope():
...@@ -28,18 +32,19 @@ class MoNet(nn.Block): ...@@ -28,18 +32,19 @@ class MoNet(nn.Block):
self.pseudo_proj = nn.Sequential() self.pseudo_proj = nn.Sequential()
# Input layer # Input layer
self.layers.add( self.layers.add(GMMConv(in_feats, n_hidden, dim, n_kernels))
GMMConv(in_feats, n_hidden, dim, n_kernels)) self.pseudo_proj.add(nn.Dense(dim, in_units=2, activation="tanh"))
self.pseudo_proj.add(nn.Dense(dim, in_units=2, activation='tanh'))
# Hidden layer # Hidden layer
for _ in range(n_layers - 1): for _ in range(n_layers - 1):
self.layers.add(GMMConv(n_hidden, n_hidden, dim, n_kernels)) self.layers.add(GMMConv(n_hidden, n_hidden, dim, n_kernels))
self.pseudo_proj.add(nn.Dense(dim, in_units=2, activation='tanh')) self.pseudo_proj.add(
nn.Dense(dim, in_units=2, activation="tanh")
)
# Output layer # Output layer
self.layers.add(GMMConv(n_hidden, out_feats, dim, n_kernels)) self.layers.add(GMMConv(n_hidden, out_feats, dim, n_kernels))
self.pseudo_proj.add(nn.Dense(dim, in_units=2, activation='tanh')) self.pseudo_proj.add(nn.Dense(dim, in_units=2, activation="tanh"))
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
...@@ -48,8 +53,7 @@ class MoNet(nn.Block): ...@@ -48,8 +53,7 @@ class MoNet(nn.Block):
for i in range(len(self.layers)): for i in range(len(self.layers)):
if i > 0: if i > 0:
h = self.dropout(h) h = self.dropout(h)
h = self.layers[i]( h = self.layers[i](self.g, h, self.pseudo_proj[i](pseudo))
self.g, h, self.pseudo_proj[i](pseudo))
return h return h
...@@ -58,16 +62,17 @@ def evaluate(model, features, pseudo, labels, mask): ...@@ -58,16 +62,17 @@ def evaluate(model, features, pseudo, labels, mask):
accuracy = ((pred == labels) * mask).sum() / mask.sum().asscalar() accuracy = ((pred == labels) * mask).sum() / mask.sum().asscalar()
return accuracy.asscalar() return accuracy.asscalar()
def main(args): def main(args):
# load and preprocess dataset # load and preprocess dataset
if args.dataset == 'cora': if args.dataset == "cora":
data = CoraGraphDataset() data = CoraGraphDataset()
elif args.dataset == 'citeseer': elif args.dataset == "citeseer":
data = CiteseerGraphDataset() data = CiteseerGraphDataset()
elif args.dataset == 'pubmed': elif args.dataset == "pubmed":
data = PubmedGraphDataset() data = PubmedGraphDataset()
else: else:
raise ValueError('Unknown dataset: {}'.format(args.dataset)) raise ValueError("Unknown dataset: {}".format(args.dataset))
g = data[0] g = data[0]
if args.gpu < 0: if args.gpu < 0:
...@@ -78,24 +83,29 @@ def main(args): ...@@ -78,24 +83,29 @@ def main(args):
ctx = mx.gpu(args.gpu) ctx = mx.gpu(args.gpu)
g = g.to(ctx) g = g.to(ctx)
features = g.ndata['feat'] features = g.ndata["feat"]
labels = mx.nd.array(g.ndata['label'], dtype="float32", ctx=ctx) labels = mx.nd.array(g.ndata["label"], dtype="float32", ctx=ctx)
train_mask = g.ndata['train_mask'] train_mask = g.ndata["train_mask"]
val_mask = g.ndata['val_mask'] val_mask = g.ndata["val_mask"]
test_mask = g.ndata['test_mask'] test_mask = g.ndata["test_mask"]
in_feats = features.shape[1] in_feats = features.shape[1]
n_classes = data.num_labels n_classes = data.num_labels
n_edges = data.graph.number_of_edges() n_edges = data.graph.number_of_edges()
print("""----Data statistics------' print(
"""----Data statistics------'
#Edges %d #Edges %d
#Classes %d #Classes %d
#Train samples %d #Train samples %d
#Val samples %d #Val samples %d
#Test samples %d""" % #Test samples %d"""
(n_edges, n_classes, % (
train_mask.sum().asscalar(), n_edges,
val_mask.sum().asscalar(), n_classes,
test_mask.sum().asscalar())) train_mask.sum().asscalar(),
val_mask.sum().asscalar(),
test_mask.sum().asscalar(),
)
)
# add self loop # add self loop
g = dgl.remove_self_loop(g) g = dgl.remove_self_loop(g)
...@@ -107,30 +117,32 @@ def main(args): ...@@ -107,30 +117,32 @@ def main(args):
vs = vs.asnumpy() vs = vs.asnumpy()
pseudo = [] pseudo = []
for i in range(g.number_of_edges()): for i in range(g.number_of_edges()):
pseudo.append([ pseudo.append(
1 / np.sqrt(g.in_degree(us[i])), [1 / np.sqrt(g.in_degree(us[i])), 1 / np.sqrt(g.in_degree(vs[i]))]
1 / np.sqrt(g.in_degree(vs[i])) )
])
pseudo = nd.array(pseudo, ctx=ctx) pseudo = nd.array(pseudo, ctx=ctx)
# create GraphSAGE model # create GraphSAGE model
model = MoNet(g, model = MoNet(
in_feats, g,
args.n_hidden, in_feats,
n_classes, args.n_hidden,
args.n_layers, n_classes,
args.pseudo_dim, args.n_layers,
args.n_kernels, args.pseudo_dim,
args.dropout args.n_kernels,
) args.dropout,
)
model.initialize(ctx=ctx) model.initialize(ctx=ctx)
n_train_samples = train_mask.sum().asscalar() n_train_samples = train_mask.sum().asscalar()
loss_fcn = gluon.loss.SoftmaxCELoss() loss_fcn = gluon.loss.SoftmaxCELoss()
print(model.collect_params()) print(model.collect_params())
trainer = gluon.Trainer(model.collect_params(), 'adam', trainer = gluon.Trainer(
{'learning_rate': args.lr, 'wd': args.weight_decay}) model.collect_params(),
"adam",
{"learning_rate": args.lr, "wd": args.weight_decay},
)
# initialize graph # initialize graph
dur = [] dur = []
...@@ -150,37 +162,55 @@ def main(args): ...@@ -150,37 +162,55 @@ def main(args):
loss.asscalar() loss.asscalar()
dur.append(time.time() - t0) dur.append(time.time() - t0)
acc = evaluate(model, features, pseudo, labels, val_mask) acc = evaluate(model, features, pseudo, labels, val_mask)
print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | " print(
"ETputs(KTEPS) {:.2f}". format( "Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
epoch, np.mean(dur), loss.asscalar(), acc, n_edges / np.mean(dur) / 1000)) "ETputs(KTEPS) {:.2f}".format(
epoch,
np.mean(dur),
loss.asscalar(),
acc,
n_edges / np.mean(dur) / 1000,
)
)
# test set accuracy # test set accuracy
acc = evaluate(model, features, pseudo, labels, test_mask) acc = evaluate(model, features, pseudo, labels, test_mask)
print("Test accuracy {:.2%}".format(acc)) print("Test accuracy {:.2%}".format(acc))
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser(description='MoNet on citation network') parser = argparse.ArgumentParser(description="MoNet on citation network")
register_data_args(parser) register_data_args(parser)
parser.add_argument("--dropout", type=float, default=0.5, parser.add_argument(
help="dropout probability") "--dropout", type=float, default=0.5, help="dropout probability"
parser.add_argument("--gpu", type=int, default=-1, )
help="gpu") parser.add_argument("--gpu", type=int, default=-1, help="gpu")
parser.add_argument("--lr", type=float, default=1e-2, parser.add_argument("--lr", type=float, default=1e-2, help="learning rate")
help="learning rate") parser.add_argument(
parser.add_argument("--n-epochs", type=int, default=200, "--n-epochs", type=int, default=200, help="number of training epochs"
help="number of training epochs") )
parser.add_argument("--n-hidden", type=int, default=16, parser.add_argument(
help="number of hidden gcn units") "--n-hidden", type=int, default=16, help="number of hidden gcn units"
parser.add_argument("--n-layers", type=int, default=1, )
help="number of hidden gcn layers") parser.add_argument(
parser.add_argument("--pseudo-dim", type=int, default=2, "--n-layers", type=int, default=1, help="number of hidden gcn layers"
help="Pseudo coordinate dimensions in GMMConv, 2 for cora and 3 for pubmed") )
parser.add_argument("--n-kernels", type=int, default=3, parser.add_argument(
help="Number of kernels in GMMConv layer") "--pseudo-dim",
parser.add_argument("--weight-decay", type=float, default=5e-5, type=int,
help="Weight for L2 loss") default=2,
help="Pseudo coordinate dimensions in GMMConv, 2 for cora and 3 for pubmed",
)
parser.add_argument(
"--n-kernels",
type=int,
default=3,
help="Number of kernels in GMMConv layer",
)
parser.add_argument(
"--weight-decay", type=float, default=5e-5, help="Weight for L2 loss"
)
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
main(args) main(args)
\ No newline at end of file
import mxnet as mx import mxnet as mx
from mxnet import gluon from mxnet import gluon
class BaseRGCN(gluon.Block): class BaseRGCN(gluon.Block):
def __init__(self, num_nodes, h_dim, out_dim, num_rels, num_bases=-1, def __init__(
num_hidden_layers=1, dropout=0, self,
use_self_loop=False, gpu_id=-1): num_nodes,
h_dim,
out_dim,
num_rels,
num_bases=-1,
num_hidden_layers=1,
dropout=0,
use_self_loop=False,
gpu_id=-1,
):
super(BaseRGCN, self).__init__() super(BaseRGCN, self).__init__()
self.num_nodes = num_nodes self.num_nodes = num_nodes
self.h_dim = h_dim self.h_dim = h_dim
......
from .dataloader import *
from .object import * from .object import *
from .relation import * from .relation import *
from .dataloader import *
"""DataLoader utils.""" """DataLoader utils."""
import dgl
from mxnet import nd
from gluoncv.data.batchify import Pad from gluoncv.data.batchify import Pad
from mxnet import nd
import dgl
def dgl_mp_batchify_fn(data): def dgl_mp_batchify_fn(data):
if isinstance(data[0], tuple): if isinstance(data[0], tuple):
data = zip(*data) data = zip(*data)
return [dgl_mp_batchify_fn(i) for i in data] return [dgl_mp_batchify_fn(i) for i in data]
for dt in data: for dt in data:
if dt is not None: if dt is not None:
if isinstance(dt, dgl.DGLGraph): if isinstance(dt, dgl.DGLGraph):
......
"""Pascal VOC object detection dataset.""" """Pascal VOC object detection dataset."""
from __future__ import absolute_import from __future__ import absolute_import, division
from __future__ import division
import os
import logging
import warnings
import json import json
import logging
import os
import pickle import pickle
import numpy as np import warnings
from collections import Counter
import mxnet as mx import mxnet as mx
import numpy as np
from gluoncv.data import COCODetection from gluoncv.data import COCODetection
from collections import Counter
class VGObject(COCODetection): class VGObject(COCODetection):
CLASSES = ["airplane", "animal", "arm", "bag", "banana", "basket", "beach", CLASSES = [
"bear", "bed", "bench", "bike", "bird", "board", "boat", "book", "airplane",
"boot", "bottle", "bowl", "box", "boy", "branch", "building", "bus", "animal",
"cabinet", "cap", "car", "cat", "chair", "child", "clock", "coat", "arm",
"counter", "cow", "cup", "curtain", "desk", "dog", "door", "drawer", "bag",
"ear", "elephant", "engine", "eye", "face", "fence", "finger", "flag", "banana",
"flower", "food", "fork", "fruit", "giraffe", "girl", "glass", "glove", "basket",
"guy", "hair", "hand", "handle", "hat", "head", "helmet", "hill", "beach",
"horse", "house", "jacket", "jean", "kid", "kite", "lady", "lamp", "bear",
"laptop", "leaf", "leg", "letter", "light", "logo", "man", "men", "bed",
"motorcycle", "mountain", "mouth", "neck", "nose", "number", "orange", "bench",
"pant", "paper", "paw", "people", "person", "phone", "pillow", "pizza", "bike",
"plane", "plant", "plate", "player", "pole", "post", "pot", "racket", "bird",
"railing", "rock", "roof", "room", "screen", "seat", "sheep", "shelf", "board",
"shirt", "shoe", "short", "sidewalk", "sign", "sink", "skateboard", "boat",
"ski", "skier", "sneaker", "snow", "sock", "stand", "street", "book",
"surfboard", "table", "tail", "tie", "tile", "tire", "toilet", "boot",
"towel", "tower", "track", "train", "tree", "truck", "trunk", "bottle",
"umbrella", "vase", "vegetable", "vehicle", "wave", "wheel", "bowl",
"window", "windshield", "wing", "wire", "woman", "zebra"] "box",
"boy",
"branch",
"building",
"bus",
"cabinet",
"cap",
"car",
"cat",
"chair",
"child",
"clock",
"coat",
"counter",
"cow",
"cup",
"curtain",
"desk",
"dog",
"door",
"drawer",
"ear",
"elephant",
"engine",
"eye",
"face",
"fence",
"finger",
"flag",
"flower",
"food",
"fork",
"fruit",
"giraffe",
"girl",
"glass",
"glove",
"guy",
"hair",
"hand",
"handle",
"hat",
"head",
"helmet",
"hill",
"horse",
"house",
"jacket",
"jean",
"kid",
"kite",
"lady",
"lamp",
"laptop",
"leaf",
"leg",
"letter",
"light",
"logo",
"man",
"men",
"motorcycle",
"mountain",
"mouth",
"neck",
"nose",
"number",
"orange",
"pant",
"paper",
"paw",
"people",
"person",
"phone",
"pillow",
"pizza",
"plane",
"plant",
"plate",
"player",
"pole",
"post",
"pot",
"racket",
"railing",
"rock",
"roof",
"room",
"screen",
"seat",
"sheep",
"shelf",
"shirt",
"shoe",
"short",
"sidewalk",
"sign",
"sink",
"skateboard",
"ski",
"skier",
"sneaker",
"snow",
"sock",
"stand",
"street",
"surfboard",
"table",
"tail",
"tie",
"tile",
"tire",
"toilet",
"towel",
"tower",
"track",
"train",
"tree",
"truck",
"trunk",
"umbrella",
"vase",
"vegetable",
"vehicle",
"wave",
"wheel",
"window",
"windshield",
"wing",
"wire",
"woman",
"zebra",
]
def __init__(self, **kwargs): def __init__(self, **kwargs):
super(VGObject, self).__init__(**kwargs) super(VGObject, self).__init__(**kwargs)
@property @property
def annotation_dir(self): def annotation_dir(self):
return '' return ""
def _parse_image_path(self, entry): def _parse_image_path(self, entry):
dirname = 'VG_100K' dirname = "VG_100K"
filename = entry['file_name'] filename = entry["file_name"]
abs_path = os.path.join(self._root, dirname, filename) abs_path = os.path.join(self._root, dirname, filename)
return abs_path return abs_path
"""Prepare Visual Genome datasets""" """Prepare Visual Genome datasets"""
import argparse
import json
import os import os
import pickle
import random
import shutil import shutil
import argparse
import zipfile import zipfile
import random
import json
import tqdm import tqdm
import pickle
from gluoncv.utils import download, makedirs from gluoncv.utils import download, makedirs
_TARGET_DIR = os.path.expanduser('~/.mxnet/datasets/visualgenome') _TARGET_DIR = os.path.expanduser("~/.mxnet/datasets/visualgenome")
def parse_args(): def parse_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Initialize Visual Genome dataset.', description="Initialize Visual Genome dataset.",
epilog='Example: python visualgenome.py --download-dir ~/visualgenome', epilog="Example: python visualgenome.py --download-dir ~/visualgenome",
formatter_class=argparse.ArgumentDefaultsHelpFormatter) formatter_class=argparse.ArgumentDefaultsHelpFormatter,
parser.add_argument('--download-dir', type=str, default='~/visualgenome/', )
help='dataset directory on disk') parser.add_argument(
parser.add_argument('--no-download', action='store_true', help='disable automatic download if set') "--download-dir",
parser.add_argument('--overwrite', action='store_true', help='overwrite downloaded files if set, in case they are corrupted') type=str,
default="~/visualgenome/",
help="dataset directory on disk",
)
parser.add_argument(
"--no-download",
action="store_true",
help="disable automatic download if set",
)
parser.add_argument(
"--overwrite",
action="store_true",
help="overwrite downloaded files if set, in case they are corrupted",
)
args = parser.parse_args() args = parser.parse_args()
return args return args
def download_vg(path, overwrite=False): def download_vg(path, overwrite=False):
_DOWNLOAD_URLS = [ _DOWNLOAD_URLS = [
('https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip', (
'a055367f675dd5476220e9b93e4ca9957b024b94'), "https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip",
('https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip', "a055367f675dd5476220e9b93e4ca9957b024b94",
'2add3aab77623549e92b7f15cda0308f50b64ecf'), ),
(
"https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip",
"2add3aab77623549e92b7f15cda0308f50b64ecf",
),
] ]
makedirs(path) makedirs(path)
for url, checksum in _DOWNLOAD_URLS: for url, checksum in _DOWNLOAD_URLS:
filename = download(url, path=path, overwrite=overwrite, sha1_hash=checksum) filename = download(
url, path=path, overwrite=overwrite, sha1_hash=checksum
)
# extract # extract
if filename.endswith('zip'): if filename.endswith("zip"):
with zipfile.ZipFile(filename) as zf: with zipfile.ZipFile(filename) as zf:
zf.extractall(path=path) zf.extractall(path=path)
# move all images into folder `VG_100K` # move all images into folder `VG_100K`
vg_100k_path = os.path.join(path, 'VG_100K') vg_100k_path = os.path.join(path, "VG_100K")
vg_100k_2_path = os.path.join(path, 'VG_100K_2') vg_100k_2_path = os.path.join(path, "VG_100K_2")
files_2 = os.listdir(vg_100k_2_path) files_2 = os.listdir(vg_100k_2_path)
for fl in files_2: for fl in files_2:
shutil.move(os.path.join(vg_100k_2_path, fl), shutil.move(
os.path.join(vg_100k_path, fl)) os.path.join(vg_100k_2_path, fl), os.path.join(vg_100k_path, fl)
)
def download_json(path, overwrite=False): def download_json(path, overwrite=False):
url = 'https://data.dgl.ai/dataset/vg.zip' url = "https://data.dgl.ai/dataset/vg.zip"
output = 'vg.zip' output = "vg.zip"
download(url, path=path) download(url, path=path)
with zipfile.ZipFile(output) as zf: with zipfile.ZipFile(output) as zf:
zf.extractall(path=path) zf.extractall(path=path)
json_path = os.path.join(path, 'vg') json_path = os.path.join(path, "vg")
json_files = os.listdir(json_path) json_files = os.listdir(json_path)
for fl in json_files: for fl in json_files:
shutil.move(os.path.join(json_path, fl), shutil.move(os.path.join(json_path, fl), os.path.join(path, fl))
os.path.join(path, fl))
os.rmdir(json_path) os.rmdir(json_path)
if __name__ == '__main__':
if __name__ == "__main__":
args = parse_args() args = parse_args()
path = os.path.expanduser(args.download_dir) path = os.path.expanduser(args.download_dir)
if not os.path.isdir(path): if not os.path.isdir(path):
if args.no_download: if args.no_download:
raise ValueError(('{} is not a valid directory, make sure it is present.' raise ValueError(
' Or you should not disable "--no-download" to grab it'.format(path))) (
"{} is not a valid directory, make sure it is present."
' Or you should not disable "--no-download" to grab it'.format(
path
)
)
)
else: else:
download_vg(path, overwrite=args.overwrite) download_vg(path, overwrite=args.overwrite)
download_json(path, overwrite=args.overwrite) download_json(path, overwrite=args.overwrite)
# make symlink # make symlink
makedirs(os.path.expanduser('~/.mxnet/datasets')) makedirs(os.path.expanduser("~/.mxnet/datasets"))
if os.path.isdir(_TARGET_DIR): if os.path.isdir(_TARGET_DIR):
os.rmdir(_TARGET_DIR) os.rmdir(_TARGET_DIR)
os.symlink(path, _TARGET_DIR) os.symlink(path, _TARGET_DIR)
"""Pascal VOC object detection dataset.""" """Pascal VOC object detection dataset."""
from __future__ import absolute_import from __future__ import absolute_import, division
from __future__ import division
import os
import logging
import warnings
import json import json
import dgl import logging
import os
import pickle import pickle
import numpy as np import warnings
from collections import Counter
import mxnet as mx import mxnet as mx
import numpy as np
from gluoncv.data.base import VisionDataset from gluoncv.data.base import VisionDataset
from collections import Counter from gluoncv.data.transforms.presets.rcnn import (
from gluoncv.data.transforms.presets.rcnn import FasterRCNNDefaultTrainTransform, FasterRCNNDefaultValTransform FasterRCNNDefaultTrainTransform, FasterRCNNDefaultValTransform)
import dgl
class VGRelation(VisionDataset): class VGRelation(VisionDataset):
def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'visualgenome'), split='train'): def __init__(
self,
root=os.path.join("~", ".mxnet", "datasets", "visualgenome"),
split="train",
):
super(VGRelation, self).__init__(root) super(VGRelation, self).__init__(root)
self._root = os.path.expanduser(root) self._root = os.path.expanduser(root)
self._img_path = os.path.join(self._root, 'VG_100K', '{}') self._img_path = os.path.join(self._root, "VG_100K", "{}")
if split == 'train': if split == "train":
self._dict_path = os.path.join(self._root, 'rel_annotations_train.json') self._dict_path = os.path.join(
elif split == 'val': self._root, "rel_annotations_train.json"
self._dict_path = os.path.join(self._root, 'rel_annotations_val.json') )
elif split == "val":
self._dict_path = os.path.join(
self._root, "rel_annotations_val.json"
)
else: else:
raise NotImplementedError raise NotImplementedError
with open(self._dict_path) as f: with open(self._dict_path) as f:
tmp = f.read() tmp = f.read()
self._dict = json.loads(tmp) self._dict = json.loads(tmp)
self._predicates_path = os.path.join(self._root, 'predicates.json') self._predicates_path = os.path.join(self._root, "predicates.json")
with open(self._predicates_path, 'r') as f: with open(self._predicates_path, "r") as f:
tmp = f.read() tmp = f.read()
self.rel_classes = json.loads(tmp) self.rel_classes = json.loads(tmp)
self.num_rel_classes = len(self.rel_classes) + 1 self.num_rel_classes = len(self.rel_classes) + 1
self._objects_path = os.path.join(self._root, 'objects.json') self._objects_path = os.path.join(self._root, "objects.json")
with open(self._objects_path, 'r') as f: with open(self._objects_path, "r") as f:
tmp = f.read() tmp = f.read()
self.obj_classes = json.loads(tmp) self.obj_classes = json.loads(tmp)
self.num_obj_classes = len(self.obj_classes) self.num_obj_classes = len(self.obj_classes)
if split == 'val': if split == "val":
self.img_transform = FasterRCNNDefaultValTransform(short=600, max_size=1000) self.img_transform = FasterRCNNDefaultValTransform(
short=600, max_size=1000
)
else: else:
self.img_transform = FasterRCNNDefaultTrainTransform(short=600, max_size=1000) self.img_transform = FasterRCNNDefaultTrainTransform(
short=600, max_size=1000
)
self.split = split self.split = split
def __len__(self): def __len__(self):
return len(self._dict) return len(self._dict)
def _hash_bbox(self, object): def _hash_bbox(self, object):
num_list = [object['category']] + object['bbox'] num_list = [object["category"]] + object["bbox"]
return '_'.join([str(num) for num in num_list]) return "_".join([str(num) for num in num_list])
def __getitem__(self, idx): def __getitem__(self, idx):
img_id = list(self._dict)[idx] img_id = list(self._dict)[idx]
...@@ -66,8 +82,8 @@ class VGRelation(VisionDataset): ...@@ -66,8 +82,8 @@ class VGRelation(VisionDataset):
sub_node_hash = [] sub_node_hash = []
ob_node_hash = [] ob_node_hash = []
for i, it in enumerate(item): for i, it in enumerate(item):
sub_node_hash.append(self._hash_bbox(it['subject'])) sub_node_hash.append(self._hash_bbox(it["subject"]))
ob_node_hash.append(self._hash_bbox(it['object'])) ob_node_hash.append(self._hash_bbox(it["object"]))
node_set = sorted(list(set(sub_node_hash + ob_node_hash))) node_set = sorted(list(set(sub_node_hash + ob_node_hash)))
n_nodes = len(node_set) n_nodes = len(node_set)
node_to_id = {} node_to_id = {}
...@@ -86,35 +102,37 @@ class VGRelation(VisionDataset): ...@@ -86,35 +102,37 @@ class VGRelation(VisionDataset):
for i, it in enumerate(item): for i, it in enumerate(item):
if not node_visited[sub_id[i]]: if not node_visited[sub_id[i]]:
ind = sub_id[i] ind = sub_id[i]
sub = it['subject'] sub = it["subject"]
node_class_ids[ind] = sub['category'] node_class_ids[ind] = sub["category"]
# y1y2x1x2 to x1y1x2y2 # y1y2x1x2 to x1y1x2y2
bbox[ind,0] = sub['bbox'][2] bbox[ind, 0] = sub["bbox"][2]
bbox[ind,1] = sub['bbox'][0] bbox[ind, 1] = sub["bbox"][0]
bbox[ind,2] = sub['bbox'][3] bbox[ind, 2] = sub["bbox"][3]
bbox[ind,3] = sub['bbox'][1] bbox[ind, 3] = sub["bbox"][1]
node_visited[ind] = True node_visited[ind] = True
if not node_visited[ob_id[i]]: if not node_visited[ob_id[i]]:
ind = ob_id[i] ind = ob_id[i]
ob = it['object'] ob = it["object"]
node_class_ids[ind] = ob['category'] node_class_ids[ind] = ob["category"]
# y1y2x1x2 to x1y1x2y2 # y1y2x1x2 to x1y1x2y2
bbox[ind,0] = ob['bbox'][2] bbox[ind, 0] = ob["bbox"][2]
bbox[ind,1] = ob['bbox'][0] bbox[ind, 1] = ob["bbox"][0]
bbox[ind,2] = ob['bbox'][3] bbox[ind, 2] = ob["bbox"][3]
bbox[ind,3] = ob['bbox'][1] bbox[ind, 3] = ob["bbox"][1]
node_visited[ind] = True node_visited[ind] = True
eta = 0.1 eta = 0.1
node_class_vec = node_class_ids[:,0].one_hot(self.num_obj_classes, node_class_vec = node_class_ids[:, 0].one_hot(
on_value = 1 - eta + eta / self.num_obj_classes, self.num_obj_classes,
off_value = eta / self.num_obj_classes) on_value=1 - eta + eta / self.num_obj_classes,
off_value=eta / self.num_obj_classes,
)
# augmentation # augmentation
if self.split == 'val': if self.split == "val":
img, bbox, _ = self.img_transform(img, bbox) img, bbox, _ = self.img_transform(img, bbox)
else: else:
img, bbox = self.img_transform(img, bbox) img, bbox = self.img_transform(img, bbox)
...@@ -126,9 +144,9 @@ class VGRelation(VisionDataset): ...@@ -126,9 +144,9 @@ class VGRelation(VisionDataset):
predicate = [] predicate = []
for i, it in enumerate(item): for i, it in enumerate(item):
adjmat[sub_id[i], ob_id[i]] = 1 adjmat[sub_id[i], ob_id[i]] = 1
predicate.append(it['predicate']) predicate.append(it["predicate"])
predicate = mx.nd.array(predicate).expand_dims(1) predicate = mx.nd.array(predicate).expand_dims(1)
g.add_edges(sub_id, ob_id, {'rel_class': mx.nd.array(predicate) + 1}) g.add_edges(sub_id, ob_id, {"rel_class": mx.nd.array(predicate) + 1})
empty_edge_list = [] empty_edge_list = []
for i in range(n_nodes): for i in range(n_nodes):
for j in range(n_nodes): for j in range(n_nodes):
...@@ -136,11 +154,13 @@ class VGRelation(VisionDataset): ...@@ -136,11 +154,13 @@ class VGRelation(VisionDataset):
empty_edge_list.append((i, j)) empty_edge_list.append((i, j))
if len(empty_edge_list) > 0: if len(empty_edge_list) > 0:
src, dst = tuple(zip(*empty_edge_list)) src, dst = tuple(zip(*empty_edge_list))
g.add_edges(src, dst, {'rel_class': mx.nd.zeros((len(empty_edge_list), 1))}) g.add_edges(
src, dst, {"rel_class": mx.nd.zeros((len(empty_edge_list), 1))}
)
# assign features # assign features
g.ndata['bbox'] = bbox g.ndata["bbox"] = bbox
g.ndata['node_class'] = node_class_ids g.ndata["node_class"] = node_class_ids
g.ndata['node_class_vec'] = node_class_vec g.ndata["node_class_vec"] = node_class_vec
return g, img return g, img
import dgl
import argparse import argparse
import mxnet as mx
import gluoncv as gcv import gluoncv as gcv
from gluoncv.utilz import download import mxnet as mx
from data import *
from gluoncv.data.transforms import presets from gluoncv.data.transforms import presets
from model import faster_rcnn_resnet101_v1d_custom, RelDN from gluoncv.utilz import download
from model import RelDN, faster_rcnn_resnet101_v1d_custom
from utils import * from utils import *
from data import *
import dgl
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description='Demo of Scene Graph Extraction.') parser = argparse.ArgumentParser(
parser.add_argument('--image', type=str, default='', description="Demo of Scene Graph Extraction."
help="The image for scene graph extraction.") )
parser.add_argument('--gpu', type=str, default='', parser.add_argument(
help="GPU id to use for inference, default is not using GPU.") "--image",
parser.add_argument('--pretrained-faster-rcnn-params', type=str, default='', type=str,
help="Path to saved Faster R-CNN model parameters.") default="",
parser.add_argument('--reldn-params', type=str, default='', help="The image for scene graph extraction.",
help="Path to saved Faster R-CNN model parameters.") )
parser.add_argument('--faster-rcnn-params', type=str, default='', parser.add_argument(
help="Path to saved Faster R-CNN model parameters.") "--gpu",
parser.add_argument('--freq-prior', type=str, default='freq_prior.pkl', type=str,
help="Path to saved frequency prior data.") default="",
help="GPU id to use for inference, default is not using GPU.",
)
parser.add_argument(
"--pretrained-faster-rcnn-params",
type=str,
default="",
help="Path to saved Faster R-CNN model parameters.",
)
parser.add_argument(
"--reldn-params",
type=str,
default="",
help="Path to saved Faster R-CNN model parameters.",
)
parser.add_argument(
"--faster-rcnn-params",
type=str,
default="",
help="Path to saved Faster R-CNN model parameters.",
)
parser.add_argument(
"--freq-prior",
type=str,
default="freq_prior.pkl",
help="Path to saved frequency prior data.",
)
args = parser.parse_args() args = parser.parse_args()
return args return args
args = parse_args() args = parse_args()
if args.gpu: if args.gpu:
ctx = mx.gpu(int(args.gpu)) ctx = mx.gpu(int(args.gpu))
...@@ -32,31 +62,47 @@ else: ...@@ -32,31 +62,47 @@ else:
ctx = mx.cpu() ctx = mx.cpu()
net = RelDN(n_classes=50, prior_pkl=args.freq_prior, semantic_only=False) net = RelDN(n_classes=50, prior_pkl=args.freq_prior, semantic_only=False)
if args.reldn_params == '': if args.reldn_params == "":
download('http://data.dgl.ai/models/SceneGraph/reldn.params') download("http://data.dgl.ai/models/SceneGraph/reldn.params")
net.load_parameters('rendl.params', ctx=ctx) net.load_parameters("rendl.params", ctx=ctx)
else: else:
net.load_parameters(args.reldn_params, ctx=ctx) net.load_parameters(args.reldn_params, ctx=ctx)
# dataset and dataloader # dataset and dataloader
vg_val = VGRelation(split='val') vg_val = VGRelation(split="val")
detector = faster_rcnn_resnet101_v1d_custom(classes=vg_val.obj_classes, detector = faster_rcnn_resnet101_v1d_custom(
pretrained_base=False, pretrained=False, classes=vg_val.obj_classes,
additional_output=True) pretrained_base=False,
if args.pretrained_faster_rcnn_params == '': pretrained=False,
download('http://data.dgl.ai/models/SceneGraph/faster_rcnn_resnet101_v1d_visualgenome.params') additional_output=True,
params_path = 'faster_rcnn_resnet101_v1d_visualgenome.params' )
if args.pretrained_faster_rcnn_params == "":
download(
"http://data.dgl.ai/models/SceneGraph/faster_rcnn_resnet101_v1d_visualgenome.params"
)
params_path = "faster_rcnn_resnet101_v1d_visualgenome.params"
else: else:
params_path = args.pretrained_faster_rcnn_params params_path = args.pretrained_faster_rcnn_params
detector.load_parameters(params_path, ctx=ctx, ignore_extra=True, allow_missing=True) detector.load_parameters(
params_path, ctx=ctx, ignore_extra=True, allow_missing=True
)
detector_feat = faster_rcnn_resnet101_v1d_custom(classes=vg_val.obj_classes, detector_feat = faster_rcnn_resnet101_v1d_custom(
pretrained_base=False, pretrained=False, classes=vg_val.obj_classes,
additional_output=True) pretrained_base=False,
detector_feat.load_parameters(params_path, ctx=ctx, ignore_extra=True, allow_missing=True) pretrained=False,
if args.faster_rcnn_params == '': additional_output=True,
download('http://data.dgl.ai/models/SceneGraph/faster_rcnn_resnet101_v1d_visualgenome.params') )
detector_feat.features.load_parameters('faster_rcnn_resnet101_v1d_visualgenome.params', ctx=ctx) detector_feat.load_parameters(
params_path, ctx=ctx, ignore_extra=True, allow_missing=True
)
if args.faster_rcnn_params == "":
download(
"http://data.dgl.ai/models/SceneGraph/faster_rcnn_resnet101_v1d_visualgenome.params"
)
detector_feat.features.load_parameters(
"faster_rcnn_resnet101_v1d_visualgenome.params", ctx=ctx
)
else: else:
detector_feat.features.load_parameters(args.faster_rcnn_params, ctx=ctx) detector_feat.features.load_parameters(args.faster_rcnn_params, ctx=ctx)
...@@ -64,24 +110,37 @@ else: ...@@ -64,24 +110,37 @@ else:
if args.image: if args.image:
image_path = args.image image_path = args.image
else: else:
gcv.utils.download('https://raw.githubusercontent.com/dmlc/web-data/master/' + gcv.utils.download(
'dgl/examples/mxnet/scenegraph/old-couple.png', "https://raw.githubusercontent.com/dmlc/web-data/master/"
'old-couple.png') + "dgl/examples/mxnet/scenegraph/old-couple.png",
image_path = 'old-couple.png' "old-couple.png",
x, img = presets.rcnn.load_test(args.image, short=detector.short, max_size=detector.max_size) )
image_path = "old-couple.png"
x, img = presets.rcnn.load_test(
args.image, short=detector.short, max_size=detector.max_size
)
x = x.as_in_context(ctx) x = x.as_in_context(ctx)
# detector prediction # detector prediction
ids, scores, bboxes, feat, feat_ind, spatial_feat = detector(x) ids, scores, bboxes, feat, feat_ind, spatial_feat = detector(x)
# build graph, extract edge features # build graph, extract edge features
g = build_graph_validate_pred(x, ids, scores, bboxes, feat_ind, spatial_feat, bbox_improvement=True, scores_top_k=75, overlap=False) g = build_graph_validate_pred(
rel_bbox = g.edata['rel_bbox'].expand_dims(0).as_in_context(ctx) x,
ids,
scores,
bboxes,
feat_ind,
spatial_feat,
bbox_improvement=True,
scores_top_k=75,
overlap=False,
)
rel_bbox = g.edata["rel_bbox"].expand_dims(0).as_in_context(ctx)
_, _, _, spatial_feat_rel = detector_feat(x, None, None, rel_bbox) _, _, _, spatial_feat_rel = detector_feat(x, None, None, rel_bbox)
g.edata['edge_feat'] = spatial_feat_rel[0] g.edata["edge_feat"] = spatial_feat_rel[0]
# graph prediction # graph prediction
g = net(g) g = net(g)
_, preds = extract_pred(g, joint_preds=True) _, preds = extract_pred(g, joint_preds=True)
preds = preds[preds[:,1].argsort()[::-1]] preds = preds[preds[:, 1].argsort()[::-1]]
plot_sg(img, preds, detector.classes, vg_val.rel_classes, 10) plot_sg(img, preds, detector.classes, vg_val.rel_classes, 10)
import dgl import pickle
import gluoncv as gcv import gluoncv as gcv
import mxnet as mx import mxnet as mx
import numpy as np import numpy as np
from mxnet import nd from mxnet import nd
from mxnet.gluon import nn from mxnet.gluon import nn
from dgl.utils import toindex
import pickle
import dgl
from dgl.nn.mxnet import GraphConv from dgl.nn.mxnet import GraphConv
from dgl.utils import toindex
__all__ = ["RelDN"]
__all__ = ['RelDN']
class EdgeConfMLP(nn.Block): class EdgeConfMLP(nn.Block):
'''compute the confidence for edges''' """compute the confidence for edges"""
def __init__(self): def __init__(self):
super(EdgeConfMLP, self).__init__() super(EdgeConfMLP, self).__init__()
def forward(self, edges): def forward(self, edges):
score_pred = nd.log_softmax(edges.data['preds'])[:,1:].max(axis=1) score_pred = nd.log_softmax(edges.data["preds"])[:, 1:].max(axis=1)
score_phr = score_pred + edges.src['node_class_logit'] + edges.dst['node_class_logit'] score_phr = (
return {'score_pred': score_pred, score_pred
'score_phr': score_phr} + edges.src["node_class_logit"]
+ edges.dst["node_class_logit"]
)
return {"score_pred": score_pred, "score_phr": score_phr}
class EdgeBBoxExtend(nn.Block): class EdgeBBoxExtend(nn.Block):
'''encode the bounding boxes''' """encode the bounding boxes"""
def __init__(self): def __init__(self):
super(EdgeBBoxExtend, self).__init__() super(EdgeBBoxExtend, self).__init__()
def bbox_delta(self, bbox_a, bbox_b): def bbox_delta(self, bbox_a, bbox_b):
n = bbox_a.shape[0] n = bbox_a.shape[0]
result = nd.zeros((n, 4), ctx=bbox_a.context) result = nd.zeros((n, 4), ctx=bbox_a.context)
result[:,0] = bbox_a[:,0] - bbox_b[:,0] result[:, 0] = bbox_a[:, 0] - bbox_b[:, 0]
result[:,1] = bbox_a[:,1] - bbox_b[:,1] result[:, 1] = bbox_a[:, 1] - bbox_b[:, 1]
result[:,2] = nd.log((bbox_a[:,2] - bbox_a[:,0] + 1e-8) / (bbox_b[:,2] - bbox_b[:,0] + 1e-8)) result[:, 2] = nd.log(
result[:,3] = nd.log((bbox_a[:,3] - bbox_a[:,1] + 1e-8) / (bbox_b[:,3] - bbox_b[:,1] + 1e-8)) (bbox_a[:, 2] - bbox_a[:, 0] + 1e-8)
/ (bbox_b[:, 2] - bbox_b[:, 0] + 1e-8)
)
result[:, 3] = nd.log(
(bbox_a[:, 3] - bbox_a[:, 1] + 1e-8)
/ (bbox_b[:, 3] - bbox_b[:, 1] + 1e-8)
)
return result return result
def forward(self, edges): def forward(self, edges):
ctx = edges.src['pred_bbox'].context ctx = edges.src["pred_bbox"].context
n = edges.src['pred_bbox'].shape[0] n = edges.src["pred_bbox"].shape[0]
delta_src_obj = self.bbox_delta(edges.src['pred_bbox'], edges.dst['pred_bbox']) delta_src_obj = self.bbox_delta(
delta_src_rel = self.bbox_delta(edges.src['pred_bbox'], edges.data['rel_bbox']) edges.src["pred_bbox"], edges.dst["pred_bbox"]
delta_rel_obj = self.bbox_delta(edges.data['rel_bbox'], edges.dst['pred_bbox']) )
delta_src_rel = self.bbox_delta(
edges.src["pred_bbox"], edges.data["rel_bbox"]
)
delta_rel_obj = self.bbox_delta(
edges.data["rel_bbox"], edges.dst["pred_bbox"]
)
result = nd.zeros((n, 12), ctx=ctx) result = nd.zeros((n, 12), ctx=ctx)
result[:,0:4] = delta_src_obj result[:, 0:4] = delta_src_obj
result[:,4:8] = delta_src_rel result[:, 4:8] = delta_src_rel
result[:,8:12] = delta_rel_obj result[:, 8:12] = delta_rel_obj
return {'pred_bbox_additional': result} return {"pred_bbox_additional": result}
class EdgeFreqPrior(nn.Block): class EdgeFreqPrior(nn.Block):
'''make use of the pre-trained frequency prior''' """make use of the pre-trained frequency prior"""
def __init__(self, prior_pkl): def __init__(self, prior_pkl):
super(EdgeFreqPrior, self).__init__() super(EdgeFreqPrior, self).__init__()
with open(prior_pkl, 'rb') as f: with open(prior_pkl, "rb") as f:
freq_prior = pickle.load(f) freq_prior = pickle.load(f)
self.freq_prior = freq_prior self.freq_prior = freq_prior
def forward(self, edges): def forward(self, edges):
ctx = edges.src['node_class_pred'].context ctx = edges.src["node_class_pred"].context
src_ind = edges.src['node_class_pred'].asnumpy().astype(int) src_ind = edges.src["node_class_pred"].asnumpy().astype(int)
dst_ind = edges.dst['node_class_pred'].asnumpy().astype(int) dst_ind = edges.dst["node_class_pred"].asnumpy().astype(int)
prob = self.freq_prior[src_ind, dst_ind] prob = self.freq_prior[src_ind, dst_ind]
out = nd.array(prob, ctx=ctx) out = nd.array(prob, ctx=ctx)
return {'freq_prior': out} return {"freq_prior": out}
class EdgeSpatial(nn.Block): class EdgeSpatial(nn.Block):
'''spatial feature branch''' """spatial feature branch"""
def __init__(self, n_classes): def __init__(self, n_classes):
super(EdgeSpatial, self).__init__() super(EdgeSpatial, self).__init__()
self.mlp = nn.Sequential() self.mlp = nn.Sequential()
...@@ -76,14 +100,20 @@ class EdgeSpatial(nn.Block): ...@@ -76,14 +100,20 @@ class EdgeSpatial(nn.Block):
self.mlp.add(nn.Dense(n_classes)) self.mlp.add(nn.Dense(n_classes))
def forward(self, edges): def forward(self, edges):
feat = nd.concat(edges.src['pred_bbox'], edges.dst['pred_bbox'], feat = nd.concat(
edges.data['rel_bbox'], edges.data['pred_bbox_additional']) edges.src["pred_bbox"],
edges.dst["pred_bbox"],
edges.data["rel_bbox"],
edges.data["pred_bbox_additional"],
)
out = self.mlp(feat) out = self.mlp(feat)
return {'spatial': out} return {"spatial": out}
class EdgeVisual(nn.Block): class EdgeVisual(nn.Block):
'''visual feature branch''' """visual feature branch"""
def __init__(self, n_classes, vis_feat_dim=7*7*3):
def __init__(self, n_classes, vis_feat_dim=7 * 7 * 3):
super(EdgeVisual, self).__init__() super(EdgeVisual, self).__init__()
self.dim_in = vis_feat_dim self.dim_in = vis_feat_dim
self.mlp_joint = nn.Sequential() self.mlp_joint = nn.Sequential()
...@@ -97,15 +127,21 @@ class EdgeVisual(nn.Block): ...@@ -97,15 +127,21 @@ class EdgeVisual(nn.Block):
self.mlp_ob = nn.Dense(n_classes) self.mlp_ob = nn.Dense(n_classes)
def forward(self, edges): def forward(self, edges):
feat = nd.concat(edges.src['node_feat'], edges.dst['node_feat'], edges.data['edge_feat']) feat = nd.concat(
edges.src["node_feat"],
edges.dst["node_feat"],
edges.data["edge_feat"],
)
out_joint = self.mlp_joint(feat) out_joint = self.mlp_joint(feat)
out_sub = self.mlp_sub(edges.src['node_feat']) out_sub = self.mlp_sub(edges.src["node_feat"])
out_ob = self.mlp_ob(edges.dst['node_feat']) out_ob = self.mlp_ob(edges.dst["node_feat"])
out = out_joint + out_sub + out_ob out = out_joint + out_sub + out_ob
return {'visual': out} return {"visual": out}
class RelDN(nn.Block): class RelDN(nn.Block):
'''The RelDN Model''' """The RelDN Model"""
def __init__(self, n_classes, prior_pkl, semantic_only=False): def __init__(self, n_classes, prior_pkl, semantic_only=False):
super(RelDN, self).__init__() super(RelDN, self).__init__()
# output layers # output layers
...@@ -121,19 +157,21 @@ class RelDN(nn.Block): ...@@ -121,19 +157,21 @@ class RelDN(nn.Block):
self.edge_conf_mlp = EdgeConfMLP() self.edge_conf_mlp = EdgeConfMLP()
self.semantic_only = semantic_only self.semantic_only = semantic_only
def forward(self, g): def forward(self, g):
if g is None or g.number_of_nodes() == 0: if g is None or g.number_of_nodes() == 0:
return g return g
# predictions # predictions
g.apply_edges(self.freq_prior) g.apply_edges(self.freq_prior)
if self.semantic_only: if self.semantic_only:
g.edata['preds'] = g.edata['freq_prior'] g.edata["preds"] = g.edata["freq_prior"]
else: else:
# bbox extension # bbox extension
g.apply_edges(self.edge_bbox_extend) g.apply_edges(self.edge_bbox_extend)
g.apply_edges(self.spatial) g.apply_edges(self.spatial)
g.apply_edges(self.visual) g.apply_edges(self.visual)
g.edata['preds'] = g.edata['freq_prior'] + g.edata['spatial'] + g.edata['visual'] g.edata["preds"] = (
g.edata["freq_prior"] + g.edata["spatial"] + g.edata["visual"]
)
# subgraph for gconv # subgraph for gconv
g.apply_edges(self.edge_conf_mlp) g.apply_edges(self.edge_conf_mlp)
return g return g
...@@ -3,29 +3,29 @@ import argparse ...@@ -3,29 +3,29 @@ import argparse
import os import os
# disable autotune # disable autotune
os.environ['MXNET_CUDNN_AUTOTUNE_DEFAULT'] = '0' os.environ["MXNET_CUDNN_AUTOTUNE_DEFAULT"] = "0"
import logging import logging
import time import time
import numpy as np
import mxnet as mx
from mxnet import gluon
from mxnet import autograd
from mxnet.contrib import amp
import gluoncv as gcv import gluoncv as gcv
import mxnet as mx
import numpy as np
from data import *
from gluoncv import data as gdata from gluoncv import data as gdata
from gluoncv import utils as gutils from gluoncv import utils as gutils
from gluoncv.data.batchify import Append, FasterRCNNTrainBatchify, Tuple
from gluoncv.data.transforms.presets.rcnn import (
FasterRCNNDefaultTrainTransform, FasterRCNNDefaultValTransform)
from gluoncv.model_zoo import get_model from gluoncv.model_zoo import get_model
from gluoncv.data.batchify import FasterRCNNTrainBatchify, Tuple, Append
from gluoncv.data.transforms.presets.rcnn import FasterRCNNDefaultTrainTransform, \
FasterRCNNDefaultValTransform
from gluoncv.utils.metrics.voc_detection import VOC07MApMetric
from gluoncv.utils.metrics.coco_detection import COCODetectionMetric from gluoncv.utils.metrics.coco_detection import COCODetectionMetric
from gluoncv.utils.parallel import Parallelizable, Parallel from gluoncv.utils.metrics.rcnn import (RCNNAccMetric, RCNNL1LossMetric,
from gluoncv.utils.metrics.rcnn import RPNAccMetric, RPNL1LossMetric, RCNNAccMetric, \ RPNAccMetric, RPNL1LossMetric)
RCNNL1LossMetric from gluoncv.utils.metrics.voc_detection import VOC07MApMetric
from gluoncv.utils.parallel import Parallel, Parallelizable
from data import * from model import (faster_rcnn_resnet50_v1b_custom,
from model import faster_rcnn_resnet101_v1d_custom, faster_rcnn_resnet50_v1b_custom faster_rcnn_resnet101_v1d_custom)
from mxnet import autograd, gluon
from mxnet.contrib import amp
try: try:
import horovod.mxnet as hvd import horovod.mxnet as hvd
...@@ -34,111 +34,229 @@ except ImportError: ...@@ -34,111 +34,229 @@ except ImportError:
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description='Train Faster-RCNN networks e2e.') parser = argparse.ArgumentParser(
parser.add_argument('--network', type=str, default='resnet101_v1d', description="Train Faster-RCNN networks e2e."
help="Base network name which serves as feature extraction base.") )
parser.add_argument('--dataset', type=str, default='visualgenome', parser.add_argument(
help='Training dataset. Now support voc and coco.') "--network",
parser.add_argument('--num-workers', '-j', dest='num_workers', type=int, type=str,
default=8, help='Number of data workers, you can use larger ' default="resnet101_v1d",
'number to accelerate data loading, ' help="Base network name which serves as feature extraction base.",
'if your CPU and GPUs are powerful.') )
parser.add_argument('--batch-size', type=int, default=8, help='Training mini-batch size.') parser.add_argument(
parser.add_argument('--gpus', type=str, default='0', "--dataset",
help='Training with GPUs, you can specify 1,3 for example.') type=str,
parser.add_argument('--epochs', type=str, default='', default="visualgenome",
help='Training epochs.') help="Training dataset. Now support voc and coco.",
parser.add_argument('--resume', type=str, default='', )
help='Resume from previously saved parameters if not None. ' parser.add_argument(
'For example, you can resume from ./faster_rcnn_xxx_0123.params') "--num-workers",
parser.add_argument('--start-epoch', type=int, default=0, "-j",
help='Starting epoch for resuming, default is 0 for new training.' dest="num_workers",
'You can specify it to 100 for example to start from 100 epoch.') type=int,
parser.add_argument('--lr', type=str, default='', default=8,
help='Learning rate, default is 0.001 for voc single gpu training.') help="Number of data workers, you can use larger "
parser.add_argument('--lr-decay', type=float, default=0.1, "number to accelerate data loading, "
help='decay rate of learning rate. default is 0.1.') "if your CPU and GPUs are powerful.",
parser.add_argument('--lr-decay-epoch', type=str, default='', )
help='epochs at which learning rate decays. default is 14,20 for voc.') parser.add_argument(
parser.add_argument('--lr-warmup', type=str, default='', "--batch-size", type=int, default=8, help="Training mini-batch size."
help='warmup iterations to adjust learning rate, default is 0 for voc.') )
parser.add_argument('--lr-warmup-factor', type=float, default=1. / 3., parser.add_argument(
help='warmup factor of base lr.') "--gpus",
parser.add_argument('--momentum', type=float, default=0.9, type=str,
help='SGD momentum, default is 0.9') default="0",
parser.add_argument('--wd', type=str, default='', help="Training with GPUs, you can specify 1,3 for example.",
help='Weight decay, default is 5e-4 for voc') )
parser.add_argument('--log-interval', type=int, default=100, parser.add_argument(
help='Logging mini-batch interval. Default is 100.') "--epochs", type=str, default="", help="Training epochs."
parser.add_argument('--save-prefix', type=str, default='', )
help='Saving parameter prefix') parser.add_argument(
parser.add_argument('--save-interval', type=int, default=1, "--resume",
help='Saving parameters epoch interval, best model will always be saved.') type=str,
parser.add_argument('--val-interval', type=int, default=1, default="",
help='Epoch interval for validation, increase the number will reduce the ' help="Resume from previously saved parameters if not None. "
'training time if validation is slow.') "For example, you can resume from ./faster_rcnn_xxx_0123.params",
parser.add_argument('--seed', type=int, default=233, )
help='Random seed to be fixed.') parser.add_argument(
parser.add_argument('--verbose', dest='verbose', action='store_true', "--start-epoch",
help='Print helpful debugging info once set.') type=int,
parser.add_argument('--mixup', action='store_true', help='Use mixup training.') default=0,
parser.add_argument('--no-mixup-epochs', type=int, default=20, help="Starting epoch for resuming, default is 0 for new training."
help='Disable mixup training if enabled in the last N epochs.') "You can specify it to 100 for example to start from 100 epoch.",
)
parser.add_argument(
"--lr",
type=str,
default="",
help="Learning rate, default is 0.001 for voc single gpu training.",
)
parser.add_argument(
"--lr-decay",
type=float,
default=0.1,
help="decay rate of learning rate. default is 0.1.",
)
parser.add_argument(
"--lr-decay-epoch",
type=str,
default="",
help="epochs at which learning rate decays. default is 14,20 for voc.",
)
parser.add_argument(
"--lr-warmup",
type=str,
default="",
help="warmup iterations to adjust learning rate, default is 0 for voc.",
)
parser.add_argument(
"--lr-warmup-factor",
type=float,
default=1.0 / 3.0,
help="warmup factor of base lr.",
)
parser.add_argument(
"--momentum",
type=float,
default=0.9,
help="SGD momentum, default is 0.9",
)
parser.add_argument(
"--wd",
type=str,
default="",
help="Weight decay, default is 5e-4 for voc",
)
parser.add_argument(
"--log-interval",
type=int,
default=100,
help="Logging mini-batch interval. Default is 100.",
)
parser.add_argument(
"--save-prefix", type=str, default="", help="Saving parameter prefix"
)
parser.add_argument(
"--save-interval",
type=int,
default=1,
help="Saving parameters epoch interval, best model will always be saved.",
)
parser.add_argument(
"--val-interval",
type=int,
default=1,
help="Epoch interval for validation, increase the number will reduce the "
"training time if validation is slow.",
)
parser.add_argument(
"--seed", type=int, default=233, help="Random seed to be fixed."
)
parser.add_argument(
"--verbose",
dest="verbose",
action="store_true",
help="Print helpful debugging info once set.",
)
parser.add_argument(
"--mixup", action="store_true", help="Use mixup training."
)
parser.add_argument(
"--no-mixup-epochs",
type=int,
default=20,
help="Disable mixup training if enabled in the last N epochs.",
)
# Norm layer options # Norm layer options
parser.add_argument('--norm-layer', type=str, default=None, parser.add_argument(
help='Type of normalization layer to use. ' "--norm-layer",
'If set to None, backbone normalization layer will be fixed,' type=str,
' and no normalization layer will be used. ' default=None,
'Currently supports \'bn\', and None, default is None.' help="Type of normalization layer to use. "
'Note that if horovod is enabled, sync bn will not work correctly.') "If set to None, backbone normalization layer will be fixed,"
" and no normalization layer will be used. "
"Currently supports 'bn', and None, default is None."
"Note that if horovod is enabled, sync bn will not work correctly.",
)
# FPN options # FPN options
parser.add_argument('--use-fpn', action='store_true', parser.add_argument(
help='Whether to use feature pyramid network.') "--use-fpn",
action="store_true",
help="Whether to use feature pyramid network.",
)
# Performance options # Performance options
parser.add_argument('--disable-hybridization', action='store_true', parser.add_argument(
help='Whether to disable hybridize the model. ' "--disable-hybridization",
'Memory usage and speed will decrese.') action="store_true",
parser.add_argument('--static-alloc', action='store_true', help="Whether to disable hybridize the model. "
help='Whether to use static memory allocation. Memory usage will increase.') "Memory usage and speed will decrese.",
parser.add_argument('--amp', action='store_true', )
help='Use MXNet AMP for mixed precision training.') parser.add_argument(
parser.add_argument('--horovod', action='store_true', "--static-alloc",
help='Use MXNet Horovod for distributed training. Must be run with OpenMPI. ' action="store_true",
'--gpus is ignored when using --horovod.') help="Whether to use static memory allocation. Memory usage will increase.",
parser.add_argument('--executor-threads', type=int, default=1, )
help='Number of threads for executor for scheduling ops. ' parser.add_argument(
'More threads may incur higher GPU memory footprint, ' "--amp",
'but may speed up throughput. Note that when horovod is used, ' action="store_true",
'it is set to 1.') help="Use MXNet AMP for mixed precision training.",
parser.add_argument('--kv-store', type=str, default='nccl', )
help='KV store options. local, device, nccl, dist_sync, dist_device_sync, ' parser.add_argument(
'dist_async are available.') "--horovod",
action="store_true",
help="Use MXNet Horovod for distributed training. Must be run with OpenMPI. "
"--gpus is ignored when using --horovod.",
)
parser.add_argument(
"--executor-threads",
type=int,
default=1,
help="Number of threads for executor for scheduling ops. "
"More threads may incur higher GPU memory footprint, "
"but may speed up throughput. Note that when horovod is used, "
"it is set to 1.",
)
parser.add_argument(
"--kv-store",
type=str,
default="nccl",
help="KV store options. local, device, nccl, dist_sync, dist_device_sync, "
"dist_async are available.",
)
args = parser.parse_args() args = parser.parse_args()
if args.horovod: if args.horovod:
if hvd is None: if hvd is None:
raise SystemExit("Horovod not found, please check if you installed it correctly.") raise SystemExit(
"Horovod not found, please check if you installed it correctly."
)
hvd.init() hvd.init()
if args.dataset == 'voc': if args.dataset == "voc":
args.epochs = int(args.epochs) if args.epochs else 20 args.epochs = int(args.epochs) if args.epochs else 20
args.lr_decay_epoch = args.lr_decay_epoch if args.lr_decay_epoch else '14,20' args.lr_decay_epoch = (
args.lr_decay_epoch if args.lr_decay_epoch else "14,20"
)
args.lr = float(args.lr) if args.lr else 0.001 args.lr = float(args.lr) if args.lr else 0.001
args.lr_warmup = args.lr_warmup if args.lr_warmup else -1 args.lr_warmup = args.lr_warmup if args.lr_warmup else -1
args.wd = float(args.wd) if args.wd else 5e-4 args.wd = float(args.wd) if args.wd else 5e-4
elif args.dataset == 'visualgenome': elif args.dataset == "visualgenome":
args.epochs = int(args.epochs) if args.epochs else 20 args.epochs = int(args.epochs) if args.epochs else 20
args.lr_decay_epoch = args.lr_decay_epoch if args.lr_decay_epoch else '14,20' args.lr_decay_epoch = (
args.lr_decay_epoch if args.lr_decay_epoch else "14,20"
)
args.lr = float(args.lr) if args.lr else 0.001 args.lr = float(args.lr) if args.lr else 0.001
args.lr_warmup = args.lr_warmup if args.lr_warmup else -1 args.lr_warmup = args.lr_warmup if args.lr_warmup else -1
args.wd = float(args.wd) if args.wd else 5e-4 args.wd = float(args.wd) if args.wd else 5e-4
elif args.dataset == 'coco': elif args.dataset == "coco":
args.epochs = int(args.epochs) if args.epochs else 26 args.epochs = int(args.epochs) if args.epochs else 26
args.lr_decay_epoch = args.lr_decay_epoch if args.lr_decay_epoch else '17,23' args.lr_decay_epoch = (
args.lr_decay_epoch if args.lr_decay_epoch else "17,23"
)
args.lr = float(args.lr) if args.lr else 0.01 args.lr = float(args.lr) if args.lr else 0.01
args.lr_warmup = args.lr_warmup if args.lr_warmup else 1000 args.lr_warmup = args.lr_warmup if args.lr_warmup else 1000
args.wd = float(args.wd) if args.wd else 1e-4 args.wd = float(args.wd) if args.wd else 1e-4
...@@ -146,71 +264,129 @@ def parse_args(): ...@@ -146,71 +264,129 @@ def parse_args():
def get_dataset(dataset, args): def get_dataset(dataset, args):
if dataset.lower() == 'voc': if dataset.lower() == "voc":
train_dataset = gdata.VOCDetection( train_dataset = gdata.VOCDetection(
splits=[(2007, 'trainval'), (2012, 'trainval')]) splits=[(2007, "trainval"), (2012, "trainval")]
val_dataset = gdata.VOCDetection( )
splits=[(2007, 'test')]) val_dataset = gdata.VOCDetection(splits=[(2007, "test")])
val_metric = VOC07MApMetric(iou_thresh=0.5, class_names=val_dataset.classes) val_metric = VOC07MApMetric(
elif dataset.lower() == 'coco': iou_thresh=0.5, class_names=val_dataset.classes
train_dataset = gdata.COCODetection(splits='instances_train2017', use_crowd=False) )
val_dataset = gdata.COCODetection(splits='instances_val2017', skip_empty=False) elif dataset.lower() == "coco":
val_metric = COCODetectionMetric(val_dataset, args.save_prefix + '_eval', cleanup=True) train_dataset = gdata.COCODetection(
elif dataset.lower() == 'visualgenome': splits="instances_train2017", use_crowd=False
train_dataset = VGObject(root=os.path.join('~', '.mxnet', 'datasets', 'visualgenome'), )
splits='detections_train', use_crowd=False) val_dataset = gdata.COCODetection(
val_dataset = VGObject(root=os.path.join('~', '.mxnet', 'datasets', 'visualgenome'), splits="instances_val2017", skip_empty=False
splits='detections_val', skip_empty=False) )
val_metric = COCODetectionMetric(val_dataset, args.save_prefix + '_eval', cleanup=True) val_metric = COCODetectionMetric(
val_dataset, args.save_prefix + "_eval", cleanup=True
)
elif dataset.lower() == "visualgenome":
train_dataset = VGObject(
root=os.path.join("~", ".mxnet", "datasets", "visualgenome"),
splits="detections_train",
use_crowd=False,
)
val_dataset = VGObject(
root=os.path.join("~", ".mxnet", "datasets", "visualgenome"),
splits="detections_val",
skip_empty=False,
)
val_metric = COCODetectionMetric(
val_dataset, args.save_prefix + "_eval", cleanup=True
)
else: else:
raise NotImplementedError('Dataset: {} not implemented.'.format(dataset)) raise NotImplementedError(
"Dataset: {} not implemented.".format(dataset)
)
if args.mixup: if args.mixup:
from gluoncv.data.mixup import detection from gluoncv.data.mixup import detection
train_dataset = detection.MixupDetection(train_dataset) train_dataset = detection.MixupDetection(train_dataset)
return train_dataset, val_dataset, val_metric return train_dataset, val_dataset, val_metric
def get_dataloader(net, train_dataset, val_dataset, train_transform, val_transform, batch_size, def get_dataloader(
num_shards, args): net,
train_dataset,
val_dataset,
train_transform,
val_transform,
batch_size,
num_shards,
args,
):
"""Get dataloader.""" """Get dataloader."""
train_bfn = FasterRCNNTrainBatchify(net, num_shards) train_bfn = FasterRCNNTrainBatchify(net, num_shards)
if hasattr(train_dataset, 'get_im_aspect_ratio'): if hasattr(train_dataset, "get_im_aspect_ratio"):
im_aspect_ratio = train_dataset.get_im_aspect_ratio() im_aspect_ratio = train_dataset.get_im_aspect_ratio()
else: else:
im_aspect_ratio = [1.] * len(train_dataset) im_aspect_ratio = [1.0] * len(train_dataset)
train_sampler = \ train_sampler = gcv.nn.sampler.SplitSortedBucketSampler(
gcv.nn.sampler.SplitSortedBucketSampler(im_aspect_ratio, batch_size, im_aspect_ratio,
num_parts=hvd.size() if args.horovod else 1, batch_size,
part_index=hvd.rank() if args.horovod else 0, num_parts=hvd.size() if args.horovod else 1,
shuffle=True) part_index=hvd.rank() if args.horovod else 0,
train_loader = mx.gluon.data.DataLoader(train_dataset.transform( shuffle=True,
train_transform(net.short, net.max_size, net, ashape=net.ashape, multi_stage=args.use_fpn)), )
batch_sampler=train_sampler, batchify_fn=train_bfn, num_workers=args.num_workers) train_loader = mx.gluon.data.DataLoader(
train_dataset.transform(
train_transform(
net.short,
net.max_size,
net,
ashape=net.ashape,
multi_stage=args.use_fpn,
)
),
batch_sampler=train_sampler,
batchify_fn=train_bfn,
num_workers=args.num_workers,
)
if val_dataset is None: if val_dataset is None:
val_loader = None val_loader = None
else: else:
val_bfn = Tuple(*[Append() for _ in range(3)]) val_bfn = Tuple(*[Append() for _ in range(3)])
short = net.short[-1] if isinstance(net.short, (tuple, list)) else net.short short = (
net.short[-1] if isinstance(net.short, (tuple, list)) else net.short
)
# validation use 1 sample per device # validation use 1 sample per device
val_loader = mx.gluon.data.DataLoader( val_loader = mx.gluon.data.DataLoader(
val_dataset.transform(val_transform(short, net.max_size)), num_shards, False, val_dataset.transform(val_transform(short, net.max_size)),
batchify_fn=val_bfn, last_batch='keep', num_workers=args.num_workers) num_shards,
False,
batchify_fn=val_bfn,
last_batch="keep",
num_workers=args.num_workers,
)
return train_loader, val_loader return train_loader, val_loader
def save_params(net, logger, best_map, current_map, epoch, save_interval, prefix): def save_params(
net, logger, best_map, current_map, epoch, save_interval, prefix
):
current_map = float(current_map) current_map = float(current_map)
if current_map > best_map[0]: if current_map > best_map[0]:
logger.info('[Epoch {}] mAP {} higher than current best {} saving to {}'.format( logger.info(
epoch, current_map, best_map, '{:s}_best.params'.format(prefix))) "[Epoch {}] mAP {} higher than current best {} saving to {}".format(
epoch, current_map, best_map, "{:s}_best.params".format(prefix)
)
)
best_map[0] = current_map best_map[0] = current_map
net.save_parameters('{:s}_best.params'.format(prefix)) net.save_parameters("{:s}_best.params".format(prefix))
with open(prefix + '_best_map.log', 'a') as f: with open(prefix + "_best_map.log", "a") as f:
f.write('{:04d}:\t{:.4f}\n'.format(epoch, current_map)) f.write("{:04d}:\t{:.4f}\n".format(epoch, current_map))
if save_interval and (epoch + 1) % save_interval == 0: if save_interval and (epoch + 1) % save_interval == 0:
logger.info('[Epoch {}] Saving parameters to {}'.format( logger.info(
epoch, '{:s}_{:04d}_{:.4f}.params'.format(prefix, epoch, current_map))) "[Epoch {}] Saving parameters to {}".format(
net.save_parameters('{:s}_{:04d}_{:.4f}.params'.format(prefix, epoch, current_map)) epoch,
"{:s}_{:04d}_{:.4f}.params".format(prefix, epoch, current_map),
)
)
net.save_parameters(
"{:s}_{:04d}_{:.4f}.params".format(prefix, epoch, current_map)
)
def split_and_load(batch, ctx_list): def split_and_load(batch, ctx_list):
...@@ -254,23 +430,37 @@ def validate(net, val_data, ctx, eval_metric, args): ...@@ -254,23 +430,37 @@ def validate(net, val_data, ctx, eval_metric, args):
gt_ids.append(y.slice_axis(axis=-1, begin=4, end=5)) gt_ids.append(y.slice_axis(axis=-1, begin=4, end=5))
gt_bboxes.append(y.slice_axis(axis=-1, begin=0, end=4)) gt_bboxes.append(y.slice_axis(axis=-1, begin=0, end=4))
gt_bboxes[-1] *= im_scale gt_bboxes[-1] *= im_scale
gt_difficults.append(y.slice_axis(axis=-1, begin=5, end=6) if y.shape[-1] > 5 else None) gt_difficults.append(
y.slice_axis(axis=-1, begin=5, end=6)
if y.shape[-1] > 5
else None
)
# update metric # update metric
for det_bbox, det_id, det_score, gt_bbox, gt_id, gt_diff in zip(det_bboxes, det_ids, for det_bbox, det_id, det_score, gt_bbox, gt_id, gt_diff in zip(
det_scores, gt_bboxes, det_bboxes, det_ids, det_scores, gt_bboxes, gt_ids, gt_difficults
gt_ids, gt_difficults): ):
eval_metric.update(det_bbox, det_id, det_score, gt_bbox, gt_id, gt_diff) eval_metric.update(
det_bbox, det_id, det_score, gt_bbox, gt_id, gt_diff
)
return eval_metric.get() return eval_metric.get()
def get_lr_at_iter(alpha, lr_warmup_factor=1. / 3.): def get_lr_at_iter(alpha, lr_warmup_factor=1.0 / 3.0):
return lr_warmup_factor * (1 - alpha) + alpha return lr_warmup_factor * (1 - alpha) + alpha
class ForwardBackwardTask(Parallelizable): class ForwardBackwardTask(Parallelizable):
def __init__(self, net, optimizer, rpn_cls_loss, rpn_box_loss, rcnn_cls_loss, rcnn_box_loss, def __init__(
mix_ratio): self,
net,
optimizer,
rpn_cls_loss,
rpn_box_loss,
rcnn_cls_loss,
rcnn_box_loss,
mix_ratio,
):
super(ForwardBackwardTask, self).__init__() super(ForwardBackwardTask, self).__init__()
self.net = net self.net = net
self._optimizer = optimizer self._optimizer = optimizer
...@@ -285,96 +475,159 @@ class ForwardBackwardTask(Parallelizable): ...@@ -285,96 +475,159 @@ class ForwardBackwardTask(Parallelizable):
with autograd.record(): with autograd.record():
gt_label = label[:, :, 4:5] gt_label = label[:, :, 4:5]
gt_box = label[:, :, :4] gt_box = label[:, :, :4]
cls_pred, box_pred, roi, samples, matches, rpn_score, rpn_box, anchors, cls_targets, \ (
box_targets, box_masks, _ = net(data, gt_box, gt_label) cls_pred,
box_pred,
roi,
samples,
matches,
rpn_score,
rpn_box,
anchors,
cls_targets,
box_targets,
box_masks,
_,
) = net(data, gt_box, gt_label)
# losses of rpn # losses of rpn
rpn_score = rpn_score.squeeze(axis=-1) rpn_score = rpn_score.squeeze(axis=-1)
num_rpn_pos = (rpn_cls_targets >= 0).sum() num_rpn_pos = (rpn_cls_targets >= 0).sum()
rpn_loss1 = self.rpn_cls_loss(rpn_score, rpn_cls_targets, rpn_loss1 = (
rpn_cls_targets >= 0) * rpn_cls_targets.size / num_rpn_pos self.rpn_cls_loss(
rpn_loss2 = self.rpn_box_loss(rpn_box, rpn_box_targets, rpn_score, rpn_cls_targets, rpn_cls_targets >= 0
rpn_box_masks) * rpn_box.size / num_rpn_pos )
* rpn_cls_targets.size
/ num_rpn_pos
)
rpn_loss2 = (
self.rpn_box_loss(rpn_box, rpn_box_targets, rpn_box_masks)
* rpn_box.size
/ num_rpn_pos
)
# rpn overall loss, use sum rather than average # rpn overall loss, use sum rather than average
rpn_loss = rpn_loss1 + rpn_loss2 rpn_loss = rpn_loss1 + rpn_loss2
# losses of rcnn # losses of rcnn
num_rcnn_pos = (cls_targets >= 0).sum() num_rcnn_pos = (cls_targets >= 0).sum()
rcnn_loss1 = self.rcnn_cls_loss(cls_pred, cls_targets, rcnn_loss1 = (
cls_targets.expand_dims(-1) >= 0) * cls_targets.size / \ self.rcnn_cls_loss(
num_rcnn_pos cls_pred, cls_targets, cls_targets.expand_dims(-1) >= 0
rcnn_loss2 = self.rcnn_box_loss(box_pred, box_targets, box_masks) * box_pred.size / \ )
num_rcnn_pos * cls_targets.size
/ num_rcnn_pos
)
rcnn_loss2 = (
self.rcnn_box_loss(box_pred, box_targets, box_masks)
* box_pred.size
/ num_rcnn_pos
)
rcnn_loss = rcnn_loss1 + rcnn_loss2 rcnn_loss = rcnn_loss1 + rcnn_loss2
# overall losses # overall losses
total_loss = rpn_loss.sum() * self.mix_ratio + rcnn_loss.sum() * self.mix_ratio total_loss = (
rpn_loss.sum() * self.mix_ratio
+ rcnn_loss.sum() * self.mix_ratio
)
rpn_loss1_metric = rpn_loss1.mean() * self.mix_ratio rpn_loss1_metric = rpn_loss1.mean() * self.mix_ratio
rpn_loss2_metric = rpn_loss2.mean() * self.mix_ratio rpn_loss2_metric = rpn_loss2.mean() * self.mix_ratio
rcnn_loss1_metric = rcnn_loss1.mean() * self.mix_ratio rcnn_loss1_metric = rcnn_loss1.mean() * self.mix_ratio
rcnn_loss2_metric = rcnn_loss2.mean() * self.mix_ratio rcnn_loss2_metric = rcnn_loss2.mean() * self.mix_ratio
rpn_acc_metric = [[rpn_cls_targets, rpn_cls_targets >= 0], [rpn_score]] rpn_acc_metric = [
[rpn_cls_targets, rpn_cls_targets >= 0],
[rpn_score],
]
rpn_l1_loss_metric = [[rpn_box_targets, rpn_box_masks], [rpn_box]] rpn_l1_loss_metric = [[rpn_box_targets, rpn_box_masks], [rpn_box]]
rcnn_acc_metric = [[cls_targets], [cls_pred]] rcnn_acc_metric = [[cls_targets], [cls_pred]]
rcnn_l1_loss_metric = [[box_targets, box_masks], [box_pred]] rcnn_l1_loss_metric = [[box_targets, box_masks], [box_pred]]
if args.amp: if args.amp:
with amp.scale_loss(total_loss, self._optimizer) as scaled_losses: with amp.scale_loss(
total_loss, self._optimizer
) as scaled_losses:
autograd.backward(scaled_losses) autograd.backward(scaled_losses)
else: else:
total_loss.backward() total_loss.backward()
return rpn_loss1_metric, rpn_loss2_metric, rcnn_loss1_metric, rcnn_loss2_metric, \ return (
rpn_acc_metric, rpn_l1_loss_metric, rcnn_acc_metric, rcnn_l1_loss_metric rpn_loss1_metric,
rpn_loss2_metric,
rcnn_loss1_metric,
rcnn_loss2_metric,
rpn_acc_metric,
rpn_l1_loss_metric,
rcnn_acc_metric,
rcnn_l1_loss_metric,
)
def train(net, train_data, val_data, eval_metric, batch_size, ctx, args): def train(net, train_data, val_data, eval_metric, batch_size, ctx, args):
"""Training pipeline""" """Training pipeline"""
args.kv_store = 'device' if (args.amp and 'nccl' in args.kv_store) else args.kv_store args.kv_store = (
"device" if (args.amp and "nccl" in args.kv_store) else args.kv_store
)
kv = mx.kvstore.create(args.kv_store) kv = mx.kvstore.create(args.kv_store)
net.collect_params().setattr('grad_req', 'null') net.collect_params().setattr("grad_req", "null")
net.collect_train_params().setattr('grad_req', 'write') net.collect_train_params().setattr("grad_req", "write")
optimizer_params = {'learning_rate': args.lr, 'wd': args.wd, 'momentum': args.momentum} optimizer_params = {
"learning_rate": args.lr,
"wd": args.wd,
"momentum": args.momentum,
}
if args.horovod: if args.horovod:
hvd.broadcast_parameters(net.collect_params(), root_rank=0) hvd.broadcast_parameters(net.collect_params(), root_rank=0)
trainer = hvd.DistributedTrainer( trainer = hvd.DistributedTrainer(
net.collect_train_params(), # fix batchnorm, fix first stage, etc... net.collect_train_params(), # fix batchnorm, fix first stage, etc...
'sgd', "sgd",
optimizer_params) optimizer_params,
)
else: else:
trainer = gluon.Trainer( trainer = gluon.Trainer(
net.collect_train_params(), # fix batchnorm, fix first stage, etc... net.collect_train_params(), # fix batchnorm, fix first stage, etc...
'sgd', "sgd",
optimizer_params, optimizer_params,
update_on_kvstore=(False if args.amp else None), kvstore=kv) update_on_kvstore=(False if args.amp else None),
kvstore=kv,
)
if args.amp: if args.amp:
amp.init_trainer(trainer) amp.init_trainer(trainer)
# lr decay policy # lr decay policy
lr_decay = float(args.lr_decay) lr_decay = float(args.lr_decay)
lr_steps = sorted([float(ls) for ls in args.lr_decay_epoch.split(',') if ls.strip()]) lr_steps = sorted(
[float(ls) for ls in args.lr_decay_epoch.split(",") if ls.strip()]
)
lr_warmup = float(args.lr_warmup) # avoid int division lr_warmup = float(args.lr_warmup) # avoid int division
# TODO(zhreshold) losses? # TODO(zhreshold) losses?
rpn_cls_loss = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss(from_sigmoid=False) rpn_cls_loss = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss(
rpn_box_loss = mx.gluon.loss.HuberLoss(rho=1 / 9.) # == smoothl1 from_sigmoid=False
)
rpn_box_loss = mx.gluon.loss.HuberLoss(rho=1 / 9.0) # == smoothl1
rcnn_cls_loss = mx.gluon.loss.SoftmaxCrossEntropyLoss() rcnn_cls_loss = mx.gluon.loss.SoftmaxCrossEntropyLoss()
rcnn_box_loss = mx.gluon.loss.HuberLoss() # == smoothl1 rcnn_box_loss = mx.gluon.loss.HuberLoss() # == smoothl1
metrics = [mx.metric.Loss('RPN_Conf'), metrics = [
mx.metric.Loss('RPN_SmoothL1'), mx.metric.Loss("RPN_Conf"),
mx.metric.Loss('RCNN_CrossEntropy'), mx.metric.Loss("RPN_SmoothL1"),
mx.metric.Loss('RCNN_SmoothL1'), ] mx.metric.Loss("RCNN_CrossEntropy"),
mx.metric.Loss("RCNN_SmoothL1"),
]
rpn_acc_metric = RPNAccMetric() rpn_acc_metric = RPNAccMetric()
rpn_bbox_metric = RPNL1LossMetric() rpn_bbox_metric = RPNL1LossMetric()
rcnn_acc_metric = RCNNAccMetric() rcnn_acc_metric = RCNNAccMetric()
rcnn_bbox_metric = RCNNL1LossMetric() rcnn_bbox_metric = RCNNL1LossMetric()
metrics2 = [rpn_acc_metric, rpn_bbox_metric, rcnn_acc_metric, rcnn_bbox_metric] metrics2 = [
rpn_acc_metric,
rpn_bbox_metric,
rcnn_acc_metric,
rcnn_bbox_metric,
]
# set up logger # set up logger
logging.basicConfig() logging.basicConfig()
logger = logging.getLogger() logger = logging.getLogger()
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
log_file_path = args.save_prefix + '_train.log' log_file_path = args.save_prefix + "_train.log"
log_dir = os.path.dirname(log_file_path) log_dir = os.path.dirname(log_file_path)
if log_dir and not os.path.exists(log_dir): if log_dir and not os.path.exists(log_dir):
os.makedirs(log_dir) os.makedirs(log_dir)
...@@ -382,17 +635,28 @@ def train(net, train_data, val_data, eval_metric, batch_size, ctx, args): ...@@ -382,17 +635,28 @@ def train(net, train_data, val_data, eval_metric, batch_size, ctx, args):
logger.addHandler(fh) logger.addHandler(fh)
logger.info(args) logger.info(args)
if args.verbose: if args.verbose:
logger.info('Trainable parameters:') logger.info("Trainable parameters:")
logger.info(net.collect_train_params().keys()) logger.info(net.collect_train_params().keys())
logger.info('Start training from [Epoch {}]'.format(args.start_epoch)) logger.info("Start training from [Epoch {}]".format(args.start_epoch))
best_map = [0] best_map = [0]
for epoch in range(args.start_epoch, args.epochs): for epoch in range(args.start_epoch, args.epochs):
mix_ratio = 1.0 mix_ratio = 1.0
if not args.disable_hybridization: if not args.disable_hybridization:
net.hybridize(static_alloc=args.static_alloc) net.hybridize(static_alloc=args.static_alloc)
rcnn_task = ForwardBackwardTask(net, trainer, rpn_cls_loss, rpn_box_loss, rcnn_cls_loss, rcnn_task = ForwardBackwardTask(
rcnn_box_loss, mix_ratio=1.0) net,
executor = Parallel(args.executor_threads, rcnn_task) if not args.horovod else None trainer,
rpn_cls_loss,
rpn_box_loss,
rcnn_cls_loss,
rcnn_box_loss,
mix_ratio=1.0,
)
executor = (
Parallel(args.executor_threads, rcnn_task)
if not args.horovod
else None
)
if args.mixup: if args.mixup:
# TODO(zhreshold) only support evenly mixup now, target generator needs to be modified otherwise # TODO(zhreshold) only support evenly mixup now, target generator needs to be modified otherwise
train_data._dataset._data.set_mixup(np.random.uniform, 0.5, 0.5) train_data._dataset._data.set_mixup(np.random.uniform, 0.5, 0.5)
...@@ -404,22 +668,29 @@ def train(net, train_data, val_data, eval_metric, batch_size, ctx, args): ...@@ -404,22 +668,29 @@ def train(net, train_data, val_data, eval_metric, batch_size, ctx, args):
new_lr = trainer.learning_rate * lr_decay new_lr = trainer.learning_rate * lr_decay
lr_steps.pop(0) lr_steps.pop(0)
trainer.set_learning_rate(new_lr) trainer.set_learning_rate(new_lr)
logger.info("[Epoch {}] Set learning rate to {}".format(epoch, new_lr)) logger.info(
"[Epoch {}] Set learning rate to {}".format(epoch, new_lr)
)
for metric in metrics: for metric in metrics:
metric.reset() metric.reset()
tic = time.time() tic = time.time()
btic = time.time() btic = time.time()
base_lr = trainer.learning_rate base_lr = trainer.learning_rate
rcnn_task.mix_ratio = mix_ratio rcnn_task.mix_ratio = mix_ratio
logger.info('Total Num of Batches: %d'%(len(train_data))) logger.info("Total Num of Batches: %d" % (len(train_data)))
for i, batch in enumerate(train_data): for i, batch in enumerate(train_data):
if epoch == 0 and i <= lr_warmup: if epoch == 0 and i <= lr_warmup:
# adjust based on real percentage # adjust based on real percentage
new_lr = base_lr * get_lr_at_iter(i / lr_warmup, args.lr_warmup_factor) new_lr = base_lr * get_lr_at_iter(
i / lr_warmup, args.lr_warmup_factor
)
if new_lr != trainer.learning_rate: if new_lr != trainer.learning_rate:
if i % args.log_interval == 0: if i % args.log_interval == 0:
logger.info( logger.info(
'[Epoch 0 Iteration {}] Set learning rate to {}'.format(i, new_lr)) "[Epoch 0 Iteration {}] Set learning rate to {}".format(
i, new_lr
)
)
trainer.set_learning_rate(new_lr) trainer.set_learning_rate(new_lr)
batch = split_and_load(batch, ctx_list=ctx) batch = split_and_load(batch, ctx_list=ctx)
metric_losses = [[] for _ in metrics] metric_losses = [[] for _ in metrics]
...@@ -445,34 +716,70 @@ def train(net, train_data, val_data, eval_metric, batch_size, ctx, args): ...@@ -445,34 +716,70 @@ def train(net, train_data, val_data, eval_metric, batch_size, ctx, args):
trainer.step(batch_size) trainer.step(batch_size)
# update metrics # update metrics
if (not args.horovod or hvd.rank() == 0) and args.log_interval \ if (
and not (i + 1) % args.log_interval: (not args.horovod or hvd.rank() == 0)
msg = ','.join( and args.log_interval
['{}={:.3f}'.format(*metric.get()) for metric in metrics + metrics2]) and not (i + 1) % args.log_interval
logger.info('[Epoch {}][Batch {}], Speed: {:.3f} samples/sec, {}'.format( ):
epoch, i, args.log_interval * args.batch_size / (time.time() - btic), msg)) msg = ",".join(
[
"{}={:.3f}".format(*metric.get())
for metric in metrics + metrics2
]
)
logger.info(
"[Epoch {}][Batch {}], Speed: {:.3f} samples/sec, {}".format(
epoch,
i,
args.log_interval
* args.batch_size
/ (time.time() - btic),
msg,
)
)
btic = time.time() btic = time.time()
if (not args.horovod) or hvd.rank() == 0: if (not args.horovod) or hvd.rank() == 0:
msg = ','.join(['{}={:.3f}'.format(*metric.get()) for metric in metrics]) msg = ",".join(
logger.info('[Epoch {}] Training cost: {:.3f}, {}'.format( ["{}={:.3f}".format(*metric.get()) for metric in metrics]
epoch, (time.time() - tic), msg)) )
logger.info(
"[Epoch {}] Training cost: {:.3f}, {}".format(
epoch, (time.time() - tic), msg
)
)
if not (epoch + 1) % args.val_interval: if not (epoch + 1) % args.val_interval:
# consider reduce the frequency of validation to save time # consider reduce the frequency of validation to save time
if val_data is not None: if val_data is not None:
map_name, mean_ap = validate(net, val_data, ctx, eval_metric, args) map_name, mean_ap = validate(
val_msg = '\n'.join(['{}={}'.format(k, v) for k, v in zip(map_name, mean_ap)]) net, val_data, ctx, eval_metric, args
logger.info('[Epoch {}] Validation: \n{}'.format(epoch, val_msg)) )
val_msg = "\n".join(
[
"{}={}".format(k, v)
for k, v in zip(map_name, mean_ap)
]
)
logger.info(
"[Epoch {}] Validation: \n{}".format(epoch, val_msg)
)
current_map = float(mean_ap[-1]) current_map = float(mean_ap[-1])
else: else:
current_map = 0 current_map = 0
else: else:
current_map = 0. current_map = 0.0
save_params(net, logger, best_map, current_map, epoch, args.save_interval, save_params(
args.save_prefix) net,
logger,
best_map,
if __name__ == '__main__': current_map,
epoch,
args.save_interval,
args.save_prefix,
)
if __name__ == "__main__":
import sys import sys
sys.setrecursionlimit(1100) sys.setrecursionlimit(1100)
...@@ -487,26 +794,31 @@ if __name__ == '__main__': ...@@ -487,26 +794,31 @@ if __name__ == '__main__':
if args.horovod: if args.horovod:
ctx = [mx.gpu(hvd.local_rank())] ctx = [mx.gpu(hvd.local_rank())]
else: else:
ctx = [mx.gpu(int(i)) for i in args.gpus.split(',') if i.strip()] ctx = [mx.gpu(int(i)) for i in args.gpus.split(",") if i.strip()]
ctx = ctx if ctx else [mx.cpu()] ctx = ctx if ctx else [mx.cpu()]
# network # network
kwargs = {} kwargs = {}
module_list = [] module_list = []
if args.use_fpn: if args.use_fpn:
module_list.append('fpn') module_list.append("fpn")
if args.norm_layer is not None: if args.norm_layer is not None:
module_list.append(args.norm_layer) module_list.append(args.norm_layer)
if args.norm_layer == 'bn': if args.norm_layer == "bn":
kwargs['num_devices'] = len(args.gpus.split(',')) kwargs["num_devices"] = len(args.gpus.split(","))
net_name = '_'.join(('faster_rcnn', *module_list, args.network, 'custom')) net_name = "_".join(("faster_rcnn", *module_list, args.network, "custom"))
args.save_prefix += net_name args.save_prefix += net_name
gutils.makedirs(args.save_prefix) gutils.makedirs(args.save_prefix)
train_dataset, val_dataset, eval_metric = get_dataset(args.dataset, args) train_dataset, val_dataset, eval_metric = get_dataset(args.dataset, args)
net = faster_rcnn_resnet101_v1d_custom(classes=train_dataset.classes, transfer='coco', net = faster_rcnn_resnet101_v1d_custom(
pretrained_base=False, additional_output=False, classes=train_dataset.classes,
per_device_batch_size=args.batch_size // len(ctx), **kwargs) transfer="coco",
pretrained_base=False,
additional_output=False,
per_device_batch_size=args.batch_size // len(ctx),
**kwargs
)
if args.resume.strip(): if args.resume.strip():
net.load_parameters(args.resume.strip()) net.load_parameters(args.resume.strip())
else: else:
...@@ -517,10 +829,19 @@ if __name__ == '__main__': ...@@ -517,10 +829,19 @@ if __name__ == '__main__':
net.collect_params().reset_ctx(ctx) net.collect_params().reset_ctx(ctx)
# training data # training data
batch_size = args.batch_size // len(ctx) if args.horovod else args.batch_size batch_size = (
args.batch_size // len(ctx) if args.horovod else args.batch_size
)
train_data, val_data = get_dataloader( train_data, val_data = get_dataloader(
net, train_dataset, val_dataset, FasterRCNNDefaultTrainTransform, net,
FasterRCNNDefaultValTransform, batch_size, len(ctx), args) train_dataset,
val_dataset,
FasterRCNNDefaultTrainTransform,
FasterRCNNDefaultValTransform,
batch_size,
len(ctx),
args,
)
# training # training
train(net, train_data, val_data, eval_metric, batch_size, ctx, args) train(net, train_data, val_data, eval_metric, batch_size, ctx, args)
import argparse
import json
import os
import pickle
import numpy as np import numpy as np
import json, pickle, os, argparse
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description='Train the Frequenct Prior For RelDN.') parser = argparse.ArgumentParser(
parser.add_argument('--overlap', action='store_true', description="Train the Frequenct Prior For RelDN."
help="Only count overlap boxes.") )
parser.add_argument('--json-path', type=str, default='~/.mxnet/datasets/visualgenome', parser.add_argument(
help="Only count overlap boxes.") "--overlap", action="store_true", help="Only count overlap boxes."
)
parser.add_argument(
"--json-path",
type=str,
default="~/.mxnet/datasets/visualgenome",
help="Only count overlap boxes.",
)
args = parser.parse_args() args = parser.parse_args()
return args return args
args = parse_args() args = parse_args()
use_overlap = args.overlap use_overlap = args.overlap
PATH_TO_DATASETS = os.path.expanduser(args.json_path) PATH_TO_DATASETS = os.path.expanduser(args.json_path)
path_to_json = os.path.join(PATH_TO_DATASETS, 'rel_annotations_train.json') path_to_json = os.path.join(PATH_TO_DATASETS, "rel_annotations_train.json")
# format in y1y2x1x2 # format in y1y2x1x2
def with_overlap(boxA, boxB): def with_overlap(boxA, boxB):
...@@ -29,17 +42,19 @@ def with_overlap(boxA, boxB): ...@@ -29,17 +42,19 @@ def with_overlap(boxA, boxB):
return 0 return 0
def box_ious(boxes): def box_ious(boxes):
n = len(boxes) n = len(boxes)
res = np.zeros((n, n)) res = np.zeros((n, n))
for i in range(n-1): for i in range(n - 1):
for j in range(i+1, n): for j in range(i + 1, n):
iou_val = with_overlap(boxes[i], boxes[j]) iou_val = with_overlap(boxes[i], boxes[j])
res[i, j] = iou_val res[i, j] = iou_val
res[j, i] = iou_val res[j, i] = iou_val
return res return res
with open(path_to_json, 'r') as f:
with open(path_to_json, "r") as f:
tmp = f.read() tmp = f.read()
train_data = json.loads(tmp) train_data = json.loads(tmp)
...@@ -49,11 +64,11 @@ bg_matrix = np.zeros((150, 150), dtype=np.int64) ...@@ -49,11 +64,11 @@ bg_matrix = np.zeros((150, 150), dtype=np.int64)
for _, item in train_data.items(): for _, item in train_data.items():
gt_box_to_label = {} gt_box_to_label = {}
for rel in item: for rel in item:
sub_bbox = rel['subject']['bbox'] sub_bbox = rel["subject"]["bbox"]
ob_bbox = rel['object']['bbox'] ob_bbox = rel["object"]["bbox"]
sub_class = rel['subject']['category'] sub_class = rel["subject"]["category"]
ob_class = rel['object']['category'] ob_class = rel["object"]["category"]
rel_class = rel['predicate'] rel_class = rel["predicate"]
sub_node = tuple(sub_bbox) sub_node = tuple(sub_bbox)
ob_node = tuple(ob_bbox) ob_node = tuple(ob_bbox)
...@@ -93,8 +108,8 @@ pred_dist = np.log(fg_matrix / (fg_matrix.sum(2)[:, :, None] + eps) + eps) ...@@ -93,8 +108,8 @@ pred_dist = np.log(fg_matrix / (fg_matrix.sum(2)[:, :, None] + eps) + eps)
if use_overlap: if use_overlap:
with open('freq_prior_overlap.pkl', 'wb') as f: with open("freq_prior_overlap.pkl", "wb") as f:
pickle.dump(pred_dist, f) pickle.dump(pred_dist, f)
else: else:
with open('freq_prior.pkl', 'wb') as f: with open("freq_prior.pkl", "wb") as f:
pickle.dump(pred_dist, f) pickle.dump(pred_dist, f)
import dgl import argparse
import logging
import time
import mxnet as mx import mxnet as mx
import numpy as np import numpy as np
import logging, time, argparse from data import *
from mxnet import nd, gluon
from gluoncv.data.batchify import Pad from gluoncv.data.batchify import Pad
from gluoncv.utils import makedirs from gluoncv.utils import makedirs
from model import RelDN, faster_rcnn_resnet101_v1d_custom
from model import faster_rcnn_resnet101_v1d_custom, RelDN from mxnet import gluon, nd
from utils import * from utils import *
from data import *
import dgl
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description='Train RelDN Model.') parser = argparse.ArgumentParser(description="Train RelDN Model.")
parser.add_argument('--gpus', type=str, default='0', parser.add_argument(
help="Training with GPUs, you can specify 1,3 for example.") "--gpus",
parser.add_argument('--batch-size', type=int, default=8, type=str,
help="Total batch-size for training.") default="0",
parser.add_argument('--epochs', type=int, default=9, help="Training with GPUs, you can specify 1,3 for example.",
help="Training epochs.") )
parser.add_argument('--lr-reldn', type=float, default=0.01, parser.add_argument(
help="Learning rate for RelDN module.") "--batch-size",
parser.add_argument('--wd-reldn', type=float, default=0.0001, type=int,
help="Weight decay for RelDN module.") default=8,
parser.add_argument('--lr-faster-rcnn', type=float, default=0.01, help="Total batch-size for training.",
help="Learning rate for Faster R-CNN module.") )
parser.add_argument('--wd-faster-rcnn', type=float, default=0.0001, parser.add_argument(
help="Weight decay for RelDN module.") "--epochs", type=int, default=9, help="Training epochs."
parser.add_argument('--lr-decay-epochs', type=str, default='5,8', )
help="Learning rate decay points.") parser.add_argument(
parser.add_argument('--lr-warmup-iters', type=int, default=4000, "--lr-reldn",
help="Learning rate warm-up iterations.") type=float,
parser.add_argument('--save-dir', type=str, default='params_resnet101_v1d_reldn', default=0.01,
help="Path to save model parameters.") help="Learning rate for RelDN module.",
parser.add_argument('--log-dir', type=str, default='reldn_output.log', )
help="Path to save training logs.") parser.add_argument(
parser.add_argument('--pretrained-faster-rcnn-params', type=str, required=True, "--wd-reldn",
help="Path to saved Faster R-CNN model parameters.") type=float,
parser.add_argument('--freq-prior', type=str, default='freq_prior.pkl', default=0.0001,
help="Path to saved frequency prior data.") help="Weight decay for RelDN module.",
parser.add_argument('--verbose-freq', type=int, default=100, )
help="Frequency of log printing in number of iterations.") parser.add_argument(
"--lr-faster-rcnn",
type=float,
default=0.01,
help="Learning rate for Faster R-CNN module.",
)
parser.add_argument(
"--wd-faster-rcnn",
type=float,
default=0.0001,
help="Weight decay for RelDN module.",
)
parser.add_argument(
"--lr-decay-epochs",
type=str,
default="5,8",
help="Learning rate decay points.",
)
parser.add_argument(
"--lr-warmup-iters",
type=int,
default=4000,
help="Learning rate warm-up iterations.",
)
parser.add_argument(
"--save-dir",
type=str,
default="params_resnet101_v1d_reldn",
help="Path to save model parameters.",
)
parser.add_argument(
"--log-dir",
type=str,
default="reldn_output.log",
help="Path to save training logs.",
)
parser.add_argument(
"--pretrained-faster-rcnn-params",
type=str,
required=True,
help="Path to saved Faster R-CNN model parameters.",
)
parser.add_argument(
"--freq-prior",
type=str,
default="freq_prior.pkl",
help="Path to saved frequency prior data.",
)
parser.add_argument(
"--verbose-freq",
type=int,
default=100,
help="Frequency of log printing in number of iterations.",
)
args = parser.parse_args() args = parser.parse_args()
return args return args
args = parse_args() args = parse_args()
filehandler = logging.FileHandler(args.log_dir) filehandler = logging.FileHandler(args.log_dir)
streamhandler = logging.StreamHandler() streamhandler = logging.StreamHandler()
logger = logging.getLogger('') logger = logging.getLogger("")
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
logger.addHandler(filehandler) logger.addHandler(filehandler)
logger.addHandler(streamhandler) logger.addHandler(streamhandler)
# Hyperparams # Hyperparams
ctx = [mx.gpu(int(i)) for i in args.gpus.split(',') if i.strip()] ctx = [mx.gpu(int(i)) for i in args.gpus.split(",") if i.strip()]
if ctx: if ctx:
num_gpus = len(ctx) num_gpus = len(ctx)
assert args.batch_size % num_gpus == 0 assert args.batch_size % num_gpus == 0
...@@ -71,13 +129,18 @@ N_objects = 150 ...@@ -71,13 +129,18 @@ N_objects = 150
save_dir = args.save_dir save_dir = args.save_dir
makedirs(save_dir) makedirs(save_dir)
batch_verbose_freq = args.verbose_freq batch_verbose_freq = args.verbose_freq
lr_decay_epochs = [int(i) for i in args.lr_decay_epochs.split(',')] lr_decay_epochs = [int(i) for i in args.lr_decay_epochs.split(",")]
# Dataset and dataloader # Dataset and dataloader
vg_train = VGRelation(split='train') vg_train = VGRelation(split="train")
logger.info('data loaded!') logger.info("data loaded!")
train_data = gluon.data.DataLoader(vg_train, batch_size=len(ctx), shuffle=True, num_workers=8*num_gpus, train_data = gluon.data.DataLoader(
batchify_fn=dgl_mp_batchify_fn) vg_train,
batch_size=len(ctx),
shuffle=True,
num_workers=8 * num_gpus,
batchify_fn=dgl_mp_batchify_fn,
)
n_batches = len(train_data) n_batches = len(train_data)
# Network definition # Network definition
...@@ -85,30 +148,47 @@ net = RelDN(n_classes=N_relations, prior_pkl=args.freq_prior) ...@@ -85,30 +148,47 @@ net = RelDN(n_classes=N_relations, prior_pkl=args.freq_prior)
net.spatial.initialize(mx.init.Normal(1e-4), ctx=ctx) net.spatial.initialize(mx.init.Normal(1e-4), ctx=ctx)
net.visual.initialize(mx.init.Normal(1e-4), ctx=ctx) net.visual.initialize(mx.init.Normal(1e-4), ctx=ctx)
for k, v in net.collect_params().items(): for k, v in net.collect_params().items():
v.grad_req = 'add' if aggregate_grad else 'write' v.grad_req = "add" if aggregate_grad else "write"
net_params = net.collect_params() net_params = net.collect_params()
net_trainer = gluon.Trainer(net.collect_params(), 'adam', net_trainer = gluon.Trainer(
{'learning_rate': args.lr_reldn, 'wd': args.wd_reldn}) net.collect_params(),
"adam",
{"learning_rate": args.lr_reldn, "wd": args.wd_reldn},
)
det_params_path = args.pretrained_faster_rcnn_params det_params_path = args.pretrained_faster_rcnn_params
detector = faster_rcnn_resnet101_v1d_custom(classes=vg_train.obj_classes, detector = faster_rcnn_resnet101_v1d_custom(
pretrained_base=False, pretrained=False, classes=vg_train.obj_classes,
additional_output=True) pretrained_base=False,
detector.load_parameters(det_params_path, ctx=ctx, ignore_extra=True, allow_missing=True) pretrained=False,
additional_output=True,
)
detector.load_parameters(
det_params_path, ctx=ctx, ignore_extra=True, allow_missing=True
)
for k, v in detector.collect_params().items(): for k, v in detector.collect_params().items():
v.grad_req = 'null' v.grad_req = "null"
detector_feat = faster_rcnn_resnet101_v1d_custom(classes=vg_train.obj_classes, detector_feat = faster_rcnn_resnet101_v1d_custom(
pretrained_base=False, pretrained=False, classes=vg_train.obj_classes,
additional_output=True) pretrained_base=False,
detector_feat.load_parameters(det_params_path, ctx=ctx, ignore_extra=True, allow_missing=True) pretrained=False,
additional_output=True,
)
detector_feat.load_parameters(
det_params_path, ctx=ctx, ignore_extra=True, allow_missing=True
)
for k, v in detector_feat.collect_params().items(): for k, v in detector_feat.collect_params().items():
v.grad_req = 'null' v.grad_req = "null"
for k, v in detector_feat.features.collect_params().items(): for k, v in detector_feat.features.collect_params().items():
v.grad_req = 'add' if aggregate_grad else 'write' v.grad_req = "add" if aggregate_grad else "write"
det_params = detector_feat.features.collect_params() det_params = detector_feat.features.collect_params()
det_trainer = gluon.Trainer(detector_feat.features.collect_params(), 'adam', det_trainer = gluon.Trainer(
{'learning_rate': args.lr_faster_rcnn, 'wd': args.wd_faster_rcnn}) detector_feat.features.collect_params(),
"adam",
{"learning_rate": args.lr_faster_rcnn, "wd": args.wd_faster_rcnn},
)
def get_data_batch(g_list, img_list, ctx_list): def get_data_batch(g_list, img_list, ctx_list):
if g_list is None or len(g_list) == 0: if g_list is None or len(g_list) == 0:
...@@ -118,37 +198,58 @@ def get_data_batch(g_list, img_list, ctx_list): ...@@ -118,37 +198,58 @@ def get_data_batch(g_list, img_list, ctx_list):
if size < n_gpu: if size < n_gpu:
raise Exception("too small batch") raise Exception("too small batch")
step = size // n_gpu step = size // n_gpu
G_list = [g_list[i*step:(i+1)*step] if i < n_gpu - 1 else g_list[i*step:size] for i in range(n_gpu)] G_list = [
img_list = [img_list[i*step:(i+1)*step] if i < n_gpu - 1 else img_list[i*step:size] for i in range(n_gpu)] g_list[i * step : (i + 1) * step]
if i < n_gpu - 1
else g_list[i * step : size]
for i in range(n_gpu)
]
img_list = [
img_list[i * step : (i + 1) * step]
if i < n_gpu - 1
else img_list[i * step : size]
for i in range(n_gpu)
]
for G_slice, ctx in zip(G_list, ctx_list): for G_slice, ctx in zip(G_list, ctx_list):
for G in G_slice: for G in G_slice:
G.ndata['bbox'] = G.ndata['bbox'].as_in_context(ctx) G.ndata["bbox"] = G.ndata["bbox"].as_in_context(ctx)
G.ndata['node_class'] = G.ndata['node_class'].as_in_context(ctx) G.ndata["node_class"] = G.ndata["node_class"].as_in_context(ctx)
G.ndata['node_class_vec'] = G.ndata['node_class_vec'].as_in_context(ctx) G.ndata["node_class_vec"] = G.ndata["node_class_vec"].as_in_context(
G.edata['rel_class'] = G.edata['rel_class'].as_in_context(ctx) ctx
)
G.edata["rel_class"] = G.edata["rel_class"].as_in_context(ctx)
img_list = [img.as_in_context(ctx) for img in img_list] img_list = [img.as_in_context(ctx) for img in img_list]
return G_list, img_list return G_list, img_list
L_rel = gluon.loss.SoftmaxCELoss() L_rel = gluon.loss.SoftmaxCELoss()
train_metric = mx.metric.Accuracy(name='rel_acc') train_metric = mx.metric.Accuracy(name="rel_acc")
train_metric_top5 = mx.metric.TopKAccuracy(5, name='rel_acc_top5') train_metric_top5 = mx.metric.TopKAccuracy(5, name="rel_acc_top5")
metric_list = [train_metric, train_metric_top5] metric_list = [train_metric, train_metric_top5]
def batch_print(epoch, i, batch_verbose_freq, n_batches, btic, loss_rel_val, metric_list):
if (i+1) % batch_verbose_freq == 0: def batch_print(
print_txt = 'Epoch[%d] Batch[%d/%d], time: %d, loss_rel=%.4f '%\ epoch, i, batch_verbose_freq, n_batches, btic, loss_rel_val, metric_list
(epoch, i, n_batches, int(time.time() - btic), ):
loss_rel_val / (i+1), ) if (i + 1) % batch_verbose_freq == 0:
print_txt = "Epoch[%d] Batch[%d/%d], time: %d, loss_rel=%.4f " % (
epoch,
i,
n_batches,
int(time.time() - btic),
loss_rel_val / (i + 1),
)
for metric in metric_list: for metric in metric_list:
metric_name, metric_val = metric.get() metric_name, metric_val = metric.get()
print_txt += '%s=%.4f '%(metric_name, metric_val) print_txt += "%s=%.4f " % (metric_name, metric_val)
logger.info(print_txt) logger.info(print_txt)
btic = time.time() btic = time.time()
loss_rel_val = 0 loss_rel_val = 0
return btic, loss_rel_val return btic, loss_rel_val
for epoch in range(nepoch): for epoch in range(nepoch):
loss_rel_val = 0 loss_rel_val = 0
tic = time.time() tic = time.time()
...@@ -159,17 +260,25 @@ for epoch in range(nepoch): ...@@ -159,17 +260,25 @@ for epoch in range(nepoch):
net_trainer_base_lr = net_trainer.learning_rate net_trainer_base_lr = net_trainer.learning_rate
det_trainer_base_lr = det_trainer.learning_rate det_trainer_base_lr = det_trainer.learning_rate
if epoch == 5 or epoch == 8: if epoch == 5 or epoch == 8:
net_trainer.set_learning_rate(net_trainer.learning_rate*0.1) net_trainer.set_learning_rate(net_trainer.learning_rate * 0.1)
det_trainer.set_learning_rate(det_trainer.learning_rate*0.1) det_trainer.set_learning_rate(det_trainer.learning_rate * 0.1)
for i, (G_list, img_list) in enumerate(train_data): for i, (G_list, img_list) in enumerate(train_data):
if epoch == 0 and i < args.lr_warmup_iters: if epoch == 0 and i < args.lr_warmup_iters:
alpha = i / args.lr_warmup_iters alpha = i / args.lr_warmup_iters
warmup_factor = 1/3 * (1 - alpha) + alpha warmup_factor = 1 / 3 * (1 - alpha) + alpha
net_trainer.set_learning_rate(net_trainer_base_lr*warmup_factor) net_trainer.set_learning_rate(net_trainer_base_lr * warmup_factor)
det_trainer.set_learning_rate(det_trainer_base_lr*warmup_factor) det_trainer.set_learning_rate(det_trainer_base_lr * warmup_factor)
G_list, img_list = get_data_batch(G_list, img_list, ctx) G_list, img_list = get_data_batch(G_list, img_list, ctx)
if G_list is None or img_list is None: if G_list is None or img_list is None:
btic, loss_rel_val = batch_print(epoch, i, batch_verbose_freq, n_batches, btic, loss_rel_val, metric_list) btic, loss_rel_val = batch_print(
epoch,
i,
batch_verbose_freq,
n_batches,
btic,
loss_rel_val,
metric_list,
)
continue continue
loss = [] loss = []
...@@ -179,17 +288,29 @@ for epoch in range(nepoch): ...@@ -179,17 +288,29 @@ for epoch in range(nepoch):
with mx.autograd.record(): with mx.autograd.record():
for G_slice, img in zip(G_list, img_list): for G_slice, img in zip(G_list, img_list):
cur_ctx = img.context cur_ctx = img.context
bbox_list = [G.ndata['bbox'] for G in G_slice] bbox_list = [G.ndata["bbox"] for G in G_slice]
bbox_stack = bbox_pad(bbox_list).as_in_context(cur_ctx) bbox_stack = bbox_pad(bbox_list).as_in_context(cur_ctx)
with mx.autograd.pause(): with mx.autograd.pause():
ids, scores, bbox, feat, feat_ind, spatial_feat = detector(img) ids, scores, bbox, feat, feat_ind, spatial_feat = detector(
g_pred_batch = build_graph_train(G_slice, bbox_stack, img, ids, scores, bbox, feat_ind, img
spatial_feat, scores_top_k=300, overlap=False) )
g_pred_batch = build_graph_train(
G_slice,
bbox_stack,
img,
ids,
scores,
bbox,
feat_ind,
spatial_feat,
scores_top_k=300,
overlap=False,
)
g_batch = l0_sample(g_pred_batch) g_batch = l0_sample(g_pred_batch)
if g_batch is None: if g_batch is None:
continue continue
rel_bbox = g_batch.edata['rel_bbox'] rel_bbox = g_batch.edata["rel_bbox"]
batch_id = g_batch.edata['batch_id'].asnumpy() batch_id = g_batch.edata["batch_id"].asnumpy()
n_sample_edges = g_batch.number_of_edges() n_sample_edges = g_batch.number_of_edges()
n_graph = len(G_slice) n_graph = len(G_slice)
bbox_rel_list = [] bbox_rel_list = []
...@@ -203,13 +324,19 @@ for epoch in range(nepoch): ...@@ -203,13 +324,19 @@ for epoch in range(nepoch):
bbox_rel_stack[:, :, 1] *= img_size[0] bbox_rel_stack[:, :, 1] *= img_size[0]
bbox_rel_stack[:, :, 2] *= img_size[1] bbox_rel_stack[:, :, 2] *= img_size[1]
bbox_rel_stack[:, :, 3] *= img_size[0] bbox_rel_stack[:, :, 3] *= img_size[0]
_, _, _, spatial_feat_rel = detector_feat(img, None, None, bbox_rel_stack) _, _, _, spatial_feat_rel = detector_feat(
img, None, None, bbox_rel_stack
)
spatial_feat_rel_list = [] spatial_feat_rel_list = []
for j in range(n_graph): for j in range(n_graph):
eids = np.where(batch_id == j)[0] eids = np.where(batch_id == j)[0]
if len(eids) > 0: if len(eids) > 0:
spatial_feat_rel_list.append(spatial_feat_rel[j, 0:len(eids)]) spatial_feat_rel_list.append(
g_batch.edata['edge_feat'] = nd.concat(*spatial_feat_rel_list, dim=0) spatial_feat_rel[j, 0 : len(eids)]
)
g_batch.edata["edge_feat"] = nd.concat(
*spatial_feat_rel_list, dim=0
)
G_batch.append(g_batch) G_batch.append(g_batch)
...@@ -218,17 +345,28 @@ for epoch in range(nepoch): ...@@ -218,17 +345,28 @@ for epoch in range(nepoch):
for G_pred, img in zip(G_batch, img_list): for G_pred, img in zip(G_batch, img_list):
if G_pred is None or G_pred.number_of_nodes() == 0: if G_pred is None or G_pred.number_of_nodes() == 0:
continue continue
loss_rel = L_rel(G_pred.edata['preds'], G_pred.edata['rel_class'], loss_rel = L_rel(
G_pred.edata['sample_weights']) G_pred.edata["preds"],
G_pred.edata["rel_class"],
G_pred.edata["sample_weights"],
)
loss.append(loss_rel.sum()) loss.append(loss_rel.sum())
loss_rel_val += loss_rel.mean().asscalar() / num_gpus loss_rel_val += loss_rel.mean().asscalar() / num_gpus
if len(loss) == 0: if len(loss) == 0:
btic, loss_rel_val = batch_print(epoch, i, batch_verbose_freq, n_batches, btic, loss_rel_val, metric_list) btic, loss_rel_val = batch_print(
epoch,
i,
batch_verbose_freq,
n_batches,
btic,
loss_rel_val,
metric_list,
)
continue continue
for l in loss: for l in loss:
l.backward() l.backward()
if (i+1) % per_device_batch_size == 0 or i == n_batches - 1: if (i + 1) % per_device_batch_size == 0 or i == n_batches - 1:
net_trainer.step(args.batch_size) net_trainer.step(args.batch_size)
det_trainer.step(args.batch_size) det_trainer.step(args.batch_size)
if aggregate_grad: if aggregate_grad:
...@@ -239,23 +377,41 @@ for epoch in range(nepoch): ...@@ -239,23 +377,41 @@ for epoch in range(nepoch):
for G_pred, img_slice in zip(G_batch, img_list): for G_pred, img_slice in zip(G_batch, img_list):
if G_pred is None or G_pred.number_of_nodes() == 0: if G_pred is None or G_pred.number_of_nodes() == 0:
continue continue
link_ind = np.where(G_pred.edata['rel_class'].asnumpy() > 0)[0] link_ind = np.where(G_pred.edata["rel_class"].asnumpy() > 0)[0]
if len(link_ind) == 0: if len(link_ind) == 0:
continue continue
train_metric.update([G_pred.edata['rel_class'][link_ind]], train_metric.update(
[G_pred.edata['preds'][link_ind]]) [G_pred.edata["rel_class"][link_ind]],
train_metric_top5.update([G_pred.edata['rel_class'][link_ind]], [G_pred.edata["preds"][link_ind]],
[G_pred.edata['preds'][link_ind]]) )
btic, loss_rel_val = batch_print(epoch, i, batch_verbose_freq, n_batches, btic, loss_rel_val, metric_list) train_metric_top5.update(
if (i+1) % batch_verbose_freq == 0: [G_pred.edata["rel_class"][link_ind]],
net.save_parameters('%s/model-%d.params'%(save_dir, epoch)) [G_pred.edata["preds"][link_ind]],
detector_feat.features.save_parameters('%s/detector_feat.features-%d.params'%(save_dir, epoch)) )
print_txt = 'Epoch[%d], time: %d, loss_rel=%.4f,'%\ btic, loss_rel_val = batch_print(
(epoch, int(time.time() - tic), epoch,
loss_rel_val / (i+1)) i,
batch_verbose_freq,
n_batches,
btic,
loss_rel_val,
metric_list,
)
if (i + 1) % batch_verbose_freq == 0:
net.save_parameters("%s/model-%d.params" % (save_dir, epoch))
detector_feat.features.save_parameters(
"%s/detector_feat.features-%d.params" % (save_dir, epoch)
)
print_txt = "Epoch[%d], time: %d, loss_rel=%.4f," % (
epoch,
int(time.time() - tic),
loss_rel_val / (i + 1),
)
for metric in metric_list: for metric in metric_list:
metric_name, metric_val = metric.get() metric_name, metric_val = metric.get()
print_txt += '%s=%.4f '%(metric_name, metric_val) print_txt += "%s=%.4f " % (metric_name, metric_val)
logger.info(print_txt) logger.info(print_txt)
net.save_parameters('%s/model-%d.params'%(save_dir, epoch)) net.save_parameters("%s/model-%d.params" % (save_dir, epoch))
detector_feat.features.save_parameters('%s/detector_feat.features-%d.params'%(save_dir, epoch)) detector_feat.features.save_parameters(
"%s/detector_feat.features-%d.params" % (save_dir, epoch)
)
from .metric import *
from .build_graph import * from .build_graph import *
from .metric import *
from .sampling import * from .sampling import *
from .viz import * from .viz import *
import dgl
from mxnet import nd
import numpy as np import numpy as np
from mxnet import nd
import dgl
def bbox_improve(bbox): def bbox_improve(bbox):
'''bbox encoding''' """bbox encoding"""
area = (bbox[:,2] - bbox[:,0]) * (bbox[:,3] - bbox[:,1]) area = (bbox[:, 2] - bbox[:, 0]) * (bbox[:, 3] - bbox[:, 1])
return nd.concat(bbox, area.expand_dims(1)) return nd.concat(bbox, area.expand_dims(1))
def extract_edge_bbox(g): def extract_edge_bbox(g):
'''bbox encoding''' """bbox encoding"""
src, dst = g.edges(order='eid') src, dst = g.edges(order="eid")
n = g.number_of_edges() n = g.number_of_edges()
src_bbox = g.ndata['pred_bbox'][src.asnumpy()] src_bbox = g.ndata["pred_bbox"][src.asnumpy()]
dst_bbox = g.ndata['pred_bbox'][dst.asnumpy()] dst_bbox = g.ndata["pred_bbox"][dst.asnumpy()]
edge_bbox = nd.zeros((n, 4), ctx=g.ndata['pred_bbox'].context) edge_bbox = nd.zeros((n, 4), ctx=g.ndata["pred_bbox"].context)
edge_bbox[:,0] = nd.stack(src_bbox[:,0], dst_bbox[:,0]).min(axis=0) edge_bbox[:, 0] = nd.stack(src_bbox[:, 0], dst_bbox[:, 0]).min(axis=0)
edge_bbox[:,1] = nd.stack(src_bbox[:,1], dst_bbox[:,1]).min(axis=0) edge_bbox[:, 1] = nd.stack(src_bbox[:, 1], dst_bbox[:, 1]).min(axis=0)
edge_bbox[:,2] = nd.stack(src_bbox[:,2], dst_bbox[:,2]).max(axis=0) edge_bbox[:, 2] = nd.stack(src_bbox[:, 2], dst_bbox[:, 2]).max(axis=0)
edge_bbox[:,3] = nd.stack(src_bbox[:,3], dst_bbox[:,3]).max(axis=0) edge_bbox[:, 3] = nd.stack(src_bbox[:, 3], dst_bbox[:, 3]).max(axis=0)
return edge_bbox return edge_bbox
def build_graph_train(g_slice, gt_bbox, img, ids, scores, bbox, feat_ind,
spatial_feat, iou_thresh=0.5, def build_graph_train(
bbox_improvement=True, scores_top_k=50, overlap=False): g_slice,
'''given ground truth and predicted bboxes, assign the label to the predicted w.r.t iou_thresh''' gt_bbox,
img,
ids,
scores,
bbox,
feat_ind,
spatial_feat,
iou_thresh=0.5,
bbox_improvement=True,
scores_top_k=50,
overlap=False,
):
"""given ground truth and predicted bboxes, assign the label to the predicted w.r.t iou_thresh"""
# match and re-factor the graph # match and re-factor the graph
img_size = img.shape[2:4] img_size = img.shape[2:4]
gt_bbox[:, :, 0] /= img_size[1] gt_bbox[:, :, 0] /= img_size[1]
...@@ -39,24 +54,33 @@ def build_graph_train(g_slice, gt_bbox, img, ids, scores, bbox, feat_ind, ...@@ -39,24 +54,33 @@ def build_graph_train(g_slice, gt_bbox, img, ids, scores, bbox, feat_ind,
g_pred_batch = [] g_pred_batch = []
for gi in range(n_graph): for gi in range(n_graph):
g = g_slice[gi] g = g_slice[gi]
ctx = g.ndata['bbox'].context ctx = g.ndata["bbox"].context
inds = np.where(scores[gi, :, 0].asnumpy() > 0)[0].tolist() inds = np.where(scores[gi, :, 0].asnumpy() > 0)[0].tolist()
if len(inds) == 0: if len(inds) == 0:
return None return None
if len(inds) > scores_top_k: if len(inds) > scores_top_k:
top_score_inds = scores[gi, inds, 0].asnumpy().argsort()[::-1][0:scores_top_k] top_score_inds = (
scores[gi, inds, 0].asnumpy().argsort()[::-1][0:scores_top_k]
)
inds = np.array(inds)[top_score_inds].tolist() inds = np.array(inds)[top_score_inds].tolist()
n_nodes = len(inds) n_nodes = len(inds)
roi_ind = feat_ind[gi, inds].squeeze(axis=1) roi_ind = feat_ind[gi, inds].squeeze(axis=1)
g_pred = dgl.DGLGraph() g_pred = dgl.DGLGraph()
g_pred.add_nodes(n_nodes, {'pred_bbox': bbox[gi, inds], g_pred.add_nodes(
'node_feat': spatial_feat[gi, roi_ind], n_nodes,
'node_class_pred': ids[gi, inds, 0], {
'node_class_logit': nd.log(scores[gi, inds, 0] + 1e-7)}) "pred_bbox": bbox[gi, inds],
"node_feat": spatial_feat[gi, roi_ind],
"node_class_pred": ids[gi, inds, 0],
"node_class_logit": nd.log(scores[gi, inds, 0] + 1e-7),
},
)
# iou matching # iou matching
ious = nd.contrib.box_iou(gt_bbox[gi], g_pred.ndata['pred_bbox']).asnumpy() ious = nd.contrib.box_iou(
gt_bbox[gi], g_pred.ndata["pred_bbox"]
).asnumpy()
H, W = ious.shape H, W = ious.shape
h = H h = H
w = W w = W
...@@ -70,8 +94,8 @@ def build_graph_train(g_slice, gt_bbox, img, ids, scores, bbox, feat_ind, ...@@ -70,8 +94,8 @@ def build_graph_train(g_slice, gt_bbox, img, ids, scores, bbox, feat_ind,
if ious[row_ind, col_ind] < iou_thresh: if ious[row_ind, col_ind] < iou_thresh:
break break
pred_to_gt_ind[col_ind] = row_ind pred_to_gt_ind[col_ind] = row_ind
gt_node_class = g.ndata['node_class'][row_ind] gt_node_class = g.ndata["node_class"][row_ind]
pred_node_class = g_pred.ndata['node_class_pred'][col_ind] pred_node_class = g_pred.ndata["node_class_pred"][col_ind]
if gt_node_class == pred_node_class: if gt_node_class == pred_node_class:
pred_to_gt_class_match[col_ind] = 1 pred_to_gt_class_match[col_ind] = 1
pred_to_gt_class_match_id[col_ind] = row_ind pred_to_gt_class_match_id[col_ind] = row_ind
...@@ -84,7 +108,7 @@ def build_graph_train(g_slice, gt_bbox, img, ids, scores, bbox, feat_ind, ...@@ -84,7 +108,7 @@ def build_graph_train(g_slice, gt_bbox, img, ids, scores, bbox, feat_ind,
triplet = [] triplet = []
adjmat = np.zeros((n_nodes, n_nodes)) adjmat = np.zeros((n_nodes, n_nodes))
src, dst = g.all_edges(order='eid') src, dst = g.all_edges(order="eid")
eid_keys = np.column_stack([src.asnumpy(), dst.asnumpy()]) eid_keys = np.column_stack([src.asnumpy(), dst.asnumpy()])
eid_dict = {} eid_dict = {}
for i, key in enumerate(eid_keys): for i, key in enumerate(eid_keys):
...@@ -93,7 +117,7 @@ def build_graph_train(g_slice, gt_bbox, img, ids, scores, bbox, feat_ind, ...@@ -93,7 +117,7 @@ def build_graph_train(g_slice, gt_bbox, img, ids, scores, bbox, feat_ind,
eid_dict[k] = [i] eid_dict[k] = [i]
else: else:
eid_dict[k].append(i) eid_dict[k].append(i)
ori_rel_class = g.edata['rel_class'].asnumpy() ori_rel_class = g.edata["rel_class"].asnumpy()
for i in range(n_nodes): for i in range(n_nodes):
for j in range(n_nodes): for j in range(n_nodes):
if i != j: if i != j:
...@@ -105,25 +129,27 @@ def build_graph_train(g_slice, gt_bbox, img, ids, scores, bbox, feat_ind, ...@@ -105,25 +129,27 @@ def build_graph_train(g_slice, gt_bbox, img, ids, scores, bbox, feat_ind,
n_edges_between = len(rel_cls) n_edges_between = len(rel_cls)
for ii in range(n_edges_between): for ii in range(n_edges_between):
triplet.append((i, j, rel_cls[ii])) triplet.append((i, j, rel_cls[ii]))
adjmat[i,j] = 1 adjmat[i, j] = 1
else: else:
triplet.append((i, j, 0)) triplet.append((i, j, 0))
src, dst, rel_class = tuple(zip(*triplet)) src, dst, rel_class = tuple(zip(*triplet))
rel_class = nd.array(rel_class, ctx=ctx).expand_dims(1) rel_class = nd.array(rel_class, ctx=ctx).expand_dims(1)
g_pred.add_edges(src, dst, data={'rel_class': rel_class}) g_pred.add_edges(src, dst, data={"rel_class": rel_class})
# other operations # other operations
n_nodes = g_pred.number_of_nodes() n_nodes = g_pred.number_of_nodes()
n_edges = g_pred.number_of_edges() n_edges = g_pred.number_of_edges()
if bbox_improvement: if bbox_improvement:
g_pred.ndata['pred_bbox'] = bbox_improve(g_pred.ndata['pred_bbox']) g_pred.ndata["pred_bbox"] = bbox_improve(g_pred.ndata["pred_bbox"])
g_pred.edata['rel_bbox'] = extract_edge_bbox(g_pred) g_pred.edata["rel_bbox"] = extract_edge_bbox(g_pred)
g_pred.edata['batch_id'] = nd.zeros((n_edges, 1), ctx = ctx) + gi g_pred.edata["batch_id"] = nd.zeros((n_edges, 1), ctx=ctx) + gi
# remove non-overlapping edges # remove non-overlapping edges
if overlap: if overlap:
overlap_ious = nd.contrib.box_iou(g_pred.ndata['pred_bbox'][:,0:4], overlap_ious = nd.contrib.box_iou(
g_pred.ndata['pred_bbox'][:,0:4]).asnumpy() g_pred.ndata["pred_bbox"][:, 0:4],
g_pred.ndata["pred_bbox"][:, 0:4],
).asnumpy()
cols, rows = np.where(overlap_ious <= 1e-7) cols, rows = np.where(overlap_ious <= 1e-7)
if cols.shape[0] > 0: if cols.shape[0] > 0:
eids = g_pred.edge_ids(cols, rows)[2].asnumpy().tolist() eids = g_pred.edge_ids(cols, rows)[2].asnumpy().tolist()
...@@ -132,15 +158,17 @@ def build_graph_train(g_slice, gt_bbox, img, ids, scores, bbox, feat_ind, ...@@ -132,15 +158,17 @@ def build_graph_train(g_slice, gt_bbox, img, ids, scores, bbox, feat_ind,
if g_pred.number_of_edges() == 0: if g_pred.number_of_edges() == 0:
g_pred = None g_pred = None
g_pred_batch.append(g_pred) g_pred_batch.append(g_pred)
if n_graph > 1: if n_graph > 1:
return dgl.batch(g_pred_batch) return dgl.batch(g_pred_batch)
else: else:
return g_pred_batch[0] return g_pred_batch[0]
def build_graph_validate_gt_obj(img, gt_ids, bbox, spatial_feat,
bbox_improvement=True, overlap=False): def build_graph_validate_gt_obj(
'''given ground truth bbox and label, build graph for validation''' img, gt_ids, bbox, spatial_feat, bbox_improvement=True, overlap=False
):
"""given ground truth bbox and label, build graph for validation"""
n_batch = img.shape[0] n_batch = img.shape[0]
img_size = img.shape[2:4] img_size = img.shape[2:4]
bbox[:, :, 0] /= img_size[1] bbox[:, :, 0] /= img_size[1]
...@@ -156,10 +184,17 @@ def build_graph_validate_gt_obj(img, gt_ids, bbox, spatial_feat, ...@@ -156,10 +184,17 @@ def build_graph_validate_gt_obj(img, gt_ids, bbox, spatial_feat,
continue continue
n_nodes = len(inds) n_nodes = len(inds)
g_pred = dgl.DGLGraph() g_pred = dgl.DGLGraph()
g_pred.add_nodes(n_nodes, {'pred_bbox': bbox[btc, inds], g_pred.add_nodes(
'node_feat': spatial_feat[btc, inds], n_nodes,
'node_class_pred': gt_ids[btc, inds, 0], {
'node_class_logit': nd.zeros_like(gt_ids[btc, inds, 0], ctx=ctx)}) "pred_bbox": bbox[btc, inds],
"node_feat": spatial_feat[btc, inds],
"node_class_pred": gt_ids[btc, inds, 0],
"node_class_logit": nd.zeros_like(
gt_ids[btc, inds, 0], ctx=ctx
),
},
)
edge_list = [] edge_list = []
for i in range(n_nodes - 1): for i in range(n_nodes - 1):
...@@ -172,21 +207,30 @@ def build_graph_validate_gt_obj(img, gt_ids, bbox, spatial_feat, ...@@ -172,21 +207,30 @@ def build_graph_validate_gt_obj(img, gt_ids, bbox, spatial_feat,
n_nodes = g_pred.number_of_nodes() n_nodes = g_pred.number_of_nodes()
n_edges = g_pred.number_of_edges() n_edges = g_pred.number_of_edges()
if bbox_improvement: if bbox_improvement:
g_pred.ndata['pred_bbox'] = bbox_improve(g_pred.ndata['pred_bbox']) g_pred.ndata["pred_bbox"] = bbox_improve(g_pred.ndata["pred_bbox"])
g_pred.edata['rel_bbox'] = extract_edge_bbox(g_pred) g_pred.edata["rel_bbox"] = extract_edge_bbox(g_pred)
g_pred.edata['batch_id'] = nd.zeros((n_edges, 1), ctx = ctx) + btc g_pred.edata["batch_id"] = nd.zeros((n_edges, 1), ctx=ctx) + btc
g_batch.append(g_pred) g_batch.append(g_pred)
if len(g_batch) == 0: if len(g_batch) == 0:
return None return None
if len(g_batch) > 1: if len(g_batch) > 1:
return dgl.batch(g_batch) return dgl.batch(g_batch)
return g_batch[0] return g_batch[0]
def build_graph_validate_gt_bbox(img, ids, scores, bbox, spatial_feat, gt_ids=None,
bbox_improvement=True, overlap=False): def build_graph_validate_gt_bbox(
'''given ground truth bbox, build graph for validation''' img,
ids,
scores,
bbox,
spatial_feat,
gt_ids=None,
bbox_improvement=True,
overlap=False,
):
"""given ground truth bbox, build graph for validation"""
n_batch = img.shape[0] n_batch = img.shape[0]
img_size = img.shape[2:4] img_size = img.shape[2:4]
bbox[:, :, 0] /= img_size[1] bbox[:, :, 0] /= img_size[1]
...@@ -197,17 +241,22 @@ def build_graph_validate_gt_bbox(img, ids, scores, bbox, spatial_feat, gt_ids=No ...@@ -197,17 +241,22 @@ def build_graph_validate_gt_bbox(img, ids, scores, bbox, spatial_feat, gt_ids=No
g_batch = [] g_batch = []
for btc in range(n_batch): for btc in range(n_batch):
id_btc = scores[btc][:,:,0].argmax(0) id_btc = scores[btc][:, :, 0].argmax(0)
score_btc = scores[btc][:,:,0].max(0) score_btc = scores[btc][:, :, 0].max(0)
inds = np.where(bbox[btc].sum(1).asnumpy() > 0)[0].tolist() inds = np.where(bbox[btc].sum(1).asnumpy() > 0)[0].tolist()
if len(inds) == 0: if len(inds) == 0:
continue continue
n_nodes = len(inds) n_nodes = len(inds)
g_pred = dgl.DGLGraph() g_pred = dgl.DGLGraph()
g_pred.add_nodes(n_nodes, {'pred_bbox': bbox[btc, inds], g_pred.add_nodes(
'node_feat': spatial_feat[btc, inds], n_nodes,
'node_class_pred': id_btc, {
'node_class_logit': nd.log(score_btc + 1e-7)}) "pred_bbox": bbox[btc, inds],
"node_feat": spatial_feat[btc, inds],
"node_class_pred": id_btc,
"node_class_logit": nd.log(score_btc + 1e-7),
},
)
edge_list = [] edge_list = []
for i in range(n_nodes - 1): for i in range(n_nodes - 1):
...@@ -220,21 +269,31 @@ def build_graph_validate_gt_bbox(img, ids, scores, bbox, spatial_feat, gt_ids=No ...@@ -220,21 +269,31 @@ def build_graph_validate_gt_bbox(img, ids, scores, bbox, spatial_feat, gt_ids=No
n_nodes = g_pred.number_of_nodes() n_nodes = g_pred.number_of_nodes()
n_edges = g_pred.number_of_edges() n_edges = g_pred.number_of_edges()
if bbox_improvement: if bbox_improvement:
g_pred.ndata['pred_bbox'] = bbox_improve(g_pred.ndata['pred_bbox']) g_pred.ndata["pred_bbox"] = bbox_improve(g_pred.ndata["pred_bbox"])
g_pred.edata['rel_bbox'] = extract_edge_bbox(g_pred) g_pred.edata["rel_bbox"] = extract_edge_bbox(g_pred)
g_pred.edata['batch_id'] = nd.zeros((n_edges, 1), ctx = ctx) + btc g_pred.edata["batch_id"] = nd.zeros((n_edges, 1), ctx=ctx) + btc
g_batch.append(g_pred) g_batch.append(g_pred)
if len(g_batch) == 0: if len(g_batch) == 0:
return None return None
if len(g_batch) > 1: if len(g_batch) > 1:
return dgl.batch(g_batch) return dgl.batch(g_batch)
return g_batch[0] return g_batch[0]
def build_graph_validate_pred(img, ids, scores, bbox, feat_ind, spatial_feat,
bbox_improvement=True, scores_top_k=50, overlap=False): def build_graph_validate_pred(
'''given predicted bbox, build graph for validation''' img,
ids,
scores,
bbox,
feat_ind,
spatial_feat,
bbox_improvement=True,
scores_top_k=50,
overlap=False,
):
"""given predicted bbox, build graph for validation"""
n_batch = img.shape[0] n_batch = img.shape[0]
img_size = img.shape[2:4] img_size = img.shape[2:4]
bbox[:, :, 0] /= img_size[1] bbox[:, :, 0] /= img_size[1]
...@@ -249,16 +308,23 @@ def build_graph_validate_pred(img, ids, scores, bbox, feat_ind, spatial_feat, ...@@ -249,16 +308,23 @@ def build_graph_validate_pred(img, ids, scores, bbox, feat_ind, spatial_feat,
if len(inds) == 0: if len(inds) == 0:
continue continue
if len(inds) > scores_top_k: if len(inds) > scores_top_k:
top_score_inds = scores[btc, inds, 0].asnumpy().argsort()[::-1][0:scores_top_k] top_score_inds = (
scores[btc, inds, 0].asnumpy().argsort()[::-1][0:scores_top_k]
)
inds = np.array(inds)[top_score_inds].tolist() inds = np.array(inds)[top_score_inds].tolist()
n_nodes = len(inds) n_nodes = len(inds)
roi_ind = feat_ind[btc, inds].squeeze(axis=1) roi_ind = feat_ind[btc, inds].squeeze(axis=1)
g_pred = dgl.DGLGraph() g_pred = dgl.DGLGraph()
g_pred.add_nodes(n_nodes, {'pred_bbox': bbox[btc, inds], g_pred.add_nodes(
'node_feat': spatial_feat[btc, roi_ind], n_nodes,
'node_class_pred': ids[btc, inds, 0], {
'node_class_logit': nd.log(scores[btc, inds, 0] + 1e-7)}) "pred_bbox": bbox[btc, inds],
"node_feat": spatial_feat[btc, roi_ind],
"node_class_pred": ids[btc, inds, 0],
"node_class_logit": nd.log(scores[btc, inds, 0] + 1e-7),
},
)
edge_list = [] edge_list = []
for i in range(n_nodes - 1): for i in range(n_nodes - 1):
...@@ -271,14 +337,14 @@ def build_graph_validate_pred(img, ids, scores, bbox, feat_ind, spatial_feat, ...@@ -271,14 +337,14 @@ def build_graph_validate_pred(img, ids, scores, bbox, feat_ind, spatial_feat,
n_nodes = g_pred.number_of_nodes() n_nodes = g_pred.number_of_nodes()
n_edges = g_pred.number_of_edges() n_edges = g_pred.number_of_edges()
if bbox_improvement: if bbox_improvement:
g_pred.ndata['pred_bbox'] = bbox_improve(g_pred.ndata['pred_bbox']) g_pred.ndata["pred_bbox"] = bbox_improve(g_pred.ndata["pred_bbox"])
g_pred.edata['rel_bbox'] = extract_edge_bbox(g_pred) g_pred.edata["rel_bbox"] = extract_edge_bbox(g_pred)
g_pred.edata['batch_id'] = nd.zeros((n_edges, 1), ctx = ctx) + btc g_pred.edata["batch_id"] = nd.zeros((n_edges, 1), ctx=ctx) + btc
g_batch.append(g_pred) g_batch.append(g_pred)
if len(g_batch) == 0: if len(g_batch) == 0:
return None return None
if len(g_batch) > 1: if len(g_batch) > 1:
return dgl.batch(g_batch) return dgl.batch(g_batch)
return g_batch[0] return g_batch[0]
import dgl import logging
import time
from operator import attrgetter, itemgetter
import mxnet as mx import mxnet as mx
import numpy as np import numpy as np
import logging, time from gluoncv.data.batchify import Pad
from operator import attrgetter, itemgetter from gluoncv.model_zoo import get_model
from mxnet import nd, gluon from mxnet import gluon, nd
from mxnet.gluon import nn from mxnet.gluon import nn
from dgl.utils import toindex
import dgl
from dgl.nn.mxnet import GraphConv from dgl.nn.mxnet import GraphConv
from gluoncv.model_zoo import get_model from dgl.utils import toindex
from gluoncv.data.batchify import Pad
def iou(boxA, boxB): def iou(boxA, boxB):
# determine the (x, y)-coordinates of the intersection rectangle # determine the (x, y)-coordinates of the intersection rectangle
...@@ -16,9 +20,9 @@ def iou(boxA, boxB): ...@@ -16,9 +20,9 @@ def iou(boxA, boxB):
yA = max(boxA[1], boxB[1]) yA = max(boxA[1], boxB[1])
xB = min(boxA[2], boxB[2]) xB = min(boxA[2], boxB[2])
yB = min(boxA[3], boxB[3]) yB = min(boxA[3], boxB[3])
interArea = max(0, xB - xA) * max(0, yB - yA) interArea = max(0, xB - xA) * max(0, yB - yA)
if interArea < 1e-7 : if interArea < 1e-7:
return 0 return 0
boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1]) boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
...@@ -29,12 +33,14 @@ def iou(boxA, boxB): ...@@ -29,12 +33,14 @@ def iou(boxA, boxB):
iou_val = interArea / float(boxAArea + boxBArea - interArea) iou_val = interArea / float(boxAArea + boxBArea - interArea)
return iou_val return iou_val
def object_iou_thresh(gt_object, pred_object, iou_thresh=0.5): def object_iou_thresh(gt_object, pred_object, iou_thresh=0.5):
obj_iou = iou(gt_object[1:5], pred_object[1:5]) obj_iou = iou(gt_object[1:5], pred_object[1:5])
if obj_iou >= iou_thresh: if obj_iou >= iou_thresh:
return True return True
return False return False
def triplet_iou_thresh(pred_triplet, gt_triplet, iou_thresh=0.5): def triplet_iou_thresh(pred_triplet, gt_triplet, iou_thresh=0.5):
sub_iou = iou(gt_triplet[5:9], pred_triplet[5:9]) sub_iou = iou(gt_triplet[5:9], pred_triplet[5:9])
if sub_iou >= iou_thresh: if sub_iou >= iou_thresh:
...@@ -43,10 +49,11 @@ def triplet_iou_thresh(pred_triplet, gt_triplet, iou_thresh=0.5): ...@@ -43,10 +49,11 @@ def triplet_iou_thresh(pred_triplet, gt_triplet, iou_thresh=0.5):
return True return True
return False return False
@mx.metric.register @mx.metric.register
@mx.metric.alias('auc') @mx.metric.alias("auc")
class AUCMetric(mx.metric.EvalMetric): class AUCMetric(mx.metric.EvalMetric):
def __init__(self, name='auc', eps=1e-12): def __init__(self, name="auc", eps=1e-12):
super(AUCMetric, self).__init__(name) super(AUCMetric, self).__init__(name)
self.eps = eps self.eps = eps
...@@ -78,12 +85,14 @@ class AUCMetric(mx.metric.EvalMetric): ...@@ -78,12 +85,14 @@ class AUCMetric(mx.metric.EvalMetric):
self.sum_metric += area / total_area self.sum_metric += area / total_area
self.num_inst += 1 self.num_inst += 1
@mx.metric.register @mx.metric.register
@mx.metric.alias('predcls') @mx.metric.alias("predcls")
class PredCls(mx.metric.EvalMetric): class PredCls(mx.metric.EvalMetric):
'''Metric with ground truth object location and label''' """Metric with ground truth object location and label"""
def __init__(self, topk=20, iou_thresh=0.99): def __init__(self, topk=20, iou_thresh=0.99):
super(PredCls, self).__init__('predcls@%d'%(topk)) super(PredCls, self).__init__("predcls@%d" % (topk))
self.topk = topk self.topk = topk
self.iou_thresh = iou_thresh self.iou_thresh = iou_thresh
...@@ -91,7 +100,7 @@ class PredCls(mx.metric.EvalMetric): ...@@ -91,7 +100,7 @@ class PredCls(mx.metric.EvalMetric):
if labels is None or preds is None: if labels is None or preds is None:
self.num_inst += 1 self.num_inst += 1
return return
preds = preds[preds[:,0].argsort()[::-1]] preds = preds[preds[:, 0].argsort()[::-1]]
m = min(self.topk, preds.shape[0]) m = min(self.topk, preds.shape[0])
count = 0 count = 0
gt_edge_num = labels.shape[0] gt_edge_num = labels.shape[0]
...@@ -102,8 +111,9 @@ class PredCls(mx.metric.EvalMetric): ...@@ -102,8 +111,9 @@ class PredCls(mx.metric.EvalMetric):
if label_matched[j]: if label_matched[j]:
continue continue
label = labels[j] label = labels[j]
if int(label[2]) == int(pred[2]) and \ if int(label[2]) == int(pred[2]) and triplet_iou_thresh(
triplet_iou_thresh(pred, label, self.iou_thresh): pred, label, self.iou_thresh
):
count += 1 count += 1
label_matched[j] = True label_matched[j] = True
...@@ -111,12 +121,14 @@ class PredCls(mx.metric.EvalMetric): ...@@ -111,12 +121,14 @@ class PredCls(mx.metric.EvalMetric):
self.sum_metric += count / total self.sum_metric += count / total
self.num_inst += 1 self.num_inst += 1
@mx.metric.register @mx.metric.register
@mx.metric.alias('phrcls') @mx.metric.alias("phrcls")
class PhrCls(mx.metric.EvalMetric): class PhrCls(mx.metric.EvalMetric):
'''Metric with ground truth object location and predicted object label from detector''' """Metric with ground truth object location and predicted object label from detector"""
def __init__(self, topk=20, iou_thresh=0.99): def __init__(self, topk=20, iou_thresh=0.99):
super(PhrCls, self).__init__('phrcls@%d'%(topk)) super(PhrCls, self).__init__("phrcls@%d" % (topk))
self.topk = topk self.topk = topk
self.iou_thresh = iou_thresh self.iou_thresh = iou_thresh
...@@ -124,7 +136,7 @@ class PhrCls(mx.metric.EvalMetric): ...@@ -124,7 +136,7 @@ class PhrCls(mx.metric.EvalMetric):
if labels is None or preds is None: if labels is None or preds is None:
self.num_inst += 1 self.num_inst += 1
return return
preds = preds[preds[:,1].argsort()[::-1]] preds = preds[preds[:, 1].argsort()[::-1]]
m = min(self.topk, preds.shape[0]) m = min(self.topk, preds.shape[0])
count = 0 count = 0
gt_edge_num = labels.shape[0] gt_edge_num = labels.shape[0]
...@@ -135,22 +147,26 @@ class PhrCls(mx.metric.EvalMetric): ...@@ -135,22 +147,26 @@ class PhrCls(mx.metric.EvalMetric):
if label_matched[j]: if label_matched[j]:
continue continue
label = labels[j] label = labels[j]
if int(label[2]) == int(pred[2]) and \ if (
int(label[3]) == int(pred[3]) and \ int(label[2]) == int(pred[2])
int(label[4]) == int(pred[4]) and \ and int(label[3]) == int(pred[3])
triplet_iou_thresh(pred, label, self.iou_thresh): and int(label[4]) == int(pred[4])
and triplet_iou_thresh(pred, label, self.iou_thresh)
):
count += 1 count += 1
label_matched[j] = True label_matched[j] = True
total = labels.shape[0] total = labels.shape[0]
self.sum_metric += count / total self.sum_metric += count / total
self.num_inst += 1 self.num_inst += 1
@mx.metric.register @mx.metric.register
@mx.metric.alias('sgdet') @mx.metric.alias("sgdet")
class SGDet(mx.metric.EvalMetric): class SGDet(mx.metric.EvalMetric):
'''Metric with predicted object information by the detector''' """Metric with predicted object information by the detector"""
def __init__(self, topk=20, iou_thresh=0.5): def __init__(self, topk=20, iou_thresh=0.5):
super(SGDet, self).__init__('sgdet@%d'%(topk)) super(SGDet, self).__init__("sgdet@%d" % (topk))
self.topk = topk self.topk = topk
self.iou_thresh = iou_thresh self.iou_thresh = iou_thresh
...@@ -158,7 +174,7 @@ class SGDet(mx.metric.EvalMetric): ...@@ -158,7 +174,7 @@ class SGDet(mx.metric.EvalMetric):
if labels is None or preds is None: if labels is None or preds is None:
self.num_inst += 1 self.num_inst += 1
return return
preds = preds[preds[:,1].argsort()[::-1]] preds = preds[preds[:, 1].argsort()[::-1]]
m = min(self.topk, len(preds)) m = min(self.topk, len(preds))
count = 0 count = 0
gt_edge_num = labels.shape[0] gt_edge_num = labels.shape[0]
...@@ -169,22 +185,26 @@ class SGDet(mx.metric.EvalMetric): ...@@ -169,22 +185,26 @@ class SGDet(mx.metric.EvalMetric):
if label_matched[j]: if label_matched[j]:
continue continue
label = labels[j] label = labels[j]
if int(label[2]) == int(pred[2]) and \ if (
int(label[3]) == int(pred[3]) and \ int(label[2]) == int(pred[2])
int(label[4]) == int(pred[4]) and \ and int(label[3]) == int(pred[3])
triplet_iou_thresh(pred, label, self.iou_thresh): and int(label[4]) == int(pred[4])
and triplet_iou_thresh(pred, label, self.iou_thresh)
):
count += 1 count += 1
label_matched[j] =True label_matched[j] = True
total = labels.shape[0] total = labels.shape[0]
self.sum_metric += count / total self.sum_metric += count / total
self.num_inst += 1 self.num_inst += 1
@mx.metric.register @mx.metric.register
@mx.metric.alias('sgdet+') @mx.metric.alias("sgdet+")
class SGDetPlus(mx.metric.EvalMetric): class SGDetPlus(mx.metric.EvalMetric):
'''Metric proposed by `Graph R-CNN for Scene Graph Generation`''' """Metric proposed by `Graph R-CNN for Scene Graph Generation`"""
def __init__(self, topk=20, iou_thresh=0.5): def __init__(self, topk=20, iou_thresh=0.5):
super(SGDetPlus, self).__init__('sgdet+@%d'%(topk)) super(SGDetPlus, self).__init__("sgdet+@%d" % (topk))
self.topk = topk self.topk = topk
self.iou_thresh = iou_thresh self.iou_thresh = iou_thresh
...@@ -205,13 +225,14 @@ class SGDetPlus(mx.metric.EvalMetric): ...@@ -205,13 +225,14 @@ class SGDetPlus(mx.metric.EvalMetric):
if object_matched[j]: if object_matched[j]:
continue continue
label = label_objects[j] label = label_objects[j]
if int(label[0]) == int(pred[0]) and \ if int(label[0]) == int(pred[0]) and object_iou_thresh(
object_iou_thresh(pred, label, self.iou_thresh): pred, label, self.iou_thresh
):
count += 1 count += 1
object_matched[j] = True object_matched[j] = True
# count predicate and triplet # count predicate and triplet
pred_triplets = pred_triplets[pred_triplets[:,1].argsort()[::-1]] pred_triplets = pred_triplets[pred_triplets[:, 1].argsort()[::-1]]
m = min(self.topk, len(pred_triplets)) m = min(self.topk, len(pred_triplets))
gt_triplet_num = label_triplets.shape[0] gt_triplet_num = label_triplets.shape[0]
triplet_matched = [False for label in label_triplets] triplet_matched = [False for label in label_triplets]
...@@ -221,15 +242,18 @@ class SGDetPlus(mx.metric.EvalMetric): ...@@ -221,15 +242,18 @@ class SGDetPlus(mx.metric.EvalMetric):
for j in range(gt_triplet_num): for j in range(gt_triplet_num):
label = label_triplets[j] label = label_triplets[j]
if not predicate_matched: if not predicate_matched:
if int(label[2]) == int(pred[2]) and \ if int(label[2]) == int(pred[2]) and triplet_iou_thresh(
triplet_iou_thresh(pred, label, self.iou_thresh): pred, label, self.iou_thresh
):
count += label[3] count += label[3]
predicate_matched[j] = True predicate_matched[j] = True
if not triplet_matched[j]: if not triplet_matched[j]:
if int(label[2]) == int(pred[2]) and \ if (
int(label[3]) == int(pred[3]) and \ int(label[2]) == int(pred[2])
int(label[4]) == int(pred[4]) and \ and int(label[3]) == int(pred[3])
triplet_iou_thresh(pred, label, self.iou_thresh): and int(label[4]) == int(pred[4])
and triplet_iou_thresh(pred, label, self.iou_thresh)
):
count += 1 count += 1
triplet_matched[j] = True triplet_matched[j] = True
# compute sum # compute sum
...@@ -238,27 +262,28 @@ class SGDetPlus(mx.metric.EvalMetric): ...@@ -238,27 +262,28 @@ class SGDetPlus(mx.metric.EvalMetric):
self.sum_metric += count / N self.sum_metric += count / N
self.num_inst += 1 self.num_inst += 1
def extract_gt(g, img_size): def extract_gt(g, img_size):
'''extract prediction from ground truth graph''' """extract prediction from ground truth graph"""
if g is None or g.number_of_nodes() == 0: if g is None or g.number_of_nodes() == 0:
return None, None return None, None
gt_eids = np.where(g.edata['rel_class'].asnumpy() > 0)[0] gt_eids = np.where(g.edata["rel_class"].asnumpy() > 0)[0]
if len(gt_eids) == 0: if len(gt_eids) == 0:
return None, None return None, None
gt_class = g.ndata['node_class'][:,0].asnumpy() gt_class = g.ndata["node_class"][:, 0].asnumpy()
gt_bbox = g.ndata['bbox'].asnumpy() gt_bbox = g.ndata["bbox"].asnumpy()
gt_bbox[:, 0] /= img_size[1] gt_bbox[:, 0] /= img_size[1]
gt_bbox[:, 1] /= img_size[0] gt_bbox[:, 1] /= img_size[0]
gt_bbox[:, 2] /= img_size[1] gt_bbox[:, 2] /= img_size[1]
gt_bbox[:, 3] /= img_size[0] gt_bbox[:, 3] /= img_size[0]
gt_objects = np.vstack([gt_class, gt_bbox.transpose(1, 0)]).transpose(1, 0) gt_objects = np.vstack([gt_class, gt_bbox.transpose(1, 0)]).transpose(1, 0)
gt_node_ids = g.find_edges(gt_eids) gt_node_ids = g.find_edges(gt_eids)
gt_node_sub = gt_node_ids[0].asnumpy() gt_node_sub = gt_node_ids[0].asnumpy()
gt_node_ob = gt_node_ids[1].asnumpy() gt_node_ob = gt_node_ids[1].asnumpy()
gt_rel_class = g.edata['rel_class'][gt_eids,0].asnumpy() - 1 gt_rel_class = g.edata["rel_class"][gt_eids, 0].asnumpy() - 1
gt_sub_class = gt_class[gt_node_sub] gt_sub_class = gt_class[gt_node_sub]
gt_ob_class = gt_class[gt_node_ob] gt_ob_class = gt_class[gt_node_ob]
...@@ -266,32 +291,42 @@ def extract_gt(g, img_size): ...@@ -266,32 +291,42 @@ def extract_gt(g, img_size):
gt_ob_bbox = gt_bbox[gt_node_ob] gt_ob_bbox = gt_bbox[gt_node_ob]
n = len(gt_eids) n = len(gt_eids)
gt_triplets = np.vstack([np.ones(n), np.ones(n), gt_triplets = np.vstack(
gt_rel_class, gt_sub_class, gt_ob_class, [
gt_sub_bbox.transpose(1, 0), np.ones(n),
gt_ob_bbox.transpose(1, 0)]).transpose(1, 0) np.ones(n),
gt_rel_class,
gt_sub_class,
gt_ob_class,
gt_sub_bbox.transpose(1, 0),
gt_ob_bbox.transpose(1, 0),
]
).transpose(1, 0)
return gt_objects, gt_triplets return gt_objects, gt_triplets
def extract_pred(g, topk=100, joint_preds=False): def extract_pred(g, topk=100, joint_preds=False):
'''extract prediction from prediction graph for validation and visualization''' """extract prediction from prediction graph for validation and visualization"""
if g is None or g.number_of_nodes() == 0: if g is None or g.number_of_nodes() == 0:
return None, None return None, None
pred_class = g.ndata['node_class_pred'].asnumpy() pred_class = g.ndata["node_class_pred"].asnumpy()
pred_class_prob = g.ndata['node_class_logit'].asnumpy() pred_class_prob = g.ndata["node_class_logit"].asnumpy()
pred_bbox = g.ndata['pred_bbox'][:,0:4].asnumpy() pred_bbox = g.ndata["pred_bbox"][:, 0:4].asnumpy()
pred_objects = np.vstack([pred_class, pred_bbox.transpose(1, 0)]).transpose(1, 0) pred_objects = np.vstack([pred_class, pred_bbox.transpose(1, 0)]).transpose(
1, 0
)
score_pred = g.edata['score_pred'].asnumpy() score_pred = g.edata["score_pred"].asnumpy()
score_phr = g.edata['score_phr'].asnumpy() score_phr = g.edata["score_phr"].asnumpy()
score_pred_topk_eids = (-score_pred).argsort()[0:topk].tolist() score_pred_topk_eids = (-score_pred).argsort()[0:topk].tolist()
score_phr_topk_eids = (-score_phr).argsort()[0:topk].tolist() score_phr_topk_eids = (-score_phr).argsort()[0:topk].tolist()
topk_eids = sorted(list(set(score_pred_topk_eids + score_phr_topk_eids))) topk_eids = sorted(list(set(score_pred_topk_eids + score_phr_topk_eids)))
pred_rel_prob = g.edata['preds'][topk_eids].asnumpy() pred_rel_prob = g.edata["preds"][topk_eids].asnumpy()
if joint_preds: if joint_preds:
pred_rel_class = pred_rel_prob[:,1:].argmax(axis=1) pred_rel_class = pred_rel_prob[:, 1:].argmax(axis=1)
else: else:
pred_rel_class = pred_rel_prob.argmax(axis=1) pred_rel_class = pred_rel_prob.argmax(axis=1)
...@@ -307,8 +342,15 @@ def extract_pred(g, topk=100, joint_preds=False): ...@@ -307,8 +342,15 @@ def extract_pred(g, topk=100, joint_preds=False):
pred_ob_class_prob = pred_class_prob[pred_node_ob] pred_ob_class_prob = pred_class_prob[pred_node_ob]
pred_ob_bbox = pred_bbox[pred_node_ob] pred_ob_bbox = pred_bbox[pred_node_ob]
pred_triplets = np.vstack([score_pred[topk_eids], score_phr[topk_eids], pred_triplets = np.vstack(
pred_rel_class, pred_sub_class, pred_ob_class, [
pred_sub_bbox.transpose(1, 0), score_pred[topk_eids],
pred_ob_bbox.transpose(1, 0)]).transpose(1, 0) score_phr[topk_eids],
pred_rel_class,
pred_sub_class,
pred_ob_class,
pred_sub_bbox.transpose(1, 0),
pred_ob_bbox.transpose(1, 0),
]
).transpose(1, 0)
return pred_objects, pred_triplets return pred_objects, pred_triplets
import dgl
from dgl.utils import toindex
import mxnet as mx import mxnet as mx
import numpy as np import numpy as np
import dgl
from dgl.utils import toindex
def l0_sample(g, positive_max=128, negative_ratio=3): def l0_sample(g, positive_max=128, negative_ratio=3):
'''sampling positive and negative edges''' """sampling positive and negative edges"""
if g is None: if g is None:
return None return None
n_eids = g.number_of_edges() n_eids = g.number_of_edges()
pos_eids = np.where(g.edata['rel_class'].asnumpy() > 0)[0] pos_eids = np.where(g.edata["rel_class"].asnumpy() > 0)[0]
neg_eids = np.where(g.edata['rel_class'].asnumpy() == 0)[0] neg_eids = np.where(g.edata["rel_class"].asnumpy() == 0)[0]
if len(pos_eids) == 0: if len(pos_eids) == 0:
return None return None
...@@ -26,6 +28,7 @@ def l0_sample(g, positive_max=128, negative_ratio=3): ...@@ -26,6 +28,7 @@ def l0_sample(g, positive_max=128, negative_ratio=3):
eids = np.where(weights > 0)[0] eids = np.where(weights > 0)[0]
sub_g = g.edge_subgraph(toindex(eids.tolist())) sub_g = g.edge_subgraph(toindex(eids.tolist()))
sub_g.copy_from_parent() sub_g.copy_from_parent()
sub_g.edata['sample_weights'] = mx.nd.array(weights[eids], sub_g.edata["sample_weights"] = mx.nd.array(
ctx=g.edata['rel_class'].context) weights[eids], ctx=g.edata["rel_class"].context
)
return sub_g return sub_g
import numpy as np
import gluoncv as gcv import gluoncv as gcv
import numpy as np
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
def plot_sg(img, preds, obj_classes, rel_classes, topk=1): def plot_sg(img, preds, obj_classes, rel_classes, topk=1):
'''visualization of generated scene graph''' """visualization of generated scene graph"""
size = img.shape[0:2] size = img.shape[0:2]
box_scale = np.array([size[1], size[0], size[1], size[0]]) box_scale = np.array([size[1], size[0], size[1], size[0]])
topk = min(topk, preds.shape[0]) topk = min(topk, preds.shape[0])
...@@ -17,30 +18,51 @@ def plot_sg(img, preds, obj_classes, rel_classes, topk=1): ...@@ -17,30 +18,51 @@ def plot_sg(img, preds, obj_classes, rel_classes, topk=1):
rel_name = rel_classes[rel] rel_name = rel_classes[rel]
src_bbox = preds[i, 5:9] * box_scale src_bbox = preds[i, 5:9] * box_scale
dst_bbox = preds[i, 9:13] * box_scale dst_bbox = preds[i, 9:13] * box_scale
src_center = np.array([(src_bbox[0] + src_bbox[2]) / 2, (src_bbox[1] + src_bbox[3]) / 2]) src_center = np.array(
dst_center = np.array([(dst_bbox[0] + dst_bbox[2]) / 2, (dst_bbox[1] + dst_bbox[3]) / 2]) [(src_bbox[0] + src_bbox[2]) / 2, (src_bbox[1] + src_bbox[3]) / 2]
)
dst_center = np.array(
[(dst_bbox[0] + dst_bbox[2]) / 2, (dst_bbox[1] + dst_bbox[3]) / 2]
)
rel_center = (src_center + dst_center) / 2 rel_center = (src_center + dst_center) / 2
line_x = np.array([(src_bbox[0] + src_bbox[2]) / 2, (dst_bbox[0] + dst_bbox[2]) / 2]) line_x = np.array(
line_y = np.array([(src_bbox[1] + src_bbox[3]) / 2, (dst_bbox[1] + dst_bbox[3]) / 2]) [(src_bbox[0] + src_bbox[2]) / 2, (dst_bbox[0] + dst_bbox[2]) / 2]
)
ax.plot(line_x, line_y, line_y = np.array(
linewidth=3.0, alpha=0.7, color=plt.cm.cool(rel)) [(src_bbox[1] + src_bbox[3]) / 2, (dst_bbox[1] + dst_bbox[3]) / 2]
)
ax.text(src_center[0], src_center[1],
'{:s}'.format(src_name), ax.plot(
bbox=dict(alpha=0.5), line_x, line_y, linewidth=3.0, alpha=0.7, color=plt.cm.cool(rel)
fontsize=12, color='white') )
ax.text(dst_center[0], dst_center[1],
'{:s}'.format(dst_name), ax.text(
bbox=dict(alpha=0.5), src_center[0],
fontsize=12, color='white') src_center[1],
ax.text(rel_center[0], rel_center[1], "{:s}".format(src_name),
'{:s}'.format(rel_name), bbox=dict(alpha=0.5),
bbox=dict(alpha=0.5), fontsize=12,
fontsize=12, color='white') color="white",
)
ax.text(
dst_center[0],
dst_center[1],
"{:s}".format(dst_name),
bbox=dict(alpha=0.5),
fontsize=12,
color="white",
)
ax.text(
rel_center[0],
rel_center[1],
"{:s}".format(rel_name),
bbox=dict(alpha=0.5),
fontsize=12,
color="white",
)
return ax return ax
plot_sg(img, preds, 2)
plot_sg(img, preds, 2)
import dgl import argparse
import logging
import time
import mxnet as mx import mxnet as mx
import numpy as np import numpy as np
import logging, time, argparse from data import *
from mxnet import nd, gluon
from gluoncv.data.batchify import Pad from gluoncv.data.batchify import Pad
from model import RelDN, faster_rcnn_resnet101_v1d_custom
from model import faster_rcnn_resnet101_v1d_custom, RelDN from mxnet import gluon, nd
from utils import * from utils import *
from data import *
import dgl
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description='Validate Pre-trained RelDN Model.') parser = argparse.ArgumentParser(
parser.add_argument('--gpus', type=str, default='0', description="Validate Pre-trained RelDN Model."
help="Training with GPUs, you can specify 1,3 for example.") )
parser.add_argument('--batch-size', type=int, default=8, parser.add_argument(
help="Total batch-size for training.") "--gpus",
parser.add_argument('--metric', type=str, default='sgdet', type=str,
help="Evaluation metric, could be 'predcls', 'phrcls', 'sgdet' or 'sgdet+'.") default="0",
parser.add_argument('--pretrained-faster-rcnn-params', type=str, required=True, help="Training with GPUs, you can specify 1,3 for example.",
help="Path to saved Faster R-CNN model parameters.") )
parser.add_argument('--reldn-params', type=str, required=True, parser.add_argument(
help="Path to saved Faster R-CNN model parameters.") "--batch-size",
parser.add_argument('--faster-rcnn-params', type=str, required=True, type=int,
help="Path to saved Faster R-CNN model parameters.") default=8,
parser.add_argument('--log-dir', type=str, default='reldn_output.log', help="Total batch-size for training.",
help="Path to save training logs.") )
parser.add_argument('--freq-prior', type=str, default='freq_prior.pkl', parser.add_argument(
help="Path to saved frequency prior data.") "--metric",
parser.add_argument('--verbose-freq', type=int, default=100, type=str,
help="Frequency of log printing in number of iterations.") default="sgdet",
help="Evaluation metric, could be 'predcls', 'phrcls', 'sgdet' or 'sgdet+'.",
)
parser.add_argument(
"--pretrained-faster-rcnn-params",
type=str,
required=True,
help="Path to saved Faster R-CNN model parameters.",
)
parser.add_argument(
"--reldn-params",
type=str,
required=True,
help="Path to saved Faster R-CNN model parameters.",
)
parser.add_argument(
"--faster-rcnn-params",
type=str,
required=True,
help="Path to saved Faster R-CNN model parameters.",
)
parser.add_argument(
"--log-dir",
type=str,
default="reldn_output.log",
help="Path to save training logs.",
)
parser.add_argument(
"--freq-prior",
type=str,
default="freq_prior.pkl",
help="Path to saved frequency prior data.",
)
parser.add_argument(
"--verbose-freq",
type=int,
default=100,
help="Frequency of log printing in number of iterations.",
)
args = parser.parse_args() args = parser.parse_args()
return args return args
args = parse_args() args = parse_args()
filehandler = logging.FileHandler(args.log_dir) filehandler = logging.FileHandler(args.log_dir)
streamhandler = logging.StreamHandler() streamhandler = logging.StreamHandler()
logger = logging.getLogger('') logger = logging.getLogger("")
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
logger.addHandler(filehandler) logger.addHandler(filehandler)
logger.addHandler(streamhandler) logger.addHandler(streamhandler)
# Hyperparams # Hyperparams
ctx = [mx.gpu(int(i)) for i in args.gpus.split(',') if i.strip()] ctx = [mx.gpu(int(i)) for i in args.gpus.split(",") if i.strip()]
if ctx: if ctx:
num_gpus = len(ctx) num_gpus = len(ctx)
assert args.batch_size % num_gpus == 0 assert args.batch_size % num_gpus == 0
...@@ -58,46 +101,65 @@ batch_verbose_freq = args.verbose_freq ...@@ -58,46 +101,65 @@ batch_verbose_freq = args.verbose_freq
mode = args.metric mode = args.metric
metric_list = [] metric_list = []
topk_list = [20, 50, 100] topk_list = [20, 50, 100]
if mode == 'predcls': if mode == "predcls":
for topk in topk_list: for topk in topk_list:
metric_list.append(PredCls(topk=topk)) metric_list.append(PredCls(topk=topk))
if mode == 'phrcls': if mode == "phrcls":
for topk in topk_list: for topk in topk_list:
metric_list.append(PhrCls(topk=topk)) metric_list.append(PhrCls(topk=topk))
if mode == 'sgdet': if mode == "sgdet":
for topk in topk_list: for topk in topk_list:
metric_list.append(SGDet(topk=topk)) metric_list.append(SGDet(topk=topk))
if mode == 'sgdet+': if mode == "sgdet+":
for topk in topk_list: for topk in topk_list:
metric_list.append(SGDetPlus(topk=topk)) metric_list.append(SGDetPlus(topk=topk))
for metric in metric_list: for metric in metric_list:
metric.reset() metric.reset()
semantic_only = False semantic_only = False
net = RelDN(n_classes=N_relations, prior_pkl=args.freq_prior, net = RelDN(
semantic_only=semantic_only) n_classes=N_relations,
prior_pkl=args.freq_prior,
semantic_only=semantic_only,
)
net.load_parameters(args.reldn_params, ctx=ctx) net.load_parameters(args.reldn_params, ctx=ctx)
# dataset and dataloader # dataset and dataloader
vg_val = VGRelation(split='val') vg_val = VGRelation(split="val")
logger.info('data loaded!') logger.info("data loaded!")
val_data = gluon.data.DataLoader(vg_val, batch_size=len(ctx), shuffle=False, num_workers=16*num_gpus, val_data = gluon.data.DataLoader(
batchify_fn=dgl_mp_batchify_fn) vg_val,
batch_size=len(ctx),
shuffle=False,
num_workers=16 * num_gpus,
batchify_fn=dgl_mp_batchify_fn,
)
n_batches = len(val_data) n_batches = len(val_data)
detector = faster_rcnn_resnet101_v1d_custom(classes=vg_val.obj_classes, detector = faster_rcnn_resnet101_v1d_custom(
pretrained_base=False, pretrained=False, classes=vg_val.obj_classes,
additional_output=True) pretrained_base=False,
pretrained=False,
additional_output=True,
)
params_path = args.pretrained_faster_rcnn_params params_path = args.pretrained_faster_rcnn_params
detector.load_parameters(params_path, ctx=ctx, ignore_extra=True, allow_missing=True) detector.load_parameters(
params_path, ctx=ctx, ignore_extra=True, allow_missing=True
)
detector_feat = faster_rcnn_resnet101_v1d_custom(classes=vg_val.obj_classes, detector_feat = faster_rcnn_resnet101_v1d_custom(
pretrained_base=False, pretrained=False, classes=vg_val.obj_classes,
additional_output=True) pretrained_base=False,
detector_feat.load_parameters(params_path, ctx=ctx, ignore_extra=True, allow_missing=True) pretrained=False,
additional_output=True,
)
detector_feat.load_parameters(
params_path, ctx=ctx, ignore_extra=True, allow_missing=True
)
detector_feat.features.load_parameters(args.faster_rcnn_params, ctx=ctx) detector_feat.features.load_parameters(args.faster_rcnn_params, ctx=ctx)
def get_data_batch(g_list, img_list, ctx_list): def get_data_batch(g_list, img_list, ctx_list):
if g_list is None or len(g_list) == 0: if g_list is None or len(g_list) == 0:
return None, None return None, None
...@@ -106,27 +168,39 @@ def get_data_batch(g_list, img_list, ctx_list): ...@@ -106,27 +168,39 @@ def get_data_batch(g_list, img_list, ctx_list):
if size < n_gpu: if size < n_gpu:
raise Exception("too small batch") raise Exception("too small batch")
step = size // n_gpu step = size // n_gpu
G_list = [g_list[i*step:(i+1)*step] if i < n_gpu - 1 else g_list[i*step:size] for i in range(n_gpu)] G_list = [
img_list = [img_list[i*step:(i+1)*step] if i < n_gpu - 1 else img_list[i*step:size] for i in range(n_gpu)] g_list[i * step : (i + 1) * step]
if i < n_gpu - 1
else g_list[i * step : size]
for i in range(n_gpu)
]
img_list = [
img_list[i * step : (i + 1) * step]
if i < n_gpu - 1
else img_list[i * step : size]
for i in range(n_gpu)
]
for G_slice, ctx in zip(G_list, ctx_list): for G_slice, ctx in zip(G_list, ctx_list):
for G in G_slice: for G in G_slice:
G.ndata['bbox'] = G.ndata['bbox'].as_in_context(ctx) G.ndata["bbox"] = G.ndata["bbox"].as_in_context(ctx)
G.ndata['node_class'] = G.ndata['node_class'].as_in_context(ctx) G.ndata["node_class"] = G.ndata["node_class"].as_in_context(ctx)
G.ndata['node_class_vec'] = G.ndata['node_class_vec'].as_in_context(ctx) G.ndata["node_class_vec"] = G.ndata["node_class_vec"].as_in_context(
G.edata['rel_class'] = G.edata['rel_class'].as_in_context(ctx) ctx
)
G.edata["rel_class"] = G.edata["rel_class"].as_in_context(ctx)
img_list = [img.as_in_context(ctx) for img in img_list] img_list = [img.as_in_context(ctx) for img in img_list]
return G_list, img_list return G_list, img_list
for i, (G_list, img_list) in enumerate(val_data): for i, (G_list, img_list) in enumerate(val_data):
G_list, img_list = get_data_batch(G_list, img_list, ctx) G_list, img_list = get_data_batch(G_list, img_list, ctx)
if G_list is None or img_list is None: if G_list is None or img_list is None:
if (i+1) % batch_verbose_freq == 0: if (i + 1) % batch_verbose_freq == 0:
print_txt = 'Batch[%d/%d] '%\ print_txt = "Batch[%d/%d] " % (i, n_batches)
(i, n_batches)
for metric in metric_list: for metric in metric_list:
metric_name, metric_val = metric.get() metric_name, metric_val = metric.get()
print_txt += '%s=%.4f '%(metric_name, metric_val) print_txt += "%s=%.4f " % (metric_name, metric_val)
logger.info(print_txt) logger.info(print_txt)
continue continue
...@@ -136,31 +210,57 @@ for i, (G_list, img_list) in enumerate(val_data): ...@@ -136,31 +210,57 @@ for i, (G_list, img_list) in enumerate(val_data):
# loss_cls_val = 0 # loss_cls_val = 0
for G_slice, img in zip(G_list, img_list): for G_slice, img in zip(G_list, img_list):
cur_ctx = img.context cur_ctx = img.context
if mode == 'predcls': if mode == "predcls":
bbox_list = [G.ndata['bbox'] for G in G_slice] bbox_list = [G.ndata["bbox"] for G in G_slice]
bbox_stack = bbox_pad(bbox_list).as_in_context(cur_ctx) bbox_stack = bbox_pad(bbox_list).as_in_context(cur_ctx)
ids, scores, bbox, spatial_feat = detector(img, None, None, bbox_stack) ids, scores, bbox, spatial_feat = detector(
img, None, None, bbox_stack
)
node_class_list = [G.ndata['node_class'] for G in G_slice] node_class_list = [G.ndata["node_class"] for G in G_slice]
node_class_stack = bbox_pad(node_class_list).as_in_context(cur_ctx) node_class_stack = bbox_pad(node_class_list).as_in_context(cur_ctx)
g_pred_batch = build_graph_validate_gt_obj(img, node_class_stack, bbox, spatial_feat, g_pred_batch = build_graph_validate_gt_obj(
bbox_improvement=True, overlap=False) img,
elif mode == 'phrcls': node_class_stack,
bbox,
spatial_feat,
bbox_improvement=True,
overlap=False,
)
elif mode == "phrcls":
# use ground truth bbox # use ground truth bbox
bbox_list = [G.ndata['bbox'] for G in G_slice] bbox_list = [G.ndata["bbox"] for G in G_slice]
bbox_stack = bbox_pad(bbox_list).as_in_context(cur_ctx) bbox_stack = bbox_pad(bbox_list).as_in_context(cur_ctx)
ids, scores, bbox, spatial_feat = detector(img, None, None, bbox_stack) ids, scores, bbox, spatial_feat = detector(
img, None, None, bbox_stack
)
g_pred_batch = build_graph_validate_gt_bbox(img, ids, scores, bbox, spatial_feat, g_pred_batch = build_graph_validate_gt_bbox(
bbox_improvement=True, overlap=False) img,
ids,
scores,
bbox,
spatial_feat,
bbox_improvement=True,
overlap=False,
)
else: else:
# use predicted bbox # use predicted bbox
ids, scores, bbox, feat, feat_ind, spatial_feat = detector(img) ids, scores, bbox, feat, feat_ind, spatial_feat = detector(img)
g_pred_batch = build_graph_validate_pred(img, ids, scores, bbox, feat_ind, spatial_feat, g_pred_batch = build_graph_validate_pred(
bbox_improvement=True, scores_top_k=75, overlap=False) img,
ids,
scores,
bbox,
feat_ind,
spatial_feat,
bbox_improvement=True,
scores_top_k=75,
overlap=False,
)
if not semantic_only: if not semantic_only:
rel_bbox = g_pred_batch.edata['rel_bbox'] rel_bbox = g_pred_batch.edata["rel_bbox"]
batch_id = g_pred_batch.edata['batch_id'].asnumpy() batch_id = g_pred_batch.edata["batch_id"].asnumpy()
n_sample_edges = g_pred_batch.number_of_edges() n_sample_edges = g_pred_batch.number_of_edges()
# g_pred_batch.edata['edge_feat'] = mx.nd.zeros((n_sample_edges, 49), ctx=cur_ctx) # g_pred_batch.edata['edge_feat'] = mx.nd.zeros((n_sample_edges, 49), ctx=cur_ctx)
n_graph = len(G_slice) n_graph = len(G_slice)
...@@ -170,13 +270,19 @@ for i, (G_list, img_list) in enumerate(val_data): ...@@ -170,13 +270,19 @@ for i, (G_list, img_list) in enumerate(val_data):
if len(eids) > 0: if len(eids) > 0:
bbox_rel_list.append(rel_bbox[eids]) bbox_rel_list.append(rel_bbox[eids])
bbox_rel_stack = bbox_pad(bbox_rel_list).as_in_context(cur_ctx) bbox_rel_stack = bbox_pad(bbox_rel_list).as_in_context(cur_ctx)
_, _, _, spatial_feat_rel = detector_feat(img, None, None, bbox_rel_stack) _, _, _, spatial_feat_rel = detector_feat(
img, None, None, bbox_rel_stack
)
spatial_feat_rel_list = [] spatial_feat_rel_list = []
for j in range(n_graph): for j in range(n_graph):
eids = np.where(batch_id == j)[0] eids = np.where(batch_id == j)[0]
if len(eids) > 0: if len(eids) > 0:
spatial_feat_rel_list.append(spatial_feat_rel[j, 0:len(eids)]) spatial_feat_rel_list.append(
g_pred_batch.edata['edge_feat'] = nd.concat(*spatial_feat_rel_list, dim=0) spatial_feat_rel[j, 0 : len(eids)]
)
g_pred_batch.edata["edge_feat"] = nd.concat(
*spatial_feat_rel_list, dim=0
)
G_batch.append(g_pred_batch) G_batch.append(g_pred_batch)
...@@ -189,23 +295,25 @@ for i, (G_list, img_list) in enumerate(val_data): ...@@ -189,23 +295,25 @@ for i, (G_list, img_list) in enumerate(val_data):
gt_objects, gt_triplet = extract_gt(G_gt, img_slice.shape[2:4]) gt_objects, gt_triplet = extract_gt(G_gt, img_slice.shape[2:4])
pred_objects, pred_triplet = extract_pred(G_pred, joint_preds=True) pred_objects, pred_triplet = extract_pred(G_pred, joint_preds=True)
for metric in metric_list: for metric in metric_list:
if isinstance(metric, PredCls) or \ if (
isinstance(metric, PhrCls) or \ isinstance(metric, PredCls)
isinstance(metric, SGDet): or isinstance(metric, PhrCls)
or isinstance(metric, SGDet)
):
metric.update(gt_triplet, pred_triplet) metric.update(gt_triplet, pred_triplet)
else: else:
metric.update((gt_objects, gt_triplet), (pred_objects, pred_triplet)) metric.update(
if (i+1) % batch_verbose_freq == 0: (gt_objects, gt_triplet), (pred_objects, pred_triplet)
print_txt = 'Batch[%d/%d] '%\ )
(i, n_batches) if (i + 1) % batch_verbose_freq == 0:
print_txt = "Batch[%d/%d] " % (i, n_batches)
for metric in metric_list: for metric in metric_list:
metric_name, metric_val = metric.get() metric_name, metric_val = metric.get()
print_txt += '%s=%.4f '%(metric_name, metric_val) print_txt += "%s=%.4f " % (metric_name, metric_val)
logger.info(print_txt) logger.info(print_txt)
print_txt = 'Batch[%d/%d] '%\ print_txt = "Batch[%d/%d] " % (n_batches, n_batches)
(n_batches, n_batches)
for metric in metric_list: for metric in metric_list:
metric_name, metric_val = metric.get() metric_name, metric_val = metric.get()
print_txt += '%s=%.4f '%(metric_name, metric_val) print_txt += "%s=%.4f " % (metric_name, metric_val)
logger.info(print_txt) logger.info(print_txt)
...@@ -5,14 +5,18 @@ Paper: https://arxiv.org/abs/1902.07153 ...@@ -5,14 +5,18 @@ Paper: https://arxiv.org/abs/1902.07153
Code: https://github.com/Tiiiger/SGC Code: https://github.com/Tiiiger/SGC
SGC implementation in DGL. SGC implementation in DGL.
""" """
import argparse, time, math import argparse
import numpy as np import math
import time
import mxnet as mx import mxnet as mx
from mxnet import nd, gluon import numpy as np
from mxnet import gluon, nd
from mxnet.gluon import nn from mxnet.gluon import nn
import dgl import dgl
from dgl.data import register_data_args from dgl.data import (CiteseerGraphDataset, CoraGraphDataset,
from dgl.data import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset PubmedGraphDataset, register_data_args)
from dgl.nn.mxnet.conv import SGConv from dgl.nn.mxnet.conv import SGConv
...@@ -21,16 +25,17 @@ def evaluate(model, g, features, labels, mask): ...@@ -21,16 +25,17 @@ def evaluate(model, g, features, labels, mask):
accuracy = ((pred == labels) * mask).sum() / mask.sum().asscalar() accuracy = ((pred == labels) * mask).sum() / mask.sum().asscalar()
return accuracy.asscalar() return accuracy.asscalar()
def main(args): def main(args):
# load and preprocess dataset # load and preprocess dataset
if args.dataset == 'cora': if args.dataset == "cora":
data = CoraGraphDataset() data = CoraGraphDataset()
elif args.dataset == 'citeseer': elif args.dataset == "citeseer":
data = CiteseerGraphDataset() data = CiteseerGraphDataset()
elif args.dataset == 'pubmed': elif args.dataset == "pubmed":
data = PubmedGraphDataset() data = PubmedGraphDataset()
else: else:
raise ValueError('Unknown dataset: {}'.format(args.dataset)) raise ValueError("Unknown dataset: {}".format(args.dataset))
g = data[0] g = data[0]
if args.gpu < 0: if args.gpu < 0:
...@@ -41,35 +46,36 @@ def main(args): ...@@ -41,35 +46,36 @@ def main(args):
ctx = mx.gpu(args.gpu) ctx = mx.gpu(args.gpu)
g = g.int().to(ctx) g = g.int().to(ctx)
features = g.ndata['feat'] features = g.ndata["feat"]
labels = mx.nd.array(g.ndata['label'], dtype="float32", ctx=ctx) labels = mx.nd.array(g.ndata["label"], dtype="float32", ctx=ctx)
train_mask = g.ndata['train_mask'] train_mask = g.ndata["train_mask"]
val_mask = g.ndata['val_mask'] val_mask = g.ndata["val_mask"]
test_mask = g.ndata['test_mask'] test_mask = g.ndata["test_mask"]
in_feats = features.shape[1] in_feats = features.shape[1]
n_classes = data.num_labels n_classes = data.num_labels
n_edges = data.graph.number_of_edges() n_edges = data.graph.number_of_edges()
print("""----Data statistics------' print(
"""----Data statistics------'
#Edges %d #Edges %d
#Classes %d #Classes %d
#Train samples %d #Train samples %d
#Val samples %d #Val samples %d
#Test samples %d""" % #Test samples %d"""
(n_edges, n_classes, % (
train_mask.sum().asscalar(), n_edges,
val_mask.sum().asscalar(), n_classes,
test_mask.sum().asscalar())) train_mask.sum().asscalar(),
val_mask.sum().asscalar(),
test_mask.sum().asscalar(),
)
)
# add self loop # add self loop
g = dgl.remove_self_loop(g) g = dgl.remove_self_loop(g)
g = dgl.add_self_loop(g) g = dgl.add_self_loop(g)
# create SGC model # create SGC model
model = SGConv(in_feats, model = SGConv(in_feats, n_classes, k=2, cached=True, bias=args.bias)
n_classes,
k=2,
cached=True,
bias=args.bias)
model.initialize(ctx=ctx) model.initialize(ctx=ctx)
n_train_samples = train_mask.sum().asscalar() n_train_samples = train_mask.sum().asscalar()
...@@ -77,8 +83,11 @@ def main(args): ...@@ -77,8 +83,11 @@ def main(args):
# use optimizer # use optimizer
print(model.collect_params()) print(model.collect_params())
trainer = gluon.Trainer(model.collect_params(), 'adam', trainer = gluon.Trainer(
{'learning_rate': args.lr, 'wd': args.weight_decay}) model.collect_params(),
"adam",
{"learning_rate": args.lr, "wd": args.weight_decay},
)
# initialize graph # initialize graph
dur = [] dur = []
...@@ -98,28 +107,36 @@ def main(args): ...@@ -98,28 +107,36 @@ def main(args):
loss.asscalar() loss.asscalar()
dur.append(time.time() - t0) dur.append(time.time() - t0)
acc = evaluate(model, g, features, labels, val_mask) acc = evaluate(model, g, features, labels, val_mask)
print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | " print(
"ETputs(KTEPS) {:.2f}". format( "Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
epoch, np.mean(dur), loss.asscalar(), acc, n_edges / np.mean(dur) / 1000)) "ETputs(KTEPS) {:.2f}".format(
epoch,
np.mean(dur),
loss.asscalar(),
acc,
n_edges / np.mean(dur) / 1000,
)
)
# test set accuracy # test set accuracy
acc = evaluate(model, g, features, labels, test_mask) acc = evaluate(model, g, features, labels, test_mask)
print("Test accuracy {:.2%}".format(acc)) print("Test accuracy {:.2%}".format(acc))
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser(description='SGC') parser = argparse.ArgumentParser(description="SGC")
register_data_args(parser) register_data_args(parser)
parser.add_argument("--gpu", type=int, default=-1, parser.add_argument("--gpu", type=int, default=-1, help="gpu")
help="gpu") parser.add_argument("--lr", type=float, default=0.2, help="learning rate")
parser.add_argument("--lr", type=float, default=0.2, parser.add_argument(
help="learning rate") "--bias", action="store_true", default=False, help="flag to use bias"
parser.add_argument("--bias", action='store_true', default=False, )
help="flag to use bias") parser.add_argument(
parser.add_argument("--n-epochs", type=int, default=100, "--n-epochs", type=int, default=100, help="number of training epochs"
help="number of training epochs") )
parser.add_argument("--weight-decay", type=float, default=5e-6, parser.add_argument(
help="Weight for L2 loss") "--weight-decay", type=float, default=5e-6, help="Weight for L2 loss"
)
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
......
...@@ -6,18 +6,15 @@ References: ...@@ -6,18 +6,15 @@ References:
""" """
import mxnet as mx import mxnet as mx
from mxnet import gluon from mxnet import gluon
import dgl import dgl
from dgl.nn.mxnet import TAGConv from dgl.nn.mxnet import TAGConv
class TAGCN(gluon.Block): class TAGCN(gluon.Block):
def __init__(self, def __init__(
g, self, g, in_feats, n_hidden, n_classes, n_layers, activation, dropout
in_feats, ):
n_hidden,
n_classes,
n_layers,
activation,
dropout):
super(TAGCN, self).__init__() super(TAGCN, self).__init__()
self.g = g self.g = g
self.layers = gluon.nn.Sequential() self.layers = gluon.nn.Sequential()
...@@ -27,7 +24,7 @@ class TAGCN(gluon.Block): ...@@ -27,7 +24,7 @@ class TAGCN(gluon.Block):
for i in range(n_layers - 1): for i in range(n_layers - 1):
self.layers.add(TAGConv(n_hidden, n_hidden, activation=activation)) self.layers.add(TAGConv(n_hidden, n_hidden, activation=activation))
# output layer # output layer
self.layers.add(TAGConv(n_hidden, n_classes)) #activation=None self.layers.add(TAGConv(n_hidden, n_classes)) # activation=None
self.dropout = gluon.nn.Dropout(rate=dropout) self.dropout = gluon.nn.Dropout(rate=dropout)
def forward(self, features): def forward(self, features):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment