"src/array/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "83115794c29ef1db47f7e7e2e4fde54c0d7f0a4a"
Unverified Commit 99831073 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Feature] Support multidimensional features for GAT (#2912)

* support multidimensional features for GAT

* docstring

* lint

* fix
parent caa6d607
...@@ -213,20 +213,20 @@ class GATConv(nn.Block): ...@@ -213,20 +213,20 @@ class GATConv(nn.Block):
graph : DGLGraph graph : DGLGraph
The graph. The graph.
feat : mxnet.NDArray or pair of mxnet.NDArray feat : mxnet.NDArray or pair of mxnet.NDArray
If a mxnet.NDArray is given, the input feature of shape :math:`(N, D_{in})` where If a mxnet.NDArray is given, the input feature of shape :math:`(N, *, D_{in})` where
:math:`D_{in}` is size of input feature, :math:`N` is the number of nodes. :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
If a pair of mxnet.NDArray is given, the pair must contain two tensors of shape If a pair of mxnet.NDArray is given, the pair must contain two tensors of shape
:math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`. :math:`(N_{in}, *, D_{in_{src}})` and :math:`(N_{out}, *, D_{in_{dst}})`.
get_attention : bool, optional get_attention : bool, optional
Whether to return the attention values. Default to False. Whether to return the attention values. Default to False.
Returns Returns
------- -------
mxnet.NDArray mxnet.NDArray
The output feature of shape :math:`(N, H, D_{out})` where :math:`H` The output feature of shape :math:`(N, *, H, D_{out})` where :math:`H`
is the number of heads, and :math:`D_{out}` is size of output feature. is the number of heads, and :math:`D_{out}` is size of output feature.
mxnet.NDArray, optional mxnet.NDArray, optional
The attention values of shape :math:`(E, H, 1)`, where :math:`E` is the number of The attention values of shape :math:`(E, *, H, 1)`, where :math:`E` is the number of
edges. This is returned only when :attr:`get_attention` is ``True``. edges. This is returned only when :attr:`get_attention` is ``True``.
Raises Raises
...@@ -250,18 +250,23 @@ class GATConv(nn.Block): ...@@ -250,18 +250,23 @@ class GATConv(nn.Block):
'suppress the check and let the code run.') 'suppress the check and let the code run.')
if isinstance(feat, tuple): if isinstance(feat, tuple):
src_prefix_shape = feat[0].shape[:-1]
dst_prefix_shape = feat[1].shape[:-1]
feat_dim = feat[0].shape[-1]
h_src = self.feat_drop(feat[0]) h_src = self.feat_drop(feat[0])
h_dst = self.feat_drop(feat[1]) h_dst = self.feat_drop(feat[1])
if not hasattr(self, 'fc_src'): if not hasattr(self, 'fc_src'):
self.fc_src, self.fc_dst = self.fc, self.fc self.fc_src, self.fc_dst = self.fc, self.fc
feat_src = self.fc_src(h_src).reshape( feat_src = self.fc_src(h_src.reshape(-1, feat_dim)).reshape(
-1, self._num_heads, self._out_feats) *src_prefix_shape, self._num_heads, self._out_feats)
feat_dst = self.fc_dst(h_dst).reshape( feat_dst = self.fc_dst(h_dst.reshape(-1, feat_dim)).reshape(
-1, self._num_heads, self._out_feats) *dst_prefix_shape, self._num_heads, self._out_feats)
else: else:
src_prefix_shape = dst_prefix_shape = feat.shape[:-1]
feat_dim = feat[0].shape[-1]
h_src = h_dst = self.feat_drop(feat) h_src = h_dst = self.feat_drop(feat)
feat_src = feat_dst = self.fc(h_src).reshape( feat_src = feat_dst = self.fc(h_src.reshape(-1, feat_dim)).reshape(
-1, self._num_heads, self._out_feats) *src_prefix_shape, self._num_heads, self._out_feats)
if graph.is_block: if graph.is_block:
feat_dst = feat_src[:graph.number_of_dst_nodes()] feat_dst = feat_src[:graph.number_of_dst_nodes()]
# NOTE: GAT paper uses "first concatenation then linear projection" # NOTE: GAT paper uses "first concatenation then linear projection"
...@@ -288,7 +293,8 @@ class GATConv(nn.Block): ...@@ -288,7 +293,8 @@ class GATConv(nn.Block):
rst = graph.dstdata['ft'] rst = graph.dstdata['ft']
# residual # residual
if self.res_fc is not None: if self.res_fc is not None:
resval = self.res_fc(h_dst).reshape(h_dst.shape[0], -1, self._out_feats) resval = self.res_fc(h_dst.reshape(-1, feat_dim)).reshape(
*dst_prefix_shape, -1, self._out_feats)
rst = rst + resval rst = rst + resval
# activation # activation
if self.activation: if self.activation:
......
...@@ -229,20 +229,20 @@ class GATConv(nn.Module): ...@@ -229,20 +229,20 @@ class GATConv(nn.Module):
graph : DGLGraph graph : DGLGraph
The graph. The graph.
feat : torch.Tensor or pair of torch.Tensor feat : torch.Tensor or pair of torch.Tensor
If a torch.Tensor is given, the input feature of shape :math:`(N, D_{in})` where If a torch.Tensor is given, the input feature of shape :math:`(N, *, D_{in})` where
:math:`D_{in}` is size of input feature, :math:`N` is the number of nodes. :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
If a pair of torch.Tensor is given, the pair must contain two tensors of shape If a pair of torch.Tensor is given, the pair must contain two tensors of shape
:math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`. :math:`(N_{in}, *, D_{in_{src}})` and :math:`(N_{out}, *, D_{in_{dst}})`.
get_attention : bool, optional get_attention : bool, optional
Whether to return the attention values. Default to False. Whether to return the attention values. Default to False.
Returns Returns
------- -------
torch.Tensor torch.Tensor
The output feature of shape :math:`(N, H, D_{out})` where :math:`H` The output feature of shape :math:`(N, *, H, D_{out})` where :math:`H`
is the number of heads, and :math:`D_{out}` is size of output feature. is the number of heads, and :math:`D_{out}` is size of output feature.
torch.Tensor, optional torch.Tensor, optional
The attention values of shape :math:`(E, H, 1)`, where :math:`E` is the number of The attention values of shape :math:`(E, *, H, 1)`, where :math:`E` is the number of
edges. This is returned only when :attr:`get_attention` is ``True``. edges. This is returned only when :attr:`get_attention` is ``True``.
Raises Raises
...@@ -266,18 +266,25 @@ class GATConv(nn.Module): ...@@ -266,18 +266,25 @@ class GATConv(nn.Module):
'suppress the check and let the code run.') 'suppress the check and let the code run.')
if isinstance(feat, tuple): if isinstance(feat, tuple):
src_prefix_shape = feat[0].shape[:-1]
dst_prefix_shape = feat[1].shape[:-1]
h_src = self.feat_drop(feat[0]) h_src = self.feat_drop(feat[0])
h_dst = self.feat_drop(feat[1]) h_dst = self.feat_drop(feat[1])
if not hasattr(self, 'fc_src'): if not hasattr(self, 'fc_src'):
feat_src = self.fc(h_src).view(-1, self._num_heads, self._out_feats) feat_src = self.fc(h_src).view(
feat_dst = self.fc(h_dst).view(-1, self._num_heads, self._out_feats) *src_prefix_shape, self._num_heads, self._out_feats)
feat_dst = self.fc(h_dst).view(
*dst_prefix_shape, self._num_heads, self._out_feats)
else: else:
feat_src = self.fc_src(h_src).view(-1, self._num_heads, self._out_feats) feat_src = self.fc_src(h_src).view(
feat_dst = self.fc_dst(h_dst).view(-1, self._num_heads, self._out_feats) *src_prefix_shape, self._num_heads, self._out_feats)
feat_dst = self.fc_dst(h_dst).view(
*dst_prefix_shape, self._num_heads, self._out_feats)
else: else:
src_prefix_shape = dst_prefix_shape = feat.shape[:-1]
h_src = h_dst = self.feat_drop(feat) h_src = h_dst = self.feat_drop(feat)
feat_src = feat_dst = self.fc(h_src).view( feat_src = feat_dst = self.fc(h_src).view(
-1, self._num_heads, self._out_feats) *src_prefix_shape, self._num_heads, self._out_feats)
if graph.is_block: if graph.is_block:
feat_dst = feat_src[:graph.number_of_dst_nodes()] feat_dst = feat_src[:graph.number_of_dst_nodes()]
# NOTE: GAT paper uses "first concatenation then linear projection" # NOTE: GAT paper uses "first concatenation then linear projection"
...@@ -306,11 +313,12 @@ class GATConv(nn.Module): ...@@ -306,11 +313,12 @@ class GATConv(nn.Module):
# residual # residual
if self.res_fc is not None: if self.res_fc is not None:
# Use -1 rather than self._num_heads to handle broadcasting # Use -1 rather than self._num_heads to handle broadcasting
resval = self.res_fc(h_dst).view(h_dst.shape[0], -1, self._out_feats) resval = self.res_fc(h_dst).view(*dst_prefix_shape, -1, self._out_feats)
rst = rst + resval rst = rst + resval
# bias # bias
if self.bias is not None: if self.bias is not None:
rst = rst + self.bias.view(1, self._num_heads, self._out_feats) rst = rst + self.bias.view(
*((1,) * len(dst_prefix_shape)), self._num_heads, self._out_feats)
# activation # activation
if self.activation: if self.activation:
rst = self.activation(rst) rst = self.activation(rst)
......
...@@ -207,20 +207,20 @@ class GATConv(layers.Layer): ...@@ -207,20 +207,20 @@ class GATConv(layers.Layer):
graph : DGLGraph graph : DGLGraph
The graph. The graph.
feat : tf.Tensor or pair of tf.Tensor feat : tf.Tensor or pair of tf.Tensor
If a tf.Tensor is given, the input feature of shape :math:`(N, D_{in})` where If a tf.Tensor is given, the input feature of shape :math:`(N, *, D_{in})` where
:math:`D_{in}` is size of input feature, :math:`N` is the number of nodes. :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
If a pair of tf.Tensor is given, the pair must contain two tensors of shape If a pair of tf.Tensor is given, the pair must contain two tensors of shape
:math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`. :math:`(N_{in}, *, D_{in_{src}})` and :math:`(N_{out}, *, D_{in_{dst}})`.
get_attention : bool, optional get_attention : bool, optional
Whether to return the attention values. Default to False. Whether to return the attention values. Default to False.
Returns Returns
------- -------
tf.Tensor tf.Tensor
The output feature of shape :math:`(N, H, D_{out})` where :math:`H` The output feature of shape :math:`(N, *, H, D_{out})` where :math:`H`
is the number of heads, and :math:`D_{out}` is size of output feature. is the number of heads, and :math:`D_{out}` is size of output feature.
tf.Tensor, optional tf.Tensor, optional
The attention values of shape :math:`(E, H, 1)`, where :math:`E` is the number of The attention values of shape :math:`(E, *, H, 1)`, where :math:`E` is the number of
edges. This is returned only when :attr:`get_attention` is ``True``. edges. This is returned only when :attr:`get_attention` is ``True``.
Raises Raises
...@@ -244,16 +244,23 @@ class GATConv(layers.Layer): ...@@ -244,16 +244,23 @@ class GATConv(layers.Layer):
'suppress the check and let the code run.') 'suppress the check and let the code run.')
if isinstance(feat, tuple): if isinstance(feat, tuple):
src_prefix_shape = tuple(feat[0].shape[:-1])
dst_prefix_shape = tuple(feat[1].shape[:-1])
h_src = self.feat_drop(feat[0]) h_src = self.feat_drop(feat[0])
h_dst = self.feat_drop(feat[1]) h_dst = self.feat_drop(feat[1])
if not hasattr(self, 'fc_src'): if not hasattr(self, 'fc_src'):
self.fc_src, self.fc_dst = self.fc, self.fc self.fc_src, self.fc_dst = self.fc, self.fc
feat_src = tf.reshape(self.fc_src(h_src), (-1, self._num_heads, self._out_feats)) feat_src = tf.reshape(
feat_dst = tf.reshape(self.fc_dst(h_dst), (-1, self._num_heads, self._out_feats)) self.fc_src(h_src),
src_prefix_shape + (self._num_heads, self._out_feats))
feat_dst = tf.reshape(
self.fc_dst(h_dst),
dst_prefix_shape + (self._num_heads, self._out_feats))
else: else:
src_prefix_shape = dst_prefix_shape = tuple(feat.shape[:-1])
h_src = h_dst = self.feat_drop(feat) h_src = h_dst = self.feat_drop(feat)
feat_src = feat_dst = tf.reshape( feat_src = feat_dst = tf.reshape(
self.fc(h_src), (-1, self._num_heads, self._out_feats)) self.fc(h_src), src_prefix_shape + (self._num_heads, self._out_feats))
if graph.is_block: if graph.is_block:
feat_dst = feat_src[:graph.number_of_dst_nodes()] feat_dst = feat_src[:graph.number_of_dst_nodes()]
# NOTE: GAT paper uses "first concatenation then linear projection" # NOTE: GAT paper uses "first concatenation then linear projection"
...@@ -282,7 +289,7 @@ class GATConv(layers.Layer): ...@@ -282,7 +289,7 @@ class GATConv(layers.Layer):
# residual # residual
if self.res_fc is not None: if self.res_fc is not None:
resval = tf.reshape(self.res_fc( resval = tf.reshape(self.res_fc(
h_dst), (h_dst.shape[0], -1, self._out_feats)) h_dst), dst_prefix_shape + (-1, self._out_feats))
rst = rst + resval rst = rst + resval
# activation # activation
if self.activation: if self.activation:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment