"projects/Panoptic-DeepLab/vscode:/vscode.git/clone" did not exist on "b634945d8ce3fbcbbcf2fc89e62cf7de03b17673"
Unverified Commit 2570d412 authored by Jinjing Zhou's avatar Jinjing Zhou Committed by GitHub
Browse files

[NN] Add fast path for GateGCNConv when it has only one edge type (#2994)

* fix gatedgcn

* fix lint
parent 411bef54
...@@ -61,6 +61,7 @@ class GatedGraphConv(nn.Module): ...@@ -61,6 +61,7 @@ class GatedGraphConv(nn.Module):
[ 0.6393, 0.3447, 0.3893, 0.4279, 0.3342, 0.3809, 0.0406, 0.5030, [ 0.6393, 0.3447, 0.3893, 0.4279, 0.3342, 0.3809, 0.0406, 0.5030,
0.1342, 0.0425]], grad_fn=<AddBackward0>) 0.1342, 0.0425]], grad_fn=<AddBackward0>)
""" """
def __init__(self, def __init__(self,
in_feats, in_feats,
out_feats, out_feats,
...@@ -110,7 +111,7 @@ class GatedGraphConv(nn.Module): ...@@ -110,7 +111,7 @@ class GatedGraphConv(nn.Module):
""" """
self._allow_zero_in_degree = set_value self._allow_zero_in_degree = set_value
def forward(self, graph, feat, etypes): def forward(self, graph, feat, etypes=None):
""" """
Description Description
...@@ -125,9 +126,10 @@ class GatedGraphConv(nn.Module): ...@@ -125,9 +126,10 @@ class GatedGraphConv(nn.Module):
The input feature of shape :math:`(N, D_{in})` where :math:`N` The input feature of shape :math:`(N, D_{in})` where :math:`N`
is the number of nodes of the graph and :math:`D_{in}` is the is the number of nodes of the graph and :math:`D_{in}` is the
input feature size. input feature size.
etypes : torch.LongTensor etypes : torch.LongTensor, or None
The edge type tensor of shape :math:`(E,)` where :math:`E` is The edge type tensor of shape :math:`(E,)` where :math:`E` is
the number of edges of the graph. the number of edges of the graph. When there's only one edge type,
this argument can be skipped
Returns Returns
------- -------
...@@ -139,18 +141,30 @@ class GatedGraphConv(nn.Module): ...@@ -139,18 +141,30 @@ class GatedGraphConv(nn.Module):
assert graph.is_homogeneous, \ assert graph.is_homogeneous, \
"not a homogeneous graph; convert it with to_homogeneous " \ "not a homogeneous graph; convert it with to_homogeneous " \
"and pass in the edge type as argument" "and pass in the edge type as argument"
if self._n_etypes != 1:
assert etypes.min() >= 0 and etypes.max() < self._n_etypes, \ assert etypes.min() >= 0 and etypes.max() < self._n_etypes, \
"edge type indices out of range [0, {})".format(self._n_etypes) "edge type indices out of range [0, {})".format(
zero_pad = feat.new_zeros((feat.shape[0], self._out_feats - feat.shape[1])) self._n_etypes)
zero_pad = feat.new_zeros(
(feat.shape[0], self._out_feats - feat.shape[1]))
feat = th.cat([feat, zero_pad], -1) feat = th.cat([feat, zero_pad], -1)
for _ in range(self._n_steps): for _ in range(self._n_steps):
if self._n_etypes == 1 and etypes is None:
# Fast path when graph has only one edge type
graph.ndata['h'] = self.linears[0](feat)
graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'a'))
a = graph.ndata.pop('a') # (N, D)
else:
graph.ndata['h'] = feat graph.ndata['h'] = feat
for i in range(self._n_etypes): for i in range(self._n_etypes):
eids = th.nonzero(etypes == i, as_tuple=False).view(-1).type(graph.idtype) eids = th.nonzero(
etypes == i, as_tuple=False).view(-1).type(graph.idtype)
if len(eids) > 0: if len(eids) > 0:
graph.apply_edges( graph.apply_edges(
lambda edges: {'W_e*h': self.linears[i](edges.src['h'])}, lambda edges: {
'W_e*h': self.linears[i](edges.src['h'])},
eids eids
) )
graph.update_all(fn.copy_e('W_e*h', 'm'), fn.sum('m', 'a')) graph.update_all(fn.copy_e('W_e*h', 'm'), fn.sum('m', 'a'))
......
...@@ -725,6 +725,23 @@ def test_gated_graph_conv(g, idtype): ...@@ -725,6 +725,23 @@ def test_gated_graph_conv(g, idtype):
# current we only do shape check # current we only do shape check
assert h.shape[-1] == 10 assert h.shape[-1] == 10
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
def test_gated_graph_conv_one_etype(g, idtype):
ctx = F.ctx()
g = g.astype(idtype).to(ctx)
ggconv = nn.GatedGraphConv(5, 10, 5, 1)
etypes = th.zeros(g.number_of_edges())
feat = F.randn((g.number_of_nodes(), 5))
ggconv = ggconv.to(ctx)
etypes = etypes.to(ctx)
h = ggconv(g, feat, etypes)
h2 = ggconv(g, feat)
# current we only do shape check
assert F.allclose(h, h2)
assert h.shape[-1] == 10
@parametrize_dtype @parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree'])) @pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
def test_nn_conv(g, idtype): def test_nn_conv(g, idtype):
...@@ -1113,6 +1130,7 @@ if __name__ == '__main__': ...@@ -1113,6 +1130,7 @@ if __name__ == '__main__':
test_gin_conv() test_gin_conv()
test_agnn_conv() test_agnn_conv()
test_gated_graph_conv() test_gated_graph_conv()
test_gated_graph_conv_one_etype()
test_nn_conv() test_nn_conv()
test_gmm_conv() test_gmm_conv()
test_dotgat_conv() test_dotgat_conv()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment