Unverified Commit 9314aabd authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[Refactor] Interface of nn modules (#798)

* refactor

* upd mpnn
parent 650f6ee1
......@@ -36,5 +36,5 @@ class GCN(gluon.Block):
for i, layer in enumerate(self.layers):
if i != 0:
h = self.dropout(h)
h = layer(h, self.g)
h = layer(self.g, h)
return h
......@@ -53,5 +53,5 @@ class APPNP(nn.Module):
h = self.activation(layer(h))
h = self.layers[-1](self.feat_drop(h))
# propagation step
h = self.propagate(h, self.g)
h = self.propagate(self.g, h)
return h
......@@ -49,7 +49,7 @@ class GAT(nn.Module):
def forward(self, inputs):
h = inputs
for l in range(self.num_layers):
h = self.gat_layers[l](h, self.g).flatten(1)
h = self.gat_layers[l](self.g, h).flatten(1)
# output projection
logits = self.gat_layers[-1](h, self.g).mean(1)
logits = self.gat_layers[-1](self.g, h).mean(1)
return logits
......@@ -35,5 +35,5 @@ class GCN(nn.Module):
for i, layer in enumerate(self.layers):
if i != 0:
h = self.dropout(h)
h = layer(h, self.g)
h = layer(self.g, h)
return h
......@@ -155,7 +155,7 @@ class GIN(nn.Module):
hidden_rep = [h]
for layer in range(self.num_layers - 1):
h = self.ginlayers[layer](h, g)
h = self.ginlayers[layer](g, h)
hidden_rep.append(h)
score_over_layer = 0
......
......@@ -41,7 +41,7 @@ class GraphSAGE(nn.Module):
def forward(self, features):
h = features
for layer in self.layers:
h = layer(h, self.g)
h = layer(self.g, h)
return h
......
......@@ -50,7 +50,7 @@ GIN_CONFIG = {
}
CHEBNET_CONFIG = {
'extra_args': [16, 1, 3, True],
'extra_args': [32, 1, 2, True],
'lr': 1e-2,
'weight_decay': 5e-4,
}
......@@ -30,7 +30,7 @@ class GCN(nn.Module):
for i, layer in enumerate(self.layers):
if i != 0:
h = self.dropout(h)
h = layer(h, self.g)
h = layer(self.g, h)
return h
......@@ -70,9 +70,9 @@ class GAT(nn.Module):
def forward(self, inputs):
h = inputs
for l in range(self.num_layers):
h = self.gat_layers[l](h, self.g).flatten(1)
h = self.gat_layers[l](self.g, h).flatten(1)
# output projection
logits = self.gat_layers[-1](h, self.g).mean(1)
logits = self.gat_layers[-1](self.g, h).mean(1)
return logits
......@@ -101,7 +101,7 @@ class GraphSAGE(nn.Module):
def forward(self, features):
h = features
for layer in self.layers:
h = layer(h, self.g)
h = layer(self.g, h)
return h
......@@ -148,7 +148,7 @@ class APPNP(nn.Module):
h = self.activation(layer(h))
h = self.layers[-1](self.feat_drop(h))
# propagation step
h = self.propagate(h, self.g)
h = self.propagate(self.g, h)
return h
......@@ -178,7 +178,7 @@ class TAGCN(nn.Module):
for i, layer in enumerate(self.layers):
if i != 0:
h = self.dropout(h)
h = layer(h, self.g)
h = layer(self.g, h)
return h
......@@ -210,7 +210,7 @@ class AGNN(nn.Module):
def forward(self, features):
h = self.proj(features)
for layer in self.layers:
h = layer(h, self.g)
h = layer(self.g, h)
return self.cls(h)
......@@ -231,7 +231,7 @@ class SGC(nn.Module):
bias=bias)
def forward(self, features):
return self.net(features, self.g)
return self.net(self.g, features)
class GIN(nn.Module):
......@@ -286,7 +286,7 @@ class GIN(nn.Module):
def forward(self, features):
h = features
for layer in self.layers:
h = layer(h, self.g)
h = layer(self.g, h)
return h
class ChebNet(nn.Module):
......@@ -316,5 +316,5 @@ class ChebNet(nn.Module):
def forward(self, features):
h = features
for layer in self.layers:
h = layer(h, self.g)
h = layer(self.g, h, [2])
return h
\ No newline at end of file
......@@ -19,7 +19,7 @@ from dgl.nn.pytorch.conv import SGConv
def evaluate(model, g, features, labels, mask):
model.eval()
with torch.no_grad():
logits = model(features, g)[mask] # only compute the evaluation set
logits = model(g, features)[mask] # only compute the evaluation set
labels = labels[mask]
_, indices = torch.max(logits, dim=1)
correct = torch.sum(indices == labels)
......@@ -86,7 +86,7 @@ def main(args):
if epoch >= 3:
t0 = time.time()
# forward
logits = model(features, g) # only compute the train set
logits = model(g, features) # only compute the train set
loss = loss_fcn(logits[train_mask], labels[train_mask])
optimizer.zero_grad()
......
......@@ -21,7 +21,7 @@ def normalize(h):
def evaluate(model, features, graph, labels, mask):
model.eval()
with torch.no_grad():
logits = model(features, graph)[mask] # only compute the evaluation set
logits = model(graph, features)[mask] # only compute the evaluation set
labels = labels[mask]
_, indices = torch.max(logits, dim=1)
correct = torch.sum(indices == labels)
......@@ -82,7 +82,7 @@ def main(args):
# define loss closure
def closure():
optimizer.zero_grad()
output = model(features, g)[train_mask]
output = model(g, features)[train_mask]
loss_train = F.cross_entropy(output, labels[train_mask])
loss_train.backward()
return loss_train
......
......@@ -35,5 +35,5 @@ class TAGCN(nn.Module):
for i, layer in enumerate(self.layers):
if i != 0:
h = self.dropout(h)
h = layer(h, self.g)
h = layer(self.g, h)
return h
......@@ -145,7 +145,7 @@ class MPNNModel(nn.Module):
out, h = self.gru(m.unsqueeze(0), h)
out = out.squeeze(0)
out = self.set2set(out, g)
out = self.set2set(g, out)
out = F.relu(self.lin1(out))
out = self.lin2(out)
return out
......@@ -83,7 +83,7 @@ class GraphConv(gluon.Block):
self._activation = activation
def forward(self, feat, graph):
def forward(self, graph, feat):
r"""Compute graph convolution.
Notes
......@@ -95,10 +95,10 @@ class GraphConv(gluon.Block):
Parameters
----------
feat : mxnet.NDArray
The input feature
graph : DGLGraph
The graph.
feat : mxnet.NDArray
The input feature
Returns
-------
......
......@@ -19,16 +19,16 @@ class SumPooling(nn.Block):
def __init__(self):
super(SumPooling, self).__init__()
def forward(self, feat, graph):
def forward(self, graph, feat):
r"""Compute sum pooling.
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
The graph.
feat : mxnet.NDArray
The input feature with shape :math:`(N, *)` where
:math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns
-------
......@@ -56,16 +56,16 @@ class AvgPooling(nn.Block):
def __init__(self):
super(AvgPooling, self).__init__()
def forward(self, feat, graph):
def forward(self, graph, feat):
r"""Compute average pooling.
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
The graph.
feat : mxnet.NDArray
The input feature with shape :math:`(N, *)` where
:math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns
-------
......@@ -93,16 +93,16 @@ class MaxPooling(nn.Block):
def __init__(self):
super(MaxPooling, self).__init__()
def forward(self, feat, graph):
def forward(self, graph, feat):
r"""Compute max pooling.
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
The graph.
feat : mxnet.NDArray
The input feature with shape :math:`(N, *)` where
:math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns
-------
......@@ -134,16 +134,16 @@ class SortPooling(nn.Block):
super(SortPooling, self).__init__()
self.k = k
def forward(self, feat, graph):
def forward(self, graph, feat):
r"""Compute sort pooling.
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
The graph.
feat : mxnet.NDArray
The input feature with shape :math:`(N, D)` where
:math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns
-------
......@@ -190,16 +190,16 @@ class GlobalAttentionPooling(nn.Block):
self.gate_nn = gate_nn
self.feat_nn = feat_nn
def forward(self, feat, graph):
def forward(self, graph, feat):
r"""Compute global attention pooling.
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
The graph.
feat : mxnet.NDArray
The input feature with shape :math:`(N, D)` where
:math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns
-------
......@@ -258,16 +258,16 @@ class Set2Set(nn.Block):
self.lstm = gluon.rnn.LSTM(
self.input_dim, num_layers=n_layers, input_size=self.output_dim)
def forward(self, feat, graph):
def forward(self, graph, feat):
r"""Compute set2set pooling.
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
The graph.
feat : mxnet.NDArray
The input feature with shape :math:`(N, D)` where
:math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns
-------
......
......@@ -107,7 +107,7 @@ class GraphConv(nn.Module):
if self.bias is not None:
init.zeros_(self.bias)
def forward(self, feat, graph):
def forward(self, graph, feat):
r"""Compute graph convolution.
Notes
......@@ -119,10 +119,10 @@ class GraphConv(nn.Module):
Parameters
----------
feat : torch.Tensor
The input feature
graph : DGLGraph
The graph.
feat : torch.Tensor
The input feature
Returns
-------
......@@ -246,16 +246,16 @@ class GATConv(nn.Module):
if isinstance(self.res_fc, nn.Linear):
nn.init.xavier_normal_(self.res_fc.weight, gain=gain)
def forward(self, feat, graph):
def forward(self, graph, feat):
r"""Compute graph attention network layer.
Parameters
----------
graph : DGLGraph
The graph.
feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
is size of input feature, :math:`N` is the number of nodes.
graph : DGLGraph
The graph.
Returns
-------
......@@ -338,16 +338,16 @@ class TAGConv(nn.Module):
gain = nn.init.calculate_gain('relu')
nn.init.xavier_normal_(self.lin.weight, gain=gain)
def forward(self, feat, graph):
def forward(self, graph, feat):
r"""Compute topology adaptive graph convolution.
Parameters
----------
graph : DGLGraph
The graph.
feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
is size of input feature, :math:`N` is the number of nodes.
graph : DGLGraph
The graph.
Returns
-------
......@@ -643,16 +643,16 @@ class SAGEConv(nn.Module):
_, (rst, _) = self.lstm(m, h)
return {'neigh': rst.squeeze(0)}
def forward(self, feat, graph):
def forward(self, graph, feat):
r"""Compute GraphSAGE layer.
Parameters
----------
graph : DGLGraph
The graph.
feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
is size of input feature, :math:`N` is the number of nodes.
graph : DGLGraph
The graph.
Returns
-------
......@@ -742,11 +742,13 @@ class GatedGraphConv(nn.Module):
self.gru.reset_parameters()
init.xavier_normal_(self.edge_embed.weight, gain=gain)
def forward(self, feat, etypes, graph):
def forward(self, graph, feat, etypes):
"""Compute Gated Graph Convolution layer.
Parameters
----------
graph : DGLGraph
The graph.
feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`N`
is the number of nodes of the graph and :math:`D_{in}` is the
......@@ -754,8 +756,6 @@ class GatedGraphConv(nn.Module):
etypes : torch.LongTensor
The edge type tensor of shape :math:`(E,)` where :math:`E` is
the number of edges of the graph.
graph : DGLGraph
The graph.
Returns
-------
......@@ -856,11 +856,13 @@ class GMMConv(nn.Module):
if self.bias is not None:
init.zeros_(self.bias.data)
def forward(self, feat, pseudo, graph):
def forward(self, graph, feat, pseudo):
"""Compute Gaussian Mixture Model Convolution layer.
Parameters
----------
graph : DGLGraph
The graph.
feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`N`
is the number of nodes of the graph and :math:`D_{in}` is the
......@@ -869,8 +871,6 @@ class GMMConv(nn.Module):
The pseudo coordinate tensor of shape :math:`(E, D_{u})` where
:math:`E` is the number of edges of the graph and :math:`D_{u}`
is the dimensionality of pseudo coordinate.
graph : DGLGraph
The graph.
Returns
-------
......@@ -940,18 +940,18 @@ class GINConv(nn.Module):
else:
self.register_buffer('eps', th.FloatTensor([init_eps]))
def forward(self, feat, graph):
def forward(self, graph, feat):
r"""Compute Graph Isomorphism Network layer.
Parameters
----------
graph : DGLGraph
The graph.
feat : torch.Tensor
The input feature of shape :math:`(N, D)` where :math:`D`
could be any positive integer, :math:`N` is the number
of nodes. If ``apply_func`` is not None, :math:`D` should
fit the input dimensionality requirement of ``apply_func``.
graph : DGLGraph
The graph.
Returns
-------
......@@ -1025,16 +1025,22 @@ class ChebConv(nn.Module):
if module.bias is not None:
init.zeros_(module.bias)
def forward(self, feat, graph, lambda_max=None):
def forward(self, graph, feat, lambda_max=None):
r"""Compute ChebNet layer.
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
The graph.
feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
is size of input feature, :math:`N` is the number of nodes.
graph : DGLGraph
The graph.
lambda_max : list or tensor or None, optional.
A list(tensor) with length :math:`B`, stores the largest eigenvalue
of the normalized laplacian of each individual graph in ``graph``,
where :math:`B` is the batch size of the input graph. Default: None.
If None, this method would compute the list by calling
``dgl.laplacian_lambda_max``.
Returns
-------
......@@ -1047,13 +1053,13 @@ class ChebConv(nn.Module):
graph.in_degrees().float().clamp(min=1), -0.5).unsqueeze(-1).to(feat.device)
if lambda_max is None:
lambda_max = laplacian_lambda_max(graph)
if isinstance(lambda_max, list):
lambda_max = th.Tensor(lambda_max).to(feat.device)
if lambda_max.dim() < 1:
lambda_max = lambda_max.unsqueeze(-1) # (B,) to (B, 1)
# broadcast from (B, 1) to (N, 1)
lambda_max = broadcast_nodes(graph, lambda_max)
# T0(X)
Tx_0 = feat
rst = self.fc[0](Tx_0)
# T1(X)
......@@ -1125,16 +1131,16 @@ class SGConv(nn.Module):
self._k = k
self.norm = norm
def forward(self, feat, graph):
def forward(self, graph, feat):
r"""Compute Simplifying Graph Convolution layer.
Parameters
----------
graph : DGLGraph
The graph.
feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
is size of input feature, :math:`N` is the number of nodes.
graph : DGLGraph
The graph.
Returns
-------
......@@ -1241,11 +1247,13 @@ class NNConv(nn.Module):
if isinstance(self.res_fc, nn.Linear):
nn.init.xavier_normal_(self.res_fc.weight, gain=gain)
def forward(self, feat, efeat, graph):
def forward(self, graph, feat, efeat):
r"""Compute MPNN Graph Convolution layer.
Parameters
----------
graph : DGLGraph
The graph.
feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`N`
is the number of nodes of the graph and :math:`D_{in}` is the
......@@ -1253,8 +1261,6 @@ class NNConv(nn.Module):
efeat : torch.Tensor
The edge feature of shape :math:`(N, *)`, should fit the input
shape requirement of ``edge_nn``.
graph : DGLGraph
The graph.
Returns
-------
......@@ -1309,16 +1315,16 @@ class APPNPConv(nn.Module):
self._alpha = alpha
self.edge_drop = nn.Dropout(edge_drop) if edge_drop > 0 else Identity()
def forward(self, feat, graph):
def forward(self, graph, feat):
r"""Compute APPNP layer.
Parameters
----------
graph : DGLGraph
The graph.
feat : torch.Tensor
The input feature of shape :math:`(N, *)` :math:`N` is the
number of nodes, and :math:`*` could be of any shape.
graph : DGLGraph
The graph.
Returns
-------
......@@ -1374,16 +1380,16 @@ class AGNNConv(nn.Module):
else:
self.register_buffer('beta', th.Tensor([init_beta]))
def forward(self, feat, graph):
def forward(self, graph, feat):
r"""Compute AGNN layer.
Parameters
----------
graph : DGLGraph
The graph.
feat : torch.Tensor
The input feature of shape :math:`(N, *)` :math:`N` is the
number of nodes, and :math:`*` could be of any shape.
graph : DGLGraph
The graph.
Returns
-------
......@@ -1452,18 +1458,18 @@ class DenseGraphConv(nn.Module):
if self.bias is not None:
init.zeros_(self.bias)
def forward(self, feat, adj):
def forward(self, adj, feat):
r"""Compute (Dense) Graph Convolution layer.
Parameters
----------
feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
is size of input feature, :math:`N` is the number of nodes.
adj : torch.Tensor
The adjacency matrix of the graph to apply Graph Convolution on,
should be of shape :math:`(N, N)`, where a row represents the destination
and a column represents the source.
feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
is size of input feature, :math:`N` is the number of nodes.
Returns
-------
......@@ -1549,18 +1555,18 @@ class DenseSAGEConv(nn.Module):
gain = nn.init.calculate_gain('relu')
nn.init.xavier_uniform_(self.fc.weight, gain=gain)
def forward(self, feat, adj):
def forward(self, adj, feat):
r"""Compute (Dense) Graph SAGE layer.
Parameters
----------
feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
is size of input feature, :math:`N` is the number of nodes.
adj : torch.Tensor
The adjacency matrix of the graph to apply Graph Convolution on,
should be of shape :math:`(N, N)`, where a row represents the destination
and a column represents the source.
feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
is size of input feature, :math:`N` is the number of nodes.
Returns
-------
......@@ -1629,18 +1635,21 @@ class DenseChebConv(nn.Module):
for i in range(self._k):
init.xavier_normal_(self.W[i], init.calculate_gain('relu'))
def forward(self, feat, adj):
def forward(self, adj, feat, lambda_max=None):
r"""Compute (Dense) Chebyshev Spectral Graph Convolution layer.
Parameters
----------
feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
is size of input feature, :math:`N` is the number of nodes.
adj : torch.Tensor
The adjacency matrix of the graph to apply Graph Convolution on,
should be of shape :math:`(N, N)`, where a row represents the destination
and a column represents the source.
feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
is size of input feature, :math:`N` is the number of nodes.
lambda_max : float or None, optional
A float value indicates the largest eigenvalue of given graph.
Default: None.
Returns
-------
......@@ -1656,10 +1665,11 @@ class DenseChebConv(nn.Module):
I = th.eye(num_nodes).to(A)
L = I - D_invsqrt @ A @ D_invsqrt
if lambda_max is None:
lambda_ = th.eig(L)[0][:, 0]
lambda_max = lambda_.max()
L_hat = 2 * L / lambda_max - I
L_hat = 2 * L / lambda_max - I
Z = [th.eye(num_nodes).to(A)]
for i in range(1, self._k):
if i == 1:
......
......@@ -23,17 +23,17 @@ class SumPooling(nn.Module):
def __init__(self):
super(SumPooling, self).__init__()
def forward(self, feat, graph):
def forward(self, graph, feat):
r"""Compute sum pooling.
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
The graph.
feat : torch.Tensor
The input feature with shape :math:`(N, *)` where
:math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns
-------
......@@ -57,16 +57,16 @@ class AvgPooling(nn.Module):
def __init__(self):
super(AvgPooling, self).__init__()
def forward(self, feat, graph):
def forward(self, graph, feat):
r"""Compute average pooling.
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
The graph.
feat : torch.Tensor
The input feature with shape :math:`(N, *)` where
:math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns
-------
......@@ -90,16 +90,16 @@ class MaxPooling(nn.Module):
def __init__(self):
super(MaxPooling, self).__init__()
def forward(self, feat, graph):
def forward(self, graph, feat):
r"""Compute max pooling.
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
The graph.
feat : torch.Tensor
The input feature with shape :math:`(N, *)` where
:math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns
-------
......@@ -127,16 +127,16 @@ class SortPooling(nn.Module):
super(SortPooling, self).__init__()
self.k = k
def forward(self, feat, graph):
def forward(self, graph, feat):
r"""Compute sort pooling.
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
The graph.
feat : torch.Tensor
The input feature with shape :math:`(N, D)` where
:math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns
-------
......@@ -179,16 +179,16 @@ class GlobalAttentionPooling(nn.Module):
self.gate_nn = gate_nn
self.feat_nn = feat_nn
def forward(self, feat, graph):
def forward(self, graph, feat):
r"""Compute global attention pooling.
Parameters
----------
graph : DGLGraph
The graph.
feat : torch.Tensor
The input feature with shape :math:`(N, D)` where
:math:`N` is the number of nodes in the graph.
graph : DGLGraph
The graph.
Returns
-------
......@@ -252,16 +252,16 @@ class Set2Set(nn.Module):
"""Reinitialize learnable parameters."""
self.lstm.reset_parameters()
def forward(self, feat, graph):
def forward(self, graph, feat):
r"""Compute set2set pooling.
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
The graph.
feat : torch.Tensor
The input feature with shape :math:`(N, D)` where
:math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns
-------
......@@ -568,17 +568,17 @@ class SetTransformerEncoder(nn.Module):
self.layers = nn.ModuleList(layers)
def forward(self, feat, graph):
def forward(self, graph, feat):
"""
Compute the Encoder part of Set Transformer.
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
The graph.
feat : torch.Tensor
The input feature with shape :math:`(N, D)` where
:math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns
-------
......@@ -634,17 +634,17 @@ class SetTransformerDecoder(nn.Module):
self.layers = nn.ModuleList(layers)
def forward(self, feat, graph):
def forward(self, graph, feat):
"""
Compute the decoder part of Set Transformer.
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
The graph.
feat : torch.Tensor
The input feature with shape :math:`(N, D)` where
:math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns
-------
......
......@@ -24,13 +24,13 @@ def test_graph_conv():
conv.initialize(ctx=ctx)
# test#1: basic
h0 = F.ones((3, 5))
h1 = conv(h0, g)
h1 = conv(g, h0)
assert len(g.ndata) == 0
assert len(g.edata) == 0
check_close(h1, _AXWb(adj, h0, conv.weight, conv.bias))
# test#2: more-dim
h0 = F.ones((3, 5, 5))
h1 = conv(h0, g)
h1 = conv(g, h0)
assert len(g.ndata) == 0
assert len(g.edata) == 0
check_close(h1, _AXWb(adj, h0, conv.weight, conv.bias))
......@@ -40,12 +40,12 @@ def test_graph_conv():
# test#3: basic
h0 = F.ones((3, 5))
h1 = conv(h0, g)
h1 = conv(g, h0)
assert len(g.ndata) == 0
assert len(g.edata) == 0
# test#4: basic
h0 = F.ones((3, 5, 5))
h1 = conv(h0, g)
h1 = conv(g, h0)
assert len(g.ndata) == 0
assert len(g.edata) == 0
......@@ -55,18 +55,18 @@ def test_graph_conv():
with autograd.train_mode():
# test#3: basic
h0 = F.ones((3, 5))
h1 = conv(h0, g)
h1 = conv(g, h0)
assert len(g.ndata) == 0
assert len(g.edata) == 0
# test#4: basic
h0 = F.ones((3, 5, 5))
h1 = conv(h0, g)
h1 = conv(g, h0)
assert len(g.ndata) == 0
assert len(g.edata) == 0
# test not override features
g.ndata["h"] = 2 * F.ones((3, 1))
h1 = conv(h0, g)
h1 = conv(g, h0)
assert len(g.ndata) == 1
assert len(g.edata) == 0
assert "h" in g.ndata
......@@ -82,13 +82,13 @@ def test_set2set():
# test#1: basic
h0 = F.randn((g.number_of_nodes(), 5))
h1 = s2s(h0, g)
h1 = s2s(g, h0)
assert h1.shape[0] == 10 and h1.ndim == 1
# test#2: batched graph
bg = dgl.batch([g, g, g])
h0 = F.randn((bg.number_of_nodes(), 5))
h1 = s2s(h0, bg)
h1 = s2s(bg, h0)
assert h1.shape[0] == 3 and h1.shape[1] == 10 and h1.ndim == 2
def test_glob_att_pool():
......@@ -100,13 +100,13 @@ def test_glob_att_pool():
print(gap)
# test#1: basic
h0 = F.randn((g.number_of_nodes(), 5))
h1 = gap(h0, g)
h1 = gap(g, h0)
assert h1.shape[0] == 10 and h1.ndim == 1
# test#2: batched graph
bg = dgl.batch([g, g, g, g])
h0 = F.randn((bg.number_of_nodes(), 5))
h1 = gap(h0, bg)
h1 = gap(bg, h0)
assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.ndim == 2
def test_simple_pool():
......@@ -120,20 +120,20 @@ def test_simple_pool():
# test#1: basic
h0 = F.randn((g.number_of_nodes(), 5))
h1 = sum_pool(h0, g)
h1 = sum_pool(g, h0)
check_close(h1, F.sum(h0, 0))
h1 = avg_pool(h0, g)
h1 = avg_pool(g, h0)
check_close(h1, F.mean(h0, 0))
h1 = max_pool(h0, g)
h1 = max_pool(g, h0)
check_close(h1, F.max(h0, 0))
h1 = sort_pool(h0, g)
h1 = sort_pool(g, h0)
assert h1.shape[0] == 10 * 5 and h1.ndim == 1
# test#2: batched graph
g_ = dgl.DGLGraph(nx.path_graph(5))
bg = dgl.batch([g, g_, g, g_, g])
h0 = F.randn((bg.number_of_nodes(), 5))
h1 = sum_pool(h0, bg)
h1 = sum_pool(bg, h0)
truth = mx.nd.stack(F.sum(h0[:15], 0),
F.sum(h0[15:20], 0),
F.sum(h0[20:35], 0),
......@@ -141,7 +141,7 @@ def test_simple_pool():
F.sum(h0[40:55], 0), axis=0)
check_close(h1, truth)
h1 = avg_pool(h0, bg)
h1 = avg_pool(bg, h0)
truth = mx.nd.stack(F.mean(h0[:15], 0),
F.mean(h0[15:20], 0),
F.mean(h0[20:35], 0),
......@@ -149,7 +149,7 @@ def test_simple_pool():
F.mean(h0[40:55], 0), axis=0)
check_close(h1, truth)
h1 = max_pool(h0, bg)
h1 = max_pool(bg, h0)
truth = mx.nd.stack(F.max(h0[:15], 0),
F.max(h0[15:20], 0),
F.max(h0[20:35], 0),
......@@ -157,7 +157,7 @@ def test_simple_pool():
F.max(h0[40:55], 0), axis=0)
check_close(h1, truth)
h1 = sort_pool(h0, bg)
h1 = sort_pool(bg, h0)
assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.ndim == 2
def uniform_attention(g, shape):
......
......@@ -24,13 +24,13 @@ def test_graph_conv():
print(conv)
# test#1: basic
h0 = F.ones((3, 5))
h1 = conv(h0, g)
h1 = conv(g, h0)
assert len(g.ndata) == 0
assert len(g.edata) == 0
assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
# test#2: more-dim
h0 = F.ones((3, 5, 5))
h1 = conv(h0, g)
h1 = conv(g, h0)
assert len(g.ndata) == 0
assert len(g.edata) == 0
assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
......@@ -40,12 +40,12 @@ def test_graph_conv():
conv = conv.to(ctx)
# test#3: basic
h0 = F.ones((3, 5))
h1 = conv(h0, g)
h1 = conv(g, h0)
assert len(g.ndata) == 0
assert len(g.edata) == 0
# test#4: basic
h0 = F.ones((3, 5, 5))
h1 = conv(h0, g)
h1 = conv(g, h0)
assert len(g.ndata) == 0
assert len(g.edata) == 0
......@@ -54,12 +54,12 @@ def test_graph_conv():
conv = conv.to(ctx)
# test#3: basic
h0 = F.ones((3, 5))
h1 = conv(h0, g)
h1 = conv(g, h0)
assert len(g.ndata) == 0
assert len(g.edata) == 0
# test#4: basic
h0 = F.ones((3, 5, 5))
h1 = conv(h0, g)
h1 = conv(g, h0)
assert len(g.ndata) == 0
assert len(g.edata) == 0
......@@ -94,7 +94,7 @@ def test_tagconv():
# test#1: basic
h0 = F.ones((3, 5))
h1 = conv(h0, g)
h1 = conv(g, h0)
assert len(g.ndata) == 0
assert len(g.edata) == 0
shp = norm.shape + (1,) * (h0.dim() - 1)
......@@ -107,7 +107,7 @@ def test_tagconv():
conv = conv.to(ctx)
# test#2: basic
h0 = F.ones((3, 5))
h1 = conv(h0, g)
h1 = conv(g, h0)
assert h1.shape[-1] == 2
# test reset_parameters
......@@ -127,7 +127,7 @@ def test_set2set():
# test#1: basic
h0 = F.randn((g.number_of_nodes(), 5))
h1 = s2s(h0, g)
h1 = s2s(g, h0)
assert h1.shape[0] == 10 and h1.dim() == 1
# test#2: batched graph
......@@ -135,7 +135,7 @@ def test_set2set():
g2 = dgl.DGLGraph(nx.path_graph(5))
bg = dgl.batch([g, g1, g2])
h0 = F.randn((bg.number_of_nodes(), 5))
h1 = s2s(h0, bg)
h1 = s2s(bg, h0)
assert h1.shape[0] == 3 and h1.shape[1] == 10 and h1.dim() == 2
def test_glob_att_pool():
......@@ -149,13 +149,13 @@ def test_glob_att_pool():
# test#1: basic
h0 = F.randn((g.number_of_nodes(), 5))
h1 = gap(h0, g)
h1 = gap(g, h0)
assert h1.shape[0] == 10 and h1.dim() == 1
# test#2: batched graph
bg = dgl.batch([g, g, g, g])
h0 = F.randn((bg.number_of_nodes(), 5))
h1 = gap(h0, bg)
h1 = gap(bg, h0)
assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.dim() == 2
def test_simple_pool():
......@@ -176,13 +176,13 @@ def test_simple_pool():
max_pool = max_pool.to(ctx)
sort_pool = sort_pool.to(ctx)
h0 = h0.to(ctx)
h1 = sum_pool(h0, g)
h1 = sum_pool(g, h0)
assert F.allclose(h1, F.sum(h0, 0))
h1 = avg_pool(h0, g)
h1 = avg_pool(g, h0)
assert F.allclose(h1, F.mean(h0, 0))
h1 = max_pool(h0, g)
h1 = max_pool(g, h0)
assert F.allclose(h1, F.max(h0, 0))
h1 = sort_pool(h0, g)
h1 = sort_pool(g, h0)
assert h1.shape[0] == 10 * 5 and h1.dim() == 1
# test#2: batched graph
......@@ -192,7 +192,7 @@ def test_simple_pool():
if F.gpu_ctx():
h0 = h0.to(ctx)
h1 = sum_pool(h0, bg)
h1 = sum_pool(bg, h0)
truth = th.stack([F.sum(h0[:15], 0),
F.sum(h0[15:20], 0),
F.sum(h0[20:35], 0),
......@@ -200,7 +200,7 @@ def test_simple_pool():
F.sum(h0[40:55], 0)], 0)
assert F.allclose(h1, truth)
h1 = avg_pool(h0, bg)
h1 = avg_pool(bg, h0)
truth = th.stack([F.mean(h0[:15], 0),
F.mean(h0[15:20], 0),
F.mean(h0[20:35], 0),
......@@ -208,7 +208,7 @@ def test_simple_pool():
F.mean(h0[40:55], 0)], 0)
assert F.allclose(h1, truth)
h1 = max_pool(h0, bg)
h1 = max_pool(bg, h0)
truth = th.stack([F.max(h0[:15], 0),
F.max(h0[15:20], 0),
F.max(h0[20:35], 0),
......@@ -216,7 +216,7 @@ def test_simple_pool():
F.max(h0[40:55], 0)], 0)
assert F.allclose(h1, truth)
h1 = sort_pool(h0, bg)
h1 = sort_pool(bg, h0)
assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.dim() == 2
def test_set_trans():
......@@ -234,11 +234,11 @@ def test_set_trans():
# test#1: basic
h0 = F.randn((g.number_of_nodes(), 50))
h1 = st_enc_0(h0, g)
h1 = st_enc_0(g, h0)
assert h1.shape == h0.shape
h1 = st_enc_1(h0, g)
h1 = st_enc_1(g, h0)
assert h1.shape == h0.shape
h2 = st_dec(h1, g)
h2 = st_dec(g, h1)
assert h2.shape[0] == 200 and h2.dim() == 1
# test#2: batched graph
......@@ -246,12 +246,12 @@ def test_set_trans():
g2 = dgl.DGLGraph(nx.path_graph(10))
bg = dgl.batch([g, g1, g2])
h0 = F.randn((bg.number_of_nodes(), 50))
h1 = st_enc_0(h0, bg)
h1 = st_enc_0(bg, h0)
assert h1.shape == h0.shape
h1 = st_enc_1(h0, bg)
h1 = st_enc_1(bg, h0)
assert h1.shape == h0.shape
h2 = st_dec(h1, bg)
h2 = st_dec(bg, h1)
assert h2.shape[0] == 3 and h2.shape[1] == 200 and h2.dim() == 2
def uniform_attention(g, shape):
......@@ -375,7 +375,7 @@ def test_gat_conv():
gat = gat.to(ctx)
feat = feat.to(ctx)
h = gat(feat, g)
h = gat(g, feat)
assert h.shape[-1] == 2 and h.shape[-2] == 4
def test_sage_conv():
......@@ -389,7 +389,7 @@ def test_sage_conv():
sage = sage.to(ctx)
feat = feat.to(ctx)
h = sage(feat, g)
h = sage(g, feat)
assert h.shape[-1] == 10
def test_sgc_conv():
......@@ -403,7 +403,7 @@ def test_sgc_conv():
sgc = sgc.to(ctx)
feat = feat.to(ctx)
h = sgc(feat, g)
h = sgc(g, feat)
assert h.shape[-1] == 10
# cached
......@@ -412,8 +412,8 @@ def test_sgc_conv():
if F.gpu_ctx():
sgc = sgc.to(ctx)
h_0 = sgc(feat, g)
h_1 = sgc(feat + 1, g)
h_0 = sgc(g, feat)
h_1 = sgc(g, feat + 1)
assert F.allclose(h_0, h_1)
assert h_0.shape[-1] == 10
......@@ -427,7 +427,7 @@ def test_appnp_conv():
appnp = appnp.to(ctx)
feat = feat.to(ctx)
h = appnp(feat, g)
h = appnp(g, feat)
assert h.shape[-1] == 5
def test_gin_conv():
......@@ -444,7 +444,7 @@ def test_gin_conv():
gin = gin.to(ctx)
feat = feat.to(ctx)
h = gin(feat, g)
h = gin(g, feat)
assert h.shape[-1] == 12
def test_agnn_conv():
......@@ -457,7 +457,7 @@ def test_agnn_conv():
agnn = agnn.to(ctx)
feat = feat.to(ctx)
h = agnn(feat, g)
h = agnn(g, feat)
assert h.shape[-1] == 5
def test_gated_graph_conv():
......@@ -472,7 +472,7 @@ def test_gated_graph_conv():
feat = feat.to(ctx)
etypes = etypes.to(ctx)
h = ggconv(feat, etypes, g)
h = ggconv(g, feat, etypes)
# current we only do shape check
assert h.shape[-1] == 10
......@@ -489,7 +489,7 @@ def test_nn_conv():
feat = feat.to(ctx)
efeat = efeat.to(ctx)
h = nnconv(feat, efeat, g)
h = nnconv(g, feat, efeat)
# currently we only do shape check
assert h.shape[-1] == 10
......@@ -505,7 +505,7 @@ def test_gmm_conv():
feat = feat.to(ctx)
pseudo = pseudo.to(ctx)
h = gmmconv(feat, pseudo, g)
h = gmmconv(g, feat, pseudo)
# currently we only do shape check
assert h.shape[-1] == 10
......@@ -523,8 +523,8 @@ def test_dense_graph_conv():
dense_conv = dense_conv.to(ctx)
feat = feat.to(ctx)
out_conv = conv(feat, g)
out_dense_conv = dense_conv(feat, adj)
out_conv = conv(g, feat)
out_dense_conv = dense_conv(adj, feat)
assert F.allclose(out_conv, out_dense_conv)
def test_dense_sage_conv():
......@@ -541,8 +541,8 @@ def test_dense_sage_conv():
dense_sage = dense_sage.to(ctx)
feat = feat.to(ctx)
out_sage = sage(feat, g)
out_dense_sage = dense_sage(feat, adj)
out_sage = sage(g, feat)
out_dense_sage = dense_sage(adj, feat)
assert F.allclose(out_sage, out_dense_sage)
def test_dense_cheb_conv():
......@@ -562,8 +562,8 @@ def test_dense_cheb_conv():
dense_cheb = dense_cheb.to(ctx)
feat = feat.to(ctx)
out_cheb = cheb(feat, g)
out_dense_cheb = dense_cheb(feat, adj)
out_cheb = cheb(g, feat, [2.0])
out_dense_cheb = dense_cheb(adj, feat, 2.0)
assert F.allclose(out_cheb, out_dense_cheb)
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment