Unverified Commit 99831073 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Feature] Support multidimensional features for GAT (#2912)

* support multidimensional features for GAT

* docstring

* lint

* fix
parent caa6d607
......@@ -213,20 +213,20 @@ class GATConv(nn.Block):
graph : DGLGraph
The graph.
feat : mxnet.NDArray or pair of mxnet.NDArray
If a mxnet.NDArray is given, the input feature of shape :math:`(N, D_{in})` where
If a mxnet.NDArray is given, the input feature of shape :math:`(N, *, D_{in})` where
:math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
If a pair of mxnet.NDArray is given, the pair must contain two tensors of shape
:math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.
:math:`(N_{in}, *, D_{in_{src}})` and :math:`(N_{out}, *, D_{in_{dst}})`.
get_attention : bool, optional
Whether to return the attention values. Default to False.
Returns
-------
mxnet.NDArray
The output feature of shape :math:`(N, H, D_{out})` where :math:`H`
The output feature of shape :math:`(N, *, H, D_{out})` where :math:`H`
is the number of heads, and :math:`D_{out}` is size of output feature.
mxnet.NDArray, optional
The attention values of shape :math:`(E, H, 1)`, where :math:`E` is the number of
The attention values of shape :math:`(E, *, H, 1)`, where :math:`E` is the number of
edges. This is returned only when :attr:`get_attention` is ``True``.
Raises
......@@ -250,18 +250,23 @@ class GATConv(nn.Block):
'suppress the check and let the code run.')
if isinstance(feat, tuple):
src_prefix_shape = feat[0].shape[:-1]
dst_prefix_shape = feat[1].shape[:-1]
feat_dim = feat[0].shape[-1]
h_src = self.feat_drop(feat[0])
h_dst = self.feat_drop(feat[1])
if not hasattr(self, 'fc_src'):
self.fc_src, self.fc_dst = self.fc, self.fc
feat_src = self.fc_src(h_src).reshape(
-1, self._num_heads, self._out_feats)
feat_dst = self.fc_dst(h_dst).reshape(
-1, self._num_heads, self._out_feats)
feat_src = self.fc_src(h_src.reshape(-1, feat_dim)).reshape(
*src_prefix_shape, self._num_heads, self._out_feats)
feat_dst = self.fc_dst(h_dst.reshape(-1, feat_dim)).reshape(
*dst_prefix_shape, self._num_heads, self._out_feats)
else:
src_prefix_shape = dst_prefix_shape = feat.shape[:-1]
feat_dim = feat[0].shape[-1]
h_src = h_dst = self.feat_drop(feat)
feat_src = feat_dst = self.fc(h_src).reshape(
-1, self._num_heads, self._out_feats)
feat_src = feat_dst = self.fc(h_src.reshape(-1, feat_dim)).reshape(
*src_prefix_shape, self._num_heads, self._out_feats)
if graph.is_block:
feat_dst = feat_src[:graph.number_of_dst_nodes()]
# NOTE: GAT paper uses "first concatenation then linear projection"
......@@ -288,7 +293,8 @@ class GATConv(nn.Block):
rst = graph.dstdata['ft']
# residual
if self.res_fc is not None:
resval = self.res_fc(h_dst).reshape(h_dst.shape[0], -1, self._out_feats)
resval = self.res_fc(h_dst.reshape(-1, feat_dim)).reshape(
*dst_prefix_shape, -1, self._out_feats)
rst = rst + resval
# activation
if self.activation:
......
......@@ -229,20 +229,20 @@ class GATConv(nn.Module):
graph : DGLGraph
The graph.
feat : torch.Tensor or pair of torch.Tensor
If a torch.Tensor is given, the input feature of shape :math:`(N, D_{in})` where
If a torch.Tensor is given, the input feature of shape :math:`(N, *, D_{in})` where
:math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
If a pair of torch.Tensor is given, the pair must contain two tensors of shape
:math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.
:math:`(N_{in}, *, D_{in_{src}})` and :math:`(N_{out}, *, D_{in_{dst}})`.
get_attention : bool, optional
Whether to return the attention values. Default to False.
Returns
-------
torch.Tensor
The output feature of shape :math:`(N, H, D_{out})` where :math:`H`
The output feature of shape :math:`(N, *, H, D_{out})` where :math:`H`
is the number of heads, and :math:`D_{out}` is size of output feature.
torch.Tensor, optional
The attention values of shape :math:`(E, H, 1)`, where :math:`E` is the number of
The attention values of shape :math:`(E, *, H, 1)`, where :math:`E` is the number of
edges. This is returned only when :attr:`get_attention` is ``True``.
Raises
......@@ -266,18 +266,25 @@ class GATConv(nn.Module):
'suppress the check and let the code run.')
if isinstance(feat, tuple):
src_prefix_shape = feat[0].shape[:-1]
dst_prefix_shape = feat[1].shape[:-1]
h_src = self.feat_drop(feat[0])
h_dst = self.feat_drop(feat[1])
if not hasattr(self, 'fc_src'):
feat_src = self.fc(h_src).view(-1, self._num_heads, self._out_feats)
feat_dst = self.fc(h_dst).view(-1, self._num_heads, self._out_feats)
feat_src = self.fc(h_src).view(
*src_prefix_shape, self._num_heads, self._out_feats)
feat_dst = self.fc(h_dst).view(
*dst_prefix_shape, self._num_heads, self._out_feats)
else:
feat_src = self.fc_src(h_src).view(-1, self._num_heads, self._out_feats)
feat_dst = self.fc_dst(h_dst).view(-1, self._num_heads, self._out_feats)
feat_src = self.fc_src(h_src).view(
*src_prefix_shape, self._num_heads, self._out_feats)
feat_dst = self.fc_dst(h_dst).view(
*dst_prefix_shape, self._num_heads, self._out_feats)
else:
src_prefix_shape = dst_prefix_shape = feat.shape[:-1]
h_src = h_dst = self.feat_drop(feat)
feat_src = feat_dst = self.fc(h_src).view(
-1, self._num_heads, self._out_feats)
*src_prefix_shape, self._num_heads, self._out_feats)
if graph.is_block:
feat_dst = feat_src[:graph.number_of_dst_nodes()]
# NOTE: GAT paper uses "first concatenation then linear projection"
......@@ -306,11 +313,12 @@ class GATConv(nn.Module):
# residual
if self.res_fc is not None:
# Use -1 rather than self._num_heads to handle broadcasting
resval = self.res_fc(h_dst).view(h_dst.shape[0], -1, self._out_feats)
resval = self.res_fc(h_dst).view(*dst_prefix_shape, -1, self._out_feats)
rst = rst + resval
# bias
if self.bias is not None:
rst = rst + self.bias.view(1, self._num_heads, self._out_feats)
rst = rst + self.bias.view(
*((1,) * len(dst_prefix_shape)), self._num_heads, self._out_feats)
# activation
if self.activation:
rst = self.activation(rst)
......
......@@ -207,20 +207,20 @@ class GATConv(layers.Layer):
graph : DGLGraph
The graph.
feat : tf.Tensor or pair of tf.Tensor
If a tf.Tensor is given, the input feature of shape :math:`(N, D_{in})` where
If a tf.Tensor is given, the input feature of shape :math:`(N, *, D_{in})` where
:math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
If a pair of tf.Tensor is given, the pair must contain two tensors of shape
:math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.
:math:`(N_{in}, *, D_{in_{src}})` and :math:`(N_{out}, *, D_{in_{dst}})`.
get_attention : bool, optional
Whether to return the attention values. Default to False.
Returns
-------
tf.Tensor
The output feature of shape :math:`(N, H, D_{out})` where :math:`H`
The output feature of shape :math:`(N, *, H, D_{out})` where :math:`H`
is the number of heads, and :math:`D_{out}` is size of output feature.
tf.Tensor, optional
The attention values of shape :math:`(E, H, 1)`, where :math:`E` is the number of
The attention values of shape :math:`(E, *, H, 1)`, where :math:`E` is the number of
edges. This is returned only when :attr:`get_attention` is ``True``.
Raises
......@@ -244,16 +244,23 @@ class GATConv(layers.Layer):
'suppress the check and let the code run.')
if isinstance(feat, tuple):
src_prefix_shape = tuple(feat[0].shape[:-1])
dst_prefix_shape = tuple(feat[1].shape[:-1])
h_src = self.feat_drop(feat[0])
h_dst = self.feat_drop(feat[1])
if not hasattr(self, 'fc_src'):
self.fc_src, self.fc_dst = self.fc, self.fc
feat_src = tf.reshape(self.fc_src(h_src), (-1, self._num_heads, self._out_feats))
feat_dst = tf.reshape(self.fc_dst(h_dst), (-1, self._num_heads, self._out_feats))
feat_src = tf.reshape(
self.fc_src(h_src),
src_prefix_shape + (self._num_heads, self._out_feats))
feat_dst = tf.reshape(
self.fc_dst(h_dst),
dst_prefix_shape + (self._num_heads, self._out_feats))
else:
src_prefix_shape = dst_prefix_shape = tuple(feat.shape[:-1])
h_src = h_dst = self.feat_drop(feat)
feat_src = feat_dst = tf.reshape(
self.fc(h_src), (-1, self._num_heads, self._out_feats))
self.fc(h_src), src_prefix_shape + (self._num_heads, self._out_feats))
if graph.is_block:
feat_dst = feat_src[:graph.number_of_dst_nodes()]
# NOTE: GAT paper uses "first concatenation then linear projection"
......@@ -282,7 +289,7 @@ class GATConv(layers.Layer):
# residual
if self.res_fc is not None:
resval = tf.reshape(self.res_fc(
h_dst), (h_dst.shape[0], -1, self._out_feats))
h_dst), dst_prefix_shape + (-1, self._out_feats))
rst = rst + resval
# activation
if self.activation:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment