reldn.py 5.4 KB
Newer Older
1
2
import pickle

3
4
5
6
7
8
import gluoncv as gcv
import mxnet as mx
import numpy as np
from mxnet import nd
from mxnet.gluon import nn

9
import dgl
10
from dgl.nn.mxnet import GraphConv
11
12
13
from dgl.utils import toindex

__all__ = ["RelDN"]
14
15
16


class EdgeConfMLP(nn.Block):
17
18
    """compute the confidence for edges"""

19
20
21
22
    def __init__(self):
        super(EdgeConfMLP, self).__init__()

    def forward(self, edges):
23
24
25
26
27
28
29
30
        score_pred = nd.log_softmax(edges.data["preds"])[:, 1:].max(axis=1)
        score_phr = (
            score_pred
            + edges.src["node_class_logit"]
            + edges.dst["node_class_logit"]
        )
        return {"score_pred": score_pred, "score_phr": score_phr}

31
32

class EdgeBBoxExtend(nn.Block):
33
34
    """encode the bounding boxes"""

35
36
37
38
39
40
    def __init__(self):
        super(EdgeBBoxExtend, self).__init__()

    def bbox_delta(self, bbox_a, bbox_b):
        n = bbox_a.shape[0]
        result = nd.zeros((n, 4), ctx=bbox_a.context)
41
42
43
44
45
46
47
48
49
50
        result[:, 0] = bbox_a[:, 0] - bbox_b[:, 0]
        result[:, 1] = bbox_a[:, 1] - bbox_b[:, 1]
        result[:, 2] = nd.log(
            (bbox_a[:, 2] - bbox_a[:, 0] + 1e-8)
            / (bbox_b[:, 2] - bbox_b[:, 0] + 1e-8)
        )
        result[:, 3] = nd.log(
            (bbox_a[:, 3] - bbox_a[:, 1] + 1e-8)
            / (bbox_b[:, 3] - bbox_b[:, 1] + 1e-8)
        )
51
52
53
        return result

    def forward(self, edges):
54
55
56
57
58
59
60
61
62
63
64
        ctx = edges.src["pred_bbox"].context
        n = edges.src["pred_bbox"].shape[0]
        delta_src_obj = self.bbox_delta(
            edges.src["pred_bbox"], edges.dst["pred_bbox"]
        )
        delta_src_rel = self.bbox_delta(
            edges.src["pred_bbox"], edges.data["rel_bbox"]
        )
        delta_rel_obj = self.bbox_delta(
            edges.data["rel_bbox"], edges.dst["pred_bbox"]
        )
65
        result = nd.zeros((n, 12), ctx=ctx)
66
67
68
69
70
        result[:, 0:4] = delta_src_obj
        result[:, 4:8] = delta_src_rel
        result[:, 8:12] = delta_rel_obj
        return {"pred_bbox_additional": result}

71
72

class EdgeFreqPrior(nn.Block):
73
74
    """make use of the pre-trained frequency prior"""

75
76
    def __init__(self, prior_pkl):
        super(EdgeFreqPrior, self).__init__()
77
        with open(prior_pkl, "rb") as f:
78
79
80
81
            freq_prior = pickle.load(f)
        self.freq_prior = freq_prior

    def forward(self, edges):
82
83
84
        ctx = edges.src["node_class_pred"].context
        src_ind = edges.src["node_class_pred"].asnumpy().astype(int)
        dst_ind = edges.dst["node_class_pred"].asnumpy().astype(int)
85
86
        prob = self.freq_prior[src_ind, dst_ind]
        out = nd.array(prob, ctx=ctx)
87
88
        return {"freq_prior": out}

89
90

class EdgeSpatial(nn.Block):
91
92
    """spatial feature branch"""

93
94
95
96
97
98
99
100
101
102
    def __init__(self, n_classes):
        super(EdgeSpatial, self).__init__()
        self.mlp = nn.Sequential()
        self.mlp.add(nn.Dense(64))
        self.mlp.add(nn.LeakyReLU(0.1))
        self.mlp.add(nn.Dense(64))
        self.mlp.add(nn.LeakyReLU(0.1))
        self.mlp.add(nn.Dense(n_classes))

    def forward(self, edges):
103
104
105
106
107
108
        feat = nd.concat(
            edges.src["pred_bbox"],
            edges.dst["pred_bbox"],
            edges.data["rel_bbox"],
            edges.data["pred_bbox_additional"],
        )
109
        out = self.mlp(feat)
110
111
        return {"spatial": out}

112
113

class EdgeVisual(nn.Block):
114
115
116
    """visual feature branch"""

    def __init__(self, n_classes, vis_feat_dim=7 * 7 * 3):
117
118
119
120
121
122
123
124
125
126
127
128
129
        super(EdgeVisual, self).__init__()
        self.dim_in = vis_feat_dim
        self.mlp_joint = nn.Sequential()
        self.mlp_joint.add(nn.Dense(vis_feat_dim // 2))
        self.mlp_joint.add(nn.LeakyReLU(0.1))
        self.mlp_joint.add(nn.Dense(vis_feat_dim // 3))
        self.mlp_joint.add(nn.LeakyReLU(0.1))
        self.mlp_joint.add(nn.Dense(n_classes))

        self.mlp_sub = nn.Dense(n_classes)
        self.mlp_ob = nn.Dense(n_classes)

    def forward(self, edges):
130
131
132
133
134
        feat = nd.concat(
            edges.src["node_feat"],
            edges.dst["node_feat"],
            edges.data["edge_feat"],
        )
135
        out_joint = self.mlp_joint(feat)
136
137
        out_sub = self.mlp_sub(edges.src["node_feat"])
        out_ob = self.mlp_ob(edges.dst["node_feat"])
138
        out = out_joint + out_sub + out_ob
139
140
        return {"visual": out}

141
142

class RelDN(nn.Block):
143
144
    """The RelDN Model"""

145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
    def __init__(self, n_classes, prior_pkl, semantic_only=False):
        super(RelDN, self).__init__()
        # output layers
        self.edge_bbox_extend = EdgeBBoxExtend()
        # semantic through mlp encoding
        if prior_pkl is not None:
            self.freq_prior = EdgeFreqPrior(prior_pkl)

        # with predicate class and a link class
        self.spatial = EdgeSpatial(n_classes + 1)
        # with visual features
        self.visual = EdgeVisual(n_classes + 1)
        self.edge_conf_mlp = EdgeConfMLP()
        self.semantic_only = semantic_only

160
    def forward(self, g):
161
162
163
164
165
        if g is None or g.number_of_nodes() == 0:
            return g
        # predictions
        g.apply_edges(self.freq_prior)
        if self.semantic_only:
166
            g.edata["preds"] = g.edata["freq_prior"]
167
168
169
170
171
        else:
            # bbox extension
            g.apply_edges(self.edge_bbox_extend)
            g.apply_edges(self.spatial)
            g.apply_edges(self.visual)
172
173
174
            g.edata["preds"] = (
                g.edata["freq_prior"] + g.edata["spatial"] + g.edata["visual"]
            )
175
176
177
        # subgraph for gconv
        g.apply_edges(self.edge_conf_mlp)
        return g