"examples/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "701b4fccc2eed979ae3db801fabb6bf7bc03940c"
Unverified Commit 9314aabd authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[Refactor] Interface of nn modules (#798)

* refactor

* upd mpnn
parent 650f6ee1
...@@ -36,5 +36,5 @@ class GCN(gluon.Block): ...@@ -36,5 +36,5 @@ class GCN(gluon.Block):
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
if i != 0: if i != 0:
h = self.dropout(h) h = self.dropout(h)
h = layer(h, self.g) h = layer(self.g, h)
return h return h
...@@ -53,5 +53,5 @@ class APPNP(nn.Module): ...@@ -53,5 +53,5 @@ class APPNP(nn.Module):
h = self.activation(layer(h)) h = self.activation(layer(h))
h = self.layers[-1](self.feat_drop(h)) h = self.layers[-1](self.feat_drop(h))
# propagation step # propagation step
h = self.propagate(h, self.g) h = self.propagate(self.g, h)
return h return h
...@@ -49,7 +49,7 @@ class GAT(nn.Module): ...@@ -49,7 +49,7 @@ class GAT(nn.Module):
def forward(self, inputs): def forward(self, inputs):
h = inputs h = inputs
for l in range(self.num_layers): for l in range(self.num_layers):
h = self.gat_layers[l](h, self.g).flatten(1) h = self.gat_layers[l](self.g, h).flatten(1)
# output projection # output projection
logits = self.gat_layers[-1](h, self.g).mean(1) logits = self.gat_layers[-1](self.g, h).mean(1)
return logits return logits
...@@ -35,5 +35,5 @@ class GCN(nn.Module): ...@@ -35,5 +35,5 @@ class GCN(nn.Module):
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
if i != 0: if i != 0:
h = self.dropout(h) h = self.dropout(h)
h = layer(h, self.g) h = layer(self.g, h)
return h return h
...@@ -155,7 +155,7 @@ class GIN(nn.Module): ...@@ -155,7 +155,7 @@ class GIN(nn.Module):
hidden_rep = [h] hidden_rep = [h]
for layer in range(self.num_layers - 1): for layer in range(self.num_layers - 1):
h = self.ginlayers[layer](h, g) h = self.ginlayers[layer](g, h)
hidden_rep.append(h) hidden_rep.append(h)
score_over_layer = 0 score_over_layer = 0
......
...@@ -41,7 +41,7 @@ class GraphSAGE(nn.Module): ...@@ -41,7 +41,7 @@ class GraphSAGE(nn.Module):
def forward(self, features): def forward(self, features):
h = features h = features
for layer in self.layers: for layer in self.layers:
h = layer(h, self.g) h = layer(self.g, h)
return h return h
......
...@@ -50,7 +50,7 @@ GIN_CONFIG = { ...@@ -50,7 +50,7 @@ GIN_CONFIG = {
} }
CHEBNET_CONFIG = { CHEBNET_CONFIG = {
'extra_args': [16, 1, 3, True], 'extra_args': [32, 1, 2, True],
'lr': 1e-2, 'lr': 1e-2,
'weight_decay': 5e-4, 'weight_decay': 5e-4,
} }
...@@ -30,7 +30,7 @@ class GCN(nn.Module): ...@@ -30,7 +30,7 @@ class GCN(nn.Module):
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
if i != 0: if i != 0:
h = self.dropout(h) h = self.dropout(h)
h = layer(h, self.g) h = layer(self.g, h)
return h return h
...@@ -70,9 +70,9 @@ class GAT(nn.Module): ...@@ -70,9 +70,9 @@ class GAT(nn.Module):
def forward(self, inputs): def forward(self, inputs):
h = inputs h = inputs
for l in range(self.num_layers): for l in range(self.num_layers):
h = self.gat_layers[l](h, self.g).flatten(1) h = self.gat_layers[l](self.g, h).flatten(1)
# output projection # output projection
logits = self.gat_layers[-1](h, self.g).mean(1) logits = self.gat_layers[-1](self.g, h).mean(1)
return logits return logits
...@@ -101,7 +101,7 @@ class GraphSAGE(nn.Module): ...@@ -101,7 +101,7 @@ class GraphSAGE(nn.Module):
def forward(self, features): def forward(self, features):
h = features h = features
for layer in self.layers: for layer in self.layers:
h = layer(h, self.g) h = layer(self.g, h)
return h return h
...@@ -148,7 +148,7 @@ class APPNP(nn.Module): ...@@ -148,7 +148,7 @@ class APPNP(nn.Module):
h = self.activation(layer(h)) h = self.activation(layer(h))
h = self.layers[-1](self.feat_drop(h)) h = self.layers[-1](self.feat_drop(h))
# propagation step # propagation step
h = self.propagate(h, self.g) h = self.propagate(self.g, h)
return h return h
...@@ -178,7 +178,7 @@ class TAGCN(nn.Module): ...@@ -178,7 +178,7 @@ class TAGCN(nn.Module):
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
if i != 0: if i != 0:
h = self.dropout(h) h = self.dropout(h)
h = layer(h, self.g) h = layer(self.g, h)
return h return h
...@@ -210,7 +210,7 @@ class AGNN(nn.Module): ...@@ -210,7 +210,7 @@ class AGNN(nn.Module):
def forward(self, features): def forward(self, features):
h = self.proj(features) h = self.proj(features)
for layer in self.layers: for layer in self.layers:
h = layer(h, self.g) h = layer(self.g, h)
return self.cls(h) return self.cls(h)
...@@ -231,7 +231,7 @@ class SGC(nn.Module): ...@@ -231,7 +231,7 @@ class SGC(nn.Module):
bias=bias) bias=bias)
def forward(self, features): def forward(self, features):
return self.net(features, self.g) return self.net(self.g, features)
class GIN(nn.Module): class GIN(nn.Module):
...@@ -286,7 +286,7 @@ class GIN(nn.Module): ...@@ -286,7 +286,7 @@ class GIN(nn.Module):
def forward(self, features): def forward(self, features):
h = features h = features
for layer in self.layers: for layer in self.layers:
h = layer(h, self.g) h = layer(self.g, h)
return h return h
class ChebNet(nn.Module): class ChebNet(nn.Module):
...@@ -316,5 +316,5 @@ class ChebNet(nn.Module): ...@@ -316,5 +316,5 @@ class ChebNet(nn.Module):
def forward(self, features): def forward(self, features):
h = features h = features
for layer in self.layers: for layer in self.layers:
h = layer(h, self.g) h = layer(self.g, h, [2])
return h return h
\ No newline at end of file
...@@ -19,7 +19,7 @@ from dgl.nn.pytorch.conv import SGConv ...@@ -19,7 +19,7 @@ from dgl.nn.pytorch.conv import SGConv
def evaluate(model, g, features, labels, mask): def evaluate(model, g, features, labels, mask):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
logits = model(features, g)[mask] # only compute the evaluation set logits = model(g, features)[mask] # only compute the evaluation set
labels = labels[mask] labels = labels[mask]
_, indices = torch.max(logits, dim=1) _, indices = torch.max(logits, dim=1)
correct = torch.sum(indices == labels) correct = torch.sum(indices == labels)
...@@ -86,7 +86,7 @@ def main(args): ...@@ -86,7 +86,7 @@ def main(args):
if epoch >= 3: if epoch >= 3:
t0 = time.time() t0 = time.time()
# forward # forward
logits = model(features, g) # only compute the train set logits = model(g, features) # only compute the train set
loss = loss_fcn(logits[train_mask], labels[train_mask]) loss = loss_fcn(logits[train_mask], labels[train_mask])
optimizer.zero_grad() optimizer.zero_grad()
......
...@@ -21,7 +21,7 @@ def normalize(h): ...@@ -21,7 +21,7 @@ def normalize(h):
def evaluate(model, features, graph, labels, mask): def evaluate(model, features, graph, labels, mask):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
logits = model(features, graph)[mask] # only compute the evaluation set logits = model(graph, features)[mask] # only compute the evaluation set
labels = labels[mask] labels = labels[mask]
_, indices = torch.max(logits, dim=1) _, indices = torch.max(logits, dim=1)
correct = torch.sum(indices == labels) correct = torch.sum(indices == labels)
...@@ -82,7 +82,7 @@ def main(args): ...@@ -82,7 +82,7 @@ def main(args):
# define loss closure # define loss closure
def closure(): def closure():
optimizer.zero_grad() optimizer.zero_grad()
output = model(features, g)[train_mask] output = model(g, features)[train_mask]
loss_train = F.cross_entropy(output, labels[train_mask]) loss_train = F.cross_entropy(output, labels[train_mask])
loss_train.backward() loss_train.backward()
return loss_train return loss_train
......
...@@ -35,5 +35,5 @@ class TAGCN(nn.Module): ...@@ -35,5 +35,5 @@ class TAGCN(nn.Module):
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
if i != 0: if i != 0:
h = self.dropout(h) h = self.dropout(h)
h = layer(h, self.g) h = layer(self.g, h)
return h return h
...@@ -145,7 +145,7 @@ class MPNNModel(nn.Module): ...@@ -145,7 +145,7 @@ class MPNNModel(nn.Module):
out, h = self.gru(m.unsqueeze(0), h) out, h = self.gru(m.unsqueeze(0), h)
out = out.squeeze(0) out = out.squeeze(0)
out = self.set2set(out, g) out = self.set2set(g, out)
out = F.relu(self.lin1(out)) out = F.relu(self.lin1(out))
out = self.lin2(out) out = self.lin2(out)
return out return out
...@@ -83,7 +83,7 @@ class GraphConv(gluon.Block): ...@@ -83,7 +83,7 @@ class GraphConv(gluon.Block):
self._activation = activation self._activation = activation
def forward(self, feat, graph): def forward(self, graph, feat):
r"""Compute graph convolution. r"""Compute graph convolution.
Notes Notes
...@@ -95,10 +95,10 @@ class GraphConv(gluon.Block): ...@@ -95,10 +95,10 @@ class GraphConv(gluon.Block):
Parameters Parameters
---------- ----------
feat : mxnet.NDArray
The input feature
graph : DGLGraph graph : DGLGraph
The graph. The graph.
feat : mxnet.NDArray
The input feature
Returns Returns
------- -------
......
...@@ -19,16 +19,16 @@ class SumPooling(nn.Block): ...@@ -19,16 +19,16 @@ class SumPooling(nn.Block):
def __init__(self): def __init__(self):
super(SumPooling, self).__init__() super(SumPooling, self).__init__()
def forward(self, feat, graph): def forward(self, graph, feat):
r"""Compute sum pooling. r"""Compute sum pooling.
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph
The graph.
feat : mxnet.NDArray feat : mxnet.NDArray
The input feature with shape :math:`(N, *)` where The input feature with shape :math:`(N, *)` where
:math:`N` is the number of nodes in the graph. :math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns Returns
------- -------
...@@ -56,16 +56,16 @@ class AvgPooling(nn.Block): ...@@ -56,16 +56,16 @@ class AvgPooling(nn.Block):
def __init__(self): def __init__(self):
super(AvgPooling, self).__init__() super(AvgPooling, self).__init__()
def forward(self, feat, graph): def forward(self, graph, feat):
r"""Compute average pooling. r"""Compute average pooling.
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph
The graph.
feat : mxnet.NDArray feat : mxnet.NDArray
The input feature with shape :math:`(N, *)` where The input feature with shape :math:`(N, *)` where
:math:`N` is the number of nodes in the graph. :math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns Returns
------- -------
...@@ -93,16 +93,16 @@ class MaxPooling(nn.Block): ...@@ -93,16 +93,16 @@ class MaxPooling(nn.Block):
def __init__(self): def __init__(self):
super(MaxPooling, self).__init__() super(MaxPooling, self).__init__()
def forward(self, feat, graph): def forward(self, graph, feat):
r"""Compute max pooling. r"""Compute max pooling.
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph
The graph.
feat : mxnet.NDArray feat : mxnet.NDArray
The input feature with shape :math:`(N, *)` where The input feature with shape :math:`(N, *)` where
:math:`N` is the number of nodes in the graph. :math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns Returns
------- -------
...@@ -134,16 +134,16 @@ class SortPooling(nn.Block): ...@@ -134,16 +134,16 @@ class SortPooling(nn.Block):
super(SortPooling, self).__init__() super(SortPooling, self).__init__()
self.k = k self.k = k
def forward(self, feat, graph): def forward(self, graph, feat):
r"""Compute sort pooling. r"""Compute sort pooling.
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph
The graph.
feat : mxnet.NDArray feat : mxnet.NDArray
The input feature with shape :math:`(N, D)` where The input feature with shape :math:`(N, D)` where
:math:`N` is the number of nodes in the graph. :math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns Returns
------- -------
...@@ -190,16 +190,16 @@ class GlobalAttentionPooling(nn.Block): ...@@ -190,16 +190,16 @@ class GlobalAttentionPooling(nn.Block):
self.gate_nn = gate_nn self.gate_nn = gate_nn
self.feat_nn = feat_nn self.feat_nn = feat_nn
def forward(self, feat, graph): def forward(self, graph, feat):
r"""Compute global attention pooling. r"""Compute global attention pooling.
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph
The graph.
feat : mxnet.NDArray feat : mxnet.NDArray
The input feature with shape :math:`(N, D)` where The input feature with shape :math:`(N, D)` where
:math:`N` is the number of nodes in the graph. :math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns Returns
------- -------
...@@ -258,16 +258,16 @@ class Set2Set(nn.Block): ...@@ -258,16 +258,16 @@ class Set2Set(nn.Block):
self.lstm = gluon.rnn.LSTM( self.lstm = gluon.rnn.LSTM(
self.input_dim, num_layers=n_layers, input_size=self.output_dim) self.input_dim, num_layers=n_layers, input_size=self.output_dim)
def forward(self, feat, graph): def forward(self, graph, feat):
r"""Compute set2set pooling. r"""Compute set2set pooling.
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph
The graph.
feat : mxnet.NDArray feat : mxnet.NDArray
The input feature with shape :math:`(N, D)` where The input feature with shape :math:`(N, D)` where
:math:`N` is the number of nodes in the graph. :math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns Returns
------- -------
......
...@@ -107,7 +107,7 @@ class GraphConv(nn.Module): ...@@ -107,7 +107,7 @@ class GraphConv(nn.Module):
if self.bias is not None: if self.bias is not None:
init.zeros_(self.bias) init.zeros_(self.bias)
def forward(self, feat, graph): def forward(self, graph, feat):
r"""Compute graph convolution. r"""Compute graph convolution.
Notes Notes
...@@ -119,10 +119,10 @@ class GraphConv(nn.Module): ...@@ -119,10 +119,10 @@ class GraphConv(nn.Module):
Parameters Parameters
---------- ----------
feat : torch.Tensor
The input feature
graph : DGLGraph graph : DGLGraph
The graph. The graph.
feat : torch.Tensor
The input feature
Returns Returns
------- -------
...@@ -246,16 +246,16 @@ class GATConv(nn.Module): ...@@ -246,16 +246,16 @@ class GATConv(nn.Module):
if isinstance(self.res_fc, nn.Linear): if isinstance(self.res_fc, nn.Linear):
nn.init.xavier_normal_(self.res_fc.weight, gain=gain) nn.init.xavier_normal_(self.res_fc.weight, gain=gain)
def forward(self, feat, graph): def forward(self, graph, feat):
r"""Compute graph attention network layer. r"""Compute graph attention network layer.
Parameters Parameters
---------- ----------
graph : DGLGraph
The graph.
feat : torch.Tensor feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}` The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
is size of input feature, :math:`N` is the number of nodes. is size of input feature, :math:`N` is the number of nodes.
graph : DGLGraph
The graph.
Returns Returns
------- -------
...@@ -338,16 +338,16 @@ class TAGConv(nn.Module): ...@@ -338,16 +338,16 @@ class TAGConv(nn.Module):
gain = nn.init.calculate_gain('relu') gain = nn.init.calculate_gain('relu')
nn.init.xavier_normal_(self.lin.weight, gain=gain) nn.init.xavier_normal_(self.lin.weight, gain=gain)
def forward(self, feat, graph): def forward(self, graph, feat):
r"""Compute topology adaptive graph convolution. r"""Compute topology adaptive graph convolution.
Parameters Parameters
---------- ----------
graph : DGLGraph
The graph.
feat : torch.Tensor feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}` The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
is size of input feature, :math:`N` is the number of nodes. is size of input feature, :math:`N` is the number of nodes.
graph : DGLGraph
The graph.
Returns Returns
------- -------
...@@ -643,16 +643,16 @@ class SAGEConv(nn.Module): ...@@ -643,16 +643,16 @@ class SAGEConv(nn.Module):
_, (rst, _) = self.lstm(m, h) _, (rst, _) = self.lstm(m, h)
return {'neigh': rst.squeeze(0)} return {'neigh': rst.squeeze(0)}
def forward(self, feat, graph): def forward(self, graph, feat):
r"""Compute GraphSAGE layer. r"""Compute GraphSAGE layer.
Parameters Parameters
---------- ----------
graph : DGLGraph
The graph.
feat : torch.Tensor feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}` The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
is size of input feature, :math:`N` is the number of nodes. is size of input feature, :math:`N` is the number of nodes.
graph : DGLGraph
The graph.
Returns Returns
------- -------
...@@ -742,11 +742,13 @@ class GatedGraphConv(nn.Module): ...@@ -742,11 +742,13 @@ class GatedGraphConv(nn.Module):
self.gru.reset_parameters() self.gru.reset_parameters()
init.xavier_normal_(self.edge_embed.weight, gain=gain) init.xavier_normal_(self.edge_embed.weight, gain=gain)
def forward(self, feat, etypes, graph): def forward(self, graph, feat, etypes):
"""Compute Gated Graph Convolution layer. """Compute Gated Graph Convolution layer.
Parameters Parameters
---------- ----------
graph : DGLGraph
The graph.
feat : torch.Tensor feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`N` The input feature of shape :math:`(N, D_{in})` where :math:`N`
is the number of nodes of the graph and :math:`D_{in}` is the is the number of nodes of the graph and :math:`D_{in}` is the
...@@ -754,8 +756,6 @@ class GatedGraphConv(nn.Module): ...@@ -754,8 +756,6 @@ class GatedGraphConv(nn.Module):
etypes : torch.LongTensor etypes : torch.LongTensor
The edge type tensor of shape :math:`(E,)` where :math:`E` is The edge type tensor of shape :math:`(E,)` where :math:`E` is
the number of edges of the graph. the number of edges of the graph.
graph : DGLGraph
The graph.
Returns Returns
------- -------
...@@ -856,11 +856,13 @@ class GMMConv(nn.Module): ...@@ -856,11 +856,13 @@ class GMMConv(nn.Module):
if self.bias is not None: if self.bias is not None:
init.zeros_(self.bias.data) init.zeros_(self.bias.data)
def forward(self, feat, pseudo, graph): def forward(self, graph, feat, pseudo):
"""Compute Gaussian Mixture Model Convolution layer. """Compute Gaussian Mixture Model Convolution layer.
Parameters Parameters
---------- ----------
graph : DGLGraph
The graph.
feat : torch.Tensor feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`N` The input feature of shape :math:`(N, D_{in})` where :math:`N`
is the number of nodes of the graph and :math:`D_{in}` is the is the number of nodes of the graph and :math:`D_{in}` is the
...@@ -869,8 +871,6 @@ class GMMConv(nn.Module): ...@@ -869,8 +871,6 @@ class GMMConv(nn.Module):
The pseudo coordinate tensor of shape :math:`(E, D_{u})` where The pseudo coordinate tensor of shape :math:`(E, D_{u})` where
:math:`E` is the number of edges of the graph and :math:`D_{u}` :math:`E` is the number of edges of the graph and :math:`D_{u}`
is the dimensionality of pseudo coordinate. is the dimensionality of pseudo coordinate.
graph : DGLGraph
The graph.
Returns Returns
------- -------
...@@ -940,18 +940,18 @@ class GINConv(nn.Module): ...@@ -940,18 +940,18 @@ class GINConv(nn.Module):
else: else:
self.register_buffer('eps', th.FloatTensor([init_eps])) self.register_buffer('eps', th.FloatTensor([init_eps]))
def forward(self, feat, graph): def forward(self, graph, feat):
r"""Compute Graph Isomorphism Network layer. r"""Compute Graph Isomorphism Network layer.
Parameters Parameters
---------- ----------
graph : DGLGraph
The graph.
feat : torch.Tensor feat : torch.Tensor
The input feature of shape :math:`(N, D)` where :math:`D` The input feature of shape :math:`(N, D)` where :math:`D`
could be any positive integer, :math:`N` is the number could be any positive integer, :math:`N` is the number
of nodes. If ``apply_func`` is not None, :math:`D` should of nodes. If ``apply_func`` is not None, :math:`D` should
fit the input dimensionality requirement of ``apply_func``. fit the input dimensionality requirement of ``apply_func``.
graph : DGLGraph
The graph.
Returns Returns
------- -------
...@@ -1025,16 +1025,22 @@ class ChebConv(nn.Module): ...@@ -1025,16 +1025,22 @@ class ChebConv(nn.Module):
if module.bias is not None: if module.bias is not None:
init.zeros_(module.bias) init.zeros_(module.bias)
def forward(self, feat, graph, lambda_max=None): def forward(self, graph, feat, lambda_max=None):
r"""Compute ChebNet layer. r"""Compute ChebNet layer.
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph
The graph.
feat : torch.Tensor feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}` The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
is size of input feature, :math:`N` is the number of nodes. is size of input feature, :math:`N` is the number of nodes.
graph : DGLGraph lambda_max : list or tensor or None, optional.
The graph. A list(tensor) with length :math:`B`, stores the largest eigenvalue
of the normalized laplacian of each individual graph in ``graph``,
where :math:`B` is the batch size of the input graph. Default: None.
If None, this method would compute the list by calling
``dgl.laplacian_lambda_max``.
Returns Returns
------- -------
...@@ -1047,13 +1053,13 @@ class ChebConv(nn.Module): ...@@ -1047,13 +1053,13 @@ class ChebConv(nn.Module):
graph.in_degrees().float().clamp(min=1), -0.5).unsqueeze(-1).to(feat.device) graph.in_degrees().float().clamp(min=1), -0.5).unsqueeze(-1).to(feat.device)
if lambda_max is None: if lambda_max is None:
lambda_max = laplacian_lambda_max(graph) lambda_max = laplacian_lambda_max(graph)
lambda_max = th.Tensor(lambda_max).to(feat.device) if isinstance(lambda_max, list):
lambda_max = th.Tensor(lambda_max).to(feat.device)
if lambda_max.dim() < 1: if lambda_max.dim() < 1:
lambda_max = lambda_max.unsqueeze(-1) # (B,) to (B, 1) lambda_max = lambda_max.unsqueeze(-1) # (B,) to (B, 1)
# broadcast from (B, 1) to (N, 1) # broadcast from (B, 1) to (N, 1)
lambda_max = broadcast_nodes(graph, lambda_max) lambda_max = broadcast_nodes(graph, lambda_max)
# T0(X) # T0(X)
Tx_0 = feat Tx_0 = feat
rst = self.fc[0](Tx_0) rst = self.fc[0](Tx_0)
# T1(X) # T1(X)
...@@ -1125,16 +1131,16 @@ class SGConv(nn.Module): ...@@ -1125,16 +1131,16 @@ class SGConv(nn.Module):
self._k = k self._k = k
self.norm = norm self.norm = norm
def forward(self, feat, graph): def forward(self, graph, feat):
r"""Compute Simplifying Graph Convolution layer. r"""Compute Simplifying Graph Convolution layer.
Parameters Parameters
---------- ----------
graph : DGLGraph
The graph.
feat : torch.Tensor feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}` The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
is size of input feature, :math:`N` is the number of nodes. is size of input feature, :math:`N` is the number of nodes.
graph : DGLGraph
The graph.
Returns Returns
------- -------
...@@ -1241,11 +1247,13 @@ class NNConv(nn.Module): ...@@ -1241,11 +1247,13 @@ class NNConv(nn.Module):
if isinstance(self.res_fc, nn.Linear): if isinstance(self.res_fc, nn.Linear):
nn.init.xavier_normal_(self.res_fc.weight, gain=gain) nn.init.xavier_normal_(self.res_fc.weight, gain=gain)
def forward(self, feat, efeat, graph): def forward(self, graph, feat, efeat):
r"""Compute MPNN Graph Convolution layer. r"""Compute MPNN Graph Convolution layer.
Parameters Parameters
---------- ----------
graph : DGLGraph
The graph.
feat : torch.Tensor feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`N` The input feature of shape :math:`(N, D_{in})` where :math:`N`
is the number of nodes of the graph and :math:`D_{in}` is the is the number of nodes of the graph and :math:`D_{in}` is the
...@@ -1253,8 +1261,6 @@ class NNConv(nn.Module): ...@@ -1253,8 +1261,6 @@ class NNConv(nn.Module):
efeat : torch.Tensor efeat : torch.Tensor
The edge feature of shape :math:`(N, *)`, should fit the input The edge feature of shape :math:`(N, *)`, should fit the input
shape requirement of ``edge_nn``. shape requirement of ``edge_nn``.
graph : DGLGraph
The graph.
Returns Returns
------- -------
...@@ -1309,16 +1315,16 @@ class APPNPConv(nn.Module): ...@@ -1309,16 +1315,16 @@ class APPNPConv(nn.Module):
self._alpha = alpha self._alpha = alpha
self.edge_drop = nn.Dropout(edge_drop) if edge_drop > 0 else Identity() self.edge_drop = nn.Dropout(edge_drop) if edge_drop > 0 else Identity()
def forward(self, feat, graph): def forward(self, graph, feat):
r"""Compute APPNP layer. r"""Compute APPNP layer.
Parameters Parameters
---------- ----------
graph : DGLGraph
The graph.
feat : torch.Tensor feat : torch.Tensor
The input feature of shape :math:`(N, *)` :math:`N` is the The input feature of shape :math:`(N, *)` :math:`N` is the
number of nodes, and :math:`*` could be of any shape. number of nodes, and :math:`*` could be of any shape.
graph : DGLGraph
The graph.
Returns Returns
------- -------
...@@ -1374,16 +1380,16 @@ class AGNNConv(nn.Module): ...@@ -1374,16 +1380,16 @@ class AGNNConv(nn.Module):
else: else:
self.register_buffer('beta', th.Tensor([init_beta])) self.register_buffer('beta', th.Tensor([init_beta]))
def forward(self, feat, graph): def forward(self, graph, feat):
r"""Compute AGNN layer. r"""Compute AGNN layer.
Parameters Parameters
---------- ----------
graph : DGLGraph
The graph.
feat : torch.Tensor feat : torch.Tensor
The input feature of shape :math:`(N, *)` :math:`N` is the The input feature of shape :math:`(N, *)` :math:`N` is the
number of nodes, and :math:`*` could be of any shape. number of nodes, and :math:`*` could be of any shape.
graph : DGLGraph
The graph.
Returns Returns
------- -------
...@@ -1452,18 +1458,18 @@ class DenseGraphConv(nn.Module): ...@@ -1452,18 +1458,18 @@ class DenseGraphConv(nn.Module):
if self.bias is not None: if self.bias is not None:
init.zeros_(self.bias) init.zeros_(self.bias)
def forward(self, feat, adj): def forward(self, adj, feat):
r"""Compute (Dense) Graph Convolution layer. r"""Compute (Dense) Graph Convolution layer.
Parameters Parameters
---------- ----------
feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
is size of input feature, :math:`N` is the number of nodes.
adj : torch.Tensor adj : torch.Tensor
The adjacency matrix of the graph to apply Graph Convolution on, The adjacency matrix of the graph to apply Graph Convolution on,
should be of shape :math:`(N, N)`, where a row represents the destination should be of shape :math:`(N, N)`, where a row represents the destination
and a column represents the source. and a column represents the source.
feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
is size of input feature, :math:`N` is the number of nodes.
Returns Returns
------- -------
...@@ -1549,18 +1555,18 @@ class DenseSAGEConv(nn.Module): ...@@ -1549,18 +1555,18 @@ class DenseSAGEConv(nn.Module):
gain = nn.init.calculate_gain('relu') gain = nn.init.calculate_gain('relu')
nn.init.xavier_uniform_(self.fc.weight, gain=gain) nn.init.xavier_uniform_(self.fc.weight, gain=gain)
def forward(self, feat, adj): def forward(self, adj, feat):
r"""Compute (Dense) Graph SAGE layer. r"""Compute (Dense) Graph SAGE layer.
Parameters Parameters
---------- ----------
feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
is size of input feature, :math:`N` is the number of nodes.
adj : torch.Tensor adj : torch.Tensor
The adjacency matrix of the graph to apply Graph Convolution on, The adjacency matrix of the graph to apply Graph Convolution on,
should be of shape :math:`(N, N)`, where a row represents the destination should be of shape :math:`(N, N)`, where a row represents the destination
and a column represents the source. and a column represents the source.
feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
is size of input feature, :math:`N` is the number of nodes.
Returns Returns
------- -------
...@@ -1629,18 +1635,21 @@ class DenseChebConv(nn.Module): ...@@ -1629,18 +1635,21 @@ class DenseChebConv(nn.Module):
for i in range(self._k): for i in range(self._k):
init.xavier_normal_(self.W[i], init.calculate_gain('relu')) init.xavier_normal_(self.W[i], init.calculate_gain('relu'))
def forward(self, feat, adj): def forward(self, adj, feat, lambda_max=None):
r"""Compute (Dense) Chebyshev Spectral Graph Convolution layer. r"""Compute (Dense) Chebyshev Spectral Graph Convolution layer.
Parameters Parameters
---------- ----------
feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
is size of input feature, :math:`N` is the number of nodes.
adj : torch.Tensor adj : torch.Tensor
The adjacency matrix of the graph to apply Graph Convolution on, The adjacency matrix of the graph to apply Graph Convolution on,
should be of shape :math:`(N, N)`, where a row represents the destination should be of shape :math:`(N, N)`, where a row represents the destination
and a column represents the source. and a column represents the source.
feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
is size of input feature, :math:`N` is the number of nodes.
lambda_max : float or None, optional
A float value indicates the largest eigenvalue of given graph.
Default: None.
Returns Returns
------- -------
...@@ -1656,10 +1665,11 @@ class DenseChebConv(nn.Module): ...@@ -1656,10 +1665,11 @@ class DenseChebConv(nn.Module):
I = th.eye(num_nodes).to(A) I = th.eye(num_nodes).to(A)
L = I - D_invsqrt @ A @ D_invsqrt L = I - D_invsqrt @ A @ D_invsqrt
lambda_ = th.eig(L)[0][:, 0] if lambda_max is None:
lambda_max = lambda_.max() lambda_ = th.eig(L)[0][:, 0]
L_hat = 2 * L / lambda_max - I lambda_max = lambda_.max()
L_hat = 2 * L / lambda_max - I
Z = [th.eye(num_nodes).to(A)] Z = [th.eye(num_nodes).to(A)]
for i in range(1, self._k): for i in range(1, self._k):
if i == 1: if i == 1:
......
...@@ -23,17 +23,17 @@ class SumPooling(nn.Module): ...@@ -23,17 +23,17 @@ class SumPooling(nn.Module):
def __init__(self): def __init__(self):
super(SumPooling, self).__init__() super(SumPooling, self).__init__()
def forward(self, feat, graph): def forward(self, graph, feat):
r"""Compute sum pooling. r"""Compute sum pooling.
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph
The graph.
feat : torch.Tensor feat : torch.Tensor
The input feature with shape :math:`(N, *)` where The input feature with shape :math:`(N, *)` where
:math:`N` is the number of nodes in the graph. :math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns Returns
------- -------
...@@ -57,16 +57,16 @@ class AvgPooling(nn.Module): ...@@ -57,16 +57,16 @@ class AvgPooling(nn.Module):
def __init__(self): def __init__(self):
super(AvgPooling, self).__init__() super(AvgPooling, self).__init__()
def forward(self, feat, graph): def forward(self, graph, feat):
r"""Compute average pooling. r"""Compute average pooling.
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph
The graph.
feat : torch.Tensor feat : torch.Tensor
The input feature with shape :math:`(N, *)` where The input feature with shape :math:`(N, *)` where
:math:`N` is the number of nodes in the graph. :math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns Returns
------- -------
...@@ -90,16 +90,16 @@ class MaxPooling(nn.Module): ...@@ -90,16 +90,16 @@ class MaxPooling(nn.Module):
def __init__(self): def __init__(self):
super(MaxPooling, self).__init__() super(MaxPooling, self).__init__()
def forward(self, feat, graph): def forward(self, graph, feat):
r"""Compute max pooling. r"""Compute max pooling.
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph
The graph.
feat : torch.Tensor feat : torch.Tensor
The input feature with shape :math:`(N, *)` where The input feature with shape :math:`(N, *)` where
:math:`N` is the number of nodes in the graph. :math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns Returns
------- -------
...@@ -127,16 +127,16 @@ class SortPooling(nn.Module): ...@@ -127,16 +127,16 @@ class SortPooling(nn.Module):
super(SortPooling, self).__init__() super(SortPooling, self).__init__()
self.k = k self.k = k
def forward(self, feat, graph): def forward(self, graph, feat):
r"""Compute sort pooling. r"""Compute sort pooling.
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph
The graph.
feat : torch.Tensor feat : torch.Tensor
The input feature with shape :math:`(N, D)` where The input feature with shape :math:`(N, D)` where
:math:`N` is the number of nodes in the graph. :math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns Returns
------- -------
...@@ -179,16 +179,16 @@ class GlobalAttentionPooling(nn.Module): ...@@ -179,16 +179,16 @@ class GlobalAttentionPooling(nn.Module):
self.gate_nn = gate_nn self.gate_nn = gate_nn
self.feat_nn = feat_nn self.feat_nn = feat_nn
def forward(self, feat, graph): def forward(self, graph, feat):
r"""Compute global attention pooling. r"""Compute global attention pooling.
Parameters Parameters
---------- ----------
graph : DGLGraph
The graph.
feat : torch.Tensor feat : torch.Tensor
The input feature with shape :math:`(N, D)` where The input feature with shape :math:`(N, D)` where
:math:`N` is the number of nodes in the graph. :math:`N` is the number of nodes in the graph.
graph : DGLGraph
The graph.
Returns Returns
------- -------
...@@ -252,16 +252,16 @@ class Set2Set(nn.Module): ...@@ -252,16 +252,16 @@ class Set2Set(nn.Module):
"""Reinitialize learnable parameters.""" """Reinitialize learnable parameters."""
self.lstm.reset_parameters() self.lstm.reset_parameters()
def forward(self, feat, graph): def forward(self, graph, feat):
r"""Compute set2set pooling. r"""Compute set2set pooling.
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph
The graph.
feat : torch.Tensor feat : torch.Tensor
The input feature with shape :math:`(N, D)` where The input feature with shape :math:`(N, D)` where
:math:`N` is the number of nodes in the graph. :math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns Returns
------- -------
...@@ -568,17 +568,17 @@ class SetTransformerEncoder(nn.Module): ...@@ -568,17 +568,17 @@ class SetTransformerEncoder(nn.Module):
self.layers = nn.ModuleList(layers) self.layers = nn.ModuleList(layers)
def forward(self, feat, graph): def forward(self, graph, feat):
""" """
Compute the Encoder part of Set Transformer. Compute the Encoder part of Set Transformer.
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph
The graph.
feat : torch.Tensor feat : torch.Tensor
The input feature with shape :math:`(N, D)` where The input feature with shape :math:`(N, D)` where
:math:`N` is the number of nodes in the graph. :math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns Returns
------- -------
...@@ -634,17 +634,17 @@ class SetTransformerDecoder(nn.Module): ...@@ -634,17 +634,17 @@ class SetTransformerDecoder(nn.Module):
self.layers = nn.ModuleList(layers) self.layers = nn.ModuleList(layers)
def forward(self, feat, graph): def forward(self, graph, feat):
""" """
Compute the decoder part of Set Transformer. Compute the decoder part of Set Transformer.
Parameters Parameters
---------- ----------
graph : DGLGraph or BatchedDGLGraph
The graph.
feat : torch.Tensor feat : torch.Tensor
The input feature with shape :math:`(N, D)` where The input feature with shape :math:`(N, D)` where
:math:`N` is the number of nodes in the graph. :math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns Returns
------- -------
......
...@@ -24,13 +24,13 @@ def test_graph_conv(): ...@@ -24,13 +24,13 @@ def test_graph_conv():
conv.initialize(ctx=ctx) conv.initialize(ctx=ctx)
# test#1: basic # test#1: basic
h0 = F.ones((3, 5)) h0 = F.ones((3, 5))
h1 = conv(h0, g) h1 = conv(g, h0)
assert len(g.ndata) == 0 assert len(g.ndata) == 0
assert len(g.edata) == 0 assert len(g.edata) == 0
check_close(h1, _AXWb(adj, h0, conv.weight, conv.bias)) check_close(h1, _AXWb(adj, h0, conv.weight, conv.bias))
# test#2: more-dim # test#2: more-dim
h0 = F.ones((3, 5, 5)) h0 = F.ones((3, 5, 5))
h1 = conv(h0, g) h1 = conv(g, h0)
assert len(g.ndata) == 0 assert len(g.ndata) == 0
assert len(g.edata) == 0 assert len(g.edata) == 0
check_close(h1, _AXWb(adj, h0, conv.weight, conv.bias)) check_close(h1, _AXWb(adj, h0, conv.weight, conv.bias))
...@@ -40,12 +40,12 @@ def test_graph_conv(): ...@@ -40,12 +40,12 @@ def test_graph_conv():
# test#3: basic # test#3: basic
h0 = F.ones((3, 5)) h0 = F.ones((3, 5))
h1 = conv(h0, g) h1 = conv(g, h0)
assert len(g.ndata) == 0 assert len(g.ndata) == 0
assert len(g.edata) == 0 assert len(g.edata) == 0
# test#4: basic # test#4: basic
h0 = F.ones((3, 5, 5)) h0 = F.ones((3, 5, 5))
h1 = conv(h0, g) h1 = conv(g, h0)
assert len(g.ndata) == 0 assert len(g.ndata) == 0
assert len(g.edata) == 0 assert len(g.edata) == 0
...@@ -55,18 +55,18 @@ def test_graph_conv(): ...@@ -55,18 +55,18 @@ def test_graph_conv():
with autograd.train_mode(): with autograd.train_mode():
# test#3: basic # test#3: basic
h0 = F.ones((3, 5)) h0 = F.ones((3, 5))
h1 = conv(h0, g) h1 = conv(g, h0)
assert len(g.ndata) == 0 assert len(g.ndata) == 0
assert len(g.edata) == 0 assert len(g.edata) == 0
# test#4: basic # test#4: basic
h0 = F.ones((3, 5, 5)) h0 = F.ones((3, 5, 5))
h1 = conv(h0, g) h1 = conv(g, h0)
assert len(g.ndata) == 0 assert len(g.ndata) == 0
assert len(g.edata) == 0 assert len(g.edata) == 0
# test not override features # test not override features
g.ndata["h"] = 2 * F.ones((3, 1)) g.ndata["h"] = 2 * F.ones((3, 1))
h1 = conv(h0, g) h1 = conv(g, h0)
assert len(g.ndata) == 1 assert len(g.ndata) == 1
assert len(g.edata) == 0 assert len(g.edata) == 0
assert "h" in g.ndata assert "h" in g.ndata
...@@ -82,13 +82,13 @@ def test_set2set(): ...@@ -82,13 +82,13 @@ def test_set2set():
# test#1: basic # test#1: basic
h0 = F.randn((g.number_of_nodes(), 5)) h0 = F.randn((g.number_of_nodes(), 5))
h1 = s2s(h0, g) h1 = s2s(g, h0)
assert h1.shape[0] == 10 and h1.ndim == 1 assert h1.shape[0] == 10 and h1.ndim == 1
# test#2: batched graph # test#2: batched graph
bg = dgl.batch([g, g, g]) bg = dgl.batch([g, g, g])
h0 = F.randn((bg.number_of_nodes(), 5)) h0 = F.randn((bg.number_of_nodes(), 5))
h1 = s2s(h0, bg) h1 = s2s(bg, h0)
assert h1.shape[0] == 3 and h1.shape[1] == 10 and h1.ndim == 2 assert h1.shape[0] == 3 and h1.shape[1] == 10 and h1.ndim == 2
def test_glob_att_pool(): def test_glob_att_pool():
...@@ -100,13 +100,13 @@ def test_glob_att_pool(): ...@@ -100,13 +100,13 @@ def test_glob_att_pool():
print(gap) print(gap)
# test#1: basic # test#1: basic
h0 = F.randn((g.number_of_nodes(), 5)) h0 = F.randn((g.number_of_nodes(), 5))
h1 = gap(h0, g) h1 = gap(g, h0)
assert h1.shape[0] == 10 and h1.ndim == 1 assert h1.shape[0] == 10 and h1.ndim == 1
# test#2: batched graph # test#2: batched graph
bg = dgl.batch([g, g, g, g]) bg = dgl.batch([g, g, g, g])
h0 = F.randn((bg.number_of_nodes(), 5)) h0 = F.randn((bg.number_of_nodes(), 5))
h1 = gap(h0, bg) h1 = gap(bg, h0)
assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.ndim == 2 assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.ndim == 2
def test_simple_pool(): def test_simple_pool():
...@@ -120,20 +120,20 @@ def test_simple_pool(): ...@@ -120,20 +120,20 @@ def test_simple_pool():
# test#1: basic # test#1: basic
h0 = F.randn((g.number_of_nodes(), 5)) h0 = F.randn((g.number_of_nodes(), 5))
h1 = sum_pool(h0, g) h1 = sum_pool(g, h0)
check_close(h1, F.sum(h0, 0)) check_close(h1, F.sum(h0, 0))
h1 = avg_pool(h0, g) h1 = avg_pool(g, h0)
check_close(h1, F.mean(h0, 0)) check_close(h1, F.mean(h0, 0))
h1 = max_pool(h0, g) h1 = max_pool(g, h0)
check_close(h1, F.max(h0, 0)) check_close(h1, F.max(h0, 0))
h1 = sort_pool(h0, g) h1 = sort_pool(g, h0)
assert h1.shape[0] == 10 * 5 and h1.ndim == 1 assert h1.shape[0] == 10 * 5 and h1.ndim == 1
# test#2: batched graph # test#2: batched graph
g_ = dgl.DGLGraph(nx.path_graph(5)) g_ = dgl.DGLGraph(nx.path_graph(5))
bg = dgl.batch([g, g_, g, g_, g]) bg = dgl.batch([g, g_, g, g_, g])
h0 = F.randn((bg.number_of_nodes(), 5)) h0 = F.randn((bg.number_of_nodes(), 5))
h1 = sum_pool(h0, bg) h1 = sum_pool(bg, h0)
truth = mx.nd.stack(F.sum(h0[:15], 0), truth = mx.nd.stack(F.sum(h0[:15], 0),
F.sum(h0[15:20], 0), F.sum(h0[15:20], 0),
F.sum(h0[20:35], 0), F.sum(h0[20:35], 0),
...@@ -141,7 +141,7 @@ def test_simple_pool(): ...@@ -141,7 +141,7 @@ def test_simple_pool():
F.sum(h0[40:55], 0), axis=0) F.sum(h0[40:55], 0), axis=0)
check_close(h1, truth) check_close(h1, truth)
h1 = avg_pool(h0, bg) h1 = avg_pool(bg, h0)
truth = mx.nd.stack(F.mean(h0[:15], 0), truth = mx.nd.stack(F.mean(h0[:15], 0),
F.mean(h0[15:20], 0), F.mean(h0[15:20], 0),
F.mean(h0[20:35], 0), F.mean(h0[20:35], 0),
...@@ -149,7 +149,7 @@ def test_simple_pool(): ...@@ -149,7 +149,7 @@ def test_simple_pool():
F.mean(h0[40:55], 0), axis=0) F.mean(h0[40:55], 0), axis=0)
check_close(h1, truth) check_close(h1, truth)
h1 = max_pool(h0, bg) h1 = max_pool(bg, h0)
truth = mx.nd.stack(F.max(h0[:15], 0), truth = mx.nd.stack(F.max(h0[:15], 0),
F.max(h0[15:20], 0), F.max(h0[15:20], 0),
F.max(h0[20:35], 0), F.max(h0[20:35], 0),
...@@ -157,7 +157,7 @@ def test_simple_pool(): ...@@ -157,7 +157,7 @@ def test_simple_pool():
F.max(h0[40:55], 0), axis=0) F.max(h0[40:55], 0), axis=0)
check_close(h1, truth) check_close(h1, truth)
h1 = sort_pool(h0, bg) h1 = sort_pool(bg, h0)
assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.ndim == 2 assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.ndim == 2
def uniform_attention(g, shape): def uniform_attention(g, shape):
......
...@@ -24,13 +24,13 @@ def test_graph_conv(): ...@@ -24,13 +24,13 @@ def test_graph_conv():
print(conv) print(conv)
# test#1: basic # test#1: basic
h0 = F.ones((3, 5)) h0 = F.ones((3, 5))
h1 = conv(h0, g) h1 = conv(g, h0)
assert len(g.ndata) == 0 assert len(g.ndata) == 0
assert len(g.edata) == 0 assert len(g.edata) == 0
assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias)) assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
# test#2: more-dim # test#2: more-dim
h0 = F.ones((3, 5, 5)) h0 = F.ones((3, 5, 5))
h1 = conv(h0, g) h1 = conv(g, h0)
assert len(g.ndata) == 0 assert len(g.ndata) == 0
assert len(g.edata) == 0 assert len(g.edata) == 0
assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias)) assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
...@@ -40,12 +40,12 @@ def test_graph_conv(): ...@@ -40,12 +40,12 @@ def test_graph_conv():
conv = conv.to(ctx) conv = conv.to(ctx)
# test#3: basic # test#3: basic
h0 = F.ones((3, 5)) h0 = F.ones((3, 5))
h1 = conv(h0, g) h1 = conv(g, h0)
assert len(g.ndata) == 0 assert len(g.ndata) == 0
assert len(g.edata) == 0 assert len(g.edata) == 0
# test#4: basic # test#4: basic
h0 = F.ones((3, 5, 5)) h0 = F.ones((3, 5, 5))
h1 = conv(h0, g) h1 = conv(g, h0)
assert len(g.ndata) == 0 assert len(g.ndata) == 0
assert len(g.edata) == 0 assert len(g.edata) == 0
...@@ -54,12 +54,12 @@ def test_graph_conv(): ...@@ -54,12 +54,12 @@ def test_graph_conv():
conv = conv.to(ctx) conv = conv.to(ctx)
# test#3: basic # test#3: basic
h0 = F.ones((3, 5)) h0 = F.ones((3, 5))
h1 = conv(h0, g) h1 = conv(g, h0)
assert len(g.ndata) == 0 assert len(g.ndata) == 0
assert len(g.edata) == 0 assert len(g.edata) == 0
# test#4: basic # test#4: basic
h0 = F.ones((3, 5, 5)) h0 = F.ones((3, 5, 5))
h1 = conv(h0, g) h1 = conv(g, h0)
assert len(g.ndata) == 0 assert len(g.ndata) == 0
assert len(g.edata) == 0 assert len(g.edata) == 0
...@@ -94,7 +94,7 @@ def test_tagconv(): ...@@ -94,7 +94,7 @@ def test_tagconv():
# test#1: basic # test#1: basic
h0 = F.ones((3, 5)) h0 = F.ones((3, 5))
h1 = conv(h0, g) h1 = conv(g, h0)
assert len(g.ndata) == 0 assert len(g.ndata) == 0
assert len(g.edata) == 0 assert len(g.edata) == 0
shp = norm.shape + (1,) * (h0.dim() - 1) shp = norm.shape + (1,) * (h0.dim() - 1)
...@@ -107,7 +107,7 @@ def test_tagconv(): ...@@ -107,7 +107,7 @@ def test_tagconv():
conv = conv.to(ctx) conv = conv.to(ctx)
# test#2: basic # test#2: basic
h0 = F.ones((3, 5)) h0 = F.ones((3, 5))
h1 = conv(h0, g) h1 = conv(g, h0)
assert h1.shape[-1] == 2 assert h1.shape[-1] == 2
# test reset_parameters # test reset_parameters
...@@ -127,7 +127,7 @@ def test_set2set(): ...@@ -127,7 +127,7 @@ def test_set2set():
# test#1: basic # test#1: basic
h0 = F.randn((g.number_of_nodes(), 5)) h0 = F.randn((g.number_of_nodes(), 5))
h1 = s2s(h0, g) h1 = s2s(g, h0)
assert h1.shape[0] == 10 and h1.dim() == 1 assert h1.shape[0] == 10 and h1.dim() == 1
# test#2: batched graph # test#2: batched graph
...@@ -135,7 +135,7 @@ def test_set2set(): ...@@ -135,7 +135,7 @@ def test_set2set():
g2 = dgl.DGLGraph(nx.path_graph(5)) g2 = dgl.DGLGraph(nx.path_graph(5))
bg = dgl.batch([g, g1, g2]) bg = dgl.batch([g, g1, g2])
h0 = F.randn((bg.number_of_nodes(), 5)) h0 = F.randn((bg.number_of_nodes(), 5))
h1 = s2s(h0, bg) h1 = s2s(bg, h0)
assert h1.shape[0] == 3 and h1.shape[1] == 10 and h1.dim() == 2 assert h1.shape[0] == 3 and h1.shape[1] == 10 and h1.dim() == 2
def test_glob_att_pool(): def test_glob_att_pool():
...@@ -149,13 +149,13 @@ def test_glob_att_pool(): ...@@ -149,13 +149,13 @@ def test_glob_att_pool():
# test#1: basic # test#1: basic
h0 = F.randn((g.number_of_nodes(), 5)) h0 = F.randn((g.number_of_nodes(), 5))
h1 = gap(h0, g) h1 = gap(g, h0)
assert h1.shape[0] == 10 and h1.dim() == 1 assert h1.shape[0] == 10 and h1.dim() == 1
# test#2: batched graph # test#2: batched graph
bg = dgl.batch([g, g, g, g]) bg = dgl.batch([g, g, g, g])
h0 = F.randn((bg.number_of_nodes(), 5)) h0 = F.randn((bg.number_of_nodes(), 5))
h1 = gap(h0, bg) h1 = gap(bg, h0)
assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.dim() == 2 assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.dim() == 2
def test_simple_pool(): def test_simple_pool():
...@@ -176,13 +176,13 @@ def test_simple_pool(): ...@@ -176,13 +176,13 @@ def test_simple_pool():
max_pool = max_pool.to(ctx) max_pool = max_pool.to(ctx)
sort_pool = sort_pool.to(ctx) sort_pool = sort_pool.to(ctx)
h0 = h0.to(ctx) h0 = h0.to(ctx)
h1 = sum_pool(h0, g) h1 = sum_pool(g, h0)
assert F.allclose(h1, F.sum(h0, 0)) assert F.allclose(h1, F.sum(h0, 0))
h1 = avg_pool(h0, g) h1 = avg_pool(g, h0)
assert F.allclose(h1, F.mean(h0, 0)) assert F.allclose(h1, F.mean(h0, 0))
h1 = max_pool(h0, g) h1 = max_pool(g, h0)
assert F.allclose(h1, F.max(h0, 0)) assert F.allclose(h1, F.max(h0, 0))
h1 = sort_pool(h0, g) h1 = sort_pool(g, h0)
assert h1.shape[0] == 10 * 5 and h1.dim() == 1 assert h1.shape[0] == 10 * 5 and h1.dim() == 1
# test#2: batched graph # test#2: batched graph
...@@ -192,7 +192,7 @@ def test_simple_pool(): ...@@ -192,7 +192,7 @@ def test_simple_pool():
if F.gpu_ctx(): if F.gpu_ctx():
h0 = h0.to(ctx) h0 = h0.to(ctx)
h1 = sum_pool(h0, bg) h1 = sum_pool(bg, h0)
truth = th.stack([F.sum(h0[:15], 0), truth = th.stack([F.sum(h0[:15], 0),
F.sum(h0[15:20], 0), F.sum(h0[15:20], 0),
F.sum(h0[20:35], 0), F.sum(h0[20:35], 0),
...@@ -200,7 +200,7 @@ def test_simple_pool(): ...@@ -200,7 +200,7 @@ def test_simple_pool():
F.sum(h0[40:55], 0)], 0) F.sum(h0[40:55], 0)], 0)
assert F.allclose(h1, truth) assert F.allclose(h1, truth)
h1 = avg_pool(h0, bg) h1 = avg_pool(bg, h0)
truth = th.stack([F.mean(h0[:15], 0), truth = th.stack([F.mean(h0[:15], 0),
F.mean(h0[15:20], 0), F.mean(h0[15:20], 0),
F.mean(h0[20:35], 0), F.mean(h0[20:35], 0),
...@@ -208,7 +208,7 @@ def test_simple_pool(): ...@@ -208,7 +208,7 @@ def test_simple_pool():
F.mean(h0[40:55], 0)], 0) F.mean(h0[40:55], 0)], 0)
assert F.allclose(h1, truth) assert F.allclose(h1, truth)
h1 = max_pool(h0, bg) h1 = max_pool(bg, h0)
truth = th.stack([F.max(h0[:15], 0), truth = th.stack([F.max(h0[:15], 0),
F.max(h0[15:20], 0), F.max(h0[15:20], 0),
F.max(h0[20:35], 0), F.max(h0[20:35], 0),
...@@ -216,7 +216,7 @@ def test_simple_pool(): ...@@ -216,7 +216,7 @@ def test_simple_pool():
F.max(h0[40:55], 0)], 0) F.max(h0[40:55], 0)], 0)
assert F.allclose(h1, truth) assert F.allclose(h1, truth)
h1 = sort_pool(h0, bg) h1 = sort_pool(bg, h0)
assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.dim() == 2 assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.dim() == 2
def test_set_trans(): def test_set_trans():
...@@ -234,11 +234,11 @@ def test_set_trans(): ...@@ -234,11 +234,11 @@ def test_set_trans():
# test#1: basic # test#1: basic
h0 = F.randn((g.number_of_nodes(), 50)) h0 = F.randn((g.number_of_nodes(), 50))
h1 = st_enc_0(h0, g) h1 = st_enc_0(g, h0)
assert h1.shape == h0.shape assert h1.shape == h0.shape
h1 = st_enc_1(h0, g) h1 = st_enc_1(g, h0)
assert h1.shape == h0.shape assert h1.shape == h0.shape
h2 = st_dec(h1, g) h2 = st_dec(g, h1)
assert h2.shape[0] == 200 and h2.dim() == 1 assert h2.shape[0] == 200 and h2.dim() == 1
# test#2: batched graph # test#2: batched graph
...@@ -246,12 +246,12 @@ def test_set_trans(): ...@@ -246,12 +246,12 @@ def test_set_trans():
g2 = dgl.DGLGraph(nx.path_graph(10)) g2 = dgl.DGLGraph(nx.path_graph(10))
bg = dgl.batch([g, g1, g2]) bg = dgl.batch([g, g1, g2])
h0 = F.randn((bg.number_of_nodes(), 50)) h0 = F.randn((bg.number_of_nodes(), 50))
h1 = st_enc_0(h0, bg) h1 = st_enc_0(bg, h0)
assert h1.shape == h0.shape assert h1.shape == h0.shape
h1 = st_enc_1(h0, bg) h1 = st_enc_1(bg, h0)
assert h1.shape == h0.shape assert h1.shape == h0.shape
h2 = st_dec(h1, bg) h2 = st_dec(bg, h1)
assert h2.shape[0] == 3 and h2.shape[1] == 200 and h2.dim() == 2 assert h2.shape[0] == 3 and h2.shape[1] == 200 and h2.dim() == 2
def uniform_attention(g, shape): def uniform_attention(g, shape):
...@@ -375,7 +375,7 @@ def test_gat_conv(): ...@@ -375,7 +375,7 @@ def test_gat_conv():
gat = gat.to(ctx) gat = gat.to(ctx)
feat = feat.to(ctx) feat = feat.to(ctx)
h = gat(feat, g) h = gat(g, feat)
assert h.shape[-1] == 2 and h.shape[-2] == 4 assert h.shape[-1] == 2 and h.shape[-2] == 4
def test_sage_conv(): def test_sage_conv():
...@@ -389,7 +389,7 @@ def test_sage_conv(): ...@@ -389,7 +389,7 @@ def test_sage_conv():
sage = sage.to(ctx) sage = sage.to(ctx)
feat = feat.to(ctx) feat = feat.to(ctx)
h = sage(feat, g) h = sage(g, feat)
assert h.shape[-1] == 10 assert h.shape[-1] == 10
def test_sgc_conv(): def test_sgc_conv():
...@@ -403,7 +403,7 @@ def test_sgc_conv(): ...@@ -403,7 +403,7 @@ def test_sgc_conv():
sgc = sgc.to(ctx) sgc = sgc.to(ctx)
feat = feat.to(ctx) feat = feat.to(ctx)
h = sgc(feat, g) h = sgc(g, feat)
assert h.shape[-1] == 10 assert h.shape[-1] == 10
# cached # cached
...@@ -412,8 +412,8 @@ def test_sgc_conv(): ...@@ -412,8 +412,8 @@ def test_sgc_conv():
if F.gpu_ctx(): if F.gpu_ctx():
sgc = sgc.to(ctx) sgc = sgc.to(ctx)
h_0 = sgc(feat, g) h_0 = sgc(g, feat)
h_1 = sgc(feat + 1, g) h_1 = sgc(g, feat + 1)
assert F.allclose(h_0, h_1) assert F.allclose(h_0, h_1)
assert h_0.shape[-1] == 10 assert h_0.shape[-1] == 10
...@@ -427,7 +427,7 @@ def test_appnp_conv(): ...@@ -427,7 +427,7 @@ def test_appnp_conv():
appnp = appnp.to(ctx) appnp = appnp.to(ctx)
feat = feat.to(ctx) feat = feat.to(ctx)
h = appnp(feat, g) h = appnp(g, feat)
assert h.shape[-1] == 5 assert h.shape[-1] == 5
def test_gin_conv(): def test_gin_conv():
...@@ -444,7 +444,7 @@ def test_gin_conv(): ...@@ -444,7 +444,7 @@ def test_gin_conv():
gin = gin.to(ctx) gin = gin.to(ctx)
feat = feat.to(ctx) feat = feat.to(ctx)
h = gin(feat, g) h = gin(g, feat)
assert h.shape[-1] == 12 assert h.shape[-1] == 12
def test_agnn_conv(): def test_agnn_conv():
...@@ -457,7 +457,7 @@ def test_agnn_conv(): ...@@ -457,7 +457,7 @@ def test_agnn_conv():
agnn = agnn.to(ctx) agnn = agnn.to(ctx)
feat = feat.to(ctx) feat = feat.to(ctx)
h = agnn(feat, g) h = agnn(g, feat)
assert h.shape[-1] == 5 assert h.shape[-1] == 5
def test_gated_graph_conv(): def test_gated_graph_conv():
...@@ -472,7 +472,7 @@ def test_gated_graph_conv(): ...@@ -472,7 +472,7 @@ def test_gated_graph_conv():
feat = feat.to(ctx) feat = feat.to(ctx)
etypes = etypes.to(ctx) etypes = etypes.to(ctx)
h = ggconv(feat, etypes, g) h = ggconv(g, feat, etypes)
# current we only do shape check # current we only do shape check
assert h.shape[-1] == 10 assert h.shape[-1] == 10
...@@ -489,7 +489,7 @@ def test_nn_conv(): ...@@ -489,7 +489,7 @@ def test_nn_conv():
feat = feat.to(ctx) feat = feat.to(ctx)
efeat = efeat.to(ctx) efeat = efeat.to(ctx)
h = nnconv(feat, efeat, g) h = nnconv(g, feat, efeat)
# currently we only do shape check # currently we only do shape check
assert h.shape[-1] == 10 assert h.shape[-1] == 10
...@@ -505,7 +505,7 @@ def test_gmm_conv(): ...@@ -505,7 +505,7 @@ def test_gmm_conv():
feat = feat.to(ctx) feat = feat.to(ctx)
pseudo = pseudo.to(ctx) pseudo = pseudo.to(ctx)
h = gmmconv(feat, pseudo, g) h = gmmconv(g, feat, pseudo)
# currently we only do shape check # currently we only do shape check
assert h.shape[-1] == 10 assert h.shape[-1] == 10
...@@ -523,8 +523,8 @@ def test_dense_graph_conv(): ...@@ -523,8 +523,8 @@ def test_dense_graph_conv():
dense_conv = dense_conv.to(ctx) dense_conv = dense_conv.to(ctx)
feat = feat.to(ctx) feat = feat.to(ctx)
out_conv = conv(feat, g) out_conv = conv(g, feat)
out_dense_conv = dense_conv(feat, adj) out_dense_conv = dense_conv(adj, feat)
assert F.allclose(out_conv, out_dense_conv) assert F.allclose(out_conv, out_dense_conv)
def test_dense_sage_conv(): def test_dense_sage_conv():
...@@ -541,8 +541,8 @@ def test_dense_sage_conv(): ...@@ -541,8 +541,8 @@ def test_dense_sage_conv():
dense_sage = dense_sage.to(ctx) dense_sage = dense_sage.to(ctx)
feat = feat.to(ctx) feat = feat.to(ctx)
out_sage = sage(feat, g) out_sage = sage(g, feat)
out_dense_sage = dense_sage(feat, adj) out_dense_sage = dense_sage(adj, feat)
assert F.allclose(out_sage, out_dense_sage) assert F.allclose(out_sage, out_dense_sage)
def test_dense_cheb_conv(): def test_dense_cheb_conv():
...@@ -562,8 +562,8 @@ def test_dense_cheb_conv(): ...@@ -562,8 +562,8 @@ def test_dense_cheb_conv():
dense_cheb = dense_cheb.to(ctx) dense_cheb = dense_cheb.to(ctx)
feat = feat.to(ctx) feat = feat.to(ctx)
out_cheb = cheb(feat, g) out_cheb = cheb(g, feat, [2.0])
out_dense_cheb = dense_cheb(feat, adj) out_dense_cheb = dense_cheb(adj, feat, 2.0)
assert F.allclose(out_cheb, out_dense_cheb) assert F.allclose(out_cheb, out_dense_cheb)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment