Unverified Commit a9f2acf3 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4641)



* [Misc] Black auto fix.

* sort
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 08c50eb7
import argparse
import time
import numpy as np
import networkx as nx
import mxnet as mx
import networkx as nx
import numpy as np
from mxnet import gluon, nd
from mxnet.gluon import nn
import dgl
from dgl.data import register_data_args
from dgl.data import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset
from dgl.data import (CiteseerGraphDataset, CoraGraphDataset,
PubmedGraphDataset, register_data_args)
from dgl.nn.mxnet.conv import GMMConv
class MoNet(nn.Block):
def __init__(self,
g,
in_feats,
n_hidden,
out_feats,
n_layers,
dim,
n_kernels,
dropout):
def __init__(
self,
g,
in_feats,
n_hidden,
out_feats,
n_layers,
dim,
n_kernels,
dropout,
):
super(MoNet, self).__init__()
self.g = g
with self.name_scope():
......@@ -28,18 +32,19 @@ class MoNet(nn.Block):
self.pseudo_proj = nn.Sequential()
# Input layer
self.layers.add(
GMMConv(in_feats, n_hidden, dim, n_kernels))
self.pseudo_proj.add(nn.Dense(dim, in_units=2, activation='tanh'))
self.layers.add(GMMConv(in_feats, n_hidden, dim, n_kernels))
self.pseudo_proj.add(nn.Dense(dim, in_units=2, activation="tanh"))
# Hidden layer
for _ in range(n_layers - 1):
self.layers.add(GMMConv(n_hidden, n_hidden, dim, n_kernels))
self.pseudo_proj.add(nn.Dense(dim, in_units=2, activation='tanh'))
self.pseudo_proj.add(
nn.Dense(dim, in_units=2, activation="tanh")
)
# Output layer
self.layers.add(GMMConv(n_hidden, out_feats, dim, n_kernels))
self.pseudo_proj.add(nn.Dense(dim, in_units=2, activation='tanh'))
self.pseudo_proj.add(nn.Dense(dim, in_units=2, activation="tanh"))
self.dropout = nn.Dropout(dropout)
......@@ -48,8 +53,7 @@ class MoNet(nn.Block):
for i in range(len(self.layers)):
if i > 0:
h = self.dropout(h)
h = self.layers[i](
self.g, h, self.pseudo_proj[i](pseudo))
h = self.layers[i](self.g, h, self.pseudo_proj[i](pseudo))
return h
......@@ -58,16 +62,17 @@ def evaluate(model, features, pseudo, labels, mask):
accuracy = ((pred == labels) * mask).sum() / mask.sum().asscalar()
return accuracy.asscalar()
def main(args):
# load and preprocess dataset
if args.dataset == 'cora':
if args.dataset == "cora":
data = CoraGraphDataset()
elif args.dataset == 'citeseer':
elif args.dataset == "citeseer":
data = CiteseerGraphDataset()
elif args.dataset == 'pubmed':
elif args.dataset == "pubmed":
data = PubmedGraphDataset()
else:
raise ValueError('Unknown dataset: {}'.format(args.dataset))
raise ValueError("Unknown dataset: {}".format(args.dataset))
g = data[0]
if args.gpu < 0:
......@@ -78,24 +83,29 @@ def main(args):
ctx = mx.gpu(args.gpu)
g = g.to(ctx)
features = g.ndata['feat']
labels = mx.nd.array(g.ndata['label'], dtype="float32", ctx=ctx)
train_mask = g.ndata['train_mask']
val_mask = g.ndata['val_mask']
test_mask = g.ndata['test_mask']
features = g.ndata["feat"]
labels = mx.nd.array(g.ndata["label"], dtype="float32", ctx=ctx)
train_mask = g.ndata["train_mask"]
val_mask = g.ndata["val_mask"]
test_mask = g.ndata["test_mask"]
in_feats = features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()
print("""----Data statistics------'
print(
"""----Data statistics------'
#Edges %d
#Classes %d
#Train samples %d
#Val samples %d
#Test samples %d""" %
(n_edges, n_classes,
train_mask.sum().asscalar(),
val_mask.sum().asscalar(),
test_mask.sum().asscalar()))
#Test samples %d"""
% (
n_edges,
n_classes,
train_mask.sum().asscalar(),
val_mask.sum().asscalar(),
test_mask.sum().asscalar(),
)
)
# add self loop
g = dgl.remove_self_loop(g)
......@@ -107,30 +117,32 @@ def main(args):
vs = vs.asnumpy()
pseudo = []
for i in range(g.number_of_edges()):
pseudo.append([
1 / np.sqrt(g.in_degree(us[i])),
1 / np.sqrt(g.in_degree(vs[i]))
])
pseudo.append(
[1 / np.sqrt(g.in_degree(us[i])), 1 / np.sqrt(g.in_degree(vs[i]))]
)
pseudo = nd.array(pseudo, ctx=ctx)
# create GraphSAGE model
model = MoNet(g,
in_feats,
args.n_hidden,
n_classes,
args.n_layers,
args.pseudo_dim,
args.n_kernels,
args.dropout
)
model = MoNet(
g,
in_feats,
args.n_hidden,
n_classes,
args.n_layers,
args.pseudo_dim,
args.n_kernels,
args.dropout,
)
model.initialize(ctx=ctx)
n_train_samples = train_mask.sum().asscalar()
loss_fcn = gluon.loss.SoftmaxCELoss()
print(model.collect_params())
trainer = gluon.Trainer(model.collect_params(), 'adam',
{'learning_rate': args.lr, 'wd': args.weight_decay})
trainer = gluon.Trainer(
model.collect_params(),
"adam",
{"learning_rate": args.lr, "wd": args.weight_decay},
)
# initialize graph
dur = []
......@@ -150,37 +162,55 @@ def main(args):
loss.asscalar()
dur.append(time.time() - t0)
acc = evaluate(model, features, pseudo, labels, val_mask)
print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
"ETputs(KTEPS) {:.2f}". format(
epoch, np.mean(dur), loss.asscalar(), acc, n_edges / np.mean(dur) / 1000))
print(
"Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
"ETputs(KTEPS) {:.2f}".format(
epoch,
np.mean(dur),
loss.asscalar(),
acc,
n_edges / np.mean(dur) / 1000,
)
)
# test set accuracy
acc = evaluate(model, features, pseudo, labels, test_mask)
print("Test accuracy {:.2%}".format(acc))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='MoNet on citation network')
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MoNet on citation network")
register_data_args(parser)
parser.add_argument("--dropout", type=float, default=0.5,
help="dropout probability")
parser.add_argument("--gpu", type=int, default=-1,
help="gpu")
parser.add_argument("--lr", type=float, default=1e-2,
help="learning rate")
parser.add_argument("--n-epochs", type=int, default=200,
help="number of training epochs")
parser.add_argument("--n-hidden", type=int, default=16,
help="number of hidden gcn units")
parser.add_argument("--n-layers", type=int, default=1,
help="number of hidden gcn layers")
parser.add_argument("--pseudo-dim", type=int, default=2,
help="Pseudo coordinate dimensions in GMMConv, 2 for cora and 3 for pubmed")
parser.add_argument("--n-kernels", type=int, default=3,
help="Number of kernels in GMMConv layer")
parser.add_argument("--weight-decay", type=float, default=5e-5,
help="Weight for L2 loss")
parser.add_argument(
"--dropout", type=float, default=0.5, help="dropout probability"
)
parser.add_argument("--gpu", type=int, default=-1, help="gpu")
parser.add_argument("--lr", type=float, default=1e-2, help="learning rate")
parser.add_argument(
"--n-epochs", type=int, default=200, help="number of training epochs"
)
parser.add_argument(
"--n-hidden", type=int, default=16, help="number of hidden gcn units"
)
parser.add_argument(
"--n-layers", type=int, default=1, help="number of hidden gcn layers"
)
parser.add_argument(
"--pseudo-dim",
type=int,
default=2,
help="Pseudo coordinate dimensions in GMMConv, 2 for cora and 3 for pubmed",
)
parser.add_argument(
"--n-kernels",
type=int,
default=3,
help="Number of kernels in GMMConv layer",
)
parser.add_argument(
"--weight-decay", type=float, default=5e-5, help="Weight for L2 loss"
)
args = parser.parse_args()
print(args)
main(args)
\ No newline at end of file
main(args)
import mxnet as mx
from mxnet import gluon
class BaseRGCN(gluon.Block):
def __init__(self, num_nodes, h_dim, out_dim, num_rels, num_bases=-1,
num_hidden_layers=1, dropout=0,
use_self_loop=False, gpu_id=-1):
def __init__(
self,
num_nodes,
h_dim,
out_dim,
num_rels,
num_bases=-1,
num_hidden_layers=1,
dropout=0,
use_self_loop=False,
gpu_id=-1,
):
super(BaseRGCN, self).__init__()
self.num_nodes = num_nodes
self.h_dim = h_dim
......
from .dataloader import *
from .object import *
from .relation import *
from .dataloader import *
"""DataLoader utils."""
import dgl
from mxnet import nd
from gluoncv.data.batchify import Pad
from mxnet import nd
import dgl
def dgl_mp_batchify_fn(data):
if isinstance(data[0], tuple):
data = zip(*data)
return [dgl_mp_batchify_fn(i) for i in data]
for dt in data:
if dt is not None:
if isinstance(dt, dgl.DGLGraph):
......
"""Pascal VOC object detection dataset."""
from __future__ import absolute_import
from __future__ import division
import os
import logging
import warnings
from __future__ import absolute_import, division
import json
import logging
import os
import pickle
import numpy as np
import warnings
from collections import Counter
import mxnet as mx
import numpy as np
from gluoncv.data import COCODetection
from collections import Counter
class VGObject(COCODetection):
CLASSES = ["airplane", "animal", "arm", "bag", "banana", "basket", "beach",
"bear", "bed", "bench", "bike", "bird", "board", "boat", "book",
"boot", "bottle", "bowl", "box", "boy", "branch", "building", "bus",
"cabinet", "cap", "car", "cat", "chair", "child", "clock", "coat",
"counter", "cow", "cup", "curtain", "desk", "dog", "door", "drawer",
"ear", "elephant", "engine", "eye", "face", "fence", "finger", "flag",
"flower", "food", "fork", "fruit", "giraffe", "girl", "glass", "glove",
"guy", "hair", "hand", "handle", "hat", "head", "helmet", "hill",
"horse", "house", "jacket", "jean", "kid", "kite", "lady", "lamp",
"laptop", "leaf", "leg", "letter", "light", "logo", "man", "men",
"motorcycle", "mountain", "mouth", "neck", "nose", "number", "orange",
"pant", "paper", "paw", "people", "person", "phone", "pillow", "pizza",
"plane", "plant", "plate", "player", "pole", "post", "pot", "racket",
"railing", "rock", "roof", "room", "screen", "seat", "sheep", "shelf",
"shirt", "shoe", "short", "sidewalk", "sign", "sink", "skateboard",
"ski", "skier", "sneaker", "snow", "sock", "stand", "street",
"surfboard", "table", "tail", "tie", "tile", "tire", "toilet",
"towel", "tower", "track", "train", "tree", "truck", "trunk",
"umbrella", "vase", "vegetable", "vehicle", "wave", "wheel",
"window", "windshield", "wing", "wire", "woman", "zebra"]
CLASSES = [
"airplane",
"animal",
"arm",
"bag",
"banana",
"basket",
"beach",
"bear",
"bed",
"bench",
"bike",
"bird",
"board",
"boat",
"book",
"boot",
"bottle",
"bowl",
"box",
"boy",
"branch",
"building",
"bus",
"cabinet",
"cap",
"car",
"cat",
"chair",
"child",
"clock",
"coat",
"counter",
"cow",
"cup",
"curtain",
"desk",
"dog",
"door",
"drawer",
"ear",
"elephant",
"engine",
"eye",
"face",
"fence",
"finger",
"flag",
"flower",
"food",
"fork",
"fruit",
"giraffe",
"girl",
"glass",
"glove",
"guy",
"hair",
"hand",
"handle",
"hat",
"head",
"helmet",
"hill",
"horse",
"house",
"jacket",
"jean",
"kid",
"kite",
"lady",
"lamp",
"laptop",
"leaf",
"leg",
"letter",
"light",
"logo",
"man",
"men",
"motorcycle",
"mountain",
"mouth",
"neck",
"nose",
"number",
"orange",
"pant",
"paper",
"paw",
"people",
"person",
"phone",
"pillow",
"pizza",
"plane",
"plant",
"plate",
"player",
"pole",
"post",
"pot",
"racket",
"railing",
"rock",
"roof",
"room",
"screen",
"seat",
"sheep",
"shelf",
"shirt",
"shoe",
"short",
"sidewalk",
"sign",
"sink",
"skateboard",
"ski",
"skier",
"sneaker",
"snow",
"sock",
"stand",
"street",
"surfboard",
"table",
"tail",
"tie",
"tile",
"tire",
"toilet",
"towel",
"tower",
"track",
"train",
"tree",
"truck",
"trunk",
"umbrella",
"vase",
"vegetable",
"vehicle",
"wave",
"wheel",
"window",
"windshield",
"wing",
"wire",
"woman",
"zebra",
]
def __init__(self, **kwargs):
super(VGObject, self).__init__(**kwargs)
@property
def annotation_dir(self):
return ''
return ""
def _parse_image_path(self, entry):
dirname = 'VG_100K'
filename = entry['file_name']
dirname = "VG_100K"
filename = entry["file_name"]
abs_path = os.path.join(self._root, dirname, filename)
return abs_path
"""Prepare Visual Genome datasets"""
import argparse
import json
import os
import pickle
import random
import shutil
import argparse
import zipfile
import random
import json
import tqdm
import pickle
from gluoncv.utils import download, makedirs
_TARGET_DIR = os.path.expanduser('~/.mxnet/datasets/visualgenome')
_TARGET_DIR = os.path.expanduser("~/.mxnet/datasets/visualgenome")
def parse_args():
parser = argparse.ArgumentParser(
description='Initialize Visual Genome dataset.',
epilog='Example: python visualgenome.py --download-dir ~/visualgenome',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--download-dir', type=str, default='~/visualgenome/',
help='dataset directory on disk')
parser.add_argument('--no-download', action='store_true', help='disable automatic download if set')
parser.add_argument('--overwrite', action='store_true', help='overwrite downloaded files if set, in case they are corrupted')
description="Initialize Visual Genome dataset.",
epilog="Example: python visualgenome.py --download-dir ~/visualgenome",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--download-dir",
type=str,
default="~/visualgenome/",
help="dataset directory on disk",
)
parser.add_argument(
"--no-download",
action="store_true",
help="disable automatic download if set",
)
parser.add_argument(
"--overwrite",
action="store_true",
help="overwrite downloaded files if set, in case they are corrupted",
)
args = parser.parse_args()
return args
def download_vg(path, overwrite=False):
_DOWNLOAD_URLS = [
('https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip',
'a055367f675dd5476220e9b93e4ca9957b024b94'),
('https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip',
'2add3aab77623549e92b7f15cda0308f50b64ecf'),
(
"https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip",
"a055367f675dd5476220e9b93e4ca9957b024b94",
),
(
"https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip",
"2add3aab77623549e92b7f15cda0308f50b64ecf",
),
]
makedirs(path)
for url, checksum in _DOWNLOAD_URLS:
filename = download(url, path=path, overwrite=overwrite, sha1_hash=checksum)
filename = download(
url, path=path, overwrite=overwrite, sha1_hash=checksum
)
# extract
if filename.endswith('zip'):
if filename.endswith("zip"):
with zipfile.ZipFile(filename) as zf:
zf.extractall(path=path)
# move all images into folder `VG_100K`
vg_100k_path = os.path.join(path, 'VG_100K')
vg_100k_2_path = os.path.join(path, 'VG_100K_2')
vg_100k_path = os.path.join(path, "VG_100K")
vg_100k_2_path = os.path.join(path, "VG_100K_2")
files_2 = os.listdir(vg_100k_2_path)
for fl in files_2:
shutil.move(os.path.join(vg_100k_2_path, fl),
os.path.join(vg_100k_path, fl))
shutil.move(
os.path.join(vg_100k_2_path, fl), os.path.join(vg_100k_path, fl)
)
def download_json(path, overwrite=False):
url = 'https://data.dgl.ai/dataset/vg.zip'
output = 'vg.zip'
url = "https://data.dgl.ai/dataset/vg.zip"
output = "vg.zip"
download(url, path=path)
with zipfile.ZipFile(output) as zf:
zf.extractall(path=path)
json_path = os.path.join(path, 'vg')
json_path = os.path.join(path, "vg")
json_files = os.listdir(json_path)
for fl in json_files:
shutil.move(os.path.join(json_path, fl),
os.path.join(path, fl))
shutil.move(os.path.join(json_path, fl), os.path.join(path, fl))
os.rmdir(json_path)
if __name__ == '__main__':
if __name__ == "__main__":
args = parse_args()
path = os.path.expanduser(args.download_dir)
if not os.path.isdir(path):
if args.no_download:
raise ValueError(('{} is not a valid directory, make sure it is present.'
' Or you should not disable "--no-download" to grab it'.format(path)))
raise ValueError(
(
"{} is not a valid directory, make sure it is present."
' Or you should not disable "--no-download" to grab it'.format(
path
)
)
)
else:
download_vg(path, overwrite=args.overwrite)
download_json(path, overwrite=args.overwrite)
# make symlink
makedirs(os.path.expanduser('~/.mxnet/datasets'))
makedirs(os.path.expanduser("~/.mxnet/datasets"))
if os.path.isdir(_TARGET_DIR):
os.rmdir(_TARGET_DIR)
os.symlink(path, _TARGET_DIR)
"""Pascal VOC object detection dataset."""
from __future__ import absolute_import
from __future__ import division
import os
import logging
import warnings
from __future__ import absolute_import, division
import json
import dgl
import logging
import os
import pickle
import numpy as np
import warnings
from collections import Counter
import mxnet as mx
import numpy as np
from gluoncv.data.base import VisionDataset
from collections import Counter
from gluoncv.data.transforms.presets.rcnn import FasterRCNNDefaultTrainTransform, FasterRCNNDefaultValTransform
from gluoncv.data.transforms.presets.rcnn import (
FasterRCNNDefaultTrainTransform, FasterRCNNDefaultValTransform)
import dgl
class VGRelation(VisionDataset):
def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'visualgenome'), split='train'):
def __init__(
self,
root=os.path.join("~", ".mxnet", "datasets", "visualgenome"),
split="train",
):
super(VGRelation, self).__init__(root)
self._root = os.path.expanduser(root)
self._img_path = os.path.join(self._root, 'VG_100K', '{}')
if split == 'train':
self._dict_path = os.path.join(self._root, 'rel_annotations_train.json')
elif split == 'val':
self._dict_path = os.path.join(self._root, 'rel_annotations_val.json')
self._img_path = os.path.join(self._root, "VG_100K", "{}")
if split == "train":
self._dict_path = os.path.join(
self._root, "rel_annotations_train.json"
)
elif split == "val":
self._dict_path = os.path.join(
self._root, "rel_annotations_val.json"
)
else:
raise NotImplementedError
with open(self._dict_path) as f:
tmp = f.read()
self._dict = json.loads(tmp)
self._predicates_path = os.path.join(self._root, 'predicates.json')
with open(self._predicates_path, 'r') as f:
self._predicates_path = os.path.join(self._root, "predicates.json")
with open(self._predicates_path, "r") as f:
tmp = f.read()
self.rel_classes = json.loads(tmp)
self.num_rel_classes = len(self.rel_classes) + 1
self._objects_path = os.path.join(self._root, 'objects.json')
with open(self._objects_path, 'r') as f:
self._objects_path = os.path.join(self._root, "objects.json")
with open(self._objects_path, "r") as f:
tmp = f.read()
self.obj_classes = json.loads(tmp)
self.num_obj_classes = len(self.obj_classes)
if split == 'val':
self.img_transform = FasterRCNNDefaultValTransform(short=600, max_size=1000)
if split == "val":
self.img_transform = FasterRCNNDefaultValTransform(
short=600, max_size=1000
)
else:
self.img_transform = FasterRCNNDefaultTrainTransform(short=600, max_size=1000)
self.img_transform = FasterRCNNDefaultTrainTransform(
short=600, max_size=1000
)
self.split = split
def __len__(self):
return len(self._dict)
def _hash_bbox(self, object):
num_list = [object['category']] + object['bbox']
return '_'.join([str(num) for num in num_list])
num_list = [object["category"]] + object["bbox"]
return "_".join([str(num) for num in num_list])
def __getitem__(self, idx):
img_id = list(self._dict)[idx]
......@@ -66,8 +82,8 @@ class VGRelation(VisionDataset):
sub_node_hash = []
ob_node_hash = []
for i, it in enumerate(item):
sub_node_hash.append(self._hash_bbox(it['subject']))
ob_node_hash.append(self._hash_bbox(it['object']))
sub_node_hash.append(self._hash_bbox(it["subject"]))
ob_node_hash.append(self._hash_bbox(it["object"]))
node_set = sorted(list(set(sub_node_hash + ob_node_hash)))
n_nodes = len(node_set)
node_to_id = {}
......@@ -86,35 +102,37 @@ class VGRelation(VisionDataset):
for i, it in enumerate(item):
if not node_visited[sub_id[i]]:
ind = sub_id[i]
sub = it['subject']
node_class_ids[ind] = sub['category']
sub = it["subject"]
node_class_ids[ind] = sub["category"]
# y1y2x1x2 to x1y1x2y2
bbox[ind,0] = sub['bbox'][2]
bbox[ind,1] = sub['bbox'][0]
bbox[ind,2] = sub['bbox'][3]
bbox[ind,3] = sub['bbox'][1]
bbox[ind, 0] = sub["bbox"][2]
bbox[ind, 1] = sub["bbox"][0]
bbox[ind, 2] = sub["bbox"][3]
bbox[ind, 3] = sub["bbox"][1]
node_visited[ind] = True
if not node_visited[ob_id[i]]:
ind = ob_id[i]
ob = it['object']
node_class_ids[ind] = ob['category']
ob = it["object"]
node_class_ids[ind] = ob["category"]
# y1y2x1x2 to x1y1x2y2
bbox[ind,0] = ob['bbox'][2]
bbox[ind,1] = ob['bbox'][0]
bbox[ind,2] = ob['bbox'][3]
bbox[ind,3] = ob['bbox'][1]
bbox[ind, 0] = ob["bbox"][2]
bbox[ind, 1] = ob["bbox"][0]
bbox[ind, 2] = ob["bbox"][3]
bbox[ind, 3] = ob["bbox"][1]
node_visited[ind] = True
eta = 0.1
node_class_vec = node_class_ids[:,0].one_hot(self.num_obj_classes,
on_value = 1 - eta + eta / self.num_obj_classes,
off_value = eta / self.num_obj_classes)
node_class_vec = node_class_ids[:, 0].one_hot(
self.num_obj_classes,
on_value=1 - eta + eta / self.num_obj_classes,
off_value=eta / self.num_obj_classes,
)
# augmentation
if self.split == 'val':
if self.split == "val":
img, bbox, _ = self.img_transform(img, bbox)
else:
img, bbox = self.img_transform(img, bbox)
......@@ -126,9 +144,9 @@ class VGRelation(VisionDataset):
predicate = []
for i, it in enumerate(item):
adjmat[sub_id[i], ob_id[i]] = 1
predicate.append(it['predicate'])
predicate.append(it["predicate"])
predicate = mx.nd.array(predicate).expand_dims(1)
g.add_edges(sub_id, ob_id, {'rel_class': mx.nd.array(predicate) + 1})
g.add_edges(sub_id, ob_id, {"rel_class": mx.nd.array(predicate) + 1})
empty_edge_list = []
for i in range(n_nodes):
for j in range(n_nodes):
......@@ -136,11 +154,13 @@ class VGRelation(VisionDataset):
empty_edge_list.append((i, j))
if len(empty_edge_list) > 0:
src, dst = tuple(zip(*empty_edge_list))
g.add_edges(src, dst, {'rel_class': mx.nd.zeros((len(empty_edge_list), 1))})
g.add_edges(
src, dst, {"rel_class": mx.nd.zeros((len(empty_edge_list), 1))}
)
# assign features
g.ndata['bbox'] = bbox
g.ndata['node_class'] = node_class_ids
g.ndata['node_class_vec'] = node_class_vec
g.ndata["bbox"] = bbox
g.ndata["node_class"] = node_class_ids
g.ndata["node_class_vec"] = node_class_vec
return g, img
import dgl
import argparse
import mxnet as mx
import gluoncv as gcv
from gluoncv.utilz import download
import mxnet as mx
from data import *
from gluoncv.data.transforms import presets
from model import faster_rcnn_resnet101_v1d_custom, RelDN
from gluoncv.utilz import download
from model import RelDN, faster_rcnn_resnet101_v1d_custom
from utils import *
from data import *
import dgl
def parse_args():
parser = argparse.ArgumentParser(description='Demo of Scene Graph Extraction.')
parser.add_argument('--image', type=str, default='',
help="The image for scene graph extraction.")
parser.add_argument('--gpu', type=str, default='',
help="GPU id to use for inference, default is not using GPU.")
parser.add_argument('--pretrained-faster-rcnn-params', type=str, default='',
help="Path to saved Faster R-CNN model parameters.")
parser.add_argument('--reldn-params', type=str, default='',
help="Path to saved Faster R-CNN model parameters.")
parser.add_argument('--faster-rcnn-params', type=str, default='',
help="Path to saved Faster R-CNN model parameters.")
parser.add_argument('--freq-prior', type=str, default='freq_prior.pkl',
help="Path to saved frequency prior data.")
parser = argparse.ArgumentParser(
description="Demo of Scene Graph Extraction."
)
parser.add_argument(
"--image",
type=str,
default="",
help="The image for scene graph extraction.",
)
parser.add_argument(
"--gpu",
type=str,
default="",
help="GPU id to use for inference, default is not using GPU.",
)
parser.add_argument(
"--pretrained-faster-rcnn-params",
type=str,
default="",
help="Path to saved Faster R-CNN model parameters.",
)
parser.add_argument(
"--reldn-params",
type=str,
default="",
help="Path to saved Faster R-CNN model parameters.",
)
parser.add_argument(
"--faster-rcnn-params",
type=str,
default="",
help="Path to saved Faster R-CNN model parameters.",
)
parser.add_argument(
"--freq-prior",
type=str,
default="freq_prior.pkl",
help="Path to saved frequency prior data.",
)
args = parser.parse_args()
return args
args = parse_args()
if args.gpu:
ctx = mx.gpu(int(args.gpu))
......@@ -32,31 +62,47 @@ else:
ctx = mx.cpu()
net = RelDN(n_classes=50, prior_pkl=args.freq_prior, semantic_only=False)
if args.reldn_params == '':
download('http://data.dgl.ai/models/SceneGraph/reldn.params')
net.load_parameters('rendl.params', ctx=ctx)
if args.reldn_params == "":
download("http://data.dgl.ai/models/SceneGraph/reldn.params")
net.load_parameters("rendl.params", ctx=ctx)
else:
net.load_parameters(args.reldn_params, ctx=ctx)
# dataset and dataloader
vg_val = VGRelation(split='val')
detector = faster_rcnn_resnet101_v1d_custom(classes=vg_val.obj_classes,
pretrained_base=False, pretrained=False,
additional_output=True)
if args.pretrained_faster_rcnn_params == '':
download('http://data.dgl.ai/models/SceneGraph/faster_rcnn_resnet101_v1d_visualgenome.params')
params_path = 'faster_rcnn_resnet101_v1d_visualgenome.params'
vg_val = VGRelation(split="val")
detector = faster_rcnn_resnet101_v1d_custom(
classes=vg_val.obj_classes,
pretrained_base=False,
pretrained=False,
additional_output=True,
)
if args.pretrained_faster_rcnn_params == "":
download(
"http://data.dgl.ai/models/SceneGraph/faster_rcnn_resnet101_v1d_visualgenome.params"
)
params_path = "faster_rcnn_resnet101_v1d_visualgenome.params"
else:
params_path = args.pretrained_faster_rcnn_params
detector.load_parameters(params_path, ctx=ctx, ignore_extra=True, allow_missing=True)
detector.load_parameters(
params_path, ctx=ctx, ignore_extra=True, allow_missing=True
)
detector_feat = faster_rcnn_resnet101_v1d_custom(classes=vg_val.obj_classes,
pretrained_base=False, pretrained=False,
additional_output=True)
detector_feat.load_parameters(params_path, ctx=ctx, ignore_extra=True, allow_missing=True)
if args.faster_rcnn_params == '':
download('http://data.dgl.ai/models/SceneGraph/faster_rcnn_resnet101_v1d_visualgenome.params')
detector_feat.features.load_parameters('faster_rcnn_resnet101_v1d_visualgenome.params', ctx=ctx)
detector_feat = faster_rcnn_resnet101_v1d_custom(
classes=vg_val.obj_classes,
pretrained_base=False,
pretrained=False,
additional_output=True,
)
detector_feat.load_parameters(
params_path, ctx=ctx, ignore_extra=True, allow_missing=True
)
if args.faster_rcnn_params == "":
download(
"http://data.dgl.ai/models/SceneGraph/faster_rcnn_resnet101_v1d_visualgenome.params"
)
detector_feat.features.load_parameters(
"faster_rcnn_resnet101_v1d_visualgenome.params", ctx=ctx
)
else:
detector_feat.features.load_parameters(args.faster_rcnn_params, ctx=ctx)
......@@ -64,24 +110,37 @@ else:
if args.image:
image_path = args.image
else:
gcv.utils.download('https://raw.githubusercontent.com/dmlc/web-data/master/' +
'dgl/examples/mxnet/scenegraph/old-couple.png',
'old-couple.png')
image_path = 'old-couple.png'
x, img = presets.rcnn.load_test(args.image, short=detector.short, max_size=detector.max_size)
gcv.utils.download(
"https://raw.githubusercontent.com/dmlc/web-data/master/"
+ "dgl/examples/mxnet/scenegraph/old-couple.png",
"old-couple.png",
)
image_path = "old-couple.png"
x, img = presets.rcnn.load_test(
args.image, short=detector.short, max_size=detector.max_size
)
x = x.as_in_context(ctx)
# detector prediction
ids, scores, bboxes, feat, feat_ind, spatial_feat = detector(x)
# build graph, extract edge features
g = build_graph_validate_pred(x, ids, scores, bboxes, feat_ind, spatial_feat, bbox_improvement=True, scores_top_k=75, overlap=False)
rel_bbox = g.edata['rel_bbox'].expand_dims(0).as_in_context(ctx)
g = build_graph_validate_pred(
x,
ids,
scores,
bboxes,
feat_ind,
spatial_feat,
bbox_improvement=True,
scores_top_k=75,
overlap=False,
)
rel_bbox = g.edata["rel_bbox"].expand_dims(0).as_in_context(ctx)
_, _, _, spatial_feat_rel = detector_feat(x, None, None, rel_bbox)
g.edata['edge_feat'] = spatial_feat_rel[0]
g.edata["edge_feat"] = spatial_feat_rel[0]
# graph prediction
g = net(g)
_, preds = extract_pred(g, joint_preds=True)
preds = preds[preds[:,1].argsort()[::-1]]
preds = preds[preds[:, 1].argsort()[::-1]]
plot_sg(img, preds, detector.classes, vg_val.rel_classes, 10)
import dgl
import pickle
import gluoncv as gcv
import mxnet as mx
import numpy as np
from mxnet import nd
from mxnet.gluon import nn
from dgl.utils import toindex
import pickle
import dgl
from dgl.nn.mxnet import GraphConv
from dgl.utils import toindex
__all__ = ["RelDN"]
__all__ = ['RelDN']
class EdgeConfMLP(nn.Block):
'''compute the confidence for edges'''
"""compute the confidence for edges"""
def __init__(self):
super(EdgeConfMLP, self).__init__()
def forward(self, edges):
score_pred = nd.log_softmax(edges.data['preds'])[:,1:].max(axis=1)
score_phr = score_pred + edges.src['node_class_logit'] + edges.dst['node_class_logit']
return {'score_pred': score_pred,
'score_phr': score_phr}
score_pred = nd.log_softmax(edges.data["preds"])[:, 1:].max(axis=1)
score_phr = (
score_pred
+ edges.src["node_class_logit"]
+ edges.dst["node_class_logit"]
)
return {"score_pred": score_pred, "score_phr": score_phr}
class EdgeBBoxExtend(nn.Block):
'''encode the bounding boxes'''
"""encode the bounding boxes"""
def __init__(self):
super(EdgeBBoxExtend, self).__init__()
def bbox_delta(self, bbox_a, bbox_b):
n = bbox_a.shape[0]
result = nd.zeros((n, 4), ctx=bbox_a.context)
result[:,0] = bbox_a[:,0] - bbox_b[:,0]
result[:,1] = bbox_a[:,1] - bbox_b[:,1]
result[:,2] = nd.log((bbox_a[:,2] - bbox_a[:,0] + 1e-8) / (bbox_b[:,2] - bbox_b[:,0] + 1e-8))
result[:,3] = nd.log((bbox_a[:,3] - bbox_a[:,1] + 1e-8) / (bbox_b[:,3] - bbox_b[:,1] + 1e-8))
result[:, 0] = bbox_a[:, 0] - bbox_b[:, 0]
result[:, 1] = bbox_a[:, 1] - bbox_b[:, 1]
result[:, 2] = nd.log(
(bbox_a[:, 2] - bbox_a[:, 0] + 1e-8)
/ (bbox_b[:, 2] - bbox_b[:, 0] + 1e-8)
)
result[:, 3] = nd.log(
(bbox_a[:, 3] - bbox_a[:, 1] + 1e-8)
/ (bbox_b[:, 3] - bbox_b[:, 1] + 1e-8)
)
return result
def forward(self, edges):
ctx = edges.src['pred_bbox'].context
n = edges.src['pred_bbox'].shape[0]
delta_src_obj = self.bbox_delta(edges.src['pred_bbox'], edges.dst['pred_bbox'])
delta_src_rel = self.bbox_delta(edges.src['pred_bbox'], edges.data['rel_bbox'])
delta_rel_obj = self.bbox_delta(edges.data['rel_bbox'], edges.dst['pred_bbox'])
ctx = edges.src["pred_bbox"].context
n = edges.src["pred_bbox"].shape[0]
delta_src_obj = self.bbox_delta(
edges.src["pred_bbox"], edges.dst["pred_bbox"]
)
delta_src_rel = self.bbox_delta(
edges.src["pred_bbox"], edges.data["rel_bbox"]
)
delta_rel_obj = self.bbox_delta(
edges.data["rel_bbox"], edges.dst["pred_bbox"]
)
result = nd.zeros((n, 12), ctx=ctx)
result[:,0:4] = delta_src_obj
result[:,4:8] = delta_src_rel
result[:,8:12] = delta_rel_obj
return {'pred_bbox_additional': result}
result[:, 0:4] = delta_src_obj
result[:, 4:8] = delta_src_rel
result[:, 8:12] = delta_rel_obj
return {"pred_bbox_additional": result}
class EdgeFreqPrior(nn.Block):
'''make use of the pre-trained frequency prior'''
"""make use of the pre-trained frequency prior"""
def __init__(self, prior_pkl):
super(EdgeFreqPrior, self).__init__()
with open(prior_pkl, 'rb') as f:
with open(prior_pkl, "rb") as f:
freq_prior = pickle.load(f)
self.freq_prior = freq_prior
def forward(self, edges):
ctx = edges.src['node_class_pred'].context
src_ind = edges.src['node_class_pred'].asnumpy().astype(int)
dst_ind = edges.dst['node_class_pred'].asnumpy().astype(int)
ctx = edges.src["node_class_pred"].context
src_ind = edges.src["node_class_pred"].asnumpy().astype(int)
dst_ind = edges.dst["node_class_pred"].asnumpy().astype(int)
prob = self.freq_prior[src_ind, dst_ind]
out = nd.array(prob, ctx=ctx)
return {'freq_prior': out}
return {"freq_prior": out}
class EdgeSpatial(nn.Block):
'''spatial feature branch'''
"""spatial feature branch"""
def __init__(self, n_classes):
super(EdgeSpatial, self).__init__()
self.mlp = nn.Sequential()
......@@ -76,14 +100,20 @@ class EdgeSpatial(nn.Block):
self.mlp.add(nn.Dense(n_classes))
def forward(self, edges):
feat = nd.concat(edges.src['pred_bbox'], edges.dst['pred_bbox'],
edges.data['rel_bbox'], edges.data['pred_bbox_additional'])
feat = nd.concat(
edges.src["pred_bbox"],
edges.dst["pred_bbox"],
edges.data["rel_bbox"],
edges.data["pred_bbox_additional"],
)
out = self.mlp(feat)
return {'spatial': out}
return {"spatial": out}
class EdgeVisual(nn.Block):
'''visual feature branch'''
def __init__(self, n_classes, vis_feat_dim=7*7*3):
"""visual feature branch"""
def __init__(self, n_classes, vis_feat_dim=7 * 7 * 3):
super(EdgeVisual, self).__init__()
self.dim_in = vis_feat_dim
self.mlp_joint = nn.Sequential()
......@@ -97,15 +127,21 @@ class EdgeVisual(nn.Block):
self.mlp_ob = nn.Dense(n_classes)
def forward(self, edges):
feat = nd.concat(edges.src['node_feat'], edges.dst['node_feat'], edges.data['edge_feat'])
feat = nd.concat(
edges.src["node_feat"],
edges.dst["node_feat"],
edges.data["edge_feat"],
)
out_joint = self.mlp_joint(feat)
out_sub = self.mlp_sub(edges.src['node_feat'])
out_ob = self.mlp_ob(edges.dst['node_feat'])
out_sub = self.mlp_sub(edges.src["node_feat"])
out_ob = self.mlp_ob(edges.dst["node_feat"])
out = out_joint + out_sub + out_ob
return {'visual': out}
return {"visual": out}
class RelDN(nn.Block):
'''The RelDN Model'''
"""The RelDN Model"""
def __init__(self, n_classes, prior_pkl, semantic_only=False):
super(RelDN, self).__init__()
# output layers
......@@ -121,19 +157,21 @@ class RelDN(nn.Block):
self.edge_conf_mlp = EdgeConfMLP()
self.semantic_only = semantic_only
def forward(self, g):
def forward(self, g):
if g is None or g.number_of_nodes() == 0:
return g
# predictions
g.apply_edges(self.freq_prior)
if self.semantic_only:
g.edata['preds'] = g.edata['freq_prior']
g.edata["preds"] = g.edata["freq_prior"]
else:
# bbox extension
g.apply_edges(self.edge_bbox_extend)
g.apply_edges(self.spatial)
g.apply_edges(self.visual)
g.edata['preds'] = g.edata['freq_prior'] + g.edata['spatial'] + g.edata['visual']
g.edata["preds"] = (
g.edata["freq_prior"] + g.edata["spatial"] + g.edata["visual"]
)
# subgraph for gconv
g.apply_edges(self.edge_conf_mlp)
return g
......@@ -3,29 +3,29 @@ import argparse
import os
# disable autotune
os.environ['MXNET_CUDNN_AUTOTUNE_DEFAULT'] = '0'
os.environ["MXNET_CUDNN_AUTOTUNE_DEFAULT"] = "0"
import logging
import time
import numpy as np
import mxnet as mx
from mxnet import gluon
from mxnet import autograd
from mxnet.contrib import amp
import gluoncv as gcv
import mxnet as mx
import numpy as np
from data import *
from gluoncv import data as gdata
from gluoncv import utils as gutils
from gluoncv.data.batchify import Append, FasterRCNNTrainBatchify, Tuple
from gluoncv.data.transforms.presets.rcnn import (
FasterRCNNDefaultTrainTransform, FasterRCNNDefaultValTransform)
from gluoncv.model_zoo import get_model
from gluoncv.data.batchify import FasterRCNNTrainBatchify, Tuple, Append
from gluoncv.data.transforms.presets.rcnn import FasterRCNNDefaultTrainTransform, \
FasterRCNNDefaultValTransform
from gluoncv.utils.metrics.voc_detection import VOC07MApMetric
from gluoncv.utils.metrics.coco_detection import COCODetectionMetric
from gluoncv.utils.parallel import Parallelizable, Parallel
from gluoncv.utils.metrics.rcnn import RPNAccMetric, RPNL1LossMetric, RCNNAccMetric, \
RCNNL1LossMetric
from data import *
from model import faster_rcnn_resnet101_v1d_custom, faster_rcnn_resnet50_v1b_custom
from gluoncv.utils.metrics.rcnn import (RCNNAccMetric, RCNNL1LossMetric,
RPNAccMetric, RPNL1LossMetric)
from gluoncv.utils.metrics.voc_detection import VOC07MApMetric
from gluoncv.utils.parallel import Parallel, Parallelizable
from model import (faster_rcnn_resnet50_v1b_custom,
faster_rcnn_resnet101_v1d_custom)
from mxnet import autograd, gluon
from mxnet.contrib import amp
try:
import horovod.mxnet as hvd
......@@ -34,111 +34,229 @@ except ImportError:
def parse_args():
parser = argparse.ArgumentParser(description='Train Faster-RCNN networks e2e.')
parser.add_argument('--network', type=str, default='resnet101_v1d',
help="Base network name which serves as feature extraction base.")
parser.add_argument('--dataset', type=str, default='visualgenome',
help='Training dataset. Now support voc and coco.')
parser.add_argument('--num-workers', '-j', dest='num_workers', type=int,
default=8, help='Number of data workers, you can use larger '
'number to accelerate data loading, '
'if your CPU and GPUs are powerful.')
parser.add_argument('--batch-size', type=int, default=8, help='Training mini-batch size.')
parser.add_argument('--gpus', type=str, default='0',
help='Training with GPUs, you can specify 1,3 for example.')
parser.add_argument('--epochs', type=str, default='',
help='Training epochs.')
parser.add_argument('--resume', type=str, default='',
help='Resume from previously saved parameters if not None. '
'For example, you can resume from ./faster_rcnn_xxx_0123.params')
parser.add_argument('--start-epoch', type=int, default=0,
help='Starting epoch for resuming, default is 0 for new training.'
'You can specify it to 100 for example to start from 100 epoch.')
parser.add_argument('--lr', type=str, default='',
help='Learning rate, default is 0.001 for voc single gpu training.')
parser.add_argument('--lr-decay', type=float, default=0.1,
help='decay rate of learning rate. default is 0.1.')
parser.add_argument('--lr-decay-epoch', type=str, default='',
help='epochs at which learning rate decays. default is 14,20 for voc.')
parser.add_argument('--lr-warmup', type=str, default='',
help='warmup iterations to adjust learning rate, default is 0 for voc.')
parser.add_argument('--lr-warmup-factor', type=float, default=1. / 3.,
help='warmup factor of base lr.')
parser.add_argument('--momentum', type=float, default=0.9,
help='SGD momentum, default is 0.9')
parser.add_argument('--wd', type=str, default='',
help='Weight decay, default is 5e-4 for voc')
parser.add_argument('--log-interval', type=int, default=100,
help='Logging mini-batch interval. Default is 100.')
parser.add_argument('--save-prefix', type=str, default='',
help='Saving parameter prefix')
parser.add_argument('--save-interval', type=int, default=1,
help='Saving parameters epoch interval, best model will always be saved.')
parser.add_argument('--val-interval', type=int, default=1,
help='Epoch interval for validation, increase the number will reduce the '
'training time if validation is slow.')
parser.add_argument('--seed', type=int, default=233,
help='Random seed to be fixed.')
parser.add_argument('--verbose', dest='verbose', action='store_true',
help='Print helpful debugging info once set.')
parser.add_argument('--mixup', action='store_true', help='Use mixup training.')
parser.add_argument('--no-mixup-epochs', type=int, default=20,
help='Disable mixup training if enabled in the last N epochs.')
parser = argparse.ArgumentParser(
description="Train Faster-RCNN networks e2e."
)
parser.add_argument(
"--network",
type=str,
default="resnet101_v1d",
help="Base network name which serves as feature extraction base.",
)
parser.add_argument(
"--dataset",
type=str,
default="visualgenome",
help="Training dataset. Now support voc and coco.",
)
parser.add_argument(
"--num-workers",
"-j",
dest="num_workers",
type=int,
default=8,
help="Number of data workers, you can use larger "
"number to accelerate data loading, "
"if your CPU and GPUs are powerful.",
)
parser.add_argument(
"--batch-size", type=int, default=8, help="Training mini-batch size."
)
parser.add_argument(
"--gpus",
type=str,
default="0",
help="Training with GPUs, you can specify 1,3 for example.",
)
parser.add_argument(
"--epochs", type=str, default="", help="Training epochs."
)
parser.add_argument(
"--resume",
type=str,
default="",
help="Resume from previously saved parameters if not None. "
"For example, you can resume from ./faster_rcnn_xxx_0123.params",
)
parser.add_argument(
"--start-epoch",
type=int,
default=0,
help="Starting epoch for resuming, default is 0 for new training."
"You can specify it to 100 for example to start from 100 epoch.",
)
parser.add_argument(
"--lr",
type=str,
default="",
help="Learning rate, default is 0.001 for voc single gpu training.",
)
parser.add_argument(
"--lr-decay",
type=float,
default=0.1,
help="decay rate of learning rate. default is 0.1.",
)
parser.add_argument(
"--lr-decay-epoch",
type=str,
default="",
help="epochs at which learning rate decays. default is 14,20 for voc.",
)
parser.add_argument(
"--lr-warmup",
type=str,
default="",
help="warmup iterations to adjust learning rate, default is 0 for voc.",
)
parser.add_argument(
"--lr-warmup-factor",
type=float,
default=1.0 / 3.0,
help="warmup factor of base lr.",
)
parser.add_argument(
"--momentum",
type=float,
default=0.9,
help="SGD momentum, default is 0.9",
)
parser.add_argument(
"--wd",
type=str,
default="",
help="Weight decay, default is 5e-4 for voc",
)
parser.add_argument(
"--log-interval",
type=int,
default=100,
help="Logging mini-batch interval. Default is 100.",
)
parser.add_argument(
"--save-prefix", type=str, default="", help="Saving parameter prefix"
)
parser.add_argument(
"--save-interval",
type=int,
default=1,
help="Saving parameters epoch interval, best model will always be saved.",
)
parser.add_argument(
"--val-interval",
type=int,
default=1,
help="Epoch interval for validation, increase the number will reduce the "
"training time if validation is slow.",
)
parser.add_argument(
"--seed", type=int, default=233, help="Random seed to be fixed."
)
parser.add_argument(
"--verbose",
dest="verbose",
action="store_true",
help="Print helpful debugging info once set.",
)
parser.add_argument(
"--mixup", action="store_true", help="Use mixup training."
)
parser.add_argument(
"--no-mixup-epochs",
type=int,
default=20,
help="Disable mixup training if enabled in the last N epochs.",
)
# Norm layer options
parser.add_argument('--norm-layer', type=str, default=None,
help='Type of normalization layer to use. '
'If set to None, backbone normalization layer will be fixed,'
' and no normalization layer will be used. '
'Currently supports \'bn\', and None, default is None.'
'Note that if horovod is enabled, sync bn will not work correctly.')
parser.add_argument(
"--norm-layer",
type=str,
default=None,
help="Type of normalization layer to use. "
"If set to None, backbone normalization layer will be fixed,"
" and no normalization layer will be used. "
"Currently supports 'bn', and None, default is None."
"Note that if horovod is enabled, sync bn will not work correctly.",
)
# FPN options
parser.add_argument('--use-fpn', action='store_true',
help='Whether to use feature pyramid network.')
parser.add_argument(
"--use-fpn",
action="store_true",
help="Whether to use feature pyramid network.",
)
# Performance options
parser.add_argument('--disable-hybridization', action='store_true',
help='Whether to disable hybridize the model. '
'Memory usage and speed will decrese.')
parser.add_argument('--static-alloc', action='store_true',
help='Whether to use static memory allocation. Memory usage will increase.')
parser.add_argument('--amp', action='store_true',
help='Use MXNet AMP for mixed precision training.')
parser.add_argument('--horovod', action='store_true',
help='Use MXNet Horovod for distributed training. Must be run with OpenMPI. '
'--gpus is ignored when using --horovod.')
parser.add_argument('--executor-threads', type=int, default=1,
help='Number of threads for executor for scheduling ops. '
'More threads may incur higher GPU memory footprint, '
'but may speed up throughput. Note that when horovod is used, '
'it is set to 1.')
parser.add_argument('--kv-store', type=str, default='nccl',
help='KV store options. local, device, nccl, dist_sync, dist_device_sync, '
'dist_async are available.')
parser.add_argument(
"--disable-hybridization",
action="store_true",
help="Whether to disable hybridize the model. "
"Memory usage and speed will decrese.",
)
parser.add_argument(
"--static-alloc",
action="store_true",
help="Whether to use static memory allocation. Memory usage will increase.",
)
parser.add_argument(
"--amp",
action="store_true",
help="Use MXNet AMP for mixed precision training.",
)
parser.add_argument(
"--horovod",
action="store_true",
help="Use MXNet Horovod for distributed training. Must be run with OpenMPI. "
"--gpus is ignored when using --horovod.",
)
parser.add_argument(
"--executor-threads",
type=int,
default=1,
help="Number of threads for executor for scheduling ops. "
"More threads may incur higher GPU memory footprint, "
"but may speed up throughput. Note that when horovod is used, "
"it is set to 1.",
)
parser.add_argument(
"--kv-store",
type=str,
default="nccl",
help="KV store options. local, device, nccl, dist_sync, dist_device_sync, "
"dist_async are available.",
)
args = parser.parse_args()
if args.horovod:
if hvd is None:
raise SystemExit("Horovod not found, please check if you installed it correctly.")
raise SystemExit(
"Horovod not found, please check if you installed it correctly."
)
hvd.init()
if args.dataset == 'voc':
if args.dataset == "voc":
args.epochs = int(args.epochs) if args.epochs else 20
args.lr_decay_epoch = args.lr_decay_epoch if args.lr_decay_epoch else '14,20'
args.lr_decay_epoch = (
args.lr_decay_epoch if args.lr_decay_epoch else "14,20"
)
args.lr = float(args.lr) if args.lr else 0.001
args.lr_warmup = args.lr_warmup if args.lr_warmup else -1
args.wd = float(args.wd) if args.wd else 5e-4
elif args.dataset == 'visualgenome':
elif args.dataset == "visualgenome":
args.epochs = int(args.epochs) if args.epochs else 20
args.lr_decay_epoch = args.lr_decay_epoch if args.lr_decay_epoch else '14,20'
args.lr_decay_epoch = (
args.lr_decay_epoch if args.lr_decay_epoch else "14,20"
)
args.lr = float(args.lr) if args.lr else 0.001
args.lr_warmup = args.lr_warmup if args.lr_warmup else -1
args.wd = float(args.wd) if args.wd else 5e-4
elif args.dataset == 'coco':
elif args.dataset == "coco":
args.epochs = int(args.epochs) if args.epochs else 26
args.lr_decay_epoch = args.lr_decay_epoch if args.lr_decay_epoch else '17,23'
args.lr_decay_epoch = (
args.lr_decay_epoch if args.lr_decay_epoch else "17,23"
)
args.lr = float(args.lr) if args.lr else 0.01
args.lr_warmup = args.lr_warmup if args.lr_warmup else 1000
args.wd = float(args.wd) if args.wd else 1e-4
......@@ -146,71 +264,129 @@ def parse_args():
def get_dataset(dataset, args):
if dataset.lower() == 'voc':
if dataset.lower() == "voc":
train_dataset = gdata.VOCDetection(
splits=[(2007, 'trainval'), (2012, 'trainval')])
val_dataset = gdata.VOCDetection(
splits=[(2007, 'test')])
val_metric = VOC07MApMetric(iou_thresh=0.5, class_names=val_dataset.classes)
elif dataset.lower() == 'coco':
train_dataset = gdata.COCODetection(splits='instances_train2017', use_crowd=False)
val_dataset = gdata.COCODetection(splits='instances_val2017', skip_empty=False)
val_metric = COCODetectionMetric(val_dataset, args.save_prefix + '_eval', cleanup=True)
elif dataset.lower() == 'visualgenome':
train_dataset = VGObject(root=os.path.join('~', '.mxnet', 'datasets', 'visualgenome'),
splits='detections_train', use_crowd=False)
val_dataset = VGObject(root=os.path.join('~', '.mxnet', 'datasets', 'visualgenome'),
splits='detections_val', skip_empty=False)
val_metric = COCODetectionMetric(val_dataset, args.save_prefix + '_eval', cleanup=True)
splits=[(2007, "trainval"), (2012, "trainval")]
)
val_dataset = gdata.VOCDetection(splits=[(2007, "test")])
val_metric = VOC07MApMetric(
iou_thresh=0.5, class_names=val_dataset.classes
)
elif dataset.lower() == "coco":
train_dataset = gdata.COCODetection(
splits="instances_train2017", use_crowd=False
)
val_dataset = gdata.COCODetection(
splits="instances_val2017", skip_empty=False
)
val_metric = COCODetectionMetric(
val_dataset, args.save_prefix + "_eval", cleanup=True
)
elif dataset.lower() == "visualgenome":
train_dataset = VGObject(
root=os.path.join("~", ".mxnet", "datasets", "visualgenome"),
splits="detections_train",
use_crowd=False,
)
val_dataset = VGObject(
root=os.path.join("~", ".mxnet", "datasets", "visualgenome"),
splits="detections_val",
skip_empty=False,
)
val_metric = COCODetectionMetric(
val_dataset, args.save_prefix + "_eval", cleanup=True
)
else:
raise NotImplementedError('Dataset: {} not implemented.'.format(dataset))
raise NotImplementedError(
"Dataset: {} not implemented.".format(dataset)
)
if args.mixup:
from gluoncv.data.mixup import detection
train_dataset = detection.MixupDetection(train_dataset)
return train_dataset, val_dataset, val_metric
def get_dataloader(net, train_dataset, val_dataset, train_transform, val_transform, batch_size,
num_shards, args):
def get_dataloader(
net,
train_dataset,
val_dataset,
train_transform,
val_transform,
batch_size,
num_shards,
args,
):
"""Get dataloader."""
train_bfn = FasterRCNNTrainBatchify(net, num_shards)
if hasattr(train_dataset, 'get_im_aspect_ratio'):
if hasattr(train_dataset, "get_im_aspect_ratio"):
im_aspect_ratio = train_dataset.get_im_aspect_ratio()
else:
im_aspect_ratio = [1.] * len(train_dataset)
train_sampler = \
gcv.nn.sampler.SplitSortedBucketSampler(im_aspect_ratio, batch_size,
num_parts=hvd.size() if args.horovod else 1,
part_index=hvd.rank() if args.horovod else 0,
shuffle=True)
train_loader = mx.gluon.data.DataLoader(train_dataset.transform(
train_transform(net.short, net.max_size, net, ashape=net.ashape, multi_stage=args.use_fpn)),
batch_sampler=train_sampler, batchify_fn=train_bfn, num_workers=args.num_workers)
im_aspect_ratio = [1.0] * len(train_dataset)
train_sampler = gcv.nn.sampler.SplitSortedBucketSampler(
im_aspect_ratio,
batch_size,
num_parts=hvd.size() if args.horovod else 1,
part_index=hvd.rank() if args.horovod else 0,
shuffle=True,
)
train_loader = mx.gluon.data.DataLoader(
train_dataset.transform(
train_transform(
net.short,
net.max_size,
net,
ashape=net.ashape,
multi_stage=args.use_fpn,
)
),
batch_sampler=train_sampler,
batchify_fn=train_bfn,
num_workers=args.num_workers,
)
if val_dataset is None:
val_loader = None
else:
val_bfn = Tuple(*[Append() for _ in range(3)])
short = net.short[-1] if isinstance(net.short, (tuple, list)) else net.short
short = (
net.short[-1] if isinstance(net.short, (tuple, list)) else net.short
)
# validation use 1 sample per device
val_loader = mx.gluon.data.DataLoader(
val_dataset.transform(val_transform(short, net.max_size)), num_shards, False,
batchify_fn=val_bfn, last_batch='keep', num_workers=args.num_workers)
val_dataset.transform(val_transform(short, net.max_size)),
num_shards,
False,
batchify_fn=val_bfn,
last_batch="keep",
num_workers=args.num_workers,
)
return train_loader, val_loader
def save_params(net, logger, best_map, current_map, epoch, save_interval, prefix):
def save_params(
net, logger, best_map, current_map, epoch, save_interval, prefix
):
current_map = float(current_map)
if current_map > best_map[0]:
logger.info('[Epoch {}] mAP {} higher than current best {} saving to {}'.format(
epoch, current_map, best_map, '{:s}_best.params'.format(prefix)))
logger.info(
"[Epoch {}] mAP {} higher than current best {} saving to {}".format(
epoch, current_map, best_map, "{:s}_best.params".format(prefix)
)
)
best_map[0] = current_map
net.save_parameters('{:s}_best.params'.format(prefix))
with open(prefix + '_best_map.log', 'a') as f:
f.write('{:04d}:\t{:.4f}\n'.format(epoch, current_map))
net.save_parameters("{:s}_best.params".format(prefix))
with open(prefix + "_best_map.log", "a") as f:
f.write("{:04d}:\t{:.4f}\n".format(epoch, current_map))
if save_interval and (epoch + 1) % save_interval == 0:
logger.info('[Epoch {}] Saving parameters to {}'.format(
epoch, '{:s}_{:04d}_{:.4f}.params'.format(prefix, epoch, current_map)))
net.save_parameters('{:s}_{:04d}_{:.4f}.params'.format(prefix, epoch, current_map))
logger.info(
"[Epoch {}] Saving parameters to {}".format(
epoch,
"{:s}_{:04d}_{:.4f}.params".format(prefix, epoch, current_map),
)
)
net.save_parameters(
"{:s}_{:04d}_{:.4f}.params".format(prefix, epoch, current_map)
)
def split_and_load(batch, ctx_list):
......@@ -254,23 +430,37 @@ def validate(net, val_data, ctx, eval_metric, args):
gt_ids.append(y.slice_axis(axis=-1, begin=4, end=5))
gt_bboxes.append(y.slice_axis(axis=-1, begin=0, end=4))
gt_bboxes[-1] *= im_scale
gt_difficults.append(y.slice_axis(axis=-1, begin=5, end=6) if y.shape[-1] > 5 else None)
gt_difficults.append(
y.slice_axis(axis=-1, begin=5, end=6)
if y.shape[-1] > 5
else None
)
# update metric
for det_bbox, det_id, det_score, gt_bbox, gt_id, gt_diff in zip(det_bboxes, det_ids,
det_scores, gt_bboxes,
gt_ids, gt_difficults):
eval_metric.update(det_bbox, det_id, det_score, gt_bbox, gt_id, gt_diff)
for det_bbox, det_id, det_score, gt_bbox, gt_id, gt_diff in zip(
det_bboxes, det_ids, det_scores, gt_bboxes, gt_ids, gt_difficults
):
eval_metric.update(
det_bbox, det_id, det_score, gt_bbox, gt_id, gt_diff
)
return eval_metric.get()
def get_lr_at_iter(alpha, lr_warmup_factor=1. / 3.):
def get_lr_at_iter(alpha, lr_warmup_factor=1.0 / 3.0):
return lr_warmup_factor * (1 - alpha) + alpha
class ForwardBackwardTask(Parallelizable):
def __init__(self, net, optimizer, rpn_cls_loss, rpn_box_loss, rcnn_cls_loss, rcnn_box_loss,
mix_ratio):
def __init__(
self,
net,
optimizer,
rpn_cls_loss,
rpn_box_loss,
rcnn_cls_loss,
rcnn_box_loss,
mix_ratio,
):
super(ForwardBackwardTask, self).__init__()
self.net = net
self._optimizer = optimizer
......@@ -285,96 +475,159 @@ class ForwardBackwardTask(Parallelizable):
with autograd.record():
gt_label = label[:, :, 4:5]
gt_box = label[:, :, :4]
cls_pred, box_pred, roi, samples, matches, rpn_score, rpn_box, anchors, cls_targets, \
box_targets, box_masks, _ = net(data, gt_box, gt_label)
(
cls_pred,
box_pred,
roi,
samples,
matches,
rpn_score,
rpn_box,
anchors,
cls_targets,
box_targets,
box_masks,
_,
) = net(data, gt_box, gt_label)
# losses of rpn
rpn_score = rpn_score.squeeze(axis=-1)
num_rpn_pos = (rpn_cls_targets >= 0).sum()
rpn_loss1 = self.rpn_cls_loss(rpn_score, rpn_cls_targets,
rpn_cls_targets >= 0) * rpn_cls_targets.size / num_rpn_pos
rpn_loss2 = self.rpn_box_loss(rpn_box, rpn_box_targets,
rpn_box_masks) * rpn_box.size / num_rpn_pos
rpn_loss1 = (
self.rpn_cls_loss(
rpn_score, rpn_cls_targets, rpn_cls_targets >= 0
)
* rpn_cls_targets.size
/ num_rpn_pos
)
rpn_loss2 = (
self.rpn_box_loss(rpn_box, rpn_box_targets, rpn_box_masks)
* rpn_box.size
/ num_rpn_pos
)
# rpn overall loss, use sum rather than average
rpn_loss = rpn_loss1 + rpn_loss2
# losses of rcnn
num_rcnn_pos = (cls_targets >= 0).sum()
rcnn_loss1 = self.rcnn_cls_loss(cls_pred, cls_targets,
cls_targets.expand_dims(-1) >= 0) * cls_targets.size / \
num_rcnn_pos
rcnn_loss2 = self.rcnn_box_loss(box_pred, box_targets, box_masks) * box_pred.size / \
num_rcnn_pos
rcnn_loss1 = (
self.rcnn_cls_loss(
cls_pred, cls_targets, cls_targets.expand_dims(-1) >= 0
)
* cls_targets.size
/ num_rcnn_pos
)
rcnn_loss2 = (
self.rcnn_box_loss(box_pred, box_targets, box_masks)
* box_pred.size
/ num_rcnn_pos
)
rcnn_loss = rcnn_loss1 + rcnn_loss2
# overall losses
total_loss = rpn_loss.sum() * self.mix_ratio + rcnn_loss.sum() * self.mix_ratio
total_loss = (
rpn_loss.sum() * self.mix_ratio
+ rcnn_loss.sum() * self.mix_ratio
)
rpn_loss1_metric = rpn_loss1.mean() * self.mix_ratio
rpn_loss2_metric = rpn_loss2.mean() * self.mix_ratio
rcnn_loss1_metric = rcnn_loss1.mean() * self.mix_ratio
rcnn_loss2_metric = rcnn_loss2.mean() * self.mix_ratio
rpn_acc_metric = [[rpn_cls_targets, rpn_cls_targets >= 0], [rpn_score]]
rpn_acc_metric = [
[rpn_cls_targets, rpn_cls_targets >= 0],
[rpn_score],
]
rpn_l1_loss_metric = [[rpn_box_targets, rpn_box_masks], [rpn_box]]
rcnn_acc_metric = [[cls_targets], [cls_pred]]
rcnn_l1_loss_metric = [[box_targets, box_masks], [box_pred]]
if args.amp:
with amp.scale_loss(total_loss, self._optimizer) as scaled_losses:
with amp.scale_loss(
total_loss, self._optimizer
) as scaled_losses:
autograd.backward(scaled_losses)
else:
total_loss.backward()
return rpn_loss1_metric, rpn_loss2_metric, rcnn_loss1_metric, rcnn_loss2_metric, \
rpn_acc_metric, rpn_l1_loss_metric, rcnn_acc_metric, rcnn_l1_loss_metric
return (
rpn_loss1_metric,
rpn_loss2_metric,
rcnn_loss1_metric,
rcnn_loss2_metric,
rpn_acc_metric,
rpn_l1_loss_metric,
rcnn_acc_metric,
rcnn_l1_loss_metric,
)
def train(net, train_data, val_data, eval_metric, batch_size, ctx, args):
"""Training pipeline"""
args.kv_store = 'device' if (args.amp and 'nccl' in args.kv_store) else args.kv_store
args.kv_store = (
"device" if (args.amp and "nccl" in args.kv_store) else args.kv_store
)
kv = mx.kvstore.create(args.kv_store)
net.collect_params().setattr('grad_req', 'null')
net.collect_train_params().setattr('grad_req', 'write')
optimizer_params = {'learning_rate': args.lr, 'wd': args.wd, 'momentum': args.momentum}
net.collect_params().setattr("grad_req", "null")
net.collect_train_params().setattr("grad_req", "write")
optimizer_params = {
"learning_rate": args.lr,
"wd": args.wd,
"momentum": args.momentum,
}
if args.horovod:
hvd.broadcast_parameters(net.collect_params(), root_rank=0)
trainer = hvd.DistributedTrainer(
net.collect_train_params(), # fix batchnorm, fix first stage, etc...
'sgd',
optimizer_params)
"sgd",
optimizer_params,
)
else:
trainer = gluon.Trainer(
net.collect_train_params(), # fix batchnorm, fix first stage, etc...
'sgd',
"sgd",
optimizer_params,
update_on_kvstore=(False if args.amp else None), kvstore=kv)
update_on_kvstore=(False if args.amp else None),
kvstore=kv,
)
if args.amp:
amp.init_trainer(trainer)
# lr decay policy
lr_decay = float(args.lr_decay)
lr_steps = sorted([float(ls) for ls in args.lr_decay_epoch.split(',') if ls.strip()])
lr_steps = sorted(
[float(ls) for ls in args.lr_decay_epoch.split(",") if ls.strip()]
)
lr_warmup = float(args.lr_warmup) # avoid int division
# TODO(zhreshold) losses?
rpn_cls_loss = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss(from_sigmoid=False)
rpn_box_loss = mx.gluon.loss.HuberLoss(rho=1 / 9.) # == smoothl1
rpn_cls_loss = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss(
from_sigmoid=False
)
rpn_box_loss = mx.gluon.loss.HuberLoss(rho=1 / 9.0) # == smoothl1
rcnn_cls_loss = mx.gluon.loss.SoftmaxCrossEntropyLoss()
rcnn_box_loss = mx.gluon.loss.HuberLoss() # == smoothl1
metrics = [mx.metric.Loss('RPN_Conf'),
mx.metric.Loss('RPN_SmoothL1'),
mx.metric.Loss('RCNN_CrossEntropy'),
mx.metric.Loss('RCNN_SmoothL1'), ]
metrics = [
mx.metric.Loss("RPN_Conf"),
mx.metric.Loss("RPN_SmoothL1"),
mx.metric.Loss("RCNN_CrossEntropy"),
mx.metric.Loss("RCNN_SmoothL1"),
]
rpn_acc_metric = RPNAccMetric()
rpn_bbox_metric = RPNL1LossMetric()
rcnn_acc_metric = RCNNAccMetric()
rcnn_bbox_metric = RCNNL1LossMetric()
metrics2 = [rpn_acc_metric, rpn_bbox_metric, rcnn_acc_metric, rcnn_bbox_metric]
metrics2 = [
rpn_acc_metric,
rpn_bbox_metric,
rcnn_acc_metric,
rcnn_bbox_metric,
]
# set up logger
logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)
log_file_path = args.save_prefix + '_train.log'
log_file_path = args.save_prefix + "_train.log"
log_dir = os.path.dirname(log_file_path)
if log_dir and not os.path.exists(log_dir):
os.makedirs(log_dir)
......@@ -382,17 +635,28 @@ def train(net, train_data, val_data, eval_metric, batch_size, ctx, args):
logger.addHandler(fh)
logger.info(args)
if args.verbose:
logger.info('Trainable parameters:')
logger.info("Trainable parameters:")
logger.info(net.collect_train_params().keys())
logger.info('Start training from [Epoch {}]'.format(args.start_epoch))
logger.info("Start training from [Epoch {}]".format(args.start_epoch))
best_map = [0]
for epoch in range(args.start_epoch, args.epochs):
mix_ratio = 1.0
if not args.disable_hybridization:
net.hybridize(static_alloc=args.static_alloc)
rcnn_task = ForwardBackwardTask(net, trainer, rpn_cls_loss, rpn_box_loss, rcnn_cls_loss,
rcnn_box_loss, mix_ratio=1.0)
executor = Parallel(args.executor_threads, rcnn_task) if not args.horovod else None
rcnn_task = ForwardBackwardTask(
net,
trainer,
rpn_cls_loss,
rpn_box_loss,
rcnn_cls_loss,
rcnn_box_loss,
mix_ratio=1.0,
)
executor = (
Parallel(args.executor_threads, rcnn_task)
if not args.horovod
else None
)
if args.mixup:
# TODO(zhreshold) only support evenly mixup now, target generator needs to be modified otherwise
train_data._dataset._data.set_mixup(np.random.uniform, 0.5, 0.5)
......@@ -404,22 +668,29 @@ def train(net, train_data, val_data, eval_metric, batch_size, ctx, args):
new_lr = trainer.learning_rate * lr_decay
lr_steps.pop(0)
trainer.set_learning_rate(new_lr)
logger.info("[Epoch {}] Set learning rate to {}".format(epoch, new_lr))
logger.info(
"[Epoch {}] Set learning rate to {}".format(epoch, new_lr)
)
for metric in metrics:
metric.reset()
tic = time.time()
btic = time.time()
base_lr = trainer.learning_rate
rcnn_task.mix_ratio = mix_ratio
logger.info('Total Num of Batches: %d'%(len(train_data)))
logger.info("Total Num of Batches: %d" % (len(train_data)))
for i, batch in enumerate(train_data):
if epoch == 0 and i <= lr_warmup:
# adjust based on real percentage
new_lr = base_lr * get_lr_at_iter(i / lr_warmup, args.lr_warmup_factor)
new_lr = base_lr * get_lr_at_iter(
i / lr_warmup, args.lr_warmup_factor
)
if new_lr != trainer.learning_rate:
if i % args.log_interval == 0:
logger.info(
'[Epoch 0 Iteration {}] Set learning rate to {}'.format(i, new_lr))
"[Epoch 0 Iteration {}] Set learning rate to {}".format(
i, new_lr
)
)
trainer.set_learning_rate(new_lr)
batch = split_and_load(batch, ctx_list=ctx)
metric_losses = [[] for _ in metrics]
......@@ -445,34 +716,70 @@ def train(net, train_data, val_data, eval_metric, batch_size, ctx, args):
trainer.step(batch_size)
# update metrics
if (not args.horovod or hvd.rank() == 0) and args.log_interval \
and not (i + 1) % args.log_interval:
msg = ','.join(
['{}={:.3f}'.format(*metric.get()) for metric in metrics + metrics2])
logger.info('[Epoch {}][Batch {}], Speed: {:.3f} samples/sec, {}'.format(
epoch, i, args.log_interval * args.batch_size / (time.time() - btic), msg))
if (
(not args.horovod or hvd.rank() == 0)
and args.log_interval
and not (i + 1) % args.log_interval
):
msg = ",".join(
[
"{}={:.3f}".format(*metric.get())
for metric in metrics + metrics2
]
)
logger.info(
"[Epoch {}][Batch {}], Speed: {:.3f} samples/sec, {}".format(
epoch,
i,
args.log_interval
* args.batch_size
/ (time.time() - btic),
msg,
)
)
btic = time.time()
if (not args.horovod) or hvd.rank() == 0:
msg = ','.join(['{}={:.3f}'.format(*metric.get()) for metric in metrics])
logger.info('[Epoch {}] Training cost: {:.3f}, {}'.format(
epoch, (time.time() - tic), msg))
msg = ",".join(
["{}={:.3f}".format(*metric.get()) for metric in metrics]
)
logger.info(
"[Epoch {}] Training cost: {:.3f}, {}".format(
epoch, (time.time() - tic), msg
)
)
if not (epoch + 1) % args.val_interval:
# consider reduce the frequency of validation to save time
if val_data is not None:
map_name, mean_ap = validate(net, val_data, ctx, eval_metric, args)
val_msg = '\n'.join(['{}={}'.format(k, v) for k, v in zip(map_name, mean_ap)])
logger.info('[Epoch {}] Validation: \n{}'.format(epoch, val_msg))
map_name, mean_ap = validate(
net, val_data, ctx, eval_metric, args
)
val_msg = "\n".join(
[
"{}={}".format(k, v)
for k, v in zip(map_name, mean_ap)
]
)
logger.info(
"[Epoch {}] Validation: \n{}".format(epoch, val_msg)
)
current_map = float(mean_ap[-1])
else:
current_map = 0
else:
current_map = 0.
save_params(net, logger, best_map, current_map, epoch, args.save_interval,
args.save_prefix)
if __name__ == '__main__':
current_map = 0.0
save_params(
net,
logger,
best_map,
current_map,
epoch,
args.save_interval,
args.save_prefix,
)
if __name__ == "__main__":
import sys
sys.setrecursionlimit(1100)
......@@ -487,26 +794,31 @@ if __name__ == '__main__':
if args.horovod:
ctx = [mx.gpu(hvd.local_rank())]
else:
ctx = [mx.gpu(int(i)) for i in args.gpus.split(',') if i.strip()]
ctx = [mx.gpu(int(i)) for i in args.gpus.split(",") if i.strip()]
ctx = ctx if ctx else [mx.cpu()]
# network
kwargs = {}
module_list = []
if args.use_fpn:
module_list.append('fpn')
module_list.append("fpn")
if args.norm_layer is not None:
module_list.append(args.norm_layer)
if args.norm_layer == 'bn':
kwargs['num_devices'] = len(args.gpus.split(','))
if args.norm_layer == "bn":
kwargs["num_devices"] = len(args.gpus.split(","))
net_name = '_'.join(('faster_rcnn', *module_list, args.network, 'custom'))
net_name = "_".join(("faster_rcnn", *module_list, args.network, "custom"))
args.save_prefix += net_name
gutils.makedirs(args.save_prefix)
train_dataset, val_dataset, eval_metric = get_dataset(args.dataset, args)
net = faster_rcnn_resnet101_v1d_custom(classes=train_dataset.classes, transfer='coco',
pretrained_base=False, additional_output=False,
per_device_batch_size=args.batch_size // len(ctx), **kwargs)
net = faster_rcnn_resnet101_v1d_custom(
classes=train_dataset.classes,
transfer="coco",
pretrained_base=False,
additional_output=False,
per_device_batch_size=args.batch_size // len(ctx),
**kwargs
)
if args.resume.strip():
net.load_parameters(args.resume.strip())
else:
......@@ -517,10 +829,19 @@ if __name__ == '__main__':
net.collect_params().reset_ctx(ctx)
# training data
batch_size = args.batch_size // len(ctx) if args.horovod else args.batch_size
batch_size = (
args.batch_size // len(ctx) if args.horovod else args.batch_size
)
train_data, val_data = get_dataloader(
net, train_dataset, val_dataset, FasterRCNNDefaultTrainTransform,
FasterRCNNDefaultValTransform, batch_size, len(ctx), args)
net,
train_dataset,
val_dataset,
FasterRCNNDefaultTrainTransform,
FasterRCNNDefaultValTransform,
batch_size,
len(ctx),
args,
)
# training
train(net, train_data, val_data, eval_metric, batch_size, ctx, args)
import argparse
import json
import os
import pickle
import numpy as np
import json, pickle, os, argparse
def parse_args():
parser = argparse.ArgumentParser(description='Train the Frequenct Prior For RelDN.')
parser.add_argument('--overlap', action='store_true',
help="Only count overlap boxes.")
parser.add_argument('--json-path', type=str, default='~/.mxnet/datasets/visualgenome',
help="Only count overlap boxes.")
parser = argparse.ArgumentParser(
description="Train the Frequenct Prior For RelDN."
)
parser.add_argument(
"--overlap", action="store_true", help="Only count overlap boxes."
)
parser.add_argument(
"--json-path",
type=str,
default="~/.mxnet/datasets/visualgenome",
help="Only count overlap boxes.",
)
args = parser.parse_args()
return args
args = parse_args()
use_overlap = args.overlap
PATH_TO_DATASETS = os.path.expanduser(args.json_path)
path_to_json = os.path.join(PATH_TO_DATASETS, 'rel_annotations_train.json')
path_to_json = os.path.join(PATH_TO_DATASETS, "rel_annotations_train.json")
# format in y1y2x1x2
def with_overlap(boxA, boxB):
......@@ -29,17 +42,19 @@ def with_overlap(boxA, boxB):
return 0
def box_ious(boxes):
n = len(boxes)
res = np.zeros((n, n))
for i in range(n-1):
for j in range(i+1, n):
for i in range(n - 1):
for j in range(i + 1, n):
iou_val = with_overlap(boxes[i], boxes[j])
res[i, j] = iou_val
res[j, i] = iou_val
return res
with open(path_to_json, 'r') as f:
with open(path_to_json, "r") as f:
tmp = f.read()
train_data = json.loads(tmp)
......@@ -49,11 +64,11 @@ bg_matrix = np.zeros((150, 150), dtype=np.int64)
for _, item in train_data.items():
gt_box_to_label = {}
for rel in item:
sub_bbox = rel['subject']['bbox']
ob_bbox = rel['object']['bbox']
sub_class = rel['subject']['category']
ob_class = rel['object']['category']
rel_class = rel['predicate']
sub_bbox = rel["subject"]["bbox"]
ob_bbox = rel["object"]["bbox"]
sub_class = rel["subject"]["category"]
ob_class = rel["object"]["category"]
rel_class = rel["predicate"]
sub_node = tuple(sub_bbox)
ob_node = tuple(ob_bbox)
......@@ -93,8 +108,8 @@ pred_dist = np.log(fg_matrix / (fg_matrix.sum(2)[:, :, None] + eps) + eps)
if use_overlap:
with open('freq_prior_overlap.pkl', 'wb') as f:
with open("freq_prior_overlap.pkl", "wb") as f:
pickle.dump(pred_dist, f)
else:
with open('freq_prior.pkl', 'wb') as f:
with open("freq_prior.pkl", "wb") as f:
pickle.dump(pred_dist, f)
import dgl
import argparse
import logging
import time
import mxnet as mx
import numpy as np
import logging, time, argparse
from mxnet import nd, gluon
from data import *
from gluoncv.data.batchify import Pad
from gluoncv.utils import makedirs
from model import faster_rcnn_resnet101_v1d_custom, RelDN
from model import RelDN, faster_rcnn_resnet101_v1d_custom
from mxnet import gluon, nd
from utils import *
from data import *
import dgl
def parse_args():
parser = argparse.ArgumentParser(description='Train RelDN Model.')
parser.add_argument('--gpus', type=str, default='0',
help="Training with GPUs, you can specify 1,3 for example.")
parser.add_argument('--batch-size', type=int, default=8,
help="Total batch-size for training.")
parser.add_argument('--epochs', type=int, default=9,
help="Training epochs.")
parser.add_argument('--lr-reldn', type=float, default=0.01,
help="Learning rate for RelDN module.")
parser.add_argument('--wd-reldn', type=float, default=0.0001,
help="Weight decay for RelDN module.")
parser.add_argument('--lr-faster-rcnn', type=float, default=0.01,
help="Learning rate for Faster R-CNN module.")
parser.add_argument('--wd-faster-rcnn', type=float, default=0.0001,
help="Weight decay for RelDN module.")
parser.add_argument('--lr-decay-epochs', type=str, default='5,8',
help="Learning rate decay points.")
parser.add_argument('--lr-warmup-iters', type=int, default=4000,
help="Learning rate warm-up iterations.")
parser.add_argument('--save-dir', type=str, default='params_resnet101_v1d_reldn',
help="Path to save model parameters.")
parser.add_argument('--log-dir', type=str, default='reldn_output.log',
help="Path to save training logs.")
parser.add_argument('--pretrained-faster-rcnn-params', type=str, required=True,
help="Path to saved Faster R-CNN model parameters.")
parser.add_argument('--freq-prior', type=str, default='freq_prior.pkl',
help="Path to saved frequency prior data.")
parser.add_argument('--verbose-freq', type=int, default=100,
help="Frequency of log printing in number of iterations.")
parser = argparse.ArgumentParser(description="Train RelDN Model.")
parser.add_argument(
"--gpus",
type=str,
default="0",
help="Training with GPUs, you can specify 1,3 for example.",
)
parser.add_argument(
"--batch-size",
type=int,
default=8,
help="Total batch-size for training.",
)
parser.add_argument(
"--epochs", type=int, default=9, help="Training epochs."
)
parser.add_argument(
"--lr-reldn",
type=float,
default=0.01,
help="Learning rate for RelDN module.",
)
parser.add_argument(
"--wd-reldn",
type=float,
default=0.0001,
help="Weight decay for RelDN module.",
)
parser.add_argument(
"--lr-faster-rcnn",
type=float,
default=0.01,
help="Learning rate for Faster R-CNN module.",
)
parser.add_argument(
"--wd-faster-rcnn",
type=float,
default=0.0001,
help="Weight decay for RelDN module.",
)
parser.add_argument(
"--lr-decay-epochs",
type=str,
default="5,8",
help="Learning rate decay points.",
)
parser.add_argument(
"--lr-warmup-iters",
type=int,
default=4000,
help="Learning rate warm-up iterations.",
)
parser.add_argument(
"--save-dir",
type=str,
default="params_resnet101_v1d_reldn",
help="Path to save model parameters.",
)
parser.add_argument(
"--log-dir",
type=str,
default="reldn_output.log",
help="Path to save training logs.",
)
parser.add_argument(
"--pretrained-faster-rcnn-params",
type=str,
required=True,
help="Path to saved Faster R-CNN model parameters.",
)
parser.add_argument(
"--freq-prior",
type=str,
default="freq_prior.pkl",
help="Path to saved frequency prior data.",
)
parser.add_argument(
"--verbose-freq",
type=int,
default=100,
help="Frequency of log printing in number of iterations.",
)
args = parser.parse_args()
return args
args = parse_args()
filehandler = logging.FileHandler(args.log_dir)
streamhandler = logging.StreamHandler()
logger = logging.getLogger('')
logger = logging.getLogger("")
logger.setLevel(logging.INFO)
logger.addHandler(filehandler)
logger.addHandler(streamhandler)
# Hyperparams
ctx = [mx.gpu(int(i)) for i in args.gpus.split(',') if i.strip()]
ctx = [mx.gpu(int(i)) for i in args.gpus.split(",") if i.strip()]
if ctx:
num_gpus = len(ctx)
assert args.batch_size % num_gpus == 0
......@@ -71,13 +129,18 @@ N_objects = 150
save_dir = args.save_dir
makedirs(save_dir)
batch_verbose_freq = args.verbose_freq
lr_decay_epochs = [int(i) for i in args.lr_decay_epochs.split(',')]
lr_decay_epochs = [int(i) for i in args.lr_decay_epochs.split(",")]
# Dataset and dataloader
vg_train = VGRelation(split='train')
logger.info('data loaded!')
train_data = gluon.data.DataLoader(vg_train, batch_size=len(ctx), shuffle=True, num_workers=8*num_gpus,
batchify_fn=dgl_mp_batchify_fn)
vg_train = VGRelation(split="train")
logger.info("data loaded!")
train_data = gluon.data.DataLoader(
vg_train,
batch_size=len(ctx),
shuffle=True,
num_workers=8 * num_gpus,
batchify_fn=dgl_mp_batchify_fn,
)
n_batches = len(train_data)
# Network definition
......@@ -85,30 +148,47 @@ net = RelDN(n_classes=N_relations, prior_pkl=args.freq_prior)
net.spatial.initialize(mx.init.Normal(1e-4), ctx=ctx)
net.visual.initialize(mx.init.Normal(1e-4), ctx=ctx)
for k, v in net.collect_params().items():
v.grad_req = 'add' if aggregate_grad else 'write'
v.grad_req = "add" if aggregate_grad else "write"
net_params = net.collect_params()
net_trainer = gluon.Trainer(net.collect_params(), 'adam',
{'learning_rate': args.lr_reldn, 'wd': args.wd_reldn})
net_trainer = gluon.Trainer(
net.collect_params(),
"adam",
{"learning_rate": args.lr_reldn, "wd": args.wd_reldn},
)
det_params_path = args.pretrained_faster_rcnn_params
detector = faster_rcnn_resnet101_v1d_custom(classes=vg_train.obj_classes,
pretrained_base=False, pretrained=False,
additional_output=True)
detector.load_parameters(det_params_path, ctx=ctx, ignore_extra=True, allow_missing=True)
detector = faster_rcnn_resnet101_v1d_custom(
classes=vg_train.obj_classes,
pretrained_base=False,
pretrained=False,
additional_output=True,
)
detector.load_parameters(
det_params_path, ctx=ctx, ignore_extra=True, allow_missing=True
)
for k, v in detector.collect_params().items():
v.grad_req = 'null'
v.grad_req = "null"
detector_feat = faster_rcnn_resnet101_v1d_custom(classes=vg_train.obj_classes,
pretrained_base=False, pretrained=False,
additional_output=True)
detector_feat.load_parameters(det_params_path, ctx=ctx, ignore_extra=True, allow_missing=True)
detector_feat = faster_rcnn_resnet101_v1d_custom(
classes=vg_train.obj_classes,
pretrained_base=False,
pretrained=False,
additional_output=True,
)
detector_feat.load_parameters(
det_params_path, ctx=ctx, ignore_extra=True, allow_missing=True
)
for k, v in detector_feat.collect_params().items():
v.grad_req = 'null'
v.grad_req = "null"
for k, v in detector_feat.features.collect_params().items():
v.grad_req = 'add' if aggregate_grad else 'write'
v.grad_req = "add" if aggregate_grad else "write"
det_params = detector_feat.features.collect_params()
det_trainer = gluon.Trainer(detector_feat.features.collect_params(), 'adam',
{'learning_rate': args.lr_faster_rcnn, 'wd': args.wd_faster_rcnn})
det_trainer = gluon.Trainer(
detector_feat.features.collect_params(),
"adam",
{"learning_rate": args.lr_faster_rcnn, "wd": args.wd_faster_rcnn},
)
def get_data_batch(g_list, img_list, ctx_list):
if g_list is None or len(g_list) == 0:
......@@ -118,37 +198,58 @@ def get_data_batch(g_list, img_list, ctx_list):
if size < n_gpu:
raise Exception("too small batch")
step = size // n_gpu
G_list = [g_list[i*step:(i+1)*step] if i < n_gpu - 1 else g_list[i*step:size] for i in range(n_gpu)]
img_list = [img_list[i*step:(i+1)*step] if i < n_gpu - 1 else img_list[i*step:size] for i in range(n_gpu)]
G_list = [
g_list[i * step : (i + 1) * step]
if i < n_gpu - 1
else g_list[i * step : size]
for i in range(n_gpu)
]
img_list = [
img_list[i * step : (i + 1) * step]
if i < n_gpu - 1
else img_list[i * step : size]
for i in range(n_gpu)
]
for G_slice, ctx in zip(G_list, ctx_list):
for G in G_slice:
G.ndata['bbox'] = G.ndata['bbox'].as_in_context(ctx)
G.ndata['node_class'] = G.ndata['node_class'].as_in_context(ctx)
G.ndata['node_class_vec'] = G.ndata['node_class_vec'].as_in_context(ctx)
G.edata['rel_class'] = G.edata['rel_class'].as_in_context(ctx)
G.ndata["bbox"] = G.ndata["bbox"].as_in_context(ctx)
G.ndata["node_class"] = G.ndata["node_class"].as_in_context(ctx)
G.ndata["node_class_vec"] = G.ndata["node_class_vec"].as_in_context(
ctx
)
G.edata["rel_class"] = G.edata["rel_class"].as_in_context(ctx)
img_list = [img.as_in_context(ctx) for img in img_list]
return G_list, img_list
L_rel = gluon.loss.SoftmaxCELoss()
train_metric = mx.metric.Accuracy(name='rel_acc')
train_metric_top5 = mx.metric.TopKAccuracy(5, name='rel_acc_top5')
train_metric = mx.metric.Accuracy(name="rel_acc")
train_metric_top5 = mx.metric.TopKAccuracy(5, name="rel_acc_top5")
metric_list = [train_metric, train_metric_top5]
def batch_print(epoch, i, batch_verbose_freq, n_batches, btic, loss_rel_val, metric_list):
if (i+1) % batch_verbose_freq == 0:
print_txt = 'Epoch[%d] Batch[%d/%d], time: %d, loss_rel=%.4f '%\
(epoch, i, n_batches, int(time.time() - btic),
loss_rel_val / (i+1), )
def batch_print(
epoch, i, batch_verbose_freq, n_batches, btic, loss_rel_val, metric_list
):
if (i + 1) % batch_verbose_freq == 0:
print_txt = "Epoch[%d] Batch[%d/%d], time: %d, loss_rel=%.4f " % (
epoch,
i,
n_batches,
int(time.time() - btic),
loss_rel_val / (i + 1),
)
for metric in metric_list:
metric_name, metric_val = metric.get()
print_txt += '%s=%.4f '%(metric_name, metric_val)
print_txt += "%s=%.4f " % (metric_name, metric_val)
logger.info(print_txt)
btic = time.time()
loss_rel_val = 0
return btic, loss_rel_val
for epoch in range(nepoch):
loss_rel_val = 0
tic = time.time()
......@@ -159,17 +260,25 @@ for epoch in range(nepoch):
net_trainer_base_lr = net_trainer.learning_rate
det_trainer_base_lr = det_trainer.learning_rate
if epoch == 5 or epoch == 8:
net_trainer.set_learning_rate(net_trainer.learning_rate*0.1)
det_trainer.set_learning_rate(det_trainer.learning_rate*0.1)
net_trainer.set_learning_rate(net_trainer.learning_rate * 0.1)
det_trainer.set_learning_rate(det_trainer.learning_rate * 0.1)
for i, (G_list, img_list) in enumerate(train_data):
if epoch == 0 and i < args.lr_warmup_iters:
alpha = i / args.lr_warmup_iters
warmup_factor = 1/3 * (1 - alpha) + alpha
net_trainer.set_learning_rate(net_trainer_base_lr*warmup_factor)
det_trainer.set_learning_rate(det_trainer_base_lr*warmup_factor)
warmup_factor = 1 / 3 * (1 - alpha) + alpha
net_trainer.set_learning_rate(net_trainer_base_lr * warmup_factor)
det_trainer.set_learning_rate(det_trainer_base_lr * warmup_factor)
G_list, img_list = get_data_batch(G_list, img_list, ctx)
if G_list is None or img_list is None:
btic, loss_rel_val = batch_print(epoch, i, batch_verbose_freq, n_batches, btic, loss_rel_val, metric_list)
btic, loss_rel_val = batch_print(
epoch,
i,
batch_verbose_freq,
n_batches,
btic,
loss_rel_val,
metric_list,
)
continue
loss = []
......@@ -179,17 +288,29 @@ for epoch in range(nepoch):
with mx.autograd.record():
for G_slice, img in zip(G_list, img_list):
cur_ctx = img.context
bbox_list = [G.ndata['bbox'] for G in G_slice]
bbox_list = [G.ndata["bbox"] for G in G_slice]
bbox_stack = bbox_pad(bbox_list).as_in_context(cur_ctx)
with mx.autograd.pause():
ids, scores, bbox, feat, feat_ind, spatial_feat = detector(img)
g_pred_batch = build_graph_train(G_slice, bbox_stack, img, ids, scores, bbox, feat_ind,
spatial_feat, scores_top_k=300, overlap=False)
ids, scores, bbox, feat, feat_ind, spatial_feat = detector(
img
)
g_pred_batch = build_graph_train(
G_slice,
bbox_stack,
img,
ids,
scores,
bbox,
feat_ind,
spatial_feat,
scores_top_k=300,
overlap=False,
)
g_batch = l0_sample(g_pred_batch)
if g_batch is None:
continue
rel_bbox = g_batch.edata['rel_bbox']
batch_id = g_batch.edata['batch_id'].asnumpy()
rel_bbox = g_batch.edata["rel_bbox"]
batch_id = g_batch.edata["batch_id"].asnumpy()
n_sample_edges = g_batch.number_of_edges()
n_graph = len(G_slice)
bbox_rel_list = []
......@@ -203,13 +324,19 @@ for epoch in range(nepoch):
bbox_rel_stack[:, :, 1] *= img_size[0]
bbox_rel_stack[:, :, 2] *= img_size[1]
bbox_rel_stack[:, :, 3] *= img_size[0]
_, _, _, spatial_feat_rel = detector_feat(img, None, None, bbox_rel_stack)
_, _, _, spatial_feat_rel = detector_feat(
img, None, None, bbox_rel_stack
)
spatial_feat_rel_list = []
for j in range(n_graph):
eids = np.where(batch_id == j)[0]
if len(eids) > 0:
spatial_feat_rel_list.append(spatial_feat_rel[j, 0:len(eids)])
g_batch.edata['edge_feat'] = nd.concat(*spatial_feat_rel_list, dim=0)
spatial_feat_rel_list.append(
spatial_feat_rel[j, 0 : len(eids)]
)
g_batch.edata["edge_feat"] = nd.concat(
*spatial_feat_rel_list, dim=0
)
G_batch.append(g_batch)
......@@ -218,17 +345,28 @@ for epoch in range(nepoch):
for G_pred, img in zip(G_batch, img_list):
if G_pred is None or G_pred.number_of_nodes() == 0:
continue
loss_rel = L_rel(G_pred.edata['preds'], G_pred.edata['rel_class'],
G_pred.edata['sample_weights'])
loss_rel = L_rel(
G_pred.edata["preds"],
G_pred.edata["rel_class"],
G_pred.edata["sample_weights"],
)
loss.append(loss_rel.sum())
loss_rel_val += loss_rel.mean().asscalar() / num_gpus
if len(loss) == 0:
btic, loss_rel_val = batch_print(epoch, i, batch_verbose_freq, n_batches, btic, loss_rel_val, metric_list)
btic, loss_rel_val = batch_print(
epoch,
i,
batch_verbose_freq,
n_batches,
btic,
loss_rel_val,
metric_list,
)
continue
for l in loss:
l.backward()
if (i+1) % per_device_batch_size == 0 or i == n_batches - 1:
if (i + 1) % per_device_batch_size == 0 or i == n_batches - 1:
net_trainer.step(args.batch_size)
det_trainer.step(args.batch_size)
if aggregate_grad:
......@@ -239,23 +377,41 @@ for epoch in range(nepoch):
for G_pred, img_slice in zip(G_batch, img_list):
if G_pred is None or G_pred.number_of_nodes() == 0:
continue
link_ind = np.where(G_pred.edata['rel_class'].asnumpy() > 0)[0]
link_ind = np.where(G_pred.edata["rel_class"].asnumpy() > 0)[0]
if len(link_ind) == 0:
continue
train_metric.update([G_pred.edata['rel_class'][link_ind]],
[G_pred.edata['preds'][link_ind]])
train_metric_top5.update([G_pred.edata['rel_class'][link_ind]],
[G_pred.edata['preds'][link_ind]])
btic, loss_rel_val = batch_print(epoch, i, batch_verbose_freq, n_batches, btic, loss_rel_val, metric_list)
if (i+1) % batch_verbose_freq == 0:
net.save_parameters('%s/model-%d.params'%(save_dir, epoch))
detector_feat.features.save_parameters('%s/detector_feat.features-%d.params'%(save_dir, epoch))
print_txt = 'Epoch[%d], time: %d, loss_rel=%.4f,'%\
(epoch, int(time.time() - tic),
loss_rel_val / (i+1))
train_metric.update(
[G_pred.edata["rel_class"][link_ind]],
[G_pred.edata["preds"][link_ind]],
)
train_metric_top5.update(
[G_pred.edata["rel_class"][link_ind]],
[G_pred.edata["preds"][link_ind]],
)
btic, loss_rel_val = batch_print(
epoch,
i,
batch_verbose_freq,
n_batches,
btic,
loss_rel_val,
metric_list,
)
if (i + 1) % batch_verbose_freq == 0:
net.save_parameters("%s/model-%d.params" % (save_dir, epoch))
detector_feat.features.save_parameters(
"%s/detector_feat.features-%d.params" % (save_dir, epoch)
)
print_txt = "Epoch[%d], time: %d, loss_rel=%.4f," % (
epoch,
int(time.time() - tic),
loss_rel_val / (i + 1),
)
for metric in metric_list:
metric_name, metric_val = metric.get()
print_txt += '%s=%.4f '%(metric_name, metric_val)
print_txt += "%s=%.4f " % (metric_name, metric_val)
logger.info(print_txt)
net.save_parameters('%s/model-%d.params'%(save_dir, epoch))
detector_feat.features.save_parameters('%s/detector_feat.features-%d.params'%(save_dir, epoch))
net.save_parameters("%s/model-%d.params" % (save_dir, epoch))
detector_feat.features.save_parameters(
"%s/detector_feat.features-%d.params" % (save_dir, epoch)
)
from .metric import *
from .build_graph import *
from .metric import *
from .sampling import *
from .viz import *
import dgl
from mxnet import nd
import numpy as np
from mxnet import nd
import dgl
def bbox_improve(bbox):
'''bbox encoding'''
area = (bbox[:,2] - bbox[:,0]) * (bbox[:,3] - bbox[:,1])
"""bbox encoding"""
area = (bbox[:, 2] - bbox[:, 0]) * (bbox[:, 3] - bbox[:, 1])
return nd.concat(bbox, area.expand_dims(1))
def extract_edge_bbox(g):
'''bbox encoding'''
src, dst = g.edges(order='eid')
"""bbox encoding"""
src, dst = g.edges(order="eid")
n = g.number_of_edges()
src_bbox = g.ndata['pred_bbox'][src.asnumpy()]
dst_bbox = g.ndata['pred_bbox'][dst.asnumpy()]
edge_bbox = nd.zeros((n, 4), ctx=g.ndata['pred_bbox'].context)
edge_bbox[:,0] = nd.stack(src_bbox[:,0], dst_bbox[:,0]).min(axis=0)
edge_bbox[:,1] = nd.stack(src_bbox[:,1], dst_bbox[:,1]).min(axis=0)
edge_bbox[:,2] = nd.stack(src_bbox[:,2], dst_bbox[:,2]).max(axis=0)
edge_bbox[:,3] = nd.stack(src_bbox[:,3], dst_bbox[:,3]).max(axis=0)
src_bbox = g.ndata["pred_bbox"][src.asnumpy()]
dst_bbox = g.ndata["pred_bbox"][dst.asnumpy()]
edge_bbox = nd.zeros((n, 4), ctx=g.ndata["pred_bbox"].context)
edge_bbox[:, 0] = nd.stack(src_bbox[:, 0], dst_bbox[:, 0]).min(axis=0)
edge_bbox[:, 1] = nd.stack(src_bbox[:, 1], dst_bbox[:, 1]).min(axis=0)
edge_bbox[:, 2] = nd.stack(src_bbox[:, 2], dst_bbox[:, 2]).max(axis=0)
edge_bbox[:, 3] = nd.stack(src_bbox[:, 3], dst_bbox[:, 3]).max(axis=0)
return edge_bbox
def build_graph_train(g_slice, gt_bbox, img, ids, scores, bbox, feat_ind,
spatial_feat, iou_thresh=0.5,
bbox_improvement=True, scores_top_k=50, overlap=False):
'''given ground truth and predicted bboxes, assign the label to the predicted w.r.t iou_thresh'''
def build_graph_train(
g_slice,
gt_bbox,
img,
ids,
scores,
bbox,
feat_ind,
spatial_feat,
iou_thresh=0.5,
bbox_improvement=True,
scores_top_k=50,
overlap=False,
):
"""given ground truth and predicted bboxes, assign the label to the predicted w.r.t iou_thresh"""
# match and re-factor the graph
img_size = img.shape[2:4]
gt_bbox[:, :, 0] /= img_size[1]
......@@ -39,24 +54,33 @@ def build_graph_train(g_slice, gt_bbox, img, ids, scores, bbox, feat_ind,
g_pred_batch = []
for gi in range(n_graph):
g = g_slice[gi]
ctx = g.ndata['bbox'].context
ctx = g.ndata["bbox"].context
inds = np.where(scores[gi, :, 0].asnumpy() > 0)[0].tolist()
if len(inds) == 0:
return None
if len(inds) > scores_top_k:
top_score_inds = scores[gi, inds, 0].asnumpy().argsort()[::-1][0:scores_top_k]
top_score_inds = (
scores[gi, inds, 0].asnumpy().argsort()[::-1][0:scores_top_k]
)
inds = np.array(inds)[top_score_inds].tolist()
n_nodes = len(inds)
roi_ind = feat_ind[gi, inds].squeeze(axis=1)
g_pred = dgl.DGLGraph()
g_pred.add_nodes(n_nodes, {'pred_bbox': bbox[gi, inds],
'node_feat': spatial_feat[gi, roi_ind],
'node_class_pred': ids[gi, inds, 0],
'node_class_logit': nd.log(scores[gi, inds, 0] + 1e-7)})
g_pred.add_nodes(
n_nodes,
{
"pred_bbox": bbox[gi, inds],
"node_feat": spatial_feat[gi, roi_ind],
"node_class_pred": ids[gi, inds, 0],
"node_class_logit": nd.log(scores[gi, inds, 0] + 1e-7),
},
)
# iou matching
ious = nd.contrib.box_iou(gt_bbox[gi], g_pred.ndata['pred_bbox']).asnumpy()
ious = nd.contrib.box_iou(
gt_bbox[gi], g_pred.ndata["pred_bbox"]
).asnumpy()
H, W = ious.shape
h = H
w = W
......@@ -70,8 +94,8 @@ def build_graph_train(g_slice, gt_bbox, img, ids, scores, bbox, feat_ind,
if ious[row_ind, col_ind] < iou_thresh:
break
pred_to_gt_ind[col_ind] = row_ind
gt_node_class = g.ndata['node_class'][row_ind]
pred_node_class = g_pred.ndata['node_class_pred'][col_ind]
gt_node_class = g.ndata["node_class"][row_ind]
pred_node_class = g_pred.ndata["node_class_pred"][col_ind]
if gt_node_class == pred_node_class:
pred_to_gt_class_match[col_ind] = 1
pred_to_gt_class_match_id[col_ind] = row_ind
......@@ -84,7 +108,7 @@ def build_graph_train(g_slice, gt_bbox, img, ids, scores, bbox, feat_ind,
triplet = []
adjmat = np.zeros((n_nodes, n_nodes))
src, dst = g.all_edges(order='eid')
src, dst = g.all_edges(order="eid")
eid_keys = np.column_stack([src.asnumpy(), dst.asnumpy()])
eid_dict = {}
for i, key in enumerate(eid_keys):
......@@ -93,7 +117,7 @@ def build_graph_train(g_slice, gt_bbox, img, ids, scores, bbox, feat_ind,
eid_dict[k] = [i]
else:
eid_dict[k].append(i)
ori_rel_class = g.edata['rel_class'].asnumpy()
ori_rel_class = g.edata["rel_class"].asnumpy()
for i in range(n_nodes):
for j in range(n_nodes):
if i != j:
......@@ -105,25 +129,27 @@ def build_graph_train(g_slice, gt_bbox, img, ids, scores, bbox, feat_ind,
n_edges_between = len(rel_cls)
for ii in range(n_edges_between):
triplet.append((i, j, rel_cls[ii]))
adjmat[i,j] = 1
adjmat[i, j] = 1
else:
triplet.append((i, j, 0))
src, dst, rel_class = tuple(zip(*triplet))
rel_class = nd.array(rel_class, ctx=ctx).expand_dims(1)
g_pred.add_edges(src, dst, data={'rel_class': rel_class})
g_pred.add_edges(src, dst, data={"rel_class": rel_class})
# other operations
n_nodes = g_pred.number_of_nodes()
n_edges = g_pred.number_of_edges()
if bbox_improvement:
g_pred.ndata['pred_bbox'] = bbox_improve(g_pred.ndata['pred_bbox'])
g_pred.edata['rel_bbox'] = extract_edge_bbox(g_pred)
g_pred.edata['batch_id'] = nd.zeros((n_edges, 1), ctx = ctx) + gi
g_pred.ndata["pred_bbox"] = bbox_improve(g_pred.ndata["pred_bbox"])
g_pred.edata["rel_bbox"] = extract_edge_bbox(g_pred)
g_pred.edata["batch_id"] = nd.zeros((n_edges, 1), ctx=ctx) + gi
# remove non-overlapping edges
if overlap:
overlap_ious = nd.contrib.box_iou(g_pred.ndata['pred_bbox'][:,0:4],
g_pred.ndata['pred_bbox'][:,0:4]).asnumpy()
overlap_ious = nd.contrib.box_iou(
g_pred.ndata["pred_bbox"][:, 0:4],
g_pred.ndata["pred_bbox"][:, 0:4],
).asnumpy()
cols, rows = np.where(overlap_ious <= 1e-7)
if cols.shape[0] > 0:
eids = g_pred.edge_ids(cols, rows)[2].asnumpy().tolist()
......@@ -132,15 +158,17 @@ def build_graph_train(g_slice, gt_bbox, img, ids, scores, bbox, feat_ind,
if g_pred.number_of_edges() == 0:
g_pred = None
g_pred_batch.append(g_pred)
if n_graph > 1:
return dgl.batch(g_pred_batch)
else:
return g_pred_batch[0]
def build_graph_validate_gt_obj(img, gt_ids, bbox, spatial_feat,
bbox_improvement=True, overlap=False):
'''given ground truth bbox and label, build graph for validation'''
def build_graph_validate_gt_obj(
img, gt_ids, bbox, spatial_feat, bbox_improvement=True, overlap=False
):
"""given ground truth bbox and label, build graph for validation"""
n_batch = img.shape[0]
img_size = img.shape[2:4]
bbox[:, :, 0] /= img_size[1]
......@@ -156,10 +184,17 @@ def build_graph_validate_gt_obj(img, gt_ids, bbox, spatial_feat,
continue
n_nodes = len(inds)
g_pred = dgl.DGLGraph()
g_pred.add_nodes(n_nodes, {'pred_bbox': bbox[btc, inds],
'node_feat': spatial_feat[btc, inds],
'node_class_pred': gt_ids[btc, inds, 0],
'node_class_logit': nd.zeros_like(gt_ids[btc, inds, 0], ctx=ctx)})
g_pred.add_nodes(
n_nodes,
{
"pred_bbox": bbox[btc, inds],
"node_feat": spatial_feat[btc, inds],
"node_class_pred": gt_ids[btc, inds, 0],
"node_class_logit": nd.zeros_like(
gt_ids[btc, inds, 0], ctx=ctx
),
},
)
edge_list = []
for i in range(n_nodes - 1):
......@@ -172,21 +207,30 @@ def build_graph_validate_gt_obj(img, gt_ids, bbox, spatial_feat,
n_nodes = g_pred.number_of_nodes()
n_edges = g_pred.number_of_edges()
if bbox_improvement:
g_pred.ndata['pred_bbox'] = bbox_improve(g_pred.ndata['pred_bbox'])
g_pred.edata['rel_bbox'] = extract_edge_bbox(g_pred)
g_pred.edata['batch_id'] = nd.zeros((n_edges, 1), ctx = ctx) + btc
g_pred.ndata["pred_bbox"] = bbox_improve(g_pred.ndata["pred_bbox"])
g_pred.edata["rel_bbox"] = extract_edge_bbox(g_pred)
g_pred.edata["batch_id"] = nd.zeros((n_edges, 1), ctx=ctx) + btc
g_batch.append(g_pred)
if len(g_batch) == 0:
return None
return None
if len(g_batch) > 1:
return dgl.batch(g_batch)
return g_batch[0]
def build_graph_validate_gt_bbox(img, ids, scores, bbox, spatial_feat, gt_ids=None,
bbox_improvement=True, overlap=False):
'''given ground truth bbox, build graph for validation'''
def build_graph_validate_gt_bbox(
img,
ids,
scores,
bbox,
spatial_feat,
gt_ids=None,
bbox_improvement=True,
overlap=False,
):
"""given ground truth bbox, build graph for validation"""
n_batch = img.shape[0]
img_size = img.shape[2:4]
bbox[:, :, 0] /= img_size[1]
......@@ -197,17 +241,22 @@ def build_graph_validate_gt_bbox(img, ids, scores, bbox, spatial_feat, gt_ids=No
g_batch = []
for btc in range(n_batch):
id_btc = scores[btc][:,:,0].argmax(0)
score_btc = scores[btc][:,:,0].max(0)
id_btc = scores[btc][:, :, 0].argmax(0)
score_btc = scores[btc][:, :, 0].max(0)
inds = np.where(bbox[btc].sum(1).asnumpy() > 0)[0].tolist()
if len(inds) == 0:
continue
n_nodes = len(inds)
g_pred = dgl.DGLGraph()
g_pred.add_nodes(n_nodes, {'pred_bbox': bbox[btc, inds],
'node_feat': spatial_feat[btc, inds],
'node_class_pred': id_btc,
'node_class_logit': nd.log(score_btc + 1e-7)})
g_pred.add_nodes(
n_nodes,
{
"pred_bbox": bbox[btc, inds],
"node_feat": spatial_feat[btc, inds],
"node_class_pred": id_btc,
"node_class_logit": nd.log(score_btc + 1e-7),
},
)
edge_list = []
for i in range(n_nodes - 1):
......@@ -220,21 +269,31 @@ def build_graph_validate_gt_bbox(img, ids, scores, bbox, spatial_feat, gt_ids=No
n_nodes = g_pred.number_of_nodes()
n_edges = g_pred.number_of_edges()
if bbox_improvement:
g_pred.ndata['pred_bbox'] = bbox_improve(g_pred.ndata['pred_bbox'])
g_pred.edata['rel_bbox'] = extract_edge_bbox(g_pred)
g_pred.edata['batch_id'] = nd.zeros((n_edges, 1), ctx = ctx) + btc
g_pred.ndata["pred_bbox"] = bbox_improve(g_pred.ndata["pred_bbox"])
g_pred.edata["rel_bbox"] = extract_edge_bbox(g_pred)
g_pred.edata["batch_id"] = nd.zeros((n_edges, 1), ctx=ctx) + btc
g_batch.append(g_pred)
if len(g_batch) == 0:
return None
return None
if len(g_batch) > 1:
return dgl.batch(g_batch)
return g_batch[0]
def build_graph_validate_pred(img, ids, scores, bbox, feat_ind, spatial_feat,
bbox_improvement=True, scores_top_k=50, overlap=False):
'''given predicted bbox, build graph for validation'''
def build_graph_validate_pred(
img,
ids,
scores,
bbox,
feat_ind,
spatial_feat,
bbox_improvement=True,
scores_top_k=50,
overlap=False,
):
"""given predicted bbox, build graph for validation"""
n_batch = img.shape[0]
img_size = img.shape[2:4]
bbox[:, :, 0] /= img_size[1]
......@@ -249,16 +308,23 @@ def build_graph_validate_pred(img, ids, scores, bbox, feat_ind, spatial_feat,
if len(inds) == 0:
continue
if len(inds) > scores_top_k:
top_score_inds = scores[btc, inds, 0].asnumpy().argsort()[::-1][0:scores_top_k]
top_score_inds = (
scores[btc, inds, 0].asnumpy().argsort()[::-1][0:scores_top_k]
)
inds = np.array(inds)[top_score_inds].tolist()
n_nodes = len(inds)
roi_ind = feat_ind[btc, inds].squeeze(axis=1)
g_pred = dgl.DGLGraph()
g_pred.add_nodes(n_nodes, {'pred_bbox': bbox[btc, inds],
'node_feat': spatial_feat[btc, roi_ind],
'node_class_pred': ids[btc, inds, 0],
'node_class_logit': nd.log(scores[btc, inds, 0] + 1e-7)})
g_pred.add_nodes(
n_nodes,
{
"pred_bbox": bbox[btc, inds],
"node_feat": spatial_feat[btc, roi_ind],
"node_class_pred": ids[btc, inds, 0],
"node_class_logit": nd.log(scores[btc, inds, 0] + 1e-7),
},
)
edge_list = []
for i in range(n_nodes - 1):
......@@ -271,14 +337,14 @@ def build_graph_validate_pred(img, ids, scores, bbox, feat_ind, spatial_feat,
n_nodes = g_pred.number_of_nodes()
n_edges = g_pred.number_of_edges()
if bbox_improvement:
g_pred.ndata['pred_bbox'] = bbox_improve(g_pred.ndata['pred_bbox'])
g_pred.edata['rel_bbox'] = extract_edge_bbox(g_pred)
g_pred.edata['batch_id'] = nd.zeros((n_edges, 1), ctx = ctx) + btc
g_pred.ndata["pred_bbox"] = bbox_improve(g_pred.ndata["pred_bbox"])
g_pred.edata["rel_bbox"] = extract_edge_bbox(g_pred)
g_pred.edata["batch_id"] = nd.zeros((n_edges, 1), ctx=ctx) + btc
g_batch.append(g_pred)
if len(g_batch) == 0:
return None
return None
if len(g_batch) > 1:
return dgl.batch(g_batch)
return g_batch[0]
import dgl
import logging
import time
from operator import attrgetter, itemgetter
import mxnet as mx
import numpy as np
import logging, time
from operator import attrgetter, itemgetter
from mxnet import nd, gluon
from gluoncv.data.batchify import Pad
from gluoncv.model_zoo import get_model
from mxnet import gluon, nd
from mxnet.gluon import nn
from dgl.utils import toindex
import dgl
from dgl.nn.mxnet import GraphConv
from gluoncv.model_zoo import get_model
from gluoncv.data.batchify import Pad
from dgl.utils import toindex
def iou(boxA, boxB):
# determine the (x, y)-coordinates of the intersection rectangle
......@@ -16,9 +20,9 @@ def iou(boxA, boxB):
yA = max(boxA[1], boxB[1])
xB = min(boxA[2], boxB[2])
yB = min(boxA[3], boxB[3])
interArea = max(0, xB - xA) * max(0, yB - yA)
if interArea < 1e-7 :
if interArea < 1e-7:
return 0
boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
......@@ -29,12 +33,14 @@ def iou(boxA, boxB):
iou_val = interArea / float(boxAArea + boxBArea - interArea)
return iou_val
def object_iou_thresh(gt_object, pred_object, iou_thresh=0.5):
obj_iou = iou(gt_object[1:5], pred_object[1:5])
if obj_iou >= iou_thresh:
return True
return False
def triplet_iou_thresh(pred_triplet, gt_triplet, iou_thresh=0.5):
sub_iou = iou(gt_triplet[5:9], pred_triplet[5:9])
if sub_iou >= iou_thresh:
......@@ -43,10 +49,11 @@ def triplet_iou_thresh(pred_triplet, gt_triplet, iou_thresh=0.5):
return True
return False
@mx.metric.register
@mx.metric.alias('auc')
@mx.metric.alias("auc")
class AUCMetric(mx.metric.EvalMetric):
def __init__(self, name='auc', eps=1e-12):
def __init__(self, name="auc", eps=1e-12):
super(AUCMetric, self).__init__(name)
self.eps = eps
......@@ -78,12 +85,14 @@ class AUCMetric(mx.metric.EvalMetric):
self.sum_metric += area / total_area
self.num_inst += 1
@mx.metric.register
@mx.metric.alias('predcls')
@mx.metric.alias("predcls")
class PredCls(mx.metric.EvalMetric):
'''Metric with ground truth object location and label'''
"""Metric with ground truth object location and label"""
def __init__(self, topk=20, iou_thresh=0.99):
super(PredCls, self).__init__('predcls@%d'%(topk))
super(PredCls, self).__init__("predcls@%d" % (topk))
self.topk = topk
self.iou_thresh = iou_thresh
......@@ -91,7 +100,7 @@ class PredCls(mx.metric.EvalMetric):
if labels is None or preds is None:
self.num_inst += 1
return
preds = preds[preds[:,0].argsort()[::-1]]
preds = preds[preds[:, 0].argsort()[::-1]]
m = min(self.topk, preds.shape[0])
count = 0
gt_edge_num = labels.shape[0]
......@@ -102,8 +111,9 @@ class PredCls(mx.metric.EvalMetric):
if label_matched[j]:
continue
label = labels[j]
if int(label[2]) == int(pred[2]) and \
triplet_iou_thresh(pred, label, self.iou_thresh):
if int(label[2]) == int(pred[2]) and triplet_iou_thresh(
pred, label, self.iou_thresh
):
count += 1
label_matched[j] = True
......@@ -111,12 +121,14 @@ class PredCls(mx.metric.EvalMetric):
self.sum_metric += count / total
self.num_inst += 1
@mx.metric.register
@mx.metric.alias('phrcls')
@mx.metric.alias("phrcls")
class PhrCls(mx.metric.EvalMetric):
'''Metric with ground truth object location and predicted object label from detector'''
"""Metric with ground truth object location and predicted object label from detector"""
def __init__(self, topk=20, iou_thresh=0.99):
super(PhrCls, self).__init__('phrcls@%d'%(topk))
super(PhrCls, self).__init__("phrcls@%d" % (topk))
self.topk = topk
self.iou_thresh = iou_thresh
......@@ -124,7 +136,7 @@ class PhrCls(mx.metric.EvalMetric):
if labels is None or preds is None:
self.num_inst += 1
return
preds = preds[preds[:,1].argsort()[::-1]]
preds = preds[preds[:, 1].argsort()[::-1]]
m = min(self.topk, preds.shape[0])
count = 0
gt_edge_num = labels.shape[0]
......@@ -135,22 +147,26 @@ class PhrCls(mx.metric.EvalMetric):
if label_matched[j]:
continue
label = labels[j]
if int(label[2]) == int(pred[2]) and \
int(label[3]) == int(pred[3]) and \
int(label[4]) == int(pred[4]) and \
triplet_iou_thresh(pred, label, self.iou_thresh):
if (
int(label[2]) == int(pred[2])
and int(label[3]) == int(pred[3])
and int(label[4]) == int(pred[4])
and triplet_iou_thresh(pred, label, self.iou_thresh)
):
count += 1
label_matched[j] = True
total = labels.shape[0]
self.sum_metric += count / total
self.num_inst += 1
@mx.metric.register
@mx.metric.alias('sgdet')
@mx.metric.alias("sgdet")
class SGDet(mx.metric.EvalMetric):
'''Metric with predicted object information by the detector'''
"""Metric with predicted object information by the detector"""
def __init__(self, topk=20, iou_thresh=0.5):
super(SGDet, self).__init__('sgdet@%d'%(topk))
super(SGDet, self).__init__("sgdet@%d" % (topk))
self.topk = topk
self.iou_thresh = iou_thresh
......@@ -158,7 +174,7 @@ class SGDet(mx.metric.EvalMetric):
if labels is None or preds is None:
self.num_inst += 1
return
preds = preds[preds[:,1].argsort()[::-1]]
preds = preds[preds[:, 1].argsort()[::-1]]
m = min(self.topk, len(preds))
count = 0
gt_edge_num = labels.shape[0]
......@@ -169,22 +185,26 @@ class SGDet(mx.metric.EvalMetric):
if label_matched[j]:
continue
label = labels[j]
if int(label[2]) == int(pred[2]) and \
int(label[3]) == int(pred[3]) and \
int(label[4]) == int(pred[4]) and \
triplet_iou_thresh(pred, label, self.iou_thresh):
if (
int(label[2]) == int(pred[2])
and int(label[3]) == int(pred[3])
and int(label[4]) == int(pred[4])
and triplet_iou_thresh(pred, label, self.iou_thresh)
):
count += 1
label_matched[j] =True
label_matched[j] = True
total = labels.shape[0]
self.sum_metric += count / total
self.num_inst += 1
@mx.metric.register
@mx.metric.alias('sgdet+')
@mx.metric.alias("sgdet+")
class SGDetPlus(mx.metric.EvalMetric):
'''Metric proposed by `Graph R-CNN for Scene Graph Generation`'''
"""Metric proposed by `Graph R-CNN for Scene Graph Generation`"""
def __init__(self, topk=20, iou_thresh=0.5):
super(SGDetPlus, self).__init__('sgdet+@%d'%(topk))
super(SGDetPlus, self).__init__("sgdet+@%d" % (topk))
self.topk = topk
self.iou_thresh = iou_thresh
......@@ -205,13 +225,14 @@ class SGDetPlus(mx.metric.EvalMetric):
if object_matched[j]:
continue
label = label_objects[j]
if int(label[0]) == int(pred[0]) and \
object_iou_thresh(pred, label, self.iou_thresh):
if int(label[0]) == int(pred[0]) and object_iou_thresh(
pred, label, self.iou_thresh
):
count += 1
object_matched[j] = True
# count predicate and triplet
pred_triplets = pred_triplets[pred_triplets[:,1].argsort()[::-1]]
pred_triplets = pred_triplets[pred_triplets[:, 1].argsort()[::-1]]
m = min(self.topk, len(pred_triplets))
gt_triplet_num = label_triplets.shape[0]
triplet_matched = [False for label in label_triplets]
......@@ -221,15 +242,18 @@ class SGDetPlus(mx.metric.EvalMetric):
for j in range(gt_triplet_num):
label = label_triplets[j]
if not predicate_matched:
if int(label[2]) == int(pred[2]) and \
triplet_iou_thresh(pred, label, self.iou_thresh):
if int(label[2]) == int(pred[2]) and triplet_iou_thresh(
pred, label, self.iou_thresh
):
count += label[3]
predicate_matched[j] = True
if not triplet_matched[j]:
if int(label[2]) == int(pred[2]) and \
int(label[3]) == int(pred[3]) and \
int(label[4]) == int(pred[4]) and \
triplet_iou_thresh(pred, label, self.iou_thresh):
if (
int(label[2]) == int(pred[2])
and int(label[3]) == int(pred[3])
and int(label[4]) == int(pred[4])
and triplet_iou_thresh(pred, label, self.iou_thresh)
):
count += 1
triplet_matched[j] = True
# compute sum
......@@ -238,27 +262,28 @@ class SGDetPlus(mx.metric.EvalMetric):
self.sum_metric += count / N
self.num_inst += 1
def extract_gt(g, img_size):
'''extract prediction from ground truth graph'''
"""extract prediction from ground truth graph"""
if g is None or g.number_of_nodes() == 0:
return None, None
gt_eids = np.where(g.edata['rel_class'].asnumpy() > 0)[0]
gt_eids = np.where(g.edata["rel_class"].asnumpy() > 0)[0]
if len(gt_eids) == 0:
return None, None
gt_class = g.ndata['node_class'][:,0].asnumpy()
gt_bbox = g.ndata['bbox'].asnumpy()
gt_bbox[:, 0] /= img_size[1]
gt_bbox[:, 1] /= img_size[0]
gt_bbox[:, 2] /= img_size[1]
gt_bbox[:, 3] /= img_size[0]
gt_class = g.ndata["node_class"][:, 0].asnumpy()
gt_bbox = g.ndata["bbox"].asnumpy()
gt_bbox[:, 0] /= img_size[1]
gt_bbox[:, 1] /= img_size[0]
gt_bbox[:, 2] /= img_size[1]
gt_bbox[:, 3] /= img_size[0]
gt_objects = np.vstack([gt_class, gt_bbox.transpose(1, 0)]).transpose(1, 0)
gt_node_ids = g.find_edges(gt_eids)
gt_node_sub = gt_node_ids[0].asnumpy()
gt_node_ob = gt_node_ids[1].asnumpy()
gt_rel_class = g.edata['rel_class'][gt_eids,0].asnumpy() - 1
gt_rel_class = g.edata["rel_class"][gt_eids, 0].asnumpy() - 1
gt_sub_class = gt_class[gt_node_sub]
gt_ob_class = gt_class[gt_node_ob]
......@@ -266,32 +291,42 @@ def extract_gt(g, img_size):
gt_ob_bbox = gt_bbox[gt_node_ob]
n = len(gt_eids)
gt_triplets = np.vstack([np.ones(n), np.ones(n),
gt_rel_class, gt_sub_class, gt_ob_class,
gt_sub_bbox.transpose(1, 0),
gt_ob_bbox.transpose(1, 0)]).transpose(1, 0)
gt_triplets = np.vstack(
[
np.ones(n),
np.ones(n),
gt_rel_class,
gt_sub_class,
gt_ob_class,
gt_sub_bbox.transpose(1, 0),
gt_ob_bbox.transpose(1, 0),
]
).transpose(1, 0)
return gt_objects, gt_triplets
def extract_pred(g, topk=100, joint_preds=False):
'''extract prediction from prediction graph for validation and visualization'''
"""extract prediction from prediction graph for validation and visualization"""
if g is None or g.number_of_nodes() == 0:
return None, None
pred_class = g.ndata['node_class_pred'].asnumpy()
pred_class_prob = g.ndata['node_class_logit'].asnumpy()
pred_bbox = g.ndata['pred_bbox'][:,0:4].asnumpy()
pred_class = g.ndata["node_class_pred"].asnumpy()
pred_class_prob = g.ndata["node_class_logit"].asnumpy()
pred_bbox = g.ndata["pred_bbox"][:, 0:4].asnumpy()
pred_objects = np.vstack([pred_class, pred_bbox.transpose(1, 0)]).transpose(1, 0)
pred_objects = np.vstack([pred_class, pred_bbox.transpose(1, 0)]).transpose(
1, 0
)
score_pred = g.edata['score_pred'].asnumpy()
score_phr = g.edata['score_phr'].asnumpy()
score_pred = g.edata["score_pred"].asnumpy()
score_phr = g.edata["score_phr"].asnumpy()
score_pred_topk_eids = (-score_pred).argsort()[0:topk].tolist()
score_phr_topk_eids = (-score_phr).argsort()[0:topk].tolist()
topk_eids = sorted(list(set(score_pred_topk_eids + score_phr_topk_eids)))
pred_rel_prob = g.edata['preds'][topk_eids].asnumpy()
pred_rel_prob = g.edata["preds"][topk_eids].asnumpy()
if joint_preds:
pred_rel_class = pred_rel_prob[:,1:].argmax(axis=1)
pred_rel_class = pred_rel_prob[:, 1:].argmax(axis=1)
else:
pred_rel_class = pred_rel_prob.argmax(axis=1)
......@@ -307,8 +342,15 @@ def extract_pred(g, topk=100, joint_preds=False):
pred_ob_class_prob = pred_class_prob[pred_node_ob]
pred_ob_bbox = pred_bbox[pred_node_ob]
pred_triplets = np.vstack([score_pred[topk_eids], score_phr[topk_eids],
pred_rel_class, pred_sub_class, pred_ob_class,
pred_sub_bbox.transpose(1, 0),
pred_ob_bbox.transpose(1, 0)]).transpose(1, 0)
pred_triplets = np.vstack(
[
score_pred[topk_eids],
score_phr[topk_eids],
pred_rel_class,
pred_sub_class,
pred_ob_class,
pred_sub_bbox.transpose(1, 0),
pred_ob_bbox.transpose(1, 0),
]
).transpose(1, 0)
return pred_objects, pred_triplets
import dgl
from dgl.utils import toindex
import mxnet as mx
import numpy as np
import dgl
from dgl.utils import toindex
def l0_sample(g, positive_max=128, negative_ratio=3):
'''sampling positive and negative edges'''
"""sampling positive and negative edges"""
if g is None:
return None
n_eids = g.number_of_edges()
pos_eids = np.where(g.edata['rel_class'].asnumpy() > 0)[0]
neg_eids = np.where(g.edata['rel_class'].asnumpy() == 0)[0]
pos_eids = np.where(g.edata["rel_class"].asnumpy() > 0)[0]
neg_eids = np.where(g.edata["rel_class"].asnumpy() == 0)[0]
if len(pos_eids) == 0:
return None
......@@ -26,6 +28,7 @@ def l0_sample(g, positive_max=128, negative_ratio=3):
eids = np.where(weights > 0)[0]
sub_g = g.edge_subgraph(toindex(eids.tolist()))
sub_g.copy_from_parent()
sub_g.edata['sample_weights'] = mx.nd.array(weights[eids],
ctx=g.edata['rel_class'].context)
sub_g.edata["sample_weights"] = mx.nd.array(
weights[eids], ctx=g.edata["rel_class"].context
)
return sub_g
import numpy as np
import gluoncv as gcv
import numpy as np
from matplotlib import pyplot as plt
def plot_sg(img, preds, obj_classes, rel_classes, topk=1):
'''visualization of generated scene graph'''
"""visualization of generated scene graph"""
size = img.shape[0:2]
box_scale = np.array([size[1], size[0], size[1], size[0]])
topk = min(topk, preds.shape[0])
......@@ -17,30 +18,51 @@ def plot_sg(img, preds, obj_classes, rel_classes, topk=1):
rel_name = rel_classes[rel]
src_bbox = preds[i, 5:9] * box_scale
dst_bbox = preds[i, 9:13] * box_scale
src_center = np.array([(src_bbox[0] + src_bbox[2]) / 2, (src_bbox[1] + src_bbox[3]) / 2])
dst_center = np.array([(dst_bbox[0] + dst_bbox[2]) / 2, (dst_bbox[1] + dst_bbox[3]) / 2])
src_center = np.array(
[(src_bbox[0] + src_bbox[2]) / 2, (src_bbox[1] + src_bbox[3]) / 2]
)
dst_center = np.array(
[(dst_bbox[0] + dst_bbox[2]) / 2, (dst_bbox[1] + dst_bbox[3]) / 2]
)
rel_center = (src_center + dst_center) / 2
line_x = np.array([(src_bbox[0] + src_bbox[2]) / 2, (dst_bbox[0] + dst_bbox[2]) / 2])
line_y = np.array([(src_bbox[1] + src_bbox[3]) / 2, (dst_bbox[1] + dst_bbox[3]) / 2])
ax.plot(line_x, line_y,
linewidth=3.0, alpha=0.7, color=plt.cm.cool(rel))
ax.text(src_center[0], src_center[1],
'{:s}'.format(src_name),
bbox=dict(alpha=0.5),
fontsize=12, color='white')
ax.text(dst_center[0], dst_center[1],
'{:s}'.format(dst_name),
bbox=dict(alpha=0.5),
fontsize=12, color='white')
ax.text(rel_center[0], rel_center[1],
'{:s}'.format(rel_name),
bbox=dict(alpha=0.5),
fontsize=12, color='white')
line_x = np.array(
[(src_bbox[0] + src_bbox[2]) / 2, (dst_bbox[0] + dst_bbox[2]) / 2]
)
line_y = np.array(
[(src_bbox[1] + src_bbox[3]) / 2, (dst_bbox[1] + dst_bbox[3]) / 2]
)
ax.plot(
line_x, line_y, linewidth=3.0, alpha=0.7, color=plt.cm.cool(rel)
)
ax.text(
src_center[0],
src_center[1],
"{:s}".format(src_name),
bbox=dict(alpha=0.5),
fontsize=12,
color="white",
)
ax.text(
dst_center[0],
dst_center[1],
"{:s}".format(dst_name),
bbox=dict(alpha=0.5),
fontsize=12,
color="white",
)
ax.text(
rel_center[0],
rel_center[1],
"{:s}".format(rel_name),
bbox=dict(alpha=0.5),
fontsize=12,
color="white",
)
return ax
plot_sg(img, preds, 2)
plot_sg(img, preds, 2)
import dgl
import argparse
import logging
import time
import mxnet as mx
import numpy as np
import logging, time, argparse
from mxnet import nd, gluon
from data import *
from gluoncv.data.batchify import Pad
from model import faster_rcnn_resnet101_v1d_custom, RelDN
from model import RelDN, faster_rcnn_resnet101_v1d_custom
from mxnet import gluon, nd
from utils import *
from data import *
import dgl
def parse_args():
parser = argparse.ArgumentParser(description='Validate Pre-trained RelDN Model.')
parser.add_argument('--gpus', type=str, default='0',
help="Training with GPUs, you can specify 1,3 for example.")
parser.add_argument('--batch-size', type=int, default=8,
help="Total batch-size for training.")
parser.add_argument('--metric', type=str, default='sgdet',
help="Evaluation metric, could be 'predcls', 'phrcls', 'sgdet' or 'sgdet+'.")
parser.add_argument('--pretrained-faster-rcnn-params', type=str, required=True,
help="Path to saved Faster R-CNN model parameters.")
parser.add_argument('--reldn-params', type=str, required=True,
help="Path to saved Faster R-CNN model parameters.")
parser.add_argument('--faster-rcnn-params', type=str, required=True,
help="Path to saved Faster R-CNN model parameters.")
parser.add_argument('--log-dir', type=str, default='reldn_output.log',
help="Path to save training logs.")
parser.add_argument('--freq-prior', type=str, default='freq_prior.pkl',
help="Path to saved frequency prior data.")
parser.add_argument('--verbose-freq', type=int, default=100,
help="Frequency of log printing in number of iterations.")
parser = argparse.ArgumentParser(
description="Validate Pre-trained RelDN Model."
)
parser.add_argument(
"--gpus",
type=str,
default="0",
help="Training with GPUs, you can specify 1,3 for example.",
)
parser.add_argument(
"--batch-size",
type=int,
default=8,
help="Total batch-size for training.",
)
parser.add_argument(
"--metric",
type=str,
default="sgdet",
help="Evaluation metric, could be 'predcls', 'phrcls', 'sgdet' or 'sgdet+'.",
)
parser.add_argument(
"--pretrained-faster-rcnn-params",
type=str,
required=True,
help="Path to saved Faster R-CNN model parameters.",
)
parser.add_argument(
"--reldn-params",
type=str,
required=True,
help="Path to saved Faster R-CNN model parameters.",
)
parser.add_argument(
"--faster-rcnn-params",
type=str,
required=True,
help="Path to saved Faster R-CNN model parameters.",
)
parser.add_argument(
"--log-dir",
type=str,
default="reldn_output.log",
help="Path to save training logs.",
)
parser.add_argument(
"--freq-prior",
type=str,
default="freq_prior.pkl",
help="Path to saved frequency prior data.",
)
parser.add_argument(
"--verbose-freq",
type=int,
default=100,
help="Frequency of log printing in number of iterations.",
)
args = parser.parse_args()
return args
args = parse_args()
filehandler = logging.FileHandler(args.log_dir)
streamhandler = logging.StreamHandler()
logger = logging.getLogger('')
logger = logging.getLogger("")
logger.setLevel(logging.INFO)
logger.addHandler(filehandler)
logger.addHandler(streamhandler)
# Hyperparams
ctx = [mx.gpu(int(i)) for i in args.gpus.split(',') if i.strip()]
ctx = [mx.gpu(int(i)) for i in args.gpus.split(",") if i.strip()]
if ctx:
num_gpus = len(ctx)
assert args.batch_size % num_gpus == 0
......@@ -58,46 +101,65 @@ batch_verbose_freq = args.verbose_freq
mode = args.metric
metric_list = []
topk_list = [20, 50, 100]
if mode == 'predcls':
if mode == "predcls":
for topk in topk_list:
metric_list.append(PredCls(topk=topk))
if mode == 'phrcls':
if mode == "phrcls":
for topk in topk_list:
metric_list.append(PhrCls(topk=topk))
if mode == 'sgdet':
if mode == "sgdet":
for topk in topk_list:
metric_list.append(SGDet(topk=topk))
if mode == 'sgdet+':
if mode == "sgdet+":
for topk in topk_list:
metric_list.append(SGDetPlus(topk=topk))
for metric in metric_list:
metric.reset()
semantic_only = False
net = RelDN(n_classes=N_relations, prior_pkl=args.freq_prior,
semantic_only=semantic_only)
net = RelDN(
n_classes=N_relations,
prior_pkl=args.freq_prior,
semantic_only=semantic_only,
)
net.load_parameters(args.reldn_params, ctx=ctx)
# dataset and dataloader
vg_val = VGRelation(split='val')
logger.info('data loaded!')
val_data = gluon.data.DataLoader(vg_val, batch_size=len(ctx), shuffle=False, num_workers=16*num_gpus,
batchify_fn=dgl_mp_batchify_fn)
vg_val = VGRelation(split="val")
logger.info("data loaded!")
val_data = gluon.data.DataLoader(
vg_val,
batch_size=len(ctx),
shuffle=False,
num_workers=16 * num_gpus,
batchify_fn=dgl_mp_batchify_fn,
)
n_batches = len(val_data)
detector = faster_rcnn_resnet101_v1d_custom(classes=vg_val.obj_classes,
pretrained_base=False, pretrained=False,
additional_output=True)
detector = faster_rcnn_resnet101_v1d_custom(
classes=vg_val.obj_classes,
pretrained_base=False,
pretrained=False,
additional_output=True,
)
params_path = args.pretrained_faster_rcnn_params
detector.load_parameters(params_path, ctx=ctx, ignore_extra=True, allow_missing=True)
detector.load_parameters(
params_path, ctx=ctx, ignore_extra=True, allow_missing=True
)
detector_feat = faster_rcnn_resnet101_v1d_custom(classes=vg_val.obj_classes,
pretrained_base=False, pretrained=False,
additional_output=True)
detector_feat.load_parameters(params_path, ctx=ctx, ignore_extra=True, allow_missing=True)
detector_feat = faster_rcnn_resnet101_v1d_custom(
classes=vg_val.obj_classes,
pretrained_base=False,
pretrained=False,
additional_output=True,
)
detector_feat.load_parameters(
params_path, ctx=ctx, ignore_extra=True, allow_missing=True
)
detector_feat.features.load_parameters(args.faster_rcnn_params, ctx=ctx)
def get_data_batch(g_list, img_list, ctx_list):
if g_list is None or len(g_list) == 0:
return None, None
......@@ -106,27 +168,39 @@ def get_data_batch(g_list, img_list, ctx_list):
if size < n_gpu:
raise Exception("too small batch")
step = size // n_gpu
G_list = [g_list[i*step:(i+1)*step] if i < n_gpu - 1 else g_list[i*step:size] for i in range(n_gpu)]
img_list = [img_list[i*step:(i+1)*step] if i < n_gpu - 1 else img_list[i*step:size] for i in range(n_gpu)]
G_list = [
g_list[i * step : (i + 1) * step]
if i < n_gpu - 1
else g_list[i * step : size]
for i in range(n_gpu)
]
img_list = [
img_list[i * step : (i + 1) * step]
if i < n_gpu - 1
else img_list[i * step : size]
for i in range(n_gpu)
]
for G_slice, ctx in zip(G_list, ctx_list):
for G in G_slice:
G.ndata['bbox'] = G.ndata['bbox'].as_in_context(ctx)
G.ndata['node_class'] = G.ndata['node_class'].as_in_context(ctx)
G.ndata['node_class_vec'] = G.ndata['node_class_vec'].as_in_context(ctx)
G.edata['rel_class'] = G.edata['rel_class'].as_in_context(ctx)
G.ndata["bbox"] = G.ndata["bbox"].as_in_context(ctx)
G.ndata["node_class"] = G.ndata["node_class"].as_in_context(ctx)
G.ndata["node_class_vec"] = G.ndata["node_class_vec"].as_in_context(
ctx
)
G.edata["rel_class"] = G.edata["rel_class"].as_in_context(ctx)
img_list = [img.as_in_context(ctx) for img in img_list]
return G_list, img_list
for i, (G_list, img_list) in enumerate(val_data):
G_list, img_list = get_data_batch(G_list, img_list, ctx)
if G_list is None or img_list is None:
if (i+1) % batch_verbose_freq == 0:
print_txt = 'Batch[%d/%d] '%\
(i, n_batches)
if (i + 1) % batch_verbose_freq == 0:
print_txt = "Batch[%d/%d] " % (i, n_batches)
for metric in metric_list:
metric_name, metric_val = metric.get()
print_txt += '%s=%.4f '%(metric_name, metric_val)
print_txt += "%s=%.4f " % (metric_name, metric_val)
logger.info(print_txt)
continue
......@@ -136,31 +210,57 @@ for i, (G_list, img_list) in enumerate(val_data):
# loss_cls_val = 0
for G_slice, img in zip(G_list, img_list):
cur_ctx = img.context
if mode == 'predcls':
bbox_list = [G.ndata['bbox'] for G in G_slice]
if mode == "predcls":
bbox_list = [G.ndata["bbox"] for G in G_slice]
bbox_stack = bbox_pad(bbox_list).as_in_context(cur_ctx)
ids, scores, bbox, spatial_feat = detector(img, None, None, bbox_stack)
ids, scores, bbox, spatial_feat = detector(
img, None, None, bbox_stack
)
node_class_list = [G.ndata['node_class'] for G in G_slice]
node_class_list = [G.ndata["node_class"] for G in G_slice]
node_class_stack = bbox_pad(node_class_list).as_in_context(cur_ctx)
g_pred_batch = build_graph_validate_gt_obj(img, node_class_stack, bbox, spatial_feat,
bbox_improvement=True, overlap=False)
elif mode == 'phrcls':
g_pred_batch = build_graph_validate_gt_obj(
img,
node_class_stack,
bbox,
spatial_feat,
bbox_improvement=True,
overlap=False,
)
elif mode == "phrcls":
# use ground truth bbox
bbox_list = [G.ndata['bbox'] for G in G_slice]
bbox_list = [G.ndata["bbox"] for G in G_slice]
bbox_stack = bbox_pad(bbox_list).as_in_context(cur_ctx)
ids, scores, bbox, spatial_feat = detector(img, None, None, bbox_stack)
ids, scores, bbox, spatial_feat = detector(
img, None, None, bbox_stack
)
g_pred_batch = build_graph_validate_gt_bbox(img, ids, scores, bbox, spatial_feat,
bbox_improvement=True, overlap=False)
g_pred_batch = build_graph_validate_gt_bbox(
img,
ids,
scores,
bbox,
spatial_feat,
bbox_improvement=True,
overlap=False,
)
else:
# use predicted bbox
ids, scores, bbox, feat, feat_ind, spatial_feat = detector(img)
g_pred_batch = build_graph_validate_pred(img, ids, scores, bbox, feat_ind, spatial_feat,
bbox_improvement=True, scores_top_k=75, overlap=False)
g_pred_batch = build_graph_validate_pred(
img,
ids,
scores,
bbox,
feat_ind,
spatial_feat,
bbox_improvement=True,
scores_top_k=75,
overlap=False,
)
if not semantic_only:
rel_bbox = g_pred_batch.edata['rel_bbox']
batch_id = g_pred_batch.edata['batch_id'].asnumpy()
rel_bbox = g_pred_batch.edata["rel_bbox"]
batch_id = g_pred_batch.edata["batch_id"].asnumpy()
n_sample_edges = g_pred_batch.number_of_edges()
# g_pred_batch.edata['edge_feat'] = mx.nd.zeros((n_sample_edges, 49), ctx=cur_ctx)
n_graph = len(G_slice)
......@@ -170,13 +270,19 @@ for i, (G_list, img_list) in enumerate(val_data):
if len(eids) > 0:
bbox_rel_list.append(rel_bbox[eids])
bbox_rel_stack = bbox_pad(bbox_rel_list).as_in_context(cur_ctx)
_, _, _, spatial_feat_rel = detector_feat(img, None, None, bbox_rel_stack)
_, _, _, spatial_feat_rel = detector_feat(
img, None, None, bbox_rel_stack
)
spatial_feat_rel_list = []
for j in range(n_graph):
eids = np.where(batch_id == j)[0]
if len(eids) > 0:
spatial_feat_rel_list.append(spatial_feat_rel[j, 0:len(eids)])
g_pred_batch.edata['edge_feat'] = nd.concat(*spatial_feat_rel_list, dim=0)
spatial_feat_rel_list.append(
spatial_feat_rel[j, 0 : len(eids)]
)
g_pred_batch.edata["edge_feat"] = nd.concat(
*spatial_feat_rel_list, dim=0
)
G_batch.append(g_pred_batch)
......@@ -189,23 +295,25 @@ for i, (G_list, img_list) in enumerate(val_data):
gt_objects, gt_triplet = extract_gt(G_gt, img_slice.shape[2:4])
pred_objects, pred_triplet = extract_pred(G_pred, joint_preds=True)
for metric in metric_list:
if isinstance(metric, PredCls) or \
isinstance(metric, PhrCls) or \
isinstance(metric, SGDet):
if (
isinstance(metric, PredCls)
or isinstance(metric, PhrCls)
or isinstance(metric, SGDet)
):
metric.update(gt_triplet, pred_triplet)
else:
metric.update((gt_objects, gt_triplet), (pred_objects, pred_triplet))
if (i+1) % batch_verbose_freq == 0:
print_txt = 'Batch[%d/%d] '%\
(i, n_batches)
metric.update(
(gt_objects, gt_triplet), (pred_objects, pred_triplet)
)
if (i + 1) % batch_verbose_freq == 0:
print_txt = "Batch[%d/%d] " % (i, n_batches)
for metric in metric_list:
metric_name, metric_val = metric.get()
print_txt += '%s=%.4f '%(metric_name, metric_val)
print_txt += "%s=%.4f " % (metric_name, metric_val)
logger.info(print_txt)
print_txt = 'Batch[%d/%d] '%\
(n_batches, n_batches)
print_txt = "Batch[%d/%d] " % (n_batches, n_batches)
for metric in metric_list:
metric_name, metric_val = metric.get()
print_txt += '%s=%.4f '%(metric_name, metric_val)
print_txt += "%s=%.4f " % (metric_name, metric_val)
logger.info(print_txt)
......@@ -5,14 +5,18 @@ Paper: https://arxiv.org/abs/1902.07153
Code: https://github.com/Tiiiger/SGC
SGC implementation in DGL.
"""
import argparse, time, math
import numpy as np
import argparse
import math
import time
import mxnet as mx
from mxnet import nd, gluon
import numpy as np
from mxnet import gluon, nd
from mxnet.gluon import nn
import dgl
from dgl.data import register_data_args
from dgl.data import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset
from dgl.data import (CiteseerGraphDataset, CoraGraphDataset,
PubmedGraphDataset, register_data_args)
from dgl.nn.mxnet.conv import SGConv
......@@ -21,16 +25,17 @@ def evaluate(model, g, features, labels, mask):
accuracy = ((pred == labels) * mask).sum() / mask.sum().asscalar()
return accuracy.asscalar()
def main(args):
# load and preprocess dataset
if args.dataset == 'cora':
if args.dataset == "cora":
data = CoraGraphDataset()
elif args.dataset == 'citeseer':
elif args.dataset == "citeseer":
data = CiteseerGraphDataset()
elif args.dataset == 'pubmed':
elif args.dataset == "pubmed":
data = PubmedGraphDataset()
else:
raise ValueError('Unknown dataset: {}'.format(args.dataset))
raise ValueError("Unknown dataset: {}".format(args.dataset))
g = data[0]
if args.gpu < 0:
......@@ -41,35 +46,36 @@ def main(args):
ctx = mx.gpu(args.gpu)
g = g.int().to(ctx)
features = g.ndata['feat']
labels = mx.nd.array(g.ndata['label'], dtype="float32", ctx=ctx)
train_mask = g.ndata['train_mask']
val_mask = g.ndata['val_mask']
test_mask = g.ndata['test_mask']
features = g.ndata["feat"]
labels = mx.nd.array(g.ndata["label"], dtype="float32", ctx=ctx)
train_mask = g.ndata["train_mask"]
val_mask = g.ndata["val_mask"]
test_mask = g.ndata["test_mask"]
in_feats = features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()
print("""----Data statistics------'
print(
"""----Data statistics------'
#Edges %d
#Classes %d
#Train samples %d
#Val samples %d
#Test samples %d""" %
(n_edges, n_classes,
train_mask.sum().asscalar(),
val_mask.sum().asscalar(),
test_mask.sum().asscalar()))
#Test samples %d"""
% (
n_edges,
n_classes,
train_mask.sum().asscalar(),
val_mask.sum().asscalar(),
test_mask.sum().asscalar(),
)
)
# add self loop
g = dgl.remove_self_loop(g)
g = dgl.add_self_loop(g)
# create SGC model
model = SGConv(in_feats,
n_classes,
k=2,
cached=True,
bias=args.bias)
model = SGConv(in_feats, n_classes, k=2, cached=True, bias=args.bias)
model.initialize(ctx=ctx)
n_train_samples = train_mask.sum().asscalar()
......@@ -77,8 +83,11 @@ def main(args):
# use optimizer
print(model.collect_params())
trainer = gluon.Trainer(model.collect_params(), 'adam',
{'learning_rate': args.lr, 'wd': args.weight_decay})
trainer = gluon.Trainer(
model.collect_params(),
"adam",
{"learning_rate": args.lr, "wd": args.weight_decay},
)
# initialize graph
dur = []
......@@ -98,28 +107,36 @@ def main(args):
loss.asscalar()
dur.append(time.time() - t0)
acc = evaluate(model, g, features, labels, val_mask)
print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
"ETputs(KTEPS) {:.2f}". format(
epoch, np.mean(dur), loss.asscalar(), acc, n_edges / np.mean(dur) / 1000))
print(
"Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
"ETputs(KTEPS) {:.2f}".format(
epoch,
np.mean(dur),
loss.asscalar(),
acc,
n_edges / np.mean(dur) / 1000,
)
)
# test set accuracy
acc = evaluate(model, g, features, labels, test_mask)
print("Test accuracy {:.2%}".format(acc))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='SGC')
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="SGC")
register_data_args(parser)
parser.add_argument("--gpu", type=int, default=-1,
help="gpu")
parser.add_argument("--lr", type=float, default=0.2,
help="learning rate")
parser.add_argument("--bias", action='store_true', default=False,
help="flag to use bias")
parser.add_argument("--n-epochs", type=int, default=100,
help="number of training epochs")
parser.add_argument("--weight-decay", type=float, default=5e-6,
help="Weight for L2 loss")
parser.add_argument("--gpu", type=int, default=-1, help="gpu")
parser.add_argument("--lr", type=float, default=0.2, help="learning rate")
parser.add_argument(
"--bias", action="store_true", default=False, help="flag to use bias"
)
parser.add_argument(
"--n-epochs", type=int, default=100, help="number of training epochs"
)
parser.add_argument(
"--weight-decay", type=float, default=5e-6, help="Weight for L2 loss"
)
args = parser.parse_args()
print(args)
......
......@@ -6,18 +6,15 @@ References:
"""
import mxnet as mx
from mxnet import gluon
import dgl
from dgl.nn.mxnet import TAGConv
class TAGCN(gluon.Block):
def __init__(self,
g,
in_feats,
n_hidden,
n_classes,
n_layers,
activation,
dropout):
def __init__(
self, g, in_feats, n_hidden, n_classes, n_layers, activation, dropout
):
super(TAGCN, self).__init__()
self.g = g
self.layers = gluon.nn.Sequential()
......@@ -27,7 +24,7 @@ class TAGCN(gluon.Block):
for i in range(n_layers - 1):
self.layers.add(TAGConv(n_hidden, n_hidden, activation=activation))
# output layer
self.layers.add(TAGConv(n_hidden, n_classes)) #activation=None
self.layers.add(TAGConv(n_hidden, n_classes)) # activation=None
self.dropout = gluon.nn.Dropout(rate=dropout)
def forward(self, features):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment