"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "f4a44b7707d9f481f4188155443279a85ba9e67d"
Unverified Commit efd909e6 authored by Riju Mukherjee's avatar Riju Mukherjee Committed by GitHub
Browse files

[NN] Enhance EGATConv branch (#4062)



* enhance EGATConv| nfeats as tuples

* egatconv modified for bipartite graphs

* modified docstrings

* added/modified unittests for EGATConv

* Update egatconv.py

* rectified lint errors
Co-authored-by: default avatarrijulizer <riju.mukherjee@gmail.com>
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
parent 92063d88
......@@ -7,6 +7,7 @@ from torch.nn import init
from .... import function as fn
from ...functional import edge_softmax
from ....base import DGLError
from ....utils import expand_as_pair
# pylint: enable=W0235
class EGATConv(nn.Module):
......@@ -27,8 +28,14 @@ class EGATConv(nn.Module):
Parameters
----------
in_node_feats : int
Input node feature size :math:`h_{i}`.
in_node_feats : int, or pair of ints
Input feature size; i.e, the number of dimensions of :math:`h_{i}`.
EGATConv can be applied on homogeneous graph and unidirectional
`bipartite graph <https://docs.dgl.ai/generated/dgl.bipartite.html?highlight=bipartite>`__.
If the layer is to be applied to a unidirectional bipartite graph, ``in_feats``
specifies the input feature size on both the source and destination nodes. If
a scalar is given, the source and destination node feature size would take the
same value.
in_edge_feats : int
Input edge feature size :math:`f_{ij}`.
out_node_feats : int
......@@ -46,10 +53,10 @@ class EGATConv(nn.Module):
>>> import torch as th
>>> from dgl.nn import EGATConv
>>> # Case 1: Homogeneous graph
>>> num_nodes, num_edges = 8, 30
>>> # generate a graph
>>> graph = dgl.rand_graph(num_nodes,num_edges)
>>> node_feats = th.rand((num_nodes, 20))
>>> edge_feats = th.rand((num_edges, 12))
>>> egat = EGATConv(in_node_feats=20,
......@@ -61,8 +68,33 @@ class EGATConv(nn.Module):
>>> new_node_feats, new_edge_feats = egat(graph, node_feats, edge_feats)
>>> new_node_feats.shape, new_edge_feats.shape
torch.Size([8, 3, 15]) torch.Size([30, 3, 10])
"""
>>> # Case 2: Unidirectional bipartite graph
>>> u = [0, 1, 0, 0, 1]
>>> v = [0, 1, 2, 3, 2]
>>> g = dgl.heterograph({('A', 'r', 'B'): (u, v)})
>>> u_feat = th.tensor(np.random.rand(2, 25).astype(np.float32))
>>> v_feat = th.tensor(np.random.rand(4, 30).astype(np.float32))
>>> nfeats = (u_feat,v_feat)
>>> efeats = th.tensor(np.random.rand(5, 15).astype(np.float32))
>>> in_node_feats = (25,30)
>>> in_edge_feats = 15
>>> out_node_feats = 10
>>> out_edge_feats = 5
>>> num_heads = 3
>>> egat_model = EGATConv(in_node_feats,
... in_edge_feats,
... out_node_feats,
... out_edge_feats,
... num_heads,
... bias=True)
>>> #forward pass
>>> new_node_feats,
>>> new_edge_feats,
>>> attentions = egat_model(g, nfeats, efeats, get_attention=True)
>>> new_node_feats.shape, new_edge_feats.shape, attentions.shape
(torch.Size([4, 3, 10]), torch.Size([5, 3, 5]), torch.Size([5, 3, 1]))
"""
def __init__(self,
in_node_feats,
in_edge_feats,
......@@ -73,12 +105,25 @@ class EGATConv(nn.Module):
super().__init__()
self._num_heads = num_heads
self._in_src_node_feats, self._in_dst_node_feats = expand_as_pair(in_node_feats)
self._out_node_feats = out_node_feats
self._out_edge_feats = out_edge_feats
self.fc_node = nn.Linear(in_node_feats, out_node_feats*num_heads, bias=True)
self.fc_ni = nn.Linear(in_node_feats, out_edge_feats*num_heads, bias=False)
if isinstance(in_node_feats, tuple):
self.fc_node_src = nn.Linear(
self._in_src_node_feats, out_node_feats * num_heads, bias=False)
self.fc_ni = nn.Linear(
self._in_src_node_feats, out_edge_feats*num_heads, bias=False)
self.fc_nj = nn.Linear(
self._in_dst_node_feats, out_edge_feats*num_heads, bias=False)
else:
self.fc_node_src = nn.Linear(
self._in_src_node_feats, out_node_feats * num_heads, bias=False)
self.fc_ni = nn.Linear(
self._in_src_node_feats, out_edge_feats*num_heads, bias=False)
self.fc_nj = nn.Linear(
self._in_src_node_feats, out_edge_feats*num_heads, bias=False)
self.fc_fij = nn.Linear(in_edge_feats, out_edge_feats*num_heads, bias=False)
self.fc_nj = nn.Linear(in_node_feats, out_edge_feats*num_heads, bias=False)
self.attn = nn.Parameter(th.FloatTensor(size=(1, num_heads, out_edge_feats)))
if bias:
self.bias = nn.Parameter(th.FloatTensor(size=(num_heads * out_edge_feats,)))
......@@ -91,7 +136,7 @@ class EGATConv(nn.Module):
Reinitialize learnable parameters.
"""
gain = init.calculate_gain('relu')
init.xavier_normal_(self.fc_node.weight, gain=gain)
init.xavier_normal_(self.fc_node_src.weight, gain=gain)
init.xavier_normal_(self.fc_ni.weight, gain=gain)
init.xavier_normal_(self.fc_fij.weight, gain=gain)
init.xavier_normal_(self.fc_nj.weight, gain=gain)
......@@ -106,11 +151,14 @@ class EGATConv(nn.Module):
----------
graph : DGLGraph
The graph.
nfeats : torch.Tensor
The input node feature of shape :math:`(N, D_{in})`
nfeat : torch.Tensor or pair of torch.Tensor
If a torch.Tensor is given, the input feature of shape :math:`(N, D_{in})`
where:
:math:`D_{in}` is size of input node feature,
:math:`N` is the number of nodes.
If a pair of torch.Tensor is given, the pair must contain two tensors of shape
:math:`(N_{in}, D_{in_{src}})` and
:math:`(N_{out}, D_{in_{dst}})`.
efeats: torch.Tensor
The input edge feature of shape :math:`(E, F_{in})`
where:
......@@ -144,13 +192,18 @@ class EGATConv(nn.Module):
'calling `g = dgl.add_self_loop(g)` will resolve '
'the issue.')
# TODO allow node src and dst feats
# calc edge attention
# same trick way as in dgl.nn.pytorch.GATConv, but also includes edge feats
# https://github.com/dmlc/dgl/blob/master/python/dgl/nn/pytorch/conv/gatconv.py
f_ni = self.fc_ni(nfeats)
f_nj = self.fc_nj(nfeats)
if isinstance(nfeats, tuple):
nfeats_src, nfeats_dst = nfeats
else:
nfeats_src = nfeats_dst = nfeats
f_ni = self.fc_ni(nfeats_src)
f_nj = self.fc_nj(nfeats_dst)
f_fij = self.fc_fij(efeats)
graph.srcdata.update({'f_ni': f_ni})
graph.dstdata.update({'f_nj': f_nj})
# add ni, nj factors
......@@ -164,13 +217,13 @@ class EGATConv(nn.Module):
# compute attention factor
e = (f_out * self.attn).sum(dim=-1).unsqueeze(-1)
graph.edata['a'] = edge_softmax(graph, e)
graph.ndata['h_out'] = self.fc_node(nfeats).view(-1, self._num_heads,
graph.srcdata['h_out'] = self.fc_node_src(nfeats_src).view(-1, self._num_heads,
self._out_node_feats)
# calc weighted sum
graph.update_all(fn.u_mul_e('h_out', 'a', 'm'),
fn.sum('m', 'h_out'))
h_out = graph.ndata['h_out'].view(-1, self._num_heads, self._out_node_feats)
h_out = graph.dstdata['h_out'].view(-1, self._num_heads, self._out_node_feats)
if get_attention:
return h_out, f_out, graph.edata.pop('a')
else:
......
......@@ -506,13 +506,41 @@ def test_egat_conv(g, idtype, out_node_feats, out_edge_feats, num_heads):
num_heads=num_heads)
nfeat = F.randn((g.number_of_nodes(), 10))
efeat = F.randn((g.number_of_edges(), 5))
egat = egat.to(ctx)
h, f = egat(g, nfeat, efeat)
h, f, attn = egat(g, nfeat, efeat, True)
th.save(egat, tmp_buffer)
assert h.shape == (g.number_of_nodes(), num_heads, out_node_feats)
assert f.shape == (g.number_of_edges(), num_heads, out_edge_feats)
_, _, attn = egat(g, nfeat, efeat, True)
assert attn.shape == (g.number_of_edges(), num_heads, 1)
@parametrize_idtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
@pytest.mark.parametrize('out_node_feats', [1, 5])
@pytest.mark.parametrize('out_edge_feats', [1, 5])
@pytest.mark.parametrize('num_heads', [1, 4])
def test_egat_conv_bi(g, idtype, out_node_feats, out_edge_feats, num_heads):
g = g.astype(idtype).to(F.ctx())
ctx = F.ctx()
egat = nn.EGATConv(in_node_feats=(10,15),
in_edge_feats=7,
out_node_feats=out_node_feats,
out_edge_feats=out_edge_feats,
num_heads=num_heads)
nfeat = (F.randn((g.number_of_src_nodes(), 10)), F.randn((g.number_of_dst_nodes(), 15)))
efeat = F.randn((g.number_of_edges(), 7))
egat = egat.to(ctx)
h, f = egat(g, nfeat, efeat)
th.save(egat, tmp_buffer)
assert h.shape == (g.number_of_dst_nodes(), num_heads, out_node_feats)
assert f.shape == (g.number_of_edges(), num_heads, out_edge_feats)
_, _, attn = egat(g, nfeat, efeat, True)
assert attn.shape == (g.number_of_edges(), num_heads, 1)
@parametrize_idtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn', 'lstm'])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment