"tests/python/vscode:/vscode.git/clone" did not exist on "a1051f0095c43218636f7be7d66d80b705439e6f"
Unverified Commit a9f2acf3 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4641)



* [Misc] Black auto fix.

* sort
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 08c50eb7
import argparse import argparse
import time import time
import numpy as np
import networkx as nx
import mxnet as mx import mxnet as mx
import networkx as nx
import numpy as np
from mxnet import gluon, nd from mxnet import gluon, nd
from mxnet.gluon import nn from mxnet.gluon import nn
import dgl import dgl
from dgl.data import register_data_args from dgl.data import (CiteseerGraphDataset, CoraGraphDataset,
from dgl.data import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset PubmedGraphDataset, register_data_args)
from dgl.nn.mxnet.conv import GMMConv from dgl.nn.mxnet.conv import GMMConv
class MoNet(nn.Block): class MoNet(nn.Block):
def __init__(self, def __init__(
g, self,
in_feats, g,
n_hidden, in_feats,
out_feats, n_hidden,
n_layers, out_feats,
dim, n_layers,
n_kernels, dim,
dropout): n_kernels,
dropout,
):
super(MoNet, self).__init__() super(MoNet, self).__init__()
self.g = g self.g = g
with self.name_scope(): with self.name_scope():
...@@ -28,18 +32,19 @@ class MoNet(nn.Block): ...@@ -28,18 +32,19 @@ class MoNet(nn.Block):
self.pseudo_proj = nn.Sequential() self.pseudo_proj = nn.Sequential()
# Input layer # Input layer
self.layers.add( self.layers.add(GMMConv(in_feats, n_hidden, dim, n_kernels))
GMMConv(in_feats, n_hidden, dim, n_kernels)) self.pseudo_proj.add(nn.Dense(dim, in_units=2, activation="tanh"))
self.pseudo_proj.add(nn.Dense(dim, in_units=2, activation='tanh'))
# Hidden layer # Hidden layer
for _ in range(n_layers - 1): for _ in range(n_layers - 1):
self.layers.add(GMMConv(n_hidden, n_hidden, dim, n_kernels)) self.layers.add(GMMConv(n_hidden, n_hidden, dim, n_kernels))
self.pseudo_proj.add(nn.Dense(dim, in_units=2, activation='tanh')) self.pseudo_proj.add(
nn.Dense(dim, in_units=2, activation="tanh")
)
# Output layer # Output layer
self.layers.add(GMMConv(n_hidden, out_feats, dim, n_kernels)) self.layers.add(GMMConv(n_hidden, out_feats, dim, n_kernels))
self.pseudo_proj.add(nn.Dense(dim, in_units=2, activation='tanh')) self.pseudo_proj.add(nn.Dense(dim, in_units=2, activation="tanh"))
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
...@@ -48,8 +53,7 @@ class MoNet(nn.Block): ...@@ -48,8 +53,7 @@ class MoNet(nn.Block):
for i in range(len(self.layers)): for i in range(len(self.layers)):
if i > 0: if i > 0:
h = self.dropout(h) h = self.dropout(h)
h = self.layers[i]( h = self.layers[i](self.g, h, self.pseudo_proj[i](pseudo))
self.g, h, self.pseudo_proj[i](pseudo))
return h return h
...@@ -58,16 +62,17 @@ def evaluate(model, features, pseudo, labels, mask): ...@@ -58,16 +62,17 @@ def evaluate(model, features, pseudo, labels, mask):
accuracy = ((pred == labels) * mask).sum() / mask.sum().asscalar() accuracy = ((pred == labels) * mask).sum() / mask.sum().asscalar()
return accuracy.asscalar() return accuracy.asscalar()
def main(args): def main(args):
# load and preprocess dataset # load and preprocess dataset
if args.dataset == 'cora': if args.dataset == "cora":
data = CoraGraphDataset() data = CoraGraphDataset()
elif args.dataset == 'citeseer': elif args.dataset == "citeseer":
data = CiteseerGraphDataset() data = CiteseerGraphDataset()
elif args.dataset == 'pubmed': elif args.dataset == "pubmed":
data = PubmedGraphDataset() data = PubmedGraphDataset()
else: else:
raise ValueError('Unknown dataset: {}'.format(args.dataset)) raise ValueError("Unknown dataset: {}".format(args.dataset))
g = data[0] g = data[0]
if args.gpu < 0: if args.gpu < 0:
...@@ -78,24 +83,29 @@ def main(args): ...@@ -78,24 +83,29 @@ def main(args):
ctx = mx.gpu(args.gpu) ctx = mx.gpu(args.gpu)
g = g.to(ctx) g = g.to(ctx)
features = g.ndata['feat'] features = g.ndata["feat"]
labels = mx.nd.array(g.ndata['label'], dtype="float32", ctx=ctx) labels = mx.nd.array(g.ndata["label"], dtype="float32", ctx=ctx)
train_mask = g.ndata['train_mask'] train_mask = g.ndata["train_mask"]
val_mask = g.ndata['val_mask'] val_mask = g.ndata["val_mask"]
test_mask = g.ndata['test_mask'] test_mask = g.ndata["test_mask"]
in_feats = features.shape[1] in_feats = features.shape[1]
n_classes = data.num_labels n_classes = data.num_labels
n_edges = data.graph.number_of_edges() n_edges = data.graph.number_of_edges()
print("""----Data statistics------' print(
"""----Data statistics------'
#Edges %d #Edges %d
#Classes %d #Classes %d
#Train samples %d #Train samples %d
#Val samples %d #Val samples %d
#Test samples %d""" % #Test samples %d"""
(n_edges, n_classes, % (
train_mask.sum().asscalar(), n_edges,
val_mask.sum().asscalar(), n_classes,
test_mask.sum().asscalar())) train_mask.sum().asscalar(),
val_mask.sum().asscalar(),
test_mask.sum().asscalar(),
)
)
# add self loop # add self loop
g = dgl.remove_self_loop(g) g = dgl.remove_self_loop(g)
...@@ -107,30 +117,32 @@ def main(args): ...@@ -107,30 +117,32 @@ def main(args):
vs = vs.asnumpy() vs = vs.asnumpy()
pseudo = [] pseudo = []
for i in range(g.number_of_edges()): for i in range(g.number_of_edges()):
pseudo.append([ pseudo.append(
1 / np.sqrt(g.in_degree(us[i])), [1 / np.sqrt(g.in_degree(us[i])), 1 / np.sqrt(g.in_degree(vs[i]))]
1 / np.sqrt(g.in_degree(vs[i])) )
])
pseudo = nd.array(pseudo, ctx=ctx) pseudo = nd.array(pseudo, ctx=ctx)
# create GraphSAGE model # create GraphSAGE model
model = MoNet(g, model = MoNet(
in_feats, g,
args.n_hidden, in_feats,
n_classes, args.n_hidden,
args.n_layers, n_classes,
args.pseudo_dim, args.n_layers,
args.n_kernels, args.pseudo_dim,
args.dropout args.n_kernels,
) args.dropout,
)
model.initialize(ctx=ctx) model.initialize(ctx=ctx)
n_train_samples = train_mask.sum().asscalar() n_train_samples = train_mask.sum().asscalar()
loss_fcn = gluon.loss.SoftmaxCELoss() loss_fcn = gluon.loss.SoftmaxCELoss()
print(model.collect_params()) print(model.collect_params())
trainer = gluon.Trainer(model.collect_params(), 'adam', trainer = gluon.Trainer(
{'learning_rate': args.lr, 'wd': args.weight_decay}) model.collect_params(),
"adam",
{"learning_rate": args.lr, "wd": args.weight_decay},
)
# initialize graph # initialize graph
dur = [] dur = []
...@@ -150,37 +162,55 @@ def main(args): ...@@ -150,37 +162,55 @@ def main(args):
loss.asscalar() loss.asscalar()
dur.append(time.time() - t0) dur.append(time.time() - t0)
acc = evaluate(model, features, pseudo, labels, val_mask) acc = evaluate(model, features, pseudo, labels, val_mask)
print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | " print(
"ETputs(KTEPS) {:.2f}". format( "Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
epoch, np.mean(dur), loss.asscalar(), acc, n_edges / np.mean(dur) / 1000)) "ETputs(KTEPS) {:.2f}".format(
epoch,
np.mean(dur),
loss.asscalar(),
acc,
n_edges / np.mean(dur) / 1000,
)
)
# test set accuracy # test set accuracy
acc = evaluate(model, features, pseudo, labels, test_mask) acc = evaluate(model, features, pseudo, labels, test_mask)
print("Test accuracy {:.2%}".format(acc)) print("Test accuracy {:.2%}".format(acc))
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser(description='MoNet on citation network') parser = argparse.ArgumentParser(description="MoNet on citation network")
register_data_args(parser) register_data_args(parser)
parser.add_argument("--dropout", type=float, default=0.5, parser.add_argument(
help="dropout probability") "--dropout", type=float, default=0.5, help="dropout probability"
parser.add_argument("--gpu", type=int, default=-1, )
help="gpu") parser.add_argument("--gpu", type=int, default=-1, help="gpu")
parser.add_argument("--lr", type=float, default=1e-2, parser.add_argument("--lr", type=float, default=1e-2, help="learning rate")
help="learning rate") parser.add_argument(
parser.add_argument("--n-epochs", type=int, default=200, "--n-epochs", type=int, default=200, help="number of training epochs"
help="number of training epochs") )
parser.add_argument("--n-hidden", type=int, default=16, parser.add_argument(
help="number of hidden gcn units") "--n-hidden", type=int, default=16, help="number of hidden gcn units"
parser.add_argument("--n-layers", type=int, default=1, )
help="number of hidden gcn layers") parser.add_argument(
parser.add_argument("--pseudo-dim", type=int, default=2, "--n-layers", type=int, default=1, help="number of hidden gcn layers"
help="Pseudo coordinate dimensions in GMMConv, 2 for cora and 3 for pubmed") )
parser.add_argument("--n-kernels", type=int, default=3, parser.add_argument(
help="Number of kernels in GMMConv layer") "--pseudo-dim",
parser.add_argument("--weight-decay", type=float, default=5e-5, type=int,
help="Weight for L2 loss") default=2,
help="Pseudo coordinate dimensions in GMMConv, 2 for cora and 3 for pubmed",
)
parser.add_argument(
"--n-kernels",
type=int,
default=3,
help="Number of kernels in GMMConv layer",
)
parser.add_argument(
"--weight-decay", type=float, default=5e-5, help="Weight for L2 loss"
)
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
main(args) main(args)
\ No newline at end of file
import mxnet as mx import mxnet as mx
from mxnet import gluon from mxnet import gluon
class BaseRGCN(gluon.Block): class BaseRGCN(gluon.Block):
def __init__(self, num_nodes, h_dim, out_dim, num_rels, num_bases=-1, def __init__(
num_hidden_layers=1, dropout=0, self,
use_self_loop=False, gpu_id=-1): num_nodes,
h_dim,
out_dim,
num_rels,
num_bases=-1,
num_hidden_layers=1,
dropout=0,
use_self_loop=False,
gpu_id=-1,
):
super(BaseRGCN, self).__init__() super(BaseRGCN, self).__init__()
self.num_nodes = num_nodes self.num_nodes = num_nodes
self.h_dim = h_dim self.h_dim = h_dim
......
from .dataloader import *
from .object import * from .object import *
from .relation import * from .relation import *
from .dataloader import *
"""DataLoader utils.""" """DataLoader utils."""
import dgl
from mxnet import nd
from gluoncv.data.batchify import Pad from gluoncv.data.batchify import Pad
from mxnet import nd
import dgl
def dgl_mp_batchify_fn(data): def dgl_mp_batchify_fn(data):
if isinstance(data[0], tuple): if isinstance(data[0], tuple):
data = zip(*data) data = zip(*data)
return [dgl_mp_batchify_fn(i) for i in data] return [dgl_mp_batchify_fn(i) for i in data]
for dt in data: for dt in data:
if dt is not None: if dt is not None:
if isinstance(dt, dgl.DGLGraph): if isinstance(dt, dgl.DGLGraph):
......
"""Pascal VOC object detection dataset.""" """Pascal VOC object detection dataset."""
from __future__ import absolute_import from __future__ import absolute_import, division
from __future__ import division
import os
import logging
import warnings
import json import json
import logging
import os
import pickle import pickle
import numpy as np import warnings
from collections import Counter
import mxnet as mx import mxnet as mx
import numpy as np
from gluoncv.data import COCODetection from gluoncv.data import COCODetection
from collections import Counter
class VGObject(COCODetection): class VGObject(COCODetection):
CLASSES = ["airplane", "animal", "arm", "bag", "banana", "basket", "beach", CLASSES = [
"bear", "bed", "bench", "bike", "bird", "board", "boat", "book", "airplane",
"boot", "bottle", "bowl", "box", "boy", "branch", "building", "bus", "animal",
"cabinet", "cap", "car", "cat", "chair", "child", "clock", "coat", "arm",
"counter", "cow", "cup", "curtain", "desk", "dog", "door", "drawer", "bag",
"ear", "elephant", "engine", "eye", "face", "fence", "finger", "flag", "banana",
"flower", "food", "fork", "fruit", "giraffe", "girl", "glass", "glove", "basket",
"guy", "hair", "hand", "handle", "hat", "head", "helmet", "hill", "beach",
"horse", "house", "jacket", "jean", "kid", "kite", "lady", "lamp", "bear",
"laptop", "leaf", "leg", "letter", "light", "logo", "man", "men", "bed",
"motorcycle", "mountain", "mouth", "neck", "nose", "number", "orange", "bench",
"pant", "paper", "paw", "people", "person", "phone", "pillow", "pizza", "bike",
"plane", "plant", "plate", "player", "pole", "post", "pot", "racket", "bird",
"railing", "rock", "roof", "room", "screen", "seat", "sheep", "shelf", "board",
"shirt", "shoe", "short", "sidewalk", "sign", "sink", "skateboard", "boat",
"ski", "skier", "sneaker", "snow", "sock", "stand", "street", "book",
"surfboard", "table", "tail", "tie", "tile", "tire", "toilet", "boot",
"towel", "tower", "track", "train", "tree", "truck", "trunk", "bottle",
"umbrella", "vase", "vegetable", "vehicle", "wave", "wheel", "bowl",
"window", "windshield", "wing", "wire", "woman", "zebra"] "box",
"boy",
"branch",
"building",
"bus",
"cabinet",
"cap",
"car",
"cat",
"chair",
"child",
"clock",
"coat",
"counter",
"cow",
"cup",
"curtain",
"desk",
"dog",
"door",
"drawer",
"ear",
"elephant",
"engine",
"eye",
"face",
"fence",
"finger",
"flag",
"flower",
"food",
"fork",
"fruit",
"giraffe",
"girl",
"glass",
"glove",
"guy",
"hair",
"hand",
"handle",
"hat",
"head",
"helmet",
"hill",
"horse",
"house",
"jacket",
"jean",
"kid",
"kite",
"lady",
"lamp",
"laptop",
"leaf",
"leg",
"letter",
"light",
"logo",
"man",
"men",
"motorcycle",
"mountain",
"mouth",
"neck",
"nose",
"number",
"orange",
"pant",
"paper",
"paw",
"people",
"person",
"phone",
"pillow",
"pizza",
"plane",
"plant",
"plate",
"player",
"pole",
"post",
"pot",
"racket",
"railing",
"rock",
"roof",
"room",
"screen",
"seat",
"sheep",
"shelf",
"shirt",
"shoe",
"short",
"sidewalk",
"sign",
"sink",
"skateboard",
"ski",
"skier",
"sneaker",
"snow",
"sock",
"stand",
"street",
"surfboard",
"table",
"tail",
"tie",
"tile",
"tire",
"toilet",
"towel",
"tower",
"track",
"train",
"tree",
"truck",
"trunk",
"umbrella",
"vase",
"vegetable",
"vehicle",
"wave",
"wheel",
"window",
"windshield",
"wing",
"wire",
"woman",
"zebra",
]
def __init__(self, **kwargs): def __init__(self, **kwargs):
super(VGObject, self).__init__(**kwargs) super(VGObject, self).__init__(**kwargs)
@property @property
def annotation_dir(self): def annotation_dir(self):
return '' return ""
def _parse_image_path(self, entry): def _parse_image_path(self, entry):
dirname = 'VG_100K' dirname = "VG_100K"
filename = entry['file_name'] filename = entry["file_name"]
abs_path = os.path.join(self._root, dirname, filename) abs_path = os.path.join(self._root, dirname, filename)
return abs_path return abs_path
"""Prepare Visual Genome datasets""" """Prepare Visual Genome datasets"""
import argparse
import json
import os import os
import pickle
import random
import shutil import shutil
import argparse
import zipfile import zipfile
import random
import json
import tqdm import tqdm
import pickle
from gluoncv.utils import download, makedirs from gluoncv.utils import download, makedirs
_TARGET_DIR = os.path.expanduser('~/.mxnet/datasets/visualgenome') _TARGET_DIR = os.path.expanduser("~/.mxnet/datasets/visualgenome")
def parse_args(): def parse_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Initialize Visual Genome dataset.', description="Initialize Visual Genome dataset.",
epilog='Example: python visualgenome.py --download-dir ~/visualgenome', epilog="Example: python visualgenome.py --download-dir ~/visualgenome",
formatter_class=argparse.ArgumentDefaultsHelpFormatter) formatter_class=argparse.ArgumentDefaultsHelpFormatter,
parser.add_argument('--download-dir', type=str, default='~/visualgenome/', )
help='dataset directory on disk') parser.add_argument(
parser.add_argument('--no-download', action='store_true', help='disable automatic download if set') "--download-dir",
parser.add_argument('--overwrite', action='store_true', help='overwrite downloaded files if set, in case they are corrupted') type=str,
default="~/visualgenome/",
help="dataset directory on disk",
)
parser.add_argument(
"--no-download",
action="store_true",
help="disable automatic download if set",
)
parser.add_argument(
"--overwrite",
action="store_true",
help="overwrite downloaded files if set, in case they are corrupted",
)
args = parser.parse_args() args = parser.parse_args()
return args return args
def download_vg(path, overwrite=False): def download_vg(path, overwrite=False):
_DOWNLOAD_URLS = [ _DOWNLOAD_URLS = [
('https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip', (
'a055367f675dd5476220e9b93e4ca9957b024b94'), "https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip",
('https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip', "a055367f675dd5476220e9b93e4ca9957b024b94",
'2add3aab77623549e92b7f15cda0308f50b64ecf'), ),
(
"https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip",
"2add3aab77623549e92b7f15cda0308f50b64ecf",
),
] ]
makedirs(path) makedirs(path)
for url, checksum in _DOWNLOAD_URLS: for url, checksum in _DOWNLOAD_URLS:
filename = download(url, path=path, overwrite=overwrite, sha1_hash=checksum) filename = download(
url, path=path, overwrite=overwrite, sha1_hash=checksum
)
# extract # extract
if filename.endswith('zip'): if filename.endswith("zip"):
with zipfile.ZipFile(filename) as zf: with zipfile.ZipFile(filename) as zf:
zf.extractall(path=path) zf.extractall(path=path)
# move all images into folder `VG_100K` # move all images into folder `VG_100K`
vg_100k_path = os.path.join(path, 'VG_100K') vg_100k_path = os.path.join(path, "VG_100K")
vg_100k_2_path = os.path.join(path, 'VG_100K_2') vg_100k_2_path = os.path.join(path, "VG_100K_2")
files_2 = os.listdir(vg_100k_2_path) files_2 = os.listdir(vg_100k_2_path)
for fl in files_2: for fl in files_2:
shutil.move(os.path.join(vg_100k_2_path, fl), shutil.move(
os.path.join(vg_100k_path, fl)) os.path.join(vg_100k_2_path, fl), os.path.join(vg_100k_path, fl)
)
def download_json(path, overwrite=False): def download_json(path, overwrite=False):
url = 'https://data.dgl.ai/dataset/vg.zip' url = "https://data.dgl.ai/dataset/vg.zip"
output = 'vg.zip' output = "vg.zip"
download(url, path=path) download(url, path=path)
with zipfile.ZipFile(output) as zf: with zipfile.ZipFile(output) as zf:
zf.extractall(path=path) zf.extractall(path=path)
json_path = os.path.join(path, 'vg') json_path = os.path.join(path, "vg")
json_files = os.listdir(json_path) json_files = os.listdir(json_path)
for fl in json_files: for fl in json_files:
shutil.move(os.path.join(json_path, fl), shutil.move(os.path.join(json_path, fl), os.path.join(path, fl))
os.path.join(path, fl))
os.rmdir(json_path) os.rmdir(json_path)
if __name__ == '__main__':
if __name__ == "__main__":
args = parse_args() args = parse_args()
path = os.path.expanduser(args.download_dir) path = os.path.expanduser(args.download_dir)
if not os.path.isdir(path): if not os.path.isdir(path):
if args.no_download: if args.no_download:
raise ValueError(('{} is not a valid directory, make sure it is present.' raise ValueError(
' Or you should not disable "--no-download" to grab it'.format(path))) (
"{} is not a valid directory, make sure it is present."
' Or you should not disable "--no-download" to grab it'.format(
path
)
)
)
else: else:
download_vg(path, overwrite=args.overwrite) download_vg(path, overwrite=args.overwrite)
download_json(path, overwrite=args.overwrite) download_json(path, overwrite=args.overwrite)
# make symlink # make symlink
makedirs(os.path.expanduser('~/.mxnet/datasets')) makedirs(os.path.expanduser("~/.mxnet/datasets"))
if os.path.isdir(_TARGET_DIR): if os.path.isdir(_TARGET_DIR):
os.rmdir(_TARGET_DIR) os.rmdir(_TARGET_DIR)
os.symlink(path, _TARGET_DIR) os.symlink(path, _TARGET_DIR)
"""Pascal VOC object detection dataset.""" """Pascal VOC object detection dataset."""
from __future__ import absolute_import from __future__ import absolute_import, division
from __future__ import division
import os
import logging
import warnings
import json import json
import dgl import logging
import os
import pickle import pickle
import numpy as np import warnings
from collections import Counter
import mxnet as mx import mxnet as mx
import numpy as np
from gluoncv.data.base import VisionDataset from gluoncv.data.base import VisionDataset
from collections import Counter from gluoncv.data.transforms.presets.rcnn import (
from gluoncv.data.transforms.presets.rcnn import FasterRCNNDefaultTrainTransform, FasterRCNNDefaultValTransform FasterRCNNDefaultTrainTransform, FasterRCNNDefaultValTransform)
import dgl
class VGRelation(VisionDataset): class VGRelation(VisionDataset):
def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'visualgenome'), split='train'): def __init__(
self,
root=os.path.join("~", ".mxnet", "datasets", "visualgenome"),
split="train",
):
super(VGRelation, self).__init__(root) super(VGRelation, self).__init__(root)
self._root = os.path.expanduser(root) self._root = os.path.expanduser(root)
self._img_path = os.path.join(self._root, 'VG_100K', '{}') self._img_path = os.path.join(self._root, "VG_100K", "{}")
if split == 'train': if split == "train":
self._dict_path = os.path.join(self._root, 'rel_annotations_train.json') self._dict_path = os.path.join(
elif split == 'val': self._root, "rel_annotations_train.json"
self._dict_path = os.path.join(self._root, 'rel_annotations_val.json') )
elif split == "val":
self._dict_path = os.path.join(
self._root, "rel_annotations_val.json"
)
else: else:
raise NotImplementedError raise NotImplementedError
with open(self._dict_path) as f: with open(self._dict_path) as f:
tmp = f.read() tmp = f.read()
self._dict = json.loads(tmp) self._dict = json.loads(tmp)
self._predicates_path = os.path.join(self._root, 'predicates.json') self._predicates_path = os.path.join(self._root, "predicates.json")
with open(self._predicates_path, 'r') as f: with open(self._predicates_path, "r") as f:
tmp = f.read() tmp = f.read()
self.rel_classes = json.loads(tmp) self.rel_classes = json.loads(tmp)
self.num_rel_classes = len(self.rel_classes) + 1 self.num_rel_classes = len(self.rel_classes) + 1
self._objects_path = os.path.join(self._root, 'objects.json') self._objects_path = os.path.join(self._root, "objects.json")
with open(self._objects_path, 'r') as f: with open(self._objects_path, "r") as f:
tmp = f.read() tmp = f.read()
self.obj_classes = json.loads(tmp) self.obj_classes = json.loads(tmp)
self.num_obj_classes = len(self.obj_classes) self.num_obj_classes = len(self.obj_classes)
if split == 'val': if split == "val":
self.img_transform = FasterRCNNDefaultValTransform(short=600, max_size=1000) self.img_transform = FasterRCNNDefaultValTransform(
short=600, max_size=1000
)
else: else:
self.img_transform = FasterRCNNDefaultTrainTransform(short=600, max_size=1000) self.img_transform = FasterRCNNDefaultTrainTransform(
short=600, max_size=1000
)
self.split = split self.split = split
def __len__(self): def __len__(self):
return len(self._dict) return len(self._dict)
def _hash_bbox(self, object): def _hash_bbox(self, object):
num_list = [object['category']] + object['bbox'] num_list = [object["category"]] + object["bbox"]
return '_'.join([str(num) for num in num_list]) return "_".join([str(num) for num in num_list])
def __getitem__(self, idx): def __getitem__(self, idx):
img_id = list(self._dict)[idx] img_id = list(self._dict)[idx]
...@@ -66,8 +82,8 @@ class VGRelation(VisionDataset): ...@@ -66,8 +82,8 @@ class VGRelation(VisionDataset):
sub_node_hash = [] sub_node_hash = []
ob_node_hash = [] ob_node_hash = []
for i, it in enumerate(item): for i, it in enumerate(item):
sub_node_hash.append(self._hash_bbox(it['subject'])) sub_node_hash.append(self._hash_bbox(it["subject"]))
ob_node_hash.append(self._hash_bbox(it['object'])) ob_node_hash.append(self._hash_bbox(it["object"]))
node_set = sorted(list(set(sub_node_hash + ob_node_hash))) node_set = sorted(list(set(sub_node_hash + ob_node_hash)))
n_nodes = len(node_set) n_nodes = len(node_set)
node_to_id = {} node_to_id = {}
...@@ -86,35 +102,37 @@ class VGRelation(VisionDataset): ...@@ -86,35 +102,37 @@ class VGRelation(VisionDataset):
for i, it in enumerate(item): for i, it in enumerate(item):
if not node_visited[sub_id[i]]: if not node_visited[sub_id[i]]:
ind = sub_id[i] ind = sub_id[i]
sub = it['subject'] sub = it["subject"]
node_class_ids[ind] = sub['category'] node_class_ids[ind] = sub["category"]
# y1y2x1x2 to x1y1x2y2 # y1y2x1x2 to x1y1x2y2
bbox[ind,0] = sub['bbox'][2] bbox[ind, 0] = sub["bbox"][2]
bbox[ind,1] = sub['bbox'][0] bbox[ind, 1] = sub["bbox"][0]
bbox[ind,2] = sub['bbox'][3] bbox[ind, 2] = sub["bbox"][3]
bbox[ind,3] = sub['bbox'][1] bbox[ind, 3] = sub["bbox"][1]
node_visited[ind] = True node_visited[ind] = True
if not node_visited[ob_id[i]]: if not node_visited[ob_id[i]]:
ind = ob_id[i] ind = ob_id[i]
ob = it['object'] ob = it["object"]
node_class_ids[ind] = ob['category'] node_class_ids[ind] = ob["category"]
# y1y2x1x2 to x1y1x2y2 # y1y2x1x2 to x1y1x2y2
bbox[ind,0] = ob['bbox'][2] bbox[ind, 0] = ob["bbox"][2]
bbox[ind,1] = ob['bbox'][0] bbox[ind, 1] = ob["bbox"][0]
bbox[ind,2] = ob['bbox'][3] bbox[ind, 2] = ob["bbox"][3]
bbox[ind,3] = ob['bbox'][1] bbox[ind, 3] = ob["bbox"][1]
node_visited[ind] = True node_visited[ind] = True
eta = 0.1 eta = 0.1
node_class_vec = node_class_ids[:,0].one_hot(self.num_obj_classes, node_class_vec = node_class_ids[:, 0].one_hot(
on_value = 1 - eta + eta / self.num_obj_classes, self.num_obj_classes,
off_value = eta / self.num_obj_classes) on_value=1 - eta + eta / self.num_obj_classes,
off_value=eta / self.num_obj_classes,
)
# augmentation # augmentation
if self.split == 'val': if self.split == "val":
img, bbox, _ = self.img_transform(img, bbox) img, bbox, _ = self.img_transform(img, bbox)
else: else:
img, bbox = self.img_transform(img, bbox) img, bbox = self.img_transform(img, bbox)
...@@ -126,9 +144,9 @@ class VGRelation(VisionDataset): ...@@ -126,9 +144,9 @@ class VGRelation(VisionDataset):
predicate = [] predicate = []
for i, it in enumerate(item): for i, it in enumerate(item):
adjmat[sub_id[i], ob_id[i]] = 1 adjmat[sub_id[i], ob_id[i]] = 1
predicate.append(it['predicate']) predicate.append(it["predicate"])
predicate = mx.nd.array(predicate).expand_dims(1) predicate = mx.nd.array(predicate).expand_dims(1)
g.add_edges(sub_id, ob_id, {'rel_class': mx.nd.array(predicate) + 1}) g.add_edges(sub_id, ob_id, {"rel_class": mx.nd.array(predicate) + 1})
empty_edge_list = [] empty_edge_list = []
for i in range(n_nodes): for i in range(n_nodes):
for j in range(n_nodes): for j in range(n_nodes):
...@@ -136,11 +154,13 @@ class VGRelation(VisionDataset): ...@@ -136,11 +154,13 @@ class VGRelation(VisionDataset):
empty_edge_list.append((i, j)) empty_edge_list.append((i, j))
if len(empty_edge_list) > 0: if len(empty_edge_list) > 0:
src, dst = tuple(zip(*empty_edge_list)) src, dst = tuple(zip(*empty_edge_list))
g.add_edges(src, dst, {'rel_class': mx.nd.zeros((len(empty_edge_list), 1))}) g.add_edges(
src, dst, {"rel_class": mx.nd.zeros((len(empty_edge_list), 1))}
)
# assign features # assign features
g.ndata['bbox'] = bbox g.ndata["bbox"] = bbox
g.ndata['node_class'] = node_class_ids g.ndata["node_class"] = node_class_ids
g.ndata['node_class_vec'] = node_class_vec g.ndata["node_class_vec"] = node_class_vec
return g, img return g, img
import dgl
import argparse import argparse
import mxnet as mx
import gluoncv as gcv import gluoncv as gcv
from gluoncv.utilz import download import mxnet as mx
from data import *
from gluoncv.data.transforms import presets from gluoncv.data.transforms import presets
from model import faster_rcnn_resnet101_v1d_custom, RelDN from gluoncv.utilz import download
from model import RelDN, faster_rcnn_resnet101_v1d_custom
from utils import * from utils import *
from data import *
import dgl
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description='Demo of Scene Graph Extraction.') parser = argparse.ArgumentParser(
parser.add_argument('--image', type=str, default='', description="Demo of Scene Graph Extraction."
help="The image for scene graph extraction.") )
parser.add_argument('--gpu', type=str, default='', parser.add_argument(
help="GPU id to use for inference, default is not using GPU.") "--image",
parser.add_argument('--pretrained-faster-rcnn-params', type=str, default='', type=str,
help="Path to saved Faster R-CNN model parameters.") default="",
parser.add_argument('--reldn-params', type=str, default='', help="The image for scene graph extraction.",
help="Path to saved Faster R-CNN model parameters.") )
parser.add_argument('--faster-rcnn-params', type=str, default='', parser.add_argument(
help="Path to saved Faster R-CNN model parameters.") "--gpu",
parser.add_argument('--freq-prior', type=str, default='freq_prior.pkl', type=str,
help="Path to saved frequency prior data.") default="",
help="GPU id to use for inference, default is not using GPU.",
)
parser.add_argument(
"--pretrained-faster-rcnn-params",
type=str,
default="",
help="Path to saved Faster R-CNN model parameters.",
)
parser.add_argument(
"--reldn-params",
type=str,
default="",
help="Path to saved Faster R-CNN model parameters.",
)
parser.add_argument(
"--faster-rcnn-params",
type=str,
default="",
help="Path to saved Faster R-CNN model parameters.",
)
parser.add_argument(
"--freq-prior",
type=str,
default="freq_prior.pkl",
help="Path to saved frequency prior data.",
)
args = parser.parse_args() args = parser.parse_args()
return args return args
args = parse_args() args = parse_args()
if args.gpu: if args.gpu:
ctx = mx.gpu(int(args.gpu)) ctx = mx.gpu(int(args.gpu))
...@@ -32,31 +62,47 @@ else: ...@@ -32,31 +62,47 @@ else:
ctx = mx.cpu() ctx = mx.cpu()
net = RelDN(n_classes=50, prior_pkl=args.freq_prior, semantic_only=False) net = RelDN(n_classes=50, prior_pkl=args.freq_prior, semantic_only=False)
if args.reldn_params == '': if args.reldn_params == "":
download('http://data.dgl.ai/models/SceneGraph/reldn.params') download("http://data.dgl.ai/models/SceneGraph/reldn.params")
net.load_parameters('rendl.params', ctx=ctx) net.load_parameters("rendl.params", ctx=ctx)
else: else:
net.load_parameters(args.reldn_params, ctx=ctx) net.load_parameters(args.reldn_params, ctx=ctx)
# dataset and dataloader # dataset and dataloader
vg_val = VGRelation(split='val') vg_val = VGRelation(split="val")
detector = faster_rcnn_resnet101_v1d_custom(classes=vg_val.obj_classes, detector = faster_rcnn_resnet101_v1d_custom(
pretrained_base=False, pretrained=False, classes=vg_val.obj_classes,
additional_output=True) pretrained_base=False,
if args.pretrained_faster_rcnn_params == '': pretrained=False,
download('http://data.dgl.ai/models/SceneGraph/faster_rcnn_resnet101_v1d_visualgenome.params') additional_output=True,
params_path = 'faster_rcnn_resnet101_v1d_visualgenome.params' )
if args.pretrained_faster_rcnn_params == "":
download(
"http://data.dgl.ai/models/SceneGraph/faster_rcnn_resnet101_v1d_visualgenome.params"
)
params_path = "faster_rcnn_resnet101_v1d_visualgenome.params"
else: else:
params_path = args.pretrained_faster_rcnn_params params_path = args.pretrained_faster_rcnn_params
detector.load_parameters(params_path, ctx=ctx, ignore_extra=True, allow_missing=True) detector.load_parameters(
params_path, ctx=ctx, ignore_extra=True, allow_missing=True
)
detector_feat = faster_rcnn_resnet101_v1d_custom(classes=vg_val.obj_classes, detector_feat = faster_rcnn_resnet101_v1d_custom(
pretrained_base=False, pretrained=False, classes=vg_val.obj_classes,
additional_output=True) pretrained_base=False,
detector_feat.load_parameters(params_path, ctx=ctx, ignore_extra=True, allow_missing=True) pretrained=False,
if args.faster_rcnn_params == '': additional_output=True,
download('http://data.dgl.ai/models/SceneGraph/faster_rcnn_resnet101_v1d_visualgenome.params') )
detector_feat.features.load_parameters('faster_rcnn_resnet101_v1d_visualgenome.params', ctx=ctx) detector_feat.load_parameters(
params_path, ctx=ctx, ignore_extra=True, allow_missing=True
)
if args.faster_rcnn_params == "":
download(
"http://data.dgl.ai/models/SceneGraph/faster_rcnn_resnet101_v1d_visualgenome.params"
)
detector_feat.features.load_parameters(
"faster_rcnn_resnet101_v1d_visualgenome.params", ctx=ctx
)
else: else:
detector_feat.features.load_parameters(args.faster_rcnn_params, ctx=ctx) detector_feat.features.load_parameters(args.faster_rcnn_params, ctx=ctx)
...@@ -64,24 +110,37 @@ else: ...@@ -64,24 +110,37 @@ else:
if args.image: if args.image:
image_path = args.image image_path = args.image
else: else:
gcv.utils.download('https://raw.githubusercontent.com/dmlc/web-data/master/' + gcv.utils.download(
'dgl/examples/mxnet/scenegraph/old-couple.png', "https://raw.githubusercontent.com/dmlc/web-data/master/"
'old-couple.png') + "dgl/examples/mxnet/scenegraph/old-couple.png",
image_path = 'old-couple.png' "old-couple.png",
x, img = presets.rcnn.load_test(args.image, short=detector.short, max_size=detector.max_size) )
image_path = "old-couple.png"
x, img = presets.rcnn.load_test(
args.image, short=detector.short, max_size=detector.max_size
)
x = x.as_in_context(ctx) x = x.as_in_context(ctx)
# detector prediction # detector prediction
ids, scores, bboxes, feat, feat_ind, spatial_feat = detector(x) ids, scores, bboxes, feat, feat_ind, spatial_feat = detector(x)
# build graph, extract edge features # build graph, extract edge features
g = build_graph_validate_pred(x, ids, scores, bboxes, feat_ind, spatial_feat, bbox_improvement=True, scores_top_k=75, overlap=False) g = build_graph_validate_pred(
rel_bbox = g.edata['rel_bbox'].expand_dims(0).as_in_context(ctx) x,
ids,
scores,
bboxes,
feat_ind,
spatial_feat,
bbox_improvement=True,
scores_top_k=75,
overlap=False,
)
rel_bbox = g.edata["rel_bbox"].expand_dims(0).as_in_context(ctx)
_, _, _, spatial_feat_rel = detector_feat(x, None, None, rel_bbox) _, _, _, spatial_feat_rel = detector_feat(x, None, None, rel_bbox)
g.edata['edge_feat'] = spatial_feat_rel[0] g.edata["edge_feat"] = spatial_feat_rel[0]
# graph prediction # graph prediction
g = net(g) g = net(g)
_, preds = extract_pred(g, joint_preds=True) _, preds = extract_pred(g, joint_preds=True)
preds = preds[preds[:,1].argsort()[::-1]] preds = preds[preds[:, 1].argsort()[::-1]]
plot_sg(img, preds, detector.classes, vg_val.rel_classes, 10) plot_sg(img, preds, detector.classes, vg_val.rel_classes, 10)
import dgl import pickle
import gluoncv as gcv import gluoncv as gcv
import mxnet as mx import mxnet as mx
import numpy as np import numpy as np
from mxnet import nd from mxnet import nd
from mxnet.gluon import nn from mxnet.gluon import nn
from dgl.utils import toindex
import pickle
import dgl
from dgl.nn.mxnet import GraphConv from dgl.nn.mxnet import GraphConv
from dgl.utils import toindex
__all__ = ["RelDN"]
__all__ = ['RelDN']
class EdgeConfMLP(nn.Block): class EdgeConfMLP(nn.Block):
'''compute the confidence for edges''' """compute the confidence for edges"""
def __init__(self): def __init__(self):
super(EdgeConfMLP, self).__init__() super(EdgeConfMLP, self).__init__()
def forward(self, edges): def forward(self, edges):
score_pred = nd.log_softmax(edges.data['preds'])[:,1:].max(axis=1) score_pred = nd.log_softmax(edges.data["preds"])[:, 1:].max(axis=1)
score_phr = score_pred + edges.src['node_class_logit'] + edges.dst['node_class_logit'] score_phr = (
return {'score_pred': score_pred, score_pred
'score_phr': score_phr} + edges.src["node_class_logit"]
+ edges.dst["node_class_logit"]
)
return {"score_pred": score_pred, "score_phr": score_phr}
class EdgeBBoxExtend(nn.Block): class EdgeBBoxExtend(nn.Block):
'''encode the bounding boxes''' """encode the bounding boxes"""
def __init__(self): def __init__(self):
super(EdgeBBoxExtend, self).__init__() super(EdgeBBoxExtend, self).__init__()
def bbox_delta(self, bbox_a, bbox_b): def bbox_delta(self, bbox_a, bbox_b):
n = bbox_a.shape[0] n = bbox_a.shape[0]
result = nd.zeros((n, 4), ctx=bbox_a.context) result = nd.zeros((n, 4), ctx=bbox_a.context)
result[:,0] = bbox_a[:,0] - bbox_b[:,0] result[:, 0] = bbox_a[:, 0] - bbox_b[:, 0]
result[:,1] = bbox_a[:,1] - bbox_b[:,1] result[:, 1] = bbox_a[:, 1] - bbox_b[:, 1]
result[:,2] = nd.log((bbox_a[:,2] - bbox_a[:,0] + 1e-8) / (bbox_b[:,2] - bbox_b[:,0] + 1e-8)) result[:, 2] = nd.log(
result[:,3] = nd.log((bbox_a[:,3] - bbox_a[:,1] + 1e-8) / (bbox_b[:,3] - bbox_b[:,1] + 1e-8)) (bbox_a[:, 2] - bbox_a[:, 0] + 1e-8)
/ (bbox_b[:, 2] - bbox_b[:, 0] + 1e-8)
)
result[:, 3] = nd.log(
(bbox_a[:, 3] - bbox_a[:, 1] + 1e-8)
/ (bbox_b[:, 3] - bbox_b[:, 1] + 1e-8)
)
return result return result
def forward(self, edges): def forward(self, edges):
ctx = edges.src['pred_bbox'].context ctx = edges.src["pred_bbox"].context
n = edges.src['pred_bbox'].shape[0] n = edges.src["pred_bbox"].shape[0]
delta_src_obj = self.bbox_delta(edges.src['pred_bbox'], edges.dst['pred_bbox']) delta_src_obj = self.bbox_delta(
delta_src_rel = self.bbox_delta(edges.src['pred_bbox'], edges.data['rel_bbox']) edges.src["pred_bbox"], edges.dst["pred_bbox"]
delta_rel_obj = self.bbox_delta(edges.data['rel_bbox'], edges.dst['pred_bbox']) )
delta_src_rel = self.bbox_delta(
edges.src["pred_bbox"], edges.data["rel_bbox"]
)
delta_rel_obj = self.bbox_delta(
edges.data["rel_bbox"], edges.dst["pred_bbox"]
)
result = nd.zeros((n, 12), ctx=ctx) result = nd.zeros((n, 12), ctx=ctx)
result[:,0:4] = delta_src_obj result[:, 0:4] = delta_src_obj
result[:,4:8] = delta_src_rel result[:, 4:8] = delta_src_rel
result[:,8:12] = delta_rel_obj result[:, 8:12] = delta_rel_obj
return {'pred_bbox_additional': result} return {"pred_bbox_additional": result}
class EdgeFreqPrior(nn.Block): class EdgeFreqPrior(nn.Block):
'''make use of the pre-trained frequency prior''' """make use of the pre-trained frequency prior"""
def __init__(self, prior_pkl): def __init__(self, prior_pkl):
super(EdgeFreqPrior, self).__init__() super(EdgeFreqPrior, self).__init__()
with open(prior_pkl, 'rb') as f: with open(prior_pkl, "rb") as f:
freq_prior = pickle.load(f) freq_prior = pickle.load(f)
self.freq_prior = freq_prior self.freq_prior = freq_prior
def forward(self, edges): def forward(self, edges):
ctx = edges.src['node_class_pred'].context ctx = edges.src["node_class_pred"].context
src_ind = edges.src['node_class_pred'].asnumpy().astype(int) src_ind = edges.src["node_class_pred"].asnumpy().astype(int)
dst_ind = edges.dst['node_class_pred'].asnumpy().astype(int) dst_ind = edges.dst["node_class_pred"].asnumpy().astype(int)
prob = self.freq_prior[src_ind, dst_ind] prob = self.freq_prior[src_ind, dst_ind]
out = nd.array(prob, ctx=ctx) out = nd.array(prob, ctx=ctx)
return {'freq_prior': out} return {"freq_prior": out}
class EdgeSpatial(nn.Block): class EdgeSpatial(nn.Block):
'''spatial feature branch''' """spatial feature branch"""
def __init__(self, n_classes): def __init__(self, n_classes):
super(EdgeSpatial, self).__init__() super(EdgeSpatial, self).__init__()
self.mlp = nn.Sequential() self.mlp = nn.Sequential()
...@@ -76,14 +100,20 @@ class EdgeSpatial(nn.Block): ...@@ -76,14 +100,20 @@ class EdgeSpatial(nn.Block):
self.mlp.add(nn.Dense(n_classes)) self.mlp.add(nn.Dense(n_classes))
def forward(self, edges): def forward(self, edges):
feat = nd.concat(edges.src['pred_bbox'], edges.dst['pred_bbox'], feat = nd.concat(
edges.data['rel_bbox'], edges.data['pred_bbox_additional']) edges.src["pred_bbox"],
edges.dst["pred_bbox"],
edges.data["rel_bbox"],
edges.data["pred_bbox_additional"],
)
out = self.mlp(feat) out = self.mlp(feat)
return {'spatial': out} return {"spatial": out}
class EdgeVisual(nn.Block): class EdgeVisual(nn.Block):
'''visual feature branch''' """visual feature branch"""
def __init__(self, n_classes, vis_feat_dim=7*7*3):
def __init__(self, n_classes, vis_feat_dim=7 * 7 * 3):
super(EdgeVisual, self).__init__() super(EdgeVisual, self).__init__()
self.dim_in = vis_feat_dim self.dim_in = vis_feat_dim
self.mlp_joint = nn.Sequential() self.mlp_joint = nn.Sequential()
...@@ -97,15 +127,21 @@ class EdgeVisual(nn.Block): ...@@ -97,15 +127,21 @@ class EdgeVisual(nn.Block):
self.mlp_ob = nn.Dense(n_classes) self.mlp_ob = nn.Dense(n_classes)
def forward(self, edges): def forward(self, edges):
feat = nd.concat(edges.src['node_feat'], edges.dst['node_feat'], edges.data['edge_feat']) feat = nd.concat(
edges.src["node_feat"],
edges.dst["node_feat"],
edges.data["edge_feat"],
)
out_joint = self.mlp_joint(feat) out_joint = self.mlp_joint(feat)
out_sub = self.mlp_sub(edges.src['node_feat']) out_sub = self.mlp_sub(edges.src["node_feat"])
out_ob = self.mlp_ob(edges.dst['node_feat']) out_ob = self.mlp_ob(edges.dst["node_feat"])
out = out_joint + out_sub + out_ob out = out_joint + out_sub + out_ob
return {'visual': out} return {"visual": out}
class RelDN(nn.Block): class RelDN(nn.Block):
'''The RelDN Model''' """The RelDN Model"""
def __init__(self, n_classes, prior_pkl, semantic_only=False): def __init__(self, n_classes, prior_pkl, semantic_only=False):
super(RelDN, self).__init__() super(RelDN, self).__init__()
# output layers # output layers
...@@ -121,19 +157,21 @@ class RelDN(nn.Block): ...@@ -121,19 +157,21 @@ class RelDN(nn.Block):
self.edge_conf_mlp = EdgeConfMLP() self.edge_conf_mlp = EdgeConfMLP()
self.semantic_only = semantic_only self.semantic_only = semantic_only
def forward(self, g): def forward(self, g):
if g is None or g.number_of_nodes() == 0: if g is None or g.number_of_nodes() == 0:
return g return g
# predictions # predictions
g.apply_edges(self.freq_prior) g.apply_edges(self.freq_prior)
if self.semantic_only: if self.semantic_only:
g.edata['preds'] = g.edata['freq_prior'] g.edata["preds"] = g.edata["freq_prior"]
else: else:
# bbox extension # bbox extension
g.apply_edges(self.edge_bbox_extend) g.apply_edges(self.edge_bbox_extend)
g.apply_edges(self.spatial) g.apply_edges(self.spatial)
g.apply_edges(self.visual) g.apply_edges(self.visual)
g.edata['preds'] = g.edata['freq_prior'] + g.edata['spatial'] + g.edata['visual'] g.edata["preds"] = (
g.edata["freq_prior"] + g.edata["spatial"] + g.edata["visual"]
)
# subgraph for gconv # subgraph for gconv
g.apply_edges(self.edge_conf_mlp) g.apply_edges(self.edge_conf_mlp)
return g return g
import argparse
import json
import os
import pickle
import numpy as np import numpy as np
import json, pickle, os, argparse
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description='Train the Frequenct Prior For RelDN.') parser = argparse.ArgumentParser(
parser.add_argument('--overlap', action='store_true', description="Train the Frequenct Prior For RelDN."
help="Only count overlap boxes.") )
parser.add_argument('--json-path', type=str, default='~/.mxnet/datasets/visualgenome', parser.add_argument(
help="Only count overlap boxes.") "--overlap", action="store_true", help="Only count overlap boxes."
)
parser.add_argument(
"--json-path",
type=str,
default="~/.mxnet/datasets/visualgenome",
help="Only count overlap boxes.",
)
args = parser.parse_args() args = parser.parse_args()
return args return args
args = parse_args() args = parse_args()
use_overlap = args.overlap use_overlap = args.overlap
PATH_TO_DATASETS = os.path.expanduser(args.json_path) PATH_TO_DATASETS = os.path.expanduser(args.json_path)
path_to_json = os.path.join(PATH_TO_DATASETS, 'rel_annotations_train.json') path_to_json = os.path.join(PATH_TO_DATASETS, "rel_annotations_train.json")
# format in y1y2x1x2 # format in y1y2x1x2
def with_overlap(boxA, boxB): def with_overlap(boxA, boxB):
...@@ -29,17 +42,19 @@ def with_overlap(boxA, boxB): ...@@ -29,17 +42,19 @@ def with_overlap(boxA, boxB):
return 0 return 0
def box_ious(boxes): def box_ious(boxes):
n = len(boxes) n = len(boxes)
res = np.zeros((n, n)) res = np.zeros((n, n))
for i in range(n-1): for i in range(n - 1):
for j in range(i+1, n): for j in range(i + 1, n):
iou_val = with_overlap(boxes[i], boxes[j]) iou_val = with_overlap(boxes[i], boxes[j])
res[i, j] = iou_val res[i, j] = iou_val
res[j, i] = iou_val res[j, i] = iou_val
return res return res
with open(path_to_json, 'r') as f:
with open(path_to_json, "r") as f:
tmp = f.read() tmp = f.read()
train_data = json.loads(tmp) train_data = json.loads(tmp)
...@@ -49,11 +64,11 @@ bg_matrix = np.zeros((150, 150), dtype=np.int64) ...@@ -49,11 +64,11 @@ bg_matrix = np.zeros((150, 150), dtype=np.int64)
for _, item in train_data.items(): for _, item in train_data.items():
gt_box_to_label = {} gt_box_to_label = {}
for rel in item: for rel in item:
sub_bbox = rel['subject']['bbox'] sub_bbox = rel["subject"]["bbox"]
ob_bbox = rel['object']['bbox'] ob_bbox = rel["object"]["bbox"]
sub_class = rel['subject']['category'] sub_class = rel["subject"]["category"]
ob_class = rel['object']['category'] ob_class = rel["object"]["category"]
rel_class = rel['predicate'] rel_class = rel["predicate"]
sub_node = tuple(sub_bbox) sub_node = tuple(sub_bbox)
ob_node = tuple(ob_bbox) ob_node = tuple(ob_bbox)
...@@ -93,8 +108,8 @@ pred_dist = np.log(fg_matrix / (fg_matrix.sum(2)[:, :, None] + eps) + eps) ...@@ -93,8 +108,8 @@ pred_dist = np.log(fg_matrix / (fg_matrix.sum(2)[:, :, None] + eps) + eps)
if use_overlap: if use_overlap:
with open('freq_prior_overlap.pkl', 'wb') as f: with open("freq_prior_overlap.pkl", "wb") as f:
pickle.dump(pred_dist, f) pickle.dump(pred_dist, f)
else: else:
with open('freq_prior.pkl', 'wb') as f: with open("freq_prior.pkl", "wb") as f:
pickle.dump(pred_dist, f) pickle.dump(pred_dist, f)
This diff is collapsed.
from .metric import *
from .build_graph import * from .build_graph import *
from .metric import *
from .sampling import * from .sampling import *
from .viz import * from .viz import *
This diff is collapsed.
import dgl
from dgl.utils import toindex
import mxnet as mx import mxnet as mx
import numpy as np import numpy as np
import dgl
from dgl.utils import toindex
def l0_sample(g, positive_max=128, negative_ratio=3): def l0_sample(g, positive_max=128, negative_ratio=3):
'''sampling positive and negative edges''' """sampling positive and negative edges"""
if g is None: if g is None:
return None return None
n_eids = g.number_of_edges() n_eids = g.number_of_edges()
pos_eids = np.where(g.edata['rel_class'].asnumpy() > 0)[0] pos_eids = np.where(g.edata["rel_class"].asnumpy() > 0)[0]
neg_eids = np.where(g.edata['rel_class'].asnumpy() == 0)[0] neg_eids = np.where(g.edata["rel_class"].asnumpy() == 0)[0]
if len(pos_eids) == 0: if len(pos_eids) == 0:
return None return None
...@@ -26,6 +28,7 @@ def l0_sample(g, positive_max=128, negative_ratio=3): ...@@ -26,6 +28,7 @@ def l0_sample(g, positive_max=128, negative_ratio=3):
eids = np.where(weights > 0)[0] eids = np.where(weights > 0)[0]
sub_g = g.edge_subgraph(toindex(eids.tolist())) sub_g = g.edge_subgraph(toindex(eids.tolist()))
sub_g.copy_from_parent() sub_g.copy_from_parent()
sub_g.edata['sample_weights'] = mx.nd.array(weights[eids], sub_g.edata["sample_weights"] = mx.nd.array(
ctx=g.edata['rel_class'].context) weights[eids], ctx=g.edata["rel_class"].context
)
return sub_g return sub_g
import numpy as np
import gluoncv as gcv import gluoncv as gcv
import numpy as np
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
def plot_sg(img, preds, obj_classes, rel_classes, topk=1): def plot_sg(img, preds, obj_classes, rel_classes, topk=1):
'''visualization of generated scene graph''' """visualization of generated scene graph"""
size = img.shape[0:2] size = img.shape[0:2]
box_scale = np.array([size[1], size[0], size[1], size[0]]) box_scale = np.array([size[1], size[0], size[1], size[0]])
topk = min(topk, preds.shape[0]) topk = min(topk, preds.shape[0])
...@@ -17,30 +18,51 @@ def plot_sg(img, preds, obj_classes, rel_classes, topk=1): ...@@ -17,30 +18,51 @@ def plot_sg(img, preds, obj_classes, rel_classes, topk=1):
rel_name = rel_classes[rel] rel_name = rel_classes[rel]
src_bbox = preds[i, 5:9] * box_scale src_bbox = preds[i, 5:9] * box_scale
dst_bbox = preds[i, 9:13] * box_scale dst_bbox = preds[i, 9:13] * box_scale
src_center = np.array([(src_bbox[0] + src_bbox[2]) / 2, (src_bbox[1] + src_bbox[3]) / 2]) src_center = np.array(
dst_center = np.array([(dst_bbox[0] + dst_bbox[2]) / 2, (dst_bbox[1] + dst_bbox[3]) / 2]) [(src_bbox[0] + src_bbox[2]) / 2, (src_bbox[1] + src_bbox[3]) / 2]
)
dst_center = np.array(
[(dst_bbox[0] + dst_bbox[2]) / 2, (dst_bbox[1] + dst_bbox[3]) / 2]
)
rel_center = (src_center + dst_center) / 2 rel_center = (src_center + dst_center) / 2
line_x = np.array([(src_bbox[0] + src_bbox[2]) / 2, (dst_bbox[0] + dst_bbox[2]) / 2]) line_x = np.array(
line_y = np.array([(src_bbox[1] + src_bbox[3]) / 2, (dst_bbox[1] + dst_bbox[3]) / 2]) [(src_bbox[0] + src_bbox[2]) / 2, (dst_bbox[0] + dst_bbox[2]) / 2]
)
ax.plot(line_x, line_y, line_y = np.array(
linewidth=3.0, alpha=0.7, color=plt.cm.cool(rel)) [(src_bbox[1] + src_bbox[3]) / 2, (dst_bbox[1] + dst_bbox[3]) / 2]
)
ax.text(src_center[0], src_center[1],
'{:s}'.format(src_name), ax.plot(
bbox=dict(alpha=0.5), line_x, line_y, linewidth=3.0, alpha=0.7, color=plt.cm.cool(rel)
fontsize=12, color='white') )
ax.text(dst_center[0], dst_center[1],
'{:s}'.format(dst_name), ax.text(
bbox=dict(alpha=0.5), src_center[0],
fontsize=12, color='white') src_center[1],
ax.text(rel_center[0], rel_center[1], "{:s}".format(src_name),
'{:s}'.format(rel_name), bbox=dict(alpha=0.5),
bbox=dict(alpha=0.5), fontsize=12,
fontsize=12, color='white') color="white",
)
ax.text(
dst_center[0],
dst_center[1],
"{:s}".format(dst_name),
bbox=dict(alpha=0.5),
fontsize=12,
color="white",
)
ax.text(
rel_center[0],
rel_center[1],
"{:s}".format(rel_name),
bbox=dict(alpha=0.5),
fontsize=12,
color="white",
)
return ax return ax
plot_sg(img, preds, 2)
plot_sg(img, preds, 2)
This diff is collapsed.
This diff is collapsed.
...@@ -6,18 +6,15 @@ References: ...@@ -6,18 +6,15 @@ References:
""" """
import mxnet as mx import mxnet as mx
from mxnet import gluon from mxnet import gluon
import dgl import dgl
from dgl.nn.mxnet import TAGConv from dgl.nn.mxnet import TAGConv
class TAGCN(gluon.Block): class TAGCN(gluon.Block):
def __init__(self, def __init__(
g, self, g, in_feats, n_hidden, n_classes, n_layers, activation, dropout
in_feats, ):
n_hidden,
n_classes,
n_layers,
activation,
dropout):
super(TAGCN, self).__init__() super(TAGCN, self).__init__()
self.g = g self.g = g
self.layers = gluon.nn.Sequential() self.layers = gluon.nn.Sequential()
...@@ -27,7 +24,7 @@ class TAGCN(gluon.Block): ...@@ -27,7 +24,7 @@ class TAGCN(gluon.Block):
for i in range(n_layers - 1): for i in range(n_layers - 1):
self.layers.add(TAGConv(n_hidden, n_hidden, activation=activation)) self.layers.add(TAGConv(n_hidden, n_hidden, activation=activation))
# output layer # output layer
self.layers.add(TAGConv(n_hidden, n_classes)) #activation=None self.layers.add(TAGConv(n_hidden, n_classes)) # activation=None
self.dropout = gluon.nn.Dropout(rate=dropout) self.dropout = gluon.nn.Dropout(rate=dropout)
def forward(self, features): def forward(self, features):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment