Unverified Commit a9f2acf3 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4641)



* [Misc] Black auto fix.

* sort
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 08c50eb7
import argparse
import time
import numpy as np
import networkx as nx
import mxnet as mx
import networkx as nx
import numpy as np
from mxnet import gluon, nd
from mxnet.gluon import nn
import dgl
from dgl.data import register_data_args
from dgl.data import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset
from dgl.data import (CiteseerGraphDataset, CoraGraphDataset,
PubmedGraphDataset, register_data_args)
from dgl.nn.mxnet.conv import GMMConv
class MoNet(nn.Block):
def __init__(self,
def __init__(
self,
g,
in_feats,
n_hidden,
......@@ -20,7 +23,8 @@ class MoNet(nn.Block):
n_layers,
dim,
n_kernels,
dropout):
dropout,
):
super(MoNet, self).__init__()
self.g = g
with self.name_scope():
......@@ -28,18 +32,19 @@ class MoNet(nn.Block):
self.pseudo_proj = nn.Sequential()
# Input layer
self.layers.add(
GMMConv(in_feats, n_hidden, dim, n_kernels))
self.pseudo_proj.add(nn.Dense(dim, in_units=2, activation='tanh'))
self.layers.add(GMMConv(in_feats, n_hidden, dim, n_kernels))
self.pseudo_proj.add(nn.Dense(dim, in_units=2, activation="tanh"))
# Hidden layer
for _ in range(n_layers - 1):
self.layers.add(GMMConv(n_hidden, n_hidden, dim, n_kernels))
self.pseudo_proj.add(nn.Dense(dim, in_units=2, activation='tanh'))
self.pseudo_proj.add(
nn.Dense(dim, in_units=2, activation="tanh")
)
# Output layer
self.layers.add(GMMConv(n_hidden, out_feats, dim, n_kernels))
self.pseudo_proj.add(nn.Dense(dim, in_units=2, activation='tanh'))
self.pseudo_proj.add(nn.Dense(dim, in_units=2, activation="tanh"))
self.dropout = nn.Dropout(dropout)
......@@ -48,8 +53,7 @@ class MoNet(nn.Block):
for i in range(len(self.layers)):
if i > 0:
h = self.dropout(h)
h = self.layers[i](
self.g, h, self.pseudo_proj[i](pseudo))
h = self.layers[i](self.g, h, self.pseudo_proj[i](pseudo))
return h
......@@ -58,16 +62,17 @@ def evaluate(model, features, pseudo, labels, mask):
accuracy = ((pred == labels) * mask).sum() / mask.sum().asscalar()
return accuracy.asscalar()
def main(args):
# load and preprocess dataset
if args.dataset == 'cora':
if args.dataset == "cora":
data = CoraGraphDataset()
elif args.dataset == 'citeseer':
elif args.dataset == "citeseer":
data = CiteseerGraphDataset()
elif args.dataset == 'pubmed':
elif args.dataset == "pubmed":
data = PubmedGraphDataset()
else:
raise ValueError('Unknown dataset: {}'.format(args.dataset))
raise ValueError("Unknown dataset: {}".format(args.dataset))
g = data[0]
if args.gpu < 0:
......@@ -78,24 +83,29 @@ def main(args):
ctx = mx.gpu(args.gpu)
g = g.to(ctx)
features = g.ndata['feat']
labels = mx.nd.array(g.ndata['label'], dtype="float32", ctx=ctx)
train_mask = g.ndata['train_mask']
val_mask = g.ndata['val_mask']
test_mask = g.ndata['test_mask']
features = g.ndata["feat"]
labels = mx.nd.array(g.ndata["label"], dtype="float32", ctx=ctx)
train_mask = g.ndata["train_mask"]
val_mask = g.ndata["val_mask"]
test_mask = g.ndata["test_mask"]
in_feats = features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()
print("""----Data statistics------'
print(
"""----Data statistics------'
#Edges %d
#Classes %d
#Train samples %d
#Val samples %d
#Test samples %d""" %
(n_edges, n_classes,
#Test samples %d"""
% (
n_edges,
n_classes,
train_mask.sum().asscalar(),
val_mask.sum().asscalar(),
test_mask.sum().asscalar()))
test_mask.sum().asscalar(),
)
)
# add self loop
g = dgl.remove_self_loop(g)
......@@ -107,30 +117,32 @@ def main(args):
vs = vs.asnumpy()
pseudo = []
for i in range(g.number_of_edges()):
pseudo.append([
1 / np.sqrt(g.in_degree(us[i])),
1 / np.sqrt(g.in_degree(vs[i]))
])
pseudo.append(
[1 / np.sqrt(g.in_degree(us[i])), 1 / np.sqrt(g.in_degree(vs[i]))]
)
pseudo = nd.array(pseudo, ctx=ctx)
# create GraphSAGE model
model = MoNet(g,
model = MoNet(
g,
in_feats,
args.n_hidden,
n_classes,
args.n_layers,
args.pseudo_dim,
args.n_kernels,
args.dropout
args.dropout,
)
model.initialize(ctx=ctx)
n_train_samples = train_mask.sum().asscalar()
loss_fcn = gluon.loss.SoftmaxCELoss()
print(model.collect_params())
trainer = gluon.Trainer(model.collect_params(), 'adam',
{'learning_rate': args.lr, 'wd': args.weight_decay})
trainer = gluon.Trainer(
model.collect_params(),
"adam",
{"learning_rate": args.lr, "wd": args.weight_decay},
)
# initialize graph
dur = []
......@@ -150,36 +162,54 @@ def main(args):
loss.asscalar()
dur.append(time.time() - t0)
acc = evaluate(model, features, pseudo, labels, val_mask)
print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
"ETputs(KTEPS) {:.2f}". format(
epoch, np.mean(dur), loss.asscalar(), acc, n_edges / np.mean(dur) / 1000))
print(
"Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
"ETputs(KTEPS) {:.2f}".format(
epoch,
np.mean(dur),
loss.asscalar(),
acc,
n_edges / np.mean(dur) / 1000,
)
)
# test set accuracy
acc = evaluate(model, features, pseudo, labels, test_mask)
print("Test accuracy {:.2%}".format(acc))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='MoNet on citation network')
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MoNet on citation network")
register_data_args(parser)
parser.add_argument("--dropout", type=float, default=0.5,
help="dropout probability")
parser.add_argument("--gpu", type=int, default=-1,
help="gpu")
parser.add_argument("--lr", type=float, default=1e-2,
help="learning rate")
parser.add_argument("--n-epochs", type=int, default=200,
help="number of training epochs")
parser.add_argument("--n-hidden", type=int, default=16,
help="number of hidden gcn units")
parser.add_argument("--n-layers", type=int, default=1,
help="number of hidden gcn layers")
parser.add_argument("--pseudo-dim", type=int, default=2,
help="Pseudo coordinate dimensions in GMMConv, 2 for cora and 3 for pubmed")
parser.add_argument("--n-kernels", type=int, default=3,
help="Number of kernels in GMMConv layer")
parser.add_argument("--weight-decay", type=float, default=5e-5,
help="Weight for L2 loss")
parser.add_argument(
"--dropout", type=float, default=0.5, help="dropout probability"
)
parser.add_argument("--gpu", type=int, default=-1, help="gpu")
parser.add_argument("--lr", type=float, default=1e-2, help="learning rate")
parser.add_argument(
"--n-epochs", type=int, default=200, help="number of training epochs"
)
parser.add_argument(
"--n-hidden", type=int, default=16, help="number of hidden gcn units"
)
parser.add_argument(
"--n-layers", type=int, default=1, help="number of hidden gcn layers"
)
parser.add_argument(
"--pseudo-dim",
type=int,
default=2,
help="Pseudo coordinate dimensions in GMMConv, 2 for cora and 3 for pubmed",
)
parser.add_argument(
"--n-kernels",
type=int,
default=3,
help="Number of kernels in GMMConv layer",
)
parser.add_argument(
"--weight-decay", type=float, default=5e-5, help="Weight for L2 loss"
)
args = parser.parse_args()
print(args)
......
import mxnet as mx
from mxnet import gluon
class BaseRGCN(gluon.Block):
def __init__(self, num_nodes, h_dim, out_dim, num_rels, num_bases=-1,
num_hidden_layers=1, dropout=0,
use_self_loop=False, gpu_id=-1):
def __init__(
self,
num_nodes,
h_dim,
out_dim,
num_rels,
num_bases=-1,
num_hidden_layers=1,
dropout=0,
use_self_loop=False,
gpu_id=-1,
):
super(BaseRGCN, self).__init__()
self.num_nodes = num_nodes
self.h_dim = h_dim
......
from .dataloader import *
from .object import *
from .relation import *
from .dataloader import *
"""DataLoader utils."""
import dgl
from mxnet import nd
from gluoncv.data.batchify import Pad
from mxnet import nd
import dgl
def dgl_mp_batchify_fn(data):
if isinstance(data[0], tuple):
......
"""Pascal VOC object detection dataset."""
from __future__ import absolute_import
from __future__ import division
import os
import logging
import warnings
from __future__ import absolute_import, division
import json
import logging
import os
import pickle
import numpy as np
import warnings
from collections import Counter
import mxnet as mx
import numpy as np
from gluoncv.data import COCODetection
from collections import Counter
class VGObject(COCODetection):
CLASSES = ["airplane", "animal", "arm", "bag", "banana", "basket", "beach",
"bear", "bed", "bench", "bike", "bird", "board", "boat", "book",
"boot", "bottle", "bowl", "box", "boy", "branch", "building", "bus",
"cabinet", "cap", "car", "cat", "chair", "child", "clock", "coat",
"counter", "cow", "cup", "curtain", "desk", "dog", "door", "drawer",
"ear", "elephant", "engine", "eye", "face", "fence", "finger", "flag",
"flower", "food", "fork", "fruit", "giraffe", "girl", "glass", "glove",
"guy", "hair", "hand", "handle", "hat", "head", "helmet", "hill",
"horse", "house", "jacket", "jean", "kid", "kite", "lady", "lamp",
"laptop", "leaf", "leg", "letter", "light", "logo", "man", "men",
"motorcycle", "mountain", "mouth", "neck", "nose", "number", "orange",
"pant", "paper", "paw", "people", "person", "phone", "pillow", "pizza",
"plane", "plant", "plate", "player", "pole", "post", "pot", "racket",
"railing", "rock", "roof", "room", "screen", "seat", "sheep", "shelf",
"shirt", "shoe", "short", "sidewalk", "sign", "sink", "skateboard",
"ski", "skier", "sneaker", "snow", "sock", "stand", "street",
"surfboard", "table", "tail", "tie", "tile", "tire", "toilet",
"towel", "tower", "track", "train", "tree", "truck", "trunk",
"umbrella", "vase", "vegetable", "vehicle", "wave", "wheel",
"window", "windshield", "wing", "wire", "woman", "zebra"]
CLASSES = [
"airplane",
"animal",
"arm",
"bag",
"banana",
"basket",
"beach",
"bear",
"bed",
"bench",
"bike",
"bird",
"board",
"boat",
"book",
"boot",
"bottle",
"bowl",
"box",
"boy",
"branch",
"building",
"bus",
"cabinet",
"cap",
"car",
"cat",
"chair",
"child",
"clock",
"coat",
"counter",
"cow",
"cup",
"curtain",
"desk",
"dog",
"door",
"drawer",
"ear",
"elephant",
"engine",
"eye",
"face",
"fence",
"finger",
"flag",
"flower",
"food",
"fork",
"fruit",
"giraffe",
"girl",
"glass",
"glove",
"guy",
"hair",
"hand",
"handle",
"hat",
"head",
"helmet",
"hill",
"horse",
"house",
"jacket",
"jean",
"kid",
"kite",
"lady",
"lamp",
"laptop",
"leaf",
"leg",
"letter",
"light",
"logo",
"man",
"men",
"motorcycle",
"mountain",
"mouth",
"neck",
"nose",
"number",
"orange",
"pant",
"paper",
"paw",
"people",
"person",
"phone",
"pillow",
"pizza",
"plane",
"plant",
"plate",
"player",
"pole",
"post",
"pot",
"racket",
"railing",
"rock",
"roof",
"room",
"screen",
"seat",
"sheep",
"shelf",
"shirt",
"shoe",
"short",
"sidewalk",
"sign",
"sink",
"skateboard",
"ski",
"skier",
"sneaker",
"snow",
"sock",
"stand",
"street",
"surfboard",
"table",
"tail",
"tie",
"tile",
"tire",
"toilet",
"towel",
"tower",
"track",
"train",
"tree",
"truck",
"trunk",
"umbrella",
"vase",
"vegetable",
"vehicle",
"wave",
"wheel",
"window",
"windshield",
"wing",
"wire",
"woman",
"zebra",
]
def __init__(self, **kwargs):
super(VGObject, self).__init__(**kwargs)
@property
def annotation_dir(self):
return ''
return ""
def _parse_image_path(self, entry):
dirname = 'VG_100K'
filename = entry['file_name']
dirname = "VG_100K"
filename = entry["file_name"]
abs_path = os.path.join(self._root, dirname, filename)
return abs_path
"""Prepare Visual Genome datasets"""
import argparse
import json
import os
import pickle
import random
import shutil
import argparse
import zipfile
import random
import json
import tqdm
import pickle
from gluoncv.utils import download, makedirs
_TARGET_DIR = os.path.expanduser('~/.mxnet/datasets/visualgenome')
_TARGET_DIR = os.path.expanduser("~/.mxnet/datasets/visualgenome")
def parse_args():
parser = argparse.ArgumentParser(
description='Initialize Visual Genome dataset.',
epilog='Example: python visualgenome.py --download-dir ~/visualgenome',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--download-dir', type=str, default='~/visualgenome/',
help='dataset directory on disk')
parser.add_argument('--no-download', action='store_true', help='disable automatic download if set')
parser.add_argument('--overwrite', action='store_true', help='overwrite downloaded files if set, in case they are corrupted')
description="Initialize Visual Genome dataset.",
epilog="Example: python visualgenome.py --download-dir ~/visualgenome",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--download-dir",
type=str,
default="~/visualgenome/",
help="dataset directory on disk",
)
parser.add_argument(
"--no-download",
action="store_true",
help="disable automatic download if set",
)
parser.add_argument(
"--overwrite",
action="store_true",
help="overwrite downloaded files if set, in case they are corrupted",
)
args = parser.parse_args()
return args
def download_vg(path, overwrite=False):
_DOWNLOAD_URLS = [
('https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip',
'a055367f675dd5476220e9b93e4ca9957b024b94'),
('https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip',
'2add3aab77623549e92b7f15cda0308f50b64ecf'),
(
"https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip",
"a055367f675dd5476220e9b93e4ca9957b024b94",
),
(
"https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip",
"2add3aab77623549e92b7f15cda0308f50b64ecf",
),
]
makedirs(path)
for url, checksum in _DOWNLOAD_URLS:
filename = download(url, path=path, overwrite=overwrite, sha1_hash=checksum)
filename = download(
url, path=path, overwrite=overwrite, sha1_hash=checksum
)
# extract
if filename.endswith('zip'):
if filename.endswith("zip"):
with zipfile.ZipFile(filename) as zf:
zf.extractall(path=path)
# move all images into folder `VG_100K`
vg_100k_path = os.path.join(path, 'VG_100K')
vg_100k_2_path = os.path.join(path, 'VG_100K_2')
vg_100k_path = os.path.join(path, "VG_100K")
vg_100k_2_path = os.path.join(path, "VG_100K_2")
files_2 = os.listdir(vg_100k_2_path)
for fl in files_2:
shutil.move(os.path.join(vg_100k_2_path, fl),
os.path.join(vg_100k_path, fl))
shutil.move(
os.path.join(vg_100k_2_path, fl), os.path.join(vg_100k_path, fl)
)
def download_json(path, overwrite=False):
url = 'https://data.dgl.ai/dataset/vg.zip'
output = 'vg.zip'
url = "https://data.dgl.ai/dataset/vg.zip"
output = "vg.zip"
download(url, path=path)
with zipfile.ZipFile(output) as zf:
zf.extractall(path=path)
json_path = os.path.join(path, 'vg')
json_path = os.path.join(path, "vg")
json_files = os.listdir(json_path)
for fl in json_files:
shutil.move(os.path.join(json_path, fl),
os.path.join(path, fl))
shutil.move(os.path.join(json_path, fl), os.path.join(path, fl))
os.rmdir(json_path)
if __name__ == '__main__':
if __name__ == "__main__":
args = parse_args()
path = os.path.expanduser(args.download_dir)
if not os.path.isdir(path):
if args.no_download:
raise ValueError(('{} is not a valid directory, make sure it is present.'
' Or you should not disable "--no-download" to grab it'.format(path)))
raise ValueError(
(
"{} is not a valid directory, make sure it is present."
' Or you should not disable "--no-download" to grab it'.format(
path
)
)
)
else:
download_vg(path, overwrite=args.overwrite)
download_json(path, overwrite=args.overwrite)
# make symlink
makedirs(os.path.expanduser('~/.mxnet/datasets'))
makedirs(os.path.expanduser("~/.mxnet/datasets"))
if os.path.isdir(_TARGET_DIR):
os.rmdir(_TARGET_DIR)
os.symlink(path, _TARGET_DIR)
"""Pascal VOC object detection dataset."""
from __future__ import absolute_import
from __future__ import division
import os
import logging
import warnings
from __future__ import absolute_import, division
import json
import dgl
import logging
import os
import pickle
import numpy as np
import warnings
from collections import Counter
import mxnet as mx
import numpy as np
from gluoncv.data.base import VisionDataset
from collections import Counter
from gluoncv.data.transforms.presets.rcnn import FasterRCNNDefaultTrainTransform, FasterRCNNDefaultValTransform
from gluoncv.data.transforms.presets.rcnn import (
FasterRCNNDefaultTrainTransform, FasterRCNNDefaultValTransform)
import dgl
class VGRelation(VisionDataset):
def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'visualgenome'), split='train'):
def __init__(
self,
root=os.path.join("~", ".mxnet", "datasets", "visualgenome"),
split="train",
):
super(VGRelation, self).__init__(root)
self._root = os.path.expanduser(root)
self._img_path = os.path.join(self._root, 'VG_100K', '{}')
if split == 'train':
self._dict_path = os.path.join(self._root, 'rel_annotations_train.json')
elif split == 'val':
self._dict_path = os.path.join(self._root, 'rel_annotations_val.json')
self._img_path = os.path.join(self._root, "VG_100K", "{}")
if split == "train":
self._dict_path = os.path.join(
self._root, "rel_annotations_train.json"
)
elif split == "val":
self._dict_path = os.path.join(
self._root, "rel_annotations_val.json"
)
else:
raise NotImplementedError
with open(self._dict_path) as f:
tmp = f.read()
self._dict = json.loads(tmp)
self._predicates_path = os.path.join(self._root, 'predicates.json')
with open(self._predicates_path, 'r') as f:
self._predicates_path = os.path.join(self._root, "predicates.json")
with open(self._predicates_path, "r") as f:
tmp = f.read()
self.rel_classes = json.loads(tmp)
self.num_rel_classes = len(self.rel_classes) + 1
self._objects_path = os.path.join(self._root, 'objects.json')
with open(self._objects_path, 'r') as f:
self._objects_path = os.path.join(self._root, "objects.json")
with open(self._objects_path, "r") as f:
tmp = f.read()
self.obj_classes = json.loads(tmp)
self.num_obj_classes = len(self.obj_classes)
if split == 'val':
self.img_transform = FasterRCNNDefaultValTransform(short=600, max_size=1000)
if split == "val":
self.img_transform = FasterRCNNDefaultValTransform(
short=600, max_size=1000
)
else:
self.img_transform = FasterRCNNDefaultTrainTransform(short=600, max_size=1000)
self.img_transform = FasterRCNNDefaultTrainTransform(
short=600, max_size=1000
)
self.split = split
def __len__(self):
return len(self._dict)
def _hash_bbox(self, object):
num_list = [object['category']] + object['bbox']
return '_'.join([str(num) for num in num_list])
num_list = [object["category"]] + object["bbox"]
return "_".join([str(num) for num in num_list])
def __getitem__(self, idx):
img_id = list(self._dict)[idx]
......@@ -66,8 +82,8 @@ class VGRelation(VisionDataset):
sub_node_hash = []
ob_node_hash = []
for i, it in enumerate(item):
sub_node_hash.append(self._hash_bbox(it['subject']))
ob_node_hash.append(self._hash_bbox(it['object']))
sub_node_hash.append(self._hash_bbox(it["subject"]))
ob_node_hash.append(self._hash_bbox(it["object"]))
node_set = sorted(list(set(sub_node_hash + ob_node_hash)))
n_nodes = len(node_set)
node_to_id = {}
......@@ -86,35 +102,37 @@ class VGRelation(VisionDataset):
for i, it in enumerate(item):
if not node_visited[sub_id[i]]:
ind = sub_id[i]
sub = it['subject']
node_class_ids[ind] = sub['category']
sub = it["subject"]
node_class_ids[ind] = sub["category"]
# y1y2x1x2 to x1y1x2y2
bbox[ind,0] = sub['bbox'][2]
bbox[ind,1] = sub['bbox'][0]
bbox[ind,2] = sub['bbox'][3]
bbox[ind,3] = sub['bbox'][1]
bbox[ind, 0] = sub["bbox"][2]
bbox[ind, 1] = sub["bbox"][0]
bbox[ind, 2] = sub["bbox"][3]
bbox[ind, 3] = sub["bbox"][1]
node_visited[ind] = True
if not node_visited[ob_id[i]]:
ind = ob_id[i]
ob = it['object']
node_class_ids[ind] = ob['category']
ob = it["object"]
node_class_ids[ind] = ob["category"]
# y1y2x1x2 to x1y1x2y2
bbox[ind,0] = ob['bbox'][2]
bbox[ind,1] = ob['bbox'][0]
bbox[ind,2] = ob['bbox'][3]
bbox[ind,3] = ob['bbox'][1]
bbox[ind, 0] = ob["bbox"][2]
bbox[ind, 1] = ob["bbox"][0]
bbox[ind, 2] = ob["bbox"][3]
bbox[ind, 3] = ob["bbox"][1]
node_visited[ind] = True
eta = 0.1
node_class_vec = node_class_ids[:,0].one_hot(self.num_obj_classes,
on_value = 1 - eta + eta / self.num_obj_classes,
off_value = eta / self.num_obj_classes)
node_class_vec = node_class_ids[:, 0].one_hot(
self.num_obj_classes,
on_value=1 - eta + eta / self.num_obj_classes,
off_value=eta / self.num_obj_classes,
)
# augmentation
if self.split == 'val':
if self.split == "val":
img, bbox, _ = self.img_transform(img, bbox)
else:
img, bbox = self.img_transform(img, bbox)
......@@ -126,9 +144,9 @@ class VGRelation(VisionDataset):
predicate = []
for i, it in enumerate(item):
adjmat[sub_id[i], ob_id[i]] = 1
predicate.append(it['predicate'])
predicate.append(it["predicate"])
predicate = mx.nd.array(predicate).expand_dims(1)
g.add_edges(sub_id, ob_id, {'rel_class': mx.nd.array(predicate) + 1})
g.add_edges(sub_id, ob_id, {"rel_class": mx.nd.array(predicate) + 1})
empty_edge_list = []
for i in range(n_nodes):
for j in range(n_nodes):
......@@ -136,11 +154,13 @@ class VGRelation(VisionDataset):
empty_edge_list.append((i, j))
if len(empty_edge_list) > 0:
src, dst = tuple(zip(*empty_edge_list))
g.add_edges(src, dst, {'rel_class': mx.nd.zeros((len(empty_edge_list), 1))})
g.add_edges(
src, dst, {"rel_class": mx.nd.zeros((len(empty_edge_list), 1))}
)
# assign features
g.ndata['bbox'] = bbox
g.ndata['node_class'] = node_class_ids
g.ndata['node_class_vec'] = node_class_vec
g.ndata["bbox"] = bbox
g.ndata["node_class"] = node_class_ids
g.ndata["node_class_vec"] = node_class_vec
return g, img
import dgl
import argparse
import mxnet as mx
import gluoncv as gcv
from gluoncv.utilz import download
import mxnet as mx
from data import *
from gluoncv.data.transforms import presets
from model import faster_rcnn_resnet101_v1d_custom, RelDN
from gluoncv.utilz import download
from model import RelDN, faster_rcnn_resnet101_v1d_custom
from utils import *
from data import *
import dgl
def parse_args():
parser = argparse.ArgumentParser(description='Demo of Scene Graph Extraction.')
parser.add_argument('--image', type=str, default='',
help="The image for scene graph extraction.")
parser.add_argument('--gpu', type=str, default='',
help="GPU id to use for inference, default is not using GPU.")
parser.add_argument('--pretrained-faster-rcnn-params', type=str, default='',
help="Path to saved Faster R-CNN model parameters.")
parser.add_argument('--reldn-params', type=str, default='',
help="Path to saved Faster R-CNN model parameters.")
parser.add_argument('--faster-rcnn-params', type=str, default='',
help="Path to saved Faster R-CNN model parameters.")
parser.add_argument('--freq-prior', type=str, default='freq_prior.pkl',
help="Path to saved frequency prior data.")
parser = argparse.ArgumentParser(
description="Demo of Scene Graph Extraction."
)
parser.add_argument(
"--image",
type=str,
default="",
help="The image for scene graph extraction.",
)
parser.add_argument(
"--gpu",
type=str,
default="",
help="GPU id to use for inference, default is not using GPU.",
)
parser.add_argument(
"--pretrained-faster-rcnn-params",
type=str,
default="",
help="Path to saved Faster R-CNN model parameters.",
)
parser.add_argument(
"--reldn-params",
type=str,
default="",
help="Path to saved Faster R-CNN model parameters.",
)
parser.add_argument(
"--faster-rcnn-params",
type=str,
default="",
help="Path to saved Faster R-CNN model parameters.",
)
parser.add_argument(
"--freq-prior",
type=str,
default="freq_prior.pkl",
help="Path to saved frequency prior data.",
)
args = parser.parse_args()
return args
args = parse_args()
if args.gpu:
ctx = mx.gpu(int(args.gpu))
......@@ -32,31 +62,47 @@ else:
ctx = mx.cpu()
net = RelDN(n_classes=50, prior_pkl=args.freq_prior, semantic_only=False)
if args.reldn_params == '':
download('http://data.dgl.ai/models/SceneGraph/reldn.params')
net.load_parameters('rendl.params', ctx=ctx)
if args.reldn_params == "":
download("http://data.dgl.ai/models/SceneGraph/reldn.params")
net.load_parameters("rendl.params", ctx=ctx)
else:
net.load_parameters(args.reldn_params, ctx=ctx)
# dataset and dataloader
vg_val = VGRelation(split='val')
detector = faster_rcnn_resnet101_v1d_custom(classes=vg_val.obj_classes,
pretrained_base=False, pretrained=False,
additional_output=True)
if args.pretrained_faster_rcnn_params == '':
download('http://data.dgl.ai/models/SceneGraph/faster_rcnn_resnet101_v1d_visualgenome.params')
params_path = 'faster_rcnn_resnet101_v1d_visualgenome.params'
vg_val = VGRelation(split="val")
detector = faster_rcnn_resnet101_v1d_custom(
classes=vg_val.obj_classes,
pretrained_base=False,
pretrained=False,
additional_output=True,
)
if args.pretrained_faster_rcnn_params == "":
download(
"http://data.dgl.ai/models/SceneGraph/faster_rcnn_resnet101_v1d_visualgenome.params"
)
params_path = "faster_rcnn_resnet101_v1d_visualgenome.params"
else:
params_path = args.pretrained_faster_rcnn_params
detector.load_parameters(params_path, ctx=ctx, ignore_extra=True, allow_missing=True)
detector.load_parameters(
params_path, ctx=ctx, ignore_extra=True, allow_missing=True
)
detector_feat = faster_rcnn_resnet101_v1d_custom(classes=vg_val.obj_classes,
pretrained_base=False, pretrained=False,
additional_output=True)
detector_feat.load_parameters(params_path, ctx=ctx, ignore_extra=True, allow_missing=True)
if args.faster_rcnn_params == '':
download('http://data.dgl.ai/models/SceneGraph/faster_rcnn_resnet101_v1d_visualgenome.params')
detector_feat.features.load_parameters('faster_rcnn_resnet101_v1d_visualgenome.params', ctx=ctx)
detector_feat = faster_rcnn_resnet101_v1d_custom(
classes=vg_val.obj_classes,
pretrained_base=False,
pretrained=False,
additional_output=True,
)
detector_feat.load_parameters(
params_path, ctx=ctx, ignore_extra=True, allow_missing=True
)
if args.faster_rcnn_params == "":
download(
"http://data.dgl.ai/models/SceneGraph/faster_rcnn_resnet101_v1d_visualgenome.params"
)
detector_feat.features.load_parameters(
"faster_rcnn_resnet101_v1d_visualgenome.params", ctx=ctx
)
else:
detector_feat.features.load_parameters(args.faster_rcnn_params, ctx=ctx)
......@@ -64,24 +110,37 @@ else:
if args.image:
image_path = args.image
else:
gcv.utils.download('https://raw.githubusercontent.com/dmlc/web-data/master/' +
'dgl/examples/mxnet/scenegraph/old-couple.png',
'old-couple.png')
image_path = 'old-couple.png'
x, img = presets.rcnn.load_test(args.image, short=detector.short, max_size=detector.max_size)
gcv.utils.download(
"https://raw.githubusercontent.com/dmlc/web-data/master/"
+ "dgl/examples/mxnet/scenegraph/old-couple.png",
"old-couple.png",
)
image_path = "old-couple.png"
x, img = presets.rcnn.load_test(
args.image, short=detector.short, max_size=detector.max_size
)
x = x.as_in_context(ctx)
# detector prediction
ids, scores, bboxes, feat, feat_ind, spatial_feat = detector(x)
# build graph, extract edge features
g = build_graph_validate_pred(x, ids, scores, bboxes, feat_ind, spatial_feat, bbox_improvement=True, scores_top_k=75, overlap=False)
rel_bbox = g.edata['rel_bbox'].expand_dims(0).as_in_context(ctx)
g = build_graph_validate_pred(
x,
ids,
scores,
bboxes,
feat_ind,
spatial_feat,
bbox_improvement=True,
scores_top_k=75,
overlap=False,
)
rel_bbox = g.edata["rel_bbox"].expand_dims(0).as_in_context(ctx)
_, _, _, spatial_feat_rel = detector_feat(x, None, None, rel_bbox)
g.edata['edge_feat'] = spatial_feat_rel[0]
g.edata["edge_feat"] = spatial_feat_rel[0]
# graph prediction
g = net(g)
_, preds = extract_pred(g, joint_preds=True)
preds = preds[preds[:,1].argsort()[::-1]]
preds = preds[preds[:, 1].argsort()[::-1]]
plot_sg(img, preds, detector.classes, vg_val.rel_classes, 10)
import dgl
import pickle
import gluoncv as gcv
import mxnet as mx
import numpy as np
from mxnet import nd
from mxnet.gluon import nn
from dgl.utils import toindex
import pickle
import dgl
from dgl.nn.mxnet import GraphConv
from dgl.utils import toindex
__all__ = ["RelDN"]
__all__ = ['RelDN']
class EdgeConfMLP(nn.Block):
'''compute the confidence for edges'''
"""compute the confidence for edges"""
def __init__(self):
super(EdgeConfMLP, self).__init__()
def forward(self, edges):
score_pred = nd.log_softmax(edges.data['preds'])[:,1:].max(axis=1)
score_phr = score_pred + edges.src['node_class_logit'] + edges.dst['node_class_logit']
return {'score_pred': score_pred,
'score_phr': score_phr}
score_pred = nd.log_softmax(edges.data["preds"])[:, 1:].max(axis=1)
score_phr = (
score_pred
+ edges.src["node_class_logit"]
+ edges.dst["node_class_logit"]
)
return {"score_pred": score_pred, "score_phr": score_phr}
class EdgeBBoxExtend(nn.Block):
'''encode the bounding boxes'''
"""encode the bounding boxes"""
def __init__(self):
super(EdgeBBoxExtend, self).__init__()
def bbox_delta(self, bbox_a, bbox_b):
n = bbox_a.shape[0]
result = nd.zeros((n, 4), ctx=bbox_a.context)
result[:,0] = bbox_a[:,0] - bbox_b[:,0]
result[:,1] = bbox_a[:,1] - bbox_b[:,1]
result[:,2] = nd.log((bbox_a[:,2] - bbox_a[:,0] + 1e-8) / (bbox_b[:,2] - bbox_b[:,0] + 1e-8))
result[:,3] = nd.log((bbox_a[:,3] - bbox_a[:,1] + 1e-8) / (bbox_b[:,3] - bbox_b[:,1] + 1e-8))
result[:, 0] = bbox_a[:, 0] - bbox_b[:, 0]
result[:, 1] = bbox_a[:, 1] - bbox_b[:, 1]
result[:, 2] = nd.log(
(bbox_a[:, 2] - bbox_a[:, 0] + 1e-8)
/ (bbox_b[:, 2] - bbox_b[:, 0] + 1e-8)
)
result[:, 3] = nd.log(
(bbox_a[:, 3] - bbox_a[:, 1] + 1e-8)
/ (bbox_b[:, 3] - bbox_b[:, 1] + 1e-8)
)
return result
def forward(self, edges):
ctx = edges.src['pred_bbox'].context
n = edges.src['pred_bbox'].shape[0]
delta_src_obj = self.bbox_delta(edges.src['pred_bbox'], edges.dst['pred_bbox'])
delta_src_rel = self.bbox_delta(edges.src['pred_bbox'], edges.data['rel_bbox'])
delta_rel_obj = self.bbox_delta(edges.data['rel_bbox'], edges.dst['pred_bbox'])
ctx = edges.src["pred_bbox"].context
n = edges.src["pred_bbox"].shape[0]
delta_src_obj = self.bbox_delta(
edges.src["pred_bbox"], edges.dst["pred_bbox"]
)
delta_src_rel = self.bbox_delta(
edges.src["pred_bbox"], edges.data["rel_bbox"]
)
delta_rel_obj = self.bbox_delta(
edges.data["rel_bbox"], edges.dst["pred_bbox"]
)
result = nd.zeros((n, 12), ctx=ctx)
result[:,0:4] = delta_src_obj
result[:,4:8] = delta_src_rel
result[:,8:12] = delta_rel_obj
return {'pred_bbox_additional': result}
result[:, 0:4] = delta_src_obj
result[:, 4:8] = delta_src_rel
result[:, 8:12] = delta_rel_obj
return {"pred_bbox_additional": result}
class EdgeFreqPrior(nn.Block):
'''make use of the pre-trained frequency prior'''
"""make use of the pre-trained frequency prior"""
def __init__(self, prior_pkl):
super(EdgeFreqPrior, self).__init__()
with open(prior_pkl, 'rb') as f:
with open(prior_pkl, "rb") as f:
freq_prior = pickle.load(f)
self.freq_prior = freq_prior
def forward(self, edges):
ctx = edges.src['node_class_pred'].context
src_ind = edges.src['node_class_pred'].asnumpy().astype(int)
dst_ind = edges.dst['node_class_pred'].asnumpy().astype(int)
ctx = edges.src["node_class_pred"].context
src_ind = edges.src["node_class_pred"].asnumpy().astype(int)
dst_ind = edges.dst["node_class_pred"].asnumpy().astype(int)
prob = self.freq_prior[src_ind, dst_ind]
out = nd.array(prob, ctx=ctx)
return {'freq_prior': out}
return {"freq_prior": out}
class EdgeSpatial(nn.Block):
'''spatial feature branch'''
"""spatial feature branch"""
def __init__(self, n_classes):
super(EdgeSpatial, self).__init__()
self.mlp = nn.Sequential()
......@@ -76,14 +100,20 @@ class EdgeSpatial(nn.Block):
self.mlp.add(nn.Dense(n_classes))
def forward(self, edges):
feat = nd.concat(edges.src['pred_bbox'], edges.dst['pred_bbox'],
edges.data['rel_bbox'], edges.data['pred_bbox_additional'])
feat = nd.concat(
edges.src["pred_bbox"],
edges.dst["pred_bbox"],
edges.data["rel_bbox"],
edges.data["pred_bbox_additional"],
)
out = self.mlp(feat)
return {'spatial': out}
return {"spatial": out}
class EdgeVisual(nn.Block):
'''visual feature branch'''
def __init__(self, n_classes, vis_feat_dim=7*7*3):
"""visual feature branch"""
def __init__(self, n_classes, vis_feat_dim=7 * 7 * 3):
super(EdgeVisual, self).__init__()
self.dim_in = vis_feat_dim
self.mlp_joint = nn.Sequential()
......@@ -97,15 +127,21 @@ class EdgeVisual(nn.Block):
self.mlp_ob = nn.Dense(n_classes)
def forward(self, edges):
feat = nd.concat(edges.src['node_feat'], edges.dst['node_feat'], edges.data['edge_feat'])
feat = nd.concat(
edges.src["node_feat"],
edges.dst["node_feat"],
edges.data["edge_feat"],
)
out_joint = self.mlp_joint(feat)
out_sub = self.mlp_sub(edges.src['node_feat'])
out_ob = self.mlp_ob(edges.dst['node_feat'])
out_sub = self.mlp_sub(edges.src["node_feat"])
out_ob = self.mlp_ob(edges.dst["node_feat"])
out = out_joint + out_sub + out_ob
return {'visual': out}
return {"visual": out}
class RelDN(nn.Block):
'''The RelDN Model'''
"""The RelDN Model"""
def __init__(self, n_classes, prior_pkl, semantic_only=False):
super(RelDN, self).__init__()
# output layers
......@@ -127,13 +163,15 @@ class RelDN(nn.Block):
# predictions
g.apply_edges(self.freq_prior)
if self.semantic_only:
g.edata['preds'] = g.edata['freq_prior']
g.edata["preds"] = g.edata["freq_prior"]
else:
# bbox extension
g.apply_edges(self.edge_bbox_extend)
g.apply_edges(self.spatial)
g.apply_edges(self.visual)
g.edata['preds'] = g.edata['freq_prior'] + g.edata['spatial'] + g.edata['visual']
g.edata["preds"] = (
g.edata["freq_prior"] + g.edata["spatial"] + g.edata["visual"]
)
# subgraph for gconv
g.apply_edges(self.edge_conf_mlp)
return g
import argparse
import json
import os
import pickle
import numpy as np
import json, pickle, os, argparse
def parse_args():
parser = argparse.ArgumentParser(description='Train the Frequenct Prior For RelDN.')
parser.add_argument('--overlap', action='store_true',
help="Only count overlap boxes.")
parser.add_argument('--json-path', type=str, default='~/.mxnet/datasets/visualgenome',
help="Only count overlap boxes.")
parser = argparse.ArgumentParser(
description="Train the Frequenct Prior For RelDN."
)
parser.add_argument(
"--overlap", action="store_true", help="Only count overlap boxes."
)
parser.add_argument(
"--json-path",
type=str,
default="~/.mxnet/datasets/visualgenome",
help="Only count overlap boxes.",
)
args = parser.parse_args()
return args
args = parse_args()
use_overlap = args.overlap
PATH_TO_DATASETS = os.path.expanduser(args.json_path)
path_to_json = os.path.join(PATH_TO_DATASETS, 'rel_annotations_train.json')
path_to_json = os.path.join(PATH_TO_DATASETS, "rel_annotations_train.json")
# format in y1y2x1x2
def with_overlap(boxA, boxB):
......@@ -29,17 +42,19 @@ def with_overlap(boxA, boxB):
return 0
def box_ious(boxes):
n = len(boxes)
res = np.zeros((n, n))
for i in range(n-1):
for j in range(i+1, n):
for i in range(n - 1):
for j in range(i + 1, n):
iou_val = with_overlap(boxes[i], boxes[j])
res[i, j] = iou_val
res[j, i] = iou_val
return res
with open(path_to_json, 'r') as f:
with open(path_to_json, "r") as f:
tmp = f.read()
train_data = json.loads(tmp)
......@@ -49,11 +64,11 @@ bg_matrix = np.zeros((150, 150), dtype=np.int64)
for _, item in train_data.items():
gt_box_to_label = {}
for rel in item:
sub_bbox = rel['subject']['bbox']
ob_bbox = rel['object']['bbox']
sub_class = rel['subject']['category']
ob_class = rel['object']['category']
rel_class = rel['predicate']
sub_bbox = rel["subject"]["bbox"]
ob_bbox = rel["object"]["bbox"]
sub_class = rel["subject"]["category"]
ob_class = rel["object"]["category"]
rel_class = rel["predicate"]
sub_node = tuple(sub_bbox)
ob_node = tuple(ob_bbox)
......@@ -93,8 +108,8 @@ pred_dist = np.log(fg_matrix / (fg_matrix.sum(2)[:, :, None] + eps) + eps)
if use_overlap:
with open('freq_prior_overlap.pkl', 'wb') as f:
with open("freq_prior_overlap.pkl", "wb") as f:
pickle.dump(pred_dist, f)
else:
with open('freq_prior.pkl', 'wb') as f:
with open("freq_prior.pkl", "wb") as f:
pickle.dump(pred_dist, f)
This diff is collapsed.
from .metric import *
from .build_graph import *
from .metric import *
from .sampling import *
from .viz import *
import dgl
from mxnet import nd
import numpy as np
from mxnet import nd
import dgl
def bbox_improve(bbox):
'''bbox encoding'''
area = (bbox[:,2] - bbox[:,0]) * (bbox[:,3] - bbox[:,1])
"""bbox encoding"""
area = (bbox[:, 2] - bbox[:, 0]) * (bbox[:, 3] - bbox[:, 1])
return nd.concat(bbox, area.expand_dims(1))
def extract_edge_bbox(g):
'''bbox encoding'''
src, dst = g.edges(order='eid')
"""bbox encoding"""
src, dst = g.edges(order="eid")
n = g.number_of_edges()
src_bbox = g.ndata['pred_bbox'][src.asnumpy()]
dst_bbox = g.ndata['pred_bbox'][dst.asnumpy()]
edge_bbox = nd.zeros((n, 4), ctx=g.ndata['pred_bbox'].context)
edge_bbox[:,0] = nd.stack(src_bbox[:,0], dst_bbox[:,0]).min(axis=0)
edge_bbox[:,1] = nd.stack(src_bbox[:,1], dst_bbox[:,1]).min(axis=0)
edge_bbox[:,2] = nd.stack(src_bbox[:,2], dst_bbox[:,2]).max(axis=0)
edge_bbox[:,3] = nd.stack(src_bbox[:,3], dst_bbox[:,3]).max(axis=0)
src_bbox = g.ndata["pred_bbox"][src.asnumpy()]
dst_bbox = g.ndata["pred_bbox"][dst.asnumpy()]
edge_bbox = nd.zeros((n, 4), ctx=g.ndata["pred_bbox"].context)
edge_bbox[:, 0] = nd.stack(src_bbox[:, 0], dst_bbox[:, 0]).min(axis=0)
edge_bbox[:, 1] = nd.stack(src_bbox[:, 1], dst_bbox[:, 1]).min(axis=0)
edge_bbox[:, 2] = nd.stack(src_bbox[:, 2], dst_bbox[:, 2]).max(axis=0)
edge_bbox[:, 3] = nd.stack(src_bbox[:, 3], dst_bbox[:, 3]).max(axis=0)
return edge_bbox
def build_graph_train(g_slice, gt_bbox, img, ids, scores, bbox, feat_ind,
spatial_feat, iou_thresh=0.5,
bbox_improvement=True, scores_top_k=50, overlap=False):
'''given ground truth and predicted bboxes, assign the label to the predicted w.r.t iou_thresh'''
def build_graph_train(
g_slice,
gt_bbox,
img,
ids,
scores,
bbox,
feat_ind,
spatial_feat,
iou_thresh=0.5,
bbox_improvement=True,
scores_top_k=50,
overlap=False,
):
"""given ground truth and predicted bboxes, assign the label to the predicted w.r.t iou_thresh"""
# match and re-factor the graph
img_size = img.shape[2:4]
gt_bbox[:, :, 0] /= img_size[1]
......@@ -39,24 +54,33 @@ def build_graph_train(g_slice, gt_bbox, img, ids, scores, bbox, feat_ind,
g_pred_batch = []
for gi in range(n_graph):
g = g_slice[gi]
ctx = g.ndata['bbox'].context
ctx = g.ndata["bbox"].context
inds = np.where(scores[gi, :, 0].asnumpy() > 0)[0].tolist()
if len(inds) == 0:
return None
if len(inds) > scores_top_k:
top_score_inds = scores[gi, inds, 0].asnumpy().argsort()[::-1][0:scores_top_k]
top_score_inds = (
scores[gi, inds, 0].asnumpy().argsort()[::-1][0:scores_top_k]
)
inds = np.array(inds)[top_score_inds].tolist()
n_nodes = len(inds)
roi_ind = feat_ind[gi, inds].squeeze(axis=1)
g_pred = dgl.DGLGraph()
g_pred.add_nodes(n_nodes, {'pred_bbox': bbox[gi, inds],
'node_feat': spatial_feat[gi, roi_ind],
'node_class_pred': ids[gi, inds, 0],
'node_class_logit': nd.log(scores[gi, inds, 0] + 1e-7)})
g_pred.add_nodes(
n_nodes,
{
"pred_bbox": bbox[gi, inds],
"node_feat": spatial_feat[gi, roi_ind],
"node_class_pred": ids[gi, inds, 0],
"node_class_logit": nd.log(scores[gi, inds, 0] + 1e-7),
},
)
# iou matching
ious = nd.contrib.box_iou(gt_bbox[gi], g_pred.ndata['pred_bbox']).asnumpy()
ious = nd.contrib.box_iou(
gt_bbox[gi], g_pred.ndata["pred_bbox"]
).asnumpy()
H, W = ious.shape
h = H
w = W
......@@ -70,8 +94,8 @@ def build_graph_train(g_slice, gt_bbox, img, ids, scores, bbox, feat_ind,
if ious[row_ind, col_ind] < iou_thresh:
break
pred_to_gt_ind[col_ind] = row_ind
gt_node_class = g.ndata['node_class'][row_ind]
pred_node_class = g_pred.ndata['node_class_pred'][col_ind]
gt_node_class = g.ndata["node_class"][row_ind]
pred_node_class = g_pred.ndata["node_class_pred"][col_ind]
if gt_node_class == pred_node_class:
pred_to_gt_class_match[col_ind] = 1
pred_to_gt_class_match_id[col_ind] = row_ind
......@@ -84,7 +108,7 @@ def build_graph_train(g_slice, gt_bbox, img, ids, scores, bbox, feat_ind,
triplet = []
adjmat = np.zeros((n_nodes, n_nodes))
src, dst = g.all_edges(order='eid')
src, dst = g.all_edges(order="eid")
eid_keys = np.column_stack([src.asnumpy(), dst.asnumpy()])
eid_dict = {}
for i, key in enumerate(eid_keys):
......@@ -93,7 +117,7 @@ def build_graph_train(g_slice, gt_bbox, img, ids, scores, bbox, feat_ind,
eid_dict[k] = [i]
else:
eid_dict[k].append(i)
ori_rel_class = g.edata['rel_class'].asnumpy()
ori_rel_class = g.edata["rel_class"].asnumpy()
for i in range(n_nodes):
for j in range(n_nodes):
if i != j:
......@@ -105,25 +129,27 @@ def build_graph_train(g_slice, gt_bbox, img, ids, scores, bbox, feat_ind,
n_edges_between = len(rel_cls)
for ii in range(n_edges_between):
triplet.append((i, j, rel_cls[ii]))
adjmat[i,j] = 1
adjmat[i, j] = 1
else:
triplet.append((i, j, 0))
src, dst, rel_class = tuple(zip(*triplet))
rel_class = nd.array(rel_class, ctx=ctx).expand_dims(1)
g_pred.add_edges(src, dst, data={'rel_class': rel_class})
g_pred.add_edges(src, dst, data={"rel_class": rel_class})
# other operations
n_nodes = g_pred.number_of_nodes()
n_edges = g_pred.number_of_edges()
if bbox_improvement:
g_pred.ndata['pred_bbox'] = bbox_improve(g_pred.ndata['pred_bbox'])
g_pred.edata['rel_bbox'] = extract_edge_bbox(g_pred)
g_pred.edata['batch_id'] = nd.zeros((n_edges, 1), ctx = ctx) + gi
g_pred.ndata["pred_bbox"] = bbox_improve(g_pred.ndata["pred_bbox"])
g_pred.edata["rel_bbox"] = extract_edge_bbox(g_pred)
g_pred.edata["batch_id"] = nd.zeros((n_edges, 1), ctx=ctx) + gi
# remove non-overlapping edges
if overlap:
overlap_ious = nd.contrib.box_iou(g_pred.ndata['pred_bbox'][:,0:4],
g_pred.ndata['pred_bbox'][:,0:4]).asnumpy()
overlap_ious = nd.contrib.box_iou(
g_pred.ndata["pred_bbox"][:, 0:4],
g_pred.ndata["pred_bbox"][:, 0:4],
).asnumpy()
cols, rows = np.where(overlap_ious <= 1e-7)
if cols.shape[0] > 0:
eids = g_pred.edge_ids(cols, rows)[2].asnumpy().tolist()
......@@ -138,9 +164,11 @@ def build_graph_train(g_slice, gt_bbox, img, ids, scores, bbox, feat_ind,
else:
return g_pred_batch[0]
def build_graph_validate_gt_obj(img, gt_ids, bbox, spatial_feat,
bbox_improvement=True, overlap=False):
'''given ground truth bbox and label, build graph for validation'''
def build_graph_validate_gt_obj(
img, gt_ids, bbox, spatial_feat, bbox_improvement=True, overlap=False
):
"""given ground truth bbox and label, build graph for validation"""
n_batch = img.shape[0]
img_size = img.shape[2:4]
bbox[:, :, 0] /= img_size[1]
......@@ -156,10 +184,17 @@ def build_graph_validate_gt_obj(img, gt_ids, bbox, spatial_feat,
continue
n_nodes = len(inds)
g_pred = dgl.DGLGraph()
g_pred.add_nodes(n_nodes, {'pred_bbox': bbox[btc, inds],
'node_feat': spatial_feat[btc, inds],
'node_class_pred': gt_ids[btc, inds, 0],
'node_class_logit': nd.zeros_like(gt_ids[btc, inds, 0], ctx=ctx)})
g_pred.add_nodes(
n_nodes,
{
"pred_bbox": bbox[btc, inds],
"node_feat": spatial_feat[btc, inds],
"node_class_pred": gt_ids[btc, inds, 0],
"node_class_logit": nd.zeros_like(
gt_ids[btc, inds, 0], ctx=ctx
),
},
)
edge_list = []
for i in range(n_nodes - 1):
......@@ -172,9 +207,9 @@ def build_graph_validate_gt_obj(img, gt_ids, bbox, spatial_feat,
n_nodes = g_pred.number_of_nodes()
n_edges = g_pred.number_of_edges()
if bbox_improvement:
g_pred.ndata['pred_bbox'] = bbox_improve(g_pred.ndata['pred_bbox'])
g_pred.edata['rel_bbox'] = extract_edge_bbox(g_pred)
g_pred.edata['batch_id'] = nd.zeros((n_edges, 1), ctx = ctx) + btc
g_pred.ndata["pred_bbox"] = bbox_improve(g_pred.ndata["pred_bbox"])
g_pred.edata["rel_bbox"] = extract_edge_bbox(g_pred)
g_pred.edata["batch_id"] = nd.zeros((n_edges, 1), ctx=ctx) + btc
g_batch.append(g_pred)
......@@ -184,9 +219,18 @@ def build_graph_validate_gt_obj(img, gt_ids, bbox, spatial_feat,
return dgl.batch(g_batch)
return g_batch[0]
def build_graph_validate_gt_bbox(img, ids, scores, bbox, spatial_feat, gt_ids=None,
bbox_improvement=True, overlap=False):
'''given ground truth bbox, build graph for validation'''
def build_graph_validate_gt_bbox(
img,
ids,
scores,
bbox,
spatial_feat,
gt_ids=None,
bbox_improvement=True,
overlap=False,
):
"""given ground truth bbox, build graph for validation"""
n_batch = img.shape[0]
img_size = img.shape[2:4]
bbox[:, :, 0] /= img_size[1]
......@@ -197,17 +241,22 @@ def build_graph_validate_gt_bbox(img, ids, scores, bbox, spatial_feat, gt_ids=No
g_batch = []
for btc in range(n_batch):
id_btc = scores[btc][:,:,0].argmax(0)
score_btc = scores[btc][:,:,0].max(0)
id_btc = scores[btc][:, :, 0].argmax(0)
score_btc = scores[btc][:, :, 0].max(0)
inds = np.where(bbox[btc].sum(1).asnumpy() > 0)[0].tolist()
if len(inds) == 0:
continue
n_nodes = len(inds)
g_pred = dgl.DGLGraph()
g_pred.add_nodes(n_nodes, {'pred_bbox': bbox[btc, inds],
'node_feat': spatial_feat[btc, inds],
'node_class_pred': id_btc,
'node_class_logit': nd.log(score_btc + 1e-7)})
g_pred.add_nodes(
n_nodes,
{
"pred_bbox": bbox[btc, inds],
"node_feat": spatial_feat[btc, inds],
"node_class_pred": id_btc,
"node_class_logit": nd.log(score_btc + 1e-7),
},
)
edge_list = []
for i in range(n_nodes - 1):
......@@ -220,9 +269,9 @@ def build_graph_validate_gt_bbox(img, ids, scores, bbox, spatial_feat, gt_ids=No
n_nodes = g_pred.number_of_nodes()
n_edges = g_pred.number_of_edges()
if bbox_improvement:
g_pred.ndata['pred_bbox'] = bbox_improve(g_pred.ndata['pred_bbox'])
g_pred.edata['rel_bbox'] = extract_edge_bbox(g_pred)
g_pred.edata['batch_id'] = nd.zeros((n_edges, 1), ctx = ctx) + btc
g_pred.ndata["pred_bbox"] = bbox_improve(g_pred.ndata["pred_bbox"])
g_pred.edata["rel_bbox"] = extract_edge_bbox(g_pred)
g_pred.edata["batch_id"] = nd.zeros((n_edges, 1), ctx=ctx) + btc
g_batch.append(g_pred)
......@@ -232,9 +281,19 @@ def build_graph_validate_gt_bbox(img, ids, scores, bbox, spatial_feat, gt_ids=No
return dgl.batch(g_batch)
return g_batch[0]
def build_graph_validate_pred(img, ids, scores, bbox, feat_ind, spatial_feat,
bbox_improvement=True, scores_top_k=50, overlap=False):
'''given predicted bbox, build graph for validation'''
def build_graph_validate_pred(
img,
ids,
scores,
bbox,
feat_ind,
spatial_feat,
bbox_improvement=True,
scores_top_k=50,
overlap=False,
):
"""given predicted bbox, build graph for validation"""
n_batch = img.shape[0]
img_size = img.shape[2:4]
bbox[:, :, 0] /= img_size[1]
......@@ -249,16 +308,23 @@ def build_graph_validate_pred(img, ids, scores, bbox, feat_ind, spatial_feat,
if len(inds) == 0:
continue
if len(inds) > scores_top_k:
top_score_inds = scores[btc, inds, 0].asnumpy().argsort()[::-1][0:scores_top_k]
top_score_inds = (
scores[btc, inds, 0].asnumpy().argsort()[::-1][0:scores_top_k]
)
inds = np.array(inds)[top_score_inds].tolist()
n_nodes = len(inds)
roi_ind = feat_ind[btc, inds].squeeze(axis=1)
g_pred = dgl.DGLGraph()
g_pred.add_nodes(n_nodes, {'pred_bbox': bbox[btc, inds],
'node_feat': spatial_feat[btc, roi_ind],
'node_class_pred': ids[btc, inds, 0],
'node_class_logit': nd.log(scores[btc, inds, 0] + 1e-7)})
g_pred.add_nodes(
n_nodes,
{
"pred_bbox": bbox[btc, inds],
"node_feat": spatial_feat[btc, roi_ind],
"node_class_pred": ids[btc, inds, 0],
"node_class_logit": nd.log(scores[btc, inds, 0] + 1e-7),
},
)
edge_list = []
for i in range(n_nodes - 1):
......@@ -271,9 +337,9 @@ def build_graph_validate_pred(img, ids, scores, bbox, feat_ind, spatial_feat,
n_nodes = g_pred.number_of_nodes()
n_edges = g_pred.number_of_edges()
if bbox_improvement:
g_pred.ndata['pred_bbox'] = bbox_improve(g_pred.ndata['pred_bbox'])
g_pred.edata['rel_bbox'] = extract_edge_bbox(g_pred)
g_pred.edata['batch_id'] = nd.zeros((n_edges, 1), ctx = ctx) + btc
g_pred.ndata["pred_bbox"] = bbox_improve(g_pred.ndata["pred_bbox"])
g_pred.edata["rel_bbox"] = extract_edge_bbox(g_pred)
g_pred.edata["batch_id"] = nd.zeros((n_edges, 1), ctx=ctx) + btc
g_batch.append(g_pred)
......
This diff is collapsed.
import dgl
from dgl.utils import toindex
import mxnet as mx
import numpy as np
import dgl
from dgl.utils import toindex
def l0_sample(g, positive_max=128, negative_ratio=3):
'''sampling positive and negative edges'''
"""sampling positive and negative edges"""
if g is None:
return None
n_eids = g.number_of_edges()
pos_eids = np.where(g.edata['rel_class'].asnumpy() > 0)[0]
neg_eids = np.where(g.edata['rel_class'].asnumpy() == 0)[0]
pos_eids = np.where(g.edata["rel_class"].asnumpy() > 0)[0]
neg_eids = np.where(g.edata["rel_class"].asnumpy() == 0)[0]
if len(pos_eids) == 0:
return None
......@@ -26,6 +28,7 @@ def l0_sample(g, positive_max=128, negative_ratio=3):
eids = np.where(weights > 0)[0]
sub_g = g.edge_subgraph(toindex(eids.tolist()))
sub_g.copy_from_parent()
sub_g.edata['sample_weights'] = mx.nd.array(weights[eids],
ctx=g.edata['rel_class'].context)
sub_g.edata["sample_weights"] = mx.nd.array(
weights[eids], ctx=g.edata["rel_class"].context
)
return sub_g
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -6,18 +6,15 @@ References:
"""
import mxnet as mx
from mxnet import gluon
import dgl
from dgl.nn.mxnet import TAGConv
class TAGCN(gluon.Block):
def __init__(self,
g,
in_feats,
n_hidden,
n_classes,
n_layers,
activation,
dropout):
def __init__(
self, g, in_feats, n_hidden, n_classes, n_layers, activation, dropout
):
super(TAGCN, self).__init__()
self.g = g
self.layers = gluon.nn.Sequential()
......@@ -27,7 +24,7 @@ class TAGCN(gluon.Block):
for i in range(n_layers - 1):
self.layers.add(TAGConv(n_hidden, n_hidden, activation=activation))
# output layer
self.layers.add(TAGConv(n_hidden, n_classes)) #activation=None
self.layers.add(TAGConv(n_hidden, n_classes)) # activation=None
self.dropout = gluon.nn.Dropout(rate=dropout)
def forward(self, features):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment